Coverage for src / lilbee / vision.py: 100%

118 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Vision model OCR extraction for scanned PDFs. 

2 

3Rasterizes PDF pages to PNG, sends each to a local vision model 

4via the configured LLM provider, and concatenates the extracted text. 

5""" 

6 

7import concurrent.futures 

8import contextlib 

9import logging 

10import sys 

11from collections.abc import Iterator 

12from contextlib import AbstractContextManager 

13from pathlib import Path 

14from typing import Any, NamedTuple, cast 

15 

16from lilbee.config import cfg 

17from lilbee.progress import ( 

18 DetailedProgressCallback, 

19 EventType, 

20 ExtractEvent, 

21 noop_callback, 

22 shared_progress, 

23) 

24from lilbee.services import get_services 

25 

26log = logging.getLogger(__name__) 

27 

28 

29class PageText(NamedTuple): 

30 """Extracted text for a single PDF page.""" 

31 

32 page: int 

33 text: str 

34 

35 

36OCR_PROMPT = ( 

37 "Extract ALL text from this page as clean markdown. " 

38 "Preserve table structure using markdown table syntax. " 

39 "Include all rows, columns, headers, and page text exactly as shown." 

40) 

41 

42_RASTER_DPI = 150 

43 

44 

45class _SharedTask: 

46 """Updates the batch task's description with per-page vision progress.""" 

47 

48 def __init__(self, progress: Any, batch_task: Any, name: str, total: int) -> None: 

49 self._progress = progress 

50 self._batch_task = batch_task 

51 self._name = name 

52 self._total = total 

53 self._current = 0 

54 

55 def __enter__(self) -> "_SharedTask": 

56 self._progress.update( 

57 self._batch_task, description=f"Vision OCR {self._name} (0/{self._total})" 

58 ) 

59 return self 

60 

61 def __exit__(self, *_: Any) -> None: 

62 pass # batch loop updates the description after each file completes 

63 

64 def advance(self, _task_id: Any) -> None: 

65 self._current += 1 

66 self._progress.update( 

67 self._batch_task, 

68 description=f"Vision OCR {self._name} ({self._current}/{self._total})", 

69 ) 

70 

71 

72def pdf_page_count(path: Path) -> int: 

73 """Return the number of pages in a PDF without rasterizing.""" 

74 from kreuzberg import PdfPageIterator # lazy: heavy dependency 

75 

76 it = PdfPageIterator(str(path), dpi=_RASTER_DPI) 

77 return len(it) 

78 

79 

80def rasterize_pdf(path: Path) -> Iterator[tuple[int, bytes]]: 

81 """Yield (0-based index, PNG bytes) for each page of a PDF.""" 

82 from kreuzberg import PdfPageIterator # lazy: heavy dependency 

83 

84 with PdfPageIterator(str(path), dpi=_RASTER_DPI) as pages: 

85 yield from pages 

86 

87 

88def _png_to_data_url(png_bytes: bytes) -> str: 

89 """Convert raw PNG bytes to a base64 data URL for OpenAI-compatible messages.""" 

90 import base64 

91 

92 b64 = base64.b64encode(png_bytes).decode("ascii") 

93 return f"data:image/png;base64,{b64}" 

94 

95 

96def build_vision_messages(prompt: str, png_bytes: bytes) -> list[dict]: 

97 """Build OpenAI-compatible messages with image content for vision models. 

98 Uses the multipart content format expected by llama.cpp's mtmd 

99 pipeline. 

100 """ 

101 return [ 

102 { 

103 "role": "user", 

104 "content": [ 

105 {"type": "image_url", "image_url": {"url": _png_to_data_url(png_bytes)}}, 

106 {"type": "text", "text": prompt}, 

107 ], 

108 } 

109 ] 

110 

111 

112def extract_page_text(png_bytes: bytes, model: str, *, timeout: float | None = None) -> str | None: 

113 """Send a page image to a vision model and return extracted text. 

114 If the provider exposes a ``vision_ocr`` method (subprocess-isolated), 

115 that path is preferred. Otherwise falls back to ``provider.chat``. 

116 The *timeout* parameter (seconds) caps wall-clock time for the provider 

117 call using ``concurrent.futures``. ``None`` or ``0`` means no limit. 

118 """ 

119 try: 

120 provider = get_services().provider 

121 

122 # vision_ocr is optional — only llama-cpp provider implements it 

123 if hasattr(provider, "vision_ocr"): 

124 result: str = provider.vision_ocr(png_bytes, model, OCR_PROMPT) # type: ignore[attr-defined] 

125 return result 

126 

127 messages = build_vision_messages(OCR_PROMPT, png_bytes) 

128 

129 if timeout and timeout > 0: 

130 from concurrent.futures import ThreadPoolExecutor 

131 

132 with ThreadPoolExecutor(max_workers=1) as pool: 

133 future = pool.submit(provider.chat, messages, stream=False, model=model) 

134 return cast(str, future.result(timeout=timeout)) 

135 

136 return cast(str, provider.chat(messages, stream=False, model=model)) 

137 except Exception as exc: 

