Coverage for src / lilbee / providers / worker_process.py: 100%

257 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Subprocess worker for llama-cpp embedding and vision operations. 

2 

3Isolates llama-cpp native code (which holds the GIL) in a separate 

4process so the parent event loop stays responsive. Communication 

5uses multiprocessing queues with spawn start method (fork-unsafe 

6with threads). 

7""" 

8 

9from __future__ import annotations 

10 

11import contextlib 

12import logging 

13import multiprocessing 

14import multiprocessing.queues 

15import os 

16import queue 

17import time 

18from dataclasses import dataclass, field 

19from multiprocessing import get_context 

20from typing import Any, TypeVar 

21 

22from lilbee.config import cfg 

23 

24log = logging.getLogger(__name__) 

25 

26_ResponseT = TypeVar("_ResponseT", "EmbedResponse", "VisionResponse") 

27 

28_EMBED_TIMEOUT_S = 30.0 

29_JOIN_TIMEOUT_S = 5.0 

30_RESTART_DELAY_S = 0.1 

31# Substituted when ``cfg.ocr_timeout == 0`` ("no limit"). The round-trip 

32# wait loop needs a finite deadline; one day is effectively unlimited. 

33_NO_CAP_TIMEOUT_S = 86_400.0 

34 

35 

36@dataclass(frozen=True) 

37class EmbedRequest: 

38 """Request to embed a list of texts.""" 

39 

40 texts: list[str] 

41 model: str 

42 request_id: int = 0 

43 

44 

45@dataclass(frozen=True) 

46class EmbedResponse: 

47 """Response with embedding vectors or an error.""" 

48 

49 vectors: list[list[float]] = field(default_factory=list) 

50 error: str = "" 

51 request_id: int = 0 

52 

53 

54@dataclass(frozen=True) 

55class VisionRequest: 

56 """Request to run vision OCR on a PNG image.""" 

57 

58 png_bytes: bytes 

59 model: str 

60 prompt: str = "" 

61 request_id: int = 0 

62 

63 

64@dataclass(frozen=True) 

65class VisionResponse: 

66 """Response with extracted text or an error.""" 

67 

68 text: str = "" 

69 error: str = "" 

70 request_id: int = 0 

71 

72 

73@dataclass(frozen=True) 

74class LoadModelRequest: 

75 """Request to (re)load a specific model in the worker.""" 

76 

77 model: str 

78 model_type: str = "embed" # "embed" or "vision" 

79 

80 

81@dataclass(frozen=True) 

82class ShutdownRequest: 

83 """Sentinel to tell the worker to exit.""" 

84 

85 

86@dataclass(frozen=True) 

87class ConfigSnapshot: 

88 """Minimal config fields needed by the child process.""" 

89 

90 models_dir: str 

91 embedding_model: str 

92 embedding_dim: int 

93 num_ctx: int | None 

94 gpu_memory_fraction: float 

95 vision_model: str 

96 

97 

98_WorkerRequest = EmbedRequest | VisionRequest | LoadModelRequest | ShutdownRequest 

99_WorkerResponse = EmbedResponse | VisionResponse 

100 

101 

102def config_snapshot_from_cfg() -> ConfigSnapshot: 

103 """Build a ConfigSnapshot from the current cfg singleton.""" 

104 return ConfigSnapshot( 

105 models_dir=str(cfg.models_dir), 

106 embedding_model=cfg.embedding_model, 

107 embedding_dim=cfg.embedding_dim, 

108 num_ctx=cfg.num_ctx, 

109 gpu_memory_fraction=cfg.gpu_memory_fraction, 

110 vision_model=cfg.vision_model, 

111 ) 

112 

113 

114class WorkerProcess: 

115 """Manages a child process for embedding and vision inference. 

116 The child loads llama-cpp models independently, avoiding GIL 

117 contention and stdout corruption in the parent process. 

