Coverage for src / lilbee / cli / tui / widgets / model_bar.py: 100%

253 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Model status bar — Select dropdowns for chat and embedding models.""" 

2 

3from __future__ import annotations 

4 

5import contextlib 

6import logging 

7from pathlib import Path 

8from typing import ClassVar, NamedTuple 

9 

10from textual import on, work 

11from textual.app import ComposeResult 

12from textual.containers import Horizontal 

13from textual.widget import Widget 

14from textual.widgets import Select, Static 

15from textual.widgets._select import SelectCurrent 

16 

17from lilbee.catalog import clean_display_name, display_label_for_ref, extract_quant 

18from lilbee.cli.tui import messages as msg 

19from lilbee.cli.tui.app import apply_active_model 

20from lilbee.cli.tui.pill import pill 

21from lilbee.cli.tui.thread_safe import call_from_thread 

22from lilbee.config import cfg 

23from lilbee.models import ModelTask 

24from lilbee.providers.model_ref import OLLAMA_PREFIX, parse_model_ref 

25from lilbee.providers.sdk_backend import ( 

26 OLLAMA_BACKEND_NAME, 

27 PROVIDER_KEYS, 

28 detect_backend_name, 

29) 

30from lilbee.services import get_services, reset_services 

31from lilbee.store import SearchScope 

32 

33log = logging.getLogger(__name__) 

34 

35_DISABLED = Select.NULL 

36 

37_MMPROJ_MARKER = "mmproj" 

38 

39_CLOUD_WARNING_ID = "cloud-provider-warning" 

40_CLOUD_WARNING_CLASS = "cloud-warning" 

41_CLOUD_WARNING_VISIBLE_CLASS = "-visible" 

42 

43# Routing-name -> display-label map derived from PROVIDER_KEYS. Any new 

44# entry added there lights up the warning without further changes here. 

45_CLOUD_PROVIDER_LABELS: dict[str, str] = {name: label for name, _, _, label in PROVIDER_KEYS} 

46 

47 

48def _cloud_provider_label(chat_model: str) -> str | None: 

49 """Return the provider display label for cloud-routed models, else None. 

50 

51 Local models and Ollama (remote but user-local) return None. Anything 

52 classified as an API provider by ``parse_model_ref`` returns the 

53 matching label from PROVIDER_KEYS. 

54 """ 

55 if not chat_model: 

56 return None 

57 ref = parse_model_ref(chat_model) 

58 if not ref.is_api: 

59 return None 

60 return _CLOUD_PROVIDER_LABELS.get(ref.provider) 

61 

62 

63class ModelOption(NamedTuple): 

64 """A selectable model with display label and config ref.""" 

65 

66 label: str # human-readable name for the dropdown 

67 ref: str # canonical ref persisted to config 

68 

69 

70def _is_mmproj(name: str) -> bool: 

71 """Return True if a model name refers to an mmproj projection file.""" 

72 return _MMPROJ_MARKER in name.lower() 

73 

74 

75def _classify_installed_models() -> tuple[list[ModelOption], list[ModelOption]]: 

76 """Classify installed models into (chat, embedding) lists. 

77 Uses registry manifests for native models and the SDK backend's 

78 backend metadata for remote models. Filters out mmproj files. 

79 """ 

80 buckets: dict[ModelTask, list[ModelOption]] = {task: [] for task in ModelTask} 

81 seen: set[str] = set() 

82 

83 _collect_native_models(buckets, seen) 

84 _collect_remote_models(buckets, seen) 

85 _collect_api_models(buckets, seen) 

86 

87 # ModelBar exposes only chat + embedding Selects. Vision and rerank 

88 # models are collected to claim their refs in ``seen`` so later 

89 # buckets don't duplicate them, but aren't returned to the UI. 

90 return ( 

91 sorted(buckets[ModelTask.CHAT], key=lambda o: o.ref), 

92 sorted(buckets[ModelTask.EMBEDDING], key=lambda o: o.ref), 

93 ) 

94 

95 

96def _lookup_bucket( 

97 buckets: dict[ModelTask, list[ModelOption]], task: str, ref: str 

98) -> list[ModelOption] | None: 

99 """Return the bucket for *task*, or None if *task* is not a known ModelTask. 

