Coverage for src / lilbee / cli / helpers.py: 100%

190 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Shared helper functions for CLI commands and slash commands.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import json 

7import logging 

8import shutil 

9from collections.abc import Generator 

10from contextlib import contextmanager 

11from dataclasses import dataclass, field 

12from importlib.metadata import version as _pkg_version 

13from pathlib import Path 

14from typing import TYPE_CHECKING 

15 

16from pydantic import BaseModel 

17from rich.console import Console, RenderableType 

18from rich.table import Table 

19 

20from lilbee.cli import theme 

21from lilbee.config import cfg 

22from lilbee.platform import is_ignored_dir 

23from lilbee.security import validate_path_within 

24from lilbee.services import get_services 

25 

26if TYPE_CHECKING: 

27 from lilbee.cli.sync import SyncStatus 

28 from lilbee.store import SearchChunk 

29 

30 

31class ResetResult(BaseModel): 

32 """Result of a full knowledge base reset.""" 

33 

34 command: str = "reset" 

35 deleted_docs: int 

36 deleted_data: int 

37 skipped: list[str] = [] 

38 documents_dir: str 

39 data_dir: str 

40 

41 

42class StatusConfig(BaseModel): 

43 """Configuration section of a status response. 

44 

45 Exposes all four role-bound model fields (chat, embedding, vision, 

46 reranker) so the TUI status screen and plugin callers can show 

47 what's active per role. 

48 """ 

49 

50 documents_dir: str 

51 data_dir: str 

52 chat_model: str 

53 embedding_model: str 

54 vision_model: str = "" 

55 reranker_model: str = "" 

56 enable_ocr: bool | None = None 

57 

58 

59class SourceInfo(BaseModel): 

60 """A single indexed source in a status response.""" 

61 

62 filename: str 

63 file_hash: str 

64 chunk_count: int 

65 ingested_at: str 

66 

67 

68class StatusResult(BaseModel): 

69 """Full status response for the knowledge base.""" 

70 

71 command: str = "status" 

72 config: StatusConfig 

73 sources: list[SourceInfo] 

74 total_chunks: int 

75 

76 def __rich_console__( 

77 self, console: Console, options: object 

78 ) -> Generator[RenderableType, None, None]: 

79 yield f"[{theme.LABEL}]Documents:[/{theme.LABEL}] {self.config.documents_dir}" 

80 yield f"[{theme.LABEL}]Database:[/{theme.LABEL}] {self.config.data_dir}" 

81 yield f"[{theme.LABEL}]Chat model:[/{theme.LABEL}] {self.config.chat_model}" 

82 yield f"[{theme.LABEL}]Embeddings:[/{theme.LABEL}] {self.config.embedding_model}" 

83 vision = self.config.vision_model or "(disabled)" 

84 reranker = self.config.reranker_model or "(disabled)" 

85 yield f"[{theme.LABEL}]Vision:[/{theme.LABEL}] {vision}" 

86 yield f"[{theme.LABEL}]Reranker:[/{theme.LABEL}] {reranker}" 

87 if self.config.enable_ocr is not None: 

88 ocr_label = "enabled" if self.config.enable_ocr else "disabled" 

89 yield f"[{theme.LABEL}]Vision OCR:[/{theme.LABEL}] {ocr_label}" 

90 yield "" 

91 

92 if not self.sources: 

93 yield ( 

94 "No documents indexed. Drop files into the documents directory " 

95 "and run 'lilbee sync'." 

96 ) 

97 return 

98 

99 table = Table(title="Indexed Documents") 

100 table.add_column("File", style=theme.ACCENT) 

101 table.add_column("Hash", style=theme.MUTED, max_width=12) 

102 table.add_column("Chunks", justify="right") 

103 table.add_column("Ingested", style=theme.MUTED) 

104 for s in self.sources: 

105 table.add_row(s.filename, s.file_hash, str(s.chunk_count), s.ingested_at) 

106 yield table 

107 b = theme.LABEL 

108 yield f"\n[{b}]{len(self.sources)}[/{b}] documents, [{b}]{self.total_chunks}[/{b}] chunks" 

109 

110 

111def _copytree_ignore(directory: str, contents: list[str]) -> set[str]: 

112 """Ignore callback for shutil.copytree that filters ignored directories.""" 

113 return { 

114 name 

115 for name in contents 

116 if (Path(directory) / name).is_dir() and is_ignored_dir(name, cfg.ignore_dirs) 

117 } 

118 

119 

