Coverage for src / lilbee / model_manager.py: 100%

231 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Model lifecycle management across native GGUF and SDK-backed sources.""" 

2 

3import json 

4import logging 

5import os 

6import time 

7from collections.abc import Callable 

8from dataclasses import dataclass 

9from enum import Enum 

10from pathlib import Path 

11 

12import httpx 

13 

14from lilbee.config import DEFAULT_HTTP_TIMEOUT 

15from lilbee.models import ModelTask 

16from lilbee.providers.sdk_backend import ( 

17 PROVIDER_KEYS, 

18 REMOTE_BACKEND_NAME, 

19 detect_backend_name, 

20) 

21from lilbee.security import validate_path_within 

22 

23# lilbee.config imported lazily; see lilbee/catalog.py for the rationale. 

24 

25log = logging.getLogger(__name__) 

26 

27 

28class ModelSource(Enum): 

29 """Where a model is stored.""" 

30 

31 NATIVE = "native" # lilbee's GGUF files in cfg.models_dir 

32 REMOTE = "remote" # Models managed by a remote SDK-backed service 

33 

34 @classmethod 

35 def parse(cls, value: str | None) -> "ModelSource | None": 

36 """Parse a user-supplied source string. Empty or None means 'all'. 

37 

38 Raises ValueError on any other non-empty input so callers get a 

39 consistent error type regardless of entry point (CLI, MCP, server). 

40 """ 

41 if value is None or value == "": 

42 return None 

43 try: 

44 return cls(value) 

45 except ValueError as exc: 

46 valid = ", ".join(s.value for s in cls) 

47 raise ValueError(f"invalid source {value!r}; expected one of: {valid}") from exc 

48 

49 

50class ModelNotFoundError(RuntimeError): 

51 """Raised when a model ref is not found in any source or the catalog. 

52 

53 Subclasses RuntimeError so pre-existing `except RuntimeError` call 

54 sites still catch it. 

55 """ 

56 

57 

58_INSTALLED_CACHE_TTL_SECONDS = 30.0 

59 

60 

61class ModelManager: 

62 """Manages model lifecycle with distinct sources.""" 

63 

64 def __init__(self, models_dir: Path, remote_base_url: str = "http://localhost:11434") -> None: 

65 self._models_dir = models_dir 

66 self._remote_base_url = remote_base_url.rstrip("/") 

67 

68 from lilbee.registry import ModelRegistry 

69 

70 self._registry = ModelRegistry(self._models_dir) 

71 # Memoize list_installed results to avoid walking the registry 

72 # filesystem and hitting the backend HTTP endpoint on every call. 

73 # The catalog filter path fires this per request. Time-based TTL 

74 # plus explicit invalidation on pull/remove keeps freshness. 

75 self._installed_cache: dict[ModelSource | None, tuple[float, list[str]]] = {} 

76 

77 def list_installed(self, source: ModelSource | None = None) -> list[str]: 

78 """List installed model names. ``source=None`` lists all sources. 

79 

80 Memoized with a ``_INSTALLED_CACHE_TTL_SECONDS`` TTL and 

81 invalidated eagerly by ``pull``/``remove``. 

