Skip to content

geno_lewm.training.collapse

collapse

RFC-0005 collapse monitoring for training batches.

The monitor accepts plain Python nested sequences and common tensor-like objects that expose detach(), cpu(), and/or tolist(). Core math stays dependency-free so the diagnostics remain available before a full ML runtime is installed.

CollapseMetrics dataclass

CollapseMetrics(pred_cos_mean: float, pred_l2_mean: float, target_var_per_dim: float, pred_var_per_dim: float, pred_target_corr: float, pairwise_pred_dist_mean: float, kl_reg: float)

Scalar RFC-0005 §3.6 collapse diagnostics for one batch.

CollapseThresholds dataclass

CollapseThresholds(pred_var_to_target_var: float = 0.5, pairwise_to_initial: float = 0.5, kl_reg_max: float = 10.0)

Alert thresholds from RFC-0005 §3.6.

CollapseAlert dataclass

CollapseAlert(criterion: str, value: float, threshold: float)

One tripped collapse criterion.

CollapseCheck dataclass

CollapseCheck(metrics: CollapseMetrics, alerts: tuple[CollapseAlert, ...])

Metrics and alerts produced by one collapse-monitor observation.

tripped property

tripped: bool

Return whether any collapse alert tripped.

CollapseMonitor dataclass

CollapseMonitor(log_every_steps: int = 500, thresholds: CollapseThresholds = CollapseThresholds(), initial_pairwise_pred_dist_mean: float | None = None)

Compute, register, and optionally log collapse diagnostics.

observe returns None on non-logging steps, otherwise a :class:CollapseCheck. The first logged batch establishes the pairwise-distance baseline unless the caller supplies one.

should_log

should_log(step: int) -> bool

Return whether step is a scheduled collapse-monitor step.

Source code in geno_lewm/training/collapse.py
def should_log(self, step: int) -> bool:
    """Return whether ``step`` is a scheduled collapse-monitor step."""
    _require_nonnegative_int("step", step)
    return step > 0 and step % self.log_every_steps == 0

observe

observe(prediction: object, target: object, *, kl_reg: float, step: int, logger: GenoLeWMLogger | None = None, force: bool = False) -> CollapseCheck | None

Observe a validation batch at step.

Metrics are computed, written to the registry, and logged only when step is a scheduled monitoring step unless force is true.

Source code in geno_lewm/training/collapse.py
def observe(
    self,
    prediction: object,
    target: object,
    *,
    kl_reg: float,
    step: int,
    logger: GenoLeWMLogger | None = None,
    force: bool = False,
) -> CollapseCheck | None:
    """Observe a validation batch at ``step``.

    Metrics are computed, written to the registry, and logged only
    when ``step`` is a scheduled monitoring step unless ``force`` is
    true.
    """
    _require_nonnegative_int("step", step)
    if not force and not self.should_log(step):
        return None

    metrics = compute_collapse_metrics(prediction, target, kl_reg=kl_reg)
    if self.initial_pairwise_pred_dist_mean is None:
        self.initial_pairwise_pred_dist_mean = metrics.pairwise_pred_dist_mean
    alerts = detect_collapse(
        metrics,
        thresholds=self.thresholds,
        initial_pairwise_pred_dist_mean=self.initial_pairwise_pred_dist_mean,
    )
    record_collapse_metrics(metrics, alerts=alerts, logger=logger, step=step)
    return CollapseCheck(metrics=metrics, alerts=alerts)

compute_collapse_metrics

compute_collapse_metrics(prediction: object, target: object, *, kl_reg: float) -> CollapseMetrics

Compute RFC-0005 §3.6 collapse metrics for one [N, D] batch.

Source code in geno_lewm/training/collapse.py
def compute_collapse_metrics(
    prediction: object,
    target: object,
    *,
    kl_reg: float,
) -> CollapseMetrics:
    """Compute RFC-0005 §3.6 collapse metrics for one ``[N, D]`` batch."""
    pred_rows = _as_rows(prediction, "prediction")
    target_rows = _as_rows(target, "target")
    if len(pred_rows) != len(target_rows):
        raise InputError(
            "prediction and target must have the same batch size",
            details={"prediction_rows": len(pred_rows), "target_rows": len(target_rows)},
        )
    dim = len(pred_rows[0])
    if dim != len(target_rows[0]):
        raise InputError(
            "prediction and target must have the same latent dimension",
            details={"prediction_dim": dim, "target_dim": len(target_rows[0])},
        )

    kl_value = _finite_float(kl_reg, "kl_reg")
    return CollapseMetrics(
        pred_cos_mean=_mean(
            _cosine(pred, tgt) for pred, tgt in zip(pred_rows, target_rows, strict=True)
        ),
        pred_l2_mean=_mean(
            _euclidean(pred, tgt) for pred, tgt in zip(pred_rows, target_rows, strict=True)
        ),
        target_var_per_dim=_mean_variance_per_dim(target_rows),
        pred_var_per_dim=_mean_variance_per_dim(pred_rows),
        pred_target_corr=_pearson_corr(_flatten(pred_rows), _flatten(target_rows)),
        pairwise_pred_dist_mean=_pairwise_dist_mean(pred_rows),
        kl_reg=kl_value,
    )

