Coverage for src / lilbee / vision.py: 100%

75 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 08:27 +0000

1"""Vision model OCR extraction for scanned PDFs. 

2 

3Rasterizes PDF pages to PNG, sends each to a local vision model 

4via Ollama, and concatenates the extracted text. 

5""" 

6 

7import contextlib 

8import io 

9import logging 

10from collections.abc import Iterator 

11from contextlib import AbstractContextManager 

12from pathlib import Path 

13from typing import Any 

14 

15from lilbee.progress import DetailedProgressCallback, EventType, noop_callback 

16 

17log = logging.getLogger(__name__) 

18 

19_OCR_PROMPT = ( 

20 "Extract ALL text from this page as clean markdown. " 

21 "Preserve table structure using markdown table syntax. " 

22 "Include all rows, columns, headers, and page text exactly as shown." 

23) 

24 

25_RASTER_DPI = 150 

26 

27 

28def pdf_page_count(path: Path) -> int: 

29 """Return the number of pages in a PDF without rasterizing.""" 

30 import pypdfium2 as pdfium # lazy: heavy dependency 

31 

32 pdf = pdfium.PdfDocument(path) 

33 try: 

34 return len(pdf) 

35 finally: 

36 pdf.close() 

37 

38 

39def rasterize_pdf(path: Path) -> Iterator[tuple[int, bytes]]: 

40 """Yield (0-based index, PNG bytes) for each page of a PDF.""" 

41 import pypdfium2 as pdfium # lazy: heavy dependency 

42 

43 pdf = pdfium.PdfDocument(path) 

44 try: 

45 scale = _RASTER_DPI / 72 

46 for i in range(len(pdf)): 

47 page = pdf[i] 

48 bitmap = page.render(scale=scale) 

49 pil_image = bitmap.to_pil() 

50 buf = io.BytesIO() 

51 pil_image.save(buf, format="PNG") 

52 page.close() 

53 yield (i, buf.getvalue()) 

54 finally: 

55 pdf.close() 

56 

57 

58def extract_page_text(png_bytes: bytes, model: str, *, timeout: float | None = None) -> str | None: 

59 """Send a page image to a vision model and return extracted text.""" 

60 import ollama # lazy: heavy dependency 

61 

62 try: 

63 messages = [{"role": "user", "content": _OCR_PROMPT, "images": [png_bytes]}] 

64 if timeout is not None and timeout > 0: 

65 client = ollama.Client(timeout=timeout) 

66 response = client.chat(model=model, messages=messages) 

67 else: 

68 response = ollama.chat(model=model, messages=messages) 

69 return str(response.message.content or "") 

70 except Exception as exc: 

71 log.warning("Vision OCR: page skipped (%s: %s)", type(exc).__name__, exc) 

72 log.debug("Vision OCR traceback for model %s", model, exc_info=True) 

73 return None 

74 

75 

76def _make_progress(name: str, total: int, quiet: bool) -> tuple[AbstractContextManager[Any], Any]: 

77 """Return (context_manager, task_id | None) for optional Rich progress.""" 

78 if quiet: 

79 return contextlib.nullcontext(), None 

80 from rich.progress import ( # lazy: heavy dependency 

81 BarColumn, 

82 MofNCompleteColumn, 

83 Progress, 

84 TextColumn, 

85 TimeElapsedColumn, 

86 ) 

87 

88 progress = Progress( 

89 TextColumn("{task.description}"), 

90 BarColumn(), 

91 MofNCompleteColumn(), 

92 TimeElapsedColumn(), 

93 transient=True, 

94 ) 

95 task = progress.add_task(f"Vision OCR {name}", total=total) 

96 return progress, task 

97 

98 

99def extract_pdf_vision( 

100 path: Path, 

101 model: str, 

102 *, 

103 quiet: bool = False, 

104 timeout: float | None = None, 

105 on_progress: DetailedProgressCallback = noop_callback, 

106) -> list[tuple[int, str]]: 

107 """Extract text from a PDF using vision model OCR. 

108 

109 Returns a list of (1-based page number, text) tuples for pages that 

110 produced non-empty text. Fires ``extract`` progress events per page. 

111 """ 

112 total = pdf_page_count(path) 

113 if total == 0: 

114 return [] 

115 

116 result: list[tuple[int, str]] = [] 

117 failed = 0 

118 progress_ctx, progress_task = _make_progress(path.name, total, quiet) 

119 

120 with progress_ctx: 

121 for i, png in rasterize_pdf(path): 

122 on_progress( 

123 EventType.EXTRACT, 

124 {"file": path.name, "page": i + 1, "total_pages": total}, 

125 ) 

126 log.debug("Vision OCR page %d/%d with %s", i + 1, total, model) 

127 text = extract_page_text(png, model, timeout=timeout) 

128 if text is None: 

129 failed += 1 

130 elif text.strip(): 

131 result.append((i + 1, text)) 

132 if progress_task is not None: 

133 progress_ctx.advance(progress_task) # type: ignore[attr-defined] 

134 

135 if failed: 

136 log.warning("Vision OCR: %d/%d pages failed", failed, total) 

137 if not quiet: 

138 from rich.console import Console 

139 

140 Console(stderr=True).print( 

141 f"[yellow]Vision OCR: {failed}/{total} pages failed, " 

142 f"{len(result)}/{total} extracted[/yellow]" 

143 ) 

144 

145 return result