Skip to content

geno_lewm.data

data

Data pipeline helpers for GenoLeWM.

DEFAULT_EDIT_SOURCE_COUNTS module-attribute

DEFAULT_EDIT_SOURCE_COUNTS: tuple[EditSourceCount, ...] = (EditSourceCount(SOURCE_GNOMAD_COMMON, 3), EditSourceCount(SOURCE_SYNTHETIC_SNV, 3), EditSourceCount(SOURCE_SYNTHETIC_INDEL, 1), EditSourceCount(SOURCE_CLINVAR, 1))

RFC-0006 §3.3 per-window source allocation for N_edits = 8.

DEFAULT_SOURCE_FALLBACKS module-attribute

DEFAULT_SOURCE_FALLBACKS: dict[str, str] = {SOURCE_CLINVAR: SOURCE_SYNTHETIC_SNV, SOURCE_GNOMAD_COMMON: SOURCE_SYNTHETIC_SNV}

Default fallback when an absolute VCF edit is unavailable for a window.

ClinVar hard-negatives and gnomAD common variants are placed (absolute) sources: they only apply to windows that carry genome coordinates. On unplaced windows (the synthetic Carbon pretraining corpus) the absolute providers yield nothing and the builder draws synthetic SNVs instead, so pretraining-corpus windows still produce full edit tuples.

EditSourceCount dataclass

EditSourceCount(source: str, count: int)

Number of edits to draw from one RFC-0006 source per window.

GenoLeWMDataset

GenoLeWMDataset(windows: _WindowSource, providers: Mapping[str, _EditProvider], *, seed: int, mix: Sequence[EditSourceCount] = DEFAULT_EDIT_SOURCE_COUNTS, holdouts: HoldoutPolicy | None = None, fallback_sources: Mapping[str, str] | None = DEFAULT_SOURCE_FALLBACKS, preserve_length: bool = True)

Bases: _load_iterable_dataset_base()

Deterministic iterable dataset over windows and edit-source providers.

The class subclasses torch.utils.data.IterableDataset when torch is installed, but falls back to a plain Python iterable in core/dev environments. That keeps the data contract testable without pulling in the full training extra.

Source code in geno_lewm/data/builder.py
def __init__(
    self,
    windows: _WindowSource,
    providers: Mapping[str, _EditProvider],
    *,
    seed: int,
    mix: Sequence[EditSourceCount] = DEFAULT_EDIT_SOURCE_COUNTS,
    holdouts: HoldoutPolicy | None = None,
    fallback_sources: Mapping[str, str] | None = DEFAULT_SOURCE_FALLBACKS,
    preserve_length: bool = True,
) -> None:
    _require_nonnegative_int("seed", seed)
    if not providers:
        raise InputError("providers must contain at least one edit source")
    self.windows = windows
    self.providers = dict(providers)
    self.seed = seed
    self.mix = _normalize_mix(mix)
    self.holdouts = holdouts
    self.fallback_sources = dict(fallback_sources or {})
    self.preserve_length = preserve_length

__iter__

__iter__() -> Iterator[TrainingTuple]

Yield training tuples suitable for a PyTorch DataLoader.

Source code in geno_lewm/data/builder.py
def __iter__(self) -> Iterator[TrainingTuple]:
    """Yield training tuples suitable for a PyTorch DataLoader."""
    for item in self.iter_with_source_windows():
        yield item.training_tuple

iter_with_source_windows

iter_with_source_windows() -> Iterator[TrainingDatasetItem]

Yield tuples together with their source windows for trainer encoding.

Source code in geno_lewm/data/builder.py
def iter_with_source_windows(self) -> Iterator[TrainingDatasetItem]:
    """Yield tuples together with their source windows for trainer encoding."""
    worker = _torch_worker_info()
    rng = random.Random(self.seed + worker.id)
    for index, window in enumerate(_iter_window_source(self.windows)):
        if index % worker.num_workers != worker.id:
            continue
        if not isinstance(window, WindowContext):
            raise InputError(
                "window source must yield WindowContext values",
                details={"type": type(window).__name__},
            )
        for item in build_training_tuples(
            window,
            self.providers,
            rng=rng,
            mix=self.mix,
            holdouts=self.holdouts,
            fallback_sources=self.fallback_sources,
            preserve_length=self.preserve_length,
        ):
            yield TrainingDatasetItem(source_window=window, training_tuple=item)

HoldoutInterval dataclass

HoldoutInterval(chrom: str, start_bp: int, end_bp: int)

0-based half-open genomic interval excluded from training.

intersects

intersects(chrom: str | None, start_bp: int, end_bp: int) -> bool

Return whether [start_bp, end_bp) intersects this interval.

Source code in geno_lewm/data/builder.py
def intersects(self, chrom: str | None, start_bp: int, end_bp: int) -> bool:
    """Return whether ``[start_bp, end_bp)`` intersects this interval."""
    if chrom != self.chrom:
        return False
    return start_bp < self.end_bp and self.start_bp < end_bp

