Coverage for src / lilbee / models.py: 100%
171 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 08:27 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 08:27 +0000
1"""RAM detection, model selection, interactive picker, and auto-install for chat models."""
3import logging
4import os
5import shutil
6import sys
7from dataclasses import dataclass
8from pathlib import Path
10import ollama
11from rich.console import Console
12from rich.progress import BarColumn, DownloadColumn, Progress, SpinnerColumn, TextColumn
13from rich.table import Table
15from lilbee import settings
16from lilbee.config import cfg
18log = logging.getLogger(__name__)
20# Extra headroom required beyond model size (GB)
21_DISK_HEADROOM_GB = 2
23OLLAMA_MODELS_URL = "https://ollama.com/library"
26def ensure_tag(name: str) -> str:
27 """Ensure a model name has an explicit tag (e.g. ``llama3`` → ``llama3:latest``)."""
28 if not name or ":" in name:
29 return name
30 return f"{name}:latest"
33@dataclass(frozen=True)
34class ModelInfo:
35 """A curated chat model with metadata for the picker UI."""
37 name: str
38 size_gb: float
39 min_ram_gb: float
40 description: str
43MODEL_CATALOG: tuple[ModelInfo, ...] = (
44 ModelInfo("qwen3:1.7b", 1.1, 4, "Tiny — fast on any machine"),
45 ModelInfo("qwen3:4b", 2.5, 8, "Small — good balance for 8 GB RAM"),
46 ModelInfo("mistral:7b", 4.4, 8, "Small — Mistral's fast 7B, 32K context"),
47 ModelInfo("qwen3:8b", 5.0, 8, "Medium — strong general-purpose"),
48 ModelInfo("qwen3-coder:30b", 18.0, 32, "Extra large — best quality, needs 32 GB RAM"),
49)
52VISION_CATALOG: tuple[ModelInfo, ...] = (
53 ModelInfo(
54 "maternion/LightOnOCR-2:latest", 1.5, 4, "Best quality/speed — clean markdown OCR output"
55 ),
56 ModelInfo("deepseek-ocr:latest", 6.7, 8, "Excellent accuracy — plain text, no markdown"),
57 ModelInfo("minicpm-v:latest", 5.5, 8, "Good — some transcription errors, slower"),
58 ModelInfo("glm-ocr:latest", 2.2, 4, "Good accuracy — surprisingly slow despite small size"),
59)
62def get_system_ram_gb() -> float:
63 """Return total system RAM in GB. Falls back to 8.0 if detection fails."""
64 try:
65 if sys.platform == "win32":
66 import ctypes
68 class _MEMORYSTATUSEX(ctypes.Structure):
69 _fields_ = [
70 ("dwLength", ctypes.c_ulong),
71 ("dwMemoryLoad", ctypes.c_ulong),
72 ("ullTotalPhys", ctypes.c_ulonglong),
73 ("ullAvailPhys", ctypes.c_ulonglong),
74 ("ullTotalPageFile", ctypes.c_ulonglong),
75 ("ullAvailPageFile", ctypes.c_ulonglong),
76 ("ullTotalVirtual", ctypes.c_ulonglong),
77 ("ullAvailVirtual", ctypes.c_ulonglong),
78 ("ullAvailExtendedVirtual", ctypes.c_ulonglong),
79 ]
81 stat = _MEMORYSTATUSEX()
82 stat.dwLength = ctypes.sizeof(stat)
83 ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(stat)) # type: ignore[attr-defined]
84 return stat.ullTotalPhys / (1024**3)
85 else:
86 pages = os.sysconf("SC_PHYS_PAGES")
87 page_size = os.sysconf("SC_PAGE_SIZE")
88 return (pages * page_size) / (1024**3)
89 except (OSError, AttributeError, ValueError):
90 log.debug("RAM detection failed, falling back to 8.0 GB")
91 return 8.0
94def get_free_disk_gb(path: Path) -> float:
95 """Return free disk space in GB for the filesystem containing *path*."""
96 check_path = path if path.exists() else path.parent
97 while not check_path.exists():
98 check_path = check_path.parent
99 usage = shutil.disk_usage(check_path)
100 return usage.free / (1024**3)
103def pick_default_model(ram_gb: float) -> ModelInfo:
104 """Choose the largest catalog model that fits in *ram_gb*."""
105 best = MODEL_CATALOG[0]
106 for model in MODEL_CATALOG:
107 if model.min_ram_gb <= ram_gb:
108 best = model
109 return best
112def _model_download_size_gb(model: str) -> float:
113 """Estimated download size for a model."""
114 catalog_sizes = {m.name: m.size_gb for m in MODEL_CATALOG}
115 fallback = next(m.size_gb for m in MODEL_CATALOG if m.name == "qwen3:8b")
116 return catalog_sizes.get(model, fallback)
119def display_model_picker(ram_gb: float, free_disk_gb: float) -> ModelInfo:
120 """Show a Rich table of catalog models on stderr and return the recommended model."""
121 console = Console(stderr=True)
122 recommended = pick_default_model(ram_gb)
124 table = Table(title="Available Models", show_lines=False)
125 table.add_column("#", justify="right", style="bold")
126 table.add_column("Model", style="cyan")
127 table.add_column("Size", justify="right")
128 table.add_column("Description")
130 for idx, model in enumerate(MODEL_CATALOG, 1):
131 num_str = str(idx)
132 name = model.name
133 size_str = f"{model.size_gb:.1f} GB"
134 desc = model.description
136 is_recommended = model == recommended
137 disk_too_small = free_disk_gb < model.size_gb + _DISK_HEADROOM_GB
139 if is_recommended:
140 name = f"[bold]{name} ★[/bold]"
141 desc = f"[bold]{desc}[/bold]"
142 num_str = f"[bold]{num_str}[/bold]"
144 if disk_too_small:
145 size_str = f"[red]{model.size_gb:.1f} GB[/red]"
147 table.add_row(num_str, name, size_str, desc)
149 console.print()
150 console.print("[bold]No chat model found.[/bold] Pick one to download:\n")
151 console.print(table)
152 console.print(f"\n System: {ram_gb:.0f} GB RAM, {free_disk_gb:.1f} GB free disk")
153 console.print(" \u2605 = recommended for your system")
154 console.print(f" Browse more models at {OLLAMA_MODELS_URL}\n")
156 return recommended
159def pick_default_vision_model() -> ModelInfo:
160 """Return the recommended vision model (first catalog entry, best quality)."""
161 return VISION_CATALOG[0]
164def display_vision_picker(ram_gb: float, free_disk_gb: float) -> ModelInfo:
165 """Show a Rich table of vision models on stderr and return the recommended model."""
166 console = Console(stderr=True)
167 recommended = pick_default_vision_model()
169 table = Table(title="Vision OCR Models", show_lines=False)
170 table.add_column("#", justify="right", style="bold")
171 table.add_column("Model", style="cyan")
172 table.add_column("Size", justify="right")
173 table.add_column("Description")
175 for idx, model in enumerate(VISION_CATALOG, 1):
176 num_str = str(idx)
177 name = model.name
178 size_str = f"{model.size_gb:.1f} GB"
179 desc = model.description
181 is_recommended = model == recommended
182 disk_too_small = free_disk_gb < model.size_gb + _DISK_HEADROOM_GB
184 if is_recommended:
185 name = f"[bold]{name} \u2605[/bold]"
186 desc = f"[bold]{desc}[/bold]"
187 num_str = f"[bold]{num_str}[/bold]"
189 if disk_too_small:
190 size_str = f"[red]{model.size_gb:.1f} GB[/red]"
192 table.add_row(num_str, name, size_str, desc)
194 console.print()
195 console.print("[bold]Select a vision OCR model for scanned PDF extraction:[/bold]\n")
196 console.print(table)
197 console.print(f"\n System: {ram_gb:.0f} GB RAM, {free_disk_gb:.1f} GB free disk")
198 console.print(" \u2605 = recommended for your system")
199 console.print(f" Browse more models at {OLLAMA_MODELS_URL}\n")
201 return recommended
204def prompt_model_choice(ram_gb: float) -> ModelInfo:
205 """Prompt the user to pick a model by number. Returns the chosen ModelInfo."""
206 free_disk_gb = get_free_disk_gb(cfg.data_dir)
207 recommended = display_model_picker(ram_gb, free_disk_gb)
208 default_idx = list(MODEL_CATALOG).index(recommended) + 1
210 while True:
211 try:
212 raw = input(f"Choice [{default_idx}]: ").strip()
213 except (EOFError, KeyboardInterrupt):
214 return recommended
216 if not raw:
217 return recommended
219 try:
220 choice = int(raw)
221 except ValueError:
222 sys.stderr.write(f"Enter a number 1-{len(MODEL_CATALOG)}.\n")
223 continue
225 if 1 <= choice <= len(MODEL_CATALOG):
226 return MODEL_CATALOG[choice - 1]
228 sys.stderr.write(f"Enter a number 1-{len(MODEL_CATALOG)}.\n")
231def validate_disk_and_pull(model_info: ModelInfo, free_gb: float) -> None:
232 """Check disk space, pull the model, and persist the choice."""
233 required_gb = model_info.size_gb + _DISK_HEADROOM_GB
234 if free_gb < required_gb:
235 raise RuntimeError(
236 f"Not enough disk space to download '{model_info.name}': "
237 f"need {required_gb:.1f} GB, have {free_gb:.1f} GB free. "
238 f"Free up space or manually pull a smaller model with 'ollama pull <model>'."
239 )
241 pull_with_progress(model_info.name)
242 cfg.chat_model = model_info.name
243 settings.set_value(cfg.data_root, "chat_model", model_info.name)
246def pull_with_progress(model: str) -> None:
247 """Pull an Ollama model, showing a Rich progress bar on stderr."""
248 with Progress(
249 SpinnerColumn(),
250 TextColumn("{task.description}"),
251 BarColumn(),
252 DownloadColumn(),
253 TextColumn("{task.percentage:>3.0f}%"),
254 transient=True,
255 ) as progress:
256 desc = f"Downloading model '{model}'..."
257 ptask = progress.add_task(desc, total=None)
258 for event in ollama.pull(model, stream=True):
259 total = event.total or 0
260 completed = event.completed or 0
261 if total > 0:
262 progress.update(ptask, total=total, completed=completed)
263 sys.stderr.write(f"Model '{model}' ready.\n")
266def ensure_chat_model() -> None:
267 """If Ollama has no chat models installed, pick and pull one.
269 Interactive (TTY): show catalog picker with descriptions and sizes.
270 Non-interactive (CI/pipes): auto-pick recommended model silently.
271 Persists the chosen model in config.toml so it becomes the default.
272 """
273 try:
274 models = ollama.list()
275 except (ConnectionError, OSError) as exc:
276 raise RuntimeError(f"Cannot connect to Ollama: {exc}. Is Ollama running?") from exc
278 # Filter out embedding model — only check for chat models
279 embed_base = cfg.embedding_model.split(":")[0]
280 chat_models = [
281 m.model for m in models.models if m.model and m.model.split(":")[0] != embed_base
282 ]
283 if chat_models:
284 return
286 ram_gb = get_system_ram_gb()
287 free_gb = get_free_disk_gb(cfg.data_dir)
289 if sys.stdin.isatty():
290 model_info = prompt_model_choice(ram_gb)
291 else:
292 model_info = pick_default_model(ram_gb)
293 sys.stderr.write(
294 f"No chat model found. Auto-installing '{model_info.name}' "
295 f"(detected {ram_gb:.0f} GB RAM)...\n"
296 )
298 validate_disk_and_pull(model_info, free_gb)