Skip to content

geno_lewm.predictor.model

model

Cross-attention predictor for action-conditioned latent transitions.

The PyTorch runtime is optional. Importing :mod:geno_lewm.predictor is lightweight; instantiating :class:Predictor requires a training environment with PyTorch installed.

Predictor

Predictor(*, d_state: int = 1024, d_action: int = 512, d_hidden: int = 768, n_heads: int = 8, n_cross_layers: int = 4, n_self_layers: int = 2, ffn_dim: int = 768, max_actions: int = 16)

Bases: Module

Cross-attention Transformer predictor from RFC-0004.

The default keeps the RFC-0004 4-cross/2-self topology and Carbon-compatible d_state=1024 output while using the RFC's target-size d_hidden=768 variant so the trainable budget is close to the documented ~22M target.

Source code in geno_lewm/predictor/model.py
def __init__(
    self,
    *,
    d_state: int = 1024,
    d_action: int = 512,
    d_hidden: int = 768,
    n_heads: int = 8,
    n_cross_layers: int = 4,
    n_self_layers: int = 2,
    ffn_dim: int = 768,
    max_actions: int = 16,
) -> None:
    super().__init__()
    _require_positive("d_state", d_state)
    _require_positive("d_action", d_action)
    _require_positive("d_hidden", d_hidden)
    _require_positive("n_heads", n_heads)
    _require_positive("n_cross_layers", n_cross_layers)
    _require_positive("n_self_layers", n_self_layers)
    _require_positive("ffn_dim", ffn_dim)
    _require_positive("max_actions", max_actions)
    if d_hidden % n_heads != 0:
        raise InputError(
            "d_hidden must be divisible by n_heads",
            details={"d_hidden": d_hidden, "n_heads": n_heads},
        )

    self.d_state = d_state
    self.d_action = d_action
    self.d_hidden = d_hidden
    self.max_actions = max_actions
    self.state_projection = (
        nn.Identity() if d_state == d_hidden else nn.Linear(d_state, d_hidden)
    )
    self.action_projection = nn.Linear(d_action, d_hidden)
    self.token_type_embedding = nn.Embedding(2, d_hidden)
    self.step_position_embedding = nn.Embedding(max_actions + 1, d_hidden)
    self.cross_blocks = nn.ModuleList(
        _StateToActionCrossBlock(
            d_hidden=d_hidden,
            n_heads=n_heads,
            ffn_dim=ffn_dim,
        )
        if index % 2 == 0
        else _ActionToStateCrossBlock(
            d_hidden=d_hidden,
            n_heads=n_heads,
            ffn_dim=ffn_dim,
        )
        for index in range(n_cross_layers)
    )
    self.self_blocks = nn.ModuleList(
        _SelfAttentionBlock(d_hidden=d_hidden, n_heads=n_heads, ffn_dim=ffn_dim)
        for _ in range(n_self_layers)
    )
    self.output_mlp = nn.Sequential(
        nn.Linear(d_hidden, d_hidden),
        nn.GELU(),
        nn.LayerNorm(d_hidden),
        nn.Linear(d_hidden, d_state),
    )
    self.reset_parameters()

reset_parameters

reset_parameters() -> None

Initialize layers per RFC-0004 §3.4.

Source code in geno_lewm/predictor/model.py
def reset_parameters(self) -> None:
    """Initialize layers per RFC-0004 §3.4."""
    for module in self.modules():
        if isinstance(module, nn.MultiheadAttention):
            std = math.sqrt(2.0 / float(module.embed_dim))
            nn.init.trunc_normal_(module.in_proj_weight, std=std, a=-2 * std, b=2 * std)
            if module.in_proj_bias is not None:
                nn.init.zeros_(module.in_proj_bias)
        elif isinstance(module, nn.Linear):
            std = math.sqrt(2.0 / float(module.in_features))
            nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=0.02)

    final = self.output_mlp[-1]
    if isinstance(final, nn.Linear):
        nn.init.zeros_(final.weight)
        nn.init.zeros_(final.bias)

forward

forward(state: Tensor, actions: Tensor, action_mask: Tensor) -> Tensor

Return per-action next-state predictions of shape (B, K, d_state).

Source code in geno_lewm/predictor/model.py
def forward(
    self,
    state: Tensor,
    actions: Tensor,
    action_mask: Tensor,
) -> Tensor:
    """Return per-action next-state predictions of shape ``(B, K, d_state)``."""
    mask = self._validate_inputs(state, actions, action_mask)
    state_token = self._encode_state_token(state)
    action_tokens = self._encode_forward_actions(actions)
    return self._predict_from_tokens(
        state,
        state_token,
        action_tokens,
        mask,
        upcast_output_mlp=actions.shape[1] > 20,
    )