HoldoutPolicy dataclass

HoldoutPolicy(holdout_chroms: tuple[str, ...] = (), intervals: tuple[HoldoutInterval, ...] = (), edit_keys: tuple[str, ...] = (), record_ids: tuple[str, ...] = ())

Holdout exclusions enforced before a tuple reaches the trainer.

excludes_window

excludes_window(window: WindowContext) -> bool

Return whether the entire source window is in a holdout.

Source code in geno_lewm/data/builder.py
def excludes_window(self, window: WindowContext) -> bool:
    """Return whether the entire source window is in a holdout."""
    if not isinstance(window, WindowContext):
        raise InputError(
            "window must be a WindowContext",
            details={"type": type(window).__name__},
        )
    if window.record_id in self.record_ids:
        return True
    if window.chrom in self.holdout_chroms:
        return True
    return any(
        interval.intersects(window.chrom, window.start_bp, window.end_bp)
        for interval in self.intervals
    )

excludes_edit

excludes_edit(window: WindowContext, edit: RelEdit) -> bool

Return whether one relative edit intersects an edit-level holdout.

Source code in geno_lewm/data/builder.py
def excludes_edit(self, window: WindowContext, edit: RelEdit) -> bool:
    """Return whether one relative edit intersects an edit-level holdout."""
    if window.chrom is None:
        return False
    if window.chrom in self.holdout_chroms:
        return True
    edit_start = window.start_bp + edit.rel_pos
    edit_end = edit_start + len(edit.ref_bases)
    if any(
        interval.intersects(window.chrom, edit_start, edit_end) for interval in self.intervals
    ):
        return True
    return _edit_key(window.chrom, edit_start + 1, edit.ref_bases, edit.alt_bases) in set(
        self.edit_keys
    )

TrainingDatasetItem dataclass

TrainingDatasetItem(source_window: WindowContext, training_tuple: TrainingTuple)

One stream item with the source window needed for trainer encoding.

TrainingTuple dataclass

TrainingTuple(window_id: str, source_record_id: str, edit_source: str, rel_edits: tuple[RelEdit, ...], target_window: str, window_start_bp: int, window_end_bp: int)

One RFC-0006 (window_id, action, target_window) training item.

WindowContext dataclass

WindowContext(record_id: str, source: str, sequence: str, start_bp: int = 0, chrom: str | None = None)

One reference window plus source coordinates for tuple building.

Coordinates are 0-based half-open: start_bp is inclusive and end_bp is exclusive. chrom is required for absolute variant providers and chromosome/interval holdouts, but synthetic providers can operate on unplaced Carbon windows.

end_bp property

end_bp: int

Return the 0-based exclusive end coordinate.

window_id property

window_id: str

Return the content hash used for cache lookup.

ClinvarPrepareReport dataclass

ClinvarPrepareReport(output_path: Path, release: str, records_read: int, allele_records_seen: int, records_written: int, skipped_allele: int, size_bytes: int, already_exists: bool = False)

Summary emitted by geno-lewm-prepare-clinvar.

ClinvarVariant dataclass

ClinvarVariant(chrom: str, pos: int, ref: str, alt: str, clinical_significance: str, review_status: str, gene_symbol: str | None, clinvar_id: int, schema_version: str = CLINVAR_SCHEMA_VERSION)

One normalized ClinVar row.

CarbonCorpusConfig dataclass

CarbonCorpusConfig(dataset_id: str = DEFAULT_CARBON_DATASET_ID, dataset_config: str | None = None, revision: str | None = None, default_source: str | None = None, skip_invalid: bool = False, split: str = 'train', streaming: bool = True, subset_fraction: float = DEFAULT_PHASE1_SUBSET_FRACTION, subset_seed: int = 0, sequence_field: str = DEFAULT_SEQUENCE_FIELD, source_field: str = DEFAULT_SOURCE_FIELD, source_id_field: str = DEFAULT_SOURCE_ID_FIELD, window_bp: int = DEFAULT_WINDOW_BP, margin_bp: int = DEFAULT_CORPUS_MARGIN_BP, stride_bp: int = DEFAULT_CORPUS_STRIDE_BP)

Configuration for reading and windowing the Carbon pretraining corpus.

CarbonRecord dataclass

CarbonRecord(record_id: str, source: str, sequence: str)

Canonicalized source sequence record from the Carbon corpus.

length_bp property

length_bp: int

Return the canonical DNA sequence length in base pairs.

CarbonSourceMix dataclass

CarbonSourceMix(source: str, fraction: float)

One source bucket in the RFC-0006 Carbon sub-mix.

CarbonWindow dataclass

CarbonWindow(record_id: str, source: str, start_bp: int, end_bp: int, sequence: str)

