geno_lewm.predictor.model¶
model
¶
Cross-attention predictor for action-conditioned latent transitions.
The PyTorch runtime is optional. Importing :mod:geno_lewm.predictor
is lightweight; instantiating :class:Predictor requires a training
environment with PyTorch installed.
Predictor
¶
Predictor(*, d_state: int = 1024, d_action: int = 512, d_hidden: int = 768, n_heads: int = 8, n_cross_layers: int = 4, n_self_layers: int = 2, ffn_dim: int = 768, max_actions: int = 16)
Bases: Module
Cross-attention Transformer predictor from RFC-0004.
The default keeps the RFC-0004 4-cross/2-self topology and
Carbon-compatible d_state=1024 output while using the RFC's
target-size d_hidden=768 variant so the trainable budget is
close to the documented ~22M target.
Source code in geno_lewm/predictor/model.py
reset_parameters
¶
Initialize layers per RFC-0004 §3.4.
Source code in geno_lewm/predictor/model.py
forward
¶
Return per-action next-state predictions of shape (B, K, d_state).