Coverage for src / lilbee / providers / litellm_sdk.py: 100%

230 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""litellm implementation of the ``LlmSdkBackend`` Protocol. 

2 

3This is the ONLY file in lilbee that imports ``litellm``. When migrating 

4to a different SDK (e.g. ``liter-llm``), add a sibling module alongside 

5this one and flip the single import in ``providers/factory.py``. 

6 

7All knowledge of the litellm wire format (``ollama/`` prefix, OpenAI 

8content-parts schema for images) lives here. The semantic layer in 

9``sdk_llm_provider`` never touches SDK-specific conventions. 

10""" 

11 

12from __future__ import annotations 

13 

14import base64 

15import json 

16import logging 

17from collections.abc import Callable, Iterator 

18from typing import Any 

19 

20import httpx 

21 

22from lilbee.config import DEFAULT_HTTP_TIMEOUT 

23from lilbee.providers.base import ProviderError 

24from lilbee.providers.model_ref import OLLAMA_PREFIX, ProviderModelRef 

25from lilbee.providers.sdk_backend import ( 

26 CompletionRequest, 

27 CompletionResult, 

28 EmbeddingRequest, 

29 EmbeddingResult, 

30 RerankRequest, 

31 RerankResult, 

32 StreamChunk, 

33 detect_backend_name, 

34) 

35 

36log = logging.getLogger(__name__) 

37 

38_PROVIDER_NAME = "litellm" 

39_OLLAMA_URL_PATTERNS = ("localhost:11434", "127.0.0.1:11434", "ollama") 

40 

41 

42def _is_ollama(base_url: str) -> bool: 

43 """Return True if *base_url* looks like an Ollama instance.""" 

44 url_lower = base_url.lower() 

45 return any(p in url_lower for p in _OLLAMA_URL_PATTERNS) 

46 

47 

48def litellm_available() -> bool: 

49 """Return True if ``litellm`` can be imported.""" 

50 try: 

51 import litellm # noqa: F401 

52 except ImportError: 

53 return False 

54 return True 

55 

56 

57_LITELLM_MISSING_MSG = ( 

58 "Remote and API models need the lilbee[litellm] extra. " 

59 "Reinstall with: uv tool install --prerelease=allow 'lilbee[litellm]'" 

60) 

61 

62 

63def _require_litellm() -> Any: 

64 """Import ``litellm`` or raise a user-facing ProviderError with install steps.""" 

65 try: 

66 import litellm 

67 except ImportError as exc: 

68 raise ProviderError(_LITELLM_MISSING_MSG, provider=_PROVIDER_NAME) from exc 

69 return litellm 

70 

71 

72def _cache_ollama_defaults(model: str, params_text: str) -> None: 

73 """Parse Ollama parameters and store in the model defaults cache.""" 

74 from lilbee.model_defaults import parse_kv_parameters, set_defaults 

75 

76 defaults = parse_kv_parameters(params_text) 

77 set_defaults(model, defaults) 

78 

79 

80def _route_model(ref: ProviderModelRef, api_base: str | None) -> str: 

81 """Format *ref* for litellm using the OpenAI ``provider/model`` convention.""" 

82 if ref.is_api: 

83 return ref.for_openai_prefix() 

84 if api_base and _is_ollama(api_base): 

85 return f"{OLLAMA_PREFIX}{ref.name}" 

86 return ref.name 

87 

88 

89def _format_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: 

90 """Convert messages with inline image bytes into OpenAI content parts. 

91 

92 litellm routes to OpenAI-compatible endpoints that expect the 

93 ``{"type": "image_url", "image_url": {...}}`` content-parts schema 

94 for multimodal input. Messages without ``images`` pass through 

95 untouched. 