A fixed-width training window sampled from a Carbon corpus record.

window_bp property

window_bp: int

Return the window length in base pairs.

window_id property

window_id: str

Return the content-addressed window hash as lowercase hex.

GnomadPrepareReport dataclass

GnomadPrepareReport(output_path: Path, release: str, records_read: int, allele_records_seen: int, records_written: int, skipped_filter: int, skipped_af: int, skipped_allele: int, size_bytes: int, already_exists: bool = False)

Summary emitted by geno-lewm-prepare-gnomad.

GnomadVariant dataclass

GnomadVariant(chrom: str, pos: int, ref: str, alt: str, af_global: float, af_afr: float | None, af_ami: float | None, af_amr: float | None, af_asj: float | None, af_eas: float | None, af_fin: float | None, af_nfe: float | None, af_oth: float | None, af_sas: float | None, filter: str, schema_version: str = GNOMAD_SCHEMA_VERSION)

One normalized common-variant row for the gnomAD shard.

build_training_tuples

build_training_tuples(window: WindowContext, providers: Mapping[str, _EditProvider], *, rng: Random, mix: Sequence[EditSourceCount] = DEFAULT_EDIT_SOURCE_COUNTS, holdouts: HoldoutPolicy | None = None, fallback_sources: Mapping[str, str] | None = DEFAULT_SOURCE_FALLBACKS, preserve_length: bool = True) -> tuple[TrainingTuple, ...]

Build per-window training tuples with source mix and holdout checks.

providers map source names to callables returning relative edits for that window. The default mix encodes RFC-0006's 3/3/1/1 gnomAD/synthetic-SNV/synthetic-indel/ClinVar allocation. If a source cannot produce enough edits, only explicitly configured fallbacks are used; missing gnomAD data therefore fails instead of silently turning the training stream synthetic.

Source code in geno_lewm/data/builder.py
def build_training_tuples(
    window: WindowContext,
    providers: Mapping[str, _EditProvider],
    *,
    rng: random.Random,
    mix: Sequence[EditSourceCount] = DEFAULT_EDIT_SOURCE_COUNTS,
    holdouts: HoldoutPolicy | None = None,
    fallback_sources: Mapping[str, str] | None = DEFAULT_SOURCE_FALLBACKS,
    preserve_length: bool = True,
) -> tuple[TrainingTuple, ...]:
    """Build per-window training tuples with source mix and holdout checks.

    ``providers`` map source names to callables returning relative edits
    for that window. The default mix encodes RFC-0006's 3/3/1/1
    gnomAD/synthetic-SNV/synthetic-indel/ClinVar allocation. If a source
    cannot produce enough edits, only explicitly configured fallbacks are
    used; missing gnomAD data therefore fails instead of silently turning
    the training stream synthetic.
    """
    if not isinstance(window, WindowContext):
        raise InputError("window must be a WindowContext")
    if not isinstance(rng, random.Random):
        raise InputError("rng must be a random.Random instance")
    active_holdouts = holdouts if holdouts is not None else HoldoutPolicy()
    if active_holdouts.excludes_window(window):
        return ()

    source_mix = _normalize_mix(mix)
    fallbacks = dict(fallback_sources or {})
    tuples: list[TrainingTuple] = []
    for entry in source_mix:
        if entry.count == 0:
            continue
        edits = _sample_edits(
            source=entry.source,
            count=entry.count,
            window=window,
            providers=providers,
            rng=rng,
            holdouts=active_holdouts,
            fallback_sources=fallbacks,
        )
        tuples.extend(
            _tuple_for_edit(
                window,
                edit,
                source=source,
                preserve_length=preserve_length,
            )
            for source, edit in edits
        )
    return tuple(tuples)

synthetic_indel_provider

synthetic_indel_provider(window: WindowContext, count: int, rng: Random) -> tuple[RelEdit, ...]

Provider for RFC-0006 synthetic indels.

Source code in geno_lewm/data/builder.py
def synthetic_indel_provider(
    window: WindowContext, count: int, rng: random.Random
) -> tuple[RelEdit, ...]:
    """Provider for RFC-0006 synthetic indels."""
    _require_nonnegative_int("count", count)
    return tuple(indel(window.sequence, count, rng=rng))

synthetic_snv_provider

synthetic_snv_provider(window: WindowContext, count: int, rng: Random) -> tuple[RelEdit, ...]

Provider for RFC-0006 uniform synthetic SNVs.

Source code in geno_lewm/data/builder.py
def synthetic_snv_provider(
    window: WindowContext, count: int, rng: random.Random
) -> tuple[RelEdit, ...]:
    """Provider for RFC-0006 uniform synthetic SNVs."""
    _require_nonnegative_int("count", count)
    return tuple(uniform_snv(window.sequence, count, rng=rng))

variant_provider

variant_provider(variants: Sequence[EditSpec]) -> _EditProvider

Return a provider backed by absolute VCF-style variants.

