Coverage for src / lilbee / providers / litellm_sdk.py: 100%
230 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""litellm implementation of the ``LlmSdkBackend`` Protocol.
3This is the ONLY file in lilbee that imports ``litellm``. When migrating
4to a different SDK (e.g. ``liter-llm``), add a sibling module alongside
5this one and flip the single import in ``providers/factory.py``.
7All knowledge of the litellm wire format (``ollama/`` prefix, OpenAI
8content-parts schema for images) lives here. The semantic layer in
9``sdk_llm_provider`` never touches SDK-specific conventions.
10"""
12from __future__ import annotations
14import base64
15import json
16import logging
17from collections.abc import Callable, Iterator
18from typing import Any
20import httpx
22from lilbee.config import DEFAULT_HTTP_TIMEOUT
23from lilbee.providers.base import ProviderError
24from lilbee.providers.model_ref import OLLAMA_PREFIX, ProviderModelRef
25from lilbee.providers.sdk_backend import (
26 CompletionRequest,
27 CompletionResult,
28 EmbeddingRequest,
29 EmbeddingResult,
30 RerankRequest,
31 RerankResult,
32 StreamChunk,
33 detect_backend_name,
34)
36log = logging.getLogger(__name__)
38_PROVIDER_NAME = "litellm"
39_OLLAMA_URL_PATTERNS = ("localhost:11434", "127.0.0.1:11434", "ollama")
42def _is_ollama(base_url: str) -> bool:
43 """Return True if *base_url* looks like an Ollama instance."""
44 url_lower = base_url.lower()
45 return any(p in url_lower for p in _OLLAMA_URL_PATTERNS)
48def litellm_available() -> bool:
49 """Return True if ``litellm`` can be imported."""
50 try:
51 import litellm # noqa: F401
52 except ImportError:
53 return False
54 return True
57_LITELLM_MISSING_MSG = (
58 "Remote and API models need the lilbee[litellm] extra. "
59 "Reinstall with: uv tool install --prerelease=allow 'lilbee[litellm]'"
60)
63def _require_litellm() -> Any:
64 """Import ``litellm`` or raise a user-facing ProviderError with install steps."""
65 try:
66 import litellm
67 except ImportError as exc:
68 raise ProviderError(_LITELLM_MISSING_MSG, provider=_PROVIDER_NAME) from exc
69 return litellm
72def _cache_ollama_defaults(model: str, params_text: str) -> None:
73 """Parse Ollama parameters and store in the model defaults cache."""
74 from lilbee.model_defaults import parse_kv_parameters, set_defaults
76 defaults = parse_kv_parameters(params_text)
77 set_defaults(model, defaults)
80def _route_model(ref: ProviderModelRef, api_base: str | None) -> str:
81 """Format *ref* for litellm using the OpenAI ``provider/model`` convention."""
82 if ref.is_api:
83 return ref.for_openai_prefix()
84 if api_base and _is_ollama(api_base):
85 return f"{OLLAMA_PREFIX}{ref.name}"
86 return ref.name
89def _format_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
90 """Convert messages with inline image bytes into OpenAI content parts.
92 litellm routes to OpenAI-compatible endpoints that expect the
93 ``{"type": "image_url", "image_url": {...}}`` content-parts schema
94 for multimodal input. Messages without ``images`` pass through
95 untouched.
96 """
97 formatted: list[dict[str, Any]] = []
98 for msg in messages:
99 if "images" in msg:
100 content_parts: list[dict[str, Any]] = [{"type": "text", "text": msg.get("content", "")}]
101 for img in msg["images"]:
102 if isinstance(img, bytes):
103 b64 = base64.b64encode(img).decode()
104 content_parts.append(
105 {
106 "type": "image_url",
107 "image_url": {"url": f"data:image/png;base64,{b64}"},
108 }
109 )
110 formatted.append({"role": msg["role"], "content": content_parts})
111 else:
112 formatted.append(msg)
113 return formatted
116class LitellmSdkBackend:
117 """``LlmSdkBackend`` adapter backed by the ``litellm`` SDK."""
119 @property
120 def provider_name(self) -> str:
121 """Stable identifier used when wrapping errors in ``ProviderError``."""
122 return _PROVIDER_NAME
124 def active_backend_name(self, base_url: str) -> str:
125 """Return the display name of the backend ``base_url`` points at."""
126 return detect_backend_name(base_url)
128 def available(self) -> bool:
129 """Return True if the underlying SDK is installed."""
130 return litellm_available()
132 def configure_logging(self, *, suppress_debug: bool) -> None:
133 """Apply litellm's debug-info suppression toggle when requested."""
134 if not suppress_debug:
135 return
136 try:
137 import litellm
139 litellm.suppress_debug_info = True
140 except ImportError:
141 pass
143 def complete(self, request: CompletionRequest) -> CompletionResult:
144 """Run a single-shot completion through ``litellm.completion``."""
145 litellm = _require_litellm()
146 kwargs = self._completion_kwargs(request, stream=False)
147 try:
148 response = litellm.completion(**kwargs)
149 except Exception as exc:
150 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc
151 choices = getattr(response, "choices", None) or []
152 message = choices[0].message if choices else None
153 content = getattr(message, "content", "") if message is not None else ""
154 finish_reason = getattr(choices[0], "finish_reason", None) if choices else None
155 return CompletionResult(
156 content=content or "",
157 finish_reason=finish_reason,
158 model=getattr(response, "model", None),
159 )
161 def complete_stream(self, request: CompletionRequest) -> Iterator[StreamChunk]:
162 """Stream a completion through ``litellm.completion(stream=True)``."""
163 litellm = _require_litellm()
164 kwargs = self._completion_kwargs(request, stream=True)
165 try:
166 response = litellm.completion(**kwargs)
167 except Exception as exc:
168 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc
169 return self._stream_chunks(response)
171 @staticmethod
172 def _stream_chunks(response: Any) -> Iterator[StreamChunk]:
173 """Yield ``StreamChunk`` values from a litellm streaming response.
175 Exceptions raised mid-iteration are wrapped in ``ProviderError``
176 so the semantic layer sees a consistent error type regardless of
177 where the SDK failed.
178 """
179 try:
180 for chunk in response:
181 choices = getattr(chunk, "choices", None) or []
182 if not choices:
183 continue
184 delta = getattr(choices[0], "delta", None)
185 content = getattr(delta, "content", "") if delta is not None else ""
186 finish_reason = getattr(choices[0], "finish_reason", None)
187 if content or finish_reason:
188 yield StreamChunk(content=content or "", finish_reason=finish_reason)
189 except ProviderError:
190 raise
191 except Exception as exc:
192 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc
194 @staticmethod
195 def _completion_kwargs(request: CompletionRequest, *, stream: bool) -> dict[str, Any]:
196 """Translate a ``CompletionRequest`` into litellm kwargs."""
197 kwargs: dict[str, Any] = {
198 "model": _route_model(request.ref, request.api_base),
199 "messages": _format_messages(request.messages),
200 "stream": stream,
201 }
202 if request.api_base:
203 kwargs["api_base"] = request.api_base
204 if request.api_key:
205 kwargs["api_key"] = request.api_key
206 if request.options:
207 kwargs.update(request.options)
208 return kwargs
210 def embed(self, request: EmbeddingRequest) -> EmbeddingResult:
211 """Embed inputs through ``litellm.embedding``."""
212 litellm = _require_litellm()
213 kwargs: dict[str, Any] = {
214 "model": _route_model(request.ref, request.api_base),
215 "input": request.inputs,
216 }
217 if request.api_base:
218 kwargs["api_base"] = request.api_base
219 if request.api_key:
220 kwargs["api_key"] = request.api_key
221 try:
222 response = litellm.embedding(**kwargs)
223 except Exception as exc:
224 raise ProviderError(f"Embedding failed: {exc}", provider=_PROVIDER_NAME) from exc
225 data = response["data"] if isinstance(response, dict) else response.data
226 vectors = [item["embedding"] for item in data]
227 if isinstance(response, dict):
228 model = response.get("model")
229 else:
230 model = getattr(response, "model", None)
231 return EmbeddingResult(vectors=vectors, model=model)
233 def rerank(self, request: RerankRequest) -> RerankResult:
234 """Rerank documents via ``litellm.rerank`` (Cohere, Voyage, Jina, Together, HF TEI).
236 The SDK returns results sorted by relevance; we restore input
237 order via each result's ``index`` so scores line up with the
238 caller's ``candidates`` list.
239 """
240 if not request.candidates:
241 return RerankResult(scores=[])
242 litellm = _require_litellm()
243 kwargs: dict[str, Any] = {
244 "model": _route_model(request.ref, request.api_base),
245 "query": request.query,
246 "documents": request.candidates,
247 }
248 if request.api_base:
249 kwargs["api_base"] = request.api_base
250 if request.api_key:
251 kwargs["api_key"] = request.api_key
252 try:
253 response = litellm.rerank(**kwargs)
254 except Exception as exc:
255 raise ProviderError(f"Rerank failed: {exc}", provider=_PROVIDER_NAME) from exc
256 results = response["results"] if isinstance(response, dict) else response.results
257 scores = [0.0] * len(request.candidates)
258 for item in results:
259 idx = item["index"] if isinstance(item, dict) else item.index
260 score = item["relevance_score"] if isinstance(item, dict) else item.relevance_score
261 scores[idx] = float(score)
262 if isinstance(response, dict):
263 model = response.get("model")
264 else:
265 model = getattr(response, "model", None)
266 return RerankResult(scores=scores, model=model)
268 def list_models(self, *, base_url: str, api_key: str) -> list[str]:
269 """List models from Ollama or an OpenAI-compatible server."""
270 clean_base = base_url.rstrip("/")
271 if _is_ollama(clean_base):
272 return self._list_ollama_models(clean_base)
273 return self._list_openai_models(clean_base, api_key)
275 def list_chat_models(self, provider: str) -> list[str]:
276 """Return chat-mode model ids from litellm's static catalog.
278 Empty list when litellm is not installed or the provider has no
279 chat-mode entries. Callers that care about the litellm debug
280 banner should invoke ``configure_logging`` first (the semantic
281 layer's ``SdkLLMProvider`` does this in ``list_chat_models``).
282 """
283 try:
284 import litellm
285 except ImportError:
286 return []
287 models = litellm.models_by_provider.get(provider, set())
288 chat_models: list[str] = []
289 for model_name in sorted(models):
290 info = litellm.model_cost.get(model_name, {})
291 if info.get("mode") != "chat":
292 continue
293 chat_models.append(model_name)
294 return chat_models
296 @staticmethod
297 def _list_ollama_models(base_url: str) -> list[str]:
298 """List models via the Ollama ``/api/tags`` endpoint."""
299 try:
300 resp = httpx.get(f"{base_url}/api/tags", timeout=DEFAULT_HTTP_TIMEOUT)
301 resp.raise_for_status()
302 data = resp.json()
303 return [m["name"] for m in data.get("models", [])]
304 except httpx.HTTPError as exc:
305 raise ProviderError(f"Cannot list models: {exc}", provider=_PROVIDER_NAME) from exc
307 @staticmethod
308 def _list_openai_models(base_url: str, api_key: str) -> list[str]:
309 """List models via an OpenAI-compatible ``/v1/models`` endpoint."""
310 headers: dict[str, str] = {}
311 if api_key:
312 headers["Authorization"] = f"Bearer {api_key}"
313 try:
314 resp = httpx.get(f"{base_url}/v1/models", headers=headers, timeout=DEFAULT_HTTP_TIMEOUT)
315 resp.raise_for_status()
316 data = resp.json()
317 return [m["id"] for m in data.get("data", [])]
318 except httpx.HTTPError:
319 log.debug("Failed to list models via /v1/models", exc_info=True)
320 return []
322 def pull_model(
323 self,
324 model: str,
325 *,
326 base_url: str,
327 on_progress: Callable[..., Any] | None = None,
328 ) -> None:
329 """Pull a model via the Ollama ``/api/pull`` endpoint."""
330 clean_base = base_url.rstrip("/")
331 try:
332 with (
333 # Streaming Ollama /api/pull; unbounded read is intentional
334 # since model downloads can exceed any wall-clock timeout.
335 httpx.Client(timeout=None) as client, # noqa: S113
336 client.stream(
337 "POST",
338 f"{clean_base}/api/pull",
339 json={"name": model, "stream": True},
340 ) as resp,
341 ):
342 resp.raise_for_status()
343 for line in resp.iter_lines():
344 if not line:
345 continue
346 event = json.loads(line)
347 if on_progress:
348 on_progress(event)
349 if event.get("status") == "success":
350 break
351 except httpx.HTTPError as exc:
352 raise ProviderError(
353 f"Cannot pull model {model!r}: {exc}", provider=_PROVIDER_NAME
354 ) from exc
356 def show_model(self, model: str, *, base_url: str) -> dict[str, Any] | None:
357 """Get model info via the Ollama ``/api/show`` endpoint.
359 Parses and caches per-model generation defaults from the
360 ``parameters`` field. Also extracts the ``capabilities`` list
361 (newer Ollama versions) so callers can check for vision support.
362 """
363 clean_base = base_url.rstrip("/")
364 # Ollama's API uses bare model names; the routing-layer prefix has
365 # to come off before the request goes out.
366 ollama_name = model[len(OLLAMA_PREFIX) :] if model.startswith(OLLAMA_PREFIX) else model
367 try:
368 resp = httpx.post(
369 f"{clean_base}/api/show",
370 json={"name": ollama_name},
371 timeout=DEFAULT_HTTP_TIMEOUT,
372 )
373 resp.raise_for_status()
374 data = resp.json()
375 except httpx.HTTPError:
376 return None
378 result: dict[str, Any] = {}
380 params = data.get("parameters", "")
381 if isinstance(params, str) and params:
382 _cache_ollama_defaults(model, params)
383 result["parameters"] = params
384 elif params:
385 _cache_ollama_defaults(model, str(params))
386 result["parameters"] = str(params)
388 capabilities = data.get("capabilities")
389 if isinstance(capabilities, list):
390 result["capabilities"] = capabilities
392 return result or None