Coverage for src / lilbee / cli / tui / widgets / model_bar.py: 100%
253 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Model status bar — Select dropdowns for chat and embedding models."""
3from __future__ import annotations
5import contextlib
6import logging
7from pathlib import Path
8from typing import ClassVar, NamedTuple
10from textual import on, work
11from textual.app import ComposeResult
12from textual.containers import Horizontal
13from textual.widget import Widget
14from textual.widgets import Select, Static
15from textual.widgets._select import SelectCurrent
17from lilbee.catalog import clean_display_name, display_label_for_ref, extract_quant
18from lilbee.cli.tui import messages as msg
19from lilbee.cli.tui.app import apply_active_model
20from lilbee.cli.tui.pill import pill
21from lilbee.cli.tui.thread_safe import call_from_thread
22from lilbee.config import cfg
23from lilbee.models import ModelTask
24from lilbee.providers.model_ref import OLLAMA_PREFIX, parse_model_ref
25from lilbee.providers.sdk_backend import (
26 OLLAMA_BACKEND_NAME,
27 PROVIDER_KEYS,
28 detect_backend_name,
29)
30from lilbee.services import get_services, reset_services
31from lilbee.store import SearchScope
33log = logging.getLogger(__name__)
35_DISABLED = Select.NULL
37_MMPROJ_MARKER = "mmproj"
39_CLOUD_WARNING_ID = "cloud-provider-warning"
40_CLOUD_WARNING_CLASS = "cloud-warning"
41_CLOUD_WARNING_VISIBLE_CLASS = "-visible"
43# Routing-name -> display-label map derived from PROVIDER_KEYS. Any new
44# entry added there lights up the warning without further changes here.
45_CLOUD_PROVIDER_LABELS: dict[str, str] = {name: label for name, _, _, label in PROVIDER_KEYS}
48def _cloud_provider_label(chat_model: str) -> str | None:
49 """Return the provider display label for cloud-routed models, else None.
51 Local models and Ollama (remote but user-local) return None. Anything
52 classified as an API provider by ``parse_model_ref`` returns the
53 matching label from PROVIDER_KEYS.
54 """
55 if not chat_model:
56 return None
57 ref = parse_model_ref(chat_model)
58 if not ref.is_api:
59 return None
60 return _CLOUD_PROVIDER_LABELS.get(ref.provider)
63class ModelOption(NamedTuple):
64 """A selectable model with display label and config ref."""
66 label: str # human-readable name for the dropdown
67 ref: str # canonical ref persisted to config
70def _is_mmproj(name: str) -> bool:
71 """Return True if a model name refers to an mmproj projection file."""
72 return _MMPROJ_MARKER in name.lower()
75def _classify_installed_models() -> tuple[list[ModelOption], list[ModelOption]]:
76 """Classify installed models into (chat, embedding) lists.
77 Uses registry manifests for native models and the SDK backend's
78 backend metadata for remote models. Filters out mmproj files.
79 """
80 buckets: dict[ModelTask, list[ModelOption]] = {task: [] for task in ModelTask}
81 seen: set[str] = set()
83 _collect_native_models(buckets, seen)
84 _collect_remote_models(buckets, seen)
85 _collect_api_models(buckets, seen)
87 # ModelBar exposes only chat + embedding Selects. Vision and rerank
88 # models are collected to claim their refs in ``seen`` so later
89 # buckets don't duplicate them, but aren't returned to the UI.
90 return (
91 sorted(buckets[ModelTask.CHAT], key=lambda o: o.ref),
92 sorted(buckets[ModelTask.EMBEDDING], key=lambda o: o.ref),
93 )
96def _lookup_bucket(
97 buckets: dict[ModelTask, list[ModelOption]], task: str, ref: str
98) -> list[ModelOption] | None:
99 """Return the bucket for *task*, or None if *task* is not a known ModelTask.
101 Unknown tasks from forward-compat manifests are dropped rather than
102 silently misclassified into chat.
103 """
104 try:
105 key = ModelTask(task)
106 except ValueError:
107 log.debug("dropping %r with unknown task %r", ref, task)
108 return None
109 return buckets.get(key)
112def _native_label(hf_repo: str, gguf_filename: str, repo_count: int) -> str:
113 """Build the picker label, appending the quant suffix only on collision."""
114 base = clean_display_name(hf_repo)
115 if repo_count <= 1:
116 return base
117 quant = extract_quant(gguf_filename)
118 return f"{base} ({quant})" if quant else base
121def _collect_native_models(buckets: dict[ModelTask, list[ModelOption]], seen: set[str]) -> None:
122 """Add native registry models to buckets."""
123 try:
124 from lilbee.registry import ModelRegistry
126 registry = ModelRegistry(cfg.models_dir)
127 manifests = registry.list_installed()
128 repo_counts: dict[str, int] = {}
129 for m in manifests:
130 repo_counts[m.hf_repo] = repo_counts.get(m.hf_repo, 0) + 1
132 for manifest in manifests:
133 ref = manifest.ref
134 if _is_mmproj(manifest.gguf_filename) or ref in seen:
135 continue
136 bucket = _lookup_bucket(buckets, manifest.task, ref)
137 if bucket is None:
138 continue
139 seen.add(ref)
140 label = _native_label(
141 manifest.hf_repo, manifest.gguf_filename, repo_counts[manifest.hf_repo]
142 )
143 bucket.append(ModelOption(label=label, ref=ref))
144 except Exception:
145 log.debug("Could not read native model registry", exc_info=True)
148def _collect_remote_models(buckets: dict[ModelTask, list[ModelOption]], seen: set[str]) -> None:
149 """Add remote (Ollama / OpenAI-compatible) models to buckets, prefixed for routing.
151 Skipped entirely when the SDK extra (``lilbee[litellm]``) isn't installed:
152 surfacing a model the SDK can't actually serve only sets the user up
153 for a runtime traceback when they pick it. (bb-ezwy)
154 """
155 from lilbee.providers.litellm_sdk import litellm_available
157 if not litellm_available():
158 return
159 try:
160 from lilbee.model_manager import classify_remote_models
162 base_url = cfg.remote_base_url
163 is_ollama = detect_backend_name(base_url) == OLLAMA_BACKEND_NAME
164 for model in classify_remote_models(base_url):
165 # Skip backend rows with a blank model name so the picker
166 # doesn't render an empty " (Ollama)" row.
167 if not model.name.strip():
168 continue
169 ref = f"{OLLAMA_PREFIX}{model.name}" if is_ollama else model.name
170 if ref in seen or _is_mmproj(model.name):
171 continue
172 bucket = _lookup_bucket(buckets, model.task, ref)
173 if bucket is None:
174 continue
175 seen.add(ref)
176 label = f"{model.name} ({model.provider})"
177 bucket.append(ModelOption(label=label, ref=ref))
178 except Exception:
179 log.debug("Could not classify remote models", exc_info=True)
182def _collect_api_models(buckets: dict[ModelTask, list[ModelOption]], seen: set[str]) -> None:
183 """Add frontier API models (OpenAI, Anthropic, Gemini) to chat bucket.
185 Same litellm guard as ``_collect_remote_models``: API providers
186 require the SDK to route a turn; without it, surfacing them just
187 invites a runtime traceback. (bb-ezwy)
188 """
189 from lilbee.providers.litellm_sdk import litellm_available
191 if not litellm_available():
192 return
193 try:
194 from lilbee.model_manager import discover_api_models
196 # API discovery returns only chat-capable refs; revisit if providers
197 # expose embedding/vision/rerank.
198 for display_name, models in discover_api_models().items():
199 for model in models:
200 # model.provider is the display name ("Anthropic"), but the ref
201 # needs the backend-qualified prefix ("anthropic/model-name") for routing.
202 prefix = display_name.lower()
203 qualified = f"{prefix}/{model.name}"
204 if qualified in seen:
205 continue
206 seen.add(qualified)
207 label = f"{model.name} ({display_name})"
208 buckets[ModelTask.CHAT].append(ModelOption(label=label, ref=qualified))
209 except Exception:
210 log.debug("Could not discover API models", exc_info=True)
213def _sync_select(sel: Select, opts: list[ModelOption], default: str = "") -> None:
214 """Populate a Select with *opts*; show *default* with ``(not installed)``
215 when it isn't in the list, and pass unparseable defaults through unchanged."""
216 if default:
217 try:
218 ref = parse_model_ref(default)
219 except ValueError:
220 ref = None
221 if ref is not None:
222 default = ref.for_openai_prefix()
223 if default and not any(o.ref == default for o in opts):
224 shown = display_label_for_ref(default) or default
225 opts.insert(0, ModelOption(f"{shown} (not installed)", default))
226 sel.set_options(opts)
227 if default:
228 sel.value = default
229 _refresh_select_label(sel, opts, default)
232def _refresh_select_label(sel: Select, opts: list[ModelOption], value: str) -> None:
233 """Push the matching option's label into ``SelectCurrent``.
235 Textual's ``Select.set_options`` updates the option list but doesn't
236 re-render ``SelectCurrent`` if the existing ``value`` still matches
237 an option — the reactive watcher short-circuits on ``old == new``.
238 That meant a freshly-labelled option (e.g. ``"<ref> (not installed)"``)
239 kept the compose-time bare-ref label on screen. Poke the inner
240 widget directly so the visible label matches what tests assert.
241 """
242 if not value:
243 return
244 with contextlib.suppress(Exception):
245 current = sel.query_one(SelectCurrent)
246 for label, ref_value in opts:
247 if ref_value == value:
248 current.update(label)
249 return
252_SELECT_IDS = ("#chat-model-select", "#embed-model-select")
254_CSS_FILE = Path(__file__).parent / "model_bar.tcss"
256# Presentation labels for the scope toggle. Values match ``SearchScope``
257# so the widget's ``.value`` feeds directly into ``scope_to_chunk_type``.
258_SCOPE_OPTIONS: tuple[tuple[str, str], ...] = (
259 ("Both", SearchScope.BOTH.value),
260 ("Wiki", SearchScope.WIKI.value),
261 ("Raw", SearchScope.RAW.value),
262)
265class ModelBar(Widget, can_focus=False):
266 """Compact bar with Select dropdowns for active model assignments."""
268 DEFAULT_CSS: ClassVar[str] = _CSS_FILE.read_text(encoding="utf-8")
270 def __init__(self, id: str | None = None) -> None:
271 super().__init__(id=id)
272 self._populating = True # Guard against change events during init
273 self._scope: SearchScope = SearchScope.BOTH
275 @property
276 def scope(self) -> SearchScope:
277 """Current scope selection; consumed by ChatScreen when building RAG context."""
278 return self._scope
280 def compose(self) -> ComposeResult:
281 chat_label = display_label_for_ref(cfg.chat_model) or cfg.chat_model
282 embed_label = display_label_for_ref(cfg.embedding_model) or cfg.embedding_model
283 chat_opts = [(chat_label, cfg.chat_model)] if cfg.chat_model else []
284 embed_opts = [(embed_label, cfg.embedding_model)] if cfg.embedding_model else []
285 with Horizontal():
286 yield Static(pill("Chat", "$primary", "$text"), classes="model-bar-pill")
287 yield Select[str](
288 options=chat_opts,
289 prompt="Chat model",
290 id="chat-model-select",
291 allow_blank=False,
292 )
293 yield Static(pill("Embed", "$secondary", "$text"), classes="model-bar-pill")
294 yield Select[str](
295 options=embed_opts,
296 prompt="Embed model",
297 id="embed-model-select",
298 allow_blank=False,
299 )
300 # Scope picker only appears when the wiki layer is on. With wiki
301 # off, ``CHUNKS_TABLE`` contains only raw rows so a wiki/raw/both
302 # toggle has nothing to pick between; hiding it keeps the choice
303 # from implying a capability the user hasn't opted into.
304 if cfg.wiki:
305 yield Static(pill("Scope", "$accent", "$text"), classes="model-bar-pill")
306 yield Select[str](
307 options=list(_SCOPE_OPTIONS),
308 value=SearchScope.BOTH.value,
309 id="scope-select",
310 allow_blank=False,
311 )
312 yield Static("", id=_CLOUD_WARNING_ID, classes=_CLOUD_WARNING_CLASS)
314 def on_mount(self) -> None:
315 chat_sel = self.query_one("#chat-model-select", Select)
316 embed_sel = self.query_one("#embed-model-select", Select)
318 if cfg.chat_model:
319 chat_sel.value = cfg.chat_model
320 if cfg.embedding_model:
321 embed_sel.value = cfg.embedding_model
323 self._watch_overlay_collapse(chat_sel)
324 self._watch_overlay_collapse(embed_sel)
326 self._refresh_cloud_warning()
327 self._scan_models()
329 def _watch_overlay_collapse(self, sel: Select) -> None:
330 """Force a full screen refresh when a Select overlay collapses."""
332 def _on_expanded_change(expanded: bool) -> None:
333 if not expanded and self.is_mounted:
334 with contextlib.suppress(Exception):
335 self.screen.refresh()
337 self.watch(sel, "expanded", _on_expanded_change, init=False)
339 def on_unmount(self) -> None:
340 """Collapse any open dropdown before tear-down so the SelectOverlay
341 does not leak its border cells into the next screen's render."""
342 for sel in self.query(Select):
343 if sel.expanded:
344 sel.expanded = False
346 @work(thread=True)
347 def _scan_models(self) -> None:
348 """Scan installed models in background, then populate dropdowns."""
349 chat, embed = _classify_installed_models()
350 call_from_thread(self, self._populate, chat, embed)
352 def _populate(
353 self,
354 chat_models: list[ModelOption],
355 embed_models: list[ModelOption],
356 ) -> None:
357 """Populate Select widgets from scanned models (main thread)."""
358 self._populating = True
360 chat_sel = self.query_one("#chat-model-select", Select)
361 embed_sel = self.query_one("#embed-model-select", Select)
363 chat_opts = list(chat_models) if chat_models else [ModelOption("(none)", "")]
364 embed_opts = list(embed_models) if embed_models else [ModelOption("(none)", "")]
366 _sync_select(chat_sel, chat_opts, cfg.chat_model)
367 _sync_select(embed_sel, embed_opts, cfg.embedding_model)
369 self._populating = False
371 @on(Select.Changed, "#chat-model-select")
372 def _on_chat_model_changed(self, event: Select.Changed) -> None:
373 """Write the new chat model to cfg and settings."""
374 chat_sel = self.query_one("#chat-model-select", Select)
375 value = self._extract_value(event, chat_sel)
376 if value is None or value == cfg.chat_model:
377 return
378 apply_active_model(self.app, "chat_model", value)
379 self._refresh_cloud_warning()
380 self._after_model_change()
382 def _refresh_cloud_warning(self) -> None:
383 """Show a warning if the active chat model routes to a cloud API."""
384 warning = self.query_one(f"#{_CLOUD_WARNING_ID}", Static)
385 label = _cloud_provider_label(cfg.chat_model)
386 if label is None:
387 warning.remove_class(_CLOUD_WARNING_VISIBLE_CLASS)
388 return
389 warning.update(msg.MODEL_BAR_CLOUD_PROVIDER_WARNING.format(provider=label))
390 warning.add_class(_CLOUD_WARNING_VISIBLE_CLASS)
392 @on(Select.Changed, "#embed-model-select")
393 def _on_embed_model_changed(self, event: Select.Changed) -> None:
394 """Write the new embedding model to cfg and settings."""
395 embed_sel = self.query_one("#embed-model-select", Select)
396 value = self._extract_value(event, embed_sel)
397 if value is None or value == cfg.embedding_model:
398 return
399 # Pin a legacy store's identity to the OLD model BEFORE the cfg mutation
400 # so the gate in store.search/add_chunks correctly detects drift on the
401 # next op. See bb-x1qa.
402 get_services().store.initialize_meta_if_legacy()
403 apply_active_model(self.app, "embedding_model", value)
404 self._after_model_change()
406 @on(Select.Changed, "#scope-select")
407 def _on_scope_changed(self, event: Select.Changed) -> None:
408 """Track scope selection for the next ask_stream call.
410 Session-scoped on purpose; not written to settings so each new
411 session starts at "both" and the user opts into a narrower pool
412 explicitly each time.
413 """
414 scope_sel = self.query_one("#scope-select", Select)
415 value = self._extract_value(event, scope_sel)
416 if value is None:
417 return
418 self._scope = SearchScope(value)
420 def _extract_value(self, event: Select.Changed, sel: Select) -> str | None:
421 """Extract a non-empty value from a Select.Changed event, or None to skip.
423 Drops events whose payload no longer matches the widget's current
424 value. Textual posts Select.Changed asynchronously, so a prior
425 event carrying an intermediate auto-picked option can still be in
426 the queue after ``_populate`` reassigns ``sel.value`` to the
427 configured model. The stale-event check is deterministic across
428 platforms because it compares two synchronously-set values rather
429 than relying on event-loop ordering.
430 """
431 if self._populating:
432 return None
433 if event.value is _DISABLED or event.value is None or str(event.value) == "":
434 return None
435 if str(event.value) != str(sel.value):
436 return None
437 return str(event.value)
439 def _after_model_change(self) -> None:
440 """Shared post-change logic: cancel active stream and reset services safely."""
441 from lilbee.cli.tui.screens.chat import ChatScreen
443 screen = self.app.screen
444 if isinstance(screen, ChatScreen):
445 screen._apply_model_change()
446 else:
447 reset_services()
449 def refresh_models(self) -> None:
450 """Re-scan models (called after downloads complete)."""
451 self._scan_models()