Source code in geno_lewm/data/builder.py
def variant_provider(variants: Sequence[EditSpec]) -> _EditProvider:
    """Return a provider backed by absolute VCF-style variants."""
    normalized = tuple(_require_edit_spec(value) for value in variants)
    by_chrom: dict[str, tuple[tuple[int, ...], tuple[EditSpec, ...]]] = {}
    chroms = sorted({variant.chrom for variant in normalized})
    for chrom in chroms:
        ordered = tuple(sorted((item for item in normalized if item.chrom == chrom), key=_edit_pos))
        by_chrom[chrom] = (tuple(item.pos for item in ordered), ordered)

    def _provider(window: WindowContext, count: int, rng: random.Random) -> tuple[RelEdit, ...]:
        _require_nonnegative_int("count", count)
        if count == 0:
            return ()
        if window.chrom is None:
            # Unplaced windows (e.g. the synthetic Carbon pretraining corpus)
            # carry no genome coordinates, so absolute VCF variants cannot be
            # mapped onto them. Yield nothing and let the source fallback supply
            # synthetic edits (see DEFAULT_SOURCE_FALLBACKS). Placed windows with
            # a chrom still receive their real gnomAD/ClinVar variants.
            return ()
        indexed = by_chrom.get(window.chrom)
        if indexed is None:
            return ()
        positions, chrom_variants = indexed
        start = bisect_right(positions, window.start_bp)
        stop = bisect_right(positions, window.end_bp)
        candidates = [
            variant.relative_to(window.start_bp, window.end_bp - 1)
            for variant in chrom_variants[start:stop]
            if variant.pos - 1 + len(variant.ref) <= window.end_bp
        ]
        rng.shuffle(candidates)
        return tuple(candidates[:count])

    return _provider

iter_clinvar_shard

iter_clinvar_shard(path: str | Path) -> Iterator[ClinvarVariant]

Yield normalized ClinVar rows from a Parquet shard.

Source code in geno_lewm/data/clinvar.py
def iter_clinvar_shard(path: str | Path) -> Iterator[ClinvarVariant]:
    """Yield normalized ClinVar rows from a Parquet shard."""
    _pa, pq = _require_pyarrow()
    table = pq.read_table(Path(path))
    for row in table.to_pylist():
        yield ClinvarVariant(
            chrom=str(row["chrom"]),
            pos=int(row["pos"]),
            ref=str(row["ref"]),
            alt=str(row["alt"]),
            clinical_significance=str(row["clinical_significance"]),
            review_status=str(row["review_status"]),
            gene_symbol=None if row.get("gene_symbol") is None else str(row["gene_symbol"]),
            clinvar_id=int(row["clinvar_id"]),
            schema_version=str(row["schema_version"]),
        )

iter_clinvar_vcf_variants

iter_clinvar_vcf_variants(input_vcf: str | Path, *, max_allele_len: int = 16) -> Iterator[ClinvarVariant]

Yield normalized ClinVar rows from a local VCF without writing a shard.

Source code in geno_lewm/data/clinvar.py
def iter_clinvar_vcf_variants(
    input_vcf: str | Path,
    *,
    max_allele_len: int = 16,
) -> Iterator[ClinvarVariant]:
    """Yield normalized ClinVar rows from a local VCF without writing a shard."""
    _require_positive_int("max_allele_len", max_allele_len)
    for row in iter_vcf_rows(input_vcf):
        for alt_index, alt in enumerate(row.alts):
            if not is_supported_allele(row.ref, max_len=max_allele_len) or not is_supported_allele(
                alt, max_len=max_allele_len
            ):
                continue
            yield ClinvarVariant(
                chrom=row.chrom,
                pos=row.pos,
                ref=row.ref,
                alt=alt,
                clinical_significance=_clinical_significance(row.info, alt_index),
                review_status=_review_status(row.info),
                gene_symbol=_gene_symbol(row.info),
                clinvar_id=_clinvar_id(row.info, row.variant_id, alt_index),
            )

label_set

label_set(variants: Iterable[ClinvarVariant]) -> tuple[ClinvarVariant, ...]

Return ClinVar rows usable for labelled eval, excluding VUS/OTHER.

Source code in geno_lewm/data/clinvar.py
def label_set(variants: Iterable[ClinvarVariant]) -> tuple[ClinvarVariant, ...]:
    """Return ClinVar rows usable for labelled eval, excluding VUS/OTHER."""
    return tuple(row for row in variants if row.clinical_significance in CLINVAR_LABELLED_CLASSES)

prepare_clinvar_shard

prepare_clinvar_shard(input_vcf: str | Path, output_dir: str | Path, *, release: str, max_allele_len: int = 16, overwrite: bool = False) -> ClinvarPrepareReport

Normalize a local ClinVar VCF/VCF.gz into the release shard schema.

