Coverage for src / lilbee / embedder.py: 100%
72 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 08:27 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 08:27 +0000
1"""Thin wrapper around Ollama embeddings API."""
3import logging
4import math
5import time
6from typing import Any
8import ollama
10from lilbee.config import cfg
11from lilbee.models import pull_with_progress
12from lilbee.progress import DetailedProgressCallback, EventType, noop_callback
14log = logging.getLogger(__name__)
16MAX_BATCH_CHARS = 6000
19def _call_with_retry(fn: Any, *args: Any, **kwargs: Any) -> Any:
20 """Retry fn up to 3 times with exponential backoff on connection errors."""
21 delays = [1, 2, 4]
22 last_err: Exception | None = None
23 for attempt, delay in enumerate(delays):
24 try:
25 return fn(*args, **kwargs) # type: ignore[operator]
26 except (ConnectionError, OSError) as exc:
27 last_err = exc
28 log.warning("Ollama call failed (attempt %d/3): %s", attempt + 1, exc)
29 time.sleep(delay)
30 raise last_err # type: ignore[misc]
33def truncate(text: str) -> str:
34 """Truncate text to stay within the embedding model's context window."""
35 if len(text) <= cfg.max_embed_chars:
36 return text
37 log.debug("Truncating chunk from %d to %d chars for embedding", len(text), cfg.max_embed_chars)
38 return text[: cfg.max_embed_chars]
41def validate_vector(vector: list[float]) -> None:
42 """Validate embedding vector dimension and values."""
43 if len(vector) != cfg.embedding_dim:
44 raise ValueError(
45 f"Embedding dimension mismatch: expected {cfg.embedding_dim}, got {len(vector)}"
46 )
47 for i, v in enumerate(vector):
48 if math.isnan(v) or math.isinf(v):
49 raise ValueError(f"Embedding contains invalid value at index {i}: {v}")
52def validate_model() -> None:
53 """Ensure the configured embedding model is available, pulling if needed."""
54 try:
55 models = ollama.list()
56 names = {m.model for m in models.models if m.model}
57 # Also match without :latest tag
58 base_names = {n.split(":")[0] for n in names}
59 if cfg.embedding_model not in names and cfg.embedding_model not in base_names:
60 log.info("Pulling embedding model '%s' from Ollama...", cfg.embedding_model)
61 pull_with_progress(cfg.embedding_model)
62 except (ConnectionError, OSError) as exc:
63 raise RuntimeError(f"Cannot connect to Ollama: {exc}. Is Ollama running?") from exc
66def embed(text: str) -> list[float]:
67 """Embed a single text string, return vector."""
68 response = _call_with_retry(ollama.embed, model=cfg.embedding_model, input=truncate(text))
69 result: list[float] = response.embeddings[0]
70 validate_vector(result)
71 return result
74def embed_batch(
75 texts: list[str],
76 *,
77 source: str = "",
78 on_progress: DetailedProgressCallback = noop_callback,
79) -> list[list[float]]:
80 """Embed multiple texts with adaptive batching, return list of vectors.
82 Fires ``embed`` progress events per batch when *on_progress* is provided.
83 """
84 if not texts:
85 return []
86 total_chunks = len(texts)
87 vectors: list[list[float]] = []
88 batch: list[str] = []
89 batch_chars = 0
90 for text in texts:
91 truncated = truncate(text)
92 chunk_len = len(truncated)
93 if batch and batch_chars + chunk_len > MAX_BATCH_CHARS:
94 response = _call_with_retry(ollama.embed, model=cfg.embedding_model, input=batch)
95 vectors.extend(response.embeddings)
96 on_progress(
97 EventType.EMBED,
98 {"file": source, "chunk": len(vectors), "total_chunks": total_chunks},
99 )
100 batch = []
101 batch_chars = 0
102 batch.append(truncated)
103 batch_chars += chunk_len
104 if batch:
105 response = _call_with_retry(ollama.embed, model=cfg.embedding_model, input=batch)
106 vectors.extend(response.embeddings)
107 on_progress(
108 EventType.EMBED,
109 {"file": source, "chunk": len(vectors), "total_chunks": total_chunks},
110 )
111 for vec in vectors:
112 validate_vector(vec)
113 return vectors