Skip to content

geno_lewm.training

training

Training helpers for GenoLeWM.

CollapseAlert dataclass

CollapseAlert(criterion: str, value: float, threshold: float)

One tripped collapse criterion.

CollapseCheck dataclass

CollapseCheck(metrics: CollapseMetrics, alerts: tuple[CollapseAlert, ...])

Metrics and alerts produced by one collapse-monitor observation.

tripped property

tripped: bool

Return whether any collapse alert tripped.

CollapseMetrics dataclass

CollapseMetrics(pred_cos_mean: float, pred_l2_mean: float, target_var_per_dim: float, pred_var_per_dim: float, pred_target_corr: float, pairwise_pred_dist_mean: float, kl_reg: float)

Scalar RFC-0005 §3.6 collapse diagnostics for one batch.

CollapseMonitor dataclass

CollapseMonitor(log_every_steps: int = 500, thresholds: CollapseThresholds = CollapseThresholds(), initial_pairwise_pred_dist_mean: float | None = None)

Compute, register, and optionally log collapse diagnostics.

observe returns None on non-logging steps, otherwise a :class:CollapseCheck. The first logged batch establishes the pairwise-distance baseline unless the caller supplies one.

should_log

should_log(step: int) -> bool

Return whether step is a scheduled collapse-monitor step.

Source code in geno_lewm/training/collapse.py
def should_log(self, step: int) -> bool:
    """Return whether ``step`` is a scheduled collapse-monitor step."""
    _require_nonnegative_int("step", step)
    return step > 0 and step % self.log_every_steps == 0

observe

observe(prediction: object, target: object, *, kl_reg: float, step: int, logger: GenoLeWMLogger | None = None, force: bool = False) -> CollapseCheck | None

Observe a validation batch at step.

Metrics are computed, written to the registry, and logged only when step is a scheduled monitoring step unless force is true.

Source code in geno_lewm/training/collapse.py
def observe(
    self,
    prediction: object,
    target: object,
    *,
    kl_reg: float,
    step: int,
    logger: GenoLeWMLogger | None = None,
    force: bool = False,
) -> CollapseCheck | None:
    """Observe a validation batch at ``step``.

    Metrics are computed, written to the registry, and logged only
    when ``step`` is a scheduled monitoring step unless ``force`` is
    true.
    """
    _require_nonnegative_int("step", step)
    if not force and not self.should_log(step):
        return None

    metrics = compute_collapse_metrics(prediction, target, kl_reg=kl_reg)
    if self.initial_pairwise_pred_dist_mean is None:
        self.initial_pairwise_pred_dist_mean = metrics.pairwise_pred_dist_mean
    alerts = detect_collapse(
        metrics,
        thresholds=self.thresholds,
        initial_pairwise_pred_dist_mean=self.initial_pairwise_pred_dist_mean,
    )
    record_collapse_metrics(metrics, alerts=alerts, logger=logger, step=step)
    return CollapseCheck(metrics=metrics, alerts=alerts)

CollapseThresholds dataclass

CollapseThresholds(pred_var_to_target_var: float = 0.5, pairwise_to_initial: float = 0.5, kl_reg_max: float = 10.0)

Alert thresholds from RFC-0005 §3.6.

FixtureTrainingReport dataclass

FixtureTrainingReport(run_id: str, run_dir: Path, steps_requested: int, steps_completed: int, resumed_from_step: int, final_loss: float, checkpoint_path: Path, metrics_path: Path, log_path: Path, config_path: Path, dataset_manifest_path: Path, training_metadata_path: Path)

Summary returned by the deterministic fixture trainer.

AcceleratorProbe dataclass

AcceleratorProbe(requested_device: str | None, required: bool, available: bool, device_count: int, device_name: str | None, total_memory_bytes: int | None, min_memory_bytes: int, reason: str, issue_code: str | None = None)

CUDA accelerator readiness probe for Carbon-backed training.

DependencyProbe dataclass

DependencyProbe(import_name: str, package: str, required: bool, available: bool, version: str | None, reason: str)

Importability probe for one training dependency.

TrainingPreflightIssue dataclass

TrainingPreflightIssue(severity: Severity, code: str, path: str, message: str)

One preflight issue.

TrainingPreflightReport dataclass

TrainingPreflightReport(schema_version: str, generated_by: str, generated_at: str, ok: bool, dataset_snapshot_id: str | None, training_config: dict[str, object], run_dir: dict[str, object], dataset: dict[str, object], carbon: dict[str, object], accelerator: AcceleratorProbe, dependencies: tuple[DependencyProbe, ...], issues: tuple[TrainingPreflightIssue, ...])

Machine-readable readiness evidence for the real training path.

TrainingPreflightRequest dataclass

TrainingPreflightRequest(dataset_dir: Path, carbon_model_dir: Path, training_config: Path, run_dir: Path, allow_fixture_dataset: bool = False, require_native_runtime: bool = True, require_accelerator: bool = True, min_cuda_vram_gb: float = MIN_CUDA_VRAM_GB)

Inputs needed before launching a Carbon-backed training run.

CarbonTrainingReport dataclass

CarbonTrainingReport(run_id: str, run_dir: Path, dataset_snapshot_id: str, steps_requested: int, steps_completed: int, resumed_from_step: int, sample_count: int, final_loss: float, checkpoint_path: Path, resume_checkpoint_path: Path | None, metrics_path: Path, log_path: Path, config_path: Path, preflight_path: Path | None, training_metadata_path: Path)

Summary emitted by the real Carbon-backed trainer.

