Coverage for src / lilbee / embedder.py: 100%

81 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Thin wrapper around LLM provider embeddings API.""" 

2 

3import logging 

4import math 

5 

6from lilbee.config import Config 

7from lilbee.progress import DetailedProgressCallback, EmbedEvent, EventType, noop_callback 

8from lilbee.providers.base import LLMProvider 

9from lilbee.providers.model_ref import ProviderModelRef, parse_model_ref 

10 

11log = logging.getLogger(__name__) 

12 

13MAX_BATCH_CHARS = 6000 

14 

15 

16def _name_base(ref: ProviderModelRef) -> str: 

17 return ref.name.split(":")[0].lower().replace(" ", "-") 

18 

19 

20def _remote_sees_model(ref: ProviderModelRef, provider: LLMProvider) -> bool: 

21 try: 

22 available = provider.list_models() 

23 except Exception: 

24 log.debug("provider list_models failed during availability check", exc_info=True) 

25 return False 

26 base = _name_base(ref) 

27 return any(base in m.lower().replace(" ", "-") for m in available) 

28 

29 

30def _native_has_model(model: str) -> bool: 

31 from lilbee.providers.llama_cpp_provider import resolve_model_path 

32 

33 try: 

34 resolve_model_path(model) 

35 except Exception: 

36 return False 

37 return True 

38 

39 

40def is_model_available(model: str, provider: LLMProvider) -> bool: 

41 """Return True if *model* resolves via *provider* or the native registry. 

42 

43 Remote-prefixed refs (``ollama/`` and API providers) skip the native 

44 probe since they resolve through the SDK backend at call time. 

45 """ 

46 if not model: 

47 return False 

48 ref = parse_model_ref(model) 

49 if _remote_sees_model(ref, provider): 

50 return True 

51 if ref.is_remote: 

52 return False 

53 return _native_has_model(model) 

54 

55 

56class Embedder: 

57 """Embedding wrapper — truncates, batches, and validates vectors.""" 

58 

59 def __init__(self, config: Config, provider: LLMProvider) -> None: 

60 self._config = config 

61 self._provider = provider 

62 

63 def truncate(self, text: str) -> str: 

64 """Truncate text to stay within the embedding model's context window.""" 

65 if len(text) <= self._config.max_embed_chars: 

66 return text 

67 log.debug( 

68 "Truncating chunk from %d to %d chars for embedding", 

69 len(text), 

70 self._config.max_embed_chars, 

71 ) 

72 return text[: self._config.max_embed_chars] 

73 

74 def validate_vector(self, vector: list[float]) -> None: 

75 """Validate embedding vector dimension and values.""" 

76 if len(vector) != self._config.embedding_dim: 

77 raise ValueError( 

78 f"Embedding dimension mismatch: expected {self._config.embedding_dim}, " 

79 f"got {len(vector)}" 

80 ) 

81 for i, v in enumerate(vector): 

82 if math.isnan(v) or math.isinf(v): 

83 raise ValueError(f"Embedding contains invalid value at index {i}: {v}") 

84 

85 def validate_model(self) -> bool: 

86 """Check if the configured embedding model is available. No side effects.""" 

87 return self.embedding_available() 

88 

89 def embedding_available(self) -> bool: 

90 """Return True if the embedding model can be resolved. 

91 

92 Checks the provider model list and the native registry path 

93 resolution. Returns True if either finds the model. 

94 """ 

95 return is_model_available(self._config.embedding_model, self._provider) 

96 

97 def embed(self, text: str) -> list[float]: 

98 """Embed a single text string, return vector.""" 

99 vectors = self._provider.embed([self.truncate(text)]) 

100 result: list[float] = vectors[0] 

101 self.validate_vector(result) 

102 return result 

103 

104 def embed_batch( 

105 self, 

106 texts: list[str], 

107 *, 

108 source: str = "", 

109 on_progress: DetailedProgressCallback = noop_callback, 

110 ) -> list[list[float]]: 

111 """Embed multiple texts with adaptive batching, return list of vectors. 

112 Fires ``embed`` progress events per batch when *on_progress* is provided. 

113 """ 

114 if not texts: 

115 return [] 

116 total_chunks = len(texts) 

117 vectors: list[list[float]] = [] 

118 batch: list[str] = [] 

119 batch_chars = 0 

120 for text in texts: 

121 truncated = self.truncate(text) 

122 chunk_len = len(truncated) 

123 if batch and batch_chars + chunk_len > MAX_BATCH_CHARS: 

124 vectors.extend(self._provider.embed(batch)) 

125 on_progress( 

126 EventType.EMBED, 

127 EmbedEvent(file=source, chunk=len(vectors), total_chunks=total_chunks), 

128 ) 

129 batch = [] 

130 batch_chars = 0 

131 batch.append(truncated) 

132 batch_chars += chunk_len 

133 if batch: 

134 vectors.extend(self._provider.embed(batch)) 

135 on_progress( 

136 EventType.EMBED, 

137 EmbedEvent(file=source, chunk=len(vectors), total_chunks=total_chunks), 

138 ) 

139 for vec in vectors: 

140 self.validate_vector(vec) 

141 return vectors