geno_lewm.training.collapse¶
collapse
¶
RFC-0005 collapse monitoring for training batches.
The monitor accepts plain Python nested sequences and common tensor-like
objects that expose detach(), cpu(), and/or tolist(). Core
math stays dependency-free so the diagnostics remain available before a
full ML runtime is installed.
CollapseMetrics
dataclass
¶
CollapseMetrics(pred_cos_mean: float, pred_l2_mean: float, target_var_per_dim: float, pred_var_per_dim: float, pred_target_corr: float, pairwise_pred_dist_mean: float, kl_reg: float)
Scalar RFC-0005 §3.6 collapse diagnostics for one batch.
CollapseThresholds
dataclass
¶
CollapseThresholds(pred_var_to_target_var: float = 0.5, pairwise_to_initial: float = 0.5, kl_reg_max: float = 10.0)
Alert thresholds from RFC-0005 §3.6.
CollapseAlert
dataclass
¶
One tripped collapse criterion.
CollapseCheck
dataclass
¶
Metrics and alerts produced by one collapse-monitor observation.
CollapseMonitor
dataclass
¶
CollapseMonitor(log_every_steps: int = 500, thresholds: CollapseThresholds = CollapseThresholds(), initial_pairwise_pred_dist_mean: float | None = None)
Compute, register, and optionally log collapse diagnostics.
observe returns None on non-logging steps, otherwise a
:class:CollapseCheck. The first logged batch establishes the
pairwise-distance baseline unless the caller supplies one.
should_log
¶
Return whether step is a scheduled collapse-monitor step.
observe
¶
observe(prediction: object, target: object, *, kl_reg: float, step: int, logger: GenoLeWMLogger | None = None, force: bool = False) -> CollapseCheck | None
Observe a validation batch at step.
Metrics are computed, written to the registry, and logged only
when step is a scheduled monitoring step unless force is
true.
Source code in geno_lewm/training/collapse.py
compute_collapse_metrics
¶
Compute RFC-0005 §3.6 collapse metrics for one [N, D] batch.
Source code in geno_lewm/training/collapse.py
detect_collapse
¶
detect_collapse(metrics: CollapseMetrics, *, thresholds: CollapseThresholds | None = None, initial_pairwise_pred_dist_mean: float | None = None) -> tuple[CollapseAlert, ...]
Return the RFC-0005 §3.6 alert criteria tripped by metrics.
Source code in geno_lewm/training/collapse.py
record_collapse_metrics
¶
record_collapse_metrics(metrics: CollapseMetrics, *, alerts: Iterable[CollapseAlert] = (), logger: GenoLeWMLogger | None = None, step: int | None = None) -> None
Write collapse metrics to the registry and optional structured logs.