Coverage for src / lilbee / embedder.py: 100%

72 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 08:27 +0000

1"""Thin wrapper around Ollama embeddings API.""" 

2 

3import logging 

4import math 

5import time 

6from typing import Any 

7 

8import ollama 

9 

10from lilbee.config import cfg 

11from lilbee.models import pull_with_progress 

12from lilbee.progress import DetailedProgressCallback, EventType, noop_callback 

13 

14log = logging.getLogger(__name__) 

15 

16MAX_BATCH_CHARS = 6000 

17 

18 

19def _call_with_retry(fn: Any, *args: Any, **kwargs: Any) -> Any: 

20 """Retry fn up to 3 times with exponential backoff on connection errors.""" 

21 delays = [1, 2, 4] 

22 last_err: Exception | None = None 

23 for attempt, delay in enumerate(delays): 

24 try: 

25 return fn(*args, **kwargs) # type: ignore[operator] 

26 except (ConnectionError, OSError) as exc: 

27 last_err = exc 

28 log.warning("Ollama call failed (attempt %d/3): %s", attempt + 1, exc) 

29 time.sleep(delay) 

30 raise last_err # type: ignore[misc] 

31 

32 

33def truncate(text: str) -> str: 

34 """Truncate text to stay within the embedding model's context window.""" 

35 if len(text) <= cfg.max_embed_chars: 

36 return text 

37 log.debug("Truncating chunk from %d to %d chars for embedding", len(text), cfg.max_embed_chars) 

38 return text[: cfg.max_embed_chars] 

39 

40 

41def validate_vector(vector: list[float]) -> None: 

42 """Validate embedding vector dimension and values.""" 

43 if len(vector) != cfg.embedding_dim: 

44 raise ValueError( 

45 f"Embedding dimension mismatch: expected {cfg.embedding_dim}, got {len(vector)}" 

46 ) 

47 for i, v in enumerate(vector): 

48 if math.isnan(v) or math.isinf(v): 

49 raise ValueError(f"Embedding contains invalid value at index {i}: {v}") 

50 

51 

52def validate_model() -> None: 

53 """Ensure the configured embedding model is available, pulling if needed.""" 

54 try: 

55 models = ollama.list() 

56 names = {m.model for m in models.models if m.model} 

57 # Also match without :latest tag 

58 base_names = {n.split(":")[0] for n in names} 

59 if cfg.embedding_model not in names and cfg.embedding_model not in base_names: 

60 log.info("Pulling embedding model '%s' from Ollama...", cfg.embedding_model) 

61 pull_with_progress(cfg.embedding_model) 

62 except (ConnectionError, OSError) as exc: 

63 raise RuntimeError(f"Cannot connect to Ollama: {exc}. Is Ollama running?") from exc 

64 

65 

66def embed(text: str) -> list[float]: 

67 """Embed a single text string, return vector.""" 

68 response = _call_with_retry(ollama.embed, model=cfg.embedding_model, input=truncate(text)) 

69 result: list[float] = response.embeddings[0] 

70 validate_vector(result) 

71 return result 

72 

73 

74def embed_batch( 

75 texts: list[str], 

76 *, 

77 source: str = "", 

78 on_progress: DetailedProgressCallback = noop_callback, 

79) -> list[list[float]]: 

80 """Embed multiple texts with adaptive batching, return list of vectors. 

81 

82 Fires ``embed`` progress events per batch when *on_progress* is provided. 

83 """ 

84 if not texts: 

85 return [] 

86 total_chunks = len(texts) 

87 vectors: list[list[float]] = [] 

88 batch: list[str] = [] 

89 batch_chars = 0 

90 for text in texts: 

91 truncated = truncate(text) 

92 chunk_len = len(truncated) 

93 if batch and batch_chars + chunk_len > MAX_BATCH_CHARS: 

94 response = _call_with_retry(ollama.embed, model=cfg.embedding_model, input=batch) 

95 vectors.extend(response.embeddings) 

96 on_progress( 

97 EventType.EMBED, 

98 {"file": source, "chunk": len(vectors), "total_chunks": total_chunks}, 

99 ) 

100 batch = [] 

101 batch_chars = 0 

102 batch.append(truncated) 

103 batch_chars += chunk_len 

104 if batch: 

105 response = _call_with_retry(ollama.embed, model=cfg.embedding_model, input=batch) 

106 vectors.extend(response.embeddings) 

107 on_progress( 

108 EventType.EMBED, 

109 {"file": source, "chunk": len(vectors), "total_chunks": total_chunks}, 

110 ) 

111 for vec in vectors: 

112 validate_vector(vec) 

113 return vectors