def __init__(
self,
model_id: str,
revision: str,
*,
dtype: str = "bf16",
state_layer: int = -1,
pool_type: str = POOL_CENTERED_MEAN,
pool_radius: int = DEFAULT_POOL_RADIUS_TOKENS,
normalize: bool = True,
lora_config: object | None = None,
model: object | None = None,
tokenizer: object | None = None,
encoder_hash: bytes | str | None = None,
local_files_only: bool = True,
trust_remote_code: bool = False,
device: str | None = None,
) -> None:
if not model_id:
raise InputError("model_id must be non-empty")
if not revision:
raise InputError("revision must be non-empty")
if dtype not in _SUPPORTED_DTYPES:
raise InputError(
"unsupported encoder dtype",
details={"dtype": dtype, "supported": sorted(_SUPPORTED_DTYPES)},
)
if not isinstance(state_layer, int) or isinstance(state_layer, bool):
raise InputError(
"state_layer must be an integer",
details={"state_layer": state_layer, "type": type(state_layer).__name__},
)
if pool_type not in {POOL_CENTERED_MEAN, POOL_GLOBAL_MEAN}:
raise InputError(
"unsupported pool_type",
details={
"pool_type": pool_type,
"supported": [POOL_CENTERED_MEAN, POOL_GLOBAL_MEAN],
},
)
if not isinstance(pool_radius, int) or isinstance(pool_radius, bool) or pool_radius < 0:
raise InputError(
"pool_radius must be a non-negative integer",
details={"pool_radius": pool_radius, "type": type(pool_radius).__name__},
)
if not isinstance(normalize, bool):
raise InputError(
"normalize must be bool",
details={"type": type(normalize).__name__},
)
if lora_config is not None:
raise RuntimeSetupError(
"Carbon LoRA adapters are not supported by CarbonStateEncoder yet",
remediation="merge LoRA adapters before loading or track the Phase 2 adapter issue",
)
if (model is None) != (tokenizer is None):
raise InputError(
"model and tokenizer must be supplied together",
details={"model": model is not None, "tokenizer": tokenizer is not None},
)
self.model_id = model_id
self.revision = revision
self.dtype = dtype
self.state_layer = state_layer
self.pool_type = cast(_PoolType, pool_type)
self.pool_radius = pool_radius
self.normalize = normalize
self.local_files_only = local_files_only
self.trust_remote_code = trust_remote_code
self.device = _resolve_device(device)
self._encoder_hash = _coerce_encoder_hash(encoder_hash)
self._d_state: int | None = None
if model is None or tokenizer is None:
tokenizer, model = _load_transformers_components(
model_id=model_id,
revision=revision,
dtype=dtype,
local_files_only=local_files_only,
trust_remote_code=trust_remote_code,
)
self.tokenizer = tokenizer
self.model = model
_eval_if_available(self.model)
_move_module_to_device(self.model, self.device)
config = getattr(self.model, "config", None)
hidden_size = getattr(config, "hidden_size", None)
if isinstance(hidden_size, int) and not isinstance(hidden_size, bool) and hidden_size > 0:
self._d_state = hidden_size