Coverage for src / lilbee / vision.py: 100%
118 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Vision model OCR extraction for scanned PDFs.
3Rasterizes PDF pages to PNG, sends each to a local vision model
4via the configured LLM provider, and concatenates the extracted text.
5"""
7import concurrent.futures
8import contextlib
9import logging
10import sys
11from collections.abc import Iterator
12from contextlib import AbstractContextManager
13from pathlib import Path
14from typing import Any, NamedTuple, cast
16from lilbee.config import cfg
17from lilbee.progress import (
18 DetailedProgressCallback,
19 EventType,
20 ExtractEvent,
21 noop_callback,
22 shared_progress,
23)
24from lilbee.services import get_services
26log = logging.getLogger(__name__)
29class PageText(NamedTuple):
30 """Extracted text for a single PDF page."""
32 page: int
33 text: str
36OCR_PROMPT = (
37 "Extract ALL text from this page as clean markdown. "
38 "Preserve table structure using markdown table syntax. "
39 "Include all rows, columns, headers, and page text exactly as shown."
40)
42_RASTER_DPI = 150
45class _SharedTask:
46 """Updates the batch task's description with per-page vision progress."""
48 def __init__(self, progress: Any, batch_task: Any, name: str, total: int) -> None:
49 self._progress = progress
50 self._batch_task = batch_task
51 self._name = name
52 self._total = total
53 self._current = 0
55 def __enter__(self) -> "_SharedTask":
56 self._progress.update(
57 self._batch_task, description=f"Vision OCR {self._name} (0/{self._total})"
58 )
59 return self
61 def __exit__(self, *_: Any) -> None:
62 pass # batch loop updates the description after each file completes
64 def advance(self, _task_id: Any) -> None:
65 self._current += 1
66 self._progress.update(
67 self._batch_task,
68 description=f"Vision OCR {self._name} ({self._current}/{self._total})",
69 )
72def pdf_page_count(path: Path) -> int:
73 """Return the number of pages in a PDF without rasterizing."""
74 from kreuzberg import PdfPageIterator # lazy: heavy dependency
76 it = PdfPageIterator(str(path), dpi=_RASTER_DPI)
77 return len(it)
80def rasterize_pdf(path: Path) -> Iterator[tuple[int, bytes]]:
81 """Yield (0-based index, PNG bytes) for each page of a PDF."""
82 from kreuzberg import PdfPageIterator # lazy: heavy dependency
84 with PdfPageIterator(str(path), dpi=_RASTER_DPI) as pages:
85 yield from pages
88def _png_to_data_url(png_bytes: bytes) -> str:
89 """Convert raw PNG bytes to a base64 data URL for OpenAI-compatible messages."""
90 import base64
92 b64 = base64.b64encode(png_bytes).decode("ascii")
93 return f"data:image/png;base64,{b64}"
96def build_vision_messages(prompt: str, png_bytes: bytes) -> list[dict]:
97 """Build OpenAI-compatible messages with image content for vision models.
98 Uses the multipart content format expected by llama.cpp's mtmd
99 pipeline.
100 """
101 return [
102 {
103 "role": "user",
104 "content": [
105 {"type": "image_url", "image_url": {"url": _png_to_data_url(png_bytes)}},
106 {"type": "text", "text": prompt},
107 ],
108 }
109 ]
112def extract_page_text(png_bytes: bytes, model: str, *, timeout: float | None = None) -> str | None:
113 """Send a page image to a vision model and return extracted text.
114 If the provider exposes a ``vision_ocr`` method (subprocess-isolated),
115 that path is preferred. Otherwise falls back to ``provider.chat``.
116 The *timeout* parameter (seconds) caps wall-clock time for the provider
117 call using ``concurrent.futures``. ``None`` or ``0`` means no limit.
118 """
119 try:
120 provider = get_services().provider
122 # vision_ocr is optional — only llama-cpp provider implements it
123 if hasattr(provider, "vision_ocr"):
124 result: str = provider.vision_ocr(png_bytes, model, OCR_PROMPT) # type: ignore[attr-defined]
125 return result
127 messages = build_vision_messages(OCR_PROMPT, png_bytes)
129 if timeout and timeout > 0:
130 from concurrent.futures import ThreadPoolExecutor
132 with ThreadPoolExecutor(max_workers=1) as pool:
133 future = pool.submit(provider.chat, messages, stream=False, model=model)
134 return cast(str, future.result(timeout=timeout))
136 return cast(str, provider.chat(messages, stream=False, model=model))
137 except Exception as exc:
138 log.warning("Vision OCR: page skipped (%s: %s)", type(exc).__name__, exc)
139 log.debug("Vision OCR traceback for model %s", model, exc_info=True)
140 return None
143def _make_progress(name: str, total: int, quiet: bool) -> tuple[AbstractContextManager[Any], Any]:
144 """Return (context_manager, task_id | None) for optional Rich progress."""
145 if quiet:
146 return contextlib.nullcontext(), None
148 parent = shared_progress.get(None)
149 if parent is not None:
150 progress, batch_task = parent
151 return _SharedTask(progress, batch_task, name, total), batch_task
153 from rich.console import Console
154 from rich.progress import ( # lazy: heavy dependency
155 BarColumn,
156 MofNCompleteColumn,
157 Progress,
158 TextColumn,
159 TimeElapsedColumn,
160 )
162 progress = Progress(
163 TextColumn("{task.description}"),
164 BarColumn(),
165 MofNCompleteColumn(),
166 TimeElapsedColumn(),
167 transient=True,
168 console=Console(file=sys.__stderr__ or sys.stderr),
169 )
170 task = progress.add_task(f"Vision OCR {name}", total=total)
171 return progress, task
174def _record_page(
175 fut: concurrent.futures.Future[tuple[int, str | None]],
176 extracted: dict[int, str | None],
177 on_progress: DetailedProgressCallback,
178 progress_ctx: AbstractContextManager[Any],
179 progress_task: Any,
180 path: Path,
181 total: int,
182) -> None:
183 """Drain one completed page future into ``extracted`` and fire progress."""
184 i, text = fut.result()
185 extracted[i] = text
186 log.info(
187 "Vision OCR completed page %d/%d of %s (%d chars)",
188 i + 1,
189 total,
190 path.name,
191 len(text) if text else 0,
192 )
193 on_progress(
194 EventType.EXTRACT,
195 ExtractEvent(file=path.name, page=i + 1, total_pages=total),
196 )
197 if progress_task is not None:
198 progress_ctx.advance(progress_task) # type: ignore[attr-defined]
201def extract_pdf_vision(
202 path: Path,
203 model: str,
204 *,
205 quiet: bool = False,
206 timeout: float | None = None,
207 on_progress: DetailedProgressCallback = noop_callback,
208) -> list[PageText]:
209 """Extract text from a PDF using vision model OCR.
210 Returns a list of (1-based page number, text) tuples for pages that
211 produced non-empty text. Fires ``extract`` progress events per page.
212 """
213 total = pdf_page_count(path)
214 if total == 0:
215 return []
217 concurrency = max(1, cfg.vision_concurrency)
218 extracted: dict[int, str | None] = {}
219 progress_ctx, progress_task = _make_progress(path.name, total, quiet)
221 def _extract(i: int, png: bytes) -> tuple[int, str | None]:
222 log.debug("Vision OCR page %d/%d with %s", i + 1, total, model)
223 return i, extract_page_text(png, model, timeout=timeout)
225 # Inflight futures are bounded to roughly ``concurrency * 2`` so raster
226 # bytes never accumulate beyond what the pool can drain. Fully eager
227 # submission would buffer every page's PNG in memory before any OCR
228 # returned.
229 max_inflight = concurrency * 2
230 pages_iter = iter(rasterize_pdf(path))
231 with progress_ctx, concurrent.futures.ThreadPoolExecutor(max_workers=concurrency) as pool:
232 inflight: set[concurrent.futures.Future[tuple[int, str | None]]] = set()
233 for i, png in pages_iter:
234 if len(inflight) >= max_inflight:
235 done, inflight = concurrent.futures.wait(
236 inflight, return_when=concurrent.futures.FIRST_COMPLETED
237 )
238 for fut in done:
239 _record_page(
240 fut, extracted, on_progress, progress_ctx, progress_task, path, total
241 )
242 inflight.add(pool.submit(_extract, i, png))
243 for fut in concurrent.futures.as_completed(inflight):
244 _record_page(fut, extracted, on_progress, progress_ctx, progress_task, path, total)
246 result: list[PageText] = []
247 failed = 0
248 for i in sorted(extracted):
249 text = extracted[i]
250 if text is None:
251 failed += 1
252 elif text.strip():
253 result.append(PageText(i + 1, text))
255 if failed:
256 log.warning("Vision OCR: %d/%d pages failed", failed, total)
257 if not quiet:
258 from rich.console import Console
260 Console(stderr=True).print(
261 f"[yellow]Vision OCR: {failed}/{total} pages failed, "
262 f"{len(result)}/{total} extracted[/yellow]"
263 )
265 return result