Skip to content

geno_lewm.surprise.calibration

calibration

Calibration-table builder and Parquet IO for RFC-0009.

CALIBRATION_SCHEMA_VERSION module-attribute

CALIBRATION_SCHEMA_VERSION = '1.0.0'

On-disk calibration table schema version.

DEFAULT_CDF_POINTS module-attribute

DEFAULT_CDF_POINTS = 1001

Number of points in each empirical CDF grid.

DEFAULT_REFERENCE_PER_BUCKET module-attribute

DEFAULT_REFERENCE_PER_BUCKET = 10000

Default maximum number of reference variants sampled per bucket.

LOW_CONFIDENCE_BUCKET_SIZE module-attribute

LOW_CONFIDENCE_BUCKET_SIZE = 100

Buckets below this size are marked low-confidence by RFC-0009.

CalibrationExample dataclass

CalibrationExample(bucket_id: str, sigma_raw: float)

One pre-scored reference variant used to build calibration CDFs.

CalibrationWarning dataclass

CalibrationWarning(bucket_id: str, resolved_bucket_id: str, n_calibration: int, min_bucket_size: int, low_confidence: bool)

Sparse-bucket warning emitted while building a calibration table.

CalibrationBucket dataclass

CalibrationBucket(bucket_id: str, n_calibration: int, cdf: tuple[float, ...], sigma_grid: tuple[float, ...], back_off_to: str | None = None, schema_version: str = CALIBRATION_SCHEMA_VERSION)

One row in calibration.parquet.

confidence property

confidence: float

Return RFC-0009 confidence from this bucket's row count.

low_confidence property

low_confidence: bool

Return true when this bucket is below the low-confidence floor.

CalibrationTable dataclass

CalibrationTable(buckets: tuple[CalibrationBucket, ...], warnings: tuple[CalibrationWarning, ...] = (), schema_version: str = CALIBRATION_SCHEMA_VERSION)

In-memory representation of calibration.parquet.

get

get(bucket_id: str) -> CalibrationBucket | None

Return a bucket by ID, or None if absent.

Source code in geno_lewm/surprise/calibration.py
def get(self, bucket_id: str) -> CalibrationBucket | None:
    """Return a bucket by ID, or ``None`` if absent."""
    _require_bucket_id(bucket_id)
    by_id = {bucket.bucket_id: bucket for bucket in self.buckets}
    return by_id.get(bucket_id)

require

require(bucket_id: str) -> CalibrationBucket

Return a bucket by ID, raising InputError when absent.

Source code in geno_lewm/surprise/calibration.py
def require(self, bucket_id: str) -> CalibrationBucket:
    """Return a bucket by ID, raising ``InputError`` when absent."""
    bucket = self.get(bucket_id)
    if bucket is None:
        raise InputError(
            "calibration bucket is not present",
            details={"bucket_id": bucket_id},
        )
    return bucket

resolve

resolve(label_or_bucket: str, *, min_bucket_size: int = DEFAULT_MIN_BUCKET_SIZE) -> CalibrationBucket

Resolve a sparse bucket through the table's fixed backoff chain.

Source code in geno_lewm/surprise/calibration.py
def resolve(
    self,
    label_or_bucket: str,
    *,
    min_bucket_size: int = DEFAULT_MIN_BUCKET_SIZE,
) -> CalibrationBucket:
    """Resolve a sparse bucket through the table's fixed backoff chain."""
    threshold = _require_positive_int("min_bucket_size", min_bucket_size)
    counts = {bucket.bucket_id: bucket.n_calibration for bucket in self.buckets}
    resolved = select_backoff_bucket(label_or_bucket, counts, min_count=threshold)
    return self.require(resolved)

build_calibration_table

build_calibration_table(examples: Iterable[CalibrationExample], *, seed: int = 0, per_bucket_sample: int = DEFAULT_REFERENCE_PER_BUCKET, grid_size: int = DEFAULT_CDF_POINTS, min_bucket_size: int = DEFAULT_MIN_BUCKET_SIZE, low_confidence_size: int = LOW_CONFIDENCE_BUCKET_SIZE, warn_sparse: bool = True) -> CalibrationTable

Build deterministic empirical CDF buckets from pre-scored examples.