120def get_version() -> str: 

121 """Return the installed lilbee version.""" 

122 return _pkg_version("lilbee") 

123 

124 

125def json_output(data: dict) -> None: 

126 """Print a JSON object to stdout.""" 

127 print(json.dumps(data)) 

128 

129 

130def clean_result(result: SearchChunk) -> dict: 

131 """Return SearchChunk as a JSON dict, stamping vault_path when resolvable.""" 

132 payload = result.model_dump(exclude={"vector"}, exclude_none=True) 

133 vault_path = resolve_vault_path(result.source) 

134 if vault_path is not None: 

135 payload["vault_path"] = vault_path 

136 return payload 

137 

138 

139def resolve_vault_path(source_filename: str) -> str | None: 

140 """Return *source_filename* as a vault-relative path, or None if unresolvable. 

141 

142 Resolves symlinks on both sides and rejects ``..`` escapes from 

143 ``documents_dir``. 

144 """ 

145 if cfg.vault_base is None: 

146 return None 

147 try: 

148 vault_base = cfg.vault_base.resolve() 

149 documents_dir = cfg.documents_dir.resolve() 

150 source_path = (cfg.documents_dir / source_filename).resolve() 

151 source_path.relative_to(documents_dir) 

152 relative_docs_dir = documents_dir.relative_to(vault_base) 

153 except (OSError, ValueError): 

154 return None 

155 if not source_path.is_file(): 

156 return None 

157 return (relative_docs_dir / source_path.relative_to(documents_dir)).as_posix() 

158 

159 

160def gather_status() -> StatusResult: 

161 """Collect status data as a typed model (shared by human + JSON output).""" 

162 sources = get_services().store.get_sources() 

163 sorted_sources = sorted(sources, key=lambda x: x["filename"]) 

164 total_chunks = sum(s["chunk_count"] for s in sources) 

165 return StatusResult( 

166 config=StatusConfig( 

167 documents_dir=str(cfg.documents_dir), 

168 data_dir=str(cfg.data_dir), 

169 chat_model=cfg.chat_model, 

170 embedding_model=cfg.embedding_model, 

171 vision_model=cfg.vision_model, 

172 reranker_model=cfg.reranker_model, 

173 enable_ocr=cfg.enable_ocr, 

174 ), 

175 sources=[ 

176 SourceInfo( 

177 filename=s["filename"], 

178 file_hash=s["file_hash"][:12], 

179 chunk_count=s["chunk_count"], 

180 ingested_at=s["ingested_at"][:19], 

181 ) 

182 for s in sorted_sources 

183 ], 

184 total_chunks=total_chunks, 

185 ) 

186 

187 

188def render_status(con: Console) -> None: 

189 """Print status info (documents, paths, chunk counts).""" 

190 con.print(gather_status()) 

191 

192 

193@dataclass 

194class CopyResult: 

195 """Result of copying files into the documents directory.""" 

196 

197 copied: list[str] = field(default_factory=list) 

198 skipped: list[str] = field(default_factory=list) 

199 

200 

201def copy_files(paths: list[Path], *, force: bool = False) -> CopyResult: 

202 """Copy paths into documents dir. Returns structured result (no console output).""" 

203 cfg.documents_dir.mkdir(parents=True, exist_ok=True) 

204 result = CopyResult() 

205 for p in paths: 

206 dest = cfg.documents_dir / p.name 

207 validate_path_within(dest, cfg.documents_dir) 

208 if dest.exists() and not force: 

209 result.skipped.append(p.name) 

210 continue 

211 if p.is_dir(): 

212 shutil.copytree(p, dest, dirs_exist_ok=True, ignore=_copytree_ignore, symlinks=False) 

213 else: 

214 shutil.copy2(p, dest) 

215 result.copied.append(p.name) 

216 return result 

217 

218 

219def copy_paths(paths: list[Path], con: Console, *, force: bool = False) -> list[str]: 

220 """Copy *paths* into the documents directory. Returns list of copied names.""" 

221 result = copy_files(paths, force=force) 

222 for name in result.skipped: 

223 con.print( 

224 f"[{theme.WARNING}]Warning:[/{theme.WARNING}] {name} already exists in knowledge base " 

225 f"(use --force to overwrite)" 

226 ) 

227 return result.copied 

228 

229 

230def add_paths( 

231 paths: list[Path], 

232 con: Console, 

233 *, 

234 force: bool = False, 

235 background: bool = False, 

236 chat_mode: bool = False, 

237 sync_status: SyncStatus | None = None, 

238) -> None: 

