Coverage for src / lilbee / providers / model_cache.py: 100%

165 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Memory-aware LRU model cache for llama-cpp-python instances. 

2 

3Tracks loaded Llama models in an OrderedDict, evicting least-recently-used 

4entries when memory is tight or keep-alive TTL expires. Thread-safe via Lock. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import platform 

11import threading 

12import time 

13from collections import OrderedDict 

14from dataclasses import dataclass, field 

15from pathlib import Path 

16from typing import Any, Literal 

17 

18log = logging.getLogger(__name__) 

19 

20LoaderMode = Literal["chat", "embed", "rerank"] 

21MODE_CHAT: LoaderMode = "chat" 

22MODE_EMBED: LoaderMode = "embed" 

23MODE_RERANK: LoaderMode = "rerank" 

24 

25# Fallback KV cache estimate when GGUF metadata can't be read. 

26# 2048 bytes/token undershoots real KV size for modern models (Gemma3-4B is 

27# ~640 KB/token f16) but is fine as a coarse pre-load eviction signal. 

28_KV_BYTES_PER_CTX_TOKEN = 2048 

29 

30# Metal/CUDA buffer overhead as fraction of model weight memory 

31_BUFFER_OVERHEAD_FRACTION = 0.10 

32 

33# Default context length for estimation when metadata unavailable 

34_DEFAULT_CTX_LEN = 2048 

35 

36# Floor for the dynamic n_ctx computation (smaller is unusable for chat) 

37_DYNAMIC_CTX_FLOOR = 512 

38 

39# Round dynamic n_ctx down to a multiple of this (clean batch sizes) 

40_DYNAMIC_CTX_QUANTUM = 256 

41 

42# KV cache element size for f16 (bytes). Quantized KV reduces this. 

43_KV_ELEM_BYTES_F16 = 2 

44 

45 

46@dataclass 

47class _CacheEntry: 

48 """A loaded model with metadata for eviction decisions.""" 

49 

50 model: Any 

51 path: Path 

52 mode: LoaderMode 

53 estimated_bytes: int 

54 loaded_at: float = field(default_factory=time.monotonic) 

55 last_used: float = field(default_factory=time.monotonic) 

56 

57 def touch(self) -> None: 

58 """Update last-used timestamp.""" 

59 self.last_used = time.monotonic() 

60 

61 @property 

62 def embedding(self) -> bool: 

63 """True if the underlying Llama was opened with embedding=True.""" 

64 return self.mode in (MODE_EMBED, MODE_RERANK) 

65 

66 

67def kv_bytes_per_token(meta: dict[str, str] | None, kv_elem_bytes: int = _KV_ELEM_BYTES_F16) -> int: 

68 """Estimate per-token KV cache size in bytes from GGUF metadata. 

69 

70 Formula: 2 (K + V) * n_layers * n_kv_heads * head_dim * elem_bytes. 

71 Falls back to ``_KV_BYTES_PER_CTX_TOKEN`` when metadata is missing. 

72 """ 

73 if not meta: 

74 return _KV_BYTES_PER_CTX_TOKEN 

75 try: 

76 n_layers = int(meta["block_count"]) 

77 head_count_kv = int(meta.get("head_count_kv") or meta["head_count"]) 

78 if "key_length" in meta and "value_length" in meta: 

79 kv_dim = int(meta["key_length"]) + int(meta["value_length"]) 

80 else: 

81 embed = int(meta["embedding_length"]) 

82 head_count = int(meta.get("head_count") or head_count_kv) 

83 head_dim = embed // head_count 

84 kv_dim = 2 * head_dim 

85 except (KeyError, ValueError, ZeroDivisionError): 

86 return _KV_BYTES_PER_CTX_TOKEN 

87 return n_layers * head_count_kv * kv_dim * kv_elem_bytes 

88 

89 

90def estimate_model_memory( 

91 model_path: Path, 

92 n_ctx: int = _DEFAULT_CTX_LEN, 

93 kv_bytes_per_tok: int = _KV_BYTES_PER_CTX_TOKEN, 

94) -> int: 

95 """Estimate memory consumption for a GGUF model. 

96 Approximation: file_size (weights) + KV cache + 10% buffer overhead. 

97 """ 

98 file_bytes = model_path.stat().st_size if model_path.exists() else 0 

99 kv_bytes = n_ctx * kv_bytes_per_tok 

