Coverage for src / lilbee / server / handlers.py: 100%

546 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Framework-agnostic route handlers for the lilbee HTTP server. 

2 

3Every public function is a plain async callable — no framework imports. 

4Return types are dicts (JSON responses), lists, or async generators of SSE strings. 

5""" 

6 

7from __future__ import annotations 

8 

9import asyncio 

10import contextlib 

11import copy 

12import functools 

13import json 

14import logging 

15import mimetypes 

16import threading 

17import time 

18from collections.abc import AsyncGenerator, Iterator 

19from pathlib import Path 

20from typing import TYPE_CHECKING, Any, Literal, cast 

21 

22from pydantic import BaseModel 

23from pydantic_core import PydanticUndefined 

24 

25from lilbee import settings 

26from lilbee.catalog import find_catalog_entry 

27from lilbee.cli.helpers import clean_result, copy_files, gather_status, get_version 

28from lilbee.config import Config, cfg 

29from lilbee.config_meta import ( 

30 MODEL_ROLE_FIELDS as _MODEL_ROLE_FIELDS, 

31) 

32from lilbee.config_meta import ( 

33 PUBLIC_CONFIG_FIELDS as _PUBLIC_CONFIG_FIELDS, 

34) 

35from lilbee.config_meta import ( 

36 REINDEX_FIELDS, 

37 WRITABLE_CONFIG_FIELDS, 

38) 

39from lilbee.model_manager import ModelSource, get_model_manager 

40from lilbee.models import ModelTask 

41from lilbee.progress import DetailedProgressCallback, EventType, ProgressEvent, SseEvent 

42from lilbee.providers.model_ref import parse_model_ref 

43from lilbee.providers.sdk_backend import API_KEY_FIELDS 

44from lilbee.providers.sdk_llm_provider import inject_provider_keys 

45from lilbee.results import DocumentResult, group 

46from lilbee.security import validate_path_within 

47from lilbee.server.models import ( 

48 AddSummary, 

49 AskResponse, 

50 CatalogEntryResponse, 

51 CleanedChunk, 

52 ConfigResponse, 

53 ConfigUpdateResponse, 

54 DocumentInfo, 

55 DocumentListResponse, 

56 DocumentRemoveResponse, 

57 ExternalModelsResponse, 

58 HealthResponse, 

59 InstalledModelEntry, 

60 ModelsCatalogResponse, 

61 ModelsDeleteResponse, 

62 ModelsInstalledResponse, 

63 ModelsShowResponse, 

64 SetModelResponse, 

65 SourceContentResponse, 

66 StatusResponse, 

67 SyncSummary, 

68) 

69from lilbee.services import get_services 

70 

71if TYPE_CHECKING: 

72 from lilbee.catalog import CatalogModel 

73 from lilbee.ingest import SyncResult 

74 from lilbee.query import ChatMessage 

75 

76log = logging.getLogger(__name__) 

77 

78# Windows mimetypes reads from the registry, which may not define ``.md`` 

79# as ``text/markdown``. Pin the mapping at import time; ``add_type`` is 

80# idempotent so repeated imports are safe. 

81mimetypes.add_type("text/markdown", ".md") 

82 

83MAX_ADD_FILES = 100 

84 

85 

86class ModelCatalogEntry(BaseModel): 

87 """A single model in the catalog.""" 

88 

89 name: str 

90 size_gb: float 

91 min_ram_gb: float 

92 description: str 

93 installed: bool 

94 

95 

96class ModelCatalogSection(BaseModel): 

97 """A single-role catalog section with active model and installed list.""" 

98 

99 active: str 

100 catalog: list[ModelCatalogEntry] 

101 installed: list[str] 

102 

103 

104class ModelsResponse(BaseModel): 

105 """Response for GET /api/models: one catalog section per role.""" 

106 

107 chat: ModelCatalogSection 

108 embedding: ModelCatalogSection 

109 vision: ModelCatalogSection 

110 reranker: ModelCatalogSection 

111 

112 

113# ``ModelTask.RERANK.value`` is ``"rerank"`` but the route is ``/api/models/reranker``, 

114# so this mapping is needed to build correct redirect URLs in 422 responses. 

115TASK_ENDPOINT_PATH: dict[ModelTask, str] = { 

116 ModelTask.CHAT: "chat", 

117 ModelTask.EMBEDDING: "embedding", 

118 ModelTask.VISION: "vision", 

119 ModelTask.RERANK: "reranker", 

120} 

121 

122 

123def format_task_mismatch(ref: str, entry_task: ModelTask, expected_task: ModelTask) -> str: 

124 """Build the 422 body when a role slot is assigned a model of the wrong task.""" 

125 endpoint = TASK_ENDPOINT_PATH[entry_task] 

126 return ( 

127 f"Model '{ref}' is a {entry_task} model, not {expected_task}. " 

128 f"Set it via PUT /api/models/{endpoint} instead." 

129 ) 

130 

131 

132def sse_event(event: str, data: Any) -> str: 

133 """Format a single Server-Sent Event string.""" 

134 return f"event: {event}\ndata: {json.dumps(data)}\n\n" 

135 

136 

137def sse_error(message: str, *, code: str | None = None, detail: str | None = None) -> str: 

138 """Format an SSE error event with optional structured ``code`` / ``detail``.""" 

139 payload: dict[str, Any] = {"message": message} 

140 if code is not None: 

141 payload["code"] = code 

142 if detail is not None: 

143 payload["detail"] = detail 

144 return sse_event(SseEvent.ERROR, payload) 

145 

146 

147_OOM_MARKERS = ("failed to load", "free ram", "try a smaller model", "llama_context") 

148 

149 

150def classify_load_error(message: str) -> tuple[str | None, str]: 

151 """Return ``(code, user_message)`` for an SSE error event. 

152 

153 Recognises the llama.cpp OOM diagnostic and maps it to a stable code; any 

154 other input falls back to the legacy generic shape. 