239 """Copy *paths* into the knowledge base and sync (human output). 

240 When *background* is True (chat ``/add``), sync runs in a background thread 

241 and this function returns immediately after copying files. 

242 """ 

243 from lilbee.ingest import sync 

244 

245 copied = copy_paths(paths, con, force=force) 

246 if chat_mode: 

247 print(f"Copied {len(copied)} path(s) to {cfg.documents_dir}") 

248 else: 

249 con.print( 

250 f"[{theme.MUTED}]Copied {len(copied)} path(s) to {cfg.documents_dir}[/{theme.MUTED}]" 

251 ) 

252 

253 if background: 

254 from lilbee.cli.sync import run_sync_background 

255 

256 run_sync_background(con, chat_mode=chat_mode, sync_status=sync_status) 

257 return 

258 

259 result = asyncio.run(sync()) 

260 con.print(result) 

261 

262 

263def _clear_dir(base_dir: Path, skipped: list[str]) -> int: 

264 """Delete all items in *base_dir*, appending undeletable paths to *skipped*.""" 

265 log = logging.getLogger(__name__) 

266 deleted = 0 

267 if not base_dir.exists(): 

268 return deleted 

269 for item in list(base_dir.iterdir()): 

270 validate_path_within(item, base_dir) 

271 try: 

272 if item.is_dir(): 

273 shutil.rmtree(item) 

274 else: 

275 item.unlink() 

276 except OSError as exc: 

277 log.warning("Could not delete %s: %s", item, exc) 

278 skipped.append(str(item)) 

279 continue 

280 deleted += 1 

281 return deleted 

282 

283 

284def perform_reset() -> ResetResult: 

285 """Delete all documents and data. Returns summary of what was deleted.""" 

286 skipped: list[str] = [] 

287 deleted_docs = _clear_dir(cfg.documents_dir, skipped) 

288 deleted_data = _clear_dir(cfg.data_dir, skipped) 

289 

290 return ResetResult( 

291 deleted_docs=deleted_docs, 

292 deleted_data=deleted_data, 

293 skipped=skipped, 

294 documents_dir=str(cfg.documents_dir), 

295 data_dir=str(cfg.data_dir), 

296 ) 

297 

298 

299def sync_result_to_json(result: object) -> dict: 

300 """Convert a SyncResult to the JSON output envelope.""" 

301 from lilbee.ingest import SyncResult 

302 

303 if not isinstance(result, SyncResult): 

304 raise TypeError(f"Expected SyncResult, got {type(result).__name__}") 

305 return {"command": "sync", **result.model_dump()} 

306 

307 

308def auto_sync(con: Console, *, background: bool = False) -> None: 

309 """Run document sync before queries. 

310 When *background* is True, sync runs in a background thread and this 

311 function returns immediately (for chat/REPL). When False (default), 

312 sync blocks until complete (for ``lilbee ask``). 

313 """ 

314 if background: 

315 from lilbee.cli.sync import run_sync_background 

316 

317 run_sync_background(con) 

318 return 

319 

320 from lilbee.ingest import sync 

321 

322 try: 

323 result = asyncio.run(sync()) 

324 except RuntimeError as exc: 

325 con.print(f"[{theme.ERROR}]Error:[/{theme.ERROR}] {exc}") 

326 raise SystemExit(1) from None 

327 total = len(result.added) + len(result.updated) + len(result.removed) + len(result.failed) 

328 if total: 

329 con.print( 

330 f"[{theme.MUTED}]Synced: {len(result.added)} added, " 

331 f"{len(result.updated)} updated, " 

332 f"{len(result.removed)} removed, " 

333 f"{len(result.failed)} failed[/{theme.MUTED}]" 

334 ) 

335 

336 

337@contextmanager 

338def temporary_ocr_config( 

339 enable_ocr: bool | None = None, 

340 ocr_timeout: float | None = None, 

341) -> Generator[None, None, None]: 

342 """Temporarily override OCR config for the duration of the block.""" 

343 old_ocr, old_timeout = cfg.enable_ocr, cfg.ocr_timeout 

344 try: 

345 if enable_ocr is not None: 

346 cfg.enable_ocr = enable_ocr 

347 if ocr_timeout is not None: 

348 cfg.ocr_timeout = ocr_timeout 

349 yield 

350 finally: 

351 cfg.enable_ocr = old_ocr 

352 cfg.ocr_timeout = old_timeout