100 

101 Unknown tasks from forward-compat manifests are dropped rather than 

102 silently misclassified into chat. 

103 """ 

104 try: 

105 key = ModelTask(task) 

106 except ValueError: 

107 log.debug("dropping %r with unknown task %r", ref, task) 

108 return None 

109 return buckets.get(key) 

110 

111 

112def _native_label(hf_repo: str, gguf_filename: str, repo_count: int) -> str: 

113 """Build the picker label, appending the quant suffix only on collision.""" 

114 base = clean_display_name(hf_repo) 

115 if repo_count <= 1: 

116 return base 

117 quant = extract_quant(gguf_filename) 

118 return f"{base} ({quant})" if quant else base 

119 

120 

121def _collect_native_models(buckets: dict[ModelTask, list[ModelOption]], seen: set[str]) -> None: 

122 """Add native registry models to buckets.""" 

123 try: 

124 from lilbee.registry import ModelRegistry 

125 

126 registry = ModelRegistry(cfg.models_dir) 

127 manifests = registry.list_installed() 

128 repo_counts: dict[str, int] = {} 

129 for m in manifests: 

130 repo_counts[m.hf_repo] = repo_counts.get(m.hf_repo, 0) + 1 

131 

132 for manifest in manifests: 

133 ref = manifest.ref 

134 if _is_mmproj(manifest.gguf_filename) or ref in seen: 

135 continue 

136 bucket = _lookup_bucket(buckets, manifest.task, ref) 

137 if bucket is None: 

138 continue 

139 seen.add(ref) 

140 label = _native_label( 

141 manifest.hf_repo, manifest.gguf_filename, repo_counts[manifest.hf_repo] 

142 ) 

143 bucket.append(ModelOption(label=label, ref=ref)) 

144 except Exception: 

145 log.debug("Could not read native model registry", exc_info=True) 

146 

147 

148def _collect_remote_models(buckets: dict[ModelTask, list[ModelOption]], seen: set[str]) -> None: 

149 """Add remote (Ollama / OpenAI-compatible) models to buckets, prefixed for routing. 

150 

151 Skipped entirely when the SDK extra (``lilbee[litellm]``) isn't installed: 

152 surfacing a model the SDK can't actually serve only sets the user up 

153 for a runtime traceback when they pick it. (bb-ezwy) 

154 """ 

155 from lilbee.providers.litellm_sdk import litellm_available 

156 

157 if not litellm_available(): 

158 return 

159 try: 

160 from lilbee.model_manager import classify_remote_models 

161 

162 base_url = cfg.remote_base_url 

163 is_ollama = detect_backend_name(base_url) == OLLAMA_BACKEND_NAME 

164 for model in classify_remote_models(base_url): 

165 # Skip backend rows with a blank model name so the picker 

166 # doesn't render an empty " (Ollama)" row. 

167 if not model.name.strip(): 

168 continue 

169 ref = f"{OLLAMA_PREFIX}{model.name}" if is_ollama else model.name 

170 if ref in seen or _is_mmproj(model.name): 

171 continue 

172 bucket = _lookup_bucket(buckets, model.task, ref) 

173 if bucket is None: 

174 continue 

175 seen.add(ref) 

176 label = f"{model.name} ({model.provider})" 

177 bucket.append(ModelOption(label=label, ref=ref)) 

178 except Exception: 

179 log.debug("Could not classify remote models", exc_info=True) 

180 

181 

182def _collect_api_models(buckets: dict[ModelTask, list[ModelOption]], seen: set[str]) -> None: 

183 """Add frontier API models (OpenAI, Anthropic, Gemini) to chat bucket. 

184 

