geno_lewm.predictor¶
predictor
¶
Action-conditioned predictor modules for GenoLeWM.
ARPredictor
¶
Bases: Module
Inference-time autoregressive rollout over a base Predictor.
The wrapper defines the public RFC-0004 rollout contract: each
action is scored against the state predicted by the previous
action, producing [s_hat[t+1], ..., s_hat[t+K]]. When the
wrapped predictor exposes rollout-cache hooks, static action
projections are encoded once before the autoregressive loop;
otherwise the wrapper falls back to repeated forward calls.
Source code in geno_lewm/predictor/ar.py
rollout
¶
rollout(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> tuple[Tensor, ...]
Return one predicted state per autoregressive action step.
Source code in geno_lewm/predictor/ar.py
rollout_tensor
¶
rollout_tensor(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> Tensor
Return autoregressive rollout as (batch, steps, d_state).
Source code in geno_lewm/predictor/ar.py
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | |
predict_single
¶
Return the one-step predicted state for action.
Source code in geno_lewm/predictor/ar.py
predict_trajectory
¶
predict_trajectory(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> tuple[Tensor, ...]
Alias for :meth:rollout matching RFC-0004 terminology.
Source code in geno_lewm/predictor/ar.py
predict_haplotype
¶
predict_haplotype(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> Tensor
Return the final predicted state after all valid actions.
Source code in geno_lewm/predictor/ar.py
PredictionLossResult
dataclass
¶
PredictionLossResult(loss: Tensor, pred_loss: Tensor, kl_reg: Tensor, phase: Literal['phase1', 'phase2'])
Phase-aware predictor loss components.
Predictor
¶
Predictor(*, d_state: int = 1024, d_action: int = 512, d_hidden: int = 768, n_heads: int = 8, n_cross_layers: int = 4, n_self_layers: int = 2, ffn_dim: int = 768, max_actions: int = 16)
Bases: Module
Cross-attention Transformer predictor from RFC-0004.
The default keeps the RFC-0004 4-cross/2-self topology and
Carbon-compatible d_state=1024 output while using the RFC's
target-size d_hidden=768 variant so the trainable budget is
close to the documented ~22M target.
Source code in geno_lewm/predictor/model.py
reset_parameters
¶
Initialize layers per RFC-0004 §3.4.
Source code in geno_lewm/predictor/model.py
forward
¶
Return per-action next-state predictions of shape (B, K, d_state).
Source code in geno_lewm/predictor/model.py
lejepa_kl_regularizer
¶
Return the closed-form KL from empirical state distribution to N(0, I).
Source code in geno_lewm/predictor/losses.py
prediction_loss
¶
prediction_loss(prediction: Tensor, target: Tensor, *, alpha: float = 1.0, beta: float = 0.1, mask: Tensor | None = None, eps: float = 1e-08) -> Tensor
Return RFC-0005 alpha * (1 - cos) + beta * MSE / d_state.
Source code in geno_lewm/predictor/losses.py
predictor_loss
¶
predictor_loss(prediction: Tensor, target: Tensor, *, phase: Literal['phase1', 'phase2'] = 'phase1', alpha: float = 1.0, beta: float = 0.1, gamma: float = 0.5, mask: Tensor | None = None, regularizer_states: Tensor | None = None, eps: float = 1e-06) -> PredictionLossResult
Return phase-conditional total loss and monitorable components.
Source code in geno_lewm/predictor/losses.py
build_predictor
¶
Construct the predictor from a GenoLeWMConfig.
Single source of truth shared by training (:mod:geno_lewm.training.real)
and the deploy runtime (:mod:geno_lewm.deploy.runtime) so a trained
checkpoint always loads back into an identically-shaped predictor. The two
call sites previously built the predictor with different hyperparameters
(d_hidden / n_cross_layers / ffn_dim), so a real exported checkpoint could
not be loaded for scoring. Keep both on this builder.