geno_lewm.predictor.losses¶
losses
¶
Prediction and LeJEPA monitoring losses for GenoLeWM predictors.
PredictionLossResult
dataclass
¶
PredictionLossResult(loss: Tensor, pred_loss: Tensor, kl_reg: Tensor, phase: Literal['phase1', 'phase2'])
Phase-aware predictor loss components.
prediction_loss
¶
prediction_loss(prediction: Tensor, target: Tensor, *, alpha: float = 1.0, beta: float = 0.1, mask: Tensor | None = None, eps: float = 1e-08) -> Tensor
Return RFC-0005 alpha * (1 - cos) + beta * MSE / d_state.
Source code in geno_lewm/predictor/losses.py
lejepa_kl_regularizer
¶
Return the closed-form KL from empirical state distribution to N(0, I).
Source code in geno_lewm/predictor/losses.py
predictor_loss
¶
predictor_loss(prediction: Tensor, target: Tensor, *, phase: Literal['phase1', 'phase2'] = 'phase1', alpha: float = 1.0, beta: float = 0.1, gamma: float = 0.5, mask: Tensor | None = None, regularizer_states: Tensor | None = None, eps: float = 1e-06) -> PredictionLossResult
Return phase-conditional total loss and monitorable components.