Skip to content

geno_lewm.training.real

real

Single-process Carbon-backed training launcher.

This module owns the real training orchestration boundary used by geno-lewm-train --carbon-train. It remains optional-runtime code: imports are lightweight, while execution requires a geno-lewm[train] environment with local Carbon model files and a packaged dataset.

CarbonTrainingReport dataclass

CarbonTrainingReport(run_id: str, run_dir: Path, dataset_snapshot_id: str, steps_requested: int, steps_completed: int, resumed_from_step: int, sample_count: int, final_loss: float, checkpoint_path: Path, resume_checkpoint_path: Path | None, metrics_path: Path, log_path: Path, config_path: Path, preflight_path: Path | None, training_metadata_path: Path)

Summary emitted by the real Carbon-backed trainer.

run_carbon_training

run_carbon_training(*, config: GenoLeWMConfig, dataset_dir: Path, carbon_model_dir: Path, run_dir: Path, steps: int, command: str, commit_sha: str, package_version: str, preflight_report: TrainingPreflightReport | None = None, resume_from: Path | None = None) -> CarbonTrainingReport

Run a single-process Carbon-backed training job.

Source code in geno_lewm/training/real.py
def run_carbon_training(
    *,
    config: GenoLeWMConfig,
    dataset_dir: Path,
    carbon_model_dir: Path,
    run_dir: Path,
    steps: int,
    command: str,
    commit_sha: str,
    package_version: str,
    preflight_report: TrainingPreflightReport | None = None,
    resume_from: Path | None = None,
) -> CarbonTrainingReport:
    """Run a single-process Carbon-backed training job."""
    _require_positive_int("steps", steps)
    _require_positive_int("data.batch_size", config.data.batch_size)
    run_dir.mkdir(parents=True, exist_ok=True)
    config_path = write_resolved_config(config, run_dir / _RESOLVED_CONFIG_NAME)
    log_path = run_dir / CARBON_LOG_NAME
    metrics_path = run_dir / CARBON_METRICS_NAME
    checkpoint_path = run_dir / CARBON_CHECKPOINT_NAME
    metadata_path = run_dir / CARBON_TRAINING_METADATA_NAME
    run_dataset_manifest_path = run_dir / "dataset_manifest.json"
    preflight_path = run_dir / REPORT_NAME if (run_dir / REPORT_NAME).is_file() else None

    dataset_manifest = _load_dataset_manifest(dataset_dir)
    shutil.copy2(dataset_dir / "dataset_manifest.json", run_dataset_manifest_path)
    dataset_snapshot_id = _required_text(dataset_manifest, "snapshot_id")
    dataset_files = _dataset_files(dataset_manifest)
    windows = tuple(_load_windows(dataset_dir, dataset_files))
    if not windows:
        raise InputError("Carbon training requires at least one source window")
    gnomad_edits = tuple(_load_gnomad_edits(dataset_dir, dataset_files))
    clinvar_edits = tuple(_load_clinvar_edits(dataset_dir, dataset_files))
    if not gnomad_edits:
        raise InputError("Carbon training requires at least one gnomAD edit")

    device = _training_device(config)
    seeds = TrainerSeeds.from_base_seed(config.seed)
    determinism = configure_torch_reproducibility(
        seed=seeds.predictor, deterministic=config.deterministic
    )
    providers = {
        SOURCE_GNOMAD_COMMON: variant_provider(gnomad_edits),
        SOURCE_SYNTHETIC_SNV: synthetic_snv_provider,
        SOURCE_SYNTHETIC_INDEL: synthetic_indel_provider,
        SOURCE_CLINVAR: variant_provider(clinvar_edits),
    }
    iterator = _repeat_training_items(
        windows,
        providers,
        seed=seeds.data,
        fallback_sources=_dataset_fallback_sources(windows),
    )
    resumed_from_step = 0
    resume_checkpoint: _ResumeCheckpoint | None = None
    if resume_from is not None:
        resume_checkpoint = _validate_resume_checkpoint_payload(
            _load_torch_checkpoint(resume_from),
            path=resume_from,
            config=config,
            dataset_snapshot_id=dataset_snapshot_id,
            seeds=seeds,
            target_steps=steps,
        )
        resumed_from_step = resume_checkpoint.steps_completed
        _skip_training_items(
            iterator,
            item_count=resumed_from_step * config.data.batch_size,
        )

    encoder = CarbonStateEncoder(
        str(carbon_model_dir),
        config.encoder.revision,
        dtype=config.encoder.dtype,
        state_layer=config.encoder.state_layer,
        pool_type=config.encoder.pool_type,
        pool_radius=config.encoder.pool_radius,
        normalize=config.encoder.normalize,
        encoder_hash=_carbon_weights_hash(carbon_model_dir),
        device=device,
        local_files_only=True,
        trust_remote_code=config.encoder.trust_remote_code,
    )
    first_items = _next_batch(iterator, config.data.batch_size)
    first_batch = _encode_items(encoder, first_items, device=device)
    observed_d_state = int(first_batch.state.shape[1])
    if observed_d_state != config.predictor.d_state:
        raise InputError(
            "predictor.d_state must match the Carbon encoder state width",
            details={"predictor.d_state": config.predictor.d_state, "observed": observed_d_state},
            remediation="set predictor.d_state to the encoder output width in the training config",
        )

    action_encoder = _move_trainable_to_device(
        ActionEncoder(d_action=config.action.d_action),
        device,
        label="action_encoder",
    )
    predictor = _move_trainable_to_device(build_predictor(config), device, label="predictor")
    optimizer = build_adamw_optimizer(
        predictor=predictor, action_encoder=action_encoder, config=config
    )
    if resume_checkpoint is not None:
        _restore_resume_checkpoint(
            resume_checkpoint.payload,
            predictor=predictor,
            action_encoder=action_encoder,
            optimizer=optimizer,
        )
    trainer = TorchTrainer(
        predictor=predictor,
        action_encoder=action_encoder,
        optimizer=optimizer,
        config=config,
        total_steps=steps,
    )
    progress_every = max(1, int(config.training.collapse_log_every_steps))

    step_results = []
    collapse_alert_count = 0
    sample_count = resumed_from_step * config.data.batch_size
    log_mode = "a" if resumed_from_step else "w"
    progress_logger = get_logger("training", run_id=config.run_id, log_dir=run_dir)
    with log_path.open(log_mode, encoding="utf-8") as log:
        if resumed_from_step:
            log.write(
                json.dumps(
                    {
                        "event": "train.resume",
                        "run_id": config.run_id,
                        "resume_from": None
                        if resume_from is None
                        else _public_resume_path(resume_from),
                        "resumed_from_step": resumed_from_step,
                        "target_steps": steps,
                    },
                    sort_keys=True,
                )
                + "\n"
            )
        else:
            log.write(json.dumps({"event": "train.start", "run_id": config.run_id}) + "\n")
        current_batch = first_batch
        first_step = resumed_from_step + 1
        for step in range(first_step, steps + 1):
            if step > first_step:
                current_batch = _encode_items(
                    encoder,
                    _next_batch(iterator, config.data.batch_size),
                    device=device,
                )
            result = trainer.train_step(current_batch, step=step)
            step_results.append(result)
            collapse_alerts = _last_collapse_alerts(trainer)
            collapse_alert_count += len(collapse_alerts)
            sample_count += len(current_batch.window_ids)
            log.write(json.dumps({"event": "train.step", **result.to_dict()}) + "\n")
            if step in (first_step, steps) or step % progress_every == 0:
                progress_logger.info(
                    "training.metric",
                    step=step,
                    name="sample_count",
                    value=sample_count,
                    unit="samples",
                    kind="counter",
                )
                progress_logger.info(
                    "training.metric",
                    step=step,
                    name="loss",
                    value=result.loss,
                    unit="unitless",
                    kind="gauge",
                )
                progress_logger.info(
                    "training.metric",
                    step=step,
                    name="pred_var_per_dim",
                    value=result.pred_var_per_dim,
                    unit="unitless",
                    kind="gauge",
                )
            for alert in collapse_alerts:
                log.write(
                    json.dumps(
                        {"event": "training.collapse.alert", "step": step, **alert},
                        sort_keys=True,
                    )
                    + "\n"
                )
        log.write(json.dumps({"event": "train.end", "steps_completed": steps}) + "\n")

    final = step_results[-1]
    _write_metrics(
        metrics_path,
        config=config,
        steps=steps,
        resumed_from_step=resumed_from_step,
        sample_count=sample_count,
        final_loss=final.loss,
        step_results=step_results,
        collapse_alert_count=collapse_alert_count,
        dataset_snapshot_id=dataset_snapshot_id,
        resume_checkpoint_path=resume_from,
    )
    _write_checkpoint(
        checkpoint_path,
        predictor=predictor,
        action_encoder=action_encoder,
        optimizer=optimizer,
        config=config,
        dataset_snapshot_id=dataset_snapshot_id,
        steps=steps,
        seeds=seeds,
    )
    _write_training_metadata(
        metadata_path,
        config=config,
        command=command,
        commit_sha=commit_sha,
        package_version=package_version,
        dataset_snapshot_id=dataset_snapshot_id,
        seeds=seeds,
        determinism=determinism.to_dict(),
        artifacts={
            "training_config": config_path.name,
            "metrics": metrics_path.name,
            "logs": [log_path.name],
            "checkpoint_files": [checkpoint_path.name],
            "dataset_manifest": run_dataset_manifest_path.name,
        },
        preflight_report=preflight_report,
        final_loss=final.loss,
        sample_count=sample_count,
        resumed_from_step=resumed_from_step,
        resume_checkpoint_path=resume_from,
    )
    return CarbonTrainingReport(
        run_id=config.run_id,
        run_dir=run_dir,
        dataset_snapshot_id=dataset_snapshot_id,
        steps_requested=steps,
        steps_completed=steps,
        resumed_from_step=resumed_from_step,
        sample_count=sample_count,
        final_loss=final.loss,
        checkpoint_path=checkpoint_path,
        resume_checkpoint_path=resume_from,
        metrics_path=metrics_path,
        log_path=log_path,
        config_path=config_path,
        preflight_path=preflight_path,
        training_metadata_path=metadata_path,
    )