118 """ 

119 

120 def __init__(self, config: ConfigSnapshot | None = None) -> None: 

121 self._config = config 

122 self._process: Any = None 

123 self._ctx = get_context("spawn") 

124 self._request_queue: multiprocessing.Queue[_WorkerRequest] | None = None 

125 self._response_queue: multiprocessing.Queue[_WorkerResponse] | None = None 

126 self._next_id = 0 

127 self._started = False 

128 

129 def _ensure_config(self) -> ConfigSnapshot: 

130 if self._config is None: 

131 self._config = config_snapshot_from_cfg() 

132 return self._config 

133 

134 def start(self) -> None: 

135 """Launch the child process. No-op if already running.""" 

136 if self._started and self.is_alive(): 

137 return 

138 config = self._ensure_config() 

139 self._request_queue = self._ctx.Queue() 

140 self._response_queue = self._ctx.Queue() 

141 self._process = self._ctx.Process( 

142 target=_worker_main, 

143 args=(self._request_queue, self._response_queue, config), 

144 daemon=True, 

145 ) 

146 self._process.start() 

147 self._started = True 

148 log.info("Worker process started (pid=%s)", self._process.pid) 

149 

150 def stop(self) -> None: 

151 """Send shutdown request, join, terminate if needed.""" 

152 if self._process is None: 

153 self._started = False 

154 return 

155 with contextlib.suppress(OSError, ValueError, AttributeError): 

156 if self._request_queue is not None: 

157 self._request_queue.put(ShutdownRequest()) 

158 self._process.join(timeout=_JOIN_TIMEOUT_S) 

159 if self._process.is_alive(): 

160 log.warning("Worker did not exit gracefully, terminating") 

161 self._process.terminate() 

162 self._process.join(timeout=2) 

163 self._process = None 

164 self._started = False 

165 log.info("Worker process stopped") 

166 

167 def restart(self) -> None: 

168 """Stop and restart the worker (e.g. after model change).""" 

169 self.stop() 

170 time.sleep(_RESTART_DELAY_S) 

171 self.start() 

172 

173 def is_alive(self) -> bool: 

174 """Return True if the child process is running.""" 

175 return self._process is not None and self._process.is_alive() 

176 

177 def _next_request_id(self) -> int: 

178 self._next_id += 1 

179 return self._next_id 

180 

181 def _ensure_started(self) -> None: 

182 """Lazy start: launch worker on first request.""" 

183 if not self._started or not self.is_alive(): 

184 self.start() 

185 

186 def embed(self, texts: list[str], model: str = "") -> list[list[float]]: 

187 """Send an embed request and wait for the response. 

188 Auto-starts the worker if not running. Retries once on crash. 

189 """ 

190 self._ensure_started() 

191 req = EmbedRequest(texts=texts, model=model, request_id=self._next_request_id()) 

192 resp = self._round_trip(req, EmbedResponse, _EMBED_TIMEOUT_S, label="embed") 

193 return resp.vectors 

194 

195 def vision_ocr(self, png_bytes: bytes, model: str, prompt: str = "") -> str: 

196 """Run vision OCR in the worker, honouring ``cfg.ocr_timeout``. 

197 

198 Auto-starts the worker and retries once on crash. ``cfg.ocr_timeout 

199 == 0`` means no cap; substituted with ``_NO_CAP_TIMEOUT_S`` for 

200 the round-trip wait loop. 

201 """ 

202 self._ensure_started() 

203 req = VisionRequest( 

204 png_bytes=png_bytes, 

205 model=model, 

206 prompt=prompt, 

207 request_id=self._next_request_id(), 

208 ) 

209 timeout = cfg.ocr_timeout if cfg.ocr_timeout > 0 else _NO_CAP_TIMEOUT_S 

210 resp = self._round_trip(req, VisionResponse, timeout, label="vision OCR") 

211 return resp.text 

212 

213 def _round_trip( 

214 self, 

215 req: _WorkerRequest, 

216 response_type: type[_ResponseT], 

217 timeout: float, 

218 *, 

219 label: str, 

220 ) -> _ResponseT: 

221 """Send a request, wait for the typed response, restart once on crash.""" 

222 resp = self._put_and_get(req, timeout) 

223 if resp is None: 

224 log.warning("Worker crashed during %s, restarting and retrying", label) 

225 self.restart() 

226 resp = self._put_and_get(req, timeout) 

227 if resp is None: 

228 raise RuntimeError("Worker crashed again after restart") 

229 if not isinstance(resp, response_type): 

230 raise RuntimeError(f"Unexpected response type: {type(resp).__name__}") 

231 if resp.error: 

232 raise RuntimeError(resp.error) 

233 return resp 

234 

235 def _put_and_get(self, req: _WorkerRequest, timeout: float) -> _WorkerResponse | None: 

236 """Enqueue a request and block for the next response. None if worker died.""" 

237 if self._request_queue is None: 

238 raise RuntimeError("Worker not started") 

239 self._request_queue.put(req) 

240 return self._get_response(timeout=timeout) 

241 

242 def _get_response(self, timeout: float) -> _WorkerResponse | None: 

243 """Read from response queue. Return None if worker died.""" 

244 if self._response_queue is None: 

245 raise RuntimeError("Worker not started") 

246 deadline = time.monotonic() + timeout 

247 while time.monotonic() < deadline: 

248 if not self.is_alive(): 

249 return None 

250 try: 

251 return self._response_queue.get(timeout=min(1.0, timeout)) 

252 except queue.Empty: 

253 continue 

254 raise TimeoutError(f"Worker did not respond within {timeout}s") 

255 

256 def load_model(self, model: str, model_type: str = "embed") -> None: 

257 """Tell the worker to (re)load a model.""" 

258 self._ensure_started() 

259 if self._request_queue is None: 

260 raise RuntimeError("Worker not started") 

261 self._request_queue.put(LoadModelRequest(model=model, model_type=model_type)) 

262 

263 

264def _redirect_stdio() -> None: # pragma: no cover 

265 """Redirect stdout/stderr to /dev/null for the worker subprocess. 