100 overhead = int(file_bytes * _BUFFER_OVERHEAD_FRACTION) 

101 return file_bytes + kv_bytes + overhead 

102 

103 

104def compute_dynamic_ctx( 

105 *, 

106 model_bytes: int, 

107 available_bytes: int, 

108 training_ctx: int, 

109 kv_bytes_per_tok: int, 

110 ceiling: int, 

111 floor: int = _DYNAMIC_CTX_FLOOR, 

112 quantum: int = _DYNAMIC_CTX_QUANTUM, 

113) -> int: 

114 """Pick the largest n_ctx that fits in available memory. 

115 

116 Subtracts model weights and a 10% buffer overhead from ``available_bytes``, 

117 then divides the remainder by ``kv_bytes_per_tok``. Clamps to 

118 ``[floor, min(training_ctx, ceiling)]`` and rounds down to ``quantum``. 

119 """ 

120 if kv_bytes_per_tok <= 0: 

121 return min(training_ctx, ceiling) 

122 overhead = int(model_bytes * _BUFFER_OVERHEAD_FRACTION) 

123 budget = available_bytes - model_bytes - overhead 

124 if budget <= 0: 

125 return floor 

126 raw_ctx = budget // kv_bytes_per_tok 

127 upper = min(training_ctx, ceiling) 

128 bounded = max(floor, min(raw_ctx, upper)) 