Source code in geno_lewm/data/clinvar.py
def prepare_clinvar_shard(
    input_vcf: str | Path,
    output_dir: str | Path,
    *,
    release: str,
    max_allele_len: int = 16,
    overwrite: bool = False,
) -> ClinvarPrepareReport:
    """Normalize a local ClinVar VCF/VCF.gz into the release shard schema."""
    _require_release(release)
    _require_positive_int("max_allele_len", max_allele_len)
    target = Path(output_dir) / "clinvar" / release / "variants.parquet"
    if target.exists() and not overwrite:
        return ClinvarPrepareReport(
            output_path=target,
            release=release,
            records_read=0,
            allele_records_seen=0,
            records_written=_parquet_num_rows(target),
            skipped_allele=0,
            size_bytes=target.stat().st_size,
            already_exists=True,
        )

    records_read = 0
    allele_records_seen = 0
    skipped_allele = 0

    def _selected_rows() -> Iterator[ClinvarVariant]:
        nonlocal records_read, allele_records_seen, skipped_allele
        for row in iter_vcf_rows(input_vcf):
            records_read += 1
            for alt_index, alt in enumerate(row.alts):
                allele_records_seen += 1
                if not is_supported_allele(
                    row.ref, max_len=max_allele_len
                ) or not is_supported_allele(alt, max_len=max_allele_len):
                    skipped_allele += 1
                    continue
                yield ClinvarVariant(
                    chrom=row.chrom,
                    pos=row.pos,
                    ref=row.ref,
                    alt=alt,
                    clinical_significance=_clinical_significance(row.info, alt_index),
                    review_status=_review_status(row.info),
                    gene_symbol=_gene_symbol(row.info),
                    clinvar_id=_clinvar_id(row.info, row.variant_id, alt_index),
                )

    records_written = _write_parquet(_selected_rows(), target)
    return ClinvarPrepareReport(
        output_path=target,
        release=release,
        records_read=records_read,
        allele_records_seen=allele_records_seen,
        records_written=records_written,
        skipped_allele=skipped_allele,
        size_bytes=target.stat().st_size,
    )

draw_source_counts

draw_source_counts(n: int, *, rng: Random, mix: Sequence[CarbonSourceMix] = CARBON_SUBMIX) -> dict[str, int]

Draw n source samples and return counts by normalized source key.

Source code in geno_lewm/data/corpus.py
def draw_source_counts(
    n: int,
    *,
    rng: random.Random,
    mix: Sequence[CarbonSourceMix] = CARBON_SUBMIX,
) -> dict[str, int]:
    """Draw ``n`` source samples and return counts by normalized source key."""
    _require_nonnegative_int("n", n)
    entries = _validate_mix(mix)
    counts = {entry.source: 0 for entry in entries}
    for _ in range(n):
        counts[_sample_source_from_entries(rng, entries)] += 1
    return counts

iter_carbon_records

iter_carbon_records(rows: Iterable[Mapping[str, Any]], *, sequence_field: str = DEFAULT_SEQUENCE_FIELD, source_field: str = DEFAULT_SOURCE_FIELD, source_id_field: str = DEFAULT_SOURCE_ID_FIELD, subset_fraction: float = 1.0, subset_seed: int = 0, default_source: str | None = None, skip_invalid: bool = False) -> Iterator[CarbonRecord]

Yield canonical Carbon records from HF-style row mappings.

Single-source corpus configs (e.g. eukaryote_generator_10B_subset) do not carry a per-row source_field; pass default_source to label every record (it must still be a recognized source key). With skip_invalid, rows whose sequence carries unsupported (non-ACGTN) bases are skipped rather than raising — corpus shards occasionally contain IUPAC ambiguity codes.

Source code in geno_lewm/data/corpus.py
def iter_carbon_records(
    rows: Iterable[Mapping[str, Any]],
    *,
    sequence_field: str = DEFAULT_SEQUENCE_FIELD,
    source_field: str = DEFAULT_SOURCE_FIELD,
    source_id_field: str = DEFAULT_SOURCE_ID_FIELD,
    subset_fraction: float = 1.0,
    subset_seed: int = 0,
    default_source: str | None = None,
    skip_invalid: bool = False,
) -> Iterator[CarbonRecord]:
    """Yield canonical Carbon records from HF-style row mappings.

    Single-source corpus configs (e.g. ``eukaryote_generator_10B_subset``) do
    not carry a per-row ``source_field``; pass ``default_source`` to label every
    record (it must still be a recognized source key). With ``skip_invalid``,
    rows whose sequence carries unsupported (non-ACGTN) bases are skipped rather
    than raising — corpus shards occasionally contain IUPAC ambiguity codes.
    """
    _require_nonempty_str("sequence_field", sequence_field)
    _require_nonempty_str("source_field", source_field)
    _require_nonempty_str("source_id_field", source_id_field)
    _validate_fraction("subset_fraction", subset_fraction)
    _require_nonnegative_int("subset_seed", subset_seed)

    for row_idx, row in enumerate(rows):
        sequence_value = row.get(sequence_field)
        if not isinstance(sequence_value, str):
            if skip_invalid:
                continue
            raise InputError(
                "Carbon corpus row is missing a DNA sequence string",
                details={"row": row_idx, "sequence_field": sequence_field},
            )
        raw_source = row.get(source_field)
        if default_source is not None and (
            raw_source is None or (isinstance(raw_source, str) and not raw_source.strip())
        ):
            raw_source = default_source
        try:
            source = normalize_source_label(raw_source)
            sequence = canonicalize_dna(sequence_value)
        except InputError:
            if skip_invalid:
                continue
            raise
        raw_record_id = row.get(source_id_field)
        record_id = (
            str(raw_record_id) if raw_record_id not in (None, "") else _fallback_id(sequence)
        )
        if not stable_subset_includes(record_id, fraction=subset_fraction, seed=subset_seed):
            continue
        yield CarbonRecord(record_id=record_id, source=source, sequence=sequence)