266 

267 Suppresses llama-cpp's C-level prints that would corrupt the parent TUI. 

268 Queues use pipes, not stdout. 

269 

270 Covered by ``test_redirect_stdio_points_stdout_stderr_to_devnull`` in 

271 a subprocess (closing fds 1/2 in-process would deadlock pytest-xdist). 

272 Subprocess execution doesn't report back to coverage.py, so the body 

273 carries ``# pragma: no cover``. 

274 """ 

275 import os 

276 import sys 

277 

278 devnull_fd = os.open(os.devnull, os.O_RDWR) 

279 os.dup2(devnull_fd, 1) # stdout 

280 os.dup2(devnull_fd, 2) # stderr 

281 os.close(devnull_fd) 

282 sys.stdout = open(os.devnull, "w") # noqa: SIM115 

283 sys.stderr = open(os.devnull, "w") # noqa: SIM115 

284 

285 

286def _configure_worker_logging() -> None: 

287 """Route the worker's Python logs to ``$LILBEE_DATA/worker.log``.""" 

288 data_dir = os.environ.get("LILBEE_DATA") or "" 

289 if not data_dir: 

290 return 

291 log_path = os.path.join(data_dir, "worker.log") 

292 handler = logging.FileHandler(log_path) 

293 handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) 

294 root = logging.getLogger() 

295 root.addHandler(handler) 

296 root.setLevel(logging.INFO) 

297 

298 

299def _worker_main( 

300 req_q: multiprocessing.Queue[_WorkerRequest], 

301 resp_q: multiprocessing.Queue[_WorkerResponse], 

302 config: ConfigSnapshot, 

303) -> None: 

304 """Child process entry point. Loads models lazily, processes requests.""" 

305 _redirect_stdio() 

306 _configure_worker_logging() 

307 _apply_config_snapshot(config) 

308 log.info("Worker subprocess online (pid=%s)", os.getpid()) 

309 

310 embed_llm: Any = None 

311 vision_llm: Any = None 

312 current_embed_model = "" 

313 current_vision_model = "" 

314 

315 try: 

316 while True: 

317 try: 

318 request = req_q.get() 

319 except (EOFError, OSError): 

320 break 

321 

322 if isinstance(request, ShutdownRequest): 

323 _close_model(embed_llm) 

324 _close_model(vision_llm) 

325 break 

326 

327 if isinstance(request, LoadModelRequest): 

328 if request.model_type == "embed": 

329 _close_model(embed_llm) 

330 embed_llm = None 

331 current_embed_model = "" 

332 else: 

333 _close_model(vision_llm) 

334 vision_llm = None 

335 current_vision_model = "" 

336 continue 

337 

338 if isinstance(request, EmbedRequest): 

339 model_name = request.model or config.embedding_model 

340 if embed_llm is None or model_name != current_embed_model: 

341 _close_model(embed_llm) 

342 embed_llm = _load_embed_model(model_name) 

343 current_embed_model = model_name 

344 embed_resp = _handle_embed(embed_llm, request) 

345 resp_q.put(embed_resp) 

346 continue 

