Coverage for src / lilbee / server / handlers.py: 100%
546 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Framework-agnostic route handlers for the lilbee HTTP server.
3Every public function is a plain async callable — no framework imports.
4Return types are dicts (JSON responses), lists, or async generators of SSE strings.
5"""
7from __future__ import annotations
9import asyncio
10import contextlib
11import copy
12import functools
13import json
14import logging
15import mimetypes
16import threading
17import time
18from collections.abc import AsyncGenerator, Iterator
19from pathlib import Path
20from typing import TYPE_CHECKING, Any, Literal, cast
22from pydantic import BaseModel
23from pydantic_core import PydanticUndefined
25from lilbee import settings
26from lilbee.catalog import find_catalog_entry
27from lilbee.cli.helpers import clean_result, copy_files, gather_status, get_version
28from lilbee.config import Config, cfg
29from lilbee.config_meta import (
30 MODEL_ROLE_FIELDS as _MODEL_ROLE_FIELDS,
31)
32from lilbee.config_meta import (
33 PUBLIC_CONFIG_FIELDS as _PUBLIC_CONFIG_FIELDS,
34)
35from lilbee.config_meta import (
36 REINDEX_FIELDS,
37 WRITABLE_CONFIG_FIELDS,
38)
39from lilbee.model_manager import ModelSource, get_model_manager
40from lilbee.models import ModelTask
41from lilbee.progress import DetailedProgressCallback, EventType, ProgressEvent, SseEvent
42from lilbee.providers.model_ref import parse_model_ref
43from lilbee.providers.sdk_backend import API_KEY_FIELDS
44from lilbee.providers.sdk_llm_provider import inject_provider_keys
45from lilbee.results import DocumentResult, group
46from lilbee.security import validate_path_within
47from lilbee.server.models import (
48 AddSummary,
49 AskResponse,
50 CatalogEntryResponse,
51 CleanedChunk,
52 ConfigResponse,
53 ConfigUpdateResponse,
54 DocumentInfo,
55 DocumentListResponse,
56 DocumentRemoveResponse,
57 ExternalModelsResponse,
58 HealthResponse,
59 InstalledModelEntry,
60 ModelsCatalogResponse,
61 ModelsDeleteResponse,
62 ModelsInstalledResponse,
63 ModelsShowResponse,
64 SetModelResponse,
65 SourceContentResponse,
66 StatusResponse,
67 SyncSummary,
68)
69from lilbee.services import get_services
71if TYPE_CHECKING:
72 from lilbee.catalog import CatalogModel
73 from lilbee.ingest import SyncResult
74 from lilbee.query import ChatMessage
76log = logging.getLogger(__name__)
78# Windows mimetypes reads from the registry, which may not define ``.md``
79# as ``text/markdown``. Pin the mapping at import time; ``add_type`` is
80# idempotent so repeated imports are safe.
81mimetypes.add_type("text/markdown", ".md")
83MAX_ADD_FILES = 100
86class ModelCatalogEntry(BaseModel):
87 """A single model in the catalog."""
89 name: str
90 size_gb: float
91 min_ram_gb: float
92 description: str
93 installed: bool
96class ModelCatalogSection(BaseModel):
97 """A single-role catalog section with active model and installed list."""
99 active: str
100 catalog: list[ModelCatalogEntry]
101 installed: list[str]
104class ModelsResponse(BaseModel):
105 """Response for GET /api/models: one catalog section per role."""
107 chat: ModelCatalogSection
108 embedding: ModelCatalogSection
109 vision: ModelCatalogSection
110 reranker: ModelCatalogSection
113# ``ModelTask.RERANK.value`` is ``"rerank"`` but the route is ``/api/models/reranker``,
114# so this mapping is needed to build correct redirect URLs in 422 responses.
115TASK_ENDPOINT_PATH: dict[ModelTask, str] = {
116 ModelTask.CHAT: "chat",
117 ModelTask.EMBEDDING: "embedding",
118 ModelTask.VISION: "vision",
119 ModelTask.RERANK: "reranker",
120}
123def format_task_mismatch(ref: str, entry_task: ModelTask, expected_task: ModelTask) -> str:
124 """Build the 422 body when a role slot is assigned a model of the wrong task."""
125 endpoint = TASK_ENDPOINT_PATH[entry_task]
126 return (
127 f"Model '{ref}' is a {entry_task} model, not {expected_task}. "
128 f"Set it via PUT /api/models/{endpoint} instead."
129 )
132def sse_event(event: str, data: Any) -> str:
133 """Format a single Server-Sent Event string."""
134 return f"event: {event}\ndata: {json.dumps(data)}\n\n"
137def sse_error(message: str, *, code: str | None = None, detail: str | None = None) -> str:
138 """Format an SSE error event with optional structured ``code`` / ``detail``."""
139 payload: dict[str, Any] = {"message": message}
140 if code is not None:
141 payload["code"] = code
142 if detail is not None:
143 payload["detail"] = detail
144 return sse_event(SseEvent.ERROR, payload)
147_OOM_MARKERS = ("failed to load", "free ram", "try a smaller model", "llama_context")
150def classify_load_error(message: str) -> tuple[str | None, str]:
151 """Return ``(code, user_message)`` for an SSE error event.
153 Recognises the llama.cpp OOM diagnostic and maps it to a stable code; any
154 other input falls back to the legacy generic shape.
155 """
156 lowered = message.lower()
157 if any(marker in lowered for marker in _OOM_MARKERS):
158 return "model_too_large", "Model too large for available RAM"
159 return None, "Internal error"
162def sse_done(data: dict[str, Any]) -> str:
163 """Format an SSE done event."""
164 return sse_event(SseEvent.DONE, data)
167def _resolve_generation_options(options: dict[str, Any] | None) -> dict[str, Any] | None:
168 """Convert raw options dict to GenerationOptions, or None."""
169 return cfg.generation_options(**options) if options else None
172class SseStream:
173 """Context object for SSE streaming with cancellation support.
174 Bundles the queue, cancel event, and progress callback that every SSE
175 endpoint needs. Call :meth:`drain` to yield events until the task
176 completes or the client disconnects.
177 """
179 def __init__(self) -> None:
180 self.queue: asyncio.Queue[str | None] = asyncio.Queue()
181 self.cancel = threading.Event()
182 self.loop = asyncio.get_running_loop()
183 self.callback: DetailedProgressCallback = self._build_callback()
185 def _build_callback(self) -> DetailedProgressCallback:
186 """Create a progress callback that serializes events into the queue.
187 Safe to call from both the event-loop thread and worker threads.
188 """
189 loop = self.loop
190 queue = self.queue
192 def _callback(event_type: EventType, data: ProgressEvent) -> None:
193 serialized = data.model_dump() if isinstance(data, BaseModel) else data
194 payload = f"event: {event_type}\ndata: {json.dumps(serialized)}\n\n"
195 try:
196 running = asyncio.get_running_loop()
197 except RuntimeError:
198 running = None
199 if running is loop:
200 queue.put_nowait(payload)
201 else:
202 loop.call_soon_threadsafe(queue.put_nowait, payload)
204 return _callback
206 async def drain(
207 self, task: asyncio.Task[Any] | asyncio.Future[Any], label: str
208 ) -> AsyncGenerator[str, None]:
209 """Yield SSE strings until a sentinel arrives; cancel *task* on client disconnect.
211 Emits a ``heartbeat`` event whenever the producer queue stays
212 idle longer than ``cfg.sse_heartbeat_interval`` seconds so
213 clients that enforce a stream-idle timeout don't abort.
214 """
215 last_yielded = time.monotonic()
216 try:
217 while True:
218 try:
219 item = await asyncio.wait_for(self.queue.get(), timeout=0.1)
220 except TimeoutError:
221 now = time.monotonic()
222 heartbeat_interval = cfg.sse_heartbeat_interval
223 if heartbeat_interval > 0 and now - last_yielded >= heartbeat_interval:
224 last_yielded = now
225 yield sse_event(SseEvent.HEARTBEAT, {"ts": time.time()})
226 # Fallback for producers that die without a sentinel.
227 if task.done() and self.queue.empty():
228 break
229 continue
230 if item is None:
231 break
232 last_yielded = time.monotonic()
233 yield item
234 except (asyncio.CancelledError, GeneratorExit):
235 log.info("%s cancelled by client", label)
236 self.cancel.set()
237 task.cancel()
240async def health() -> HealthResponse:
241 """Return service health and version."""
242 return HealthResponse(status="ok", version=get_version())
245async def status() -> StatusResponse:
246 """Return config, sources, and chunk counts."""
247 raw = gather_status()
248 return StatusResponse(**raw.model_dump(exclude_none=True))
251async def search(q: str, top_k: int = 5, chunk_type: str | None = None) -> list[DocumentResult]:
252 """Search and return grouped DocumentResults."""
253 if not q or not q.strip():
254 raise ValueError("query must not be empty")
255 results = get_services().searcher.search(q, top_k=top_k, chunk_type=chunk_type)
256 results = [r for r in results if r.distance is None or r.distance <= cfg.max_distance]
257 return group(results)
260async def ask(
261 question: str,
262 top_k: int = 0,
263 options: dict[str, Any] | None = None,
264 chunk_type: str | None = None,
265) -> AskResponse:
266 """One-shot RAG answer. Returns answer and sources."""
267 if not question or not question.strip():
268 raise ValueError("question must not be empty")
269 opts = _resolve_generation_options(options)
270 result = get_services().searcher.ask_raw(
271 question, top_k=top_k, options=opts, chunk_type=chunk_type
272 )
273 return AskResponse(
274 answer=result.answer,
275 sources=[CleanedChunk(**clean_result(s)) for s in result.sources],
276 )
279def _run_llm_stream(
280 messages: list[ChatMessage],
281 opts: dict[str, Any] | None,
282 queue: asyncio.Queue[str | None],
283 cancel: threading.Event,
284 error_holder: list[str],
285) -> None:
286 """Stream LLM tokens into a queue from a worker thread."""
287 from lilbee.reasoning import filter_reasoning
289 try:
290 provider = get_services().provider
291 stream = provider.chat(
292 cast("list[dict[str, Any]]", messages),
293 stream=True,
294 options=opts or None,
295 model=cfg.chat_model,
296 )
297 for st in filter_reasoning(cast(Iterator[str], stream), show=cfg.show_reasoning):
298 if cancel.is_set():
299 break
300 if st.content:
301 event_type = SseEvent.REASONING if st.is_reasoning else SseEvent.TOKEN
302 queue.put_nowait(sse_event(event_type, {"token": st.content}))
303 except Exception as exc:
304 error_holder.append(str(exc))
305 finally:
306 queue.put_nowait(None)
309async def _stream_rag_response(
310 question: str,
311 history: list[ChatMessage] | None = None,
312 top_k: int = 0,
313 options: dict[str, Any] | None = None,
314 chunk_type: str | None = None,
315) -> AsyncGenerator[str, None]:
316 """Shared SSE streaming for ask_stream and chat_stream."""
317 yield "" # force generator
319 rag = get_services().searcher.build_rag_context(
320 question, top_k=top_k, history=history, chunk_type=chunk_type
321 )
322 if rag is None:
323 yield sse_error("No relevant documents found.")
324 return
326 results, messages = rag
327 opts = _resolve_generation_options(options) or cfg.generation_options()
329 sse = SseStream()
330 error_holder: list[str] = []
332 executor_fut = sse.loop.run_in_executor(
333 None, _run_llm_stream, messages, opts, sse.queue, sse.cancel, error_holder
334 )
335 task = asyncio.ensure_future(executor_fut)
336 async for event in sse.drain(task, "RAG stream"):
337 yield event
339 if error_holder:
340 raw = error_holder[0]
341 code, user_message = classify_load_error(raw)
342 log.warning("Stream error: %s", raw)
343 yield sse_error(user_message, code=code, detail=raw if code else None)
344 sse.cancel.set()
345 return
347 # Ensure executor thread has finished before yielding final events
348 await executor_fut
350 yield sse_event(SseEvent.SOURCES, [clean_result(s) for s in results])
351 yield sse_done({})
354def ask_stream(
355 question: str,
356 top_k: int = 0,
357 options: dict[str, Any] | None = None,
358 chunk_type: str | None = None,
359) -> AsyncGenerator[str, None]:
360 """Yield SSE events: token, sources, done."""
361 return _stream_rag_response(question, top_k=top_k, options=options, chunk_type=chunk_type)
364async def chat(
365 question: str,
366 history: list[ChatMessage],
367 top_k: int = 0,
368 options: dict[str, Any] | None = None,
369 chunk_type: str | None = None,
370) -> AskResponse:
371 """Chat with history. Returns answer and sources."""
372 opts = _resolve_generation_options(options)
373 result = get_services().searcher.ask_raw(
374 question, top_k=top_k, history=history, options=opts, chunk_type=chunk_type
375 )
376 return AskResponse(
377 answer=result.answer,
378 sources=[CleanedChunk(**clean_result(s)) for s in result.sources],
379 )
382def chat_stream(
383 question: str,
384 history: list[ChatMessage],
385 top_k: int = 0,
386 options: dict[str, Any] | None = None,
387 chunk_type: str | None = None,
388) -> AsyncGenerator[str, None]:
389 """Yield SSE events with chat history support."""
390 return _stream_rag_response(
391 question, history=history, top_k=top_k, options=options, chunk_type=chunk_type
392 )
395async def _run_sync_with_sentinel(
396 sse: SseStream, enable_ocr: bool | None, force_rebuild: bool = False
397) -> SyncResult:
398 """Run ingest.sync() and guarantee the drain sentinel is enqueued."""
399 from lilbee.cli.helpers import temporary_ocr_config
400 from lilbee.ingest import sync
402 try:
403 with temporary_ocr_config(enable_ocr):
404 return await sync(
405 quiet=True,
406 on_progress=sse.callback,
407 cancel=sse.cancel,
408 force_rebuild=force_rebuild,
409 )
410 finally:
411 sse.queue.put_nowait(None)
414# Registry lock serializes lock creation and the check-and-acquire step to
415# avoid a TOCTOU between locked() and acquire() under concurrent /api/add.
416_INGEST_LOCKS: dict[str, asyncio.Lock] = {}
417_INGEST_LOCK_REGISTRY: asyncio.Lock | None = None
420# Types that can carry script even within an "inline-rendered" category.
421# Keep the deny narrow and explicit. Broadening this set is a security-relevant
422# change — file an issue with the ``security`` label before adding entries.
423_RAW_INLINE_RENDER_DENY: frozenset[str] = frozenset(
424 {
425 "text/html",
426 "text/javascript",
427 "application/javascript",
428 "application/xhtml+xml",
429 "text/css",
430 "image/svg+xml",
431 }
432)
435def _is_safe_for_inline_render(content_type: str) -> bool:
436 """Whether ``raw=1`` may serve this Content-Type as-is.
438 Trusted categories (``text/*``, ``image/*``, ``application/pdf``) pass
439 through, with named exceptions for types that embed executable script.
440 Everything else degrades to ``application/octet-stream`` so an attacker-
441 renamed file (e.g. ``evil.html``) cannot trick a browser into rendering
442 it inline within the plugin origin.
443 """
444 if content_type in _RAW_INLINE_RENDER_DENY:
445 return False
446 if content_type == "application/pdf":
447 return True
448 return content_type.startswith("text/") or content_type.startswith("image/")
451def _get_registry_lock() -> asyncio.Lock:
452 """Return the registry lock, creating it on the running loop if needed."""
453 global _INGEST_LOCK_REGISTRY
454 if _INGEST_LOCK_REGISTRY is None:
455 _INGEST_LOCK_REGISTRY = asyncio.Lock()
456 return _INGEST_LOCK_REGISTRY
459def _reset_ingest_locks() -> None:
460 """Test hook: clear per-source locks and the registry lock."""
461 global _INGEST_LOCK_REGISTRY
462 _INGEST_LOCKS.clear()
463 _INGEST_LOCK_REGISTRY = None
466async def _try_acquire_source(name: str) -> asyncio.Lock | None:
467 """Acquire the lock for ``name`` or return ``None`` if already held."""
468 async with _get_registry_lock():
469 lock = _INGEST_LOCKS.get(name)
470 if lock is None:
471 lock = asyncio.Lock()
472 _INGEST_LOCKS[name] = lock
473 if lock.locked():
474 return None
475 await lock.acquire()
476 return lock
479def _canonical_source_name(p_str: str) -> str:
480 """Match the basename ``copy_files`` writes under ``cfg.documents_dir``."""
481 return Path(p_str).name
484async def sync_stream(
485 *, enable_ocr: bool | None = None, force_rebuild: bool = False
486) -> AsyncGenerator[str, None]:
487 """Trigger sync, yield SSE progress events, then done event.
489 When ``force_rebuild`` is true, the underlying sync drops every table and
490 re-ingests from ``cfg.documents_dir`` (the REST equivalent of ``lilbee rebuild``).
491 """
492 sse = SseStream()
493 task = asyncio.create_task(_run_sync_with_sentinel(sse, enable_ocr, force_rebuild))
494 async for event in sse.drain(task, "Sync stream"):
495 yield event
496 if not sse.cancel.is_set() and task.done() and not task.cancelled():
497 exc = task.exception()
498 if exc is not None:
499 yield sse_error(str(exc))
500 return
501 yield sse_done(task.result().model_dump())
504async def _run_add(
505 paths: list[str],
506 force: bool,
507 enable_ocr: bool | None,
508 ocr_timeout: float | None,
509 sse: SseStream,
510) -> AddSummary:
511 """Copy files and sync, returning the summary for the final done event."""
512 from lilbee.cli.helpers import temporary_ocr_config
513 from lilbee.ingest import sync
515 try:
516 errors: list[str] = []
517 valid: list[Path] = []
518 for p_str in paths:
519 p = Path(p_str)
520 if not p.exists():
521 errors.append(p_str)
522 else:
523 valid.append(p)
525 copy_result = copy_files(valid, force=force)
527 if sse.cancel.is_set():
528 return AddSummary(copied=copy_result.copied, skipped=copy_result.skipped, errors=errors)
530 with temporary_ocr_config(enable_ocr, ocr_timeout):
531 sync_result = await sync(quiet=True, on_progress=sse.callback, cancel=sse.cancel)
533 return AddSummary(
534 copied=copy_result.copied,
535 skipped=copy_result.skipped,
536 errors=errors,
537 sync=SyncSummary(**sync_result.model_dump()),
538 )
539 finally:
540 sse.queue.put_nowait(None)
543async def _acquire_add_locks(
544 paths: list[str],
545) -> tuple[list[tuple[str, asyncio.Lock]], list[str]]:
546 """Return ``(acquired, busy)`` partitioning of ``paths`` by lock state."""
547 acquired: list[tuple[str, asyncio.Lock]] = []
548 busy: list[str] = []
549 seen: set[str] = set()
550 for p_str in paths:
551 name = _canonical_source_name(p_str)
552 if name in seen:
553 continue
554 seen.add(name)
555 lock = await _try_acquire_source(name)
556 if lock is None:
557 busy.append(name)
558 else:
559 acquired.append((name, lock))
560 return acquired, busy
563def _release_add_locks(acquired: list[tuple[str, asyncio.Lock]]) -> None:
564 """Release every lock in ``acquired``. Safe to call multiple times."""
565 while acquired:
566 _, lock = acquired.pop()
567 if lock.locked():
568 lock.release()
571def validate_add_paths(
572 data: dict[str, Any],
573) -> tuple[list[str], bool, bool | None, float | None]:
574 """Validate add-files input. Raises ValueError on bad input."""
575 paths = data.get("paths")
576 if not isinstance(paths, list) or not paths:
577 raise ValueError("'paths' must be a non-empty list of strings")
578 if len(paths) > MAX_ADD_FILES:
579 raise ValueError(f"Too many files: {len(paths)} exceeds limit of {MAX_ADD_FILES}")
581 for p_str in paths:
582 validate_path_within(cfg.documents_dir / Path(p_str).name, cfg.documents_dir)
584 force = bool(data.get("force", False))
585 enable_ocr, ocr_timeout = _parse_ocr_params(data)
586 return paths, force, enable_ocr, ocr_timeout
589def _parse_ocr_params(data: dict[str, Any]) -> tuple[bool | None, float | None]:
590 """Extract and coerce OCR parameters from a request dict."""
591 enable_ocr = data.get("enable_ocr")
592 ocr_timeout = data.get("ocr_timeout")
593 if enable_ocr is not None:
594 enable_ocr = bool(enable_ocr)
595 if ocr_timeout is not None:
596 ocr_timeout = float(ocr_timeout)
597 return enable_ocr, ocr_timeout
600async def add_files_stream(data: dict[str, Any]) -> AsyncGenerator[str, None]:
601 """Copy files, sync, and yield SSE progress events.
603 Contended sources emit ``already_ingesting`` and the stream closes
604 without a ``done`` event, signalling the client to wait rather than retry.
605 """
606 paths = data.get("paths", [])
607 force = bool(data.get("force", False))
608 enable_ocr, ocr_timeout = _parse_ocr_params(data)
610 acquired, busy = await _acquire_add_locks(paths)
611 try:
612 for name in busy:
613 log.info("Rejecting /api/add for %s: already ingesting", name)
614 yield sse_event(SseEvent.ALREADY_INGESTING, {"source": name})
616 if not acquired:
617 return
619 sse = SseStream()
620 task = asyncio.create_task(_run_add(paths, force, enable_ocr, ocr_timeout, sse))
621 try:
622 async for event in sse.drain(task, "Add files stream"):
623 yield event
624 if not sse.cancel.is_set() and task.done() and not task.cancelled():
625 exc = task.exception()
626 if exc is not None:
627 yield sse_error(str(exc))
628 return
629 summary = task.result()
630 yield sse_done(summary.model_dump())
631 finally:
632 if not task.done():
633 task.cancel()
634 with contextlib.suppress(asyncio.CancelledError, Exception):
635 await task
636 finally:
637 _release_add_locks(acquired)
640def _catalog_section(
641 featured: tuple[CatalogModel, ...],
642 active: str,
643 installed: set[str],
644) -> ModelCatalogSection:
645 """Build a ModelCatalogSection from a featured-catalog tuple.
647 A featured row is "installed" when at least one quant of its
648 ``hf_repo`` has a manifest. Installed refs are full
649 ``hf_repo/filename`` strings, so the membership test compares the
650 leading ``hf_repo`` segment. Bare ``hf_repo`` entries are accepted
651 too (e.g. older clients that report just the repo).
652 """
653 installed_repos = {ref.rsplit("/", 1)[0] if ref.endswith(".gguf") else ref for ref in installed}
654 return ModelCatalogSection(
655 active=active,
656 catalog=[
657 ModelCatalogEntry(
658 name=m.display_name,
659 size_gb=m.size_gb,
660 min_ram_gb=m.min_ram_gb,
661 description=m.description,
662 installed=m.hf_repo in installed_repos,
663 )
664 for m in featured
665 ],
666 installed=sorted(installed),
667 )
670async def list_models() -> ModelsResponse:
671 """Return per-role catalogs (chat, embedding, vision, reranker) with active selections.
673 Uses the unfiltered installed set so a single ref lights up in every
674 catalog section it legitimately matches.
675 """
676 from lilbee.catalog import (
677 FEATURED_CHAT,
678 FEATURED_EMBEDDING,
679 FEATURED_RERANK,
680 FEATURED_VISION,
681 )
683 installed = set(get_model_manager().list_installed())
685 return ModelsResponse(
686 chat=_catalog_section(FEATURED_CHAT, cfg.chat_model, installed),
687 embedding=_catalog_section(FEATURED_EMBEDDING, cfg.embedding_model, installed),
688 vision=_catalog_section(FEATURED_VISION, cfg.vision_model, installed),
689 reranker=_catalog_section(FEATURED_RERANK, cfg.reranker_model, installed),
690 )
693async def _set_model(
694 field: Literal["chat_model", "embedding_model", "vision_model", "reranker_model"],
695 model: str,
696) -> SetModelResponse:
697 """Shared helper for switching a model field."""
698 setattr(cfg, field, model)
699 settings.set_value(cfg.data_root, field, model)
700 return SetModelResponse(model=model)
703def _resolve_via_catalog(model: str, available: set[str]) -> str | None:
704 """Resolve a bare ``hf_repo`` to whichever quant of it is in *available*."""
705 entry = find_catalog_entry(model)
706 if entry is None:
707 return None
708 return next((ref for ref in available if ref.startswith(f"{entry.hf_repo}/")), None)
711def _resolve_via_parse(model: str, available: set[str]) -> str | None:
712 """Resolve a provider-prefixed ref to its bare provider name in *available*."""
713 try:
714 parsed = parse_model_ref(model)
715 except ValueError:
716 return None
717 return parsed.name if parsed.name in available else None
720def _require_model_available(model: str) -> str:
721 """Return the installed-and-routable form of *model*, or raise."""
722 not_available = ValueError(
723 f"Model '{model}' is not available. Pull it first or check the name."
724 )
725 if not model:
726 raise not_available
727 available = set(get_services().provider.list_models())
728 if model in available:
729 return model
730 hit = _resolve_via_catalog(model, available) or _resolve_via_parse(model, available)
731 if hit is None:
732 raise not_available
733 return hit
736def _build_task_to_field() -> dict[ModelTask, str]:
737 """Invert config's ``_MODEL_FIELD_TO_TASK`` so the two maps stay in sync."""
738 from lilbee.config import _MODEL_FIELD_TO_TASK
740 return {ModelTask(task): field for field, task in _MODEL_FIELD_TO_TASK.items()}
743_TASK_TO_FIELD: dict[ModelTask, str] = _build_task_to_field()
746def _require_model_for_task(model: str, expected: ModelTask, *, allow_empty: bool = False) -> str:
747 """Validate *model* is installed locally AND passes the catalog task check.
749 Empty string unsets the role when *allow_empty* is True. Catalog +
750 task validation delegates to ``validate_model_task_assignment`` so
751 the handler and config paths share a single implementation.
752 """
753 from lilbee.config import validate_model_task_assignment
755 if allow_empty and not model.strip():
756 return ""
757 normalized = _require_model_available(model)
758 return validate_model_task_assignment(_TASK_TO_FIELD[expected], normalized, allow_bypass=False)
761async def set_chat_model(model: str) -> SetModelResponse:
762 """Switch active chat model. Validates installation and catalog task."""
763 normalized = _require_model_for_task(model, ModelTask.CHAT)
764 return await _set_model("chat_model", normalized)
767async def set_embedding_model(model: str) -> SetModelResponse:
768 """Switch embedding model. Validates installation and catalog task.
770 Returns ``reindex_required=True`` when the new model differs from the
771 embedding model that built the persisted vector store. The caller is
772 expected to trigger a rebuild (``lilbee rebuild`` or ``POST /api/sync``
773 with ``force_rebuild=true``). Search and ingest will refuse to operate
774 until that happens.
776 Pins a legacy store's identity to the OLD cfg before mutating it. Without
777 this step, a pre-upgrade store with chunks but no ``_meta`` row would have
778 its meta lazy-initialized from the NEW cfg on the next read, hiding the
779 drift the caller just introduced.
780 """
781 normalized = _require_model_for_task(model, ModelTask.EMBEDDING)
782 store = get_services().store
783 store.initialize_meta_if_legacy()
784 await _set_model("embedding_model", normalized)
785 meta = store.get_meta()
786 reindex_required = meta is not None and meta["embedding_model"] != normalized
787 return SetModelResponse(model=normalized, reindex_required=reindex_required)
790async def set_vision_model(model: str) -> SetModelResponse:
791 """Switch vision OCR model. Empty string unsets it (vision OCR disabled)."""
792 normalized = _require_model_for_task(model, ModelTask.VISION, allow_empty=True)
793 return await _set_model("vision_model", normalized)
796async def set_reranker_model(model: str) -> SetModelResponse:
797 """Switch reranker model. Empty string unsets it (reranking disabled)."""
798 normalized = _require_model_for_task(model, ModelTask.RERANK, allow_empty=True)
799 return await _set_model("reranker_model", normalized)
802_MIN_CHUNK_SIZE = 64
805def _validate_config_updates(updates: dict[str, Any]) -> None:
806 """Reject unknown fields, null values on non-nullable fields, and invalid ranges."""
807 for key, value in updates.items():
808 if key not in WRITABLE_CONFIG_FIELDS:
809 raise ValueError(f"Unknown or read-only config field: {key}")
810 if value is None and not WRITABLE_CONFIG_FIELDS[key]:
811 raise ValueError(f"Field '{key}' does not accept null")
812 chunk_val = updates.get("chunk_size")
813 if isinstance(chunk_val, int) and chunk_val < _MIN_CHUNK_SIZE:
814 raise ValueError(f"chunk_size must be >= {_MIN_CHUNK_SIZE}")
817def _apply_config_updates(updates: dict[str, Any]) -> tuple[dict[str, str], list[str]]:
818 """Apply updates to the in-memory config, rolling back on error.
819 Returns (fields_to_persist, fields_to_delete) for disk write.
820 """
821 snapshot = {k: getattr(cfg, k) for k in updates}
822 to_persist: dict[str, str] = {}
823 to_delete: list[str] = []
824 try:
825 for key, value in updates.items():
826 if value is None:
827 setattr(cfg, key, None)
828 to_delete.append(key)
829 else:
830 setattr(cfg, key, value)
831 to_persist[key] = str(getattr(cfg, key))
832 except Exception:
833 for k, v in snapshot.items():
834 setattr(cfg, k, v)
835 raise
836 return to_persist, to_delete
839async def update_config(updates: dict[str, Any]) -> ConfigUpdateResponse:
840 """Partial update of writable config fields.
841 Algorithm: validate-then-apply with rollback.
843 1. Validate all keys and null-acceptability upfront (no mutations yet).
844 This catches typos and bad input before anything changes.
845 2. Snapshot current values, then apply each update via setattr (pydantic's
846 validate_assignment catches type errors). If any field fails type
847 validation, roll back ALL fields from the snapshot so the config
848 stays consistent — no half-applied updates.
849 3. Persist to disk in batch (one file write for sets, one for deletes)
850 rather than per-field, avoiding partial writes on crash.
852 Why not just setattr-and-save per field? A multi-field PATCH like
853 {"chunk_size": 1024, "chunk_overlap": "bad"} would leave chunk_size
854 changed but chunk_overlap unchanged — the caller gets an error but
855 the config is silently modified. The snapshot/rollback prevents that.
856 """
857 _validate_config_updates(updates)
858 to_persist, to_delete = _apply_config_updates(updates)
859 if to_persist:
860 settings.update_values(cfg.data_root, to_persist)
861 if to_delete:
862 settings.delete_values(cfg.data_root, to_delete)
863 if API_KEY_FIELDS & set(updates):
864 inject_provider_keys()
865 reindex_required = bool(REINDEX_FIELDS & set(updates))
866 return ConfigUpdateResponse(updated=list(updates), reindex_required=reindex_required)
869async def delete_documents(
870 names: list[str], *, delete_files: bool = False
871) -> DocumentRemoveResponse:
872 """Remove documents from the knowledge base by source name."""
873 result = get_services().store.remove_documents(names, delete_files=delete_files)
874 return DocumentRemoveResponse(removed=result.removed, not_found=result.not_found)
877async def list_documents(
878 search: str = "",
879 limit: int = 50,
880 offset: int = 0,
881) -> DocumentListResponse:
882 """Return indexed documents with metadata, paginated and filterable.
884 Pagination and the filename filter are pushed into LanceDB via
885 ``Store.get_sources(search=..., limit=..., offset=...)`` and the
886 total comes from ``Store.count_sources(search=...)`` so neither
887 call materializes the full SOURCES table per request.
888 """
889 store = get_services().store
890 search_term = search or None
891 page = store.get_sources(search=search_term, limit=limit, offset=offset)
892 total = store.count_sources(search=search_term)
893 return DocumentListResponse(
894 documents=[
895 DocumentInfo(
896 filename=s["filename"],
897 chunk_count=s.get("chunk_count", 0),
898 ingested_at=s.get("ingested_at", ""),
899 )
900 for s in page
901 ],
902 total=total,
903 limit=limit,
904 offset=offset,
905 has_more=len(page) > 0 and (offset + len(page)) < total,
906 )
909async def get_config() -> ConfigResponse:
910 """Return all user-facing configuration values."""
911 dumped = cfg.model_dump()
912 result = {k: v for k, v in dumped.items() if k in _PUBLIC_CONFIG_FIELDS}
913 return ConfigResponse(**result)
916async def get_source_content(
917 source: str, raw: bool = False
918) -> SourceContentResponse | tuple[bytes, str]:
919 """Return a stored source file: JSON with markdown text for text types, or
920 ``(bytes, content_type)`` when *raw* is True. Binary types return empty
921 markdown so clients know to re-request with ``raw=1``.
922 """
923 from lilbee.wiki.index import parse_title
925 if not source or not source.strip():
926 raise ValueError("source must not be empty")
927 documents_dir = cfg.documents_dir
928 resolved = validate_path_within(documents_dir / source, documents_dir)
929 if not resolved.is_file():
930 raise FileNotFoundError(source)
932 content_type, _ = mimetypes.guess_type(resolved.name)
933 if content_type is None:
934 content_type = "application/octet-stream"
936 if raw:
937 # Cap raw responses to inline-render-safe categories; anything else
938 # degrades to a binary download so attacker-renamed files (e.g.
939 # evil.html) can't trick the embedding browser into running script
940 # under our origin.
941 served_type = (
942 content_type if _is_safe_for_inline_render(content_type) else "application/octet-stream"
943 )
944 return resolved.read_bytes(), served_type
946 if not content_type.startswith("text/"):
947 return SourceContentResponse(markdown="", content_type=content_type, title=None)
949 text = resolved.read_text(encoding="utf-8", errors="replace")
950 title = parse_title(text) or None
951 return SourceContentResponse(markdown=text, content_type=content_type, title=title)
954@functools.cache
955def _compute_config_defaults() -> dict[str, Any]:
956 """Materialize Config defaults once per process."""
957 defaults: dict[str, Any] = {}
958 for name, info in Config.model_fields.items():
959 is_writable_public = name in WRITABLE_CONFIG_FIELDS and name in _PUBLIC_CONFIG_FIELDS
960 if not is_writable_public and name not in _MODEL_ROLE_FIELDS:
961 continue
962 value = info.get_default(call_default_factory=True)
963 if value is PydanticUndefined: # pragma: no cover
964 continue
965 defaults[name] = value
966 return defaults
969async def get_config_defaults() -> ConfigResponse:
970 """Return canonical defaults for every public config field.
972 Covers writable fields (resettable via PATCH /api/config) and the
973 model-role fields (resettable via PUT /api/models/<role>).
975 Deepcopies the cached dict so callers that mutate the response
976 (list-valued fields like ``crawl_exclude_patterns``) cannot poison
977 subsequent calls.
978 """
979 return ConfigResponse(**copy.deepcopy(_compute_config_defaults()))
982async def models_show(model: str) -> ModelsShowResponse:
983 """Return model metadata/parameters. Returns empty model if unavailable."""
984 provider = get_services().provider
985 result = provider.show_model(model)
986 return ModelsShowResponse(**(result or {}))
989def _parse_source(source: str) -> ModelSource:
990 """Convert a source string to ModelSource enum."""
991 return ModelSource(source)
994async def models_catalog(
995 task: str | None = None,
996 search: str = "",
997 size: str | None = None,
998 installed: bool | None = None,
999 featured: bool | None = None,
1000 sort: str = "featured",
1001 limit: int = 20,
1002 offset: int = 0,
1003) -> ModelsCatalogResponse:
1004 """Return paginated model catalog with installed status."""
1005 from lilbee.catalog import enrich_catalog, get_catalog
1007 result = get_catalog(
1008 task=task,
1009 search=search,
1010 size=size,
1011 installed=installed,
1012 featured=featured,
1013 sort=sort,
1014 limit=limit,
1015 offset=offset,
1016 )
1018 registry = get_services().registry
1019 installed_refs = {m.ref for m in registry.list_installed()}
1020 enriched = enrich_catalog(result, installed_refs)
1022 return ModelsCatalogResponse(
1023 total=result.total,
1024 limit=result.limit,
1025 offset=result.offset,
1026 has_more=result.has_more,
1027 models=[
1028 CatalogEntryResponse(
1029 hf_repo=e.hf_repo,
1030 gguf_filename=e.gguf_filename,
1031 task=e.task,
1032 display_name=e.display_name,
1033 param_count=e.param_count,
1034 size_gb=e.size_gb,
1035 min_ram_gb=e.min_ram_gb,
1036 description=e.description,
1037 quality_tier=e.quality_tier,
1038 featured=e.featured,
1039 downloads=e.downloads,
1040 installed=e.installed,
1041 source=e.source,
1042 )
1043 for e in enriched
1044 ],
1045 )
1048async def models_installed() -> ModelsInstalledResponse:
1049 """Return list of installed models with their source."""
1050 manager = get_model_manager()
1051 names = manager.list_installed()
1052 models = []
1053 for name in names:
1054 src = manager.get_source(name)
1055 source_str = src.value if src is not None else ModelSource.REMOTE.value
1056 models.append(InstalledModelEntry(name=name, source=source_str))
1057 return ModelsInstalledResponse(models=models)
1060async def models_pull(model: str, *, source: str = "native") -> AsyncGenerator[str, None]:
1061 """Yield SSE progress events while pulling a model in real time.
1062 Sets a cancel event on client disconnect so the pull stops.
1063 """
1064 manager = get_model_manager()
1065 src = _parse_source(source)
1066 sse = SseStream()
1068 def _pull_blocking() -> None:
1069 def _on_progress(data: dict[str, Any]) -> None:
1070 if sse.cancel.is_set():
1071 return
1072 payload = sse_event(SseEvent.PROGRESS, data)
1073 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, payload)
1075 def _on_bytes(downloaded: int, total: int) -> None:
1076 if sse.cancel.is_set():
1077 return
1078 payload = sse_event(SseEvent.PROGRESS, {"current": downloaded, "total": total})
1079 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, payload)
1081 try:
1082 manager.pull(model, src, on_progress=_on_progress, on_bytes=_on_bytes)
1083 except Exception as exc:
1084 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, sse_error(str(exc)))
1085 finally:
1086 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, None)
1088 task = asyncio.ensure_future(asyncio.to_thread(_pull_blocking))
1089 async for event in sse.drain(task, "Model pull stream"):
1090 yield event
1093async def models_delete(model: str, *, source: str = "native") -> ModelsDeleteResponse:
1094 """Delete a model. Returns deletion status, model name, and freed space."""
1095 manager = get_model_manager()
1096 src = _parse_source(source)
1097 deleted = manager.remove(model, src)
1098 return ModelsDeleteResponse(deleted=deleted, model=model, freed_gb=0.0)
1101async def crawl_stream(
1102 url: str, depth: int | None = None, max_pages: int | None = None
1103) -> AsyncGenerator[str, None]:
1104 """Stream crawl progress as SSE events.
1105 Emits crawl_start, crawl_page, crawl_done events, then a final done event
1106 with the list of files written. On error emits crawl_error.
1107 Sets a cancel event on client disconnect so the crawl stops between pages.
1109 On first use, Chromium isn't installed yet. The stream inlines
1110 setup_start/progress/done events before the crawl begins so the
1111 Obsidian plugin's Task Center renders a matching 'setup' pill (bb-wq8g).
1112 """
1113 sse = SseStream()
1115 async def _run_crawl() -> list[Path]:
1116 from lilbee.crawler import crawl_and_save
1118 # crawl_and_save runs the Chromium bootstrap itself on first use,
1119 # relaying setup_* events through the same on_progress callback
1120 # so the SSE stream carries them before any crawl_* events.
1121 try:
1122 return await crawl_and_save(
1123 url, depth=depth, max_pages=max_pages, on_progress=sse.callback, cancel=sse.cancel
1124 )
1125 finally:
1126 sse.queue.put_nowait(None)
1128 task = asyncio.create_task(_run_crawl())
1129 async for event in sse.drain(task, "Crawl stream"):
1130 yield event
1131 if not sse.cancel.is_set() and task.done() and not task.cancelled():
1132 exc = task.exception()
1133 if exc is not None:
1134 yield sse_error(str(exc))
1135 return
1136 paths = task.result()
1137 yield sse_done({"files_written": [str(p) for p in paths]})
1140_EXTERNAL_MODELS_TTL = 60
1143class _ExternalModelsCache:
1144 """TTL cache for external model listings (no module-level mutable global)."""
1146 def __init__(self) -> None:
1147 self._time: float = 0.0
1148 self._key: str = ""
1149 self._result: ExternalModelsResponse | None = None
1151 def get(self, key: str) -> ExternalModelsResponse | None:
1152 now = time.monotonic()
1153 if self._result and key == self._key and (now - self._time) < _EXTERNAL_MODELS_TTL:
1154 return self._result
1155 return None
1157 def set(self, key: str, result: ExternalModelsResponse) -> None:
1158 self._time = time.monotonic()
1159 self._key = key
1160 self._result = result
1163_external_cache = _ExternalModelsCache()
1166async def list_external_models() -> ExternalModelsResponse:
1167 """Query the provider for available models via its list_models() API."""
1168 key = f"{cfg.remote_base_url}:{cfg.llm_api_key or ''}"
1169 cached = _external_cache.get(key)
1170 if cached:
1171 return cached
1173 try:
1174 models = await asyncio.to_thread(get_services().provider.list_models)
1175 result = ExternalModelsResponse(models=models)
1176 _external_cache.set(key, result)
1177 return result
1178 except Exception as exc:
1179 log.warning("Failed to list external models: %s", exc)
1180 return ExternalModelsResponse(models=[], error=str(exc))