Coverage for src / lilbee / providers / llama_cpp_provider.py: 100%
524 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Llama.cpp provider for local GGUF inference.
3Includes a thread-safe batching queue for embeddings so that concurrent
4ingest threads don't hit the non-thread-safe Llama object simultaneously.
5When subprocess_embed is enabled, embedding and vision calls are delegated
6to a persistent child process to avoid GIL contention.
7"""
9from __future__ import annotations
11import logging
12import os
13import queue
14import threading
15import time
16from collections.abc import Callable
17from concurrent.futures import Future
18from dataclasses import dataclass
19from pathlib import Path
20from typing import TYPE_CHECKING, Any
22from gguf import GGUFReader, GGUFValueType
24from lilbee.catalog import is_rerank_ref
25from lilbee.config import DEFAULT_NUM_CTX, KV_CACHE_TYPE_BYTES, KvCacheType, cfg
26from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError, filter_options
27from lilbee.providers.model_cache import (
28 MODE_CHAT,
29 MODE_EMBED,
30 MODE_RERANK,
31 LoaderMode,
32 compute_dynamic_ctx,
33 get_available_memory,
34 kv_bytes_per_token,
35)
36from lilbee.services import get_services
38if TYPE_CHECKING:
39 from lilbee.providers.worker_process import WorkerProcess
41log = logging.getLogger(__name__)
43_llama_log = logging.getLogger("lilbee.llama_cpp")
45# ggml.h log levels (not exposed by llama-cpp-python).
46_GGML_LOG_LEVEL_INFO = 1
47_GGML_LOG_LEVEL_WARN = 2
48_GGML_LOG_LEVEL_ERROR = 3
49_GGML_LOG_LEVEL_DEBUG = 4
50_GGML_LOG_LEVEL_CONT = 5
52# WARN demotes to INFO so noisy auto-corrections stay silent at the default WARNING level.
53_GGML_TO_PY_LEVEL = {
54 _GGML_LOG_LEVEL_INFO: logging.DEBUG,
55 _GGML_LOG_LEVEL_WARN: logging.INFO,
56 _GGML_LOG_LEVEL_ERROR: logging.ERROR,
57 _GGML_LOG_LEVEL_DEBUG: logging.DEBUG,
58}
60# Substrings llama.cpp emits at GGML_LOG_LEVEL_ERROR but which are
61# advisory: the model still loads correctly. Demoted to WARNING so users
62# don't think their setup is broken.
63_GGML_ERROR_SOFT_DEMOTE = (
64 "special_eos_id is not in special_eog_ids",
65 "embeddings required but some input tokens were not marked as outputs",
66 "n_ctx_seq", # 'n_ctx_seq (X) > n_ctx_train (Y)' -- our embed clamp prevents this
67 "tokenizer config may be incorrect",
68)
70_BATCH_WINDOW_S = 0.01 # 10ms, collect concurrent requests before dispatching
72# Cap on tokens consumed during _LockedStreamIterator.close()'s drain.
73# A runaway model (e.g. Qwen3-0.6B in a never-closing <think> loop)
74# would otherwise block close() indefinitely.
75_LOCKED_STREAM_DRAIN_CAP = 1024
76_EMBED_FUTURE_TIMEOUT_S = 300.0 # Safety net: max wait for embed result
77_RERANK_FUTURE_TIMEOUT_S = 300.0 # Safety net: max wait for rerank result
79# Chat-load OOM retry knobs.
80_MAX_OOM_RETRIES = 2
81_CTX_QUANTUM = 256
82_CTX_FLOOR = 512
84# Sentinel passed to llama-cpp-python for "offload all layers".
85_N_GPU_LAYERS_AUTO = -1
87# Settings baked into Llama() at load time, or whose change picks a
88# different model file. Sampling params are read per-call and excluded.
89LOAD_AFFECTING_KEYS = frozenset(
90 {
91 "num_ctx",
92 "chat_model",
93 "embedding_model",
94 "vision_model",
95 "reranker_model",
96 }
97)
100@dataclass
101class _EmbedRequest:
102 """A single embedding request submitted to the batch queue."""
104 texts: list[str]
105 future: Future[list[list[float]]]
108@dataclass
109class _RerankRequest:
110 """A single rerank request submitted to the batch queue."""
112 query: str
113 candidates: list[str]
114 future: Future[list[float]]
117class LlamaCppProvider(LLMProvider):
118 """Provider backed by llama-cpp-python for local GGUF model inference.
119 Embedding calls are funnelled through a single background worker thread
120 that batches concurrent requests into one ``create_embedding`` call.
121 Chat calls are serialized via a lock (no batching possible).
122 Vision models are loaded with a CLIP chat handler for image understanding.
123 """
125 def __init__(self) -> None:
126 from lilbee.providers.model_cache import MemoryAwareModelCache
128 self._cache = MemoryAwareModelCache(
129 max_memory_fraction=cfg.gpu_memory_fraction,
130 keep_alive_seconds=cfg.model_keep_alive,
131 loader=load_llama,
132 )
133 self._embed_queue: queue.Queue[_EmbedRequest | None] = queue.Queue()
134 self._rerank_queue: queue.Queue[_RerankRequest | None] = queue.Queue()
135 self._chat_lock = threading.Lock()
136 self._embed_thread = threading.Thread(target=self._embed_worker, daemon=True)
137 self._embed_thread.start()
138 self._rerank_thread = threading.Thread(target=self._rerank_worker, daemon=True)
139 self._rerank_thread.start()
140 self._subprocess_worker: WorkerProcess | None = None
141 self._subprocess_enabled = cfg.subprocess_embed
143 def _embed_worker(self) -> None:
144 """Background thread: drain queue, batch, inference, dispatch results."""
145 while True:
146 first = self._embed_queue.get()
147 if first is None:
148 break
150 batch: list[_EmbedRequest] = [first]
151 shutting_down = False
152 deadline = time.monotonic() + _BATCH_WINDOW_S
153 while time.monotonic() < deadline:
154 try:
155 req = self._embed_queue.get_nowait()
156 if req is None:
157 shutting_down = True
158 break
159 batch.append(req)
160 except queue.Empty:
161 time.sleep(0.001)
162 continue
164 self._dispatch_batch(batch)
165 if shutting_down:
166 break
168 def _dispatch_batch(self, batch: list[_EmbedRequest]) -> None:
169 """Serialize embedding requests and resolve all futures.
170 Embeds one text at a time because some model architectures (e.g.
171 nomic-bert) fail with llama_decode -1 on multi-text batches.
172 """
173 try:
174 llm = self._get_embed_llm()
175 except Exception as exc:
176 for req in batch:
177 if not req.future.done():
178 req.future.set_exception(exc)
179 return
180 for req in batch:
181 try:
182 vectors: list[list[float]] = []
183 for text in req.texts:
184 response = embed_one(llm, text)
185 vectors.append(response)
186 req.future.set_result(vectors)
187 except Exception as exc:
188 if not req.future.done():
189 req.future.set_exception(exc)
191 def _rerank_worker(self) -> None:
192 """Background thread: drain rerank queue, serialize through the model.
194 The queue is unbounded; back-pressure comes from callers awaiting
195 their futures synchronously.
196 """
197 while True:
198 req = self._rerank_queue.get()
199 if req is None:
200 break
201 self._dispatch_rerank(req)
203 def _dispatch_rerank(self, req: _RerankRequest) -> None:
204 """Run a single rerank request and resolve its future."""
205 try:
206 llm = self._get_rerank_llm()
207 except Exception as exc:
208 if not req.future.done():
209 req.future.set_exception(exc)
210 return
211 try:
212 scores = compute_rerank_scores(llm, req.query, req.candidates)
213 req.future.set_result(scores)
214 except Exception as exc:
215 if not req.future.done():
216 req.future.set_exception(exc)
218 def _get_chat_llm(self, model: str | None = None) -> Any:
219 """Load or return a cached Llama instance for chat.
221 Vision OCR has its own entry point (``vision_ocr``); the chat path
222 never substitutes a vision model, even if the chat pick is multimodal.
223 """
224 resolved_model = model or cfg.chat_model
225 model_path = resolve_model_path(resolved_model)
226 return self._cache.load_model(model_path, mode=MODE_CHAT)
228 def _get_embed_llm(self) -> Any:
229 """Load or return a cached Llama instance for embeddings."""
230 model_path = resolve_model_path(cfg.embedding_model)
231 return self._cache.load_model(model_path, mode=MODE_EMBED)
233 def _get_rerank_llm(self) -> Any:
234 """Load or return a cached Llama instance for reranking."""
235 model_name = cfg.reranker_model
236 if not model_name:
237 raise ProviderError(
238 "No reranker model configured. Set cfg.reranker_model first.",
239 provider="llama-cpp",
240 )
241 model_path = resolve_model_path(model_name)
242 return self._cache.load_model(model_path, mode=MODE_RERANK)
244 def _get_subprocess_worker(self) -> WorkerProcess:
245 """Lazy-create and return the subprocess worker."""
246 if self._subprocess_worker is None:
247 from lilbee.providers.worker_process import WorkerProcess as WP # noqa: N817
249 self._subprocess_worker = WP()
250 return self._subprocess_worker
252 def embed(self, texts: list[str]) -> list[list[float]]:
253 """Embed texts. Delegates to subprocess worker if enabled, with fallback."""
254 if self._subprocess_enabled:
255 try:
256 return self._get_subprocess_worker().embed(texts)
257 except (OSError, RuntimeError) as exc:
258 log.warning("Subprocess embed failed, falling back to in-process: %s", exc)
259 self._subprocess_enabled = False
260 fut: Future[list[list[float]]] = Future()
261 self._embed_queue.put(_EmbedRequest(texts=texts, future=fut))
262 return fut.result(timeout=_EMBED_FUTURE_TIMEOUT_S)
264 def rerank(self, query: str, candidates: list[str]) -> list[float]:
265 """Score *candidates* by relevance to *query*, queued through a single worker."""
266 if not candidates:
267 return []
268 fut: Future[list[float]] = Future()
269 self._rerank_queue.put(_RerankRequest(query=query, candidates=candidates, future=fut))
270 return fut.result(timeout=_RERANK_FUTURE_TIMEOUT_S)
272 def supports_rerank(self) -> bool:
273 """llama-cpp can rerank iff llama-cpp-python exposes the rank pooling type."""
274 return _llama_cpp_has_rank_pooling()
276 def vision_ocr(self, png_bytes: bytes, model: str, prompt: str = "") -> str:
277 """Run vision OCR via the subprocess worker."""
278 return self._get_subprocess_worker().vision_ocr(png_bytes, model, prompt)
280 def chat(
281 self,
282 messages: list[dict[str, str]],
283 *,
284 stream: bool = False,
285 options: dict[str, Any] | None = None,
286 model: str | None = None,
287 ) -> str | ClosableIterator[str]:
288 """Chat completion — serialized via lock (Llama is not thread-safe)."""
289 self._chat_lock.acquire()
290 try:
291 llm = self._get_chat_llm(model)
292 kwargs: dict[str, Any] = {}
293 if options:
294 filtered = filter_options(options)
295 if "num_predict" in filtered:
296 filtered["max_tokens"] = filtered.pop("num_predict")
297 filtered.pop("num_ctx", None) # model-load param, not per-call
298 kwargs.update(filtered)
299 response = llm.create_chat_completion(messages=messages, stream=stream, **kwargs)
300 if stream:
301 return _LockedStreamIterator(response, self._chat_lock)
302 result: str = response["choices"][0]["message"]["content"] or ""
303 return result
304 finally:
305 if not stream:
306 self._chat_lock.release()
308 def list_models(self) -> list[str]:
309 """List installed models from registry."""
310 registry = get_services().registry
311 return sorted(m.ref for m in registry.list_installed())
313 def list_chat_models(self, provider: str) -> list[str]:
314 """llama-cpp has no frontier-provider catalog; always ``[]``."""
315 return []
317 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None:
318 """Not supported directly — catalog.py handles downloads."""
319 raise NotImplementedError(
320 f"llama-cpp provider cannot pull model {model!r}. "
321 "Download GGUF files manually or use the catalog."
322 )
324 def show_model(self, model: str) -> dict[str, Any] | None:
325 """Return model metadata from GGUF headers."""
326 try:
327 path = resolve_model_path(model)
328 except ProviderError:
329 return None
330 return read_gguf_metadata(path)
332 def get_capabilities(self, model: str) -> list[str]:
333 """Detect capabilities from local GGUF files.
335 Rerank models return ``["rerank"]``; cross-encoder GGUFs cannot
336 generate text. Other models report ``"completion"``, plus
337 ``"vision"`` when an mmproj sidecar is present.
338 """
339 if _is_rerank_model(model):
340 return ["rerank"]
341 caps: list[str] = ["completion"]
342 try:
343 path = resolve_model_path(model)
344 except ProviderError:
345 log.debug("resolve_model_path failed for %s", model, exc_info=True)
346 return caps
347 try:
348 find_mmproj_for_model(path)
349 caps.append("vision")
350 except ProviderError:
351 log.debug("no mmproj for %s", model, exc_info=True)
352 return caps
354 def shutdown(self) -> None:
355 """Stop workers and unload all cached models."""
356 self._embed_queue.put(None)
357 self._embed_thread.join(timeout=2)
358 self._rerank_queue.put(None)
359 self._rerank_thread.join(timeout=2)
360 if self._subprocess_worker is not None:
361 self._subprocess_worker.stop()
362 self._subprocess_worker = None
363 self._cache.unload_all()
365 def invalidate_load_cache(self, model_path: Path | None = None) -> None:
366 """Evict cached models so the next call reloads with current settings."""
367 if model_path is None:
368 self._cache.unload_all()
369 else:
370 self._cache.unload_path(model_path)
373class _LockedStreamIterator:
374 """Wraps a streaming response so the chat lock is held until iteration ends.
375 The lock must already be acquired by the caller; this iterator releases it
376 when the underlying stream is exhausted (or on explicit close).
377 """
379 def __init__(self, response: Any, lock: threading.Lock) -> None:
380 self._response = response
381 self._lock = lock
382 self._released = False
384 def __iter__(self) -> _LockedStreamIterator:
385 return self
387 def __next__(self) -> str:
388 try:
389 while True:
390 try:
391 chunk = next(self._response)
392 except StopIteration:
393 self._release()
394 raise
395 delta = chunk.get("choices", [{}])[0].get("delta", {})
396 content: str | None = delta.get("content")
397 if content:
398 return content
399 except StopIteration:
400 raise
401 except Exception:
402 self._release()
403 raise
405 def _release(self) -> None:
406 if not self._released:
407 self._released = True
408 self._lock.release()
410 def close(self) -> None:
411 """Drain (capped) the underlying C iterator, then release the lock.
413 Simply releasing the lock without finishing inference leaves the
414 llama-cpp model in an inconsistent state. Draining lets inference
415 complete cleanly. The cap (``_LOCKED_STREAM_DRAIN_CAP``) keeps a
416 runaway think loop from blocking close() indefinitely; once the
417 cap fires we accept the inconsistent state in exchange for not
418 hanging the UI.
419 """
420 if not self._released:
421 try:
422 for i, _ in enumerate(self._response):
423 if i >= _LOCKED_STREAM_DRAIN_CAP:
424 break
425 except Exception: # noqa: S110 -- best-effort drain during release; ignore partial-read errors
426 pass
427 self._release()
429 def __del__(self) -> None: # pragma: no cover
430 self._release()
433_STDERR_LOCK = threading.Lock()
435# ctypes does not retain a Python reference to the wrapped callback;
436# this module-level handle keeps it alive for the process lifetime.
437_llama_log_callback: Any = None
438_llama_log_installed = False
439_llama_log_pending: dict[int, str] = {}
440_llama_log_pending_level: int = _GGML_LOG_LEVEL_INFO
443def _llama_log_dispatch(level: int, text_bytes: bytes, _user_data: Any) -> None:
444 """Dispatch one llama.cpp log message; CONT chunks are coalesced on newline."""
445 global _llama_log_pending_level
446 try:
447 text = text_bytes.decode("utf-8", errors="replace") if text_bytes else ""
448 except Exception: # pragma: no cover
449 return
451 if level == _GGML_LOG_LEVEL_CONT:
452 _llama_log_pending[0] = _llama_log_pending.get(0, "") + text
453 else:
454 if 0 in _llama_log_pending:
455 buffered = _llama_log_pending.pop(0).rstrip()
456 if buffered:
457 _llama_log.log(_resolve_ggml_level(_llama_log_pending_level, buffered), buffered)
458 _llama_log_pending_level = level
459 _llama_log_pending[0] = text
461 if "\n" in _llama_log_pending.get(0, ""):
462 full = _llama_log_pending.pop(0).rstrip()
463 if full:
464 _llama_log.log(_resolve_ggml_level(_llama_log_pending_level, full), full)
467def _resolve_ggml_level(ggml_level: int, text: str) -> int:
468 """Translate ggml log level to Python, demoting known-advisory ERRORs to WARNING."""
469 py_level = _GGML_TO_PY_LEVEL.get(ggml_level, logging.DEBUG)
470 if py_level == logging.ERROR and any(s in text for s in _GGML_ERROR_SOFT_DEMOTE):
471 return logging.WARNING
472 return py_level
475def install_llama_log_handler() -> None:
476 """Route llama.cpp logs through Python logging. Idempotent."""
477 global _llama_log_callback, _llama_log_installed
478 if _llama_log_installed:
479 return
480 import llama_cpp
482 _llama_log_callback = llama_cpp.llama_log_callback(_llama_log_dispatch)
483 llama_cpp.llama_log_set(_llama_log_callback, None)
484 _llama_log_installed = True
487def suppress_native_stderr(fn: Any, *args: Any, **kwargs: Any) -> Any:
488 """Call *fn* with C-level stderr suppressed.
489 llama.cpp prints noisy messages (e.g. 'init: embeddings required...')
490 that bypass Python logging. This redirects fd 2 to /dev/null for the
491 duration of the call. A lock serializes access to fd 2 so concurrent
492 threads don't corrupt each other's file descriptors.
493 """
494 with _STDERR_LOCK:
495 devnull = os.open(os.devnull, os.O_WRONLY)
496 old_stderr = os.dup(2)
497 os.dup2(devnull, 2)
498 try:
499 return fn(*args, **kwargs)
500 finally:
501 os.dup2(old_stderr, 2)
502 os.close(devnull)
503 os.close(old_stderr)
506def embed_one(llm: Any, text: str) -> list[float]:
507 """Embed a single text with llama.cpp stderr noise suppressed."""
508 response = suppress_native_stderr(llm.create_embedding, input=[text])
509 result: list[float] = response["data"][0]["embedding"]
510 return result
513def read_gguf_metadata(model_path: Path) -> dict[str, str] | None:
514 """Read metadata from a GGUF file's headers via llama-cpp-python.
515 Returns a dict with keys like 'architecture', 'context_length',
516 'embedding_length', 'chat_template', 'file_type', plus the
517 KV-cache-shape fields ('block_count', 'head_count_kv', 'key_length',
518 'value_length') used to size n_ctx against host memory.
519 """
520 from llama_cpp import Llama
522 install_llama_log_handler()
523 llm = suppress_native_stderr(
524 Llama, model_path=str(model_path), vocab_only=True, verbose=False, n_gpu_layers=0
525 )
526 try:
527 raw = llm.metadata or {}
528 result: dict[str, str] = {}
529 if "general.architecture" in raw:
530 result["architecture"] = str(raw["general.architecture"])
531 arch = raw.get("general.architecture", "llama")
532 ctx_key = f"{arch}.context_length"
533 if ctx_key in raw:
534 result["context_length"] = str(raw[ctx_key])
535 emb_key = f"{arch}.embedding_length"
536 if emb_key in raw:
537 result["embedding_length"] = str(raw[emb_key])
538 for arch_key, out_key in (
539 (f"{arch}.block_count", "block_count"),
540 (f"{arch}.attention.head_count_kv", "head_count_kv"),
541 (f"{arch}.attention.head_count", "head_count"),
542 (f"{arch}.attention.key_length", "key_length"),
543 (f"{arch}.attention.value_length", "value_length"),
544 ):
545 if arch_key in raw:
546 result[out_key] = str(raw[arch_key])
547 if "tokenizer.chat_template" in raw:
548 result["chat_template"] = str(raw["tokenizer.chat_template"])
549 if "general.file_type" in raw:
550 result["file_type"] = str(raw["general.file_type"])
551 if "general.name" in raw:
552 result["name"] = str(raw["general.name"])
553 return result or None
554 finally:
555 llm.close()
558def resolve_model_path(model: str) -> Path:
559 """Resolve a model name to a .gguf file path.
560 Resolution order:
561 1. Registry (canonical source for installed models)
562 2. Absolute path (if it points to an existing file)
563 """
564 registry = get_services().registry
565 try:
566 return registry.resolve(model)
567 except (KeyError, ValueError):
568 pass
570 # Absolute path to a .gguf file
571 candidate = Path(model)
572 if candidate.is_absolute():
573 if candidate.exists():
574 return candidate
575 raise ProviderError(f"Model file not found: {model}", provider="llama-cpp")
577 raise ProviderError(
578 f"Model {model!r} not found in registry. "
579 f"Install it via the catalog or 'lilbee models install'.",
580 provider="llama-cpp",
581 )
584def _llama_cpp_has_rank_pooling() -> bool:
585 """Return True iff the installed llama-cpp-python exposes ``LLAMA_POOLING_TYPE_RANK``."""
586 try:
587 from llama_cpp import LLAMA_POOLING_TYPE_RANK # noqa: F401
588 except ImportError:
589 return False
590 return True
593def load_llama(model_path: Path, *, mode: LoaderMode) -> Any:
594 """Load a llama_cpp.Llama in chat, embed, or rerank mode."""
595 from llama_cpp import Llama
597 install_llama_log_handler()
598 embedding = mode in (MODE_EMBED, MODE_RERANK)
599 kwargs: dict[str, Any] = {
600 "model_path": str(model_path),
601 "embedding": embedding,
602 "verbose": False,
603 "n_gpu_layers": _resolve_n_gpu_layers(embedding=embedding),
604 }
606 if embedding:
607 # Embedding/rerank: clamp n_ctx to the model's training context.
608 # Passing a chat-sized cfg.num_ctx through here triggers
609 # ``n_ctx_seq > n_ctx_train`` warnings and wastes KV memory.
610 embed_meta = _safe_read_gguf_metadata(model_path)
611 embed_train_ctx = int((embed_meta or {}).get("context_length", "2048"))
612 if cfg.num_ctx is not None:
613 kwargs["n_ctx"] = min(cfg.num_ctx, embed_train_ctx)
614 else:
615 kwargs["n_ctx"] = 0 # 0 -> llama.cpp uses the model's training context
616 elif cfg.num_ctx is not None:
617 kwargs["n_ctx"] = cfg.num_ctx
618 else:
619 meta = _safe_read_gguf_metadata(model_path)
620 kwargs["n_ctx"] = _resolve_chat_ctx(model_path, meta)
621 log.info(
622 "Chat n_ctx=%d for %s (dynamic, training_ctx=%s)",
623 kwargs["n_ctx"],
624 model_path.name,
625 (meta or {}).get("context_length", "unknown"),
626 )
628 if embedding:
629 # llama-cpp-python defaults n_batch = min(n_ctx, 512), silently
630 # truncating embeddings to 512 tokens. Set n_batch = n_ctx so each
631 # text can use the model's full context window.
632 ctx_len = embed_train_ctx if kwargs["n_ctx"] == 0 else kwargs["n_ctx"]
633 kwargs["n_batch"] = ctx_len
634 kwargs["n_ubatch"] = ctx_len
636 if mode == MODE_RERANK:
637 from llama_cpp import LLAMA_POOLING_TYPE_RANK
639 kwargs["pooling_type"] = LLAMA_POOLING_TYPE_RANK
641 if not embedding:
642 _apply_flash_attention(kwargs)
643 _apply_kv_cache_type(kwargs)
645 return _construct_llama(Llama, model_path, kwargs)
648def _safe_read_gguf_metadata(model_path: Path) -> dict[str, str] | None:
649 """Best-effort GGUF metadata read, returning None on any failure."""
650 try:
651 return read_gguf_metadata(model_path)
652 except Exception:
653 log.debug("read_gguf_metadata failed for %s", model_path, exc_info=True)
654 return None
657def _resolve_chat_ctx(model_path: Path, meta: dict[str, str] | None) -> int:
658 """Pick the largest 256-multiple n_ctx that fits in available memory."""
659 training_ctx = DEFAULT_NUM_CTX
660 if meta:
661 try:
662 training_ctx = int(meta.get("context_length", DEFAULT_NUM_CTX))
663 except (TypeError, ValueError):
664 training_ctx = DEFAULT_NUM_CTX
665 ceiling = cfg.num_ctx_max
667 try:
668 model_bytes = model_path.stat().st_size
669 available = get_available_memory(cfg.gpu_memory_fraction)
670 kv_per_tok = kv_bytes_per_token(meta, _kv_elem_bytes_for_cfg())
671 return compute_dynamic_ctx(
672 model_bytes=model_bytes,
673 available_bytes=available,
674 training_ctx=training_ctx,
675 kv_bytes_per_tok=kv_per_tok,
676 ceiling=ceiling,
677 )
678 except (OSError, ValueError):
679 log.debug("dynamic ctx sizing failed for %s, using static cap", model_path, exc_info=True)
680 return min(training_ctx, DEFAULT_NUM_CTX)
683def _kv_elem_bytes_for_cfg() -> int:
684 """Bytes per KV element implied by the configured cache type."""
685 return KV_CACHE_TYPE_BYTES[cfg.kv_cache_type]
688def _resolve_n_gpu_layers(*, embedding: bool) -> int:
689 """Resolve ``cfg.n_gpu_layers`` (None=all) to llama-cpp's offload integer."""
690 if embedding or cfg.n_gpu_layers is None:
691 return _N_GPU_LAYERS_AUTO
692 return cfg.n_gpu_layers
695def _apply_flash_attention(kwargs: dict[str, Any]) -> None:
696 """Set ``flash_attn`` per ``cfg.flash_attention`` (None=auto, True/False=force)."""
697 if cfg.flash_attention is False:
698 return
699 # None (auto) and True both pass flash_attn=True; the construct loop
700 # drops it on TypeError if llama-cpp-python doesn't support it.
701 kwargs["flash_attn"] = True
704def _apply_kv_cache_type(kwargs: dict[str, Any]) -> None:
705 """Map ``cfg.kv_cache_type`` to llama-cpp-python ``type_k`` / ``type_v``."""
706 if cfg.kv_cache_type is KvCacheType.F16:
707 return
708 type_map = _ggml_type_map()
709 if type_map is None:
710 log.debug("llama_cpp internal types unavailable; skipping KV quant")
711 return
712 ggml_type = type_map.get(cfg.kv_cache_type)
713 if ggml_type is None: # pragma: no cover -- defensive against new enum values
714 return
715 kwargs["type_k"] = ggml_type
716 kwargs["type_v"] = ggml_type
719def _ggml_type_map() -> dict[KvCacheType, Any] | None:
720 """Resolve llama-cpp-python's GGML_TYPE_* constants, or None on older builds."""
721 try:
722 from llama_cpp import llama_cpp as _llc
723 except Exception: # pragma: no cover -- only fires on llama-cpp-python without _llc
724 return None
725 return {
726 KvCacheType.F32: getattr(_llc, "GGML_TYPE_F32", None),
727 KvCacheType.F16: getattr(_llc, "GGML_TYPE_F16", None),
728 KvCacheType.Q8_0: getattr(_llc, "GGML_TYPE_Q8_0", None),
729 KvCacheType.Q4_0: getattr(_llc, "GGML_TYPE_Q4_0", None),
730 }
733def _construct_llama(llama_cls: Any, model_path: Path, kwargs: dict[str, Any]) -> Any:
734 """Call ``llama_cls(**kwargs)`` with FA fallback and OOM-retry-with-halved-ctx.
736 Each loop iteration either returns the loaded model, raises (failure
737 or unrelated TypeError), or continues with halved n_ctx; the loop is
738 therefore structurally exhaustive and never falls through.
739 """
740 fa_dropped = False
741 for attempt in range(_MAX_OOM_RETRIES + 1):
742 try:
743 return suppress_native_stderr(llama_cls, **kwargs)
744 except TypeError as exc:
745 if not _drop_flash_attn_if_unsupported(exc, kwargs, fa_dropped):
746 raise
747 fa_dropped = True
748 continue
749 except ValueError as exc:
750 if attempt == _MAX_OOM_RETRIES or not _is_load_oom(exc):
751 _raise_load_error(model_path, kwargs, exc)
752 if not _halve_ctx_for_retry(kwargs, exc):
753 _raise_load_error(model_path, kwargs, exc)
754 raise RuntimeError("unreachable: _construct_llama loop fell through") # pragma: no cover
757def _drop_flash_attn_if_unsupported(
758 exc: TypeError, kwargs: dict[str, Any], already_dropped: bool
759) -> bool:
760 """If the TypeError is about an unsupported ``flash_attn`` kwarg, drop it."""
761 if already_dropped or "flash_attn" not in kwargs or "flash_attn" not in str(exc):
762 return False
763 log.info("llama-cpp-python rejected flash_attn=True; retrying without it")
764 kwargs.pop("flash_attn", None)
765 return True
768def _halve_ctx_for_retry(kwargs: dict[str, Any], exc: ValueError) -> bool:
769 """Halve n_ctx (and matching batch sizes) for an OOM retry. Returns False if no progress."""
770 current_ctx = int(kwargs.get("n_ctx", 0) or 0)
771 if current_ctx <= 0:
772 return False
773 new_ctx = max(_CTX_FLOOR, (current_ctx // 2 // _CTX_QUANTUM) * _CTX_QUANTUM)
774 if new_ctx >= current_ctx:
775 return False
776 log.warning(
777 "llama.cpp load failed at n_ctx=%d (%s); retrying at n_ctx=%d",
778 current_ctx,
779 str(exc).splitlines()[0],
780 new_ctx,
781 )
782 kwargs["n_ctx"] = new_ctx
783 for key in ("n_batch", "n_ubatch"):
784 if key in kwargs:
785 kwargs[key] = new_ctx
786 return True
789def _raise_load_error(model_path: Path, kwargs: dict[str, Any], exc: ValueError) -> None:
790 """Raise the wrapped diagnostic for a llama.cpp load failure, or re-raise as-is."""
791 wrapped = _wrap_llama_load_error(model_path, kwargs, exc)
792 if wrapped is None:
793 raise exc
794 raise wrapped from exc
797def _is_load_oom(exc: ValueError) -> bool:
798 """Does this ValueError look like a llama.cpp memory failure?"""
799 err = str(exc)
800 return "llama_context" in err or "load model from file" in err
803def _wrap_llama_load_error(
804 model_path: Path, kwargs: dict[str, Any], exc: ValueError
805) -> ValueError | None:
806 """Diagnostic ValueError for opaque llama.cpp load failures, or None to pass through."""
807 err = str(exc)
808 if "llama_context" not in err and "load model from file" not in err:
809 return None
810 try:
811 size_gb = model_path.stat().st_size / (1024**3) if model_path.exists() else 0.0
812 except OSError: # pragma: no cover
813 size_gb = 0.0
814 n_ctx = kwargs.get("n_ctx", 0)
815 n_ctx_label = n_ctx or "model default"
816 parts = [
817 f"Failed to load {model_path.name} ({size_gb:.1f} GB) with n_ctx={n_ctx_label}.",
818 ]
819 try:
820 import psutil
822 free_gb = psutil.virtual_memory().available / (1024**3)
823 parts.append(f"Host has {free_gb:.1f} GB free RAM.")
824 except Exception as psu_exc: # pragma: no cover
825 log.debug("psutil unavailable: %s", psu_exc)
826 parts.append(
827 "Try a smaller model, lower LILBEE_NUM_CTX, set LILBEE_KV_CACHE_TYPE=q8_0, "
828 "or close other processes to free RAM. "
829 f"(llama.cpp: {err})"
830 )
831 return ValueError(" ".join(parts))
834def _is_rerank_model(model: str) -> bool:
835 """Check if *model* is an exact rerank catalog entry by ref or hf_repo."""
836 if not model:
837 return False
838 return is_rerank_ref(model)
841_RERANK_PAIR_SEPARATOR = "</s></s>"
844def compute_rerank_scores(llm: Any, query: str, candidates: list[str]) -> list[float]:
845 """Score *candidates* against *query* via llama.cpp reranker embeddings.
847 ``pooling_type=LLAMA_POOLING_TYPE_RANK`` requires the pair pre-joined
848 as ``query</s></s>candidate``; passing them as two inputs makes
849 ``llama_decode`` fail with ``-1``.
850 """
851 scores: list[float] = []
852 for candidate in candidates:
853 pair = f"{query}{_RERANK_PAIR_SEPARATOR}{candidate}"
854 response = suppress_native_stderr(llm.create_embedding, input=pair)
855 score = _extract_rerank_score(response)
856 scores.append(score)
857 return scores
860def _extract_rerank_score(response: dict[str, Any]) -> float:
861 """Extract a single relevance score from a pooling_type=RANK response.
863 Raises ``ProviderError`` with the observed shape for anything other
864 than a non-empty ``list[float]`` so upstream format changes surface.
865 """
866 data = response.get("data") or []
867 if not data:
868 raise ProviderError("Reranker returned no data", provider="llama-cpp")
869 embedding = data[-1].get("embedding")
870 if isinstance(embedding, list) and embedding and isinstance(embedding[0], (int, float)):
871 return float(embedding[0])
872 raise ProviderError(
873 "Reranker returned unexpected score shape "
874 f"(got {type(embedding).__name__}: {embedding!r}); "
875 "llama-cpp-python may have changed its response format",
876 provider="llama-cpp",
877 )
880_HF_BLOBS_DIR_NAME = "blobs"
881_HF_SNAPSHOTS_DIR_NAME = "snapshots"
884def _find_mmproj_in_hf_snapshots(model_dir: Path) -> Path | None:
885 """Walk an HF-cache ``blobs/`` dir up to its sibling ``snapshots/`` tree."""
886 if model_dir.name != _HF_BLOBS_DIR_NAME:
887 return None
888 snapshots_dir = model_dir.parent / _HF_SNAPSHOTS_DIR_NAME
889 if not snapshots_dir.is_dir():
890 return None
891 for snapshot in snapshots_dir.iterdir():
892 candidates = sorted(snapshot.glob("*mmproj*.gguf"))
893 if candidates:
894 return candidates[0]
895 return None
898def _find_mmproj_in_flat_dir(model_dir: Path) -> Path | None:
899 """Glob ``*mmproj*.gguf`` siblings of a model GGUF (sideloaded layout)."""
900 candidates = sorted(model_dir.glob("*mmproj*.gguf"))
901 return candidates[0] if candidates else None
904def find_mmproj_for_model(model_path: Path) -> Path:
905 """Find the mmproj (CLIP projection) file for a vision model.
907 Resolution order: (1) catalog lookup scoped to ``FEATURED_VISION``,
908 (2) HuggingFace-cache ``snapshots/`` sibling of ``blobs/``,
909 (3) same-directory glob for flat sideloaded layouts.
910 Raises ``ProviderError`` if none find a file.
911 """
912 from lilbee.catalog import find_mmproj_file
914 found = (
915 find_mmproj_file(model_path.stem)
916 or _find_mmproj_in_hf_snapshots(model_path.parent)
917 or _find_mmproj_in_flat_dir(model_path.parent)
918 )
919 if found is not None:
920 return found
922 raise ProviderError(
923 f"No mmproj (CLIP projection) file found for vision model {model_path.name}. "
924 f"Download the mmproj file to {model_path.parent} or re-download the vision "
925 "model through the catalog to get both files.",
926 provider="llama-cpp",
927 )
930_CLIP_PROJECTOR_TYPE_KEY = "clip.projector_type"
933def read_mmproj_projector_type(mmproj_path: Path) -> str | None:
934 """Read ``clip.projector_type`` from a GGUF mmproj without loading the model."""
935 try:
936 reader = GGUFReader(str(mmproj_path))
937 field = reader.get_field(_CLIP_PROJECTOR_TYPE_KEY)
938 except Exception:
939 log.debug("Failed to read mmproj metadata from %s", mmproj_path, exc_info=True)
940 return None
941 if field is None or field.types[-1] != GGUFValueType.STRING:
942 return None
943 return bytes(field.parts[field.data[0]]).decode("utf-8", errors="replace")