EditTypeWeight dataclass

EditTypeWeight(edit_type: EditType, weight: float)

One RFC-0005 edit-type sampling weight.

RolloutStepWeight dataclass

RolloutStepWeight(steps: int, weight: float)

One rollout-step-count sampling weight.

TorchDeterminismReport dataclass

TorchDeterminismReport(seed: int, deterministic: bool, cublas_workspace_config: str | None, torch_deterministic_algorithms: bool)

Runtime settings applied before a torch training run.

TorchTrainer

TorchTrainer(*, predictor: object, action_encoder: object, optimizer: object, config: GenoLeWMConfig, total_steps: int)

Minimal optimizer loop for Carbon-state predictor training.

Source code in geno_lewm/training/trainer.py
def __init__(
    self,
    *,
    predictor: object,
    action_encoder: object,
    optimizer: object,
    config: GenoLeWMConfig,
    total_steps: int,
) -> None:
    _require_torch("TorchTrainer")
    _require_positive_int("total_steps", total_steps)
    self.predictor = predictor
    self.action_encoder = action_encoder
    self.optimizer = optimizer
    self.config = config
    self.total_steps = total_steps
    self.collapse_monitor = CollapseMonitor(
        log_every_steps=config.training.collapse_log_every_steps,
    )
    self.last_collapse_alerts: tuple[dict[str, object], ...] = ()

train_step

train_step(batch: TorchTrainerBatch, *, step: int) -> TorchTrainerStepResult

Run one optimizer step over an encoded Carbon-state batch.

Source code in geno_lewm/training/trainer.py
def train_step(self, batch: TorchTrainerBatch, *, step: int) -> TorchTrainerStepResult:
    """Run one optimizer step over an encoded Carbon-state batch."""
    _require_positive_int("step", step)
    if step > self.total_steps:
        raise InputError(
            "step cannot exceed total_steps",
            details={"step": step, "total_steps": self.total_steps},
        )
    lr_multiplier = set_optimizer_lr(
        self.optimizer,
        step=step,
        total_steps=self.total_steps,
        warmup_steps=self.config.optimizer.warmup_steps,
        schedule=self.config.optimizer.schedule,
    )
    _zero_grad(self.optimizer)
    action_embeddings = _call_action_encoder(self.action_encoder, batch.rel_edits)
    prediction = _call_predictor(
        self.predictor,
        state=batch.state,
        actions=action_embeddings,
        action_mask=batch.action_mask,
    )
    loss_result = predictor_loss(
        prediction,
        batch.target,
        phase=self.config.phase,
        mask=batch.action_mask,
    )
    action_count = int(batch.action_mask.sum().item())
    self.last_collapse_alerts = ()
    if action_count > 0:
        collapse_check = self.collapse_monitor.observe(
            _masked_training_rows(prediction, batch.action_mask),
            _masked_training_rows(batch.target, batch.action_mask),
            kl_reg=_scalar(loss_result.kl_reg),
            step=step,
        )
        if collapse_check is not None:
            self.last_collapse_alerts = tuple(
                {
                    "criterion": alert.criterion,
                    "value": alert.value,
                    "threshold": alert.threshold,
                }
                for alert in collapse_check.alerts
            )
    loss_result.loss.backward()
    if self.config.optimizer.grad_clip > 0:
        parameters = _trainable_parameters((self.predictor, self.action_encoder))
        torch.nn.utils.clip_grad_norm_(parameters, self.config.optimizer.grad_clip)
    _optimizer_step(self.optimizer)
    return TorchTrainerStepResult(
        step=step,
        lr_multiplier=lr_multiplier,
        loss=_scalar(loss_result.loss),
        pred_loss=_scalar(loss_result.pred_loss),
        kl_reg=_scalar(loss_result.kl_reg),
        action_count=action_count,
        pred_var_per_dim=_pred_var_per_dim(prediction),
    )

TorchTrainerBatch dataclass

TorchTrainerBatch(state: Tensor, target: Tensor, rel_edits: tuple[tuple[RelEdit, ...], ...], action_mask: Tensor, window_ids: tuple[str, ...])

One encoded minibatch consumed by :class:TorchTrainer.

TorchTrainerStepResult dataclass

TorchTrainerStepResult(step: int, lr_multiplier: float, loss: float, pred_loss: float, kl_reg: float, action_count: int, pred_var_per_dim: float)

Scalar outputs from one optimizer step.

TrainerSeeds dataclass

TrainerSeeds(data: int, predictor: int, lora: int)

Distinct RNG seeds consumed by the real training stack.

compute_collapse_metrics

compute_collapse_metrics(prediction: object, target: object, *, kl_reg: float) -> CollapseMetrics

Compute RFC-0005 §3.6 collapse metrics for one [N, D] batch.

Source code in geno_lewm/training/collapse.py
def compute_collapse_metrics(
    prediction: object,
    target: object,
    *,
    kl_reg: float,
) -> CollapseMetrics:
    """Compute RFC-0005 §3.6 collapse metrics for one ``[N, D]`` batch."""
    pred_rows = _as_rows(prediction, "prediction")
    target_rows = _as_rows(target, "target")
    if len(pred_rows) != len(target_rows):
        raise InputError(
            "prediction and target must have the same batch size",
            details={"prediction_rows": len(pred_rows), "target_rows": len(target_rows)},
        )
    dim = len(pred_rows[0])
    if dim != len(target_rows[0]):
        raise InputError(
            "prediction and target must have the same latent dimension",
            details={"prediction_dim": dim, "target_dim": len(target_rows[0])},
        )

    kl_value = _finite_float(kl_reg, "kl_reg")
    return CollapseMetrics(
        pred_cos_mean=_mean(
            _cosine(pred, tgt) for pred, tgt in zip(pred_rows, target_rows, strict=True)
        ),
        pred_l2_mean=_mean(
            _euclidean(pred, tgt) for pred, tgt in zip(pred_rows, target_rows, strict=True)
        ),
        target_var_per_dim=_mean_variance_per_dim(target_rows),
        pred_var_per_dim=_mean_variance_per_dim(pred_rows),
        pred_target_corr=_pearson_corr(_flatten(pred_rows), _flatten(target_rows)),
        pairwise_pred_dist_mean=_pairwise_dist_mean(pred_rows),
        kl_reg=kl_value,
    )