129 quantized = (bounded // quantum) * quantum 

130 return max(floor, quantized) 

131 

132 

133def get_available_memory(fraction: float) -> int: 

134 """Return usable GPU/unified memory in bytes, scaled by *fraction*. 

135 - macOS (Apple Silicon): unified memory via psutil 

136 - Linux with NVIDIA GPU: pynvml -> nvidia-smi -> psutil fallback 

137 - Other: psutil system memory 

138 """ 

139 import psutil 

140 

141 system = platform.system() 

142 

143 if system == "Darwin": 

144 total = psutil.virtual_memory().total 

145 return int(total * fraction) 

146 

147 if system in ("Linux", "Windows"): 

148 nvidia_mem = _try_nvidia_memory() 

149 if nvidia_mem is not None: 

150 return int(nvidia_mem * fraction) 

151 

152 total = psutil.virtual_memory().total 

153 return int(total * fraction) 

154 

155 

156def _try_nvidia_memory() -> int | None: 

157 """Try to get NVIDIA GPU total memory via pynvml, then nvidia-smi.""" 

158 try: 

159 import pynvml # type: ignore[import-untyped] 

160 

161 pynvml.nvmlInit() 

162 handle = pynvml.nvmlDeviceGetHandleByIndex(0) 

163 info = pynvml.nvmlDeviceGetMemoryInfo(handle) 

164 pynvml.nvmlShutdown() 

165 return int(info.total) 

166 except Exception: # noqa: S110 -- optional GPU detect; absence is expected on non-NVIDIA hosts 

167 pass 

168 

169 try: 

170 import subprocess 

171 

172 # nvidia-smi ships with the NVIDIA driver and is always on PATH when 

173 # present; fully-qualifying it would break on every install layout. 

174 result = subprocess.run( 

175 ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"], # noqa: S607 

176 capture_output=True, 

177 text=True, 

178 timeout=5, 

179 ) 

180 if result.returncode == 0: 

181 mib = int(result.stdout.strip().split("\n")[0]) 

182 return mib * 1024 * 1024 

183 except Exception: # noqa: S110 -- optional GPU detect; same rationale as above 

184 pass 

185 

186 return None 

187 

188 

189class MemoryAwareModelCache: 

190 """LRU cache for Llama model instances with memory-aware eviction. 

191 Models are evicted when: 

192 - A new model won't fit in the memory budget (LRU evicted first) 

193 - A model's keep-alive TTL has expired (checked on load and via evict_stale) 

194 """ 

195 

196 def __init__( 

197 self, 

198 max_memory_fraction: float = 0.75, 

199 keep_alive_seconds: int = 300, 

200 loader: Any = None, 

201 ) -> None: 

202 self._fraction = max_memory_fraction 

203 self._keep_alive = keep_alive_seconds 

204 self._cache: OrderedDict[str, _CacheEntry] = OrderedDict() 

205 self._lock = threading.Lock() 

206 self._loader = loader 

207 

208 def load_model(self, model_path: Path, mode: LoaderMode) -> Any: 

209 """Load or return a cached Llama model instance. 

210 

211 Evicts stale entries first, then evicts LRU if memory is tight. 

212 The cache key includes the mode so chat, embed, and rerank loaders 

213 stay isolated (they construct Llama with different flags and must 

214 not be aliased). 

215 """ 

216 key = f"{model_path}:{mode}" 

217 

218 with self._lock: 

219 self._evict_stale_locked() 

220 

221 if key in self._cache: 

222 entry = self._cache[key] 

223 entry.touch() 

224 self._cache.move_to_end(key) 

225 log.debug("Cache hit: %s (%s)", model_path.name, mode) 

226 return entry.model 

227 

228 estimated = estimate_model_memory(model_path) 

229 available = get_available_memory(self._fraction) 

230 self._evict_for_space_locked(estimated, available) 

231 

232 log.info( 

233 "Loading model %s in %s mode (est. %d MB, available %d MB)", 

234 model_path.name, 

235 mode, 

236 estimated // (1024 * 1024), 

237 available // (1024 * 1024), 

238 ) 

239 model = self._loader(model_path, mode=mode) 

240 

241 self._cache[key] = _CacheEntry( 

242 model=model, 

243 path=model_path, 

244 mode=mode, 

245 estimated_bytes=estimated, 

246 ) 

247 return model 

248 

249 def _evict_stale_locked(self) -> int: 

250 """Remove models past keep_alive TTL. Must hold self._lock.""" 

251 if self._keep_alive <= 0: 

252 return 0 

253 now = time.monotonic() 

254 stale_keys = [ 

255 k for k, entry in self._cache.items() if (now - entry.last_used) > self._keep_alive 

256 ] 

257 for k in stale_keys: 

258 self._unload_entry(k) 

259 return len(stale_keys) 

260 

261 def _evict_for_space_locked(self, needed: int, available: int) -> None: 

262 """Evict LRU entries until *needed* bytes fit within *available*.""" 

263 current_usage = sum(e.estimated_bytes for e in self._cache.values()) 

264 while self._cache and (current_usage + needed) > available: 

265 oldest_key = next(iter(self._cache)) 

266 oldest = self._cache[oldest_key] 

267 current_usage -= oldest.estimated_bytes 

268 log.info("Evicting LRU model %s to free memory", oldest.path.name) 

269 self._unload_entry(oldest_key) 

270 

271 def _unload_entry(self, key: str) -> None: 

272 """Remove and close a single cache entry. Must hold self._lock.""" 

273 entry = self._cache.pop(key, None) 

274 if entry is not None: 

275 try: 

276 entry.model.close() 

277 except AttributeError: 

278 pass 

279 except Exception: 

280 log.debug("Error closing model %s", entry.path.name, exc_info=True) 

281 

282 def evict_stale(self) -> int: 

283 """Remove models past keep_alive TTL. Returns count evicted.""" 

284 with self._lock: 

285 return self._evict_stale_locked() 

286 

287 def unload_all(self) -> None: 

288 """Clear entire cache, closing all models.""" 

289 with self._lock: 

290 keys = list(self._cache.keys()) 

291 for k in keys: 

292 self._unload_entry(k) 

293 

294 def unload_path(self, model_path: Path) -> int: 

295 """Evict every cache entry for *model_path* (across all modes). Returns count.""" 

296 with self._lock: 

297 keys = [k for k, entry in self._cache.items() if entry.path == model_path] 

298 for k in keys: 

299 self._unload_entry(k) 

300 return len(keys) 

301 

302 def get_stats(self) -> dict[str, Any]: 

303 """Return cache statistics for monitoring.""" 

304 with self._lock: 

305 entries = [] 

306 for _key, entry in self._cache.items(): 

307 entries.append( 

308 { 

309 "path": str(entry.path), 

310 "mode": entry.mode, 

311 "embedding": entry.embedding, 

312 "estimated_mb": entry.estimated_bytes // (1024 * 1024), 

313 "age_seconds": int(time.monotonic() - entry.loaded_at), 

314 "idle_seconds": int(time.monotonic() - entry.last_used), 

315 } 

316 ) 

317 return { 

318 "loaded_models": len(self._cache), 

319 "total_estimated_mb": sum(e.estimated_bytes for e in self._cache.values()) 

320 // (1024 * 1024), 

321 "keep_alive_seconds": self._keep_alive, 

322 "memory_fraction": self._fraction, 

323 "models": entries, 

324 }