82 """ 

83 now = time.monotonic() 

84 cached = self._installed_cache.get(source) 

85 if cached is not None: 

86 cached_at, cached_result = cached 

87 if now - cached_at < _INSTALLED_CACHE_TTL_SECONDS: 

88 return cached_result 

89 

90 if source is None: 

91 native = set(self._list_native()) 

92 remote = set(self._list_remote()) 

93 result = sorted(native | remote) 

94 elif source is ModelSource.NATIVE: 

95 result = self._list_native() 

96 else: 

97 result = self._list_remote() 

98 

99 self._installed_cache[source] = (now, result) 

100 return result 

101 

102 def _invalidate_installed_cache(self) -> None: 

103 """Drop all cached list_installed results.""" 

104 self._installed_cache.clear() 

105 

106 def _list_native(self) -> list[str]: 

107 """List native models from the registry only.""" 

108 return sorted(m.ref for m in self._registry.list_installed()) 

109 

110 def _list_remote(self) -> list[str]: 

111 """List models from the SDK backend via its HTTP API.""" 

112 url = f"{self._remote_base_url}/api/tags" 

113 try: 

114 resp = httpx.get(url, timeout=DEFAULT_HTTP_TIMEOUT) 

115 resp.raise_for_status() 

116 data = resp.json() 

117 return [m["name"] for m in data.get("models", [])] 

118 except httpx.HTTPStatusError as exc: 

119 log.warning("SDK backend HTTP error listing models: %s", exc) 

120 return [] 

121 except (httpx.ConnectError, httpx.TimeoutException) as exc: 

122 log.debug("SDK backend not reachable: %s", exc) 

123 return [] 

124 

125 def is_installed(self, model: str, source: ModelSource | None = None) -> bool: 

126 """Check if model exists in specified source.""" 

127 if source is None: 

128 return self._is_native(model) or self._is_remote(model) 

129 if source is ModelSource.NATIVE: 

130 return self._is_native(model) 

131 return self._is_remote(model) 

132 

133 def _is_native(self, model: str) -> bool: 

134 if self._registry.is_installed(model): 

135 return True 

136 try: 

137 validate_path_within(self._models_dir / model, self._models_dir) 

138 except ValueError: 

139 return False 

140 return (self._models_dir / model).is_file() 

141 

142 def _is_remote(self, model: str) -> bool: 

143 return model in self.list_installed(ModelSource.REMOTE) 

144 

145 def get_source(self, model: str) -> ModelSource | None: 

146 """Find which source a model lives in. Native takes precedence.""" 

147 if self._is_native(model): 

148 return ModelSource.NATIVE 

149 if self._is_remote(model): 

150 return ModelSource.REMOTE 

151 return None 

152 

153 def pull( 

154 self, 

155 model: str, 

156 source: ModelSource, 

157 *, 

158 on_progress: Callable[[dict], None] | None = None, 

159 on_bytes: Callable[[int, int], None] | None = None, 

160 ) -> Path | None: 

161 """Pull/download model to specified source. 

162 

163 Returns the Path for native downloads, None for backend-managed pulls. 

164 

165 *on_progress* receives dict events from the SDK backend. 

166 *on_bytes* receives (downloaded_bytes, total_bytes) from native 

167 HuggingFace downloads. The two sources report progress in different 

168 shapes, so callers pass whichever matches the chosen source. 

169 """ 

170 try: 

171 if source is ModelSource.NATIVE: 

172 return self._pull_native(model, on_bytes=on_bytes) 

173 self._pull_remote(model, on_progress=on_progress) 

174 return None 

175 finally: 

176 self._invalidate_installed_cache() 

177 

178 def _pull_native( 

179 self, 

180 model: str, 

181 *, 

182 on_bytes: Callable[[int, int], None] | None = None, 

183 ) -> Path: 

184 """Download a featured or ad-hoc HuggingFace model to the native GGUF directory.""" 

185 from lilbee.catalog import download_model, resolve_pull_target 

186 

187 entry = resolve_pull_target(model) 

188 if entry is None: 

189 raise ModelNotFoundError( 

190 f"Model '{model}' not recognized. " 

191 "Pass a HuggingFace repo id (owner/name) or a featured model name." 

192 ) 

193 path = download_model(entry, on_progress=on_bytes) 

194 log.info("Downloaded %s to %s", model, path) 

195 return path 

196 

197 def _pull_remote( 

198 self, model: str, *, on_progress: Callable[[dict], None] | None = None 

199 ) -> None: 

200 """Pull model via the SDK backend's HTTP API with streaming progress.""" 

201 url = f"{self._remote_base_url}/api/pull" 

202 try: 

203 with ( 

204 # Model pulls stream progress over minutes; an overall 

205 # timeout would cut the download. Connect/write timeouts 

206 # still apply via httpx defaults when timeout=None. 

207 httpx.Client(timeout=None) as client, # noqa: S113 

208 client.stream("POST", url, json={"name": model, "stream": True}) as resp, 

209 ): 

210 resp.raise_for_status() 

211 for line in resp.iter_lines(): 

212 if not line: 

213 continue 

214 data = json.loads(line) 

215 if "error" in data: 

216 raise RuntimeError(f"Failed to pull '{model}': {data['error']}") 

217 if on_progress is not None: 