detect_collapse

detect_collapse(metrics: CollapseMetrics, *, thresholds: CollapseThresholds | None = None, initial_pairwise_pred_dist_mean: float | None = None) -> tuple[CollapseAlert, ...]

Return the RFC-0005 §3.6 alert criteria tripped by metrics.

Source code in geno_lewm/training/collapse.py
def detect_collapse(
    metrics: CollapseMetrics,
    *,
    thresholds: CollapseThresholds | None = None,
    initial_pairwise_pred_dist_mean: float | None = None,
) -> tuple[CollapseAlert, ...]:
    """Return the RFC-0005 §3.6 alert criteria tripped by ``metrics``."""
    active_thresholds = thresholds if thresholds is not None else CollapseThresholds()
    if initial_pairwise_pred_dist_mean is not None:
        _require_nonnegative_finite(
            "initial_pairwise_pred_dist_mean",
            initial_pairwise_pred_dist_mean,
        )

    alerts: list[CollapseAlert] = []
    pred_var_threshold = active_thresholds.pred_var_to_target_var * metrics.target_var_per_dim
    if metrics.pred_var_per_dim < pred_var_threshold:
        alerts.append(
            CollapseAlert(
                criterion="pred_var_per_dim",
                value=metrics.pred_var_per_dim,
                threshold=pred_var_threshold,
            )
        )

    if initial_pairwise_pred_dist_mean is not None:
        pairwise_threshold = active_thresholds.pairwise_to_initial * initial_pairwise_pred_dist_mean
        if metrics.pairwise_pred_dist_mean < pairwise_threshold:
            alerts.append(
                CollapseAlert(
                    criterion="pairwise_pred_dist_mean",
                    value=metrics.pairwise_pred_dist_mean,
                    threshold=pairwise_threshold,
                )
            )

    if metrics.kl_reg > active_thresholds.kl_reg_max:
        alerts.append(
            CollapseAlert(
                criterion="kl_reg",
                value=metrics.kl_reg,
                threshold=active_thresholds.kl_reg_max,
            )
        )

    return tuple(alerts)

record_collapse_metrics

record_collapse_metrics(metrics: CollapseMetrics, *, alerts: Iterable[CollapseAlert] = (), logger: GenoLeWMLogger | None = None, step: int | None = None) -> None

Write collapse metrics to the registry and optional structured logs.

Source code in geno_lewm/training/collapse.py
def record_collapse_metrics(
    metrics: CollapseMetrics,
    *,
    alerts: Iterable[CollapseAlert] = (),
    logger: GenoLeWMLogger | None = None,
    step: int | None = None,
) -> None:
    """Write collapse metrics to the registry and optional structured logs."""
    if step is not None:
        _require_nonnegative_int("step", step)

    get_gauge("geno_lewm.training.collapse.pred_cos_mean").set(metrics.pred_cos_mean)
    get_gauge("geno_lewm.training.collapse.pred_l2_mean").set(metrics.pred_l2_mean)
    get_gauge("geno_lewm.training.collapse.target_var_per_dim").set(metrics.target_var_per_dim)
    get_gauge("geno_lewm.training.collapse.pred_var_per_dim").set(metrics.pred_var_per_dim)
    get_gauge("geno_lewm.training.collapse.pred_target_corr").set(metrics.pred_target_corr)
    get_gauge("geno_lewm.training.collapse.pairwise_pred_dist_mean").set(
        metrics.pairwise_pred_dist_mean
    )
    get_gauge("geno_lewm.training.collapse.kl_reg").set(metrics.kl_reg)

    if logger is not None:
        for name, value in _metric_items(metrics):
            _log_training_metric(logger, name=name, value=value, step=step)

    for alert in alerts:
        get_counter("geno_lewm.training.collapse.alert").inc()
        if logger is not None:
            _log_collapse_alert(logger, alert, step=step)

run_fixture_training

run_fixture_training(*, config: GenoLeWMConfig, run_dir: Path, steps: int = 50, resume_from: Path | None = None, command: str, commit_sha: str, package_version: str) -> FixtureTrainingReport

Run a deterministic scalar smoke trainer and write release artifacts.

steps is the target total step count. When resume_from is supplied, the checkpoint's current step must be lower than steps; the resumed run continues with the same deterministic sample stream.

