Skip to content

geno_lewm.training.fixture

fixture

Deterministic fixture-tier trainer for the geno-lewm-train smoke path.

This module is intentionally not the Carbon-backed trainer. It gives the release workflow a clean-machine, dependency-light training command that exercises config resolution, checkpointing, deterministic resume, metrics, logs, and training-run metadata before the heavy ML stack is available.

FixtureTrainingReport dataclass

FixtureTrainingReport(run_id: str, run_dir: Path, steps_requested: int, steps_completed: int, resumed_from_step: int, final_loss: float, checkpoint_path: Path, metrics_path: Path, log_path: Path, config_path: Path, dataset_manifest_path: Path, training_metadata_path: Path)

Summary returned by the deterministic fixture trainer.

run_fixture_training

run_fixture_training(*, config: GenoLeWMConfig, run_dir: Path, steps: int = 50, resume_from: Path | None = None, command: str, commit_sha: str, package_version: str) -> FixtureTrainingReport

Run a deterministic scalar smoke trainer and write release artifacts.

steps is the target total step count. When resume_from is supplied, the checkpoint's current step must be lower than steps; the resumed run continues with the same deterministic sample stream.

Source code in geno_lewm/training/fixture.py
def run_fixture_training(
    *,
    config: GenoLeWMConfig,
    run_dir: Path,
    steps: int = 50,
    resume_from: Path | None = None,
    command: str,
    commit_sha: str,
    package_version: str,
) -> FixtureTrainingReport:
    """Run a deterministic scalar smoke trainer and write release artifacts.

    ``steps`` is the target total step count. When ``resume_from`` is
    supplied, the checkpoint's current step must be lower than ``steps``;
    the resumed run continues with the same deterministic sample stream.
    """
    _require_positive_steps(steps)
    run_dir.mkdir(parents=True, exist_ok=True)
    config_path = write_resolved_config(config, run_dir / "config.resolved.yaml")
    dataset_manifest_path = _write_fixture_dataset_manifest(config, run_dir)

    state = (
        _load_checkpoint(resume_from, expected_seed=config.seed)
        if resume_from is not None
        else _initial_state(config.seed)
    )
    resumed_from_step = state.step
    if resumed_from_step >= steps:
        raise InputError(
            "fixture resume checkpoint is already at or beyond --steps",
            details={"checkpoint_step": resumed_from_step, "steps": steps},
        )

    log_path = run_dir / FIXTURE_LOG_NAME
    mode = "a" if resumed_from_step else "w"
    with log_path.open(mode, encoding="utf-8") as log:
        if resumed_from_step:
            log.write(
                f"resume_from_step={resumed_from_step} target_steps={steps} seed={config.seed}\n"
            )
        for step in range(resumed_from_step + 1, steps + 1):
            state = _train_step(state, step=step)
            log.write(
                f"step={step} loss={state.loss:.12f} weight={state.weight:.12f} "
                "collapse_var_min=1.0 nan_loss=false\n"
            )

    checkpoint_path = run_dir / FIXTURE_CHECKPOINT_NAME
    _write_checkpoint(config, state, checkpoint_path)
    metrics_path = run_dir / FIXTURE_METRICS_NAME
    _write_metrics(config, state, metrics_path, steps=steps, resumed_from_step=resumed_from_step)
    training_metadata_path = run_dir / FIXTURE_TRAINING_METADATA_NAME
    _write_training_metadata(
        config,
        training_metadata_path,
        command=command,
        commit_sha=commit_sha,
        package_version=package_version,
    )
    return FixtureTrainingReport(
        run_id=config.run_id,
        run_dir=run_dir,
        steps_requested=steps,
        steps_completed=state.step,
        resumed_from_step=resumed_from_step,
        final_loss=state.loss,
        checkpoint_path=checkpoint_path,
        metrics_path=metrics_path,
        log_path=log_path,
        config_path=config_path,
        dataset_manifest_path=dataset_manifest_path,
        training_metadata_path=training_metadata_path,
    )