Skip to content

geno_lewm.predictor.ar

ar

Autoregressive rollout wrapper for action-conditioned predictors.

ARPredictor

ARPredictor(predictor: object)

Bases: Module

Inference-time autoregressive rollout over a base Predictor.

The wrapper defines the public RFC-0004 rollout contract: each action is scored against the state predicted by the previous action, producing [s_hat[t+1], ..., s_hat[t+K]]. When the wrapped predictor exposes rollout-cache hooks, static action projections are encoded once before the autoregressive loop; otherwise the wrapper falls back to repeated forward calls.

Source code in geno_lewm/predictor/ar.py
def __init__(self, predictor: object) -> None:
    super().__init__()
    if not callable(predictor):
        raise InputError("predictor must be callable")
    self.predictor: Any = predictor
    self.d_state = _required_positive_int(predictor, "d_state")
    self.d_action = _required_positive_int(predictor, "d_action")
    self.max_actions = _required_positive_int(predictor, "max_actions")

rollout

rollout(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> tuple[Tensor, ...]

Return one predicted state per autoregressive action step.

Source code in geno_lewm/predictor/ar.py
@_INFERENCE_MODE
def rollout(
    self,
    state: Tensor,
    action_sequence: Tensor | Sequence[Tensor],
    action_mask: Tensor | None = None,
) -> tuple[Tensor, ...]:
    """Return one predicted state per autoregressive action step."""
    return tuple(self.rollout_tensor(state, action_sequence, action_mask).unbind(dim=1))

rollout_tensor

rollout_tensor(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> Tensor

Return autoregressive rollout as (batch, steps, d_state).

Source code in geno_lewm/predictor/ar.py
@_INFERENCE_MODE
def rollout_tensor(
    self,
    state: Tensor,
    action_sequence: Tensor | Sequence[Tensor],
    action_mask: Tensor | None = None,
) -> Tensor:
    """Return autoregressive rollout as ``(batch, steps, d_state)``."""
    actions = self._normalize_actions(action_sequence)
    self._validate_state(state, actions)

    current = state
    call_mask = torch.ones((actions.shape[0], 1), dtype=torch.bool, device=actions.device)
    action_tokens = self._cached_action_tokens(actions)
    action_cache = (
        self._cached_rollout_action_cache(action_tokens) if action_mask is None else None
    )
    state_token_bias = self._cached_state_token_bias(state)
    upcast_output_mlp = actions.shape[1] > 20
    output_dtype = torch.float32 if upcast_output_mlp else state.dtype
    outputs = state.new_empty(
        (actions.shape[0], actions.shape[1], self.d_state),
        dtype=output_dtype,
    )
    if action_mask is None:
        if action_tokens is None:
            for step in range(actions.shape[1]):
                prediction = self.predictor(
                    current,
                    actions[:, step : step + 1, :],
                    call_mask,
                )
                current = prediction[:, 0, :]
                outputs[:, step, :] = current
            return outputs

        one_step_state = getattr(
            self.predictor,
            "_forward_one_step_unmasked_state_from_action_token",
            None,
        )

        one_step_unmasked = getattr(
            self.predictor,
            "_forward_one_step_unmasked_from_action_token",
            None,
        )
        select_action_cache = getattr(
            self.predictor,
            "_slice_rollout_action_cache",
            None,
        )
        if callable(one_step_state):
            for step in range(actions.shape[1]):
                step_cache = (
                    select_action_cache(action_cache, step)
                    if action_cache is not None and callable(select_action_cache)
                    else None
                )
                current = one_step_state(
                    current,
                    action_tokens[:, step, :],
                    step_cache,
                    state_token_bias=state_token_bias,
                    upcast_output_mlp=upcast_output_mlp,
                )
                outputs[:, step, :] = current
            return outputs

        if callable(one_step_unmasked):
            for step in range(actions.shape[1]):
                step_cache = (
                    select_action_cache(action_cache, step)
                    if action_cache is not None and callable(select_action_cache)
                    else None
                )
                prediction = one_step_unmasked(
                    current,
                    action_tokens[:, step, :],
                    step_cache,
                    upcast_output_mlp=upcast_output_mlp,
                )
                current = prediction[:, 0, :]
                outputs[:, step, :] = current
            return outputs

        for step in range(actions.shape[1]):
            prediction = self.predictor._forward_one_step_from_action_token(
                current,
                action_tokens[:, step, :],
                call_mask,
                upcast_output_mlp=upcast_output_mlp,
            )
            current = prediction[:, 0, :]
            outputs[:, step, :] = current
        return outputs

    mask = self._normalize_mask(actions, action_mask)
    for step in range(actions.shape[1]):
        active = mask[:, step].unsqueeze(-1)
        if action_tokens is None:
            prediction = self.predictor(
                current,
                actions[:, step : step + 1, :],
                call_mask,
            )
        else:
            prediction = self.predictor._forward_one_step_from_action_token(
                current,
                action_tokens[:, step, :],
                call_mask,
                upcast_output_mlp=upcast_output_mlp,
            )
        next_state = prediction[:, 0, :]
        current = torch.where(active, next_state, current)
        outputs[:, step, :] = torch.where(
            active,
            next_state,
            torch.zeros_like(next_state),
        )
    return outputs

predict_single

predict_single(state: Tensor, action: Tensor) -> Tensor

Return the one-step predicted state for action.

Source code in geno_lewm/predictor/ar.py
@_INFERENCE_MODE
def predict_single(self, state: Tensor, action: Tensor) -> Tensor:
    """Return the one-step predicted state for ``action``."""
    actions = self._normalize_actions(action.unsqueeze(1) if action.ndim == 2 else action)
    if actions.shape[1] != 1:
        raise InputError(
            "predict_single expects exactly one action",
            details={"steps": actions.shape[1]},
        )
    return self.rollout(state, actions)[0]

predict_trajectory

predict_trajectory(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> tuple[Tensor, ...]

Alias for :meth:rollout matching RFC-0004 terminology.

Source code in geno_lewm/predictor/ar.py
@_INFERENCE_MODE
def predict_trajectory(
    self,
    state: Tensor,
    action_sequence: Tensor | Sequence[Tensor],
    action_mask: Tensor | None = None,
) -> tuple[Tensor, ...]:
    """Alias for :meth:`rollout` matching RFC-0004 terminology."""
    return self.rollout(state, action_sequence, action_mask)

predict_haplotype

predict_haplotype(state: Tensor, action_sequence: Tensor | Sequence[Tensor], action_mask: Tensor | None = None) -> Tensor

Return the final predicted state after all valid actions.

Source code in geno_lewm/predictor/ar.py
@_INFERENCE_MODE
def predict_haplotype(
    self,
    state: Tensor,
    action_sequence: Tensor | Sequence[Tensor],
    action_mask: Tensor | None = None,
) -> Tensor:
    """Return the final predicted state after all valid actions."""
    actions = self._normalize_actions(action_sequence)
    if action_mask is None:
        return self.rollout_tensor(state, actions)[:, -1, :]
    mask = self._normalize_mask(actions, action_mask)
    trajectory = self.rollout_tensor(state, actions, mask)
    indices = mask.sum(dim=1) - 1
    rows = torch.arange(actions.shape[0], device=actions.device)
    return trajectory[rows, indices]