Source code in geno_lewm/training/fixture.py
def run_fixture_training(
    *,
    config: GenoLeWMConfig,
    run_dir: Path,
    steps: int = 50,
    resume_from: Path | None = None,
    command: str,
    commit_sha: str,
    package_version: str,
) -> FixtureTrainingReport:
    """Run a deterministic scalar smoke trainer and write release artifacts.

    ``steps`` is the target total step count. When ``resume_from`` is
    supplied, the checkpoint's current step must be lower than ``steps``;
    the resumed run continues with the same deterministic sample stream.
    """
    _require_positive_steps(steps)
    run_dir.mkdir(parents=True, exist_ok=True)
    config_path = write_resolved_config(config, run_dir / "config.resolved.yaml")
    dataset_manifest_path = _write_fixture_dataset_manifest(config, run_dir)

    state = (
        _load_checkpoint(resume_from, expected_seed=config.seed)
        if resume_from is not None
        else _initial_state(config.seed)
    )
    resumed_from_step = state.step
    if resumed_from_step >= steps:
        raise InputError(
            "fixture resume checkpoint is already at or beyond --steps",
            details={"checkpoint_step": resumed_from_step, "steps": steps},
        )

    log_path = run_dir / FIXTURE_LOG_NAME
    mode = "a" if resumed_from_step else "w"
    with log_path.open(mode, encoding="utf-8") as log:
        if resumed_from_step:
            log.write(
                f"resume_from_step={resumed_from_step} target_steps={steps} seed={config.seed}\n"
            )
        for step in range(resumed_from_step + 1, steps + 1):
            state = _train_step(state, step=step)
            log.write(
                f"step={step} loss={state.loss:.12f} weight={state.weight:.12f} "
                "collapse_var_min=1.0 nan_loss=false\n"
            )

    checkpoint_path = run_dir / FIXTURE_CHECKPOINT_NAME
    _write_checkpoint(config, state, checkpoint_path)
    metrics_path = run_dir / FIXTURE_METRICS_NAME
    _write_metrics(config, state, metrics_path, steps=steps, resumed_from_step=resumed_from_step)
    training_metadata_path = run_dir / FIXTURE_TRAINING_METADATA_NAME
    _write_training_metadata(
        config,
        training_metadata_path,
        command=command,
        commit_sha=commit_sha,
        package_version=package_version,
    )
    return FixtureTrainingReport(
        run_id=config.run_id,
        run_dir=run_dir,
        steps_requested=steps,
        steps_completed=state.step,
        resumed_from_step=resumed_from_step,
        final_loss=state.loss,
        checkpoint_path=checkpoint_path,
        metrics_path=metrics_path,
        log_path=log_path,
        config_path=config_path,
        dataset_manifest_path=dataset_manifest_path,
        training_metadata_path=training_metadata_path,
    )

build_training_preflight_report

build_training_preflight_report(request: TrainingPreflightRequest, *, generated_at: str | None = None, dependency_probe: DependencyProbeFn | None = None, accelerator_probe: AcceleratorProbeFn | None = None) -> TrainingPreflightReport

Build clean-machine readiness evidence for Carbon-backed training.

Source code in geno_lewm/training/preflight.py
def build_training_preflight_report(
    request: TrainingPreflightRequest,
    *,
    generated_at: str | None = None,
    dependency_probe: DependencyProbeFn | None = None,
    accelerator_probe: AcceleratorProbeFn | None = None,
) -> TrainingPreflightReport:
    """Build clean-machine readiness evidence for Carbon-backed training."""
    issues: list[TrainingPreflightIssue] = []
    dependency_probe = dependency_probe or _probe_dependency
    accelerator_probe = accelerator_probe or _probe_accelerator
    dataset = _inspect_dataset(request.dataset_dir, request.allow_fixture_dataset, issues)
    carbon = _inspect_carbon_model_dir(request.carbon_model_dir, issues)
    training_config = _inspect_training_config(request.training_config, issues)
    run_dir = _inspect_run_dir(request.run_dir)
    min_cuda_memory_bytes = _min_cuda_memory_bytes(request.min_cuda_vram_gb)
    accelerator = accelerator_probe(
        _requested_training_device(training_config),
        request.require_accelerator,
        min_cuda_memory_bytes,
    )
    dependencies = tuple(
        dependency_probe(name, request.require_native_runtime) for name in REQUIRED_TRAINING_MODULES
    )
    if accelerator.required and not accelerator.available:
        _issue(
            issues,
            "error",
            accelerator.issue_code or "training.accelerator_unavailable",
            request.training_config,
            accelerator.reason,
        )
    for probe in dependencies:
        if probe.required and not probe.available:
            _issue(
                issues,
                "error",
                "training.dependency_unavailable",
                probe.import_name,
                probe.reason,
            )
    snapshot_id = dataset.get("snapshot_id")
    return TrainingPreflightReport(
        schema_version=SCHEMA_VERSION,
        generated_by=GENERATED_BY,
        generated_at=_utc_now() if generated_at is None else generated_at,
        ok=not any(issue.severity == "error" for issue in issues),
        dataset_snapshot_id=snapshot_id if isinstance(snapshot_id, str) else None,
        training_config=training_config,
        run_dir=run_dir,
        dataset=dataset,
        carbon=carbon,
        accelerator=accelerator,
        dependencies=dependencies,
        issues=tuple(issues),
    )

write_training_preflight_report

write_training_preflight_report(request: TrainingPreflightRequest, output: Path | None = None, *, generated_at: str | None = None, dependency_probe: DependencyProbeFn | None = None, accelerator_probe: AcceleratorProbeFn | None = None) -> TrainingPreflightReport

Write training_preflight_report.json and return the report.