185 Same litellm guard as ``_collect_remote_models``: API providers 

186 require the SDK to route a turn; without it, surfacing them just 

187 invites a runtime traceback. (bb-ezwy) 

188 """ 

189 from lilbee.providers.litellm_sdk import litellm_available 

190 

191 if not litellm_available(): 

192 return 

193 try: 

194 from lilbee.model_manager import discover_api_models 

195 

196 # API discovery returns only chat-capable refs; revisit if providers 

197 # expose embedding/vision/rerank. 

198 for display_name, models in discover_api_models().items(): 

199 for model in models: 

200 # model.provider is the display name ("Anthropic"), but the ref 

201 # needs the backend-qualified prefix ("anthropic/model-name") for routing. 

202 prefix = display_name.lower() 

203 qualified = f"{prefix}/{model.name}" 

204 if qualified in seen: 

205 continue 

206 seen.add(qualified) 

207 label = f"{model.name} ({display_name})" 

208 buckets[ModelTask.CHAT].append(ModelOption(label=label, ref=qualified)) 

209 except Exception: 

210 log.debug("Could not discover API models", exc_info=True) 

211 

212 

213def _sync_select(sel: Select, opts: list[ModelOption], default: str = "") -> None: 

214 """Populate a Select with *opts*; show *default* with ``(not installed)`` 

215 when it isn't in the list, and pass unparseable defaults through unchanged.""" 

216 if default: 

217 try: 

218 ref = parse_model_ref(default) 

219 except ValueError: 

220 ref = None 

221 if ref is not None: 

222 default = ref.for_openai_prefix() 

223 if default and not any(o.ref == default for o in opts): 

224 shown = display_label_for_ref(default) or default 

225 opts.insert(0, ModelOption(f"{shown} (not installed)", default)) 

226 sel.set_options(opts) 

227 if default: 

228 sel.value = default 

229 _refresh_select_label(sel, opts, default) 

230 

231 

232def _refresh_select_label(sel: Select, opts: list[ModelOption], value: str) -> None: 

233 """Push the matching option's label into ``SelectCurrent``. 

234 

235 Textual's ``Select.set_options`` updates the option list but doesn't 

236 re-render ``SelectCurrent`` if the existing ``value`` still matches 

237 an option — the reactive watcher short-circuits on ``old == new``. 

238 That meant a freshly-labelled option (e.g. ``"<ref> (not installed)"``) 

239 kept the compose-time bare-ref label on screen. Poke the inner 

240 widget directly so the visible label matches what tests assert. 

241 """ 

242 if not value: 

243 return 

244 with contextlib.suppress(Exception): 

245 current = sel.query_one(SelectCurrent) 

246 for label, ref_value in opts: 

247 if ref_value == value: 

248 current.update(label) 

249 return 

250 

251 

252_SELECT_IDS = ("#chat-model-select", "#embed-model-select") 

253 

254_CSS_FILE = Path(__file__).parent / "model_bar.tcss" 

255 

256# Presentation labels for the scope toggle. Values match ``SearchScope`` 

257# so the widget's ``.value`` feeds directly into ``scope_to_chunk_type``. 

258_SCOPE_OPTIONS: tuple[tuple[str, str], ...] = ( 

259 ("Both", SearchScope.BOTH.value), 

260 ("Wiki", SearchScope.WIKI.value), 

261 ("Raw", SearchScope.RAW.value), 

262) 

263 

264 

265class ModelBar(Widget, can_focus=False): 

266 """Compact bar with Select dropdowns for active model assignments.""" 

267 

268 DEFAULT_CSS: ClassVar[str] = _CSS_FILE.read_text(encoding="utf-8") 

269 

270 def __init__(self, id: str | None = None) -> None: 

271 super().__init__(id=id) 

272 self._populating = True # Guard against change events during init 

273 self._scope: SearchScope = SearchScope.BOTH 

274 

275 @property 

276 def scope(self) -> SearchScope: 

