Coverage for src / lilbee / providers / model_cache.py: 100%
165 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Memory-aware LRU model cache for llama-cpp-python instances.
3Tracks loaded Llama models in an OrderedDict, evicting least-recently-used
4entries when memory is tight or keep-alive TTL expires. Thread-safe via Lock.
5"""
7from __future__ import annotations
9import logging
10import platform
11import threading
12import time
13from collections import OrderedDict
14from dataclasses import dataclass, field
15from pathlib import Path
16from typing import Any, Literal
18log = logging.getLogger(__name__)
20LoaderMode = Literal["chat", "embed", "rerank"]
21MODE_CHAT: LoaderMode = "chat"
22MODE_EMBED: LoaderMode = "embed"
23MODE_RERANK: LoaderMode = "rerank"
25# Fallback KV cache estimate when GGUF metadata can't be read.
26# 2048 bytes/token undershoots real KV size for modern models (Gemma3-4B is
27# ~640 KB/token f16) but is fine as a coarse pre-load eviction signal.
28_KV_BYTES_PER_CTX_TOKEN = 2048
30# Metal/CUDA buffer overhead as fraction of model weight memory
31_BUFFER_OVERHEAD_FRACTION = 0.10
33# Default context length for estimation when metadata unavailable
34_DEFAULT_CTX_LEN = 2048
36# Floor for the dynamic n_ctx computation (smaller is unusable for chat)
37_DYNAMIC_CTX_FLOOR = 512
39# Round dynamic n_ctx down to a multiple of this (clean batch sizes)
40_DYNAMIC_CTX_QUANTUM = 256
42# KV cache element size for f16 (bytes). Quantized KV reduces this.
43_KV_ELEM_BYTES_F16 = 2
46@dataclass
47class _CacheEntry:
48 """A loaded model with metadata for eviction decisions."""
50 model: Any
51 path: Path
52 mode: LoaderMode
53 estimated_bytes: int
54 loaded_at: float = field(default_factory=time.monotonic)
55 last_used: float = field(default_factory=time.monotonic)
57 def touch(self) -> None:
58 """Update last-used timestamp."""
59 self.last_used = time.monotonic()
61 @property
62 def embedding(self) -> bool:
63 """True if the underlying Llama was opened with embedding=True."""
64 return self.mode in (MODE_EMBED, MODE_RERANK)
67def kv_bytes_per_token(meta: dict[str, str] | None, kv_elem_bytes: int = _KV_ELEM_BYTES_F16) -> int:
68 """Estimate per-token KV cache size in bytes from GGUF metadata.
70 Formula: 2 (K + V) * n_layers * n_kv_heads * head_dim * elem_bytes.
71 Falls back to ``_KV_BYTES_PER_CTX_TOKEN`` when metadata is missing.
72 """
73 if not meta:
74 return _KV_BYTES_PER_CTX_TOKEN
75 try:
76 n_layers = int(meta["block_count"])
77 head_count_kv = int(meta.get("head_count_kv") or meta["head_count"])
78 if "key_length" in meta and "value_length" in meta:
79 kv_dim = int(meta["key_length"]) + int(meta["value_length"])
80 else:
81 embed = int(meta["embedding_length"])
82 head_count = int(meta.get("head_count") or head_count_kv)
83 head_dim = embed // head_count
84 kv_dim = 2 * head_dim
85 except (KeyError, ValueError, ZeroDivisionError):
86 return _KV_BYTES_PER_CTX_TOKEN
87 return n_layers * head_count_kv * kv_dim * kv_elem_bytes
90def estimate_model_memory(
91 model_path: Path,
92 n_ctx: int = _DEFAULT_CTX_LEN,
93 kv_bytes_per_tok: int = _KV_BYTES_PER_CTX_TOKEN,
94) -> int:
95 """Estimate memory consumption for a GGUF model.
96 Approximation: file_size (weights) + KV cache + 10% buffer overhead.
97 """
98 file_bytes = model_path.stat().st_size if model_path.exists() else 0
99 kv_bytes = n_ctx * kv_bytes_per_tok
100 overhead = int(file_bytes * _BUFFER_OVERHEAD_FRACTION)
101 return file_bytes + kv_bytes + overhead
104def compute_dynamic_ctx(
105 *,
106 model_bytes: int,
107 available_bytes: int,
108 training_ctx: int,
109 kv_bytes_per_tok: int,
110 ceiling: int,
111 floor: int = _DYNAMIC_CTX_FLOOR,
112 quantum: int = _DYNAMIC_CTX_QUANTUM,
113) -> int:
114 """Pick the largest n_ctx that fits in available memory.
116 Subtracts model weights and a 10% buffer overhead from ``available_bytes``,
117 then divides the remainder by ``kv_bytes_per_tok``. Clamps to
118 ``[floor, min(training_ctx, ceiling)]`` and rounds down to ``quantum``.
119 """
120 if kv_bytes_per_tok <= 0:
121 return min(training_ctx, ceiling)
122 overhead = int(model_bytes * _BUFFER_OVERHEAD_FRACTION)
123 budget = available_bytes - model_bytes - overhead
124 if budget <= 0:
125 return floor
126 raw_ctx = budget // kv_bytes_per_tok
127 upper = min(training_ctx, ceiling)
128 bounded = max(floor, min(raw_ctx, upper))
129 quantized = (bounded // quantum) * quantum
130 return max(floor, quantized)
133def get_available_memory(fraction: float) -> int:
134 """Return usable GPU/unified memory in bytes, scaled by *fraction*.
135 - macOS (Apple Silicon): unified memory via psutil
136 - Linux with NVIDIA GPU: pynvml -> nvidia-smi -> psutil fallback
137 - Other: psutil system memory
138 """
139 import psutil
141 system = platform.system()
143 if system == "Darwin":
144 total = psutil.virtual_memory().total
145 return int(total * fraction)
147 if system in ("Linux", "Windows"):
148 nvidia_mem = _try_nvidia_memory()
149 if nvidia_mem is not None:
150 return int(nvidia_mem * fraction)
152 total = psutil.virtual_memory().total
153 return int(total * fraction)
156def _try_nvidia_memory() -> int | None:
157 """Try to get NVIDIA GPU total memory via pynvml, then nvidia-smi."""
158 try:
159 import pynvml # type: ignore[import-untyped]
161 pynvml.nvmlInit()
162 handle = pynvml.nvmlDeviceGetHandleByIndex(0)
163 info = pynvml.nvmlDeviceGetMemoryInfo(handle)
164 pynvml.nvmlShutdown()
165 return int(info.total)
166 except Exception: # noqa: S110 -- optional GPU detect; absence is expected on non-NVIDIA hosts
167 pass
169 try:
170 import subprocess
172 # nvidia-smi ships with the NVIDIA driver and is always on PATH when
173 # present; fully-qualifying it would break on every install layout.
174 result = subprocess.run(
175 ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"], # noqa: S607
176 capture_output=True,
177 text=True,
178 timeout=5,
179 )
180 if result.returncode == 0:
181 mib = int(result.stdout.strip().split("\n")[0])
182 return mib * 1024 * 1024
183 except Exception: # noqa: S110 -- optional GPU detect; same rationale as above
184 pass
186 return None
189class MemoryAwareModelCache:
190 """LRU cache for Llama model instances with memory-aware eviction.
191 Models are evicted when:
192 - A new model won't fit in the memory budget (LRU evicted first)
193 - A model's keep-alive TTL has expired (checked on load and via evict_stale)
194 """
196 def __init__(
197 self,
198 max_memory_fraction: float = 0.75,
199 keep_alive_seconds: int = 300,
200 loader: Any = None,
201 ) -> None:
202 self._fraction = max_memory_fraction
203 self._keep_alive = keep_alive_seconds
204 self._cache: OrderedDict[str, _CacheEntry] = OrderedDict()
205 self._lock = threading.Lock()
206 self._loader = loader
208 def load_model(self, model_path: Path, mode: LoaderMode) -> Any:
209 """Load or return a cached Llama model instance.
211 Evicts stale entries first, then evicts LRU if memory is tight.
212 The cache key includes the mode so chat, embed, and rerank loaders
213 stay isolated (they construct Llama with different flags and must
214 not be aliased).
215 """
216 key = f"{model_path}:{mode}"
218 with self._lock:
219 self._evict_stale_locked()
221 if key in self._cache:
222 entry = self._cache[key]
223 entry.touch()
224 self._cache.move_to_end(key)
225 log.debug("Cache hit: %s (%s)", model_path.name, mode)
226 return entry.model
228 estimated = estimate_model_memory(model_path)
229 available = get_available_memory(self._fraction)
230 self._evict_for_space_locked(estimated, available)
232 log.info(
233 "Loading model %s in %s mode (est. %d MB, available %d MB)",
234 model_path.name,
235 mode,
236 estimated // (1024 * 1024),
237 available // (1024 * 1024),
238 )
239 model = self._loader(model_path, mode=mode)
241 self._cache[key] = _CacheEntry(
242 model=model,
243 path=model_path,
244 mode=mode,
245 estimated_bytes=estimated,
246 )
247 return model
249 def _evict_stale_locked(self) -> int:
250 """Remove models past keep_alive TTL. Must hold self._lock."""
251 if self._keep_alive <= 0:
252 return 0
253 now = time.monotonic()
254 stale_keys = [
255 k for k, entry in self._cache.items() if (now - entry.last_used) > self._keep_alive
256 ]
257 for k in stale_keys:
258 self._unload_entry(k)
259 return len(stale_keys)
261 def _evict_for_space_locked(self, needed: int, available: int) -> None:
262 """Evict LRU entries until *needed* bytes fit within *available*."""
263 current_usage = sum(e.estimated_bytes for e in self._cache.values())
264 while self._cache and (current_usage + needed) > available:
265 oldest_key = next(iter(self._cache))
266 oldest = self._cache[oldest_key]
267 current_usage -= oldest.estimated_bytes
268 log.info("Evicting LRU model %s to free memory", oldest.path.name)
269 self._unload_entry(oldest_key)
271 def _unload_entry(self, key: str) -> None:
272 """Remove and close a single cache entry. Must hold self._lock."""
273 entry = self._cache.pop(key, None)
274 if entry is not None:
275 try:
276 entry.model.close()
277 except AttributeError:
278 pass
279 except Exception:
280 log.debug("Error closing model %s", entry.path.name, exc_info=True)
282 def evict_stale(self) -> int:
283 """Remove models past keep_alive TTL. Returns count evicted."""
284 with self._lock:
285 return self._evict_stale_locked()
287 def unload_all(self) -> None:
288 """Clear entire cache, closing all models."""
289 with self._lock:
290 keys = list(self._cache.keys())
291 for k in keys:
292 self._unload_entry(k)
294 def unload_path(self, model_path: Path) -> int:
295 """Evict every cache entry for *model_path* (across all modes). Returns count."""
296 with self._lock:
297 keys = [k for k, entry in self._cache.items() if entry.path == model_path]
298 for k in keys:
299 self._unload_entry(k)
300 return len(keys)
302 def get_stats(self) -> dict[str, Any]:
303 """Return cache statistics for monitoring."""
304 with self._lock:
305 entries = []
306 for _key, entry in self._cache.items():
307 entries.append(
308 {
309 "path": str(entry.path),
310 "mode": entry.mode,
311 "embedding": entry.embedding,
312 "estimated_mb": entry.estimated_bytes // (1024 * 1024),
313 "age_seconds": int(time.monotonic() - entry.loaded_at),
314 "idle_seconds": int(time.monotonic() - entry.last_used),
315 }
316 )
317 return {
318 "loaded_models": len(self._cache),
319 "total_estimated_mb": sum(e.estimated_bytes for e in self._cache.values())
320 // (1024 * 1024),
321 "keep_alive_seconds": self._keep_alive,
322 "memory_fraction": self._fraction,
323 "models": entries,
324 }