Coverage for src / lilbee / providers / sdk_llm_provider.py: 100%

128 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""SDK-agnostic LLM provider implementing the public ``LLMProvider`` Protocol. 

2 

3``SdkLLMProvider`` owns the semantic layer: auth key injection, option 

4translation, model-ref parsing, error wrapping, and lazy one-shot 

5backend initialization (``configure_logging`` + ``inject_provider_keys`` 

6on first use). It speaks to the underlying SDK exclusively through an 

7``LlmSdkBackend``, so swapping SDKs is a one-file adapter change. 

8 

9Zero direct SDK imports live here. The adapter owns SDK-specific 

10concerns like wire-format prefixes (``ollama/``) and OpenAI content-parts 

11schema for image inputs. 

12""" 

13 

14from __future__ import annotations 

15 

16import logging 

17import os 

18from collections.abc import Callable 

19from pathlib import Path 

20from typing import Any 

21 

22from lilbee.config import cfg 

23from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError 

24from lilbee.providers.model_ref import parse_model_ref, translate_options 

25from lilbee.providers.sdk_backend import ( 

26 PROVIDER_KEYS, 

27 CompletionRequest, 

28 EmbeddingRequest, 

29 LlmSdkBackend, 

30 RerankRequest, 

31) 

32 

33log = logging.getLogger(__name__) 

34 

35 

36def inject_provider_keys() -> None: 

37 """Copy per-provider API keys from config into ``os.environ``. 

38 

39 OpenAI-compatible SDKs read provider-specific env vars 

40 (``OPENAI_API_KEY``, ``ANTHROPIC_API_KEY``, ...) at call time. This 

41 bridges lilbee's config system to that convention. Explicit env 

42 vars are never overwritten so users can still override via their 

43 shell. 

44 """ 

45 for _, cfg_field, env_var, _ in PROVIDER_KEYS: 

46 value = getattr(cfg, cfg_field, "") 

47 if value and not os.environ.get(env_var): 

48 os.environ[env_var] = value 

49 

50 

51class SdkLLMProvider(LLMProvider): 

52 """Provider that delegates SDK calls to an ``LlmSdkBackend``.""" 

53 

54 def __init__( 

55 self, 

56 backend: LlmSdkBackend, 

57 *, 

58 base_url: str = "http://localhost:11434", 

59 api_key: str = "", 

60 ) -> None: 

61 self._backend = backend 

62 self._base_url = base_url.rstrip("/") 

63 self._api_key = api_key 

64 self._initialized = False 

65 

66 def _ensure_initialized(self) -> None: 

67 """Apply one-shot backend setup before the first call. 

68 

69 Runs ``configure_logging(suppress_debug=cfg.json_mode)`` and 

70 ``inject_provider_keys()`` exactly once, regardless of whether 

71 the first operation is ``chat``, ``embed``, or a catalog query. 

72 Both steps happen together because the backend's first SDK 

73 import must see (a) the debug flag applied, and (b) per-provider 

74 API keys in ``os.environ``. 

75 """ 

76 if self._initialized: 

77 return 

78 try: 

79 self._backend.configure_logging(suppress_debug=cfg.json_mode) 

80 except (ImportError, AttributeError): 

81 log.debug("backend.configure_logging failed", exc_info=True) 

82 inject_provider_keys() 

83 self._initialized = True 

84 

85 def embed(self, texts: list[str]) -> list[list[float]]: 

86 """Embed texts via the configured backend.""" 

87 self._ensure_initialized() 

88 ref = parse_model_ref(cfg.embedding_model) 

89 request = EmbeddingRequest( 

90 ref=ref, 

91 inputs=texts, 

92 api_base=self._base_url if ref.needs_api_base else None, 

93 api_key=self._api_key or None, 

94 ) 

95 try: 

96 result = self._backend.embed(request) 

97 except ProviderError: 

98 raise 

99 except Exception as exc: 

100 raise ProviderError( 

101 f"Embedding failed: {exc}", provider=self._backend.provider_name 

102 ) from exc 

103 return result.vectors 

104 

105 def chat( 

106 self, 

107 messages: list[dict[str, str]], 

108 *, 

109 stream: bool = False, 

110 options: dict[str, Any] | None = None, 

111 model: str | None = None, 

112 ) -> str | ClosableIterator[str]: 

113 """Chat completion via the configured backend.""" 

114 self._ensure_initialized() 

115 ref = parse_model_ref(model or cfg.chat_model) 

116 translated = translate_options(options, ref) if options else {} 

117 request = CompletionRequest( 

118 ref=ref, 

119 messages=list(messages), 

120 options=translated, 

121 api_base=self._base_url if ref.needs_api_base else None, 

122 api_key=self._api_key or None, 

123 ) 

124 if stream: 

125 return self._chat_stream(request) 

126 try: 

127 result = self._backend.complete(request) 

128 except ProviderError: 

129 raise 

130 except Exception as exc: 

131 raise ProviderError( 

132 f"Chat failed: {exc}", provider=self._backend.provider_name 

133 ) from exc 

134 return result.content 

135 

136 def _chat_stream(self, request: CompletionRequest) -> ClosableIterator[str]: 

137 """Yield content tokens from a streaming completion. 

