Coverage for src / lilbee / model_manager.py: 100%
231 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Model lifecycle management across native GGUF and SDK-backed sources."""
3import json
4import logging
5import os
6import time
7from collections.abc import Callable
8from dataclasses import dataclass
9from enum import Enum
10from pathlib import Path
12import httpx
14from lilbee.config import DEFAULT_HTTP_TIMEOUT
15from lilbee.models import ModelTask
16from lilbee.providers.sdk_backend import (
17 PROVIDER_KEYS,
18 REMOTE_BACKEND_NAME,
19 detect_backend_name,
20)
21from lilbee.security import validate_path_within
23# lilbee.config imported lazily; see lilbee/catalog.py for the rationale.
25log = logging.getLogger(__name__)
28class ModelSource(Enum):
29 """Where a model is stored."""
31 NATIVE = "native" # lilbee's GGUF files in cfg.models_dir
32 REMOTE = "remote" # Models managed by a remote SDK-backed service
34 @classmethod
35 def parse(cls, value: str | None) -> "ModelSource | None":
36 """Parse a user-supplied source string. Empty or None means 'all'.
38 Raises ValueError on any other non-empty input so callers get a
39 consistent error type regardless of entry point (CLI, MCP, server).
40 """
41 if value is None or value == "":
42 return None
43 try:
44 return cls(value)
45 except ValueError as exc:
46 valid = ", ".join(s.value for s in cls)
47 raise ValueError(f"invalid source {value!r}; expected one of: {valid}") from exc
50class ModelNotFoundError(RuntimeError):
51 """Raised when a model ref is not found in any source or the catalog.
53 Subclasses RuntimeError so pre-existing `except RuntimeError` call
54 sites still catch it.
55 """
58_INSTALLED_CACHE_TTL_SECONDS = 30.0
61class ModelManager:
62 """Manages model lifecycle with distinct sources."""
64 def __init__(self, models_dir: Path, remote_base_url: str = "http://localhost:11434") -> None:
65 self._models_dir = models_dir
66 self._remote_base_url = remote_base_url.rstrip("/")
68 from lilbee.registry import ModelRegistry
70 self._registry = ModelRegistry(self._models_dir)
71 # Memoize list_installed results to avoid walking the registry
72 # filesystem and hitting the backend HTTP endpoint on every call.
73 # The catalog filter path fires this per request. Time-based TTL
74 # plus explicit invalidation on pull/remove keeps freshness.
75 self._installed_cache: dict[ModelSource | None, tuple[float, list[str]]] = {}
77 def list_installed(self, source: ModelSource | None = None) -> list[str]:
78 """List installed model names. ``source=None`` lists all sources.
80 Memoized with a ``_INSTALLED_CACHE_TTL_SECONDS`` TTL and
81 invalidated eagerly by ``pull``/``remove``.
82 """
83 now = time.monotonic()
84 cached = self._installed_cache.get(source)
85 if cached is not None:
86 cached_at, cached_result = cached
87 if now - cached_at < _INSTALLED_CACHE_TTL_SECONDS:
88 return cached_result
90 if source is None:
91 native = set(self._list_native())
92 remote = set(self._list_remote())
93 result = sorted(native | remote)
94 elif source is ModelSource.NATIVE:
95 result = self._list_native()
96 else:
97 result = self._list_remote()
99 self._installed_cache[source] = (now, result)
100 return result
102 def _invalidate_installed_cache(self) -> None:
103 """Drop all cached list_installed results."""
104 self._installed_cache.clear()
106 def _list_native(self) -> list[str]:
107 """List native models from the registry only."""
108 return sorted(m.ref for m in self._registry.list_installed())
110 def _list_remote(self) -> list[str]:
111 """List models from the SDK backend via its HTTP API."""
112 url = f"{self._remote_base_url}/api/tags"
113 try:
114 resp = httpx.get(url, timeout=DEFAULT_HTTP_TIMEOUT)
115 resp.raise_for_status()
116 data = resp.json()
117 return [m["name"] for m in data.get("models", [])]
118 except httpx.HTTPStatusError as exc:
119 log.warning("SDK backend HTTP error listing models: %s", exc)
120 return []
121 except (httpx.ConnectError, httpx.TimeoutException) as exc:
122 log.debug("SDK backend not reachable: %s", exc)
123 return []
125 def is_installed(self, model: str, source: ModelSource | None = None) -> bool:
126 """Check if model exists in specified source."""
127 if source is None:
128 return self._is_native(model) or self._is_remote(model)
129 if source is ModelSource.NATIVE:
130 return self._is_native(model)
131 return self._is_remote(model)
133 def _is_native(self, model: str) -> bool:
134 if self._registry.is_installed(model):
135 return True
136 try:
137 validate_path_within(self._models_dir / model, self._models_dir)
138 except ValueError:
139 return False
140 return (self._models_dir / model).is_file()
142 def _is_remote(self, model: str) -> bool:
143 return model in self.list_installed(ModelSource.REMOTE)
145 def get_source(self, model: str) -> ModelSource | None:
146 """Find which source a model lives in. Native takes precedence."""
147 if self._is_native(model):
148 return ModelSource.NATIVE
149 if self._is_remote(model):
150 return ModelSource.REMOTE
151 return None
153 def pull(
154 self,
155 model: str,
156 source: ModelSource,
157 *,
158 on_progress: Callable[[dict], None] | None = None,
159 on_bytes: Callable[[int, int], None] | None = None,
160 ) -> Path | None:
161 """Pull/download model to specified source.
163 Returns the Path for native downloads, None for backend-managed pulls.
165 *on_progress* receives dict events from the SDK backend.
166 *on_bytes* receives (downloaded_bytes, total_bytes) from native
167 HuggingFace downloads. The two sources report progress in different
168 shapes, so callers pass whichever matches the chosen source.
169 """
170 try:
171 if source is ModelSource.NATIVE:
172 return self._pull_native(model, on_bytes=on_bytes)
173 self._pull_remote(model, on_progress=on_progress)
174 return None
175 finally:
176 self._invalidate_installed_cache()
178 def _pull_native(
179 self,
180 model: str,
181 *,
182 on_bytes: Callable[[int, int], None] | None = None,
183 ) -> Path:
184 """Download a featured or ad-hoc HuggingFace model to the native GGUF directory."""
185 from lilbee.catalog import download_model, resolve_pull_target
187 entry = resolve_pull_target(model)
188 if entry is None:
189 raise ModelNotFoundError(
190 f"Model '{model}' not recognized. "
191 "Pass a HuggingFace repo id (owner/name) or a featured model name."
192 )
193 path = download_model(entry, on_progress=on_bytes)
194 log.info("Downloaded %s to %s", model, path)
195 return path
197 def _pull_remote(
198 self, model: str, *, on_progress: Callable[[dict], None] | None = None
199 ) -> None:
200 """Pull model via the SDK backend's HTTP API with streaming progress."""
201 url = f"{self._remote_base_url}/api/pull"
202 try:
203 with (
204 # Model pulls stream progress over minutes; an overall
205 # timeout would cut the download. Connect/write timeouts
206 # still apply via httpx defaults when timeout=None.
207 httpx.Client(timeout=None) as client, # noqa: S113
208 client.stream("POST", url, json={"name": model, "stream": True}) as resp,
209 ):
210 resp.raise_for_status()
211 for line in resp.iter_lines():
212 if not line:
213 continue
214 data = json.loads(line)
215 if "error" in data:
216 raise RuntimeError(f"Failed to pull '{model}': {data['error']}")
217 if on_progress is not None:
218 on_progress(data)
219 except httpx.ConnectError as exc:
220 raise RuntimeError(
221 f"Cannot connect to SDK backend: {exc}. Is the server running?"
222 ) from exc
223 log.info("Pulled %s via SDK backend", model)
225 def remove(self, model: str, source: ModelSource | None = None) -> bool:
226 """Remove installed model. Returns True if removed."""
227 try:
228 if source is None:
229 native_removed = self._remove_native(model)
230 backend_removed = self._remove_remote(model)
231 return native_removed or backend_removed
232 if source is ModelSource.NATIVE:
233 return self._remove_native(model)
234 return self._remove_remote(model)
235 finally:
236 self._invalidate_installed_cache()
238 def _remove_native(self, model: str) -> bool:
239 if self._registry.remove(model):
240 log.info("Removed native model %s from registry", model)
241 return True
242 try:
243 path = validate_path_within(self._models_dir / model, self._models_dir)
244 except ValueError:
245 log.warning("Path traversal blocked: %s escapes %s", model, self._models_dir)
246 return False
247 if path.is_file():
248 path.unlink()
249 log.info("Removed native model %s", model)
250 return True
251 return False
253 def _remove_remote(self, model: str) -> bool:
254 url = f"{self._remote_base_url}/api/delete"
255 try:
256 resp = httpx.request(
257 "DELETE",
258 url,
259 content=json.dumps({"model": model}).encode(),
260 headers={"Content-Type": "application/json"},
261 timeout=DEFAULT_HTTP_TIMEOUT,
262 )
263 if resp.status_code == 200:
264 log.info("Removed backend model %s", model)
265 return True
266 if resp.status_code == 404:
267 return False
268 log.warning("Unexpected status %d removing %s", resp.status_code, model)
269 return False
270 except httpx.ConnectError as exc:
271 raise RuntimeError(
272 f"Cannot connect to SDK backend: {exc}. Is the server running?"
273 ) from exc
276_EMBEDDING_FAMILIES = frozenset({"bert", "nomic-bert", "e5", "bge"})
277_VISION_NAME_PATTERNS = frozenset({"llava", "vision", "moondream", "ocr", "minicpm-v"})
278# Reranker detection runs BEFORE embedding detection so ``bge-reranker-*``
279# (family "bge" but clearly a reranker) does not get misclassified as
280# EMBEDDING. ``cross-encoder`` covers the SBERT-style naming convention
281# for cross-encoder rerankers.
282_RERANKER_NAME_PATTERNS = frozenset({"reranker", "rerank", "cross-encoder"})
285@dataclass
286class RemoteModel:
287 """A model from the SDK backend with inferred task classification."""
289 name: str
290 task: str # "chat", "embedding", "vision", "rerank"
291 family: str
292 parameter_size: str
293 provider: str = REMOTE_BACKEND_NAME
296def _classify_remote_task(name: str, family: str) -> str:
297 """Classify a remote model as chat, embedding, vision, or rerank.
299 Reranker detection runs first so ``bge-reranker-base`` (family
300 ``bge``) does not get dragged into the embedding bucket by the
301 family check. After reranker: embedding by family tag, vision by
302 name pattern, else chat.
303 """
304 name_lower = name.lower()
305 if any(rp in name_lower for rp in _RERANKER_NAME_PATTERNS):
306 return ModelTask.RERANK
307 family_lower = family.lower()
308 if any(ef in family_lower for ef in _EMBEDDING_FAMILIES):
309 return ModelTask.EMBEDDING
310 if any(vp in name_lower for vp in _VISION_NAME_PATTERNS):
311 return ModelTask.VISION
312 return ModelTask.CHAT
315_CLASSIFY_DEFAULT_TIMEOUT_S = 5.0
318def classify_remote_models(
319 base_url: str = "http://localhost:11434",
320 *,
321 timeout: float = _CLASSIFY_DEFAULT_TIMEOUT_S,
322) -> list[RemoteModel]:
323 """Discover and classify all models from the SDK backend by task.
325 Uses /api/tags family metadata for embedding detection and name
326 patterns for reranker and vision detection. Returns an empty list
327 on any error (including timeout) so callers in read-only code paths
328 can stay responsive when the backend is down.
329 """
330 try:
331 resp = httpx.get(f"{base_url}/api/tags", timeout=timeout)
332 resp.raise_for_status()
333 raw_models = resp.json().get("models", [])
334 except Exception:
335 log.debug("Failed to classify remote models", exc_info=True)
336 return []
338 provider = detect_backend_name(base_url)
339 result: list[RemoteModel] = []
340 for model in raw_models:
341 name = model.get("name", "")
342 details = model.get("details", {})
343 family = details.get("family", "")
344 param_size = details.get("parameter_size", "")
345 task = _classify_remote_task(name, family)
346 result.append(
347 RemoteModel(
348 name=name, task=task, family=family, parameter_size=param_size, provider=provider
349 )
350 )
351 return result
354def _has_provider_key(cfg_field: str, env_var: str) -> bool:
355 """Return True if a usable API key exists via env var or lilbee config."""
356 if os.environ.get(env_var):
357 return True
358 # circular: model_manager -> config via cfg; config.py's catalog
359 # imports pull model_manager in transitively, so we defer cfg access
360 # to call time to avoid a module-init cycle.
361 from lilbee.config import cfg
363 return bool(getattr(cfg, cfg_field, ""))
366def discover_api_models() -> dict[str, list[RemoteModel]]:
367 """Return frontier chat models grouped by provider.
369 Checks which provider API keys are available (env vars or config),
370 then asks the active provider's backend for each provider's chat
371 catalog. Going through ``SdkLLMProvider.list_chat_models`` ensures
372 ``cfg.json_mode`` suppression is applied before the SDK import,
373 so ``lilbee --json status`` does not leak a debug banner.
375 Short-circuits before touching the SDK when no keys are present.
376 """
377 active = [
378 (prov, cfg_f, env, label)
379 for prov, cfg_f, env, label in PROVIDER_KEYS
380 if _has_provider_key(cfg_f, env)
381 ]
382 if not active:
383 return {}
385 from lilbee.services import get_services
387 provider = get_services().provider
389 result: dict[str, list[RemoteModel]] = {}
390 for prov, _cfg_field, _env_var, display_name in active:
391 chat_models = [
392 RemoteModel(
393 name=model_name,
394 task=ModelTask.CHAT,
395 family="",
396 parameter_size="",
397 provider=display_name,
398 )
399 for model_name in provider.list_chat_models(prov)
400 ]
401 if chat_models:
402 result[display_name] = chat_models
403 return result
406def detect_remote_embedding_models(base_url: str = "http://localhost:11434") -> list[str]:
407 """Return names of models classified as embedding from the SDK backend."""
408 return [m.name for m in classify_remote_models(base_url) if m.task == ModelTask.EMBEDDING]
411class _ManagerHolder:
412 """Encapsulates the ModelManager singleton (no module-level mutable global)."""
414 def __init__(self) -> None:
415 self._instance: ModelManager | None = None
417 def get(self) -> ModelManager:
418 if self._instance is None:
419 from lilbee.config import cfg
421 self._instance = ModelManager(cfg.models_dir, cfg.remote_base_url)
422 return self._instance
424 def reset(self) -> None:
425 self._instance = None
428_holder = _ManagerHolder()
431def get_model_manager() -> ModelManager:
432 """Get or create the singleton ModelManager."""
433 return _holder.get()
436def reset_model_manager() -> None:
437 """Clear the singleton (for testing)."""
438 _holder.reset()