96 """ 

97 formatted: list[dict[str, Any]] = [] 

98 for msg in messages: 

99 if "images" in msg: 

100 content_parts: list[dict[str, Any]] = [{"type": "text", "text": msg.get("content", "")}] 

101 for img in msg["images"]: 

102 if isinstance(img, bytes): 

103 b64 = base64.b64encode(img).decode() 

104 content_parts.append( 

105 { 

106 "type": "image_url", 

107 "image_url": {"url": f"data:image/png;base64,{b64}"}, 

108 } 

109 ) 

110 formatted.append({"role": msg["role"], "content": content_parts}) 

111 else: 

112 formatted.append(msg) 

113 return formatted 

114 

115 

116class LitellmSdkBackend: 

117 """``LlmSdkBackend`` adapter backed by the ``litellm`` SDK.""" 

118 

119 @property 

120 def provider_name(self) -> str: 

121 """Stable identifier used when wrapping errors in ``ProviderError``.""" 

122 return _PROVIDER_NAME 

123 

124 def active_backend_name(self, base_url: str) -> str: 

125 """Return the display name of the backend ``base_url`` points at.""" 

126 return detect_backend_name(base_url) 

127 

128 def available(self) -> bool: 

129 """Return True if the underlying SDK is installed.""" 

130 return litellm_available() 

131 

132 def configure_logging(self, *, suppress_debug: bool) -> None: 

133 """Apply litellm's debug-info suppression toggle when requested.""" 

134 if not suppress_debug: 

135 return 

136 try: 

137 import litellm 

138 

139 litellm.suppress_debug_info = True 

140 except ImportError: 

141 pass 

142 

143 def complete(self, request: CompletionRequest) -> CompletionResult: 

144 """Run a single-shot completion through ``litellm.completion``.""" 

145 litellm = _require_litellm() 

146 kwargs = self._completion_kwargs(request, stream=False) 

147 try: 

148 response = litellm.completion(**kwargs) 

149 except Exception as exc: 

150 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc 

151 choices = getattr(response, "choices", None) or [] 

152 message = choices[0].message if choices else None 

153 content = getattr(message, "content", "") if message is not None else "" 

154 finish_reason = getattr(choices[0], "finish_reason", None) if choices else None 

155 return CompletionResult( 

156 content=content or "", 

157 finish_reason=finish_reason, 

158 model=getattr(response, "model", None), 

159 ) 

160 

161 def complete_stream(self, request: CompletionRequest) -> Iterator[StreamChunk]: 

162 """Stream a completion through ``litellm.completion(stream=True)``.""" 

163 litellm = _require_litellm() 

164 kwargs = self._completion_kwargs(request, stream=True) 

165 try: 

166 response = litellm.completion(**kwargs) 

167 except Exception as exc: 

168 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc 

169 return self._stream_chunks(response) 

170 

171 @staticmethod 

172 def _stream_chunks(response: Any) -> Iterator[StreamChunk]: 

173 """Yield ``StreamChunk`` values from a litellm streaming response. 

174 

175 Exceptions raised mid-iteration are wrapped in ``ProviderError`` 

176 so the semantic layer sees a consistent error type regardless of 

177 where the SDK failed. 

178 """ 

179 try: 

180 for chunk in response: 

181 choices = getattr(chunk, "choices", None) or [] 

182 if not choices: 

183 continue 

184 delta = getattr(choices[0], "delta", None) 

185 content = getattr(delta, "content", "") if delta is not None else "" 

186 finish_reason = getattr(choices[0], "finish_reason", None) 

187 if content or finish_reason: 

188 yield StreamChunk(content=content or "", finish_reason=finish_reason) 

189 except ProviderError: 

190 raise 

191 except Exception as exc: 

192 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc 

193 

194 @staticmethod 

195 def _completion_kwargs(request: CompletionRequest, *, stream: bool) -> dict[str, Any]: 

196 """Translate a ``CompletionRequest`` into litellm kwargs.""" 

197 kwargs: dict[str, Any] = { 

198 "model": _route_model(request.ref, request.api_base), 

199 "messages": _format_messages(request.messages), 

200 "stream": stream, 

201 } 

202 if request.api_base: 

203 kwargs["api_base"] = request.api_base 