218 on_progress(data) 

219 except httpx.ConnectError as exc: 

220 raise RuntimeError( 

221 f"Cannot connect to SDK backend: {exc}. Is the server running?" 

222 ) from exc 

223 log.info("Pulled %s via SDK backend", model) 

224 

225 def remove(self, model: str, source: ModelSource | None = None) -> bool: 

226 """Remove installed model. Returns True if removed.""" 

227 try: 

228 if source is None: 

229 native_removed = self._remove_native(model) 

230 backend_removed = self._remove_remote(model) 

231 return native_removed or backend_removed 

232 if source is ModelSource.NATIVE: 

233 return self._remove_native(model) 

234 return self._remove_remote(model) 

235 finally: 

236 self._invalidate_installed_cache() 

237 

238 def _remove_native(self, model: str) -> bool: 

239 if self._registry.remove(model): 

240 log.info("Removed native model %s from registry", model) 

241 return True 

242 try: 

243 path = validate_path_within(self._models_dir / model, self._models_dir) 

244 except ValueError: 

245 log.warning("Path traversal blocked: %s escapes %s", model, self._models_dir) 

246 return False 

247 if path.is_file(): 

248 path.unlink() 

249 log.info("Removed native model %s", model) 

250 return True 

251 return False 

252 

253 def _remove_remote(self, model: str) -> bool: 

254 url = f"{self._remote_base_url}/api/delete" 

255 try: 

256 resp = httpx.request( 

257 "DELETE", 

258 url, 

259 content=json.dumps({"model": model}).encode(), 

260 headers={"Content-Type": "application/json"}, 

261 timeout=DEFAULT_HTTP_TIMEOUT, 

262 ) 

263 if resp.status_code == 200: 

264 log.info("Removed backend model %s", model) 

265 return True 

266 if resp.status_code == 404: 

267 return False 

268 log.warning("Unexpected status %d removing %s", resp.status_code, model) 

269 return False 

270 except httpx.ConnectError as exc: 

271 raise RuntimeError( 

272 f"Cannot connect to SDK backend: {exc}. Is the server running?" 

273 ) from exc 

274 

275 

276_EMBEDDING_FAMILIES = frozenset({"bert", "nomic-bert", "e5", "bge"}) 

277_VISION_NAME_PATTERNS = frozenset({"llava", "vision", "moondream", "ocr", "minicpm-v"}) 

278# Reranker detection runs BEFORE embedding detection so ``bge-reranker-*`` 

279# (family "bge" but clearly a reranker) does not get misclassified as 

280# EMBEDDING. ``cross-encoder`` covers the SBERT-style naming convention 

281# for cross-encoder rerankers. 

282_RERANKER_NAME_PATTERNS = frozenset({"reranker", "rerank", "cross-encoder"}) 

283 

284 

285@dataclass 

286class RemoteModel: 

287 """A model from the SDK backend with inferred task classification.""" 

288 

289 name: str 

290 task: str # "chat", "embedding", "vision", "rerank" 

291 family: str 

292 parameter_size: str 

293 provider: str = REMOTE_BACKEND_NAME 

294 

295 

296def _classify_remote_task(name: str, family: str) -> str: 

297 """Classify a remote model as chat, embedding, vision, or rerank. 

298 

299 Reranker detection runs first so ``bge-reranker-base`` (family 

300 ``bge``) does not get dragged into the embedding bucket by the 

301 family check. After reranker: embedding by family tag, vision by 

302 name pattern, else chat. 

303 """ 

304 name_lower = name.lower() 

305 if any(rp in name_lower for rp in _RERANKER_NAME_PATTERNS): 

306 return ModelTask.RERANK 

307 family_lower = family.lower() 

308 if any(ef in family_lower for ef in _EMBEDDING_FAMILIES): 

309 return ModelTask.EMBEDDING 

310 if any(vp in name_lower for vp in _VISION_NAME_PATTERNS): 

311 return ModelTask.VISION 

312 return ModelTask.CHAT 

313 

314 

315_CLASSIFY_DEFAULT_TIMEOUT_S = 5.0 

316 

317 

318def classify_remote_models( 

319 base_url: str = "http://localhost:11434", 

320 *, 

321 timeout: float = _CLASSIFY_DEFAULT_TIMEOUT_S, 

322) -> list[RemoteModel]: 

