Skip to content

geno_lewm.predictor.losses

losses

Prediction and LeJEPA monitoring losses for GenoLeWM predictors.

PredictionLossResult dataclass

PredictionLossResult(loss: Tensor, pred_loss: Tensor, kl_reg: Tensor, phase: Literal['phase1', 'phase2'])

Phase-aware predictor loss components.

prediction_loss

prediction_loss(prediction: Tensor, target: Tensor, *, alpha: float = 1.0, beta: float = 0.1, mask: Tensor | None = None, eps: float = 1e-08) -> Tensor

Return RFC-0005 alpha * (1 - cos) + beta * MSE / d_state.

Source code in geno_lewm/predictor/losses.py
def prediction_loss(
    prediction: Tensor,
    target: Tensor,
    *,
    alpha: float = 1.0,
    beta: float = 0.1,
    mask: Tensor | None = None,
    eps: float = 1.0e-8,
) -> Tensor:  # pragma: no cover - optional torch runtime is tested separately.
    """Return RFC-0005 ``alpha * (1 - cos) + beta * MSE / d_state``."""
    _require_torch("prediction_loss")
    _validate_loss_inputs(prediction, target, mask=mask)
    _require_nonnegative("alpha", alpha)
    _require_nonnegative("beta", beta)
    _require_positive_float("eps", eps)

    cosine = functional.cosine_similarity(prediction, target, dim=-1, eps=eps)
    squared = (prediction - target).pow(2).sum(dim=-1) / prediction.shape[-1]
    per_step = alpha * (1.0 - cosine) + beta * squared
    return _masked_mean(per_step, mask)

lejepa_kl_regularizer

lejepa_kl_regularizer(states: Tensor, *, eps: float = 1e-06) -> Tensor

Return the closed-form KL from empirical state distribution to N(0, I).

Source code in geno_lewm/predictor/losses.py
def lejepa_kl_regularizer(
    states: Tensor, *, eps: float = 1.0e-6
) -> Tensor:  # pragma: no cover - optional torch runtime is tested separately.
    """Return the closed-form KL from empirical state distribution to ``N(0, I)``."""
    _require_torch("lejepa_kl_regularizer")
    _require_positive_float("eps", eps)
    if states.ndim < 2:
        raise InputError(
            "states must have shape (..., d_state)",
            details={"shape": tuple(states.shape)},
        )
    if not states.is_floating_point():
        raise InputError("states must be a floating-point tensor")
    d_state = states.shape[-1]
    if d_state <= 0:
        raise InputError("states must have a non-empty feature dimension")

    # float64 for the covariance spectrum: a minibatch with fewer samples than
    # d_state yields a rank-deficient empirical covariance, and float32 slogdet
    # loses sign reliability once many eigenvalues approach the stabilizer. The
    # covariance is symmetric PSD, so eigvalsh gives a stable spectrum directly.
    flat = states.reshape(-1, d_state).to(dtype=torch.float64)
    if flat.shape[0] == 0:
        raise InputError("states must contain at least one sample")
    mean = flat.mean(dim=0)
    centered = flat - mean
    covariance = centered.T @ centered / flat.shape[0]
    trace = torch.diagonal(covariance).sum()
    # logdet of the stabilized covariance ``cov + eps * I`` via its symmetric
    # eigenvalues. This is the same quantity as ``slogdet(cov + eps * I)`` but
    # never hits the float32 sign instability that slogdet exhibits when the
    # batch is rank-deficient (fewer samples than d_state). Clamp tiny negative
    # numerical eigenvalues to zero before adding the floor.
    eigvals = torch.linalg.eigvalsh(covariance)
    logdet = torch.log(torch.clamp(eigvals, min=0.0) + eps).sum()
    kl = 0.5 * (mean.pow(2).sum() + trace - logdet - d_state)
    return kl.to(dtype=states.dtype)

predictor_loss

predictor_loss(prediction: Tensor, target: Tensor, *, phase: Literal['phase1', 'phase2'] = 'phase1', alpha: float = 1.0, beta: float = 0.1, gamma: float = 0.5, mask: Tensor | None = None, regularizer_states: Tensor | None = None, eps: float = 1e-06) -> PredictionLossResult

Return phase-conditional total loss and monitorable components.

Source code in geno_lewm/predictor/losses.py
def predictor_loss(
    prediction: Tensor,
    target: Tensor,
    *,
    phase: Literal["phase1", "phase2"] = "phase1",
    alpha: float = 1.0,
    beta: float = 0.1,
    gamma: float = 0.5,
    mask: Tensor | None = None,
    regularizer_states: Tensor | None = None,
    eps: float = 1.0e-6,
) -> PredictionLossResult:  # pragma: no cover - optional torch runtime is tested separately.
    """Return phase-conditional total loss and monitorable components."""
    _require_torch("predictor_loss")
    if phase not in ("phase1", "phase2"):
        raise InputError("phase must be either 'phase1' or 'phase2'", details={"phase": phase})
    _require_nonnegative("gamma", gamma)
    pred = prediction_loss(prediction, target, alpha=alpha, beta=beta, mask=mask)
    reg_source = target if regularizer_states is None else regularizer_states
    kl_reg = lejepa_kl_regularizer(reg_source, eps=eps)
    total = pred if phase == "phase1" else pred + gamma * kl_reg
    return PredictionLossResult(loss=total, pred_loss=pred, kl_reg=kl_reg, phase=phase)