Skip to content

geno_lewm.action.encoder

encoder

PyTorch action encoder for window-relative genomic edits.

The module keeps PyTorch optional so importing :mod:geno_lewm.action does not pull the training stack into the base package. Instantiate ActionEncoder only in a geno-lewm[train] environment.

SeqMicroEncoder

SeqMicroEncoder(*, d_seq: int)

Bases: Module

Shared 6-mer micro-encoder for reference and alternate bases.

Source code in geno_lewm/action/encoder.py
def __init__(self, *, d_seq: int) -> None:
    super().__init__()
    _require_positive("d_seq", d_seq)
    if d_seq % 4 != 0:
        raise InputError("d_seq must be divisible by 4 so attention heads divide evenly")
    token_dim = min(_TOKEN_EMBED_DIM, d_seq)
    self.token_embedding = nn.Embedding(_VOCAB_SIZE, token_dim, padding_idx=_OOV_TOKEN_ID)
    self.token_projection = (
        nn.Identity() if token_dim == d_seq else nn.Linear(token_dim, d_seq)
    )
    layer = nn.TransformerEncoderLayer(
        d_model=d_seq,
        nhead=4,
        dim_feedforward=d_seq,
        activation="gelu",
        batch_first=True,
        norm_first=False,
        dropout=0.0,
    )
    self.encoder = nn.TransformerEncoder(layer, num_layers=2)

ActionEncoder

ActionEncoder(*, d_action: int = 512, d_pos: int = 128, d_type: int = 64, d_seq: int = 256, max_window_bp: int = 12288, carbon_tokenizer: Any | None = None)

Bases: Module

Encode :class:RelEdit objects into learned action embeddings.

Source code in geno_lewm/action/encoder.py
def __init__(
    self,
    *,
    d_action: int = 512,
    d_pos: int = 128,
    d_type: int = 64,
    d_seq: int = 256,
    max_window_bp: int = 12_288,
    carbon_tokenizer: Any | None = None,
) -> None:
    super().__init__()
    _require_positive("d_action", d_action)
    _require_positive("d_pos", d_pos)
    _require_positive("d_type", d_type)
    _require_positive("max_window_bp", max_window_bp)
    if d_pos % 2 != 0:
        raise InputError("d_pos must be even for sinusoidal position embeddings")
    self._d_action = d_action
    self.d_pos = d_pos
    self.max_window_bp = max_window_bp
    self.carbon_tokenizer = carbon_tokenizer
    self.type_embedding = nn.Embedding(len(EditType), d_type)
    self.seq_encoder = SeqMicroEncoder(d_seq=d_seq)
    projection_in = d_pos + d_type + (2 * d_seq)
    self.projection = nn.Sequential(
        nn.Linear(projection_in, 1024),
        nn.GELU(),
        nn.LayerNorm(1024),
        nn.Linear(1024, d_action),
    )
    self.padding_embedding = nn.Parameter(torch.zeros(d_action))