Source code in geno_lewm/surprise/calibration.py
def build_calibration_table(
    examples: Iterable[CalibrationExample],
    *,
    seed: int = 0,
    per_bucket_sample: int = DEFAULT_REFERENCE_PER_BUCKET,
    grid_size: int = DEFAULT_CDF_POINTS,
    min_bucket_size: int = DEFAULT_MIN_BUCKET_SIZE,
    low_confidence_size: int = LOW_CONFIDENCE_BUCKET_SIZE,
    warn_sparse: bool = True,
) -> CalibrationTable:
    """Build deterministic empirical CDF buckets from pre-scored examples."""
    _require_seed(seed)
    sample_limit = _require_positive_int("per_bucket_sample", per_bucket_sample)
    points = _require_positive_int("grid_size", grid_size)
    if points < 2:
        raise InputError("grid_size must be at least 2", details={"grid_size": points})
    min_size = _require_positive_int("min_bucket_size", min_bucket_size)
    low_size = _require_positive_int("low_confidence_size", low_confidence_size)
    if low_size > min_size:
        raise InputError(
            "low_confidence_size must be <= min_bucket_size",
            details={"low_confidence_size": low_size, "min_bucket_size": min_size},
        )

    aggregated: dict[str, list[float]] = {}
    source_buckets: set[str] = set()
    for example in examples:
        if not isinstance(example, CalibrationExample):
            raise InputError(
                "examples must contain CalibrationExample instances",
                details={"type": type(example).__name__},
            )
        source_buckets.add(example.bucket_id)
        for bucket_id in backoff_chain(example.bucket_id):
            aggregated.setdefault(bucket_id, []).append(float(example.sigma_raw))

    if not source_buckets:
        raise InputError("calibration examples must contain at least one row")

    sampled: dict[str, tuple[float, ...]] = {}
    for bucket_id, values in sorted(aggregated.items()):
        sampled[bucket_id] = _sample_bucket_values(
            values,
            seed=seed,
            bucket_id=bucket_id,
            sample_limit=sample_limit,
        )

    counts = {bucket_id: len(values) for bucket_id, values in sampled.items()}
    buckets: list[CalibrationBucket] = []
    for bucket_id, sampled_values in sorted(sampled.items()):
        cdf, sigma_grid = _empirical_cdf(sampled_values, grid_size=points)
        resolved = select_backoff_bucket(bucket_id, counts, min_count=min_size)
        buckets.append(
            CalibrationBucket(
                bucket_id=bucket_id,
                n_calibration=len(sampled_values),
                cdf=cdf,
                sigma_grid=sigma_grid,
                back_off_to=None if resolved == bucket_id else resolved,
            )
        )

    sparse_warnings = _sparse_warnings(
        sorted(source_buckets),
        counts,
        min_bucket_size=min_size,
        low_confidence_size=low_size,
    )
    if warn_sparse:
        for warning in sparse_warnings:
            warnings.warn(
                "calibration bucket remains sparse after backoff: "
                f"{warning.bucket_id} -> {warning.resolved_bucket_id} "
                f"n={warning.n_calibration} min={warning.min_bucket_size}",
                RuntimeWarning,
                stacklevel=2,
            )

    return CalibrationTable(buckets=tuple(buckets), warnings=sparse_warnings)

write_calibration_table

write_calibration_table(table: CalibrationTable, path: str | Path) -> Path

Write a calibration table to calibration.parquet.

Source code in geno_lewm/surprise/calibration.py
def write_calibration_table(table: CalibrationTable, path: str | Path) -> Path:
    """Write a calibration table to ``calibration.parquet``."""
    pa, pq = _require_pyarrow()
    destination = Path(path)
    destination.parent.mkdir(parents=True, exist_ok=True)
    arrow_table = pa.Table.from_pydict(
        {
            "bucket_id": [bucket.bucket_id for bucket in table.buckets],
            "n_calibration": [bucket.n_calibration for bucket in table.buckets],
            "cdf": [list(bucket.cdf) for bucket in table.buckets],
            "sigma_grid": [list(bucket.sigma_grid) for bucket in table.buckets],
            "back_off_to": [bucket.back_off_to for bucket in table.buckets],
            "schema_version": [bucket.schema_version for bucket in table.buckets],
        },
        schema=_arrow_schema(pa),
    )
    pq.write_table(arrow_table, destination, compression="zstd", compression_level=9)
    return destination

read_calibration_table

read_calibration_table(path: str | Path) -> CalibrationTable

Read and validate a calibration Parquet file.

Source code in geno_lewm/surprise/calibration.py
def read_calibration_table(path: str | Path) -> CalibrationTable:
    """Read and validate a calibration Parquet file."""
    _pa, pq = _require_pyarrow()
    source = Path(path)
    try:
        arrow_table = pq.read_table(source)
    except Exception as exc:
        raise SchemaCompatError(
            "calibration table could not be read",
            details={"path": str(source), "error": str(exc)},
        ) from exc

    observed = tuple(arrow_table.column_names)
    expected = _column_names()
    if observed != expected:
        raise SchemaCompatError(
            "calibration table columns do not match the documented schema",
            details={"observed": list(observed), "expected": list(expected)},
        )

    return CalibrationTable(
        buckets=tuple(
            CalibrationBucket(
                bucket_id=str(row["bucket_id"]),
                n_calibration=int(row["n_calibration"]),
                cdf=tuple(float(value) for value in row["cdf"]),
                sigma_grid=tuple(float(value) for value in row["sigma_grid"]),
                back_off_to=None if row["back_off_to"] is None else str(row["back_off_to"]),
                schema_version=str(row["schema_version"]),
            )
            for row in arrow_table.to_pylist()
        )
    )