detect_collapse

detect_collapse(metrics: CollapseMetrics, *, thresholds: CollapseThresholds | None = None, initial_pairwise_pred_dist_mean: float | None = None) -> tuple[CollapseAlert, ...]

Return the RFC-0005 §3.6 alert criteria tripped by metrics.

Source code in geno_lewm/training/collapse.py
def detect_collapse(
    metrics: CollapseMetrics,
    *,
    thresholds: CollapseThresholds | None = None,
    initial_pairwise_pred_dist_mean: float | None = None,
) -> tuple[CollapseAlert, ...]:
    """Return the RFC-0005 §3.6 alert criteria tripped by ``metrics``."""
    active_thresholds = thresholds if thresholds is not None else CollapseThresholds()
    if initial_pairwise_pred_dist_mean is not None:
        _require_nonnegative_finite(
            "initial_pairwise_pred_dist_mean",
            initial_pairwise_pred_dist_mean,
        )

    alerts: list[CollapseAlert] = []
    pred_var_threshold = active_thresholds.pred_var_to_target_var * metrics.target_var_per_dim
    if metrics.pred_var_per_dim < pred_var_threshold:
        alerts.append(
            CollapseAlert(
                criterion="pred_var_per_dim",
                value=metrics.pred_var_per_dim,
                threshold=pred_var_threshold,
            )
        )

    if initial_pairwise_pred_dist_mean is not None:
        pairwise_threshold = active_thresholds.pairwise_to_initial * initial_pairwise_pred_dist_mean
        if metrics.pairwise_pred_dist_mean < pairwise_threshold:
            alerts.append(
                CollapseAlert(
                    criterion="pairwise_pred_dist_mean",
                    value=metrics.pairwise_pred_dist_mean,
                    threshold=pairwise_threshold,
                )
            )

    if metrics.kl_reg > active_thresholds.kl_reg_max:
        alerts.append(
            CollapseAlert(
                criterion="kl_reg",
                value=metrics.kl_reg,
                threshold=active_thresholds.kl_reg_max,
            )
        )

    return tuple(alerts)

record_collapse_metrics

record_collapse_metrics(metrics: CollapseMetrics, *, alerts: Iterable[CollapseAlert] = (), logger: GenoLeWMLogger | None = None, step: int | None = None) -> None

Write collapse metrics to the registry and optional structured logs.

Source code in geno_lewm/training/collapse.py
def record_collapse_metrics(
    metrics: CollapseMetrics,
    *,
    alerts: Iterable[CollapseAlert] = (),
    logger: GenoLeWMLogger | None = None,
    step: int | None = None,
) -> None:
    """Write collapse metrics to the registry and optional structured logs."""
    if step is not None:
        _require_nonnegative_int("step", step)

    get_gauge("geno_lewm.training.collapse.pred_cos_mean").set(metrics.pred_cos_mean)
    get_gauge("geno_lewm.training.collapse.pred_l2_mean").set(metrics.pred_l2_mean)
    get_gauge("geno_lewm.training.collapse.target_var_per_dim").set(metrics.target_var_per_dim)
    get_gauge("geno_lewm.training.collapse.pred_var_per_dim").set(metrics.pred_var_per_dim)
    get_gauge("geno_lewm.training.collapse.pred_target_corr").set(metrics.pred_target_corr)
    get_gauge("geno_lewm.training.collapse.pairwise_pred_dist_mean").set(
        metrics.pairwise_pred_dist_mean
    )
    get_gauge("geno_lewm.training.collapse.kl_reg").set(metrics.kl_reg)

    if logger is not None:
        for name, value in _metric_items(metrics):
            _log_training_metric(logger, name=name, value=value, step=step)

    for alert in alerts:
        get_counter("geno_lewm.training.collapse.alert").inc()
        if logger is not None:
            _log_collapse_alert(logger, alert, step=step)