Skip to content

geno_lewm.training.trainer

trainer

Torch trainer core for Carbon-backed GenoLeWM runs.

The public CLI still gates real training behind explicit preflight and fixture modes. This module is the optional-runtime trainer boundary: it is importable without PyTorch, but constructing batches, optimizers, or train steps requires a geno-lewm[train] environment.

TrainerSeeds dataclass

TrainerSeeds(data: int, predictor: int, lora: int)

Distinct RNG seeds consumed by the real training stack.

TorchDeterminismReport dataclass

TorchDeterminismReport(seed: int, deterministic: bool, cublas_workspace_config: str | None, torch_deterministic_algorithms: bool)

Runtime settings applied before a torch training run.

TorchTrainerBatch dataclass

TorchTrainerBatch(state: Tensor, target: Tensor, rel_edits: tuple[tuple[RelEdit, ...], ...], action_mask: Tensor, window_ids: tuple[str, ...])

One encoded minibatch consumed by :class:TorchTrainer.

TorchTrainerStepResult dataclass

TorchTrainerStepResult(step: int, lr_multiplier: float, loss: float, pred_loss: float, kl_reg: float, action_count: int, pred_var_per_dim: float)

Scalar outputs from one optimizer step.

TorchTrainer

TorchTrainer(*, predictor: object, action_encoder: object, optimizer: object, config: GenoLeWMConfig, total_steps: int)

Minimal optimizer loop for Carbon-state predictor training.

Source code in geno_lewm/training/trainer.py
def __init__(
    self,
    *,
    predictor: object,
    action_encoder: object,
    optimizer: object,
    config: GenoLeWMConfig,
    total_steps: int,
) -> None:
    _require_torch("TorchTrainer")
    _require_positive_int("total_steps", total_steps)
    self.predictor = predictor
    self.action_encoder = action_encoder
    self.optimizer = optimizer
    self.config = config
    self.total_steps = total_steps
    self.collapse_monitor = CollapseMonitor(
        log_every_steps=config.training.collapse_log_every_steps,
    )
    self.last_collapse_alerts: tuple[dict[str, object], ...] = ()

train_step

train_step(batch: TorchTrainerBatch, *, step: int) -> TorchTrainerStepResult

Run one optimizer step over an encoded Carbon-state batch.

Source code in geno_lewm/training/trainer.py
def train_step(self, batch: TorchTrainerBatch, *, step: int) -> TorchTrainerStepResult:
    """Run one optimizer step over an encoded Carbon-state batch."""
    _require_positive_int("step", step)
    if step > self.total_steps:
        raise InputError(
            "step cannot exceed total_steps",
            details={"step": step, "total_steps": self.total_steps},
        )
    lr_multiplier = set_optimizer_lr(
        self.optimizer,
        step=step,
        total_steps=self.total_steps,
        warmup_steps=self.config.optimizer.warmup_steps,
        schedule=self.config.optimizer.schedule,
    )
    _zero_grad(self.optimizer)
    action_embeddings = _call_action_encoder(self.action_encoder, batch.rel_edits)
    prediction = _call_predictor(
        self.predictor,
        state=batch.state,
        actions=action_embeddings,
        action_mask=batch.action_mask,
    )
    loss_result = predictor_loss(
        prediction,
        batch.target,
        phase=self.config.phase,
        mask=batch.action_mask,
    )
    action_count = int(batch.action_mask.sum().item())
    self.last_collapse_alerts = ()
    if action_count > 0:
        collapse_check = self.collapse_monitor.observe(
            _masked_training_rows(prediction, batch.action_mask),
            _masked_training_rows(batch.target, batch.action_mask),
            kl_reg=_scalar(loss_result.kl_reg),
            step=step,
        )
        if collapse_check is not None:
            self.last_collapse_alerts = tuple(
                {
                    "criterion": alert.criterion,
                    "value": alert.value,
                    "threshold": alert.threshold,
                }
                for alert in collapse_check.alerts
            )
    loss_result.loss.backward()
    if self.config.optimizer.grad_clip > 0:
        parameters = _trainable_parameters((self.predictor, self.action_encoder))
        torch.nn.utils.clip_grad_norm_(parameters, self.config.optimizer.grad_clip)
    _optimizer_step(self.optimizer)
    return TorchTrainerStepResult(
        step=step,
        lr_multiplier=lr_multiplier,
        loss=_scalar(loss_result.loss),
        pred_loss=_scalar(loss_result.pred_loss),
        kl_reg=_scalar(loss_result.kl_reg),
        action_count=action_count,
        pred_var_per_dim=_pred_var_per_dim(prediction),
    )

