Skip to content

geno_lewm.action

action

Action representation for GenoLeWM.

Public surface defined by RFC-0003. The package ships the canonical edit types, pure-Python apply functions, synthetic samplers, and the optional PyTorch action encoder.

ActionEncoder

ActionEncoder(*, d_action: int = 512, d_pos: int = 128, d_type: int = 64, d_seq: int = 256, max_window_bp: int = 12288, carbon_tokenizer: Any | None = None)

Bases: Module

Encode :class:RelEdit objects into learned action embeddings.

Source code in geno_lewm/action/encoder.py
def __init__(
    self,
    *,
    d_action: int = 512,
    d_pos: int = 128,
    d_type: int = 64,
    d_seq: int = 256,
    max_window_bp: int = 12_288,
    carbon_tokenizer: Any | None = None,
) -> None:
    super().__init__()
    _require_positive("d_action", d_action)
    _require_positive("d_pos", d_pos)
    _require_positive("d_type", d_type)
    _require_positive("max_window_bp", max_window_bp)
    if d_pos % 2 != 0:
        raise InputError("d_pos must be even for sinusoidal position embeddings")
    self._d_action = d_action
    self.d_pos = d_pos
    self.max_window_bp = max_window_bp
    self.carbon_tokenizer = carbon_tokenizer
    self.type_embedding = nn.Embedding(len(EditType), d_type)
    self.seq_encoder = SeqMicroEncoder(d_seq=d_seq)
    projection_in = d_pos + d_type + (2 * d_seq)
    self.projection = nn.Sequential(
        nn.Linear(projection_in, 1024),
        nn.GELU(),
        nn.LayerNorm(1024),
        nn.Linear(1024, d_action),
    )
    self.padding_embedding = nn.Parameter(torch.zeros(d_action))

EditSpec dataclass

EditSpec(chrom: str, pos: int, ref: str, alt: str, edit_type: EditType = EditType.SNV)

A canonical, frozen genomic edit (RFC-0003 §3.1).

Construct with absolute VCF-style coordinates; the derived :attr:edit_type is filled in by __post_init__.