155 """ 

156 lowered = message.lower() 

157 if any(marker in lowered for marker in _OOM_MARKERS): 

158 return "model_too_large", "Model too large for available RAM" 

159 return None, "Internal error" 

160 

161 

162def sse_done(data: dict[str, Any]) -> str: 

163 """Format an SSE done event.""" 

164 return sse_event(SseEvent.DONE, data) 

165 

166 

167def _resolve_generation_options(options: dict[str, Any] | None) -> dict[str, Any] | None: 

168 """Convert raw options dict to GenerationOptions, or None.""" 

169 return cfg.generation_options(**options) if options else None 

170 

171 

172class SseStream: 

173 """Context object for SSE streaming with cancellation support. 

174 Bundles the queue, cancel event, and progress callback that every SSE 

175 endpoint needs. Call :meth:`drain` to yield events until the task 

176 completes or the client disconnects. 

177 """ 

178 

179 def __init__(self) -> None: 

180 self.queue: asyncio.Queue[str | None] = asyncio.Queue() 

181 self.cancel = threading.Event() 

182 self.loop = asyncio.get_running_loop() 

183 self.callback: DetailedProgressCallback = self._build_callback() 

184 

185 def _build_callback(self) -> DetailedProgressCallback: 

186 """Create a progress callback that serializes events into the queue. 

187 Safe to call from both the event-loop thread and worker threads. 

188 """ 

189 loop = self.loop 

190 queue = self.queue 

191 

192 def _callback(event_type: EventType, data: ProgressEvent) -> None: 

193 serialized = data.model_dump() if isinstance(data, BaseModel) else data 

194 payload = f"event: {event_type}\ndata: {json.dumps(serialized)}\n\n" 

195 try: 

196 running = asyncio.get_running_loop() 

197 except RuntimeError: 

198 running = None 

199 if running is loop: 

200 queue.put_nowait(payload) 

201 else: 

202 loop.call_soon_threadsafe(queue.put_nowait, payload) 

203 

204 return _callback 

205 

206 async def drain( 

207 self, task: asyncio.Task[Any] | asyncio.Future[Any], label: str 

208 ) -> AsyncGenerator[str, None]: 

209 """Yield SSE strings until a sentinel arrives; cancel *task* on client disconnect. 

210 

211 Emits a ``heartbeat`` event whenever the producer queue stays 

212 idle longer than ``cfg.sse_heartbeat_interval`` seconds so 

213 clients that enforce a stream-idle timeout don't abort. 