configure_torch_reproducibility

configure_torch_reproducibility(*, seed: int, deterministic: bool) -> TorchDeterminismReport

Seed Python/NumPy/PyTorch and optionally enable deterministic torch kernels.

Source code in geno_lewm/training/trainer.py
def configure_torch_reproducibility(
    *,
    seed: int,
    deterministic: bool,
) -> TorchDeterminismReport:
    """Seed Python/NumPy/PyTorch and optionally enable deterministic torch kernels."""
    _require_torch("configure_torch_reproducibility")
    _require_nonnegative_int("seed", seed)
    random.seed(seed)
    _seed_numpy(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():  # pragma: no cover - depends on host accelerator.
        torch.cuda.manual_seed_all(seed)
    cublas_config: str | None = None
    if deterministic:
        cublas_config = os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
        torch.use_deterministic_algorithms(True)
        cudnn = getattr(torch.backends, "cudnn", None)
        if cudnn is not None:
            cudnn.benchmark = False
    else:
        use_deterministic = getattr(torch, "use_deterministic_algorithms", None)
        if callable(use_deterministic):
            use_deterministic(False)
    return TorchDeterminismReport(
        seed=seed,
        deterministic=deterministic,
        cublas_workspace_config=cublas_config,
        torch_deterministic_algorithms=bool(torch.are_deterministic_algorithms_enabled()),
    )

encode_training_batch

encode_training_batch(*, encoder: object, tuples: Sequence[TrainingTuple], source_windows: Mapping[str, str], device: str | object | None = None, dtype: object | None = None) -> TorchTrainerBatch

Encode source/target windows for a real predictor-training minibatch.

Source code in geno_lewm/training/trainer.py
def encode_training_batch(
    *,
    encoder: object,
    tuples: Sequence[TrainingTuple],
    source_windows: Mapping[str, str],
    device: str | object | None = None,
    dtype: object | None = None,
) -> TorchTrainerBatch:
    """Encode source/target windows for a real predictor-training minibatch."""
    _require_torch("encode_training_batch")
    if not tuples:
        raise InputError("training batch must contain at least one tuple")
    source_sequences: list[str] = []
    target_sequences: list[str] = []
    rel_edits: list[tuple[RelEdit, ...]] = []
    window_ids: list[str] = []
    for item in tuples:
        if not isinstance(item, TrainingTuple):
            raise InputError(
                "tuples must contain TrainingTuple values",
                details={"type": type(item).__name__},
            )
        try:
            source_sequences.append(source_windows[item.window_id])
        except KeyError as exc:
            raise InputError(
                "source window sequence missing for training tuple",
                details={"window_id": item.window_id},
            ) from exc
        target_sequences.append(item.target_window)
        rel_edits.append(tuple(item.rel_edits))
        window_ids.append(item.window_id)
    target_loci = [edits[0].rel_pos if edits else None for edits in rel_edits]
    encoder_runtime = _encoder_runtime(encoder)
    source_states = _source_states(encoder_runtime, source_sequences)
    target_states = encoder_runtime.encode_batch(target_sequences, target_loci)
    state = torch.tensor(source_states, dtype=dtype or torch.float32, device=device)
    target_single = torch.tensor(target_states, dtype=state.dtype, device=state.device)
    mask = make_action_mask(rel_edits, device=state.device)
    target = target_single.unsqueeze(1).expand(-1, mask.shape[1], -1).clone()
    target = target.masked_fill(~mask.unsqueeze(-1), 0.0)
    return TorchTrainerBatch(
        state=state,
        target=target,
        rel_edits=tuple(rel_edits),
        action_mask=mask,
        window_ids=tuple(window_ids),
    )

make_action_mask

make_action_mask(rel_edits: Sequence[Sequence[object]], *, device: object | None = None) -> Tensor

Return a boolean action mask for a ragged batch of relative edits.

Source code in geno_lewm/training/trainer.py
def make_action_mask(
    rel_edits: Sequence[Sequence[object]],
    *,
    device: object | None = None,
) -> Tensor:
    """Return a boolean action mask for a ragged batch of relative edits."""
    _require_torch("make_action_mask")
    if not rel_edits:
        raise InputError("rel_edits must contain at least one batch item")
    lengths = []
    for edits in rel_edits:
        if isinstance(edits, str | bytes) or not isinstance(edits, Sequence):
            raise InputError(
                "rel_edits entries must be sequences",
                details={"type": type(edits).__name__},
            )
        if not edits:
            raise InputError("each training batch item must include at least one edit")
        lengths.append(len(edits))
    max_len = max(lengths)
    mask = torch.zeros((len(lengths), max_len), dtype=torch.bool, device=device)
    for row, length in enumerate(lengths):
        mask[row, :length] = True
    return mask

build_adamw_optimizer

build_adamw_optimizer(*, predictor: object, action_encoder: object, config: GenoLeWMConfig) -> object

Build AdamW groups for predictor/action-encoder trainable parameters.

Source code in geno_lewm/training/trainer.py
def build_adamw_optimizer(
    *,
    predictor: object,
    action_encoder: object,
    config: GenoLeWMConfig,
) -> object:
    """Build AdamW groups for predictor/action-encoder trainable parameters."""
    _require_torch("build_adamw_optimizer")
    if config.optimizer.name != "adamw":
        raise InputError(
            "real trainer currently supports AdamW only",
            details={"optimizer": config.optimizer.name},
        )
    groups = _adamw_param_groups(
        (("predictor", predictor), ("action_encoder", action_encoder)),
        lr=float(config.optimizer.lr),
        weight_decay=float(config.optimizer.weight_decay),
    )
    if not groups:
        raise InputError("no trainable predictor/action-encoder parameters found")
    return torch.optim.AdamW(
        groups,
        betas=(float(config.optimizer.beta1), float(config.optimizer.beta2)),
        eps=1.0e-8,
    )

wsd_lr_multiplier

wsd_lr_multiplier(step: int, *, total_steps: int, warmup_steps: int, schedule: ScheduleName = 'wsd') -> float

Return the RFC-0005 WSD learning-rate multiplier for a 1-indexed step.

Source code in geno_lewm/training/trainer.py
def wsd_lr_multiplier(
    step: int,
    *,
    total_steps: int,
    warmup_steps: int,
    schedule: ScheduleName = "wsd",
) -> float:
    """Return the RFC-0005 WSD learning-rate multiplier for a 1-indexed step."""
    _require_positive_int("step", step)
    _require_positive_int("total_steps", total_steps)
    _require_nonnegative_int("warmup_steps", warmup_steps)
    if step > total_steps:
        raise InputError(
            "step cannot exceed total_steps",
            details={"step": step, "total_steps": total_steps},
        )
    if schedule == "constant":
        return 1.0
    if schedule == "cosine":
        if total_steps == 1:
            return 1.0
        progress = (step - 1) / (total_steps - 1)
        return 0.5 * (1.0 + float(torchless_cos(progress)))
    if schedule != "wsd":
        raise InputError("unsupported learning-rate schedule", details={"schedule": schedule})
    if warmup_steps > 0 and step <= warmup_steps:
        return step / warmup_steps
    if total_steps <= warmup_steps:
        return 1.0
    post_warmup = total_steps - warmup_steps
    decay_start = warmup_steps + max(1, int(post_warmup * 0.80))
    final_taper_start = warmup_steps + max(1, int(post_warmup * 0.98))
    if step <= decay_start:
        return 1.0
    if step <= final_taper_start:
        span = max(1, final_taper_start - decay_start)
        progress = (step - decay_start) / span
        return 1.0 - 0.9 * progress
    span = max(1, total_steps - final_taper_start)
    progress = (step - final_taper_start) / span
    return max(0.01, 0.1 - 0.09 * progress)

set_optimizer_lr

set_optimizer_lr(optimizer: object, *, step: int, total_steps: int, warmup_steps: int, schedule: ScheduleName = 'wsd') -> float

Set optimizer group LRs from each group's initial_lr and return multiplier.

Source code in geno_lewm/training/trainer.py
def set_optimizer_lr(
    optimizer: object,
    *,
    step: int,
    total_steps: int,
    warmup_steps: int,
    schedule: ScheduleName = "wsd",
) -> float:
    """Set optimizer group LRs from each group's ``initial_lr`` and return multiplier."""
    multiplier = wsd_lr_multiplier(
        step,
        total_steps=total_steps,
        warmup_steps=warmup_steps,
        schedule=schedule,
    )
    groups = getattr(optimizer, "param_groups", None)
    if not isinstance(groups, list) or not groups:
        raise InputError("optimizer must expose non-empty param_groups")
    for group in groups:
        if not isinstance(group, dict):
            raise InputError("optimizer param_groups must contain dictionaries")
        base_lr = group.setdefault("initial_lr", group.get("lr"))
        if isinstance(base_lr, bool) or not isinstance(base_lr, int | float):
            raise InputError("optimizer param group lr must be numeric")
        group["lr"] = float(base_lr) * multiplier
    return multiplier