pos is 1-based per VCF convention; both ref and alt are explicit base strings (no <DEL> / <INS> symbolic alleles — they're deferred to v2).

relative_to

relative_to(window_start_bp: int, window_end_bp: int) -> RelEdit

Return the window-relative form (RFC-0003 §3.3).

window_start_bp and window_end_bp are 0-based inclusive coordinates on the same chromosome as :attr:chrom. The predictor sees only the relative offset; absolute coordinates never enter the model.

Source code in geno_lewm/action/spec.py
def relative_to(self, window_start_bp: int, window_end_bp: int) -> RelEdit:
    """Return the window-relative form (RFC-0003 §3.3).

    ``window_start_bp`` and ``window_end_bp`` are 0-based inclusive
    coordinates on the same chromosome as :attr:`chrom`. The
    predictor sees only the relative offset; absolute coordinates
    never enter the model.
    """
    if window_end_bp < window_start_bp:
        raise InvalidEditError(
            "window_end_bp must be >= window_start_bp",
            details={"start": window_start_bp, "end": window_end_bp},
        )
    rel_pos = self.pos - 1 - window_start_bp  # convert 1-based VCF → 0-based offset
    if rel_pos < 0 or rel_pos + len(self.ref) > (window_end_bp - window_start_bp + 1):
        raise OutOfWindowError(
            "edit falls outside the window",
            details={
                "pos": self.pos,
                "ref_len": len(self.ref),
                "window_start_bp": window_start_bp,
                "window_end_bp": window_end_bp,
                "rel_pos": rel_pos,
            },
            remediation="re-center the encoder window over the edit, or skip the edit",
        )
    return RelEdit(
        rel_pos=rel_pos,
        edit_type=self.edit_type,
        ref_bases=self.ref,
        alt_bases=self.alt,
    )

EditType

Bases: IntEnum

The six v1 edit categories (RFC-0003 §3.2).

Members are deterministic functions of (len(ref), len(alt)) — callers do not pass this value; it is computed during construction.

RelEdit dataclass

RelEdit(rel_pos: int, edit_type: EditType, ref_bases: str, alt_bases: str)

Window-relative form consumed by the action encoder.

apply_edit

apply_edit(window: str, edit: RelEdit, *, preserve_length: bool = False) -> str

Return window with edit applied.

window is the pre-edit base string (uppercase ACGTN). The function does not validate window contents beyond what the edit locus requires; that is the caller's responsibility.

The reference bases at the edit locus must match edit.ref_bases case-insensitively — otherwise :class:WindowMismatchError is raised with the locus context attached.

Pass preserve_length=True to truncate / pad the result back to the original window length on the side opposite the edit. The default leaves the indel length change intact (length-preserving is the trainer's responsibility for s_{t+1} encoding).

Source code in geno_lewm/action/apply.py
def apply_edit(window: str, edit: RelEdit, *, preserve_length: bool = False) -> str:
    """Return ``window`` with ``edit`` applied.

    ``window`` is the pre-edit base string (uppercase ACGTN). The
    function does not validate window contents beyond what the edit
    locus requires; that is the caller's responsibility.

    The reference bases at the edit locus must match ``edit.ref_bases``
    case-insensitively — otherwise :class:`WindowMismatchError` is
    raised with the locus context attached.

    Pass ``preserve_length=True`` to truncate / pad the result back to
    the original window length on the side opposite the edit. The
    default leaves the indel length change intact (length-preserving
    is the trainer's responsibility for ``s_{t+1}`` encoding).
    """
    original_len = len(window)
    end = edit.rel_pos + len(edit.ref_bases)
    if edit.rel_pos < 0 or end > original_len:
        raise OutOfWindowError(
            "edit locus is outside the window",
            details={
                "rel_pos": edit.rel_pos,
                "ref_len": len(edit.ref_bases),
                "window_len": original_len,
            },
        )

    observed = window[edit.rel_pos : end]
    if observed.upper() != edit.ref_bases.upper():
        raise WindowMismatchError(
            "window bases do not match edit.ref_bases at locus",
            details={
                "rel_pos": edit.rel_pos,
                "expected_ref": edit.ref_bases,
                "observed_ref": observed,
            },
            remediation="re-fetch the window, or correct the EditSpec.ref",
        )

    edited = window[: edit.rel_pos] + edit.alt_bases + window[end:]

    if not preserve_length:
        return edited

    return _truncate_or_pad(edited, original_len, edit_locus=edit.rel_pos)

apply_edits

apply_edits(window: str, edits: Sequence[RelEdit], *, preserve_length: bool = False) -> str

Apply a sequence of edits to window.

The edits are sorted by descending rel_pos and applied in that order (INV-ARCH-4). Edits must not overlap in genomic coordinates; overlap raises :class:OverlappingEditsError.

Equivalent inputs (same set of edits in any caller-supplied order) produce equivalent outputs — the function is order-invariant after the internal sort, which is the property the training pipeline relies on.

The preserve_length flag truncates / pads back to the input window length using the position of the first (left-most) edit as the reference locus, so the side opposite the edit cluster is the one trimmed.

Source code in geno_lewm/action/apply.py
def apply_edits(
    window: str,
    edits: Sequence[RelEdit],
    *,
    preserve_length: bool = False,
) -> str:
    """Apply a sequence of edits to ``window``.

    The edits are sorted by descending ``rel_pos`` and applied in that
    order (INV-ARCH-4). Edits must not overlap in genomic coordinates;
    overlap raises :class:`OverlappingEditsError`.

    Equivalent inputs (same set of edits in any caller-supplied order)
    produce equivalent outputs — the function is order-invariant after
    the internal sort, which is the property the training pipeline
    relies on.

    The ``preserve_length`` flag truncates / pads back to the input
    window length using the position of the **first** (left-most)
    edit as the reference locus, so the side opposite the edit cluster
    is the one trimmed.
    """
    if not edits:
        return window

    _assert_disjoint(edits)

    # Apply right-to-left. With preserve_length=False on the inner
    # calls so we only truncate once at the end (intermediate lengths
    # change with indels, which is fine).
    ordered = sorted(edits, key=lambda e: e.rel_pos, reverse=True)
    out = window
    for edit in ordered:
        out = apply_edit(out, edit, preserve_length=False)

    if not preserve_length:
        return out

    leftmost = min(e.rel_pos for e in edits)
    return _truncate_or_pad(out, len(window), edit_locus=leftmost)

indel

indel(window: str, n: int, *, rng: Random, length_dist: Mapping[int, float] | Sequence[float] | None = None, type_mix: tuple[float, float] = (0.5, 0.5), edge_margin: int = DEFAULT_EDGE_MARGIN) -> list[RelEdit]

Sample n indels (INS or DEL).

length_dist is the event length (number of bases inserted or deleted, exclusive of the VCF anchor base). Default is a truncated geometric over [1, V1_MAX_LEN-1].

type_mix is (p_ins, p_del). Default 50/50.

Source code in geno_lewm/action/synthetic.py
def indel(
    window: str,
    n: int,
    *,
    rng: random.Random,
    length_dist: Mapping[int, float] | Sequence[float] | None = None,
    type_mix: tuple[float, float] = (0.5, 0.5),
    edge_margin: int = DEFAULT_EDGE_MARGIN,
) -> list[RelEdit]:
    """Sample ``n`` indels (INS or DEL).

    ``length_dist`` is the *event* length (number of bases inserted or
    deleted, exclusive of the VCF anchor base). Default is a truncated
    geometric over ``[1, V1_MAX_LEN-1]``.

    ``type_mix`` is ``(p_ins, p_del)``. Default 50/50.
    """
    _validate_window(window, edge_margin)
    if n < 0:
        raise InputError("n must be non-negative", details={"n": n})
    if any(p < 0 for p in type_mix) or sum(type_mix) <= 0:
        raise InputError(
            "type_mix must contain non-negative probs that sum > 0",
            details={"type_mix": list(type_mix)},
        )

    p_ins = type_mix[0] / sum(type_mix)

    out: list[RelEdit] = []
    # Each requested indel resamples on a non-ACGT anchor or an N-containing
    # deletion segment so the sampler reliably returns ``n`` edits on windows
    # with occasional N bases (e.g. the Carbon pretraining corpus), matching
    # uniform_snv. Without this, a single N hit dropped a slot and returned
    # fewer than ``n`` edits, which the data builder treats as a hard error for
    # sources (synthetic_indel) that have no fallback. Bound the total attempts
    # so a pathological all-N window fails loudly instead of looping forever.
    # On all-ACGT windows every attempt succeeds first try, so the draw sequence
    # (and output) is identical to a plain ``for _ in range(n)`` loop.
    max_attempts = n * 16 + 16
    attempts = 0
    while len(out) < n and attempts < max_attempts:
        attempts += 1
        pos = _pick_position(rng, len(window), edge_margin)
        ref_anchor = window[pos]
        if ref_anchor not in _OTHER_BASE:
            continue  # non-ACGT anchor; resample
        # Event length in [1, V1_MAX_LEN-1] so total ref or alt length ≤ V1_MAX_LEN.
        # We respect the caller's distribution but clip to V1_MAX_LEN-1.
        ev_len = min(_draw_indel_length(rng, length_dist), V1_MAX_LEN - 1)

        if rng.random() < p_ins:
            # Insertion: ref = anchor, alt = anchor + ev_len random bases.
            inserted = _rand_bases(rng, ev_len)
            out.append(
                RelEdit(
                    rel_pos=pos,
                    edit_type=EditType.INS,
                    ref_bases=ref_anchor,
                    alt_bases=ref_anchor + inserted,
                )
            )
            continue

        # Deletion: ref = anchor + ev_len following bases, alt = anchor.
        end = pos + 1 + ev_len
        if end > len(window) - edge_margin:
            # Cannot fit deletion without crossing right margin; emit INS instead.
            inserted = _rand_bases(rng, ev_len)
            out.append(
                RelEdit(
                    rel_pos=pos,
                    edit_type=EditType.INS,
                    ref_bases=ref_anchor,
                    alt_bases=ref_anchor + inserted,
                )
            )
            continue
        ref_seg = window[pos:end]
        # Resample when the ref segment contains N's (cannot build a valid RelEdit).
        if any(c not in _OTHER_BASE for c in ref_seg):
            continue
        out.append(
            RelEdit(
                rel_pos=pos,
                edit_type=EditType.DEL,
                ref_bases=ref_seg,
                alt_bases=ref_anchor,
            )
        )
    if len(out) < n:
        raise InputError(
            "could not sample enough indels in the window's interior (too many N bases)",
            details={
                "requested": n,
                "produced": len(out),
                "window_len": len(window),
                "edge_margin": edge_margin,
            },
        )
    return out

mnv

mnv(window: str, n: int, *, rng: Random, length_dist: Mapping[int, float] | Sequence[float] | None = None, edge_margin: int = DEFAULT_EDGE_MARGIN) -> list[RelEdit]

Sample n MNVs (length-preserving multi-base substitutions).

Length is drawn from length_dist (default uniform over [2, 8] per RFC text). The alt is guaranteed different from ref at every base (otherwise constructing a RelEdit with that ref/alt would be rejected by EditSpec validation).

Source code in geno_lewm/action/synthetic.py
def mnv(
    window: str,
    n: int,
    *,
    rng: random.Random,
    length_dist: Mapping[int, float] | Sequence[float] | None = None,
    edge_margin: int = DEFAULT_EDGE_MARGIN,
) -> list[RelEdit]:
    """Sample ``n`` MNVs (length-preserving multi-base substitutions).

    Length is drawn from ``length_dist`` (default uniform over [2, 8]
    per RFC text). The alt is guaranteed different from ref at every
    base (otherwise constructing a RelEdit with that ref/alt would be
    rejected by EditSpec validation).
    """
    _validate_window(window, edge_margin)
    if n < 0:
        raise InputError("n must be non-negative", details={"n": n})

    if length_dist is None:
        length_dist = dict.fromkeys(range(2, 9), 1.0)  # uniform on [2, 8]

    out: list[RelEdit] = []
    for _ in range(n):
        pos = _pick_position(rng, len(window), edge_margin)
        length = max(2, min(_draw_indel_length(rng, length_dist), V1_MAX_LEN))
        end = pos + length
        if end > len(window) - edge_margin:
            continue
        ref_seg = window[pos:end]
        if any(c not in _OTHER_BASE for c in ref_seg):
            continue
        # Build alt by perturbing every base to a non-self draw.
        alt_chars = [rng.choice(_OTHER_BASE[c]) for c in ref_seg]
        alt_seg = "".join(alt_chars)
        if alt_seg == ref_seg:
            continue  # extremely unlikely; skip
        out.append(
            RelEdit(
                rel_pos=pos,
                edit_type=EditType.MNV,
                ref_bases=ref_seg,
                alt_bases=alt_seg,
            )
        )
    return out

uniform_snv

uniform_snv(window: str, n: int, *, rng: Random, edge_margin: int = DEFAULT_EDGE_MARGIN) -> list[RelEdit]

Sample n uniform SNVs anchored inside window.

Each SNV's alt is uniformly drawn from the three non-reference bases at the chosen position, so the contract "alt is always non-reference" is enforced by construction.

Returns edits in the order they were sampled. The list may contain duplicates by position — the caller (data pipeline) is responsible for deduplication if it needs disjoint edits.

Source code in geno_lewm/action/synthetic.py
def uniform_snv(
    window: str,
    n: int,
    *,
    rng: random.Random,
    edge_margin: int = DEFAULT_EDGE_MARGIN,
) -> list[RelEdit]:
    """Sample ``n`` uniform SNVs anchored inside ``window``.

    Each SNV's ``alt`` is uniformly drawn from the three non-reference
    bases at the chosen position, so the contract "alt is always
    non-reference" is enforced by construction.

    Returns edits in the order they were sampled. The list may contain
    duplicates by position — the caller (data pipeline) is responsible
    for deduplication if it needs disjoint edits.
    """
    _validate_window(window, edge_margin)
    if n < 0:
        raise InputError("n must be non-negative", details={"n": n})

    out: list[RelEdit] = []
    for _ in range(n):
        pos = _pick_position(rng, len(window), edge_margin)
        ref = window[pos]
        if ref not in _OTHER_BASE:
            # Window contains 'N' or other non-ACGT at this position; resample.
            # Simple bounded retry; if window is mostly N's the caller
            # should not be using a synthetic sampler.
            for _retry in range(10):
                pos = _pick_position(rng, len(window), edge_margin)
                ref = window[pos]
                if ref in _OTHER_BASE:
                    break
            else:  # pragma: no cover - defensive
                raise InputError(
                    "could not find an ACGT position in the window's interior",
                    details={"window_len": len(window), "edge_margin": edge_margin},
                )
        alt = rng.choice(_OTHER_BASE[ref])
        out.append(RelEdit(rel_pos=pos, edit_type=EditType.SNV, ref_bases=ref, alt_bases=alt))
    return out