347 

348 if isinstance(request, VisionRequest): 

349 model_name = request.model or config.vision_model 

350 if not model_name: 

351 resp_q.put( 

352 VisionResponse( 

353 request_id=request.request_id, 

354 text="", 

355 error="No vision model configured; vision OCR is disabled.", 

356 ) 

357 ) 

358 continue 

359 if vision_llm is None or model_name != current_vision_model: 

360 _close_model(vision_llm) 

361 vision_llm = _load_vision_model(model_name) 

362 current_vision_model = model_name 

363 vision_resp = _handle_vision(vision_llm, request) 

364 resp_q.put(vision_resp) 

365 continue 

366 finally: 

367 with contextlib.suppress(Exception): 

368 req_q.close() 

369 with contextlib.suppress(Exception): 

370 req_q.join_thread() 

371 with contextlib.suppress(Exception): 

372 resp_q.close() 

373 with contextlib.suppress(Exception): 

374 resp_q.join_thread() 

375 

376 

377def _close_model(model: Any) -> None: 

378 """Safely close a llama-cpp model instance.""" 

379 if model is not None: 

380 with contextlib.suppress(Exception): 

381 model.close() 

382 

383 

384def _apply_config_snapshot(config: ConfigSnapshot) -> None: 

385 """Apply parent-process config to the child's cfg singleton. 

386 Called exactly once at child startup so later load paths can use 

387 ``cfg.models_dir`` etc. without per-request mutation. Per-request 

388 mutation would violate the "no mutable module-level globals in 

389 request paths" rule in CLAUDE.md. 

390 """ 

391 from pathlib import Path 

392 

393 cfg.models_dir = Path(config.models_dir) 

394 cfg.embedding_model = config.embedding_model 

395 cfg.num_ctx = config.num_ctx 

396 

397 

398def _load_embed_model(model_name: str) -> Any: 

399 """Load an embedding model in the child process.""" 

400 from lilbee.providers.llama_cpp_provider import load_llama, resolve_model_path 

401 from lilbee.providers.model_cache import MODE_EMBED 

402 

403 return load_llama(resolve_model_path(model_name), mode=MODE_EMBED) 

404 

405 

406def _load_vision_model(model_name: str) -> Any: 

407 """Load a vision model in the child process via the mtmd backend.""" 

408 from lilbee.providers.llama_cpp_provider import resolve_model_path 

409 from lilbee.providers.mtmd_backend import load_vision_llama 

410 

411 return load_vision_llama(resolve_model_path(model_name)) 

412 

413 

414def _handle_embed(llm: Any, request: EmbedRequest) -> EmbedResponse: 

415 """Process a single embed request, returning response with vectors or error.""" 

416 try: 

417 from lilbee.providers.llama_cpp_provider import embed_one 

418 

419 vectors = [embed_one(llm, text) for text in request.texts] 

420 return EmbedResponse(vectors=vectors, request_id=request.request_id) 

421 except Exception as exc: 

422 return EmbedResponse(error=str(exc), request_id=request.request_id) 

423 

424 

425def _handle_vision(llm: Any, request: VisionRequest) -> VisionResponse: 

426 """Process a single vision OCR request. 

427 

428 Generation is bounded by the model's ``n_ctx`` (llama.cpp stops when 

429 the remaining context is exhausted) and by EOT, which the GGUF's own 

430 chat template now fires correctly. No extra per-request cap. 

431 """ 

432 try: 

433 from lilbee.vision import OCR_PROMPT, build_vision_messages 

434 

435 prompt = request.prompt or OCR_PROMPT 

436 messages = build_vision_messages(prompt, request.png_bytes) 

437 start = time.monotonic() 

438 response = llm.create_chat_completion(messages=messages, stream=False) 

439 text: str = response["choices"][0]["message"]["content"] or "" 

440 usage = response.get("usage", {}) or {} 

441 log.info( 

442 "vision_ocr request_id=%s wall=%.1fs prompt_tokens=%s completion_tokens=%s chars=%d", 

443 request.request_id, 

444 time.monotonic() - start, 

445 usage.get("prompt_tokens"), 

446 usage.get("completion_tokens"), 

447 len(text), 

448 ) 

449 return VisionResponse(text=text, request_id=request.request_id) 

450 except Exception as exc: 

451 return VisionResponse(error=str(exc), request_id=request.request_id)