138 log.warning("Vision OCR: page skipped (%s: %s)", type(exc).__name__, exc) 

139 log.debug("Vision OCR traceback for model %s", model, exc_info=True) 

140 return None 

141 

142 

143def _make_progress(name: str, total: int, quiet: bool) -> tuple[AbstractContextManager[Any], Any]: 

144 """Return (context_manager, task_id | None) for optional Rich progress.""" 

145 if quiet: 

146 return contextlib.nullcontext(), None 

147 

148 parent = shared_progress.get(None) 

149 if parent is not None: 

150 progress, batch_task = parent 

151 return _SharedTask(progress, batch_task, name, total), batch_task 

152 

153 from rich.console import Console 

154 from rich.progress import ( # lazy: heavy dependency 

155 BarColumn, 

156 MofNCompleteColumn, 

157 Progress, 

158 TextColumn, 

159 TimeElapsedColumn, 

160 ) 

161 

162 progress = Progress( 

163 TextColumn("{task.description}"), 

164 BarColumn(), 

165 MofNCompleteColumn(), 

166 TimeElapsedColumn(), 

167 transient=True, 

168 console=Console(file=sys.__stderr__ or sys.stderr), 

169 ) 

170 task = progress.add_task(f"Vision OCR {name}", total=total) 

171 return progress, task 

172 

173 

174def _record_page( 

175 fut: concurrent.futures.Future[tuple[int, str | None]], 

176 extracted: dict[int, str | None], 

177 on_progress: DetailedProgressCallback, 

178 progress_ctx: AbstractContextManager[Any], 

179 progress_task: Any, 

180 path: Path, 

181 total: int, 

182) -> None: 

183 """Drain one completed page future into ``extracted`` and fire progress.""" 

184 i, text = fut.result() 

185 extracted[i] = text 

186 log.info( 

187 "Vision OCR completed page %d/%d of %s (%d chars)", 

188 i + 1, 

189 total, 

190 path.name, 

191 len(text) if text else 0, 

192 ) 

193 on_progress( 

194 EventType.EXTRACT, 

195 ExtractEvent(file=path.name, page=i + 1, total_pages=total), 

196 ) 

197 if progress_task is not None: 

198 progress_ctx.advance(progress_task) # type: ignore[attr-defined] 

199 

200 

201def extract_pdf_vision( 

202 path: Path, 

203 model: str, 

204 *, 

205 quiet: bool = False, 

206 timeout: float | None = None, 

207 on_progress: DetailedProgressCallback = noop_callback, 

208) -> list[PageText]: 

209 """Extract text from a PDF using vision model OCR. 

210 Returns a list of (1-based page number, text) tuples for pages that 

211 produced non-empty text. Fires ``extract`` progress events per page. 

212 """ 

213 total = pdf_page_count(path) 

214 if total == 0: 

215 return [] 

216 

217 concurrency = max(1, cfg.vision_concurrency) 

218 extracted: dict[int, str | None] = {} 

219 progress_ctx, progress_task = _make_progress(path.name, total, quiet) 

220 

221 def _extract(i: int, png: bytes) -> tuple[int, str | None]: 

222 log.debug("Vision OCR page %d/%d with %s", i + 1, total, model) 

223 return i, extract_page_text(png, model, timeout=timeout) 

224 

225 # Inflight futures are bounded to roughly ``concurrency * 2`` so raster 

226 # bytes never accumulate beyond what the pool can drain. Fully eager 

227 # submission would buffer every page's PNG in memory before any OCR 

228 # returned. 

229 max_inflight = concurrency * 2 

230 pages_iter = iter(rasterize_pdf(path)) 

231 with progress_ctx, concurrent.futures.ThreadPoolExecutor(max_workers=concurrency) as pool: 

232 inflight: set[concurrent.futures.Future[tuple[int, str | None]]] = set() 

233 for i, png in pages_iter: 

234 if len(inflight) >= max_inflight: 

235 done, inflight = concurrent.futures.wait( 

236 inflight, return_when=concurrent.futures.FIRST_COMPLETED 

237 ) 

238 for fut in done: 

239 _record_page( 

240 fut, extracted, on_progress, progress_ctx, progress_task, path, total 

241 ) 

242 inflight.add(pool.submit(_extract, i, png)) 

243 for fut in concurrent.futures.as_completed(inflight): 

244 _record_page(fut, extracted, on_progress, progress_ctx, progress_task, path, total) 

245 

246 result: list[PageText] = [] 

247 failed = 0 

248 for i in sorted(extracted): 

249 text = extracted[i] 

250 if text is None: 

251 failed += 1 

252 elif text.strip(): 

253 result.append(PageText(i + 1, text)) 

254 

255 if failed: 

256 log.warning("Vision OCR: %d/%d pages failed", failed, total) 

257 if not quiet: 

258 from rich.console import Console 

259 

260 Console(stderr=True).print( 

261 f"[yellow]Vision OCR: {failed}/{total} pages failed, " 

262 f"{len(result)}/{total} extracted[/yellow]" 

263 ) 

264 

265 return result