Source code in geno_lewm/training/preflight.py
def write_training_preflight_report(
    request: TrainingPreflightRequest,
    output: Path | None = None,
    *,
    generated_at: str | None = None,
    dependency_probe: DependencyProbeFn | None = None,
    accelerator_probe: AcceleratorProbeFn | None = None,
) -> TrainingPreflightReport:
    """Write ``training_preflight_report.json`` and return the report."""
    report = build_training_preflight_report(
        request,
        generated_at=generated_at,
        dependency_probe=dependency_probe,
        accelerator_probe=accelerator_probe,
    )
    output = request.run_dir / REPORT_NAME if output is None else output
    output.parent.mkdir(parents=True, exist_ok=True)
    output.write_text(
        json.dumps(report.to_dict(), indent=2, sort_keys=True) + "\n", encoding="utf-8"
    )
    return report

run_carbon_training

run_carbon_training(*, config: GenoLeWMConfig, dataset_dir: Path, carbon_model_dir: Path, run_dir: Path, steps: int, command: str, commit_sha: str, package_version: str, preflight_report: TrainingPreflightReport | None = None, resume_from: Path | None = None) -> CarbonTrainingReport

Run a single-process Carbon-backed training job.

Source code in geno_lewm/training/real.py
def run_carbon_training(
    *,
    config: GenoLeWMConfig,
    dataset_dir: Path,
    carbon_model_dir: Path,
    run_dir: Path,
    steps: int,
    command: str,
    commit_sha: str,
    package_version: str,
    preflight_report: TrainingPreflightReport | None = None,
    resume_from: Path | None = None,
) -> CarbonTrainingReport:
    """Run a single-process Carbon-backed training job."""
    _require_positive_int("steps", steps)
    _require_positive_int("data.batch_size", config.data.batch_size)
    run_dir.mkdir(parents=True, exist_ok=True)
    config_path = write_resolved_config(config, run_dir / _RESOLVED_CONFIG_NAME)
    log_path = run_dir / CARBON_LOG_NAME
    metrics_path = run_dir / CARBON_METRICS_NAME
    checkpoint_path = run_dir / CARBON_CHECKPOINT_NAME
    metadata_path = run_dir / CARBON_TRAINING_METADATA_NAME
    run_dataset_manifest_path = run_dir / "dataset_manifest.json"
    preflight_path = run_dir / REPORT_NAME if (run_dir / REPORT_NAME).is_file() else None

    dataset_manifest = _load_dataset_manifest(dataset_dir)
    shutil.copy2(dataset_dir / "dataset_manifest.json", run_dataset_manifest_path)
    dataset_snapshot_id = _required_text(dataset_manifest, "snapshot_id")
    dataset_files = _dataset_files(dataset_manifest)
    windows = tuple(_load_windows(dataset_dir, dataset_files))
    if not windows:
        raise InputError("Carbon training requires at least one source window")
    gnomad_edits = tuple(_load_gnomad_edits(dataset_dir, dataset_files))
    clinvar_edits = tuple(_load_clinvar_edits(dataset_dir, dataset_files))
    if not gnomad_edits:
        raise InputError("Carbon training requires at least one gnomAD edit")

    device = _training_device(config)
    seeds = TrainerSeeds.from_base_seed(config.seed)
    determinism = configure_torch_reproducibility(
        seed=seeds.predictor, deterministic=config.deterministic
    )
    providers = {
        SOURCE_GNOMAD_COMMON: variant_provider(gnomad_edits),
        SOURCE_SYNTHETIC_SNV: synthetic_snv_provider,
        SOURCE_SYNTHETIC_INDEL: synthetic_indel_provider,
        SOURCE_CLINVAR: variant_provider(clinvar_edits),
    }
    iterator = _repeat_training_items(
        windows,
        providers,
        seed=seeds.data,
        fallback_sources=_dataset_fallback_sources(windows),
    )
    resumed_from_step = 0
    resume_checkpoint: _ResumeCheckpoint | None = None
    if resume_from is not None:
        resume_checkpoint = _validate_resume_checkpoint_payload(
            _load_torch_checkpoint(resume_from),
            path=resume_from,
            config=config,
            dataset_snapshot_id=dataset_snapshot_id,
            seeds=seeds,
            target_steps=steps,
        )
        resumed_from_step = resume_checkpoint.steps_completed
        _skip_training_items(
            iterator,
            item_count=resumed_from_step * config.data.batch_size,
        )

    encoder = CarbonStateEncoder(
        str(carbon_model_dir),
        config.encoder.revision,
        dtype=config.encoder.dtype,
        state_layer=config.encoder.state_layer,
        pool_type=config.encoder.pool_type,
        pool_radius=config.encoder.pool_radius,
        normalize=config.encoder.normalize,
        encoder_hash=_carbon_weights_hash(carbon_model_dir),
        device=device,
        local_files_only=True,
        trust_remote_code=config.encoder.trust_remote_code,
    )
    first_items = _next_batch(iterator, config.data.batch_size)
    first_batch = _encode_items(encoder, first_items, device=device)
    observed_d_state = int(first_batch.state.shape[1])
    if observed_d_state != config.predictor.d_state:
        raise InputError(
            "predictor.d_state must match the Carbon encoder state width",
            details={"predictor.d_state": config.predictor.d_state, "observed": observed_d_state},
            remediation="set predictor.d_state to the encoder output width in the training config",
        )

    action_encoder = _move_trainable_to_device(
        ActionEncoder(d_action=config.action.d_action),
        device,
        label="action_encoder",
    )
    predictor = _move_trainable_to_device(build_predictor(config), device, label="predictor")
    optimizer = build_adamw_optimizer(
        predictor=predictor, action_encoder=action_encoder, config=config
    )
    if resume_checkpoint is not None:
        _restore_resume_checkpoint(
            resume_checkpoint.payload,
            predictor=predictor,
            action_encoder=action_encoder,
            optimizer=optimizer,
        )
    trainer = TorchTrainer(
        predictor=predictor,
        action_encoder=action_encoder,
        optimizer=optimizer,
        config=config,
        total_steps=steps,
    )
    progress_every = max(1, int(config.training.collapse_log_every_steps))

    step_results = []
    collapse_alert_count = 0
    sample_count = resumed_from_step * config.data.batch_size
    log_mode = "a" if resumed_from_step else "w"
    progress_logger = get_logger("training", run_id=config.run_id, log_dir=run_dir)
    with log_path.open(log_mode, encoding="utf-8") as log:
        if resumed_from_step:
            log.write(
                json.dumps(
                    {
                        "event": "train.resume",
                        "run_id": config.run_id,
                        "resume_from": None
                        if resume_from is None
                        else _public_resume_path(resume_from),
                        "resumed_from_step": resumed_from_step,
                        "target_steps": steps,
                    },
                    sort_keys=True,
                )
                + "\n"
            )
        else:
            log.write(json.dumps({"event": "train.start", "run_id": config.run_id}) + "\n")
        current_batch = first_batch
        first_step = resumed_from_step + 1
        for step in range(first_step, steps + 1):
            if step > first_step:
                current_batch = _encode_items(
                    encoder,
                    _next_batch(iterator, config.data.batch_size),
                    device=device,
                )
            result = trainer.train_step(current_batch, step=step)
            step_results.append(result)
            collapse_alerts = _last_collapse_alerts(trainer)
            collapse_alert_count += len(collapse_alerts)
            sample_count += len(current_batch.window_ids)
            log.write(json.dumps({"event": "train.step", **result.to_dict()}) + "\n")
            if step in (first_step, steps) or step % progress_every == 0:
                progress_logger.info(
                    "training.metric",
                    step=step,
                    name="sample_count",
                    value=sample_count,
                    unit="samples",
                    kind="counter",
                )
                progress_logger.info(
                    "training.metric",
                    step=step,
                    name="loss",
                    value=result.loss,
                    unit="unitless",
                    kind="gauge",
                )
                progress_logger.info(
                    "training.metric",
                    step=step,
                    name="pred_var_per_dim",
                    value=result.pred_var_per_dim,
                    unit="unitless",
                    kind="gauge",
                )
            for alert in collapse_alerts:
                log.write(
                    json.dumps(
                        {"event": "training.collapse.alert", "step": step, **alert},
                        sort_keys=True,
                    )
                    + "\n"
                )
        log.write(json.dumps({"event": "train.end", "steps_completed": steps}) + "\n")

    final = step_results[-1]
    _write_metrics(
        metrics_path,
        config=config,
        steps=steps,
        resumed_from_step=resumed_from_step,
        sample_count=sample_count,
        final_loss=final.loss,
        step_results=step_results,
        collapse_alert_count=collapse_alert_count,
        dataset_snapshot_id=dataset_snapshot_id,
        resume_checkpoint_path=resume_from,
    )
    _write_checkpoint(
        checkpoint_path,
        predictor=predictor,
        action_encoder=action_encoder,
        optimizer=optimizer,
        config=config,
        dataset_snapshot_id=dataset_snapshot_id,
        steps=steps,
        seeds=seeds,
    )
    _write_training_metadata(
        metadata_path,
        config=config,
        command=command,
        commit_sha=commit_sha,
        package_version=package_version,
        dataset_snapshot_id=dataset_snapshot_id,
        seeds=seeds,
        determinism=determinism.to_dict(),
        artifacts={
            "training_config": config_path.name,
            "metrics": metrics_path.name,
            "logs": [log_path.name],
            "checkpoint_files": [checkpoint_path.name],
            "dataset_manifest": run_dataset_manifest_path.name,
        },
        preflight_report=preflight_report,
        final_loss=final.loss,
        sample_count=sample_count,
        resumed_from_step=resumed_from_step,
        resume_checkpoint_path=resume_from,
    )
    return CarbonTrainingReport(
        run_id=config.run_id,
        run_dir=run_dir,
        dataset_snapshot_id=dataset_snapshot_id,
        steps_requested=steps,
        steps_completed=steps,
        resumed_from_step=resumed_from_step,
        sample_count=sample_count,
        final_loss=final.loss,
        checkpoint_path=checkpoint_path,
        resume_checkpoint_path=resume_from,
        metrics_path=metrics_path,
        log_path=log_path,
        config_path=config_path,
        preflight_path=preflight_path,
        training_metadata_path=metadata_path,
    )