277 """Current scope selection; consumed by ChatScreen when building RAG context.""" 

278 return self._scope 

279 

280 def compose(self) -> ComposeResult: 

281 chat_label = display_label_for_ref(cfg.chat_model) or cfg.chat_model 

282 embed_label = display_label_for_ref(cfg.embedding_model) or cfg.embedding_model 

283 chat_opts = [(chat_label, cfg.chat_model)] if cfg.chat_model else [] 

284 embed_opts = [(embed_label, cfg.embedding_model)] if cfg.embedding_model else [] 

285 with Horizontal(): 

286 yield Static(pill("Chat", "$primary", "$text"), classes="model-bar-pill") 

287 yield Select[str]( 

288 options=chat_opts, 

289 prompt="Chat model", 

290 id="chat-model-select", 

291 allow_blank=False, 

292 ) 

293 yield Static(pill("Embed", "$secondary", "$text"), classes="model-bar-pill") 

294 yield Select[str]( 

295 options=embed_opts, 

296 prompt="Embed model", 

297 id="embed-model-select", 

298 allow_blank=False, 

299 ) 

300 # Scope picker only appears when the wiki layer is on. With wiki 

301 # off, ``CHUNKS_TABLE`` contains only raw rows so a wiki/raw/both 

302 # toggle has nothing to pick between; hiding it keeps the choice 

303 # from implying a capability the user hasn't opted into. 

304 if cfg.wiki: 

305 yield Static(pill("Scope", "$accent", "$text"), classes="model-bar-pill") 

306 yield Select[str]( 

307 options=list(_SCOPE_OPTIONS), 

308 value=SearchScope.BOTH.value, 

309 id="scope-select", 

310 allow_blank=False, 

311 ) 

312 yield Static("", id=_CLOUD_WARNING_ID, classes=_CLOUD_WARNING_CLASS) 

313 

314 def on_mount(self) -> None: 

315 chat_sel = self.query_one("#chat-model-select", Select) 

316 embed_sel = self.query_one("#embed-model-select", Select) 

317 

318 if cfg.chat_model: 

319 chat_sel.value = cfg.chat_model 

320 if cfg.embedding_model: 

321 embed_sel.value = cfg.embedding_model 

322 

323 self._watch_overlay_collapse(chat_sel) 

324 self._watch_overlay_collapse(embed_sel) 

325 

326 self._refresh_cloud_warning() 

327 self._scan_models() 

328 

329 def _watch_overlay_collapse(self, sel: Select) -> None: 

330 """Force a full screen refresh when a Select overlay collapses.""" 

331 

332 def _on_expanded_change(expanded: bool) -> None: 

333 if not expanded and self.is_mounted: 

334 with contextlib.suppress(Exception): 

335 self.screen.refresh() 

336 

337 self.watch(sel, "expanded", _on_expanded_change, init=False) 

338 

339 def on_unmount(self) -> None: 

340 """Collapse any open dropdown before tear-down so the SelectOverlay 

341 does not leak its border cells into the next screen's render.""" 

342 for sel in self.query(Select): 

343 if sel.expanded: 

344 sel.expanded = False 

345 

346 @work(thread=True) 

347 def _scan_models(self) -> None: 

348 """Scan installed models in background, then populate dropdowns.""" 

349 chat, embed = _classify_installed_models() 

350 call_from_thread(self, self._populate, chat, embed) 

351 

352 def _populate( 

353 self, 

354 chat_models: list[ModelOption], 

355 embed_models: list[ModelOption], 

356 ) -> None: 

357 """Populate Select widgets from scanned models (main thread).""" 

358 self._populating = True 

359 

360 chat_sel = self.query_one("#chat-model-select", Select) 

361 embed_sel = self.query_one("#embed-model-select", Select) 

362 

363 chat_opts = list(chat_models) if chat_models else [ModelOption("(none)", "")] 

364 embed_opts = list(embed_models) if embed_models else [ModelOption("(none)", "")] 

