Skip to content

geno_lewm.deploy.export

export

Export a trained checkpoint to deployable safetensors weights (RFC-0018 §3.3).

Phase 1 converts the training-produced predictor_checkpoint.pt into the predictor.safetensors + action_encoder.safetensors artifacts that the deploy runtime (:mod:geno_lewm.deploy.runtime) loads with strict=True, plus an export_report.json recording artifact identities. The ONNX / Core ML / GGUF targets and int8/int4 quantization land later (#67–#70); this is the minimal serialize step that unblocks packaging, scoring, eval, and the demo.

export_checkpoint

export_checkpoint(checkpoint_path: Path, output_dir: Path, *, overwrite: bool = False) -> dict[str, Any]

Convert a training checkpoint into deploy-ready safetensors artifacts.

Reads the torch.save checkpoint at checkpoint_path, writes the predictor and action-encoder state_dict tensors as safetensors files under output_dir, and emits export_report.json. Returns the report.

Source code in geno_lewm/deploy/export.py
def export_checkpoint(
    checkpoint_path: Path,
    output_dir: Path,
    *,
    overwrite: bool = False,
) -> dict[str, Any]:
    """Convert a training checkpoint into deploy-ready safetensors artifacts.

    Reads the ``torch.save`` checkpoint at ``checkpoint_path``, writes the
    predictor and action-encoder ``state_dict`` tensors as safetensors files
    under ``output_dir``, and emits ``export_report.json``. Returns the report.
    """
    checkpoint_path = Path(checkpoint_path)
    output_dir = Path(output_dir)
    if not checkpoint_path.is_file():
        raise InputError(
            "checkpoint file does not exist",
            details={"path": str(checkpoint_path)},
        )
    _prepare_output_dir(output_dir, overwrite=overwrite)

    payload = _load_checkpoint(checkpoint_path)
    artifacts: list[dict[str, Any]] = []
    for key, artifact_name in _STATE_COMPONENTS:
        state_dict = _require_state_dict(payload, key)
        destination = output_dir / artifact_name
        _save_safetensors(state_dict, destination)
        artifacts.append(
            {
                "component": key,
                "file": artifact_name,
                "sha256": sha256_file(destination),
                "size_bytes": destination.stat().st_size,
                "tensors": len(state_dict),
            }
        )

    report: dict[str, Any] = {
        "schema_version": SCHEMA_VERSION,
        "generated_by": GENERATED_BY,
        "format": "safetensors",
        "checkpoint": {
            "file": checkpoint_path.name,
            "sha256": sha256_file(checkpoint_path),
            "size_bytes": checkpoint_path.stat().st_size,
            "schema_version": payload.get("schema_version"),
            "run_id": payload.get("run_id"),
            "dataset_snapshot_id": payload.get("dataset_snapshot_id"),
            "steps_completed": payload.get("steps_completed"),
        },
        "artifacts": artifacts,
    }
    report_path = output_dir / EXPORT_REPORT_NAME
    report_path.write_text(json.dumps(report, indent=2, sort_keys=True) + "\n", encoding="utf-8")
    return report