214 """ 

215 last_yielded = time.monotonic() 

216 try: 

217 while True: 

218 try: 

219 item = await asyncio.wait_for(self.queue.get(), timeout=0.1) 

220 except TimeoutError: 

221 now = time.monotonic() 

222 heartbeat_interval = cfg.sse_heartbeat_interval 

223 if heartbeat_interval > 0 and now - last_yielded >= heartbeat_interval: 

224 last_yielded = now 

225 yield sse_event(SseEvent.HEARTBEAT, {"ts": time.time()}) 

226 # Fallback for producers that die without a sentinel. 

227 if task.done() and self.queue.empty(): 

228 break 

229 continue 

230 if item is None: 

231 break 

232 last_yielded = time.monotonic() 

233 yield item 

234 except (asyncio.CancelledError, GeneratorExit): 

235 log.info("%s cancelled by client", label) 

236 self.cancel.set() 

237 task.cancel() 

238 

239 

240async def health() -> HealthResponse: 

241 """Return service health and version.""" 

242 return HealthResponse(status="ok", version=get_version()) 

243 

244 

245async def status() -> StatusResponse: 

246 """Return config, sources, and chunk counts.""" 

247 raw = gather_status() 

248 return StatusResponse(**raw.model_dump(exclude_none=True)) 

249 

250 

251async def search(q: str, top_k: int = 5, chunk_type: str | None = None) -> list[DocumentResult]: 

252 """Search and return grouped DocumentResults.""" 

253 if not q or not q.strip(): 

254 raise ValueError("query must not be empty") 

255 results = get_services().searcher.search(q, top_k=top_k, chunk_type=chunk_type) 

256 results = [r for r in results if r.distance is None or r.distance <= cfg.max_distance] 

257 return group(results) 

258 

259 

260async def ask( 

261 question: str, 

262 top_k: int = 0, 

263 options: dict[str, Any] | None = None, 

264 chunk_type: str | None = None, 

265) -> AskResponse: 

266 """One-shot RAG answer. Returns answer and sources.""" 

267 if not question or not question.strip(): 

268 raise ValueError("question must not be empty") 

269 opts = _resolve_generation_options(options) 

270 result = get_services().searcher.ask_raw( 

271 question, top_k=top_k, options=opts, chunk_type=chunk_type 

272 ) 

273 return AskResponse( 

274 answer=result.answer, 

275 sources=[CleanedChunk(**clean_result(s)) for s in result.sources], 

276 ) 

277 

278 

279def _run_llm_stream( 

280 messages: list[ChatMessage], 

281 opts: dict[str, Any] | None, 

282 queue: asyncio.Queue[str | None], 

283 cancel: threading.Event, 

284 error_holder: list[str], 

285) -> None: 

286 """Stream LLM tokens into a queue from a worker thread.""" 

287 from lilbee.reasoning import filter_reasoning 

288 

289 try: 

290 provider = get_services().provider 

291 stream = provider.chat( 

292 cast("list[dict[str, Any]]", messages), 

293 stream=True, 

294 options=opts or None, 

295 model=cfg.chat_model, 

296 ) 

297 for st in filter_reasoning(cast(Iterator[str], stream), show=cfg.show_reasoning): 

298 if cancel.is_set(): 

299 break 

300 if st.content: 

301 event_type = SseEvent.REASONING if st.is_reasoning else SseEvent.TOKEN 

302 queue.put_nowait(sse_event(event_type, {"token": st.content})) 

303 except Exception as exc: 

304 error_holder.append(str(exc)) 

305 finally: 

306 queue.put_nowait(None) 

307 

308 

309async def _stream_rag_response( 

310 question: str, 

311 history: list[ChatMessage] | None = None, 

312 top_k: int = 0, 

313 options: dict[str, Any] | None = None, 

314 chunk_type: str | None = None, 

315) -> AsyncGenerator[str, None]: 

316 """Shared SSE streaming for ask_stream and chat_stream.""" 

317 yield "" # force generator 

318 

319 rag = get_services().searcher.build_rag_context( 

320 question, top_k=top_k, history=history, chunk_type=chunk_type 

321 ) 

322 if rag is None: 

323 yield sse_error("No relevant documents found.") 

324 return 

325 

326 results, messages = rag 

327 opts = _resolve_generation_options(options) or cfg.generation_options() 

328 

329 sse = SseStream() 

330 error_holder: list[str] = [] 

331 

332 executor_fut = sse.loop.run_in_executor( 

333 None, _run_llm_stream, messages, opts, sse.queue, sse.cancel, error_holder 

334 ) 

335 task = asyncio.ensure_future(executor_fut) 

336 async for event in sse.drain(task, "RAG stream"): 

337 yield event 

338 

339 if error_holder: 

340 raw = error_holder[0] 

341 code, user_message = classify_load_error(raw) 

342 log.warning("Stream error: %s", raw) 

343 yield sse_error(user_message, code=code, detail=raw if code else None) 

344 sse.cancel.set() 

345 return 

346 

347 # Ensure executor thread has finished before yielding final events 

348 await executor_fut 

349 

350 yield sse_event(SseEvent.SOURCES, [clean_result(s) for s in results]) 

351 yield sse_done({}) 

352 

353 

354def ask_stream( 

355 question: str, 

356 top_k: int = 0, 

357 options: dict[str, Any] | None = None, 

358 chunk_type: str | None = None, 

359) -> AsyncGenerator[str, None]: 

360 """Yield SSE events: token, sources, done.""" 

361 return _stream_rag_response(question, top_k=top_k, options=options, chunk_type=chunk_type) 

362 

363 

364async def chat( 

365 question: str, 

366 history: list[ChatMessage], 

367 top_k: int = 0, 

368 options: dict[str, Any] | None = None, 

369 chunk_type: str | None = None, 

370) -> AskResponse: 

371 """Chat with history. Returns answer and sources.""" 

372 opts = _resolve_generation_options(options) 

373 result = get_services().searcher.ask_raw( 

374 question, top_k=top_k, history=history, options=opts, chunk_type=chunk_type 

375 ) 

376 return AskResponse( 

377 answer=result.answer, 

378 sources=[CleanedChunk(**clean_result(s)) for s in result.sources], 

379 ) 

380 

381 

382def chat_stream( 

383 question: str, 

384 history: list[ChatMessage], 

385 top_k: int = 0, 

386 options: dict[str, Any] | None = None, 

387 chunk_type: str | None = None, 

388) -> AsyncGenerator[str, None]: 

389 """Yield SSE events with chat history support.""" 

390 return _stream_rag_response( 

391 question, history=history, top_k=top_k, options=options, chunk_type=chunk_type 

392 ) 

393 

394 

395async def _run_sync_with_sentinel( 

396 sse: SseStream, enable_ocr: bool | None, force_rebuild: bool = False 

397) -> SyncResult: 

398 """Run ingest.sync() and guarantee the drain sentinel is enqueued.""" 

399 from lilbee.cli.helpers import temporary_ocr_config 

400 from lilbee.ingest import sync 

401 

402 try: 

403 with temporary_ocr_config(enable_ocr): 

404 return await sync( 

405 quiet=True, 

406 on_progress=sse.callback, 

407 cancel=sse.cancel, 

408 force_rebuild=force_rebuild, 

409 ) 

410 finally: 

411 sse.queue.put_nowait(None) 

412 

413 

414# Registry lock serializes lock creation and the check-and-acquire step to 

415# avoid a TOCTOU between locked() and acquire() under concurrent /api/add. 

416_INGEST_LOCKS: dict[str, asyncio.Lock] = {} 

417_INGEST_LOCK_REGISTRY: asyncio.Lock | None = None 

418 

419 

420# Types that can carry script even within an "inline-rendered" category. 

421# Keep the deny narrow and explicit. Broadening this set is a security-relevant 

422# change — file an issue with the ``security`` label before adding entries. 

423_RAW_INLINE_RENDER_DENY: frozenset[str] = frozenset( 

424 { 

425 "text/html", 

426 "text/javascript", 

427 "application/javascript", 

428 "application/xhtml+xml", 

429 "text/css", 

430 "image/svg+xml", 

431 } 

432) 

433 

434 

435def _is_safe_for_inline_render(content_type: str) -> bool: 

436 """Whether ``raw=1`` may serve this Content-Type as-is. 

437 