204 if request.api_key: 

205 kwargs["api_key"] = request.api_key 

206 if request.options: 

207 kwargs.update(request.options) 

208 return kwargs 

209 

210 def embed(self, request: EmbeddingRequest) -> EmbeddingResult: 

211 """Embed inputs through ``litellm.embedding``.""" 

212 litellm = _require_litellm() 

213 kwargs: dict[str, Any] = { 

214 "model": _route_model(request.ref, request.api_base), 

215 "input": request.inputs, 

216 } 

217 if request.api_base: 

218 kwargs["api_base"] = request.api_base 

219 if request.api_key: 

220 kwargs["api_key"] = request.api_key 

221 try: 

222 response = litellm.embedding(**kwargs) 

223 except Exception as exc: 

224 raise ProviderError(f"Embedding failed: {exc}", provider=_PROVIDER_NAME) from exc 

225 data = response["data"] if isinstance(response, dict) else response.data 

226 vectors = [item["embedding"] for item in data] 

227 if isinstance(response, dict): 

228 model = response.get("model") 

229 else: 

230 model = getattr(response, "model", None) 

231 return EmbeddingResult(vectors=vectors, model=model) 

232 

233 def rerank(self, request: RerankRequest) -> RerankResult: 

234 """Rerank documents via ``litellm.rerank`` (Cohere, Voyage, Jina, Together, HF TEI). 

235 

236 The SDK returns results sorted by relevance; we restore input 

237 order via each result's ``index`` so scores line up with the 

238 caller's ``candidates`` list. 

239 """ 

240 if not request.candidates: 

241 return RerankResult(scores=[]) 

242 litellm = _require_litellm() 

243 kwargs: dict[str, Any] = { 

244 "model": _route_model(request.ref, request.api_base), 

245 "query": request.query, 

246 "documents": request.candidates, 

247 } 

248 if request.api_base: 

249 kwargs["api_base"] = request.api_base 

250 if request.api_key: 

251 kwargs["api_key"] = request.api_key 

252 try: 

253 response = litellm.rerank(**kwargs) 

254 except Exception as exc: 

255 raise ProviderError(f"Rerank failed: {exc}", provider=_PROVIDER_NAME) from exc 

256 results = response["results"] if isinstance(response, dict) else response.results 

257 scores = [0.0] * len(request.candidates) 

258 for item in results: 

259 idx = item["index"] if isinstance(item, dict) else item.index 

260 score = item["relevance_score"] if isinstance(item, dict) else item.relevance_score 

261 scores[idx] = float(score) 

262 if isinstance(response, dict): 

263 model = response.get("model") 

264 else: 

265 model = getattr(response, "model", None) 

266 return RerankResult(scores=scores, model=model) 

267 

268 def list_models(self, *, base_url: str, api_key: str) -> list[str]: 

269 """List models from Ollama or an OpenAI-compatible server.""" 

270 clean_base = base_url.rstrip("/") 

271 if _is_ollama(clean_base): 

272 return self._list_ollama_models(clean_base) 

273 return self._list_openai_models(clean_base, api_key) 

274 

275 def list_chat_models(self, provider: str) -> list[str]: 

276 """Return chat-mode model ids from litellm's static catalog. 

277 

278 Empty list when litellm is not installed or the provider has no 

279 chat-mode entries. Callers that care about the litellm debug 

280 banner should invoke ``configure_logging`` first (the semantic 

281 layer's ``SdkLLMProvider`` does this in ``list_chat_models``). 

282 """ 

283 try: 

284 import litellm 

285 except ImportError: 

286 return [] 

287 models = litellm.models_by_provider.get(provider, set()) 

288 chat_models: list[str] = [] 

289 for model_name in sorted(models): 

290 info = litellm.model_cost.get(model_name, {}) 

291 if info.get("mode") != "chat": 

292 continue 

293 chat_models.append(model_name) 

294 return chat_models 

295 

296 @staticmethod 

