Coverage for src / lilbee / cli / chat.py: 100%

291 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 08:27 +0000

1"""Chat loop, slash-command dispatch, and tab completion.""" 

2 

3from __future__ import annotations 

4 

5import sys 

6from collections.abc import Callable 

7from dataclasses import dataclass 

8from pathlib import Path 

9from typing import TYPE_CHECKING 

10 

11from pydantic import ValidationError 

12from rich.console import Console 

13from rich.table import Table 

14 

15from lilbee import settings 

16from lilbee.cli.helpers import ( 

17 add_paths, 

18 get_version, 

19 perform_reset, 

20 render_status, 

21 stream_response, 

22) 

23from lilbee.config import cfg 

24 

25if TYPE_CHECKING: 

26 from lilbee.query import ChatMessage 

27 

28_ADD_PREFIX = "/add " 

29_MODEL_PREFIX = "/model " 

30_VISION_PREFIX = "/vision " 

31_SET_PREFIX = "/set " 

32 

33 

34@dataclass(frozen=True) 

35class _SettingDef: 

36 """Metadata for an interactive setting.""" 

37 

38 cfg_attr: str 

39 type: type 

40 nullable: bool 

41 

42 

43_SETTINGS_MAP: dict[str, _SettingDef] = { 

44 "chat_model": _SettingDef("chat_model", str, nullable=False), 

45 "vision_model": _SettingDef("vision_model", str, nullable=True), 

46 "embedding_model": _SettingDef("embedding_model", str, nullable=False), 

47 "top_k": _SettingDef("top_k", int, nullable=False), 

48 "temperature": _SettingDef("temperature", float, nullable=True), 

49 "top_p": _SettingDef("top_p", float, nullable=True), 

50 "top_k_sampling": _SettingDef("top_k_sampling", int, nullable=True), 

51 "repeat_penalty": _SettingDef("repeat_penalty", float, nullable=True), 

52 "num_ctx": _SettingDef("num_ctx", int, nullable=True), 

53 "seed": _SettingDef("seed", int, nullable=True), 

54 "system_prompt": _SettingDef("system_prompt", str, nullable=False), 

55} 

56 

57 

58class QuitChat(Exception): 

59 """Raised by /quit to exit the chat loop.""" 

60 

61 

62def _pick_from_catalog( 

63 catalog: tuple, 

64 display_fn: Callable, 

65 con: Console, 

66 config_attr: str, 

67 setting_key: str, 

68 label: str, 

69) -> None: 

70 """Shared interactive picker flow for model/vision catalog selection. 

71 

72 *display_fn* is called with (ram_gb, free_disk_gb) and returns the recommended ModelInfo. 

73 *config_attr* is the cfg attribute to set (e.g. "chat_model" or "vision_model"). 

74 *setting_key* is the settings.toml key (e.g. "chat_model" or "vision_model"). 

75 *label* is a human-readable label for messages (e.g. "Switched to model"). 

76 """ 

77 from lilbee.models import ( 

78 get_free_disk_gb, 

79 get_system_ram_gb, 

80 pull_with_progress, 

81 validate_disk_and_pull, 

82 ) 

83 

84 ram_gb = get_system_ram_gb() 

85 free_disk_gb = get_free_disk_gb(cfg.data_dir) 

86 recommended = display_fn(ram_gb, free_disk_gb) 

87 default_idx = list(catalog).index(recommended) + 1 

88 installed = set(list_ollama_models()) 

89 

90 try: 

91 raw = input(f"Choice [{default_idx}]: ").strip() 

92 except (EOFError, KeyboardInterrupt): 

93 return 

94 

95 if not raw: 

96 return 

97 

98 try: 

99 choice = int(raw) 

100 except ValueError: 

101 con.print(f"[red]Enter a number 1-{len(catalog)}.[/red]") 

102 return 

103 

104 if not (1 <= choice <= len(catalog)): 

105 con.print(f"[red]Enter a number 1-{len(catalog)}.[/red]") 

106 return 

107 

108 model_info = catalog[choice - 1] 

109 if model_info.name not in installed: 

110 if config_attr == "chat_model": 

111 validate_disk_and_pull(model_info, free_disk_gb) 

112 else: 

113 pull_with_progress(model_info.name) 

114 setattr(cfg, config_attr, model_info.name) 

115 settings.set_value(cfg.data_root, setting_key, model_info.name) 

116 con.print(f"{label} [bold]{model_info.name}[/bold] (saved)") 

117 

118 