438 Trusted categories (``text/*``, ``image/*``, ``application/pdf``) pass 

439 through, with named exceptions for types that embed executable script. 

440 Everything else degrades to ``application/octet-stream`` so an attacker- 

441 renamed file (e.g. ``evil.html``) cannot trick a browser into rendering 

442 it inline within the plugin origin. 

443 """ 

444 if content_type in _RAW_INLINE_RENDER_DENY: 

445 return False 

446 if content_type == "application/pdf": 

447 return True 

448 return content_type.startswith("text/") or content_type.startswith("image/") 

449 

450 

451def _get_registry_lock() -> asyncio.Lock: 

452 """Return the registry lock, creating it on the running loop if needed.""" 

453 global _INGEST_LOCK_REGISTRY 

454 if _INGEST_LOCK_REGISTRY is None: 

455 _INGEST_LOCK_REGISTRY = asyncio.Lock() 

456 return _INGEST_LOCK_REGISTRY 

457 

458 

459def _reset_ingest_locks() -> None: 

460 """Test hook: clear per-source locks and the registry lock.""" 

461 global _INGEST_LOCK_REGISTRY 

462 _INGEST_LOCKS.clear() 

463 _INGEST_LOCK_REGISTRY = None 

464 

465 

466async def _try_acquire_source(name: str) -> asyncio.Lock | None: 

467 """Acquire the lock for ``name`` or return ``None`` if already held.""" 

468 async with _get_registry_lock(): 

469 lock = _INGEST_LOCKS.get(name) 

470 if lock is None: 

471 lock = asyncio.Lock() 

472 _INGEST_LOCKS[name] = lock 

473 if lock.locked(): 

474 return None 

475 await lock.acquire() 

476 return lock 

477 

478 

479def _canonical_source_name(p_str: str) -> str: 

480 """Match the basename ``copy_files`` writes under ``cfg.documents_dir``.""" 

481 return Path(p_str).name 

482 

483 

484async def sync_stream( 

485 *, enable_ocr: bool | None = None, force_rebuild: bool = False 

486) -> AsyncGenerator[str, None]: 

487 """Trigger sync, yield SSE progress events, then done event. 

488 

489 When ``force_rebuild`` is true, the underlying sync drops every table and 

490 re-ingests from ``cfg.documents_dir`` (the REST equivalent of ``lilbee rebuild``). 

491 """ 

492 sse = SseStream() 

493 task = asyncio.create_task(_run_sync_with_sentinel(sse, enable_ocr, force_rebuild)) 

494 async for event in sse.drain(task, "Sync stream"): 

495 yield event 

496 if not sse.cancel.is_set() and task.done() and not task.cancelled(): 

497 exc = task.exception() 

498 if exc is not None: 

499 yield sse_error(str(exc)) 

500 return 

501 yield sse_done(task.result().model_dump()) 

502 

503 

504async def _run_add( 

505 paths: list[str], 

506 force: bool, 

507 enable_ocr: bool | None, 

508 ocr_timeout: float | None, 

509 sse: SseStream, 

510) -> AddSummary: 

511 """Copy files and sync, returning the summary for the final done event.""" 

512 from lilbee.cli.helpers import temporary_ocr_config 

513 from lilbee.ingest import sync 

514 

515 try: 

516 errors: list[str] = [] 

517 valid: list[Path] = [] 

518 for p_str in paths: 

519 p = Path(p_str) 

520 if not p.exists(): 

521 errors.append(p_str) 

522 else: 

523 valid.append(p) 

524 

525 copy_result = copy_files(valid, force=force) 

526 

527 if sse.cancel.is_set(): 

528 return AddSummary(copied=copy_result.copied, skipped=copy_result.skipped, errors=errors) 

529 

530 with temporary_ocr_config(enable_ocr, ocr_timeout): 

531 sync_result = await sync(quiet=True, on_progress=sse.callback, cancel=sse.cancel) 

532 

533 return AddSummary( 

534 copied=copy_result.copied, 

535 skipped=copy_result.skipped, 

536 errors=errors, 

537 sync=SyncSummary(**sync_result.model_dump()), 

538 ) 

539 finally: 

540 sse.queue.put_nowait(None) 

541 

542 

543async def _acquire_add_locks( 

544 paths: list[str], 

545) -> tuple[list[tuple[str, asyncio.Lock]], list[str]]: 

546 """Return ``(acquired, busy)`` partitioning of ``paths`` by lock state.""" 

547 acquired: list[tuple[str, asyncio.Lock]] = [] 

548 busy: list[str] = [] 

549 seen: set[str] = set() 

550 for p_str in paths: 

551 name = _canonical_source_name(p_str) 

552 if name in seen: 

553 continue 

554 seen.add(name) 

555 lock = await _try_acquire_source(name) 

556 if lock is None: 

557 busy.append(name) 

558 else: 

559 acquired.append((name, lock)) 

560 return acquired, busy 

561 

562 

563def _release_add_locks(acquired: list[tuple[str, asyncio.Lock]]) -> None: 

564 """Release every lock in ``acquired``. Safe to call multiple times.""" 

565 while acquired: 

566 _, lock = acquired.pop() 

567 if lock.locked(): 

568 lock.release() 

569 

570 

571def validate_add_paths( 

572 data: dict[str, Any], 

573) -> tuple[list[str], bool, bool | None, float | None]: 

574 """Validate add-files input. Raises ValueError on bad input.""" 

575 paths = data.get("paths") 

576 if not isinstance(paths, list) or not paths: 

577 raise ValueError("'paths' must be a non-empty list of strings") 

578 if len(paths) > MAX_ADD_FILES: 

579 raise ValueError(f"Too many files: {len(paths)} exceeds limit of {MAX_ADD_FILES}") 

580 

581 for p_str in paths: 

582 validate_path_within(cfg.documents_dir / Path(p_str).name, cfg.documents_dir) 

583 

584 force = bool(data.get("force", False)) 

585 enable_ocr, ocr_timeout = _parse_ocr_params(data) 

586 return paths, force, enable_ocr, ocr_timeout 

587 

588 

589def _parse_ocr_params(data: dict[str, Any]) -> tuple[bool | None, float | None]: 

590 """Extract and coerce OCR parameters from a request dict.""" 

591 enable_ocr = data.get("enable_ocr") 

592 ocr_timeout = data.get("ocr_timeout") 

593 if enable_ocr is not None: 

594 enable_ocr = bool(enable_ocr) 

595 if ocr_timeout is not None: 

596 ocr_timeout = float(ocr_timeout) 

597 return enable_ocr, ocr_timeout 

598 

599 

600async def add_files_stream(data: dict[str, Any]) -> AsyncGenerator[str, None]: 

601 """Copy files, sync, and yield SSE progress events. 

602 

603 Contended sources emit ``already_ingesting`` and the stream closes 

604 without a ``done`` event, signalling the client to wait rather than retry. 

605 """ 

606 paths = data.get("paths", []) 

607 force = bool(data.get("force", False)) 

608 enable_ocr, ocr_timeout = _parse_ocr_params(data) 

609 

610 acquired, busy = await _acquire_add_locks(paths) 

611 try: 

612 for name in busy: 

613 log.info("Rejecting /api/add for %s: already ingesting", name) 

614 yield sse_event(SseEvent.ALREADY_INGESTING, {"source": name}) 

615 

616 if not acquired: 

617 return 

618 

619 sse = SseStream() 

620 task = asyncio.create_task(_run_add(paths, force, enable_ocr, ocr_timeout, sse)) 

621 try: 

622 async for event in sse.drain(task, "Add files stream"): 

623 yield event 

624 if not sse.cancel.is_set() and task.done() and not task.cancelled(): 

625 exc = task.exception() 

626 if exc is not None: 

627 yield sse_error(str(exc)) 

628 return 

629 summary = task.result() 

630 yield sse_done(summary.model_dump()) 

631 finally: 

632 if not task.done(): 

633 task.cancel() 

634 with contextlib.suppress(asyncio.CancelledError, Exception): 

635 await task 

636 finally: 

637 _release_add_locks(acquired) 

638 

639 

640def _catalog_section( 

641 featured: tuple[CatalogModel, ...], 

642 active: str, 

643 installed: set[str], 

644) -> ModelCatalogSection: 

645 """Build a ModelCatalogSection from a featured-catalog tuple. 

646 

647 A featured row is "installed" when at least one quant of its 

648 ``hf_repo`` has a manifest. Installed refs are full 

649 ``hf_repo/filename`` strings, so the membership test compares the 

650 leading ``hf_repo`` segment. Bare ``hf_repo`` entries are accepted 

651 too (e.g. older clients that report just the repo). 

652 """ 

653 installed_repos = {ref.rsplit("/", 1)[0] if ref.endswith(".gguf") else ref for ref in installed} 

654 return ModelCatalogSection( 

655 active=active, 

656 catalog=[ 

657 ModelCatalogEntry( 

658 name=m.display_name, 

659 size_gb=m.size_gb, 

660 min_ram_gb=m.min_ram_gb, 

661 description=m.description, 

662 installed=m.hf_repo in installed_repos, 

663 ) 

664 for m in featured 

665 ], 

666 installed=sorted(installed), 

667 ) 

668 

669 

670async def list_models() -> ModelsResponse: 

671 """Return per-role catalogs (chat, embedding, vision, reranker) with active selections. 

672 

673 Uses the unfiltered installed set so a single ref lights up in every 

674 catalog section it legitimately matches. 

675 """ 

676 from lilbee.catalog import ( 

677 FEATURED_CHAT, 

678 FEATURED_EMBEDDING, 

679 FEATURED_RERANK, 

680 FEATURED_VISION, 

681 ) 

682 

683 installed = set(get_model_manager().list_installed()) 

684 

685 return ModelsResponse( 

686 chat=_catalog_section(FEATURED_CHAT, cfg.chat_model, installed), 

687 embedding=_catalog_section(FEATURED_EMBEDDING, cfg.embedding_model, installed), 

688 vision=_catalog_section(FEATURED_VISION, cfg.vision_model, installed), 

689 reranker=_catalog_section(FEATURED_RERANK, cfg.reranker_model, installed), 

690 ) 

691 

692 

693async def _set_model( 

694 field: Literal["chat_model", "embedding_model", "vision_model", "reranker_model"], 

695 model: str, 

696) -> SetModelResponse: 

697 """Shared helper for switching a model field.""" 

698 setattr(cfg, field, model) 

699 settings.set_value(cfg.data_root, field, model) 

700 return SetModelResponse(model=model) 

701 

702 

703def _resolve_via_catalog(model: str, available: set[str]) -> str | None: 

704 """Resolve a bare ``hf_repo`` to whichever quant of it is in *available*.""" 

705 entry = find_catalog_entry(model) 

706 if entry is None: 

707 return None 

708 return next((ref for ref in available if ref.startswith(f"{entry.hf_repo}/")), None) 

709 

710 

711def _resolve_via_parse(model: str, available: set[str]) -> str | None: 

712 """Resolve a provider-prefixed ref to its bare provider name in *available*.""" 

713 try: 

714 parsed = parse_model_ref(model) 

715 except ValueError: 

716 return None 

717 return parsed.name if parsed.name in available else None 

718 

719 

720def _require_model_available(model: str) -> str: 

721 """Return the installed-and-routable form of *model*, or raise.""" 

722 not_available = ValueError( 

723 f"Model '{model}' is not available. Pull it first or check the name." 

724 ) 

725 if not model: 

726 raise not_available 

727 available = set(get_services().provider.list_models()) 

728 if model in available: 

729 return model 

730 hit = _resolve_via_catalog(model, available) or _resolve_via_parse(model, available) 

731 if hit is None: 

732 raise not_available 

733 return hit 

734 

735 

736def _build_task_to_field() -> dict[ModelTask, str]: 

737 """Invert config's ``_MODEL_FIELD_TO_TASK`` so the two maps stay in sync.""" 

738 from lilbee.config import _MODEL_FIELD_TO_TASK 

739 

740 return {ModelTask(task): field for field, task in _MODEL_FIELD_TO_TASK.items()} 

741 

742 

743_TASK_TO_FIELD: dict[ModelTask, str] = _build_task_to_field() 

744 

745 

746def _require_model_for_task(model: str, expected: ModelTask, *, allow_empty: bool = False) -> str: 

747 """Validate *model* is installed locally AND passes the catalog task check. 

748 

749 Empty string unsets the role when *allow_empty* is True. Catalog + 

750 task validation delegates to ``validate_model_task_assignment`` so 

751 the handler and config paths share a single implementation. 

752 """ 

753 from lilbee.config import validate_model_task_assignment 

754 

755 if allow_empty and not model.strip(): 

756 return "" 

757 normalized = _require_model_available(model) 

758 return validate_model_task_assignment(_TASK_TO_FIELD[expected], normalized, allow_bypass=False) 

759 

760 

761async def set_chat_model(model: str) -> SetModelResponse: 

762 """Switch active chat model. Validates installation and catalog task.""" 

763 normalized = _require_model_for_task(model, ModelTask.CHAT) 

764 return await _set_model("chat_model", normalized) 

765 

766 

767async def set_embedding_model(model: str) -> SetModelResponse: 

768 """Switch embedding model. Validates installation and catalog task. 

769 

770 Returns ``reindex_required=True`` when the new model differs from the 

771 embedding model that built the persisted vector store. The caller is 

772 expected to trigger a rebuild (``lilbee rebuild`` or ``POST /api/sync`` 

773 with ``force_rebuild=true``). Search and ingest will refuse to operate 

774 until that happens. 

775 

776 Pins a legacy store's identity to the OLD cfg before mutating it. Without 

777 this step, a pre-upgrade store with chunks but no ``_meta`` row would have 

778 its meta lazy-initialized from the NEW cfg on the next read, hiding the 

779 drift the caller just introduced. 

780 """ 

781 normalized = _require_model_for_task(model, ModelTask.EMBEDDING) 

782 store = get_services().store 

783 store.initialize_meta_if_legacy() 

784 await _set_model("embedding_model", normalized) 

785 meta = store.get_meta() 

786 reindex_required = meta is not None and meta["embedding_model"] != normalized 

787 return SetModelResponse(model=normalized, reindex_required=reindex_required) 

788 

789 

790async def set_vision_model(model: str) -> SetModelResponse: 

791 """Switch vision OCR model. Empty string unsets it (vision OCR disabled).""" 

792 normalized = _require_model_for_task(model, ModelTask.VISION, allow_empty=True) 

793 return await _set_model("vision_model", normalized) 

794 

795 

796async def set_reranker_model(model: str) -> SetModelResponse: 

797 """Switch reranker model. Empty string unsets it (reranking disabled).""" 

798 normalized = _require_model_for_task(model, ModelTask.RERANK, allow_empty=True) 

799 return await _set_model("reranker_model", normalized) 

800 

801 

802_MIN_CHUNK_SIZE = 64 

803 

804 

805def _validate_config_updates(updates: dict[str, Any]) -> None: 

806 """Reject unknown fields, null values on non-nullable fields, and invalid ranges.""" 

807 for key, value in updates.items(): 

808 if key not in WRITABLE_CONFIG_FIELDS: 

809 raise ValueError(f"Unknown or read-only config field: {key}") 

810 if value is None and not WRITABLE_CONFIG_FIELDS[key]: 

811 raise ValueError(f"Field '{key}' does not accept null") 

812 chunk_val = updates.get("chunk_size") 

813 if isinstance(chunk_val, int) and chunk_val < _MIN_CHUNK_SIZE: 

814 raise ValueError(f"chunk_size must be >= {_MIN_CHUNK_SIZE}") 

815 

816 

817def _apply_config_updates(updates: dict[str, Any]) -> tuple[dict[str, str], list[str]]: 

818 """Apply updates to the in-memory config, rolling back on error. 

819 Returns (fields_to_persist, fields_to_delete) for disk write. 

820 """ 

821 snapshot = {k: getattr(cfg, k) for k in updates} 

822 to_persist: dict[str, str] = {} 

823 to_delete: list[str] = [] 

824 try: 

825 for key, value in updates.items(): 

826 if value is None: 

827 setattr(cfg, key, None) 

828 to_delete.append(key) 

829 else: 

830 setattr(cfg, key, value) 

831 to_persist[key] = str(getattr(cfg, key)) 

832 except Exception: 

833 for k, v in snapshot.items(): 

834 setattr(cfg, k, v) 

835 raise 

836 return to_persist, to_delete 

837 

838 

839async def update_config(updates: dict[str, Any]) -> ConfigUpdateResponse: 

840 """Partial update of writable config fields. 

841 Algorithm: validate-then-apply with rollback. 

842 

843 1. Validate all keys and null-acceptability upfront (no mutations yet). 

844 This catches typos and bad input before anything changes. 

845 2. Snapshot current values, then apply each update via setattr (pydantic's 

846 validate_assignment catches type errors). If any field fails type 

847 validation, roll back ALL fields from the snapshot so the config 

848 stays consistent — no half-applied updates. 

849 3. Persist to disk in batch (one file write for sets, one for deletes) 

850 rather than per-field, avoiding partial writes on crash. 

851 

852 Why not just setattr-and-save per field? A multi-field PATCH like 

853 {"chunk_size": 1024, "chunk_overlap": "bad"} would leave chunk_size 

854 changed but chunk_overlap unchanged — the caller gets an error but 

855 the config is silently modified. The snapshot/rollback prevents that. 

856 """ 

857 _validate_config_updates(updates) 

858 to_persist, to_delete = _apply_config_updates(updates) 

859 if to_persist: 

860 settings.update_values(cfg.data_root, to_persist) 

861 if to_delete: 

862 settings.delete_values(cfg.data_root, to_delete) 

863 if API_KEY_FIELDS & set(updates): 

864 inject_provider_keys() 

865 reindex_required = bool(REINDEX_FIELDS & set(updates)) 

866 return ConfigUpdateResponse(updated=list(updates), reindex_required=reindex_required) 

867 

868 

869async def delete_documents( 

870 names: list[str], *, delete_files: bool = False 

871) -> DocumentRemoveResponse: 

872 """Remove documents from the knowledge base by source name.""" 

873 result = get_services().store.remove_documents(names, delete_files=delete_files) 

874 return DocumentRemoveResponse(removed=result.removed, not_found=result.not_found) 

875 

876 

877async def list_documents( 

878 search: str = "", 

879 limit: int = 50, 

880 offset: int = 0, 

881) -> DocumentListResponse: 

882 """Return indexed documents with metadata, paginated and filterable. 

883 

884 Pagination and the filename filter are pushed into LanceDB via 

885 ``Store.get_sources(search=..., limit=..., offset=...)`` and the 

886 total comes from ``Store.count_sources(search=...)`` so neither 

887 call materializes the full SOURCES table per request. 

888 """ 

889 store = get_services().store 

890 search_term = search or None 

891 page = store.get_sources(search=search_term, limit=limit, offset=offset) 

892 total = store.count_sources(search=search_term) 

893 return DocumentListResponse( 

894 documents=[ 

895 DocumentInfo( 

896 filename=s["filename"], 

897 chunk_count=s.get("chunk_count", 0), 

898 ingested_at=s.get("ingested_at", ""), 

899 ) 

900 for s in page 

901 ], 

902 total=total, 

903 limit=limit, 

904 offset=offset, 

905 has_more=len(page) > 0 and (offset + len(page)) < total, 

906 ) 

907 

908 

909async def get_config() -> ConfigResponse: 

910 """Return all user-facing configuration values.""" 

911 dumped = cfg.model_dump() 

912 result = {k: v for k, v in dumped.items() if k in _PUBLIC_CONFIG_FIELDS} 

913 return ConfigResponse(**result) 

914 

915 

916async def get_source_content( 

917 source: str, raw: bool = False 

918) -> SourceContentResponse | tuple[bytes, str]: 

919 """Return a stored source file: JSON with markdown text for text types, or 

920 ``(bytes, content_type)`` when *raw* is True. Binary types return empty 

921 markdown so clients know to re-request with ``raw=1``. 

922 """ 

923 from lilbee.wiki.index import parse_title 

924 

925 if not source or not source.strip(): 

926 raise ValueError("source must not be empty") 

927 documents_dir = cfg.documents_dir 

928 resolved = validate_path_within(documents_dir / source, documents_dir) 

929 if not resolved.is_file(): 

930 raise FileNotFoundError(source) 

931 

932 content_type, _ = mimetypes.guess_type(resolved.name) 

933 if content_type is None: 

934 content_type = "application/octet-stream" 

935 

936 if raw: 

937 # Cap raw responses to inline-render-safe categories; anything else 

938 # degrades to a binary download so attacker-renamed files (e.g. 

939 # evil.html) can't trick the embedding browser into running script 

940 # under our origin. 

941 served_type = ( 

942 content_type if _is_safe_for_inline_render(content_type) else "application/octet-stream" 

943 ) 

944 return resolved.read_bytes(), served_type 

945 

946 if not content_type.startswith("text/"): 

947 return SourceContentResponse(markdown="", content_type=content_type, title=None) 

948 

949 text = resolved.read_text(encoding="utf-8", errors="replace") 

950 title = parse_title(text) or None 

951 return SourceContentResponse(markdown=text, content_type=content_type, title=title) 

952 

953 

954@functools.cache 

955def _compute_config_defaults() -> dict[str, Any]: 

956 """Materialize Config defaults once per process.""" 

957 defaults: dict[str, Any] = {} 

958 for name, info in Config.model_fields.items(): 

959 is_writable_public = name in WRITABLE_CONFIG_FIELDS and name in _PUBLIC_CONFIG_FIELDS 

960 if not is_writable_public and name not in _MODEL_ROLE_FIELDS: 

961 continue 

962 value = info.get_default(call_default_factory=True) 

963 if value is PydanticUndefined: # pragma: no cover 

964 continue 

965 defaults[name] = value 

966 return defaults 

967 

968 

969async def get_config_defaults() -> ConfigResponse: 

970 """Return canonical defaults for every public config field. 

971 

972 Covers writable fields (resettable via PATCH /api/config) and the 

973 model-role fields (resettable via PUT /api/models/<role>). 

974 

975 Deepcopies the cached dict so callers that mutate the response 

976 (list-valued fields like ``crawl_exclude_patterns``) cannot poison 

977 subsequent calls. 

978 """ 

979 return ConfigResponse(**copy.deepcopy(_compute_config_defaults())) 

980 

981 

982async def models_show(model: str) -> ModelsShowResponse: 

983 """Return model metadata/parameters. Returns empty model if unavailable.""" 

984 provider = get_services().provider 

985 result = provider.show_model(model) 

986 return ModelsShowResponse(**(result or {})) 

987 

988 

989def _parse_source(source: str) -> ModelSource: 

990 """Convert a source string to ModelSource enum.""" 

991 return ModelSource(source) 

992 

993 

994async def models_catalog( 

995 task: str | None = None, 

996 search: str = "", 

997 size: str | None = None, 

998 installed: bool | None = None, 

999 featured: bool | None = None, 

1000 sort: str = "featured", 

1001 limit: int = 20, 

1002 offset: int = 0, 

1003) -> ModelsCatalogResponse: 

1004 """Return paginated model catalog with installed status.""" 

1005 from lilbee.catalog import enrich_catalog, get_catalog 

1006 

1007 result = get_catalog( 

1008 task=task, 

1009 search=search, 

1010 size=size, 

1011 installed=installed, 

1012 featured=featured, 

1013 sort=sort, 

1014 limit=limit, 

1015 offset=offset, 

1016 ) 

1017 

1018 registry = get_services().registry 

1019 installed_refs = {m.ref for m in registry.list_installed()} 

1020 enriched = enrich_catalog(result, installed_refs) 

1021 

1022 return ModelsCatalogResponse( 

1023 total=result.total, 

1024 limit=result.limit, 

1025 offset=result.offset, 

1026 has_more=result.has_more, 

1027 models=[ 

1028 CatalogEntryResponse( 

1029 hf_repo=e.hf_repo, 

1030 gguf_filename=e.gguf_filename, 

1031 task=e.task, 

1032 display_name=e.display_name, 

1033 param_count=e.param_count, 

1034 size_gb=e.size_gb, 

1035 min_ram_gb=e.min_ram_gb, 

1036 description=e.description, 

1037 quality_tier=e.quality_tier, 

1038 featured=e.featured, 

1039 downloads=e.downloads, 

1040 installed=e.installed, 

1041 source=e.source, 

1042 ) 

1043 for e in enriched 

1044 ], 

1045 ) 

1046 

1047 

1048async def models_installed() -> ModelsInstalledResponse: 

1049 """Return list of installed models with their source.""" 

1050 manager = get_model_manager() 

1051 names = manager.list_installed() 

1052 models = [] 

1053 for name in names: 

1054 src = manager.get_source(name) 

1055 source_str = src.value if src is not None else ModelSource.REMOTE.value 

1056 models.append(InstalledModelEntry(name=name, source=source_str)) 

1057 return ModelsInstalledResponse(models=models) 

1058 

1059 

1060async def models_pull(model: str, *, source: str = "native") -> AsyncGenerator[str, None]: 

1061 """Yield SSE progress events while pulling a model in real time. 

1062 Sets a cancel event on client disconnect so the pull stops. 

1063 """ 

1064 manager = get_model_manager() 

1065 src = _parse_source(source) 

1066 sse = SseStream() 

1067 

1068 def _pull_blocking() -> None: 

1069 def _on_progress(data: dict[str, Any]) -> None: 

1070 if sse.cancel.is_set(): 

1071 return 

1072 payload = sse_event(SseEvent.PROGRESS, data) 

1073 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, payload) 

1074 

1075 def _on_bytes(downloaded: int, total: int) -> None: 

1076 if sse.cancel.is_set(): 

1077 return 

1078 payload = sse_event(SseEvent.PROGRESS, {"current": downloaded, "total": total}) 

1079 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, payload) 

1080 

1081 try: 

1082 manager.pull(model, src, on_progress=_on_progress, on_bytes=_on_bytes) 

1083 except Exception as exc: 

1084 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, sse_error(str(exc))) 

1085 finally: 

1086 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, None) 

1087 

1088 task = asyncio.ensure_future(asyncio.to_thread(_pull_blocking)) 

1089 async for event in sse.drain(task, "Model pull stream"): 

1090 yield event 

1091 

1092 

1093async def models_delete(model: str, *, source: str = "native") -> ModelsDeleteResponse: 

1094 """Delete a model. Returns deletion status, model name, and freed space.""" 

1095 manager = get_model_manager() 

1096 src = _parse_source(source) 

1097 deleted = manager.remove(model, src) 

1098 return ModelsDeleteResponse(deleted=deleted, model=model, freed_gb=0.0) 

1099 

1100 

1101async def crawl_stream( 

1102 url: str, depth: int | None = None, max_pages: int | None = None 

1103) -> AsyncGenerator[str, None]: 

1104 """Stream crawl progress as SSE events. 

1105 Emits crawl_start, crawl_page, crawl_done events, then a final done event 

1106 with the list of files written. On error emits crawl_error. 

1107 Sets a cancel event on client disconnect so the crawl stops between pages. 

1108 

1109 On first use, Chromium isn't installed yet. The stream inlines 

1110 setup_start/progress/done events before the crawl begins so the 

1111 Obsidian plugin's Task Center renders a matching 'setup' pill (bb-wq8g). 

1112 """ 

1113 sse = SseStream() 

1114 

1115 async def _run_crawl() -> list[Path]: 

1116 from lilbee.crawler import crawl_and_save 

1117 

1118 # crawl_and_save runs the Chromium bootstrap itself on first use, 

1119 # relaying setup_* events through the same on_progress callback 

1120 # so the SSE stream carries them before any crawl_* events. 

1121 try: 

1122 return await crawl_and_save( 

1123 url, depth=depth, max_pages=max_pages, on_progress=sse.callback, cancel=sse.cancel 

1124 ) 

1125 finally: 

1126 sse.queue.put_nowait(None) 

1127 

1128 task = asyncio.create_task(_run_crawl()) 

1129 async for event in sse.drain(task, "Crawl stream"): 

1130 yield event 

1131 if not sse.cancel.is_set() and task.done() and not task.cancelled(): 

1132 exc = task.exception() 

1133 if exc is not None: 

1134 yield sse_error(str(exc)) 

1135 return 

1136 paths = task.result() 

1137 yield sse_done({"files_written": [str(p) for p in paths]}) 

1138 

1139 

1140_EXTERNAL_MODELS_TTL = 60 

1141 

1142 

1143class _ExternalModelsCache: 

1144 """TTL cache for external model listings (no module-level mutable global).""" 

1145 

1146 def __init__(self) -> None: 

1147 self._time: float = 0.0 

1148 self._key: str = "" 

1149 self._result: ExternalModelsResponse | None = None 

1150 

1151 def get(self, key: str) -> ExternalModelsResponse | None: 

1152 now = time.monotonic() 

1153 if self._result and key == self._key and (now - self._time) < _EXTERNAL_MODELS_TTL: 

1154 return self._result 

1155 return None 

1156 

1157 def set(self, key: str, result: ExternalModelsResponse) -> None: 

1158 self._time = time.monotonic() 

1159 self._key = key 

1160 self._result = result 

1161 

1162 

1163_external_cache = _ExternalModelsCache() 

1164 

1165 

1166async def list_external_models() -> ExternalModelsResponse: 

1167 """Query the provider for available models via its list_models() API.""" 

1168 key = f"{cfg.remote_base_url}:{cfg.llm_api_key or ''}" 

1169 cached = _external_cache.get(key) 

1170 if cached: 

1171 return cached 

1172 

1173 try: 

1174 models = await asyncio.to_thread(get_services().provider.list_models) 

1175 result = ExternalModelsResponse(models=models) 

1176 _external_cache.set(key, result) 

1177 return result 

1178 except Exception as exc: 

1179 log.warning("Failed to list external models: %s", exc) 

1180 return ExternalModelsResponse(models=[], error=str(exc))