Skip to content

geno_lewm.predictor

predictor

Action-conditioned predictor modules for GenoLeWM.

ARPredictor

ARPredictor(predictor: object)

Bases: Module

Inference-time autoregressive rollout over a base Predictor.

The wrapper defines the public RFC-0004 rollout contract: each action is scored against the state predicted by the previous action, producing [s_hat[t+1], ..., s_hat[t+K]]. When the wrapped predictor exposes rollout-cache hooks, static action projections are encoded once before the autoregressive loop; otherwise the wrapper falls back to repeated forward calls.

Source code in geno_lewm/predictor/ar.py
def __init__(self, predictor: object) -> None:
    super().__init__()
    if not callable(predictor):
        raise InputError("predictor must be callable")
    self.predictor: Any = predictor
    self.d_state = _required_positive_int(predictor, "d_state")
    self.d_action = _required_positive_int(predictor, "d_action")
    self.max_actions = _required_positive_int(predictor, "max_actions")

rollout

rollout(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> tuple[Tensor, ...]

Return one predicted state per autoregressive action step.

Source code in geno_lewm/predictor/ar.py
@_INFERENCE_MODE
def rollout(
    self,
    state: Tensor,
    action_sequence: Tensor | Sequence[Tensor],
    action_mask: Tensor | None = None,
) -> tuple[Tensor, ...]:
    """Return one predicted state per autoregressive action step."""
    return tuple(self.rollout_tensor(state, action_sequence, action_mask).unbind(dim=1))

rollout_tensor

rollout_tensor(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> Tensor

Return autoregressive rollout as (batch, steps, d_state).

Source code in geno_lewm/predictor/ar.py
@_INFERENCE_MODE
def rollout_tensor(
    self,
    state: Tensor,
    action_sequence: Tensor | Sequence[Tensor],
    action_mask: Tensor | None = None,
) -> Tensor:
    """Return autoregressive rollout as ``(batch, steps, d_state)``."""
    actions = self._normalize_actions(action_sequence)
    self._validate_state(state, actions)

    current = state
    call_mask = torch.ones((actions.shape[0], 1), dtype=torch.bool, device=actions.device)
    action_tokens = self._cached_action_tokens(actions)
    action_cache = (
        self._cached_rollout_action_cache(action_tokens) if action_mask is None else None
    )
    state_token_bias = self._cached_state_token_bias(state)
    upcast_output_mlp = actions.shape[1] > 20
    output_dtype = torch.float32 if upcast_output_mlp else state.dtype
    outputs = state.new_empty(
        (actions.shape[0], actions.shape[1], self.d_state),
        dtype=output_dtype,
    )
    if action_mask is None:
        if action_tokens is None:
            for step in range(actions.shape[1]):
                prediction = self.predictor(
                    current,
                    actions[:, step : step + 1, :],
                    call_mask,
                )
                current = prediction[:, 0, :]
                outputs[:, step, :] = current
            return outputs

        one_step_state = getattr(
            self.predictor,
            "_forward_one_step_unmasked_state_from_action_token",
            None,
        )

        one_step_unmasked = getattr(
            self.predictor,
            "_forward_one_step_unmasked_from_action_token",
            None,
        )
        select_action_cache = getattr(
            self.predictor,
            "_slice_rollout_action_cache",
            None,
        )
        if callable(one_step_state):
            for step in range(actions.shape[1]):
                step_cache = (
                    select_action_cache(action_cache, step)
                    if action_cache is not None and callable(select_action_cache)
                    else None
                )
                current = one_step_state(
                    current,
                    action_tokens[:, step, :],
                    step_cache,
                    state_token_bias=state_token_bias,
                    upcast_output_mlp=upcast_output_mlp,
                )
                outputs[:, step, :] = current
            return outputs

        if callable(one_step_unmasked):
            for step in range(actions.shape[1]):
                step_cache = (
                    select_action_cache(action_cache, step)
                    if action_cache is not None and callable(select_action_cache)
                    else None
                )
                prediction = one_step_unmasked(
                    current,
                    action_tokens[:, step, :],
                    step_cache,
                    upcast_output_mlp=upcast_output_mlp,
                )
                current = prediction[:, 0, :]
                outputs[:, step, :] = current
            return outputs

        for step in range(actions.shape[1]):
            prediction = self.predictor._forward_one_step_from_action_token(
                current,
                action_tokens[:, step, :],
                call_mask,
                upcast_output_mlp=upcast_output_mlp,
            )
            current = prediction[:, 0, :]
            outputs[:, step, :] = current
        return outputs

    mask = self._normalize_mask(actions, action_mask)
    for step in range(actions.shape[1]):
        active = mask[:, step].unsqueeze(-1)
        if action_tokens is None:
            prediction = self.predictor(
                current,
                actions[:, step : step + 1, :],
                call_mask,
            )
        else:
            prediction = self.predictor._forward_one_step_from_action_token(
                current,
                action_tokens[:, step, :],
                call_mask,
                upcast_output_mlp=upcast_output_mlp,
            )
        next_state = prediction[:, 0, :]
        current = torch.where(active, next_state, current)
        outputs[:, step, :] = torch.where(
            active,
            next_state,
            torch.zeros_like(next_state),
        )
    return outputs

predict_single

predict_single(state: Tensor, action: Tensor) -> Tensor

Return the one-step predicted state for action.

Source code in geno_lewm/predictor/ar.py
@_INFERENCE_MODE
def predict_single(self, state: Tensor, action: Tensor) -> Tensor:
    """Return the one-step predicted state for ``action``."""
    actions = self._normalize_actions(action.unsqueeze(1) if action.ndim == 2 else action)
    if actions.shape[1] != 1:
        raise InputError(
            "predict_single expects exactly one action",
            details={"steps": actions.shape[1]},
        )
    return self.rollout(state, actions)[0]

predict_trajectory

predict_trajectory(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> tuple[Tensor, ...]

Alias for :meth:rollout matching RFC-0004 terminology.

Source code in geno_lewm/predictor/ar.py
@_INFERENCE_MODE
def predict_trajectory(
    self,
    state: Tensor,
    action_sequence: Tensor | Sequence[Tensor],
    action_mask: Tensor | None = None,
) -> tuple[Tensor, ...]:
    """Alias for :meth:`rollout` matching RFC-0004 terminology."""
    return self.rollout(state, action_sequence, action_mask)

predict_haplotype

predict_haplotype(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> Tensor

Return the final predicted state after all valid actions.

Source code in geno_lewm/predictor/ar.py
@_INFERENCE_MODE
def predict_haplotype(
    self,
    state: Tensor,
    action_sequence: Tensor | Sequence[Tensor],
    action_mask: Tensor | None = None,
) -> Tensor:
    """Return the final predicted state after all valid actions."""
    actions = self._normalize_actions(action_sequence)
    if action_mask is None:
        return self.rollout_tensor(state, actions)[:, -1, :]
    mask = self._normalize_mask(actions, action_mask)
    trajectory = self.rollout_tensor(state, actions, mask)
    indices = mask.sum(dim=1) - 1
    rows = torch.arange(actions.shape[0], device=actions.device)
    return trajectory[rows, indices]

PredictionLossResult dataclass

PredictionLossResult(loss: Tensor, pred_loss: Tensor, kl_reg: Tensor, phase: Literal['phase1', 'phase2'])

Phase-aware predictor loss components.

Predictor

Predictor(*, d_state: int = 1024, d_action: int = 512, d_hidden: int = 768, n_heads: int = 8, n_cross_layers: int = 4, n_self_layers: int = 2, ffn_dim: int = 768, max_actions: int = 16)

Bases: Module

Cross-attention Transformer predictor from RFC-0004.

The default keeps the RFC-0004 4-cross/2-self topology and Carbon-compatible d_state=1024 output while using the RFC's target-size d_hidden=768 variant so the trainable budget is close to the documented ~22M target.

Source code in geno_lewm/predictor/model.py
def __init__(
    self,
    *,
    d_state: int = 1024,
    d_action: int = 512,
    d_hidden: int = 768,
    n_heads: int = 8,
    n_cross_layers: int = 4,
    n_self_layers: int = 2,
    ffn_dim: int = 768,
    max_actions: int = 16,
) -> None:
    super().__init__()
    _require_positive("d_state", d_state)
    _require_positive("d_action", d_action)
    _require_positive("d_hidden", d_hidden)
    _require_positive("n_heads", n_heads)
    _require_positive("n_cross_layers", n_cross_layers)
    _require_positive("n_self_layers", n_self_layers)
    _require_positive("ffn_dim", ffn_dim)
    _require_positive("max_actions", max_actions)
    if d_hidden % n_heads != 0:
        raise InputError(
            "d_hidden must be divisible by n_heads",
            details={"d_hidden": d_hidden, "n_heads": n_heads},
        )

    self.d_state = d_state
    self.d_action = d_action
    self.d_hidden = d_hidden
    self.max_actions = max_actions
    self.state_projection = (
        nn.Identity() if d_state == d_hidden else nn.Linear(d_state, d_hidden)
    )
    self.action_projection = nn.Linear(d_action, d_hidden)
    self.token_type_embedding = nn.Embedding(2, d_hidden)
    self.step_position_embedding = nn.Embedding(max_actions + 1, d_hidden)
    self.cross_blocks = nn.ModuleList(
        _StateToActionCrossBlock(
            d_hidden=d_hidden,
            n_heads=n_heads,
            ffn_dim=ffn_dim,
        )
        if index % 2 == 0
        else _ActionToStateCrossBlock(
            d_hidden=d_hidden,
            n_heads=n_heads,
            ffn_dim=ffn_dim,
        )
        for index in range(n_cross_layers)
    )
    self.self_blocks = nn.ModuleList(
        _SelfAttentionBlock(d_hidden=d_hidden, n_heads=n_heads, ffn_dim=ffn_dim)
        for _ in range(n_self_layers)
    )
    self.output_mlp = nn.Sequential(
        nn.Linear(d_hidden, d_hidden),
        nn.GELU(),
        nn.LayerNorm(d_hidden),
        nn.Linear(d_hidden, d_state),
    )
    self.reset_parameters()

reset_parameters

reset_parameters() -> None

Initialize layers per RFC-0004 §3.4.

Source code in geno_lewm/predictor/model.py
def reset_parameters(self) -> None:
    """Initialize layers per RFC-0004 §3.4."""
    for module in self.modules():
        if isinstance(module, nn.MultiheadAttention):
            std = math.sqrt(2.0 / float(module.embed_dim))
            nn.init.trunc_normal_(module.in_proj_weight, std=std, a=-2 * std, b=2 * std)
            if module.in_proj_bias is not None:
                nn.init.zeros_(module.in_proj_bias)
        elif isinstance(module, nn.Linear):
            std = math.sqrt(2.0 / float(module.in_features))
            nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=0.02)

    final = self.output_mlp[-1]
    if isinstance(final, nn.Linear):
        nn.init.zeros_(final.weight)
        nn.init.zeros_(final.bias)

forward

forward(state: Tensor, actions: Tensor, action_mask: Tensor) -> Tensor

Return per-action next-state predictions of shape (B, K, d_state).

Source code in geno_lewm/predictor/model.py
def forward(
    self,
    state: Tensor,
    actions: Tensor,
    action_mask: Tensor,
) -> Tensor:
    """Return per-action next-state predictions of shape ``(B, K, d_state)``."""
    mask = self._validate_inputs(state, actions, action_mask)
    state_token = self._encode_state_token(state)
    action_tokens = self._encode_forward_actions(actions)
    return self._predict_from_tokens(
        state,
        state_token,
        action_tokens,
        mask,
        upcast_output_mlp=actions.shape[1] > 20,
    )

lejepa_kl_regularizer

lejepa_kl_regularizer(states: Tensor, *, eps: float = 1e-06) -> Tensor

Return the closed-form KL from empirical state distribution to N(0, I).

Source code in geno_lewm/predictor/losses.py
def lejepa_kl_regularizer(
    states: Tensor, *, eps: float = 1.0e-6
) -> Tensor:  # pragma: no cover - optional torch runtime is tested separately.
    """Return the closed-form KL from empirical state distribution to ``N(0, I)``."""
    _require_torch("lejepa_kl_regularizer")
    _require_positive_float("eps", eps)
    if states.ndim < 2:
        raise InputError(
            "states must have shape (..., d_state)",
            details={"shape": tuple(states.shape)},
        )
    if not states.is_floating_point():
        raise InputError("states must be a floating-point tensor")
    d_state = states.shape[-1]
    if d_state <= 0:
        raise InputError("states must have a non-empty feature dimension")

    # float64 for the covariance spectrum: a minibatch with fewer samples than
    # d_state yields a rank-deficient empirical covariance, and float32 slogdet
    # loses sign reliability once many eigenvalues approach the stabilizer. The
    # covariance is symmetric PSD, so eigvalsh gives a stable spectrum directly.
    flat = states.reshape(-1, d_state).to(dtype=torch.float64)
    if flat.shape[0] == 0:
        raise InputError("states must contain at least one sample")
    mean = flat.mean(dim=0)
    centered = flat - mean
    covariance = centered.T @ centered / flat.shape[0]
    trace = torch.diagonal(covariance).sum()
    # logdet of the stabilized covariance ``cov + eps * I`` via its symmetric
    # eigenvalues. This is the same quantity as ``slogdet(cov + eps * I)`` but
    # never hits the float32 sign instability that slogdet exhibits when the
    # batch is rank-deficient (fewer samples than d_state). Clamp tiny negative
    # numerical eigenvalues to zero before adding the floor.
    eigvals = torch.linalg.eigvalsh(covariance)
    logdet = torch.log(torch.clamp(eigvals, min=0.0) + eps).sum()
    kl = 0.5 * (mean.pow(2).sum() + trace - logdet - d_state)
    return kl.to(dtype=states.dtype)

prediction_loss

prediction_loss(prediction: Tensor, target: Tensor, *, alpha: float = 1.0, beta: float = 0.1, mask: Tensor | None = None, eps: float = 1e-08) -> Tensor

Return RFC-0005 alpha * (1 - cos) + beta * MSE / d_state.

Source code in geno_lewm/predictor/losses.py
def prediction_loss(
    prediction: Tensor,
    target: Tensor,
    *,
    alpha: float = 1.0,
    beta: float = 0.1,
    mask: Tensor | None = None,
    eps: float = 1.0e-8,
) -> Tensor:  # pragma: no cover - optional torch runtime is tested separately.
    """Return RFC-0005 ``alpha * (1 - cos) + beta * MSE / d_state``."""
    _require_torch("prediction_loss")
    _validate_loss_inputs(prediction, target, mask=mask)
    _require_nonnegative("alpha", alpha)
    _require_nonnegative("beta", beta)
    _require_positive_float("eps", eps)

    cosine = functional.cosine_similarity(prediction, target, dim=-1, eps=eps)
    squared = (prediction - target).pow(2).sum(dim=-1) / prediction.shape[-1]
    per_step = alpha * (1.0 - cosine) + beta * squared
    return _masked_mean(per_step, mask)

predictor_loss

predictor_loss(prediction: Tensor, target: Tensor, *, phase: Literal['phase1', 'phase2'] = 'phase1', alpha: float = 1.0, beta: float = 0.1, gamma: float = 0.5, mask: Tensor | None = None, regularizer_states: Tensor | None = None, eps: float = 1e-06) -> PredictionLossResult

Return phase-conditional total loss and monitorable components.

Source code in geno_lewm/predictor/losses.py
def predictor_loss(
    prediction: Tensor,
    target: Tensor,
    *,
    phase: Literal["phase1", "phase2"] = "phase1",
    alpha: float = 1.0,
    beta: float = 0.1,
    gamma: float = 0.5,
    mask: Tensor | None = None,
    regularizer_states: Tensor | None = None,
    eps: float = 1.0e-6,
) -> PredictionLossResult:  # pragma: no cover - optional torch runtime is tested separately.
    """Return phase-conditional total loss and monitorable components."""
    _require_torch("predictor_loss")
    if phase not in ("phase1", "phase2"):
        raise InputError("phase must be either 'phase1' or 'phase2'", details={"phase": phase})
    _require_nonnegative("gamma", gamma)
    pred = prediction_loss(prediction, target, alpha=alpha, beta=beta, mask=mask)
    reg_source = target if regularizer_states is None else regularizer_states
    kl_reg = lejepa_kl_regularizer(reg_source, eps=eps)
    total = pred if phase == "phase1" else pred + gamma * kl_reg
    return PredictionLossResult(loss=total, pred_loss=pred, kl_reg=kl_reg, phase=phase)

build_predictor

build_predictor(config: Any) -> Predictor

Construct the predictor from a GenoLeWMConfig.

Single source of truth shared by training (:mod:geno_lewm.training.real) and the deploy runtime (:mod:geno_lewm.deploy.runtime) so a trained checkpoint always loads back into an identically-shaped predictor. The two call sites previously built the predictor with different hyperparameters (d_hidden / n_cross_layers / ffn_dim), so a real exported checkpoint could not be loaded for scoring. Keep both on this builder.

Source code in geno_lewm/predictor/__init__.py
def build_predictor(config: Any) -> Predictor:
    """Construct the predictor from a ``GenoLeWMConfig``.

    Single source of truth shared by training (:mod:`geno_lewm.training.real`)
    and the deploy runtime (:mod:`geno_lewm.deploy.runtime`) so a trained
    checkpoint always loads back into an identically-shaped predictor. The two
    call sites previously built the predictor with different hyperparameters
    (d_hidden / n_cross_layers / ffn_dim), so a real exported checkpoint could
    not be loaded for scoring. Keep both on this builder.
    """
    return Predictor(
        d_state=config.predictor.d_state,
        d_action=config.action.d_action,
        n_heads=config.predictor.n_heads,
        n_cross_layers=config.predictor.n_layers,
    )