138 

139 Exceptions surfaced by the backend at either call time or during 

140 iteration are re-raised as ``ProviderError`` so callers always 

141 see a consistent error type. 

142 """ 

143 try: 

144 stream = self._backend.complete_stream(request) 

145 for chunk in stream: 

146 if chunk.content: 

147 yield chunk.content 

148 except ProviderError: 

149 raise 

150 except Exception as exc: 

151 raise ProviderError( 

152 f"Chat failed: {exc}", provider=self._backend.provider_name 

153 ) from exc 

154 

155 def list_models(self) -> list[str]: 

156 """List models from the backend (empty list on SDK errors).""" 

157 try: 

158 return self._backend.list_models(base_url=self._base_url, api_key=self._api_key) 

159 except NotImplementedError: 

160 return [] 

161 except ProviderError: 

162 raise 

163 except Exception as exc: 

164 raise ProviderError( 

165 f"Listing models failed: {exc}", provider=self._backend.provider_name 

166 ) from exc 

167 

168 def list_chat_models(self, provider: str) -> list[str]: 

169 """List frontier chat models known to the backend for *provider*. 

170 

171 Initializes the backend first so ``cfg.json_mode`` suppression is 

172 applied before the SDK import inside the backend runs. 

173 """ 

174 self._ensure_initialized() 

175 try: 

176 return self._backend.list_chat_models(provider) 

177 except NotImplementedError: 

178 return [] 

179 except ProviderError: 

180 raise 

181 except Exception as exc: 

182 raise ProviderError( 

183 f"Listing chat models failed: {exc}", provider=self._backend.provider_name 

184 ) from exc 

185 

186 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None: 

187 """Pull a model via the backend.""" 

188 try: 

189 self._backend.pull_model(model, base_url=self._base_url, on_progress=on_progress) 

190 except NotImplementedError as exc: 

191 raise ProviderError( 

192 f"Cannot pull model {model!r}: backend does not support pulling", 

193 provider=self._backend.provider_name, 

194 ) from exc 

195 except ProviderError: 

196 raise 

197 except Exception as exc: 

198 raise ProviderError( 

199 f"Cannot pull model {model!r}: {exc}", provider=self._backend.provider_name 

200 ) from exc 

201 

202 def show_model(self, model: str) -> dict[str, Any] | None: 

203 """Return model metadata, or None when unsupported or not found.""" 

204 try: 

205 return self._backend.show_model(model, base_url=self._base_url) 

206 except NotImplementedError: 

207 return None 

208 except ProviderError: 

209 raise 

210 except Exception as exc: 

211 raise ProviderError( 

212 f"Showing model {model!r} failed: {exc}", provider=self._backend.provider_name 

213 ) from exc 

214 

215 def get_capabilities(self, model: str) -> list[str]: 

216 """Return capability tags from ``show_model`` output, or ``[]``.""" 

217 info = self.show_model(model) 

218 if info is None: 

219 return [] 

220 caps = info.get("capabilities", []) 

221 return caps if isinstance(caps, list) else [] 

222 

223 def rerank(self, query: str, candidates: list[str]) -> list[float]: 

224 """Rerank candidates via the SDK backend using ``cfg.reranker_model``.""" 

225 if not candidates: 

226 return [] 

227 self._ensure_initialized() 

228 ref = parse_model_ref(cfg.reranker_model) 

229 request = RerankRequest( 

230 ref=ref, 

231 query=query, 

232 candidates=candidates, 

233 api_base=self._base_url if ref.needs_api_base else None, 

234 api_key=self._api_key or None, 

235 ) 

236 try: 

237 result = self._backend.rerank(request) 

238 except ProviderError: 

239 raise 

240 except Exception as exc: 

241 raise ProviderError( 

242 f"Rerank failed: {exc}", provider=self._backend.provider_name 

243 ) from exc 

244 return result.scores 

245 

246 def supports_rerank(self) -> bool: 

247 """SDK-backed rerank is available when the underlying SDK is importable.""" 

248 return self._backend.available() 

249 

250 def available(self) -> bool: 

251 """Return True when the configured SDK backend can service catalog calls.""" 

252 return self._backend.available() 

253 

254 def shutdown(self) -> None: 

255 """SDK-backed providers hold no lilbee-side resources.""" 

256 

257 def invalidate_load_cache(self, model_path: Path | None = None) -> None: 

258 """No-op: cloud backends have no local model cache to evict."""