323 """Discover and classify all models from the SDK backend by task. 

324 

325 Uses /api/tags family metadata for embedding detection and name 

326 patterns for reranker and vision detection. Returns an empty list 

327 on any error (including timeout) so callers in read-only code paths 

328 can stay responsive when the backend is down. 

329 """ 

330 try: 

331 resp = httpx.get(f"{base_url}/api/tags", timeout=timeout) 

332 resp.raise_for_status() 

333 raw_models = resp.json().get("models", []) 

334 except Exception: 

335 log.debug("Failed to classify remote models", exc_info=True) 

336 return [] 

337 

338 provider = detect_backend_name(base_url) 

339 result: list[RemoteModel] = [] 

340 for model in raw_models: 

341 name = model.get("name", "") 

342 details = model.get("details", {}) 

343 family = details.get("family", "") 

344 param_size = details.get("parameter_size", "") 

345 task = _classify_remote_task(name, family) 

346 result.append( 

347 RemoteModel( 

348 name=name, task=task, family=family, parameter_size=param_size, provider=provider 

349 ) 

350 ) 

351 return result 

352 

353 

354def _has_provider_key(cfg_field: str, env_var: str) -> bool: 

355 """Return True if a usable API key exists via env var or lilbee config.""" 

356 if os.environ.get(env_var): 

357 return True 

358 # circular: model_manager -> config via cfg; config.py's catalog 

359 # imports pull model_manager in transitively, so we defer cfg access 

360 # to call time to avoid a module-init cycle. 

361 from lilbee.config import cfg 

362 

363 return bool(getattr(cfg, cfg_field, "")) 

364 

365 

366def discover_api_models() -> dict[str, list[RemoteModel]]: 

367 """Return frontier chat models grouped by provider. 

368 

369 Checks which provider API keys are available (env vars or config), 

370 then asks the active provider's backend for each provider's chat 

371 catalog. Going through ``SdkLLMProvider.list_chat_models`` ensures 

372 ``cfg.json_mode`` suppression is applied before the SDK import, 

373 so ``lilbee --json status`` does not leak a debug banner. 

374 

375 Short-circuits before touching the SDK when no keys are present. 

376 """ 

377 active = [ 

378 (prov, cfg_f, env, label) 

379 for prov, cfg_f, env, label in PROVIDER_KEYS 

380 if _has_provider_key(cfg_f, env) 

381 ] 

382 if not active: 

383 return {} 

384 

385 from lilbee.services import get_services 

386 

387 provider = get_services().provider 

388 

389 result: dict[str, list[RemoteModel]] = {} 

390 for prov, _cfg_field, _env_var, display_name in active: 

391 chat_models = [ 

392 RemoteModel( 

393 name=model_name, 

394 task=ModelTask.CHAT, 

395 family="", 

396 parameter_size="", 

397 provider=display_name, 

398 ) 

399 for model_name in provider.list_chat_models(prov) 

400 ] 

401 if chat_models: 

402 result[display_name] = chat_models 

403 return result 

404 

405 

406def detect_remote_embedding_models(base_url: str = "http://localhost:11434") -> list[str]: 

407 """Return names of models classified as embedding from the SDK backend.""" 

408 return [m.name for m in classify_remote_models(base_url) if m.task == ModelTask.EMBEDDING] 

409 

410 

411class _ManagerHolder: 

412 """Encapsulates the ModelManager singleton (no module-level mutable global).""" 

413 

414 def __init__(self) -> None: 

415 self._instance: ModelManager | None = None 

416 

417 def get(self) -> ModelManager: 

418 if self._instance is None: 

419 from lilbee.config import cfg 

420 

421 self._instance = ModelManager(cfg.models_dir, cfg.remote_base_url) 

422 return self._instance 

423 

424 def reset(self) -> None: 

425 self._instance = None 

426 

427 

428_holder = _ManagerHolder() 

429 

430 

431def get_model_manager() -> ModelManager: 

432 """Get or create the singleton ModelManager.""" 

433 return _holder.get() 

434 

435 

436def reset_model_manager() -> None: 

437 """Clear the singleton (for testing).""" 

438 _holder.reset()