119def _set_named_model( 

120 name: str, 

121 con: Console, 

122 config_attr: str, 

123 setting_key: str, 

124 label: str, 

125 *, 

126 exclude_vision: bool = False, 

127) -> None: 

128 """Validate and set a named model directly (no picker).""" 

129 from lilbee.models import ensure_tag 

130 

131 name = ensure_tag(name) 

132 available = list_ollama_models(exclude_vision=exclude_vision) 

133 if available and name not in available: 

134 con.print(f"[red]Unknown model:[/red] {name}") 

135 con.print(f"Available: {', '.join(sorted(available))}") 

136 return 

137 setattr(cfg, config_attr, name) 

138 settings.set_value(cfg.data_root, setting_key, name) 

139 con.print(f"{label} [bold]{name}[/bold] (saved)") 

140 

141 

142def handle_slash_status(args: str, con: Console) -> None: 

143 render_status(con) 

144 

145 

146def handle_slash_add(args: str, con: Console) -> None: 

147 raw = args.strip() 

148 if raw: 

149 p = Path(raw).expanduser() 

150 if not p.exists(): 

151 con.print(f"[red]Path not found:[/red] {raw}") 

152 return 

153 add_paths([p], con, force=True) 

154 else: 

155 try: 

156 from prompt_toolkit import prompt as pt_prompt 

157 from prompt_toolkit.completion import PathCompleter 

158 

159 raw = pt_prompt("path: ", completer=PathCompleter(expanduser=True)) 

160 except (ImportError, EOFError, KeyboardInterrupt): 

161 return 

162 raw = raw.strip() 

163 if not raw: 

164 return 

165 p = Path(raw).expanduser() 

166 if not p.exists(): 

167 con.print(f"[red]Path not found:[/red] {raw}") 

168 return 

169 add_paths([p], con, force=True) 

170 

171 

172def handle_slash_quit(args: str, con: Console) -> None: 

173 raise QuitChat 

174 

175 

176def handle_slash_model(args: str, con: Console) -> None: 

177 name = args.strip() 

178 if name: 

179 _set_named_model( 

180 name, con, "chat_model", "chat_model", "Switched to model", exclude_vision=True 

181 ) 

182 return 

183 

184 con.print(f"[bold]Current model:[/bold] {cfg.chat_model}\n") 

185 from lilbee.models import MODEL_CATALOG, display_model_picker 

186 

187 _pick_from_catalog( 

188 MODEL_CATALOG, 

189 display_model_picker, 

190 con, 

191 config_attr="chat_model", 

192 setting_key="chat_model", 

193 label="Switched to model", 

194 ) 

195 

196 

197def handle_slash_vision(args: str, con: Console) -> None: 

198 name = args.strip() 

199 

200 if name == "off": 

201 cfg.vision_model = "" 

202 settings.set_value(cfg.data_root, "vision_model", "") 

203 con.print("Vision OCR [bold]disabled[/bold] (saved)") 

204 return 

205 

206 if name: 

207 _set_named_model(name, con, "vision_model", "vision_model", "Vision model set to") 

208 return 

209 

210 # bare /vision — show status then picker 

211 if cfg.vision_model: 

212 con.print(f"[bold]Vision OCR:[/bold] {cfg.vision_model}\n") 

213 else: 

214 con.print("[bold]Vision OCR:[/bold] disabled\n") 

215 

216 from lilbee.models import VISION_CATALOG, display_vision_picker 

217 

218 _pick_from_catalog( 

219 VISION_CATALOG, 

220 display_vision_picker, 

221 con, 

222 config_attr="vision_model", 

223 setting_key="vision_model", 

224 label="Vision model set to", 

225 ) 

226 

227 

228def handle_slash_version(args: str, con: Console) -> None: 

229 con.print(f"lilbee [bold]{get_version()}[/bold]") 

230 

231 

232def handle_slash_reset(args: str, con: Console) -> None: 

233 con.print( 

234 "[bold red]This will delete ALL documents and data.[/bold red]\nType 'yes' to confirm:" 

235 ) 

236 try: 

237 answer = con.input("[bold red]> [/bold red]") 

238 except (EOFError, KeyboardInterrupt): 

239 con.print("Aborted.") 

240 return 

241 if answer.strip().lower() != "yes": 

242 con.print("Aborted.") 

243 return 

244 result = perform_reset() 

245 con.print( 

246 f"Reset complete: {result.deleted_docs} document(s), " 

247 f"{result.deleted_data} data item(s) deleted." 

248 ) 

249 

250 

251def _get_model_defaults() -> dict[str, str]: 

252 """Fetch generation parameter defaults from Ollama for the current chat model.""" 