draw_edit_type_counts

draw_edit_type_counts(n: int, *, rng: Random, weights: Sequence[EditTypeWeight] = DEFAULT_EDIT_TYPE_WEIGHTS) -> dict[EditType, int]

Draw n edit types and return counts by :class:EditType.

Source code in geno_lewm/training/sampling.py
def draw_edit_type_counts(
    n: int,
    *,
    rng: random.Random,
    weights: Sequence[EditTypeWeight] = DEFAULT_EDIT_TYPE_WEIGHTS,
) -> dict[EditType, int]:
    """Draw ``n`` edit types and return counts by :class:`EditType`."""
    _require_nonnegative_int("n", n)
    entries = _validate_edit_type_weights(weights)
    counts = {entry.edit_type: 0 for entry in entries}
    for _ in range(n):
        counts[_sample_weighted(rng, entries).edit_type] += 1
    return counts

draw_rollout_step_counts

draw_rollout_step_counts(n: int, *, rng: Random, mix: Sequence[RolloutStepWeight] = DEFAULT_ROLLOUT_STEP_MIX) -> dict[int, int]

Draw n rollout lengths and return counts by step count.

Source code in geno_lewm/training/sampling.py
def draw_rollout_step_counts(
    n: int,
    *,
    rng: random.Random,
    mix: Sequence[RolloutStepWeight] = DEFAULT_ROLLOUT_STEP_MIX,
) -> dict[int, int]:
    """Draw ``n`` rollout lengths and return counts by step count."""
    _require_nonnegative_int("n", n)
    entries = _validate_rollout_mix(mix)
    counts = {entry.steps: 0 for entry in entries}
    for _ in range(n):
        counts[_sample_weighted(rng, entries).steps] += 1
    return counts

