Coverage for src / lilbee / cli / chat.py: 100%
291 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 08:27 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 08:27 +0000
1"""Chat loop, slash-command dispatch, and tab completion."""
3from __future__ import annotations
5import sys
6from collections.abc import Callable
7from dataclasses import dataclass
8from pathlib import Path
9from typing import TYPE_CHECKING
11from pydantic import ValidationError
12from rich.console import Console
13from rich.table import Table
15from lilbee import settings
16from lilbee.cli.helpers import (
17 add_paths,
18 get_version,
19 perform_reset,
20 render_status,
21 stream_response,
22)
23from lilbee.config import cfg
25if TYPE_CHECKING:
26 from lilbee.query import ChatMessage
28_ADD_PREFIX = "/add "
29_MODEL_PREFIX = "/model "
30_VISION_PREFIX = "/vision "
31_SET_PREFIX = "/set "
34@dataclass(frozen=True)
35class _SettingDef:
36 """Metadata for an interactive setting."""
38 cfg_attr: str
39 type: type
40 nullable: bool
43_SETTINGS_MAP: dict[str, _SettingDef] = {
44 "chat_model": _SettingDef("chat_model", str, nullable=False),
45 "vision_model": _SettingDef("vision_model", str, nullable=True),
46 "embedding_model": _SettingDef("embedding_model", str, nullable=False),
47 "top_k": _SettingDef("top_k", int, nullable=False),
48 "temperature": _SettingDef("temperature", float, nullable=True),
49 "top_p": _SettingDef("top_p", float, nullable=True),
50 "top_k_sampling": _SettingDef("top_k_sampling", int, nullable=True),
51 "repeat_penalty": _SettingDef("repeat_penalty", float, nullable=True),
52 "num_ctx": _SettingDef("num_ctx", int, nullable=True),
53 "seed": _SettingDef("seed", int, nullable=True),
54 "system_prompt": _SettingDef("system_prompt", str, nullable=False),
55}
58class QuitChat(Exception):
59 """Raised by /quit to exit the chat loop."""
62def _pick_from_catalog(
63 catalog: tuple,
64 display_fn: Callable,
65 con: Console,
66 config_attr: str,
67 setting_key: str,
68 label: str,
69) -> None:
70 """Shared interactive picker flow for model/vision catalog selection.
72 *display_fn* is called with (ram_gb, free_disk_gb) and returns the recommended ModelInfo.
73 *config_attr* is the cfg attribute to set (e.g. "chat_model" or "vision_model").
74 *setting_key* is the settings.toml key (e.g. "chat_model" or "vision_model").
75 *label* is a human-readable label for messages (e.g. "Switched to model").
76 """
77 from lilbee.models import (
78 get_free_disk_gb,
79 get_system_ram_gb,
80 pull_with_progress,
81 validate_disk_and_pull,
82 )
84 ram_gb = get_system_ram_gb()
85 free_disk_gb = get_free_disk_gb(cfg.data_dir)
86 recommended = display_fn(ram_gb, free_disk_gb)
87 default_idx = list(catalog).index(recommended) + 1
88 installed = set(list_ollama_models())
90 try:
91 raw = input(f"Choice [{default_idx}]: ").strip()
92 except (EOFError, KeyboardInterrupt):
93 return
95 if not raw:
96 return
98 try:
99 choice = int(raw)
100 except ValueError:
101 con.print(f"[red]Enter a number 1-{len(catalog)}.[/red]")
102 return
104 if not (1 <= choice <= len(catalog)):
105 con.print(f"[red]Enter a number 1-{len(catalog)}.[/red]")
106 return
108 model_info = catalog[choice - 1]
109 if model_info.name not in installed:
110 if config_attr == "chat_model":
111 validate_disk_and_pull(model_info, free_disk_gb)
112 else:
113 pull_with_progress(model_info.name)
114 setattr(cfg, config_attr, model_info.name)
115 settings.set_value(cfg.data_root, setting_key, model_info.name)
116 con.print(f"{label} [bold]{model_info.name}[/bold] (saved)")
119def _set_named_model(
120 name: str,
121 con: Console,
122 config_attr: str,
123 setting_key: str,
124 label: str,
125 *,
126 exclude_vision: bool = False,
127) -> None:
128 """Validate and set a named model directly (no picker)."""
129 from lilbee.models import ensure_tag
131 name = ensure_tag(name)
132 available = list_ollama_models(exclude_vision=exclude_vision)
133 if available and name not in available:
134 con.print(f"[red]Unknown model:[/red] {name}")
135 con.print(f"Available: {', '.join(sorted(available))}")
136 return
137 setattr(cfg, config_attr, name)
138 settings.set_value(cfg.data_root, setting_key, name)
139 con.print(f"{label} [bold]{name}[/bold] (saved)")
142def handle_slash_status(args: str, con: Console) -> None:
143 render_status(con)
146def handle_slash_add(args: str, con: Console) -> None:
147 raw = args.strip()
148 if raw:
149 p = Path(raw).expanduser()
150 if not p.exists():
151 con.print(f"[red]Path not found:[/red] {raw}")
152 return
153 add_paths([p], con, force=True)
154 else:
155 try:
156 from prompt_toolkit import prompt as pt_prompt
157 from prompt_toolkit.completion import PathCompleter
159 raw = pt_prompt("path: ", completer=PathCompleter(expanduser=True))
160 except (ImportError, EOFError, KeyboardInterrupt):
161 return
162 raw = raw.strip()
163 if not raw:
164 return
165 p = Path(raw).expanduser()
166 if not p.exists():
167 con.print(f"[red]Path not found:[/red] {raw}")
168 return
169 add_paths([p], con, force=True)
172def handle_slash_quit(args: str, con: Console) -> None:
173 raise QuitChat
176def handle_slash_model(args: str, con: Console) -> None:
177 name = args.strip()
178 if name:
179 _set_named_model(
180 name, con, "chat_model", "chat_model", "Switched to model", exclude_vision=True
181 )
182 return
184 con.print(f"[bold]Current model:[/bold] {cfg.chat_model}\n")
185 from lilbee.models import MODEL_CATALOG, display_model_picker
187 _pick_from_catalog(
188 MODEL_CATALOG,
189 display_model_picker,
190 con,
191 config_attr="chat_model",
192 setting_key="chat_model",
193 label="Switched to model",
194 )
197def handle_slash_vision(args: str, con: Console) -> None:
198 name = args.strip()
200 if name == "off":
201 cfg.vision_model = ""
202 settings.set_value(cfg.data_root, "vision_model", "")
203 con.print("Vision OCR [bold]disabled[/bold] (saved)")
204 return
206 if name:
207 _set_named_model(name, con, "vision_model", "vision_model", "Vision model set to")
208 return
210 # bare /vision — show status then picker
211 if cfg.vision_model:
212 con.print(f"[bold]Vision OCR:[/bold] {cfg.vision_model}\n")
213 else:
214 con.print("[bold]Vision OCR:[/bold] disabled\n")
216 from lilbee.models import VISION_CATALOG, display_vision_picker
218 _pick_from_catalog(
219 VISION_CATALOG,
220 display_vision_picker,
221 con,
222 config_attr="vision_model",
223 setting_key="vision_model",
224 label="Vision model set to",
225 )
228def handle_slash_version(args: str, con: Console) -> None:
229 con.print(f"lilbee [bold]{get_version()}[/bold]")
232def handle_slash_reset(args: str, con: Console) -> None:
233 con.print(
234 "[bold red]This will delete ALL documents and data.[/bold red]\nType 'yes' to confirm:"
235 )
236 try:
237 answer = con.input("[bold red]> [/bold red]")
238 except (EOFError, KeyboardInterrupt):
239 con.print("Aborted.")
240 return
241 if answer.strip().lower() != "yes":
242 con.print("Aborted.")
243 return
244 result = perform_reset()
245 con.print(
246 f"Reset complete: {result.deleted_docs} document(s), "
247 f"{result.deleted_data} data item(s) deleted."
248 )
251def _get_model_defaults() -> dict[str, str]:
252 """Fetch generation parameter defaults from Ollama for the current chat model."""
253 _OLLAMA_TO_SETTING = {"top_k": "top_k_sampling"}
254 try:
255 import ollama
257 resp = ollama.show(cfg.chat_model)
258 defaults: dict[str, str] = {}
259 for line in (resp.parameters or "").splitlines():
260 parts = line.split()
261 if len(parts) >= 2:
262 key = _OLLAMA_TO_SETTING.get(parts[0], parts[0])
263 if key in _SETTINGS_MAP:
264 defaults[key] = parts[1]
265 return defaults
266 except (ollama.ResponseError, ConnectionError, OSError):
267 return {}
270def _format_setting_value(value: object, model_default: str | None = None) -> str:
271 """Format a setting value for display."""
272 if value is None or value == "":
273 if model_default is not None:
274 return f"[dim](model default: {model_default})[/dim]"
275 return "[dim](not set)[/dim]"
276 if isinstance(value, str) and len(value) > 60:
277 return f"[dim]({len(value)} chars)[/dim]"
278 return str(value)
281def handle_slash_settings(args: str, con: Console) -> None:
282 defaults = _get_model_defaults()
283 table = Table(show_header=False, box=None, padding=(0, 2))
284 for name, defn in _SETTINGS_MAP.items():
285 value = getattr(cfg, defn.cfg_attr)
286 table.add_row(
287 f"[bold]{name}[/bold]",
288 _format_setting_value(value, defaults.get(name)),
289 )
290 con.print(table)
293def handle_slash_set(args: str, con: Console) -> None:
294 parts = args.strip().split(None, 1)
295 if not parts:
296 con.print("Usage: /set <param> [value] — try /settings to see all params")
297 return
299 name = parts[0]
300 defn = _SETTINGS_MAP.get(name)
301 if defn is None:
302 con.print(f"[red]Unknown setting:[/red] {name}")
303 con.print(f"Available: {', '.join(_SETTINGS_MAP)}")
304 return
306 if len(parts) == 1:
307 value = getattr(cfg, defn.cfg_attr)
308 con.print(f"{name} = {_format_setting_value(value)}")
309 return
311 raw_value = parts[1]
313 if raw_value.lower() in ("off", "none", "default"):
314 if not defn.nullable:
315 con.print(f"[red]{name} cannot be cleared[/red]")
316 return
317 setattr(cfg, defn.cfg_attr, "" if defn.type is str else None)
318 settings.delete_value(cfg.data_root, name)
319 con.print(f"{name} cleared (saved)")
320 return
322 try:
323 parsed = defn.type(raw_value)
324 except (ValueError, TypeError):
325 con.print(f"[red]Invalid {defn.type.__name__}:[/red] {raw_value}")
326 return
328 try:
329 setattr(cfg, defn.cfg_attr, parsed)
330 except ValidationError as exc:
331 msg = exc.errors()[0]["msg"]
332 con.print(f"[red]{name}: {msg}[/red]")
333 return
335 settings.set_value(cfg.data_root, name, str(parsed))
336 con.print(f"{name} = {parsed} (saved)")
339def handle_slash_help(args: str, con: Console) -> None:
340 con.print("[bold]Slash commands:[/bold]")
341 con.print(" /status — show indexed documents and config")
342 con.print(" /add [path] — add a file or directory (tab-completes without args)")
343 con.print(" /model [name] — show or switch chat model")
344 con.print(" /vision [name|off] — show or switch vision OCR model")
345 con.print(" /settings — show all generation settings")
346 con.print(" /set <param> [value] — change a setting (tab-completes param names)")
347 con.print(" /version — show lilbee version")
348 con.print(" /reset — delete all documents and data")
349 con.print(" /help — show this help")
350 con.print(" /quit — exit chat")
353_SLASH_COMMANDS: dict[str, Callable[[str, Console], None]] = {
354 "status": handle_slash_status,
355 "add": handle_slash_add,
356 "model": handle_slash_model,
357 "vision": handle_slash_vision,
358 "settings": handle_slash_settings,
359 "set": handle_slash_set,
360 "version": handle_slash_version,
361 "reset": handle_slash_reset,
362 "help": handle_slash_help,
363 "quit": handle_slash_quit,
364}
367def dispatch_slash(raw_input: str, con: Console) -> bool:
368 """Try to dispatch *raw_input* as a ``/command``. Returns True if handled."""
369 stripped = raw_input.strip()
370 if not stripped.startswith("/"):
371 return False
372 parts = stripped[1:].split(None, 1)
373 cmd = parts[0].lower() if parts else ""
374 args = parts[1] if len(parts) > 1 else ""
375 handler = _SLASH_COMMANDS.get(cmd)
376 if handler is None:
377 con.print(f"[red]Unknown command:[/red] /{cmd} — try /help")
378 return True
379 handler(args, con)
380 return True
383def list_ollama_models(*, exclude_vision: bool = False) -> list[str]:
384 """Return installed Ollama model names with explicit tags, excluding embedding models.
386 When *exclude_vision* is True, also filters out known vision catalog models.
387 """
388 try:
389 import ollama
391 embed_base = cfg.embedding_model.split(":")[0]
392 models = [
393 m.model for m in ollama.list().models if m.model and m.model.split(":")[0] != embed_base
394 ]
395 if exclude_vision:
396 from lilbee.models import VISION_CATALOG
398 vision_names = {m.name for m in VISION_CATALOG}
399 models = [m for m in models if m not in vision_names]
400 return models
401 except (ConnectionError, OSError):
402 return []
405def make_completer(): # type: ignore[no-untyped-def]
406 """Build a completer class that inherits from prompt_toolkit.completion.Completer."""
407 from prompt_toolkit.completion import Completer, Completion, PathCompleter
408 from prompt_toolkit.document import Document
410 class LilbeeCompleter(Completer):
411 def get_completions(self, document, complete_event): # type: ignore[no-untyped-def,override]
412 text = document.text_before_cursor
413 if text.startswith(_ADD_PREFIX):
414 sub_text = text[len(_ADD_PREFIX) :]
415 sub_doc = Document(sub_text, len(sub_text))
416 yield from PathCompleter(expanduser=True).get_completions(sub_doc, complete_event)
417 elif text.startswith(_MODEL_PREFIX):
418 prefix = text[len(_MODEL_PREFIX) :]
419 for name in list_ollama_models(exclude_vision=True):
420 if name.startswith(prefix):
421 yield Completion(name, start_position=-len(prefix))
422 elif text.startswith(_SET_PREFIX):
423 prefix = text[len(_SET_PREFIX) :]
424 for name in _SETTINGS_MAP:
425 if name.startswith(prefix):
426 yield Completion(name, start_position=-len(prefix))
427 elif text.startswith(_VISION_PREFIX):
428 from lilbee.models import VISION_CATALOG
430 prefix = text[len(_VISION_PREFIX) :]
431 if "off".startswith(prefix):
432 yield Completion("off", start_position=-len(prefix))
433 for model in VISION_CATALOG:
434 if model.name.startswith(prefix):
435 yield Completion(model.name, start_position=-len(prefix))
436 elif text.startswith("/"):
437 prefix = text[1:]
438 for cmd in _SLASH_COMMANDS:
439 if cmd.startswith(prefix):
440 yield Completion(f"/{cmd}", start_position=-len(text))
442 return LilbeeCompleter()
445def chat_loop(con: Console) -> None:
446 """Interactive REPL with slash-command support."""
447 con.print("[bold]lilbee chat[/bold] — type /help for commands\n")
448 history: list[ChatMessage] = []
450 _prompt_fn: Callable[[], str] | None = None
451 if sys.stdin.isatty():
452 try:
453 from prompt_toolkit import PromptSession
455 _session: PromptSession[str] = PromptSession(completer=make_completer())
456 _prompt_fn = lambda: _session.prompt("> ") # noqa: E731
457 except ImportError:
458 pass
460 while True:
461 try:
462 if _prompt_fn is not None:
463 question = _prompt_fn()
464 else:
465 question = con.input("[bold green]> [/bold green]")
466 except (EOFError, KeyboardInterrupt):
467 break
468 if not question.strip():
469 continue
470 try:
471 if dispatch_slash(question, con):
472 continue
473 except QuitChat:
474 break
475 stream_response(question, history, con)