Coverage for src / lilbee / providers / worker_process.py: 100%
257 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Subprocess worker for llama-cpp embedding and vision operations.
3Isolates llama-cpp native code (which holds the GIL) in a separate
4process so the parent event loop stays responsive. Communication
5uses multiprocessing queues with spawn start method (fork-unsafe
6with threads).
7"""
9from __future__ import annotations
11import contextlib
12import logging
13import multiprocessing
14import multiprocessing.queues
15import os
16import queue
17import time
18from dataclasses import dataclass, field
19from multiprocessing import get_context
20from typing import Any, TypeVar
22from lilbee.config import cfg
24log = logging.getLogger(__name__)
26_ResponseT = TypeVar("_ResponseT", "EmbedResponse", "VisionResponse")
28_EMBED_TIMEOUT_S = 30.0
29_JOIN_TIMEOUT_S = 5.0
30_RESTART_DELAY_S = 0.1
31# Substituted when ``cfg.ocr_timeout == 0`` ("no limit"). The round-trip
32# wait loop needs a finite deadline; one day is effectively unlimited.
33_NO_CAP_TIMEOUT_S = 86_400.0
36@dataclass(frozen=True)
37class EmbedRequest:
38 """Request to embed a list of texts."""
40 texts: list[str]
41 model: str
42 request_id: int = 0
45@dataclass(frozen=True)
46class EmbedResponse:
47 """Response with embedding vectors or an error."""
49 vectors: list[list[float]] = field(default_factory=list)
50 error: str = ""
51 request_id: int = 0
54@dataclass(frozen=True)
55class VisionRequest:
56 """Request to run vision OCR on a PNG image."""
58 png_bytes: bytes
59 model: str
60 prompt: str = ""
61 request_id: int = 0
64@dataclass(frozen=True)
65class VisionResponse:
66 """Response with extracted text or an error."""
68 text: str = ""
69 error: str = ""
70 request_id: int = 0
73@dataclass(frozen=True)
74class LoadModelRequest:
75 """Request to (re)load a specific model in the worker."""
77 model: str
78 model_type: str = "embed" # "embed" or "vision"
81@dataclass(frozen=True)
82class ShutdownRequest:
83 """Sentinel to tell the worker to exit."""
86@dataclass(frozen=True)
87class ConfigSnapshot:
88 """Minimal config fields needed by the child process."""
90 models_dir: str
91 embedding_model: str
92 embedding_dim: int
93 num_ctx: int | None
94 gpu_memory_fraction: float
95 vision_model: str
98_WorkerRequest = EmbedRequest | VisionRequest | LoadModelRequest | ShutdownRequest
99_WorkerResponse = EmbedResponse | VisionResponse
102def config_snapshot_from_cfg() -> ConfigSnapshot:
103 """Build a ConfigSnapshot from the current cfg singleton."""
104 return ConfigSnapshot(
105 models_dir=str(cfg.models_dir),
106 embedding_model=cfg.embedding_model,
107 embedding_dim=cfg.embedding_dim,
108 num_ctx=cfg.num_ctx,
109 gpu_memory_fraction=cfg.gpu_memory_fraction,
110 vision_model=cfg.vision_model,
111 )
114class WorkerProcess:
115 """Manages a child process for embedding and vision inference.
116 The child loads llama-cpp models independently, avoiding GIL
117 contention and stdout corruption in the parent process.
118 """
120 def __init__(self, config: ConfigSnapshot | None = None) -> None:
121 self._config = config
122 self._process: Any = None
123 self._ctx = get_context("spawn")
124 self._request_queue: multiprocessing.Queue[_WorkerRequest] | None = None
125 self._response_queue: multiprocessing.Queue[_WorkerResponse] | None = None
126 self._next_id = 0
127 self._started = False
129 def _ensure_config(self) -> ConfigSnapshot:
130 if self._config is None:
131 self._config = config_snapshot_from_cfg()
132 return self._config
134 def start(self) -> None:
135 """Launch the child process. No-op if already running."""
136 if self._started and self.is_alive():
137 return
138 config = self._ensure_config()
139 self._request_queue = self._ctx.Queue()
140 self._response_queue = self._ctx.Queue()
141 self._process = self._ctx.Process(
142 target=_worker_main,
143 args=(self._request_queue, self._response_queue, config),
144 daemon=True,
145 )
146 self._process.start()
147 self._started = True
148 log.info("Worker process started (pid=%s)", self._process.pid)
150 def stop(self) -> None:
151 """Send shutdown request, join, terminate if needed."""
152 if self._process is None:
153 self._started = False
154 return
155 with contextlib.suppress(OSError, ValueError, AttributeError):
156 if self._request_queue is not None:
157 self._request_queue.put(ShutdownRequest())
158 self._process.join(timeout=_JOIN_TIMEOUT_S)
159 if self._process.is_alive():
160 log.warning("Worker did not exit gracefully, terminating")
161 self._process.terminate()
162 self._process.join(timeout=2)
163 self._process = None
164 self._started = False
165 log.info("Worker process stopped")
167 def restart(self) -> None:
168 """Stop and restart the worker (e.g. after model change)."""
169 self.stop()
170 time.sleep(_RESTART_DELAY_S)
171 self.start()
173 def is_alive(self) -> bool:
174 """Return True if the child process is running."""
175 return self._process is not None and self._process.is_alive()
177 def _next_request_id(self) -> int:
178 self._next_id += 1
179 return self._next_id
181 def _ensure_started(self) -> None:
182 """Lazy start: launch worker on first request."""
183 if not self._started or not self.is_alive():
184 self.start()
186 def embed(self, texts: list[str], model: str = "") -> list[list[float]]:
187 """Send an embed request and wait for the response.
188 Auto-starts the worker if not running. Retries once on crash.
189 """
190 self._ensure_started()
191 req = EmbedRequest(texts=texts, model=model, request_id=self._next_request_id())
192 resp = self._round_trip(req, EmbedResponse, _EMBED_TIMEOUT_S, label="embed")
193 return resp.vectors
195 def vision_ocr(self, png_bytes: bytes, model: str, prompt: str = "") -> str:
196 """Run vision OCR in the worker, honouring ``cfg.ocr_timeout``.
198 Auto-starts the worker and retries once on crash. ``cfg.ocr_timeout
199 == 0`` means no cap; substituted with ``_NO_CAP_TIMEOUT_S`` for
200 the round-trip wait loop.
201 """
202 self._ensure_started()
203 req = VisionRequest(
204 png_bytes=png_bytes,
205 model=model,
206 prompt=prompt,
207 request_id=self._next_request_id(),
208 )
209 timeout = cfg.ocr_timeout if cfg.ocr_timeout > 0 else _NO_CAP_TIMEOUT_S
210 resp = self._round_trip(req, VisionResponse, timeout, label="vision OCR")
211 return resp.text
213 def _round_trip(
214 self,
215 req: _WorkerRequest,
216 response_type: type[_ResponseT],
217 timeout: float,
218 *,
219 label: str,
220 ) -> _ResponseT:
221 """Send a request, wait for the typed response, restart once on crash."""
222 resp = self._put_and_get(req, timeout)
223 if resp is None:
224 log.warning("Worker crashed during %s, restarting and retrying", label)
225 self.restart()
226 resp = self._put_and_get(req, timeout)
227 if resp is None:
228 raise RuntimeError("Worker crashed again after restart")
229 if not isinstance(resp, response_type):
230 raise RuntimeError(f"Unexpected response type: {type(resp).__name__}")
231 if resp.error:
232 raise RuntimeError(resp.error)
233 return resp
235 def _put_and_get(self, req: _WorkerRequest, timeout: float) -> _WorkerResponse | None:
236 """Enqueue a request and block for the next response. None if worker died."""
237 if self._request_queue is None:
238 raise RuntimeError("Worker not started")
239 self._request_queue.put(req)
240 return self._get_response(timeout=timeout)
242 def _get_response(self, timeout: float) -> _WorkerResponse | None:
243 """Read from response queue. Return None if worker died."""
244 if self._response_queue is None:
245 raise RuntimeError("Worker not started")
246 deadline = time.monotonic() + timeout
247 while time.monotonic() < deadline:
248 if not self.is_alive():
249 return None
250 try:
251 return self._response_queue.get(timeout=min(1.0, timeout))
252 except queue.Empty:
253 continue
254 raise TimeoutError(f"Worker did not respond within {timeout}s")
256 def load_model(self, model: str, model_type: str = "embed") -> None:
257 """Tell the worker to (re)load a model."""
258 self._ensure_started()
259 if self._request_queue is None:
260 raise RuntimeError("Worker not started")
261 self._request_queue.put(LoadModelRequest(model=model, model_type=model_type))
264def _redirect_stdio() -> None: # pragma: no cover
265 """Redirect stdout/stderr to /dev/null for the worker subprocess.
267 Suppresses llama-cpp's C-level prints that would corrupt the parent TUI.
268 Queues use pipes, not stdout.
270 Covered by ``test_redirect_stdio_points_stdout_stderr_to_devnull`` in
271 a subprocess (closing fds 1/2 in-process would deadlock pytest-xdist).
272 Subprocess execution doesn't report back to coverage.py, so the body
273 carries ``# pragma: no cover``.
274 """
275 import os
276 import sys
278 devnull_fd = os.open(os.devnull, os.O_RDWR)
279 os.dup2(devnull_fd, 1) # stdout
280 os.dup2(devnull_fd, 2) # stderr
281 os.close(devnull_fd)
282 sys.stdout = open(os.devnull, "w") # noqa: SIM115
283 sys.stderr = open(os.devnull, "w") # noqa: SIM115
286def _configure_worker_logging() -> None:
287 """Route the worker's Python logs to ``$LILBEE_DATA/worker.log``."""
288 data_dir = os.environ.get("LILBEE_DATA") or ""
289 if not data_dir:
290 return
291 log_path = os.path.join(data_dir, "worker.log")
292 handler = logging.FileHandler(log_path)
293 handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
294 root = logging.getLogger()
295 root.addHandler(handler)
296 root.setLevel(logging.INFO)
299def _worker_main(
300 req_q: multiprocessing.Queue[_WorkerRequest],
301 resp_q: multiprocessing.Queue[_WorkerResponse],
302 config: ConfigSnapshot,
303) -> None:
304 """Child process entry point. Loads models lazily, processes requests."""
305 _redirect_stdio()
306 _configure_worker_logging()
307 _apply_config_snapshot(config)
308 log.info("Worker subprocess online (pid=%s)", os.getpid())
310 embed_llm: Any = None
311 vision_llm: Any = None
312 current_embed_model = ""
313 current_vision_model = ""
315 try:
316 while True:
317 try:
318 request = req_q.get()
319 except (EOFError, OSError):
320 break
322 if isinstance(request, ShutdownRequest):
323 _close_model(embed_llm)
324 _close_model(vision_llm)
325 break
327 if isinstance(request, LoadModelRequest):
328 if request.model_type == "embed":
329 _close_model(embed_llm)
330 embed_llm = None
331 current_embed_model = ""
332 else:
333 _close_model(vision_llm)
334 vision_llm = None
335 current_vision_model = ""
336 continue
338 if isinstance(request, EmbedRequest):
339 model_name = request.model or config.embedding_model
340 if embed_llm is None or model_name != current_embed_model:
341 _close_model(embed_llm)
342 embed_llm = _load_embed_model(model_name)
343 current_embed_model = model_name
344 embed_resp = _handle_embed(embed_llm, request)
345 resp_q.put(embed_resp)
346 continue
348 if isinstance(request, VisionRequest):
349 model_name = request.model or config.vision_model
350 if not model_name:
351 resp_q.put(
352 VisionResponse(
353 request_id=request.request_id,
354 text="",
355 error="No vision model configured; vision OCR is disabled.",
356 )
357 )
358 continue
359 if vision_llm is None or model_name != current_vision_model:
360 _close_model(vision_llm)
361 vision_llm = _load_vision_model(model_name)
362 current_vision_model = model_name
363 vision_resp = _handle_vision(vision_llm, request)
364 resp_q.put(vision_resp)
365 continue
366 finally:
367 with contextlib.suppress(Exception):
368 req_q.close()
369 with contextlib.suppress(Exception):
370 req_q.join_thread()
371 with contextlib.suppress(Exception):
372 resp_q.close()
373 with contextlib.suppress(Exception):
374 resp_q.join_thread()
377def _close_model(model: Any) -> None:
378 """Safely close a llama-cpp model instance."""
379 if model is not None:
380 with contextlib.suppress(Exception):
381 model.close()
384def _apply_config_snapshot(config: ConfigSnapshot) -> None:
385 """Apply parent-process config to the child's cfg singleton.
386 Called exactly once at child startup so later load paths can use
387 ``cfg.models_dir`` etc. without per-request mutation. Per-request
388 mutation would violate the "no mutable module-level globals in
389 request paths" rule in CLAUDE.md.
390 """
391 from pathlib import Path
393 cfg.models_dir = Path(config.models_dir)
394 cfg.embedding_model = config.embedding_model
395 cfg.num_ctx = config.num_ctx
398def _load_embed_model(model_name: str) -> Any:
399 """Load an embedding model in the child process."""
400 from lilbee.providers.llama_cpp_provider import load_llama, resolve_model_path
401 from lilbee.providers.model_cache import MODE_EMBED
403 return load_llama(resolve_model_path(model_name), mode=MODE_EMBED)
406def _load_vision_model(model_name: str) -> Any:
407 """Load a vision model in the child process via the mtmd backend."""
408 from lilbee.providers.llama_cpp_provider import resolve_model_path
409 from lilbee.providers.mtmd_backend import load_vision_llama
411 return load_vision_llama(resolve_model_path(model_name))
414def _handle_embed(llm: Any, request: EmbedRequest) -> EmbedResponse:
415 """Process a single embed request, returning response with vectors or error."""
416 try:
417 from lilbee.providers.llama_cpp_provider import embed_one
419 vectors = [embed_one(llm, text) for text in request.texts]
420 return EmbedResponse(vectors=vectors, request_id=request.request_id)
421 except Exception as exc:
422 return EmbedResponse(error=str(exc), request_id=request.request_id)
425def _handle_vision(llm: Any, request: VisionRequest) -> VisionResponse:
426 """Process a single vision OCR request.
428 Generation is bounded by the model's ``n_ctx`` (llama.cpp stops when
429 the remaining context is exhausted) and by EOT, which the GGUF's own
430 chat template now fires correctly. No extra per-request cap.
431 """
432 try:
433 from lilbee.vision import OCR_PROMPT, build_vision_messages
435 prompt = request.prompt or OCR_PROMPT
436 messages = build_vision_messages(prompt, request.png_bytes)
437 start = time.monotonic()
438 response = llm.create_chat_completion(messages=messages, stream=False)
439 text: str = response["choices"][0]["message"]["content"] or ""
440 usage = response.get("usage", {}) or {}
441 log.info(
442 "vision_ocr request_id=%s wall=%.1fs prompt_tokens=%s completion_tokens=%s chars=%d",
443 request.request_id,
444 time.monotonic() - start,
445 usage.get("prompt_tokens"),
446 usage.get("completion_tokens"),
447 len(text),
448 )
449 return VisionResponse(text=text, request_id=request.request_id)
450 except Exception as exc:
451 return VisionResponse(error=str(exc), request_id=request.request_id)