sample_edit_type

sample_edit_type(rng: Random, *, weights: Sequence[EditTypeWeight] = DEFAULT_EDIT_TYPE_WEIGHTS) -> EditType

Sample one edit type from the RFC-0005 edit-balanced distribution.

Source code in geno_lewm/training/sampling.py
def sample_edit_type(
    rng: random.Random,
    *,
    weights: Sequence[EditTypeWeight] = DEFAULT_EDIT_TYPE_WEIGHTS,
) -> EditType:
    """Sample one edit type from the RFC-0005 edit-balanced distribution."""
    return _sample_weighted(rng, _validate_edit_type_weights(weights)).edit_type

sample_rollout_steps

sample_rollout_steps(rng: Random, *, mix: Sequence[RolloutStepWeight] = DEFAULT_ROLLOUT_STEP_MIX) -> int

Sample a rollout length K from the Phase-1 RFC-0005 mix.

Source code in geno_lewm/training/sampling.py
def sample_rollout_steps(
    rng: random.Random,
    *,
    mix: Sequence[RolloutStepWeight] = DEFAULT_ROLLOUT_STEP_MIX,
) -> int:
    """Sample a rollout length ``K`` from the Phase-1 RFC-0005 mix."""
    return _sample_weighted(rng, _validate_rollout_mix(mix)).steps

build_adamw_optimizer

build_adamw_optimizer(*, predictor: object, action_encoder: object, config: GenoLeWMConfig) -> object

Build AdamW groups for predictor/action-encoder trainable parameters.

Source code in geno_lewm/training/trainer.py
def build_adamw_optimizer(
    *,
    predictor: object,
    action_encoder: object,
    config: GenoLeWMConfig,
) -> object:
    """Build AdamW groups for predictor/action-encoder trainable parameters."""
    _require_torch("build_adamw_optimizer")
    if config.optimizer.name != "adamw":
        raise InputError(
            "real trainer currently supports AdamW only",
            details={"optimizer": config.optimizer.name},
        )
    groups = _adamw_param_groups(
        (("predictor", predictor), ("action_encoder", action_encoder)),
        lr=float(config.optimizer.lr),
        weight_decay=float(config.optimizer.weight_decay),
    )
    if not groups:
        raise InputError("no trainable predictor/action-encoder parameters found")
    return torch.optim.AdamW(
        groups,
        betas=(float(config.optimizer.beta1), float(config.optimizer.beta2)),
        eps=1.0e-8,
    )

configure_torch_reproducibility

configure_torch_reproducibility(*, seed: int, deterministic: bool) -> TorchDeterminismReport

Seed Python/NumPy/PyTorch and optionally enable deterministic torch kernels.