iter_record_windows

iter_record_windows(record: CarbonRecord, *, window_bp: int = DEFAULT_WINDOW_BP, margin_bp: int = DEFAULT_CORPUS_MARGIN_BP, stride_bp: int = DEFAULT_CORPUS_STRIDE_BP, rng: Random | None = None) -> Iterator[CarbonWindow]

Yield canonical windows for one Carbon corpus record.

Source code in geno_lewm/data/corpus.py
def iter_record_windows(
    record: CarbonRecord,
    *,
    window_bp: int = DEFAULT_WINDOW_BP,
    margin_bp: int = DEFAULT_CORPUS_MARGIN_BP,
    stride_bp: int = DEFAULT_CORPUS_STRIDE_BP,
    rng: random.Random | None = None,
) -> Iterator[CarbonWindow]:
    """Yield canonical windows for one Carbon corpus record."""
    for start in iter_window_starts(
        record.length_bp,
        window_bp=window_bp,
        margin_bp=margin_bp,
        stride_bp=stride_bp,
        rng=rng,
    ):
        end = start + window_bp
        yield CarbonWindow(
            record_id=record.record_id,
            source=record.source,
            start_bp=start,
            end_bp=end,
            sequence=record.sequence[start:end],
        )

iter_window_starts

iter_window_starts(sequence_length: int, *, window_bp: int = DEFAULT_WINDOW_BP, margin_bp: int = DEFAULT_CORPUS_MARGIN_BP, stride_bp: int = DEFAULT_CORPUS_STRIDE_BP, rng: Random | None = None) -> Iterator[int]

Yield RFC-0006 window starts respecting margin and stride constraints.

Source code in geno_lewm/data/corpus.py
def iter_window_starts(
    sequence_length: int,
    *,
    window_bp: int = DEFAULT_WINDOW_BP,
    margin_bp: int = DEFAULT_CORPUS_MARGIN_BP,
    stride_bp: int = DEFAULT_CORPUS_STRIDE_BP,
    rng: random.Random | None = None,
) -> Iterator[int]:
    """Yield RFC-0006 window starts respecting margin and stride constraints."""
    _require_nonnegative_int("sequence_length", sequence_length)
    _require_positive_int("window_bp", window_bp)
    _require_nonnegative_int("margin_bp", margin_bp)
    _require_positive_int("stride_bp", stride_bp)

    required = window_bp + (2 * margin_bp)
    if sequence_length < required:
        return

    min_start = margin_bp
    max_start = sequence_length - window_bp - margin_bp
    phase_span = min(stride_bp, max_start - min_start + 1)
    offset = rng.randrange(phase_span) if rng is not None and phase_span > 1 else 0
    start = min_start + offset
    while start <= max_start:
        yield start
        start += stride_bp

load_hf_carbon_records

load_hf_carbon_records(config: CarbonCorpusConfig | None = None) -> Iterator[CarbonRecord]

Load Carbon corpus records through Hugging Face datasets lazily.

Source code in geno_lewm/data/corpus.py
def load_hf_carbon_records(
    config: CarbonCorpusConfig | None = None,
) -> Iterator[CarbonRecord]:
    """Load Carbon corpus records through Hugging Face ``datasets`` lazily."""
    if config is None:
        config = CarbonCorpusConfig()
    try:
        datasets = importlib.import_module("datasets")
    except ImportError as exc:
        raise RuntimeSetupError(
            "Carbon corpus loading requires Hugging Face datasets",
            remediation="install geno-lewm[train] or install datasets",
        ) from exc

    args: tuple[str, ...]
    if config.dataset_config is None:
        args = (config.dataset_id,)
    else:
        args = (config.dataset_id, config.dataset_config)
    dataset = datasets.load_dataset(
        *args,
        split=config.split,
        streaming=config.streaming,
        revision=config.revision,
    )
    return iter_carbon_records(
        dataset,
        sequence_field=config.sequence_field,
        source_field=config.source_field,
        source_id_field=config.source_id_field,
        subset_fraction=config.subset_fraction,
        subset_seed=config.subset_seed,
        default_source=config.default_source,
        skip_invalid=config.skip_invalid,
    )

