def run_carbon_training(
*,
config: GenoLeWMConfig,
dataset_dir: Path,
carbon_model_dir: Path,
run_dir: Path,
steps: int,
command: str,
commit_sha: str,
package_version: str,
preflight_report: TrainingPreflightReport | None = None,
resume_from: Path | None = None,
) -> CarbonTrainingReport:
"""Run a single-process Carbon-backed training job."""
_require_positive_int("steps", steps)
_require_positive_int("data.batch_size", config.data.batch_size)
run_dir.mkdir(parents=True, exist_ok=True)
config_path = write_resolved_config(config, run_dir / _RESOLVED_CONFIG_NAME)
log_path = run_dir / CARBON_LOG_NAME
metrics_path = run_dir / CARBON_METRICS_NAME
checkpoint_path = run_dir / CARBON_CHECKPOINT_NAME
metadata_path = run_dir / CARBON_TRAINING_METADATA_NAME
run_dataset_manifest_path = run_dir / "dataset_manifest.json"
preflight_path = run_dir / REPORT_NAME if (run_dir / REPORT_NAME).is_file() else None
dataset_manifest = _load_dataset_manifest(dataset_dir)
shutil.copy2(dataset_dir / "dataset_manifest.json", run_dataset_manifest_path)
dataset_snapshot_id = _required_text(dataset_manifest, "snapshot_id")
dataset_files = _dataset_files(dataset_manifest)
windows = tuple(_load_windows(dataset_dir, dataset_files))
if not windows:
raise InputError("Carbon training requires at least one source window")
gnomad_edits = tuple(_load_gnomad_edits(dataset_dir, dataset_files))
clinvar_edits = tuple(_load_clinvar_edits(dataset_dir, dataset_files))
if not gnomad_edits:
raise InputError("Carbon training requires at least one gnomAD edit")
device = _training_device(config)
seeds = TrainerSeeds.from_base_seed(config.seed)
determinism = configure_torch_reproducibility(
seed=seeds.predictor, deterministic=config.deterministic
)
providers = {
SOURCE_GNOMAD_COMMON: variant_provider(gnomad_edits),
SOURCE_SYNTHETIC_SNV: synthetic_snv_provider,
SOURCE_SYNTHETIC_INDEL: synthetic_indel_provider,
SOURCE_CLINVAR: variant_provider(clinvar_edits),
}
iterator = _repeat_training_items(
windows,
providers,
seed=seeds.data,
fallback_sources=_dataset_fallback_sources(windows),
)
resumed_from_step = 0
resume_checkpoint: _ResumeCheckpoint | None = None
if resume_from is not None:
resume_checkpoint = _validate_resume_checkpoint_payload(
_load_torch_checkpoint(resume_from),
path=resume_from,
config=config,
dataset_snapshot_id=dataset_snapshot_id,
seeds=seeds,
target_steps=steps,
)
resumed_from_step = resume_checkpoint.steps_completed
_skip_training_items(
iterator,
item_count=resumed_from_step * config.data.batch_size,
)
encoder = CarbonStateEncoder(
str(carbon_model_dir),
config.encoder.revision,
dtype=config.encoder.dtype,
state_layer=config.encoder.state_layer,
pool_type=config.encoder.pool_type,
pool_radius=config.encoder.pool_radius,
normalize=config.encoder.normalize,
encoder_hash=_carbon_weights_hash(carbon_model_dir),
device=device,
local_files_only=True,
trust_remote_code=config.encoder.trust_remote_code,
)
first_items = _next_batch(iterator, config.data.batch_size)
first_batch = _encode_items(encoder, first_items, device=device)
observed_d_state = int(first_batch.state.shape[1])
if observed_d_state != config.predictor.d_state:
raise InputError(
"predictor.d_state must match the Carbon encoder state width",
details={"predictor.d_state": config.predictor.d_state, "observed": observed_d_state},
remediation="set predictor.d_state to the encoder output width in the training config",
)
action_encoder = _move_trainable_to_device(
ActionEncoder(d_action=config.action.d_action),
device,
label="action_encoder",
)
predictor = _move_trainable_to_device(build_predictor(config), device, label="predictor")
optimizer = build_adamw_optimizer(
predictor=predictor, action_encoder=action_encoder, config=config
)
if resume_checkpoint is not None:
_restore_resume_checkpoint(
resume_checkpoint.payload,
predictor=predictor,
action_encoder=action_encoder,
optimizer=optimizer,
)
trainer = TorchTrainer(
predictor=predictor,
action_encoder=action_encoder,
optimizer=optimizer,
config=config,
total_steps=steps,
)
progress_every = max(1, int(config.training.collapse_log_every_steps))
step_results = []
collapse_alert_count = 0
sample_count = resumed_from_step * config.data.batch_size
log_mode = "a" if resumed_from_step else "w"
progress_logger = get_logger("training", run_id=config.run_id, log_dir=run_dir)
with log_path.open(log_mode, encoding="utf-8") as log:
if resumed_from_step:
log.write(
json.dumps(
{
"event": "train.resume",
"run_id": config.run_id,
"resume_from": None
if resume_from is None
else _public_resume_path(resume_from),
"resumed_from_step": resumed_from_step,
"target_steps": steps,
},
sort_keys=True,
)
+ "\n"
)
else:
log.write(json.dumps({"event": "train.start", "run_id": config.run_id}) + "\n")
current_batch = first_batch
first_step = resumed_from_step + 1
for step in range(first_step, steps + 1):
if step > first_step:
current_batch = _encode_items(
encoder,
_next_batch(iterator, config.data.batch_size),
device=device,
)
result = trainer.train_step(current_batch, step=step)
step_results.append(result)
collapse_alerts = _last_collapse_alerts(trainer)
collapse_alert_count += len(collapse_alerts)
sample_count += len(current_batch.window_ids)
log.write(json.dumps({"event": "train.step", **result.to_dict()}) + "\n")
if step in (first_step, steps) or step % progress_every == 0:
progress_logger.info(
"training.metric",
step=step,
name="sample_count",
value=sample_count,
unit="samples",
kind="counter",
)
progress_logger.info(
"training.metric",
step=step,
name="loss",
value=result.loss,
unit="unitless",
kind="gauge",
)
progress_logger.info(
"training.metric",
step=step,
name="pred_var_per_dim",
value=result.pred_var_per_dim,
unit="unitless",
kind="gauge",
)
for alert in collapse_alerts:
log.write(
json.dumps(
{"event": "training.collapse.alert", "step": step, **alert},
sort_keys=True,
)
+ "\n"
)
log.write(json.dumps({"event": "train.end", "steps_completed": steps}) + "\n")
final = step_results[-1]
_write_metrics(
metrics_path,
config=config,
steps=steps,
resumed_from_step=resumed_from_step,
sample_count=sample_count,
final_loss=final.loss,
step_results=step_results,
collapse_alert_count=collapse_alert_count,
dataset_snapshot_id=dataset_snapshot_id,
resume_checkpoint_path=resume_from,
)
_write_checkpoint(
checkpoint_path,
predictor=predictor,
action_encoder=action_encoder,
optimizer=optimizer,
config=config,
dataset_snapshot_id=dataset_snapshot_id,
steps=steps,
seeds=seeds,
)
_write_training_metadata(
metadata_path,
config=config,
command=command,
commit_sha=commit_sha,
package_version=package_version,
dataset_snapshot_id=dataset_snapshot_id,
seeds=seeds,
determinism=determinism.to_dict(),
artifacts={
"training_config": config_path.name,
"metrics": metrics_path.name,
"logs": [log_path.name],
"checkpoint_files": [checkpoint_path.name],
"dataset_manifest": run_dataset_manifest_path.name,
},
preflight_report=preflight_report,
final_loss=final.loss,
sample_count=sample_count,
resumed_from_step=resumed_from_step,
resume_checkpoint_path=resume_from,
)
return CarbonTrainingReport(
run_id=config.run_id,
run_dir=run_dir,
dataset_snapshot_id=dataset_snapshot_id,
steps_requested=steps,
steps_completed=steps,
resumed_from_step=resumed_from_step,
sample_count=sample_count,
final_loss=final.loss,
checkpoint_path=checkpoint_path,
resume_checkpoint_path=resume_from,
metrics_path=metrics_path,
log_path=log_path,
config_path=config_path,
preflight_path=preflight_path,
training_metadata_path=metadata_path,
)