Source code in geno_lewm/training/trainer.py
def configure_torch_reproducibility(
    *,
    seed: int,
    deterministic: bool,
) -> TorchDeterminismReport:
    """Seed Python/NumPy/PyTorch and optionally enable deterministic torch kernels."""
    _require_torch("configure_torch_reproducibility")
    _require_nonnegative_int("seed", seed)
    random.seed(seed)
    _seed_numpy(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():  # pragma: no cover - depends on host accelerator.
        torch.cuda.manual_seed_all(seed)
    cublas_config: str | None = None
    if deterministic:
        cublas_config = os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
        torch.use_deterministic_algorithms(True)
        cudnn = getattr(torch.backends, "cudnn", None)
        if cudnn is not None:
            cudnn.benchmark = False
    else:
        use_deterministic = getattr(torch, "use_deterministic_algorithms", None)
        if callable(use_deterministic):
            use_deterministic(False)
    return TorchDeterminismReport(
        seed=seed,
        deterministic=deterministic,
        cublas_workspace_config=cublas_config,
        torch_deterministic_algorithms=bool(torch.are_deterministic_algorithms_enabled()),
    )

encode_training_batch

encode_training_batch(*, encoder: object, tuples: Sequence[TrainingTuple], source_windows: Mapping[str, str], device: str | object | None = None, dtype: object | None = None) -> TorchTrainerBatch

Encode source/target windows for a real predictor-training minibatch.

Source code in geno_lewm/training/trainer.py
def encode_training_batch(
    *,
    encoder: object,
    tuples: Sequence[TrainingTuple],
    source_windows: Mapping[str, str],
    device: str | object | None = None,
    dtype: object | None = None,
) -> TorchTrainerBatch:
    """Encode source/target windows for a real predictor-training minibatch."""
    _require_torch("encode_training_batch")
    if not tuples:
        raise InputError("training batch must contain at least one tuple")
    source_sequences: list[str] = []
    target_sequences: list[str] = []
    rel_edits: list[tuple[RelEdit, ...]] = []
    window_ids: list[str] = []
    for item in tuples:
        if not isinstance(item, TrainingTuple):
            raise InputError(
                "tuples must contain TrainingTuple values",
                details={"type": type(item).__name__},
            )
        try:
            source_sequences.append(source_windows[item.window_id])
        except KeyError as exc:
            raise InputError(
                "source window sequence missing for training tuple",
                details={"window_id": item.window_id},
            ) from exc
        target_sequences.append(item.target_window)
        rel_edits.append(tuple(item.rel_edits))
        window_ids.append(item.window_id)
    target_loci = [edits[0].rel_pos if edits else None for edits in rel_edits]
    encoder_runtime = _encoder_runtime(encoder)
    source_states = _source_states(encoder_runtime, source_sequences)
    target_states = encoder_runtime.encode_batch(target_sequences, target_loci)
    state = torch.tensor(source_states, dtype=dtype or torch.float32, device=device)
    target_single = torch.tensor(target_states, dtype=state.dtype, device=state.device)
    mask = make_action_mask(rel_edits, device=state.device)
    target = target_single.unsqueeze(1).expand(-1, mask.shape[1], -1).clone()
    target = target.masked_fill(~mask.unsqueeze(-1), 0.0)
    return TorchTrainerBatch(
        state=state,
        target=target,
        rel_edits=tuple(rel_edits),
        action_mask=mask,
        window_ids=tuple(window_ids),
    )

make_action_mask

make_action_mask(rel_edits: Sequence[Sequence[object]], *, device: object | None = None) -> Tensor

Return a boolean action mask for a ragged batch of relative edits.

Source code in geno_lewm/training/trainer.py
def make_action_mask(
    rel_edits: Sequence[Sequence[object]],
    *,
    device: object | None = None,
) -> Tensor:
    """Return a boolean action mask for a ragged batch of relative edits."""
    _require_torch("make_action_mask")
    if not rel_edits:
        raise InputError("rel_edits must contain at least one batch item")
    lengths = []
    for edits in rel_edits:
        if isinstance(edits, str | bytes) or not isinstance(edits, Sequence):
            raise InputError(
                "rel_edits entries must be sequences",
                details={"type": type(edits).__name__},
            )
        if not edits:
            raise InputError("each training batch item must include at least one edit")
        lengths.append(len(edits))
    max_len = max(lengths)
    mask = torch.zeros((len(lengths), max_len), dtype=torch.bool, device=device)
    for row, length in enumerate(lengths):
        mask[row, :length] = True
    return mask

set_optimizer_lr

set_optimizer_lr(optimizer: object, *, step: int, total_steps: int, warmup_steps: int, schedule: ScheduleName = 'wsd') -> float

Set optimizer group LRs from each group's initial_lr and return multiplier.

Source code in geno_lewm/training/trainer.py
def set_optimizer_lr(
    optimizer: object,
    *,
    step: int,
    total_steps: int,
    warmup_steps: int,
    schedule: ScheduleName = "wsd",
) -> float:
    """Set optimizer group LRs from each group's ``initial_lr`` and return multiplier."""
    multiplier = wsd_lr_multiplier(
        step,
        total_steps=total_steps,
        warmup_steps=warmup_steps,
        schedule=schedule,
    )
    groups = getattr(optimizer, "param_groups", None)
    if not isinstance(groups, list) or not groups:
        raise InputError("optimizer must expose non-empty param_groups")
    for group in groups:
        if not isinstance(group, dict):
            raise InputError("optimizer param_groups must contain dictionaries")
        base_lr = group.setdefault("initial_lr", group.get("lr"))
        if isinstance(base_lr, bool) or not isinstance(base_lr, int | float):
            raise InputError("optimizer param group lr must be numeric")
        group["lr"] = float(base_lr) * multiplier
    return multiplier

wsd_lr_multiplier

wsd_lr_multiplier(step: int, *, total_steps: int, warmup_steps: int, schedule: ScheduleName = 'wsd') -> float

Return the RFC-0005 WSD learning-rate multiplier for a 1-indexed step.

Source code in geno_lewm/training/trainer.py
def wsd_lr_multiplier(
    step: int,
    *,
    total_steps: int,
    warmup_steps: int,
    schedule: ScheduleName = "wsd",
) -> float:
    """Return the RFC-0005 WSD learning-rate multiplier for a 1-indexed step."""
    _require_positive_int("step", step)
    _require_positive_int("total_steps", total_steps)
    _require_nonnegative_int("warmup_steps", warmup_steps)
    if step > total_steps:
        raise InputError(
            "step cannot exceed total_steps",
            details={"step": step, "total_steps": total_steps},
        )
    if schedule == "constant":
        return 1.0
    if schedule == "cosine":
        if total_steps == 1:
            return 1.0
        progress = (step - 1) / (total_steps - 1)
        return 0.5 * (1.0 + float(torchless_cos(progress)))
    if schedule != "wsd":
        raise InputError("unsupported learning-rate schedule", details={"schedule": schedule})
    if warmup_steps > 0 and step <= warmup_steps:
        return step / warmup_steps
    if total_steps <= warmup_steps:
        return 1.0
    post_warmup = total_steps - warmup_steps
    decay_start = warmup_steps + max(1, int(post_warmup * 0.80))
    final_taper_start = warmup_steps + max(1, int(post_warmup * 0.98))
    if step <= decay_start:
        return 1.0
    if step <= final_taper_start:
        span = max(1, final_taper_start - decay_start)
        progress = (step - decay_start) / span
        return 1.0 - 0.9 * progress
    span = max(1, total_steps - final_taper_start)
    progress = (step - final_taper_start) / span
    return max(0.01, 0.1 - 0.09 * progress)