Skip to content

geno_lewm.training.sampling

sampling

RFC-0005 edit-balanced and rollout-length samplers.

EditTypeWeight dataclass

EditTypeWeight(edit_type: EditType, weight: float)

One RFC-0005 edit-type sampling weight.

RolloutStepWeight dataclass

RolloutStepWeight(steps: int, weight: float)

One rollout-step-count sampling weight.

sample_edit_type

sample_edit_type(rng: Random, *, weights: Sequence[EditTypeWeight] = DEFAULT_EDIT_TYPE_WEIGHTS) -> EditType

Sample one edit type from the RFC-0005 edit-balanced distribution.

Source code in geno_lewm/training/sampling.py
def sample_edit_type(
    rng: random.Random,
    *,
    weights: Sequence[EditTypeWeight] = DEFAULT_EDIT_TYPE_WEIGHTS,
) -> EditType:
    """Sample one edit type from the RFC-0005 edit-balanced distribution."""
    return _sample_weighted(rng, _validate_edit_type_weights(weights)).edit_type

draw_edit_type_counts

draw_edit_type_counts(n: int, *, rng: Random, weights: Sequence[EditTypeWeight] = DEFAULT_EDIT_TYPE_WEIGHTS) -> dict[EditType, int]

Draw n edit types and return counts by :class:EditType.

Source code in geno_lewm/training/sampling.py
def draw_edit_type_counts(
    n: int,
    *,
    rng: random.Random,
    weights: Sequence[EditTypeWeight] = DEFAULT_EDIT_TYPE_WEIGHTS,
) -> dict[EditType, int]:
    """Draw ``n`` edit types and return counts by :class:`EditType`."""
    _require_nonnegative_int("n", n)
    entries = _validate_edit_type_weights(weights)
    counts = {entry.edit_type: 0 for entry in entries}
    for _ in range(n):
        counts[_sample_weighted(rng, entries).edit_type] += 1
    return counts

sample_rollout_steps

sample_rollout_steps(rng: Random, *, mix: Sequence[RolloutStepWeight] = DEFAULT_ROLLOUT_STEP_MIX) -> int

Sample a rollout length K from the Phase-1 RFC-0005 mix.

Source code in geno_lewm/training/sampling.py
def sample_rollout_steps(
    rng: random.Random,
    *,
    mix: Sequence[RolloutStepWeight] = DEFAULT_ROLLOUT_STEP_MIX,
) -> int:
    """Sample a rollout length ``K`` from the Phase-1 RFC-0005 mix."""
    return _sample_weighted(rng, _validate_rollout_mix(mix)).steps

draw_rollout_step_counts

draw_rollout_step_counts(n: int, *, rng: Random, mix: Sequence[RolloutStepWeight] = DEFAULT_ROLLOUT_STEP_MIX) -> dict[int, int]

Draw n rollout lengths and return counts by step count.

Source code in geno_lewm/training/sampling.py
def draw_rollout_step_counts(
    n: int,
    *,
    rng: random.Random,
    mix: Sequence[RolloutStepWeight] = DEFAULT_ROLLOUT_STEP_MIX,
) -> dict[int, int]:
    """Draw ``n`` rollout lengths and return counts by step count."""
    _require_nonnegative_int("n", n)
    entries = _validate_rollout_mix(mix)
    counts = {entry.steps: 0 for entry in entries}
    for _ in range(n):
        counts[_sample_weighted(rng, entries).steps] += 1
    return counts