Coverage for src / lilbee / cli / helpers.py: 100%
190 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Shared helper functions for CLI commands and slash commands."""
3from __future__ import annotations
5import asyncio
6import json
7import logging
8import shutil
9from collections.abc import Generator
10from contextlib import contextmanager
11from dataclasses import dataclass, field
12from importlib.metadata import version as _pkg_version
13from pathlib import Path
14from typing import TYPE_CHECKING
16from pydantic import BaseModel
17from rich.console import Console, RenderableType
18from rich.table import Table
20from lilbee.cli import theme
21from lilbee.config import cfg
22from lilbee.platform import is_ignored_dir
23from lilbee.security import validate_path_within
24from lilbee.services import get_services
26if TYPE_CHECKING:
27 from lilbee.cli.sync import SyncStatus
28 from lilbee.store import SearchChunk
31class ResetResult(BaseModel):
32 """Result of a full knowledge base reset."""
34 command: str = "reset"
35 deleted_docs: int
36 deleted_data: int
37 skipped: list[str] = []
38 documents_dir: str
39 data_dir: str
42class StatusConfig(BaseModel):
43 """Configuration section of a status response.
45 Exposes all four role-bound model fields (chat, embedding, vision,
46 reranker) so the TUI status screen and plugin callers can show
47 what's active per role.
48 """
50 documents_dir: str
51 data_dir: str
52 chat_model: str
53 embedding_model: str
54 vision_model: str = ""
55 reranker_model: str = ""
56 enable_ocr: bool | None = None
59class SourceInfo(BaseModel):
60 """A single indexed source in a status response."""
62 filename: str
63 file_hash: str
64 chunk_count: int
65 ingested_at: str
68class StatusResult(BaseModel):
69 """Full status response for the knowledge base."""
71 command: str = "status"
72 config: StatusConfig
73 sources: list[SourceInfo]
74 total_chunks: int
76 def __rich_console__(
77 self, console: Console, options: object
78 ) -> Generator[RenderableType, None, None]:
79 yield f"[{theme.LABEL}]Documents:[/{theme.LABEL}] {self.config.documents_dir}"
80 yield f"[{theme.LABEL}]Database:[/{theme.LABEL}] {self.config.data_dir}"
81 yield f"[{theme.LABEL}]Chat model:[/{theme.LABEL}] {self.config.chat_model}"
82 yield f"[{theme.LABEL}]Embeddings:[/{theme.LABEL}] {self.config.embedding_model}"
83 vision = self.config.vision_model or "(disabled)"
84 reranker = self.config.reranker_model or "(disabled)"
85 yield f"[{theme.LABEL}]Vision:[/{theme.LABEL}] {vision}"
86 yield f"[{theme.LABEL}]Reranker:[/{theme.LABEL}] {reranker}"
87 if self.config.enable_ocr is not None:
88 ocr_label = "enabled" if self.config.enable_ocr else "disabled"
89 yield f"[{theme.LABEL}]Vision OCR:[/{theme.LABEL}] {ocr_label}"
90 yield ""
92 if not self.sources:
93 yield (
94 "No documents indexed. Drop files into the documents directory "
95 "and run 'lilbee sync'."
96 )
97 return
99 table = Table(title="Indexed Documents")
100 table.add_column("File", style=theme.ACCENT)
101 table.add_column("Hash", style=theme.MUTED, max_width=12)
102 table.add_column("Chunks", justify="right")
103 table.add_column("Ingested", style=theme.MUTED)
104 for s in self.sources:
105 table.add_row(s.filename, s.file_hash, str(s.chunk_count), s.ingested_at)
106 yield table
107 b = theme.LABEL
108 yield f"\n[{b}]{len(self.sources)}[/{b}] documents, [{b}]{self.total_chunks}[/{b}] chunks"
111def _copytree_ignore(directory: str, contents: list[str]) -> set[str]:
112 """Ignore callback for shutil.copytree that filters ignored directories."""
113 return {
114 name
115 for name in contents
116 if (Path(directory) / name).is_dir() and is_ignored_dir(name, cfg.ignore_dirs)
117 }
120def get_version() -> str:
121 """Return the installed lilbee version."""
122 return _pkg_version("lilbee")
125def json_output(data: dict) -> None:
126 """Print a JSON object to stdout."""
127 print(json.dumps(data))
130def clean_result(result: SearchChunk) -> dict:
131 """Return SearchChunk as a JSON dict, stamping vault_path when resolvable."""
132 payload = result.model_dump(exclude={"vector"}, exclude_none=True)
133 vault_path = resolve_vault_path(result.source)
134 if vault_path is not None:
135 payload["vault_path"] = vault_path
136 return payload
139def resolve_vault_path(source_filename: str) -> str | None:
140 """Return *source_filename* as a vault-relative path, or None if unresolvable.
142 Resolves symlinks on both sides and rejects ``..`` escapes from
143 ``documents_dir``.
144 """
145 if cfg.vault_base is None:
146 return None
147 try:
148 vault_base = cfg.vault_base.resolve()
149 documents_dir = cfg.documents_dir.resolve()
150 source_path = (cfg.documents_dir / source_filename).resolve()
151 source_path.relative_to(documents_dir)
152 relative_docs_dir = documents_dir.relative_to(vault_base)
153 except (OSError, ValueError):
154 return None
155 if not source_path.is_file():
156 return None
157 return (relative_docs_dir / source_path.relative_to(documents_dir)).as_posix()
160def gather_status() -> StatusResult:
161 """Collect status data as a typed model (shared by human + JSON output)."""
162 sources = get_services().store.get_sources()
163 sorted_sources = sorted(sources, key=lambda x: x["filename"])
164 total_chunks = sum(s["chunk_count"] for s in sources)
165 return StatusResult(
166 config=StatusConfig(
167 documents_dir=str(cfg.documents_dir),
168 data_dir=str(cfg.data_dir),
169 chat_model=cfg.chat_model,
170 embedding_model=cfg.embedding_model,
171 vision_model=cfg.vision_model,
172 reranker_model=cfg.reranker_model,
173 enable_ocr=cfg.enable_ocr,
174 ),
175 sources=[
176 SourceInfo(
177 filename=s["filename"],
178 file_hash=s["file_hash"][:12],
179 chunk_count=s["chunk_count"],
180 ingested_at=s["ingested_at"][:19],
181 )
182 for s in sorted_sources
183 ],
184 total_chunks=total_chunks,
185 )
188def render_status(con: Console) -> None:
189 """Print status info (documents, paths, chunk counts)."""
190 con.print(gather_status())
193@dataclass
194class CopyResult:
195 """Result of copying files into the documents directory."""
197 copied: list[str] = field(default_factory=list)
198 skipped: list[str] = field(default_factory=list)
201def copy_files(paths: list[Path], *, force: bool = False) -> CopyResult:
202 """Copy paths into documents dir. Returns structured result (no console output)."""
203 cfg.documents_dir.mkdir(parents=True, exist_ok=True)
204 result = CopyResult()
205 for p in paths:
206 dest = cfg.documents_dir / p.name
207 validate_path_within(dest, cfg.documents_dir)
208 if dest.exists() and not force:
209 result.skipped.append(p.name)
210 continue
211 if p.is_dir():
212 shutil.copytree(p, dest, dirs_exist_ok=True, ignore=_copytree_ignore, symlinks=False)
213 else:
214 shutil.copy2(p, dest)
215 result.copied.append(p.name)
216 return result
219def copy_paths(paths: list[Path], con: Console, *, force: bool = False) -> list[str]:
220 """Copy *paths* into the documents directory. Returns list of copied names."""
221 result = copy_files(paths, force=force)
222 for name in result.skipped:
223 con.print(
224 f"[{theme.WARNING}]Warning:[/{theme.WARNING}] {name} already exists in knowledge base "
225 f"(use --force to overwrite)"
226 )
227 return result.copied
230def add_paths(
231 paths: list[Path],
232 con: Console,
233 *,
234 force: bool = False,
235 background: bool = False,
236 chat_mode: bool = False,
237 sync_status: SyncStatus | None = None,
238) -> None:
239 """Copy *paths* into the knowledge base and sync (human output).
240 When *background* is True (chat ``/add``), sync runs in a background thread
241 and this function returns immediately after copying files.
242 """
243 from lilbee.ingest import sync
245 copied = copy_paths(paths, con, force=force)
246 if chat_mode:
247 print(f"Copied {len(copied)} path(s) to {cfg.documents_dir}")
248 else:
249 con.print(
250 f"[{theme.MUTED}]Copied {len(copied)} path(s) to {cfg.documents_dir}[/{theme.MUTED}]"
251 )
253 if background:
254 from lilbee.cli.sync import run_sync_background
256 run_sync_background(con, chat_mode=chat_mode, sync_status=sync_status)
257 return
259 result = asyncio.run(sync())
260 con.print(result)
263def _clear_dir(base_dir: Path, skipped: list[str]) -> int:
264 """Delete all items in *base_dir*, appending undeletable paths to *skipped*."""
265 log = logging.getLogger(__name__)
266 deleted = 0
267 if not base_dir.exists():
268 return deleted
269 for item in list(base_dir.iterdir()):
270 validate_path_within(item, base_dir)
271 try:
272 if item.is_dir():
273 shutil.rmtree(item)
274 else:
275 item.unlink()
276 except OSError as exc:
277 log.warning("Could not delete %s: %s", item, exc)
278 skipped.append(str(item))
279 continue
280 deleted += 1
281 return deleted
284def perform_reset() -> ResetResult:
285 """Delete all documents and data. Returns summary of what was deleted."""
286 skipped: list[str] = []
287 deleted_docs = _clear_dir(cfg.documents_dir, skipped)
288 deleted_data = _clear_dir(cfg.data_dir, skipped)
290 return ResetResult(
291 deleted_docs=deleted_docs,
292 deleted_data=deleted_data,
293 skipped=skipped,
294 documents_dir=str(cfg.documents_dir),
295 data_dir=str(cfg.data_dir),
296 )
299def sync_result_to_json(result: object) -> dict:
300 """Convert a SyncResult to the JSON output envelope."""
301 from lilbee.ingest import SyncResult
303 if not isinstance(result, SyncResult):
304 raise TypeError(f"Expected SyncResult, got {type(result).__name__}")
305 return {"command": "sync", **result.model_dump()}
308def auto_sync(con: Console, *, background: bool = False) -> None:
309 """Run document sync before queries.
310 When *background* is True, sync runs in a background thread and this
311 function returns immediately (for chat/REPL). When False (default),
312 sync blocks until complete (for ``lilbee ask``).
313 """
314 if background:
315 from lilbee.cli.sync import run_sync_background
317 run_sync_background(con)
318 return
320 from lilbee.ingest import sync
322 try:
323 result = asyncio.run(sync())
324 except RuntimeError as exc:
325 con.print(f"[{theme.ERROR}]Error:[/{theme.ERROR}] {exc}")
326 raise SystemExit(1) from None
327 total = len(result.added) + len(result.updated) + len(result.removed) + len(result.failed)
328 if total:
329 con.print(
330 f"[{theme.MUTED}]Synced: {len(result.added)} added, "
331 f"{len(result.updated)} updated, "
332 f"{len(result.removed)} removed, "
333 f"{len(result.failed)} failed[/{theme.MUTED}]"
334 )
337@contextmanager
338def temporary_ocr_config(
339 enable_ocr: bool | None = None,
340 ocr_timeout: float | None = None,
341) -> Generator[None, None, None]:
342 """Temporarily override OCR config for the duration of the block."""
343 old_ocr, old_timeout = cfg.enable_ocr, cfg.ocr_timeout
344 try:
345 if enable_ocr is not None:
346 cfg.enable_ocr = enable_ocr
347 if ocr_timeout is not None:
348 cfg.ocr_timeout = ocr_timeout
349 yield
350 finally:
351 cfg.enable_ocr = old_ocr
352 cfg.ocr_timeout = old_timeout