Coverage for src / lilbee / model_defaults.py: 100%
63 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Per-model default generation settings.
3Parses and caches generation parameters from key-value parameter text
4or GGUF file metadata so that model-specific defaults (temperature, num_ctx,
5etc.) are applied automatically when switching models.
6"""
8from __future__ import annotations
10import contextlib
11import logging
12from dataclasses import dataclass, fields
13from typing import Any
15log = logging.getLogger(__name__)
17# Parameter keys we recognise and their target types
18_KNOWN_PARAM_TYPES: dict[str, type] = {
19 "temperature": float,
20 "top_p": float,
21 "top_k": int,
22 "repeat_penalty": float,
23 "num_ctx": int,
24 "max_tokens": int,
25}
27# GGUF metadata keys mapped to ModelDefaults field names
28_GGUF_KEY_MAP: dict[str, str] = {
29 "general.temperature": "temperature",
30 "general.top_p": "top_p",
31 "general.top_k": "top_k",
32 "general.repeat_penalty": "repeat_penalty",
33}
36@dataclass(frozen=True)
37class ModelDefaults:
38 """Frozen snapshot of a model's default generation parameters."""
40 temperature: float | None = None
41 top_p: float | None = None
42 top_k: int | None = None
43 repeat_penalty: float | None = None
44 num_ctx: int | None = None
45 max_tokens: int | None = None
48class _DefaultsCache:
49 """Encapsulates the per-model defaults cache (no module-level mutable global)."""
51 def __init__(self) -> None:
52 self._data: dict[str, ModelDefaults] = {}
54 def get(self, model_name: str) -> ModelDefaults | None:
55 return self._data.get(model_name)
57 def set(self, model_name: str, defaults: ModelDefaults) -> None:
58 self._data[model_name] = defaults
60 def clear(self) -> None:
61 self._data.clear()
64_defaults_cache = _DefaultsCache()
66# Public API — preserves existing call sites.
67get_defaults = _defaults_cache.get
68set_defaults = _defaults_cache.set
69clear_cache = _defaults_cache.clear
72def parse_kv_parameters(text: str) -> ModelDefaults:
73 """Parse multiline ``key value`` parameter format.
74 Example input::
76 temperature 0.7
77 top_p 0.9
78 num_ctx 4096
79 stop <|im_end|>
81 Unknown keys (like ``stop``) are silently skipped.
82 """
83 values: dict[str, Any] = {}
84 for line in text.splitlines():
85 line = line.strip()
86 if not line:
87 continue
88 parts = line.split(None, 1)
89 if len(parts) != 2:
90 continue
91 key, raw_value = parts
92 if key not in _KNOWN_PARAM_TYPES:
93 continue
94 try:
95 values[key] = _KNOWN_PARAM_TYPES[key](raw_value)
96 except (ValueError, TypeError):
97 log.debug("Skipping unparseable param %s=%r", key, raw_value)
98 return ModelDefaults(**values)
101def read_gguf_defaults(metadata: dict[str, str]) -> ModelDefaults:
102 """Extract generation defaults from a GGUF metadata dict.
103 Looks for keys like ``general.temperature``, ``context_length`` (via the
104 architecture-prefixed key already resolved by the caller into
105 ``context_length``).
106 """
107 values: dict[str, Any] = {}
108 for gguf_key, field_name in _GGUF_KEY_MAP.items():
109 if gguf_key in metadata:
110 try:
111 target_type = _field_type(field_name)
112 values[field_name] = target_type(metadata[gguf_key])
113 except (ValueError, TypeError):
114 log.debug("Skipping unparseable GGUF key %s=%r", gguf_key, metadata[gguf_key])
115 if "context_length" in metadata:
116 with contextlib.suppress(ValueError, TypeError):
117 values["num_ctx"] = int(metadata["context_length"])
118 return ModelDefaults(**values)
121def _field_type(field_name: str) -> type:
122 """Return the base type for a ModelDefaults field."""
123 for f in fields(ModelDefaults):
124 if f.name == field_name:
125 return int if "int" in str(f.type) else float
126 return float # pragma: no cover