253 _OLLAMA_TO_SETTING = {"top_k": "top_k_sampling"} 

254 try: 

255 import ollama 

256 

257 resp = ollama.show(cfg.chat_model) 

258 defaults: dict[str, str] = {} 

259 for line in (resp.parameters or "").splitlines(): 

260 parts = line.split() 

261 if len(parts) >= 2: 

262 key = _OLLAMA_TO_SETTING.get(parts[0], parts[0]) 

263 if key in _SETTINGS_MAP: 

264 defaults[key] = parts[1] 

265 return defaults 

266 except (ollama.ResponseError, ConnectionError, OSError): 

267 return {} 

268 

269 

270def _format_setting_value(value: object, model_default: str | None = None) -> str: 

271 """Format a setting value for display.""" 

272 if value is None or value == "": 

273 if model_default is not None: 

274 return f"[dim](model default: {model_default})[/dim]" 

275 return "[dim](not set)[/dim]" 

276 if isinstance(value, str) and len(value) > 60: 

277 return f"[dim]({len(value)} chars)[/dim]" 

278 return str(value) 

279 

280 

281def handle_slash_settings(args: str, con: Console) -> None: 

282 defaults = _get_model_defaults() 

283 table = Table(show_header=False, box=None, padding=(0, 2)) 

284 for name, defn in _SETTINGS_MAP.items(): 

285 value = getattr(cfg, defn.cfg_attr) 

286 table.add_row( 

287 f"[bold]{name}[/bold]", 

288 _format_setting_value(value, defaults.get(name)), 

289 ) 

290 con.print(table) 

291 

292 

293def handle_slash_set(args: str, con: Console) -> None: 

294 parts = args.strip().split(None, 1) 

295 if not parts: 

296 con.print("Usage: /set <param> [value] — try /settings to see all params") 

297 return 

298 

299 name = parts[0] 

300 defn = _SETTINGS_MAP.get(name) 

301 if defn is None: 

302 con.print(f"[red]Unknown setting:[/red] {name}") 

303 con.print(f"Available: {', '.join(_SETTINGS_MAP)}") 

304 return 

305 

306 if len(parts) == 1: 

307 value = getattr(cfg, defn.cfg_attr) 

308 con.print(f"{name} = {_format_setting_value(value)}") 

309 return 

310 

311 raw_value = parts[1] 

312 

313 if raw_value.lower() in ("off", "none", "default"): 

314 if not defn.nullable: 

315 con.print(f"[red]{name} cannot be cleared[/red]") 

316 return 

317 setattr(cfg, defn.cfg_attr, "" if defn.type is str else None) 

318 settings.delete_value(cfg.data_root, name) 

319 con.print(f"{name} cleared (saved)") 

320 return 

321 

322 try: 

323 parsed = defn.type(raw_value) 

324 except (ValueError, TypeError): 

325 con.print(f"[red]Invalid {defn.type.__name__}:[/red] {raw_value}") 

326 return 

327 

328 try: 

329 setattr(cfg, defn.cfg_attr, parsed) 

330 except ValidationError as exc: 

331 msg = exc.errors()[0]["msg"] 

332 con.print(f"[red]{name}: {msg}[/red]") 

333 return 

334 

335 settings.set_value(cfg.data_root, name, str(parsed)) 

336 con.print(f"{name} = {parsed} (saved)") 

337 

338 

339def handle_slash_help(args: str, con: Console) -> None: 

340 con.print("[bold]Slash commands:[/bold]") 

341 con.print(" /status — show indexed documents and config") 

342 con.print(" /add [path] — add a file or directory (tab-completes without args)") 

343 con.print(" /model [name] — show or switch chat model") 

344 con.print(" /vision [name|off] — show or switch vision OCR model") 

345 con.print(" /settings — show all generation settings") 

346 con.print(" /set <param> [value] — change a setting (tab-completes param names)") 

347 con.print(" /version — show lilbee version") 

348 con.print(" /reset — delete all documents and data") 

349 con.print(" /help — show this help") 

350 con.print(" /quit — exit chat") 

351 

352 

353_SLASH_COMMANDS: dict[str, Callable[[str, Console], None]] = { 

354 "status": handle_slash_status, 

355 "add": handle_slash_add, 

356 "model": handle_slash_model, 

357 "vision": handle_slash_vision, 

358 "settings": handle_slash_settings, 

359 "set": handle_slash_set, 

360 "version": handle_slash_version, 

361 "reset": handle_slash_reset, 

362 "help": handle_slash_help, 

363 "quit": handle_slash_quit, 

364} 

