Skip to content

geno_lewm.encoder.pooling

pooling

Pooling strategies for Carbon hidden states.

Defined by RFC-0002 §3.4. This module is intentionally independent of torch so the pooling contract, cache metadata, and downstream schema behavior can be validated before the Carbon runtime wrapper lands.

PoolingResult dataclass

PoolingResult(vector: tuple[float, ...], pool_type: Literal['centered_mean', 'global_mean'], pool_radius: int, untargeted: bool, center_token: int | None, token_count: int)

Pooled state vector plus cache-key metadata.

d_state property

d_state: int

Return the pooled vector width.

as_cache_fields

as_cache_fields() -> Mapping[str, object]

Return fields shared with the window-cache schema.

Source code in geno_lewm/encoder/pooling.py
def as_cache_fields(self) -> Mapping[str, object]:
    """Return fields shared with the window-cache schema."""
    return {
        "pool_type": self.pool_type,
        "pool_radius": self.pool_radius,
        "untargeted": self.untargeted,
    }

global_mean

global_mean(hidden_states: Sequence[Sequence[float]]) -> tuple[float, ...]

Mean-pool every token vector in hidden_states.

Source code in geno_lewm/encoder/pooling.py
def global_mean(hidden_states: Sequence[Sequence[float]]) -> tuple[float, ...]:
    """Mean-pool every token vector in ``hidden_states``."""
    rows = _coerce_hidden_states(hidden_states)
    return _mean_rows(rows)

centered_mean

centered_mean(hidden_states: Sequence[Sequence[float]], *, center_token: int, pool_radius: int = DEFAULT_POOL_RADIUS_TOKENS) -> tuple[float, ...]

Mean-pool the inclusive token span center_token ± pool_radius.

Source code in geno_lewm/encoder/pooling.py
def centered_mean(
    hidden_states: Sequence[Sequence[float]],
    *,
    center_token: int,
    pool_radius: int = DEFAULT_POOL_RADIUS_TOKENS,
) -> tuple[float, ...]:
    """Mean-pool the inclusive token span ``center_token ± pool_radius``."""
    rows = _coerce_hidden_states(hidden_states)
    center = _validate_center_token(center_token, len(rows))
    radius = _validate_pool_radius(pool_radius)

    start = max(0, center - radius)
    end = min(len(rows), center + radius + 1)
    return _mean_rows(rows[start:end])

pool_hidden_states

pool_hidden_states(hidden_states: Sequence[Sequence[float]], *, edit_locus: int | None = None, pool_type: Literal['centered_mean', 'global_mean'] = POOL_CENTERED_MEAN, pool_radius: int = DEFAULT_POOL_RADIUS_TOKENS, token_bp: int = CARBON_TOKEN_BP) -> PoolingResult

Pool token-level hidden states into a state vector.

edit_locus is a 0-based base-pair offset within the encoder window. When it is absent, RFC-0002 requires a global-mean fallback tagged as untargeted=True so cache consumers do not mix arbitrary reference-window embeddings with edit-local embeddings.

Source code in geno_lewm/encoder/pooling.py
def pool_hidden_states(
    hidden_states: Sequence[Sequence[float]],
    *,
    edit_locus: int | None = None,
    pool_type: Literal["centered_mean", "global_mean"] = POOL_CENTERED_MEAN,
    pool_radius: int = DEFAULT_POOL_RADIUS_TOKENS,
    token_bp: int = CARBON_TOKEN_BP,
) -> PoolingResult:
    """Pool token-level hidden states into a state vector.

    ``edit_locus`` is a 0-based base-pair offset within the encoder
    window. When it is absent, RFC-0002 requires a global-mean fallback
    tagged as ``untargeted=True`` so cache consumers do not mix arbitrary
    reference-window embeddings with edit-local embeddings.
    """
    rows = _coerce_hidden_states(hidden_states)
    requested_type = _validate_pool_type(pool_type)
    radius = _validate_pool_radius(pool_radius)

    if edit_locus is None:
        return PoolingResult(
            vector=_mean_rows(rows),
            pool_type=POOL_GLOBAL_MEAN,
            pool_radius=0,
            untargeted=True,
            center_token=None,
            token_count=len(rows),
        )

    center_token = _edit_locus_to_token(edit_locus, token_count=len(rows), token_bp=token_bp)
    if requested_type == POOL_GLOBAL_MEAN:
        return PoolingResult(
            vector=_mean_rows(rows),
            pool_type=POOL_GLOBAL_MEAN,
            pool_radius=0,
            untargeted=False,
            center_token=None,
            token_count=len(rows),
        )

    return PoolingResult(
        vector=centered_mean(rows, center_token=center_token, pool_radius=radius),
        pool_type=POOL_CENTERED_MEAN,
        pool_radius=radius,
        untargeted=False,
        center_token=center_token,
        token_count=len(rows),
    )