Coverage for src / lilbee / embedder.py: 100%
81 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Thin wrapper around LLM provider embeddings API."""
3import logging
4import math
6from lilbee.config import Config
7from lilbee.progress import DetailedProgressCallback, EmbedEvent, EventType, noop_callback
8from lilbee.providers.base import LLMProvider
9from lilbee.providers.model_ref import ProviderModelRef, parse_model_ref
11log = logging.getLogger(__name__)
13MAX_BATCH_CHARS = 6000
16def _name_base(ref: ProviderModelRef) -> str:
17 return ref.name.split(":")[0].lower().replace(" ", "-")
20def _remote_sees_model(ref: ProviderModelRef, provider: LLMProvider) -> bool:
21 try:
22 available = provider.list_models()
23 except Exception:
24 log.debug("provider list_models failed during availability check", exc_info=True)
25 return False
26 base = _name_base(ref)
27 return any(base in m.lower().replace(" ", "-") for m in available)
30def _native_has_model(model: str) -> bool:
31 from lilbee.providers.llama_cpp_provider import resolve_model_path
33 try:
34 resolve_model_path(model)
35 except Exception:
36 return False
37 return True
40def is_model_available(model: str, provider: LLMProvider) -> bool:
41 """Return True if *model* resolves via *provider* or the native registry.
43 Remote-prefixed refs (``ollama/`` and API providers) skip the native
44 probe since they resolve through the SDK backend at call time.
45 """
46 if not model:
47 return False
48 ref = parse_model_ref(model)
49 if _remote_sees_model(ref, provider):
50 return True
51 if ref.is_remote:
52 return False
53 return _native_has_model(model)
56class Embedder:
57 """Embedding wrapper — truncates, batches, and validates vectors."""
59 def __init__(self, config: Config, provider: LLMProvider) -> None:
60 self._config = config
61 self._provider = provider
63 def truncate(self, text: str) -> str:
64 """Truncate text to stay within the embedding model's context window."""
65 if len(text) <= self._config.max_embed_chars:
66 return text
67 log.debug(
68 "Truncating chunk from %d to %d chars for embedding",
69 len(text),
70 self._config.max_embed_chars,
71 )
72 return text[: self._config.max_embed_chars]
74 def validate_vector(self, vector: list[float]) -> None:
75 """Validate embedding vector dimension and values."""
76 if len(vector) != self._config.embedding_dim:
77 raise ValueError(
78 f"Embedding dimension mismatch: expected {self._config.embedding_dim}, "
79 f"got {len(vector)}"
80 )
81 for i, v in enumerate(vector):
82 if math.isnan(v) or math.isinf(v):
83 raise ValueError(f"Embedding contains invalid value at index {i}: {v}")
85 def validate_model(self) -> bool:
86 """Check if the configured embedding model is available. No side effects."""
87 return self.embedding_available()
89 def embedding_available(self) -> bool:
90 """Return True if the embedding model can be resolved.
92 Checks the provider model list and the native registry path
93 resolution. Returns True if either finds the model.
94 """
95 return is_model_available(self._config.embedding_model, self._provider)
97 def embed(self, text: str) -> list[float]:
98 """Embed a single text string, return vector."""
99 vectors = self._provider.embed([self.truncate(text)])
100 result: list[float] = vectors[0]
101 self.validate_vector(result)
102 return result
104 def embed_batch(
105 self,
106 texts: list[str],
107 *,
108 source: str = "",
109 on_progress: DetailedProgressCallback = noop_callback,
110 ) -> list[list[float]]:
111 """Embed multiple texts with adaptive batching, return list of vectors.
112 Fires ``embed`` progress events per batch when *on_progress* is provided.
113 """
114 if not texts:
115 return []
116 total_chunks = len(texts)
117 vectors: list[list[float]] = []
118 batch: list[str] = []
119 batch_chars = 0
120 for text in texts:
121 truncated = self.truncate(text)
122 chunk_len = len(truncated)
123 if batch and batch_chars + chunk_len > MAX_BATCH_CHARS:
124 vectors.extend(self._provider.embed(batch))
125 on_progress(
126 EventType.EMBED,
127 EmbedEvent(file=source, chunk=len(vectors), total_chunks=total_chunks),
128 )
129 batch = []
130 batch_chars = 0
131 batch.append(truncated)
132 batch_chars += chunk_len
133 if batch:
134 vectors.extend(self._provider.embed(batch))
135 on_progress(
136 EventType.EMBED,
137 EmbedEvent(file=source, chunk=len(vectors), total_chunks=total_chunks),
138 )
139 for vec in vectors:
140 self.validate_vector(vec)
141 return vectors