365 

366 

367def dispatch_slash(raw_input: str, con: Console) -> bool: 

368 """Try to dispatch *raw_input* as a ``/command``. Returns True if handled.""" 

369 stripped = raw_input.strip() 

370 if not stripped.startswith("/"): 

371 return False 

372 parts = stripped[1:].split(None, 1) 

373 cmd = parts[0].lower() if parts else "" 

374 args = parts[1] if len(parts) > 1 else "" 

375 handler = _SLASH_COMMANDS.get(cmd) 

376 if handler is None: 

377 con.print(f"[red]Unknown command:[/red] /{cmd} — try /help") 

378 return True 

379 handler(args, con) 

380 return True 

381 

382 

383def list_ollama_models(*, exclude_vision: bool = False) -> list[str]: 

384 """Return installed Ollama model names with explicit tags, excluding embedding models. 

385 

386 When *exclude_vision* is True, also filters out known vision catalog models. 

387 """ 

388 try: 

389 import ollama 

390 

391 embed_base = cfg.embedding_model.split(":")[0] 

392 models = [ 

393 m.model for m in ollama.list().models if m.model and m.model.split(":")[0] != embed_base 

394 ] 

395 if exclude_vision: 

396 from lilbee.models import VISION_CATALOG 

397 

398 vision_names = {m.name for m in VISION_CATALOG} 

399 models = [m for m in models if m not in vision_names] 

400 return models 

401 except (ConnectionError, OSError): 

402 return [] 

403 

404 

405def make_completer(): # type: ignore[no-untyped-def] 

406 """Build a completer class that inherits from prompt_toolkit.completion.Completer.""" 

407 from prompt_toolkit.completion import Completer, Completion, PathCompleter 

408 from prompt_toolkit.document import Document 

409 

410 class LilbeeCompleter(Completer): 

411 def get_completions(self, document, complete_event): # type: ignore[no-untyped-def,override] 

412 text = document.text_before_cursor 

413 if text.startswith(_ADD_PREFIX): 

414 sub_text = text[len(_ADD_PREFIX) :] 

415 sub_doc = Document(sub_text, len(sub_text)) 

416 yield from PathCompleter(expanduser=True).get_completions(sub_doc, complete_event) 

417 elif text.startswith(_MODEL_PREFIX): 

418 prefix = text[len(_MODEL_PREFIX) :] 

419 for name in list_ollama_models(exclude_vision=True): 

420 if name.startswith(prefix): 

421 yield Completion(name, start_position=-len(prefix)) 

422 elif text.startswith(_SET_PREFIX): 

423 prefix = text[len(_SET_PREFIX) :] 

424 for name in _SETTINGS_MAP: 

425 if name.startswith(prefix): 

426 yield Completion(name, start_position=-len(prefix)) 

427 elif text.startswith(_VISION_PREFIX): 

428 from lilbee.models import VISION_CATALOG 

429 

430 prefix = text[len(_VISION_PREFIX) :] 

431 if "off".startswith(prefix): 

432 yield Completion("off", start_position=-len(prefix)) 

433 for model in VISION_CATALOG: 

434 if model.name.startswith(prefix): 

435 yield Completion(model.name, start_position=-len(prefix)) 

436 elif text.startswith("/"): 

437 prefix = text[1:] 

438 for cmd in _SLASH_COMMANDS: 

439 if cmd.startswith(prefix): 

440 yield Completion(f"/{cmd}", start_position=-len(text)) 

441 

442 return LilbeeCompleter() 

443 

444 

445def chat_loop(con: Console) -> None: 

446 """Interactive REPL with slash-command support.""" 

447 con.print("[bold]lilbee chat[/bold] — type /help for commands\n") 

448 history: list[ChatMessage] = [] 

449 

450 _prompt_fn: Callable[[], str] | None = None 

451 if sys.stdin.isatty(): 

452 try: 

453 from prompt_toolkit import PromptSession 

454 

455 _session: PromptSession[str] = PromptSession(completer=make_completer()) 

456 _prompt_fn = lambda: _session.prompt("> ") # noqa: E731 

457 except ImportError: 

458 pass 

459 

460 while True: 

461 try: 

462 if _prompt_fn is not None: 

463 question = _prompt_fn() 

464 else: 

465 question = con.input("[bold green]> [/bold green]") 

466 except (EOFError, KeyboardInterrupt): 

467 break 

468 if not question.strip(): 

469 continue 

470 try: 

471 if dispatch_slash(question, con): 

472 continue 

473 except QuitChat: 

474 break 

475 stream_response(question, history, con)