365 

366 _sync_select(chat_sel, chat_opts, cfg.chat_model) 

367 _sync_select(embed_sel, embed_opts, cfg.embedding_model) 

368 

369 self._populating = False 

370 

371 @on(Select.Changed, "#chat-model-select") 

372 def _on_chat_model_changed(self, event: Select.Changed) -> None: 

373 """Write the new chat model to cfg and settings.""" 

374 chat_sel = self.query_one("#chat-model-select", Select) 

375 value = self._extract_value(event, chat_sel) 

376 if value is None or value == cfg.chat_model: 

377 return 

378 apply_active_model(self.app, "chat_model", value) 

379 self._refresh_cloud_warning() 

380 self._after_model_change() 

381 

382 def _refresh_cloud_warning(self) -> None: 

383 """Show a warning if the active chat model routes to a cloud API.""" 

384 warning = self.query_one(f"#{_CLOUD_WARNING_ID}", Static) 

385 label = _cloud_provider_label(cfg.chat_model) 

386 if label is None: 

387 warning.remove_class(_CLOUD_WARNING_VISIBLE_CLASS) 

388 return 

389 warning.update(msg.MODEL_BAR_CLOUD_PROVIDER_WARNING.format(provider=label)) 

390 warning.add_class(_CLOUD_WARNING_VISIBLE_CLASS) 

391 

392 @on(Select.Changed, "#embed-model-select") 

393 def _on_embed_model_changed(self, event: Select.Changed) -> None: 

394 """Write the new embedding model to cfg and settings.""" 

395 embed_sel = self.query_one("#embed-model-select", Select) 

396 value = self._extract_value(event, embed_sel) 

397 if value is None or value == cfg.embedding_model: 

398 return 

399 # Pin a legacy store's identity to the OLD model BEFORE the cfg mutation 

400 # so the gate in store.search/add_chunks correctly detects drift on the 

401 # next op. See bb-x1qa. 

402 get_services().store.initialize_meta_if_legacy() 

403 apply_active_model(self.app, "embedding_model", value) 

404 self._after_model_change() 

405 

406 @on(Select.Changed, "#scope-select") 

407 def _on_scope_changed(self, event: Select.Changed) -> None: 

408 """Track scope selection for the next ask_stream call. 

409 

410 Session-scoped on purpose; not written to settings so each new 

411 session starts at "both" and the user opts into a narrower pool 

412 explicitly each time. 

413 """ 

414 scope_sel = self.query_one("#scope-select", Select) 

415 value = self._extract_value(event, scope_sel) 

416 if value is None: 

417 return 

418 self._scope = SearchScope(value) 

419 

420 def _extract_value(self, event: Select.Changed, sel: Select) -> str | None: 

421 """Extract a non-empty value from a Select.Changed event, or None to skip. 

422 

423 Drops events whose payload no longer matches the widget's current 

424 value. Textual posts Select.Changed asynchronously, so a prior 

425 event carrying an intermediate auto-picked option can still be in 

426 the queue after ``_populate`` reassigns ``sel.value`` to the 

427 configured model. The stale-event check is deterministic across 

428 platforms because it compares two synchronously-set values rather 

429 than relying on event-loop ordering. 

430 """ 

431 if self._populating: 

432 return None 

433 if event.value is _DISABLED or event.value is None or str(event.value) == "": 

434 return None 

435 if str(event.value) != str(sel.value): 

436 return None 

437 return str(event.value) 

438 

439 def _after_model_change(self) -> None: 

440 """Shared post-change logic: cancel active stream and reset services safely.""" 

441 from lilbee.cli.tui.screens.chat import ChatScreen 

442 

443 screen = self.app.screen 

444 if isinstance(screen, ChatScreen): 

445 screen._apply_model_change() 

446 else: 

447 reset_services() 

448 

449 def refresh_models(self) -> None: 

450 """Re-scan models (called after downloads complete).""" 

451 self._scan_models()