normalize_source_label

normalize_source_label(value: object) -> str

Normalize a Carbon corpus source label to the RFC-0006 source key.

Source code in geno_lewm/data/corpus.py
def normalize_source_label(value: object) -> str:
    """Normalize a Carbon corpus source label to the RFC-0006 source key."""
    if not isinstance(value, str) or not value.strip():
        raise InputError(
            "source label must be a non-empty string",
            details={"value": value, "type": type(value).__name__},
        )
    key = value.strip().lower().replace("-", " ").replace("/", " ")
    key = " ".join(key.split())
    normalized = _SOURCE_ALIASES.get(key)
    if normalized is None:
        raise InputError(
            "unsupported Carbon corpus source label",
            details={"source": value, "known_sources": [entry.source for entry in CARBON_SUBMIX]},
        )
    return normalized

sample_source

sample_source(rng: Random, *, mix: Sequence[CarbonSourceMix] = CARBON_SUBMIX) -> str

Sample one source key from the configured RFC-0006 sub-mix.

Source code in geno_lewm/data/corpus.py
def sample_source(
    rng: random.Random,
    *,
    mix: Sequence[CarbonSourceMix] = CARBON_SUBMIX,
) -> str:
    """Sample one source key from the configured RFC-0006 sub-mix."""
    return _sample_source_from_entries(rng, _validate_mix(mix))

stable_subset_includes

stable_subset_includes(record_id: str, *, fraction: float, seed: int = 0) -> bool

Return whether record_id belongs to a deterministic corpus subset.

Source code in geno_lewm/data/corpus.py
def stable_subset_includes(record_id: str, *, fraction: float, seed: int = 0) -> bool:
    """Return whether ``record_id`` belongs to a deterministic corpus subset."""
    _require_nonempty_str("record_id", record_id)
    _validate_fraction("fraction", fraction)
    _require_nonnegative_int("seed", seed)
    digest = hashlib.sha256(f"{seed}:{record_id}".encode()).digest()
    value = int.from_bytes(digest[:8], byteorder="big") / float(1 << 64)
    return value < fraction

iter_gnomad_shard

iter_gnomad_shard(path: str | Path) -> Iterator[GnomadVariant]

Yield normalized gnomAD rows from a Parquet shard.

Source code in geno_lewm/data/gnomad.py
def iter_gnomad_shard(path: str | Path) -> Iterator[GnomadVariant]:
    """Yield normalized gnomAD rows from a Parquet shard."""
    _pa, pq = _require_pyarrow()
    table = pq.read_table(Path(path))
    for row in table.to_pylist():
        yield GnomadVariant(
            chrom=str(row["chrom"]),
            pos=int(row["pos"]),
            ref=str(row["ref"]),
            alt=str(row["alt"]),
            af_global=float(row["af_global"]),
            af_afr=_optional_float(row.get("af_afr")),
            af_ami=_optional_float(row.get("af_ami")),
            af_amr=_optional_float(row.get("af_amr")),
            af_asj=_optional_float(row.get("af_asj")),
            af_eas=_optional_float(row.get("af_eas")),
            af_fin=_optional_float(row.get("af_fin")),
            af_nfe=_optional_float(row.get("af_nfe")),
            af_oth=_optional_float(row.get("af_oth")),
            af_sas=_optional_float(row.get("af_sas")),
            filter=str(row["filter"]),
            schema_version=str(row["schema_version"]),
        )

iter_gnomad_vcf_variants

iter_gnomad_vcf_variants(input_vcf: str | Path, *, min_af: float = 0.01, max_allele_len: int = 16) -> Iterator[GnomadVariant]

Yield normalized rows from a local gnomAD VCF without writing a shard.