297 def _list_ollama_models(base_url: str) -> list[str]: 

298 """List models via the Ollama ``/api/tags`` endpoint.""" 

299 try: 

300 resp = httpx.get(f"{base_url}/api/tags", timeout=DEFAULT_HTTP_TIMEOUT) 

301 resp.raise_for_status() 

302 data = resp.json() 

303 return [m["name"] for m in data.get("models", [])] 

304 except httpx.HTTPError as exc: 

305 raise ProviderError(f"Cannot list models: {exc}", provider=_PROVIDER_NAME) from exc 

306 

307 @staticmethod 

308 def _list_openai_models(base_url: str, api_key: str) -> list[str]: 

309 """List models via an OpenAI-compatible ``/v1/models`` endpoint.""" 

310 headers: dict[str, str] = {} 

311 if api_key: 

312 headers["Authorization"] = f"Bearer {api_key}" 

313 try: 

314 resp = httpx.get(f"{base_url}/v1/models", headers=headers, timeout=DEFAULT_HTTP_TIMEOUT) 

315 resp.raise_for_status() 

316 data = resp.json() 

317 return [m["id"] for m in data.get("data", [])] 

318 except httpx.HTTPError: 

319 log.debug("Failed to list models via /v1/models", exc_info=True) 

320 return [] 

321 

322 def pull_model( 

323 self, 

324 model: str, 

325 *, 

326 base_url: str, 

327 on_progress: Callable[..., Any] | None = None, 

328 ) -> None: 

329 """Pull a model via the Ollama ``/api/pull`` endpoint.""" 

330 clean_base = base_url.rstrip("/") 

331 try: 

332 with ( 

333 # Streaming Ollama /api/pull; unbounded read is intentional 

334 # since model downloads can exceed any wall-clock timeout. 

335 httpx.Client(timeout=None) as client, # noqa: S113 

336 client.stream( 

337 "POST", 

338 f"{clean_base}/api/pull", 

339 json={"name": model, "stream": True}, 

340 ) as resp, 

341 ): 

342 resp.raise_for_status() 

343 for line in resp.iter_lines(): 

344 if not line: 

345 continue 

346 event = json.loads(line) 

347 if on_progress: 

348 on_progress(event) 

349 if event.get("status") == "success": 

350 break 

351 except httpx.HTTPError as exc: 

352 raise ProviderError( 

353 f"Cannot pull model {model!r}: {exc}", provider=_PROVIDER_NAME 

354 ) from exc 

355 

356 def show_model(self, model: str, *, base_url: str) -> dict[str, Any] | None: 

357 """Get model info via the Ollama ``/api/show`` endpoint. 

358 

359 Parses and caches per-model generation defaults from the 

360 ``parameters`` field. Also extracts the ``capabilities`` list 

361 (newer Ollama versions) so callers can check for vision support. 

362 """ 

363 clean_base = base_url.rstrip("/") 

364 # Ollama's API uses bare model names; the routing-layer prefix has 

365 # to come off before the request goes out. 

366 ollama_name = model[len(OLLAMA_PREFIX) :] if model.startswith(OLLAMA_PREFIX) else model 

367 try: 

368 resp = httpx.post( 

369 f"{clean_base}/api/show", 

370 json={"name": ollama_name}, 

371 timeout=DEFAULT_HTTP_TIMEOUT, 

372 ) 

373 resp.raise_for_status() 

374 data = resp.json() 

375 except httpx.HTTPError: 

376 return None 

377 

378 result: dict[str, Any] = {} 

379 

380 params = data.get("parameters", "") 

381 if isinstance(params, str) and params: 

382 _cache_ollama_defaults(model, params) 

383 result["parameters"] = params 

384 elif params: 

385 _cache_ollama_defaults(model, str(params)) 

386 result["parameters"] = str(params) 

387 

388 capabilities = data.get("capabilities") 

389 if isinstance(capabilities, list): 

390 result["capabilities"] = capabilities 

391 

392 return result or None