PyTorch action encoder for window-relative genomic edits.
The module keeps PyTorch optional so importing :mod:geno_lewm.action
does not pull the training stack into the base package. Instantiate
ActionEncoder only in a geno-lewm[train] environment.
SeqMicroEncoder
SeqMicroEncoder(*, d_seq: int)
Bases: Module
Shared 6-mer micro-encoder for reference and alternate bases.
Source code in geno_lewm/action/encoder.py
| def __init__(self, *, d_seq: int) -> None:
super().__init__()
_require_positive("d_seq", d_seq)
if d_seq % 4 != 0:
raise InputError("d_seq must be divisible by 4 so attention heads divide evenly")
token_dim = min(_TOKEN_EMBED_DIM, d_seq)
self.token_embedding = nn.Embedding(_VOCAB_SIZE, token_dim, padding_idx=_OOV_TOKEN_ID)
self.token_projection = (
nn.Identity() if token_dim == d_seq else nn.Linear(token_dim, d_seq)
)
layer = nn.TransformerEncoderLayer(
d_model=d_seq,
nhead=4,
dim_feedforward=d_seq,
activation="gelu",
batch_first=True,
norm_first=False,
dropout=0.0,
)
self.encoder = nn.TransformerEncoder(layer, num_layers=2)
|
ActionEncoder
ActionEncoder(*, d_action: int = 512, d_pos: int = 128, d_type: int = 64, d_seq: int = 256, max_window_bp: int = 12288, carbon_tokenizer: Any | None = None)
Bases: Module
Encode :class:RelEdit objects into learned action embeddings.
Source code in geno_lewm/action/encoder.py
| def __init__(
self,
*,
d_action: int = 512,
d_pos: int = 128,
d_type: int = 64,
d_seq: int = 256,
max_window_bp: int = 12_288,
carbon_tokenizer: Any | None = None,
) -> None:
super().__init__()
_require_positive("d_action", d_action)
_require_positive("d_pos", d_pos)
_require_positive("d_type", d_type)
_require_positive("max_window_bp", max_window_bp)
if d_pos % 2 != 0:
raise InputError("d_pos must be even for sinusoidal position embeddings")
self._d_action = d_action
self.d_pos = d_pos
self.max_window_bp = max_window_bp
self.carbon_tokenizer = carbon_tokenizer
self.type_embedding = nn.Embedding(len(EditType), d_type)
self.seq_encoder = SeqMicroEncoder(d_seq=d_seq)
projection_in = d_pos + d_type + (2 * d_seq)
self.projection = nn.Sequential(
nn.Linear(projection_in, 1024),
nn.GELU(),
nn.LayerNorm(1024),
nn.Linear(1024, d_action),
)
self.padding_embedding = nn.Parameter(torch.zeros(d_action))
|