Source code in geno_lewm/data/gnomad.py
def iter_gnomad_vcf_variants(
    input_vcf: str | Path,
    *,
    min_af: float = 0.01,
    max_allele_len: int = 16,
) -> Iterator[GnomadVariant]:
    """Yield normalized rows from a local gnomAD VCF without writing a shard."""
    report = prepare_gnomad_shard
    del report
    _require_probability("min_af", min_af)
    _require_positive_int("max_allele_len", max_allele_len)
    for row in iter_vcf_rows(input_vcf):
        for alt_index, alt in enumerate(row.alts):
            if row.filter != "PASS":
                continue
            if not is_supported_allele(row.ref, max_len=max_allele_len) or not is_supported_allele(
                alt, max_len=max_allele_len
            ):
                continue
            af_global = _af_for(row.info, ("AF", "AF_global", "AF_GLOBAL"), alt_index)
            if af_global is None or af_global < min_af:
                continue
            yield GnomadVariant(
                chrom=row.chrom,
                pos=row.pos,
                ref=row.ref,
                alt=alt,
                af_global=af_global,
                af_afr=_af_for(row.info, ("AF_afr", "AF_AFR"), alt_index),
                af_ami=_af_for(row.info, ("AF_ami", "AF_AMI"), alt_index),
                af_amr=_af_for(row.info, ("AF_amr", "AF_AMR"), alt_index),
                af_asj=_af_for(row.info, ("AF_asj", "AF_ASJ"), alt_index),
                af_eas=_af_for(row.info, ("AF_eas", "AF_EAS"), alt_index),
                af_fin=_af_for(row.info, ("AF_fin", "AF_FIN"), alt_index),
                af_nfe=_af_for(row.info, ("AF_nfe", "AF_NFE"), alt_index),
                af_oth=_af_for(row.info, ("AF_oth", "AF_OTH"), alt_index),
                af_sas=_af_for(row.info, ("AF_sas", "AF_SAS"), alt_index),
                filter=row.filter,
            )

prepare_gnomad_shard

prepare_gnomad_shard(input_vcf: str | Path, output_dir: str | Path, *, release: str = 'v4.1', min_af: float = 0.01, max_allele_len: int = 16, overwrite: bool = False) -> GnomadPrepareReport

Filter a local gnomAD VCF/VCF.gz into the release shard schema.

Source code in geno_lewm/data/gnomad.py
def prepare_gnomad_shard(
    input_vcf: str | Path,
    output_dir: str | Path,
    *,
    release: str = "v4.1",
    min_af: float = 0.01,
    max_allele_len: int = 16,
    overwrite: bool = False,
) -> GnomadPrepareReport:
    """Filter a local gnomAD VCF/VCF.gz into the release shard schema."""
    _require_release(release)
    _require_probability("min_af", min_af)
    _require_positive_int("max_allele_len", max_allele_len)

    target = Path(output_dir) / "gnomad" / release / "variants.parquet"
    if target.exists() and not overwrite:
        return GnomadPrepareReport(
            output_path=target,
            release=release,
            records_read=0,
            allele_records_seen=0,
            records_written=_parquet_num_rows(target),
            skipped_filter=0,
            skipped_af=0,
            skipped_allele=0,
            size_bytes=target.stat().st_size,
            already_exists=True,
        )

    records_read = 0
    allele_records_seen = 0
    skipped_filter = 0
    skipped_af = 0
    skipped_allele = 0

    def _selected_rows() -> Iterator[GnomadVariant]:
        nonlocal records_read, allele_records_seen, skipped_filter, skipped_af, skipped_allele
        for row in iter_vcf_rows(input_vcf):
            records_read += 1
            for alt_index, alt in enumerate(row.alts):
                allele_records_seen += 1
                if row.filter != "PASS":
                    skipped_filter += 1
                    continue
                if not is_supported_allele(
                    row.ref, max_len=max_allele_len
                ) or not is_supported_allele(alt, max_len=max_allele_len):
                    skipped_allele += 1
                    continue
                af_global = _af_for(row.info, ("AF", "AF_global", "AF_GLOBAL"), alt_index)
                if af_global is None or af_global < min_af:
                    skipped_af += 1
                    continue
                yield GnomadVariant(
                    chrom=row.chrom,
                    pos=row.pos,
                    ref=row.ref,
                    alt=alt,
                    af_global=af_global,
                    af_afr=_af_for(row.info, ("AF_afr", "AF_AFR"), alt_index),
                    af_ami=_af_for(row.info, ("AF_ami", "AF_AMI"), alt_index),
                    af_amr=_af_for(row.info, ("AF_amr", "AF_AMR"), alt_index),
                    af_asj=_af_for(row.info, ("AF_asj", "AF_ASJ"), alt_index),
                    af_eas=_af_for(row.info, ("AF_eas", "AF_EAS"), alt_index),
                    af_fin=_af_for(row.info, ("AF_fin", "AF_FIN"), alt_index),
                    af_nfe=_af_for(row.info, ("AF_nfe", "AF_NFE"), alt_index),
                    af_oth=_af_for(row.info, ("AF_oth", "AF_OTH"), alt_index),
                    af_sas=_af_for(row.info, ("AF_sas", "AF_SAS"), alt_index),
                    filter=row.filter,
                )

    records_written = _write_parquet(_selected_rows(), target)
    return GnomadPrepareReport(
        output_path=target,
        release=release,
        records_read=records_read,
        allele_records_seen=allele_records_seen,
        records_written=records_written,
        skipped_filter=skipped_filter,
        skipped_af=skipped_af,
        skipped_allele=skipped_allele,
        size_bytes=target.stat().st_size,
    )