Coverage for src / lilbee / store.py: 100%

463 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""LanceDB vector store operations.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6import math 

7import threading 

8from dataclasses import dataclass 

9from datetime import UTC, datetime, timedelta 

10from enum import StrEnum 

11from pathlib import Path 

12from typing import TYPE_CHECKING, TypedDict 

13 

14import numpy as np 

15import pyarrow as pa 

16import pyarrow.compute as pc 

17from pydantic import BaseModel, ConfigDict, Field, field_validator 

18 

19if TYPE_CHECKING: 

20 import lancedb 

21 import lancedb.table 

22 

23from lilbee.config import ( 

24 CHUNKS_TABLE, 

25 CITATIONS_TABLE, 

26 META_TABLE, 

27 SOURCES_TABLE, 

28 Config, 

29 cfg, 

30) 

31from lilbee.lock import write_lock 

32from lilbee.security import validate_path_within 

33 

34log = logging.getLogger(__name__) 

35 

36 

37def install_lancedb_thread_error_suppressor() -> None: 

38 """Install a ``threading.excepthook`` that swallows lancedb shutdown noise. 

39 lancedb has no ``close()`` API and its internal event loop thread crashes 

40 during Python interpreter teardown. The exception is harmless (the process 

41 is exiting anyway) but pollutes CLI/TUI output. This is opt-in so importing 

42 ``lilbee.store`` has no hidden side effects; call it once from the CLI/TUI 

43 bootstrap. 

44 """ 

45 original = threading.excepthook 

46 

47 def _hook(args: threading.ExceptHookArgs) -> None: 

48 if args.thread and "LanceDB" in args.thread.name: 

49 return 

50 original(args) 

51 

52 threading.excepthook = _hook 

53 

54 

55# How often readers re-check the manifest for new versions from other processes. 

56# Zero means strong consistency (every read checks); higher values reduce disk I/O 

57# on slow media (HDD) at the cost of serving slightly stale data. 

58READ_CONSISTENCY_INTERVAL = timedelta(seconds=5) 

59 

60# Values for the ``chunk_type`` column. Everything goes in as raw except wiki 

61# pages written by the wiki producer; callers filter with ``Store.search(chunk_type=...)``. 

62CHUNK_TYPE_RAW = "raw" 

63CHUNK_TYPE_WIKI = "wiki" 

64 

65 

66class SearchScope(StrEnum): 

67 """What the user wants to search over. 

68 

69 Values are used as-is on CLI flags, MCP params, and HTTP query strings. 

70 ``BOTH`` resolves to a ``None`` ``chunk_type`` (no filter); the two 

71 others map 1:1 to the chunks-table values. 

72 """ 

73 

74 RAW = CHUNK_TYPE_RAW 

75 WIKI = CHUNK_TYPE_WIKI 

76 BOTH = "both" 

77 

78 

79def scope_to_chunk_type(scope: SearchScope | str | None) -> str | None: 

80 """Translate a user-facing scope into a ``Store.search`` ``chunk_type`` arg. 

81 

82 ``None``/``"both"`` → no filter. ``"raw"`` / ``"wiki"`` → the matching 

83 chunks-table value. Raises ``ValueError`` on any other string. 

84 """ 

85 if scope is None: 

86 return None 

87 normalized = SearchScope(scope) 

88 if normalized is SearchScope.BOTH: 

89 return None 

90 return normalized.value 

91 

92 

93class SearchChunk(BaseModel): 

94 """A search result from LanceDB. 

95 Hybrid results have ``relevance_score`` set (higher = better). 

96 Vector-only results have ``distance`` set (lower = better). 

97 """ 

98 

99 model_config = ConfigDict(populate_by_name=True) 

100 

101 source: str 

102 content_type: str 

103 chunk_type: str = CHUNK_TYPE_RAW 

104 

105 @field_validator("chunk_type", mode="before") 

106 @classmethod 

107 def _coerce_none_chunk_type(cls, v: str | None) -> str: 

108 """LanceDB rows from before the chunk_type column was added return None.""" 

109 return v if v is not None else CHUNK_TYPE_RAW 

110 

111 page_start: int 

112 page_end: int 

113 line_start: int 

114 line_end: int 

115 chunk: str 

116 chunk_index: int 

117 vector: list[float] = Field(repr=False) 

118 distance: float | None = Field(None, alias="_distance") 

119 relevance_score: float | None = Field(None, alias="_relevance_score") 

120 

121 

122def cosine_sim(a: list[float], b: list[float]) -> float: 

123 """Cosine similarity between two vectors.""" 

124 dot = sum(x * y for x, y in zip(a, b, strict=True)) 

125 norm_a = math.sqrt(sum(x * x for x in a)) 

126 norm_b = math.sqrt(sum(x * x for x in b)) 

127 if norm_a == 0 or norm_b == 0: 

128 return 0.0 

129 return dot / (norm_a * norm_b) 

130 

131 

132def mmr_rerank( 

133 query_vector: list[float], 

134 results: list[SearchChunk], 

135 top_k: int, 

136 mmr_lambda: float | None = None, 

137) -> list[SearchChunk]: 

138 """Maximal Marginal Relevance — select diverse results. 

139 Algorithm: Carbonell & Goldstein 1998, 

140 "The Use of MMR, Diversity-Based Reranking for Reordering Documents 

141 and Producing Summaries." 

142 

143 ``mmr_lambda`` controls the relevance/diversity tradeoff: 

144 0.0 = maximum diversity, 1.0 = pure relevance. 

145 Defaults to ``cfg.mmr_lambda`` (0.5). 

146 

147 Complexity: O(top_k · N · D) time, O(N · D) space for N candidates 

148 of dimension D. Each outer iteration updates a running max-redundancy 

149 vector via one matmul rather than recomputing pairs pairwise. 

150 Candidate vectors run through numpy in ``float32``, which can pick a 

151 different candidate than the pure-Python ``float64`` loop on 

152 ties within ~1e-7; distinct in principle, unobservable in practice 

153 since sub-float32 differences are below retrieval signal. 

154 """ 

155 if mmr_lambda is None: 

156 mmr_lambda = cfg.mmr_lambda 

157 if len(results) <= top_k: 

158 return results 

159 

160 candidate_vecs = np.asarray([r.vector for r in results], dtype=np.float32) 

161 query = np.asarray(query_vector, dtype=np.float32) 

162 # L2-normalize once so cosine becomes a plain dot product. 

163 cand_norms = np.linalg.norm(candidate_vecs, axis=1, keepdims=True) 

164 cand_norms[cand_norms == 0] = 1.0 

165 cand_unit = candidate_vecs / cand_norms 

166 query_norm = float(np.linalg.norm(query)) or 1.0 

167 query_unit = query / query_norm 

168 

169 relevance = cand_unit @ query_unit # shape (N,) 

170 

171 n = len(results) 

172 max_redundancy = np.zeros(n, dtype=np.float32) 

173 available = np.ones(n, dtype=bool) 

174 selected: list[SearchChunk] = [] 

175 

176 for picks in range(top_k): 

177 redundancy_term = max_redundancy if picks > 0 else np.zeros(n, dtype=np.float32) 

178 score = mmr_lambda * relevance - (1.0 - mmr_lambda) * redundancy_term 

179 # Mask already-picked candidates so argmax skips them. 

180 score = np.where(available, score, -np.inf) 

181 best = int(np.argmax(score)) 

182 selected.append(results[best]) 

183 available[best] = False 

184 # Update running max redundancy against the newly-selected vector. 

185 similarity = cand_unit @ cand_unit[best] 

186 max_redundancy = np.maximum(max_redundancy, similarity) 

187 

188 return selected 

189 

190 

191class SourceRecord(TypedDict): 

192 """A tracked source document record.""" 

193 

194 filename: str 

195 file_hash: str 

196 ingested_at: str 

197 chunk_count: int 

198 source_type: str 

199 

200 

201class CitationRecord(TypedDict): 

202 """A citation linking a wiki chunk to a specific source location.""" 

203 

204 wiki_source: str 

205 wiki_chunk_index: int 

206 citation_key: str 

207 claim_type: str 

208 source_filename: str 

209 source_hash: str 

210 page_start: int 

211 page_end: int 

212 line_start: int 

213 line_end: int 

214 excerpt: str 

215 created_at: str 

216 

217 

218class StoreMeta(TypedDict): 

219 """Single-row store metadata recording the embedding model used to build the store. 

220 

221 Compatibility is checked before every read and write. When ``cfg.embedding_model`` 

222 or ``cfg.embedding_dim`` drifts from the persisted row, the store refuses to serve 

223 until ``lilbee rebuild`` (CLI) or ``POST /api/sync {"force_rebuild": true}`` (HTTP) 

224 rewrites the chunks under the new model. 

225 """ 

226 

227 embedding_model: str 

228 embedding_dim: int 

229 schema_version: int 

230 updated_at: str 

231 

232 

233# ``schema_version`` is an integer for forward-compat. Bump only if we ever need to 

234# add or rename a meta column without forcing every store to drop_all. 

235META_SCHEMA_VERSION = 1 

236 

237# Always-true predicate used to clear the single-row ``_meta`` table before re-insert. 

238# Lance's ``Table.delete`` requires a SQL where clause; this matches every row without 

239# coupling the deletion to any specific column's value domain. 

240_META_DELETE_ALL_PREDICATE = "schema_version IS NOT NULL" 

241 

242 

243class EmbeddingModelMismatchError(RuntimeError): 

244 """Raised when stored vectors were built with a different embedding model than ``cfg``. 

245 

246 Carries a user-facing message naming both the persisted and the configured model and 

247 pointing at the two recovery paths (``lilbee rebuild`` and ``POST /api/sync`` with 

248 ``force_rebuild=true``). 

249 """ 

250 

251 

252def _embedding_mismatch_message( 

253 persisted_model: str, 

254 persisted_dim: int, 

255 current_model: str, 

256 current_dim: int, 

257) -> str: 

258 return ( 

259 f"The vector store was built with embedding model '{persisted_model}' " 

260 f"(dim {persisted_dim}), but lilbee is now configured to use " 

261 f"'{current_model}' (dim {current_dim}). Search and ingest are disabled " 

262 "until the store is rebuilt under the new model. " 

263 'Run `lilbee rebuild` or POST /api/sync with `{"force_rebuild": true}`.' 

264 ) 

265 

266 

267def _meta_schema() -> pa.Schema: 

268 return pa.schema( 

269 [ 

270 pa.field("embedding_model", pa.utf8()), 

271 pa.field("embedding_dim", pa.int32()), 

272 pa.field("schema_version", pa.int32()), 

273 pa.field("updated_at", pa.utf8()), 

274 ] 

275 ) 

276 

277 

278def _sources_schema() -> pa.Schema: 

279 return pa.schema( 

280 [ 

281 pa.field("filename", pa.utf8()), 

282 pa.field("file_hash", pa.utf8()), 

283 pa.field("ingested_at", pa.utf8()), 

284 pa.field("chunk_count", pa.int32()), 

285 pa.field("source_type", pa.utf8()), 

286 ] 

287 ) 

288 

289 

290def _citations_schema() -> pa.Schema: 

291 return pa.schema( 

292 [ 

293 pa.field("wiki_source", pa.utf8()), 

294 pa.field("wiki_chunk_index", pa.int32()), 

295 pa.field("citation_key", pa.utf8()), 

296 pa.field("claim_type", pa.utf8()), 

297 pa.field("source_filename", pa.utf8()), 

298 pa.field("source_hash", pa.utf8()), 

299 pa.field("page_start", pa.int32()), 

300 pa.field("page_end", pa.int32()), 

301 pa.field("line_start", pa.int32()), 

302 pa.field("line_end", pa.int32()), 

303 pa.field("excerpt", pa.utf8()), 

304 pa.field("created_at", pa.utf8()), 

305 ] 

306 ) 

307 

308 

309def _table_names(db: lancedb.DBConnection) -> list[str]: 

310 """Get list of table names, handling the ListTablesResponse object.""" 

311 result = db.list_tables() 

312 try: 

313 return result.tables # type: ignore[no-any-return, union-attr] 

314 except AttributeError: 

315 return list(result) # type: ignore[arg-type] 

316 

317 

318def ensure_table(db: lancedb.DBConnection, name: str, schema: pa.Schema) -> lancedb.table.Table: 

319 if name in _table_names(db): 

320 return db.open_table(name) 

321 try: 

322 return db.create_table(name, schema=schema) 

323 except ValueError: 

324 return db.open_table(name) 

325 

326 

327def _safe_delete_unlocked(table: lancedb.table.Table, predicate: str) -> None: 

328 """Delete rows matching predicate, logging on failure. Caller must hold write lock.""" 

329 try: 

330 table.delete(predicate) 

331 except Exception: 

332 log.warning("Failed to delete rows matching: %s", predicate, exc_info=True) 

333 

334 

335def safe_delete(table: lancedb.table.Table, predicate: str) -> None: 

336 """Delete rows matching predicate, logging on failure.""" 

337 with write_lock(): 

338 _safe_delete_unlocked(table, predicate) 

339 

340 

341def escape_sql_string(value: str) -> str: 

342 """Escape single quotes for SQL predicates.""" 

343 return value.replace("\\", "\\\\").replace("'", "''") 

344 

345 

346def _chunk_type_predicate(chunk_type: str) -> str: 

347 """SQL predicate that matches ``chunk_type`` while tolerating NULL rows. 

348 

349 Rows written before ``chunk_type`` was populated land as NULL. They 

350 are semantically raw, so a ``'raw'`` filter still includes them; a 

351 ``'wiki'`` filter excludes them. 

352 """ 

353 escaped = escape_sql_string(chunk_type) 

354 if chunk_type == CHUNK_TYPE_RAW: 

355 return f"(chunk_type = '{escaped}' OR chunk_type IS NULL)" 

356 return f"chunk_type = '{escaped}'" 

357 

358 

359def _hybrid_search( 

360 table: lancedb.table.Table, 

361 query_text: str, 

362 query_vector: list[float], 

363 top_k: int, 

364 chunk_type: str | None = None, 

365) -> list[SearchChunk]: 

366 """Run hybrid (vector + FTS) search with RRF reranking. 

367 

368 When ``chunk_type`` is set, the predicate is pushed into the query so 

369 the limit applies *after* the type filter. Post-filtering would 

370 silently starve wiki-only queries whose matches live past the top-K 

371 hybrid window. 

372 """ 

373 from lancedb.rerankers import RRFReranker 

374 

375 query = ( 

376 table.search(query_type="hybrid") 

377 .vector(query_vector) 

378 .text(query_text) 

379 .rerank(RRFReranker()) 

380 ) 

381 if chunk_type: 

382 query = query.where(_chunk_type_predicate(chunk_type)) 

383 rows = query.limit(top_k).to_list() 

384 return [SearchChunk(**r) for r in rows] 

385 

386 

387@dataclass 

388class RemoveResult: 

389 """Result of a remove_documents operation.""" 

390 

391 removed: list[str] 

392 not_found: list[str] 

393 

394 

395_MAX_THRESHOLD = 1.0 

396_MAX_FILTER_ITERATIONS = 20 # safety cap to prevent runaway loops 

397 

398 

399def _get_distance(chunk: SearchChunk) -> float: 

400 """Extract distance as a sortable float (inf for None).""" 

401 return chunk.distance if chunk.distance is not None else float("inf") 

402 

403 

404def _count_within_threshold(sorted_results: list[SearchChunk], threshold: float) -> int: 

405 """Count results whose distance is within the given threshold.""" 

406 for i, r in enumerate(sorted_results): 

407 if _get_distance(r) > threshold: 

408 return i 

409 return len(sorted_results) 

410 

411 

412def _has_fts_index(table: lancedb.table.Table) -> bool: 

413 """Return True when an FTS index on the chunk column already exists.""" 

414 try: 

415 for idx in table.list_indices(): 

416 if idx.index_type == "FTS" and "chunk" in idx.columns: 

417 return True 

418 except Exception: 

419 return False 

420 return False 

421 

422 

423def _sources_search_filter(search: str | None) -> str | None: 

424 """Case-insensitive filename WHERE clause, or ``None`` for empty *search*.""" 

425 if not search: 

426 return None 

427 escaped = escape_sql_string(search.lower()) 

428 return f"LOWER(filename) LIKE '%{escaped}%'" 

429 

430 

431class Store: 

432 """LanceDB vector store — wraps all DB operations with config-driven defaults.""" 

433 

434 def __init__(self, config: Config) -> None: 

435 self._config = config 

436 self._fts_ready: bool = False 

437 self._db: lancedb.DBConnection | None = None 

438 # Cache of {filename: ingested_at} rebuilt only when sources 

439 # mutate; callers (temporal filter) hit it per-query. 

440 self._source_ingested_cache: dict[str, str] | None = None 

441 

442 def _invalidate_source_cache(self) -> None: 

443 """Drop the cached {filename: ingested_at} map.""" 

444 self._source_ingested_cache = None 

445 

446 def source_ingested_at_map(self) -> dict[str, str]: 

447 """Return {filename: ingested_at} for every source, cached until mutation. 

448 

449 Best-effort: a reader racing a concurrent invalidation can store a 

450 pre-mutation snapshot. The only consumer (temporal query filter) 

451 treats a missing/stale entry as "do not filter," so staleness 

452 degrades ranking precision, never correctness. 

453 """ 

454 if self._source_ingested_cache is not None: 

455 return self._source_ingested_cache 

456 mapping = {s["filename"]: s.get("ingested_at", "") for s in self.get_sources()} 

457 self._source_ingested_cache = mapping 

458 return mapping 

459 

460 def _chunks_schema(self) -> pa.Schema: 

461 return pa.schema( 

462 [ 

463 pa.field("source", pa.utf8()), 

464 pa.field("content_type", pa.utf8()), 

465 pa.field("chunk_type", pa.utf8()), 

466 pa.field("page_start", pa.int32()), 

467 pa.field("page_end", pa.int32()), 

468 pa.field("line_start", pa.int32()), 

469 pa.field("line_end", pa.int32()), 

470 pa.field("chunk", pa.utf8()), 

471 pa.field("chunk_index", pa.int32()), 

472 pa.field("vector", pa.list_(pa.float32(), self._config.embedding_dim)), 

473 ] 

474 ) 

475 

476 def get_meta(self) -> StoreMeta | None: 

477 """Return the persisted store metadata row, or ``None`` if unset.""" 

478 table = self.open_table(META_TABLE) 

479 if table is None: 

480 return None 

481 rows = table.search().limit(1).to_list() 

482 if not rows: 

483 return None 

484 row = rows[0] 

485 return StoreMeta( 

486 embedding_model=row["embedding_model"], 

487 embedding_dim=int(row["embedding_dim"]), 

488 schema_version=int(row["schema_version"]), 

489 updated_at=row["updated_at"], 

490 ) 

491 

492 def _write_meta_unlocked(self, *, embedding_model: str, embedding_dim: int) -> None: 

493 """Overwrite the single ``_meta`` row with the supplied identity. 

494 

495 Caller must hold ``write_lock()``. Args are passed explicitly rather than 

496 re-read from ``self._config`` so the caller can snapshot cfg at a coherent 

497 instant and not race with a concurrent ``set_embedding_model``. 

498 """ 

499 db = self.get_db() 

500 table = ensure_table(db, META_TABLE, _meta_schema()) 

501 _safe_delete_unlocked(table, _META_DELETE_ALL_PREDICATE) 

502 table.add( 

503 [ 

504 { 

505 "embedding_model": embedding_model, 

506 "embedding_dim": embedding_dim, 

507 "schema_version": META_SCHEMA_VERSION, 

508 "updated_at": datetime.now(UTC).isoformat(), 

509 } 

510 ] 

511 ) 

512 

513 def _has_chunks(self) -> bool: 

514 """Return True when the chunks table exists and has at least one row.""" 

515 chunks = self.open_table(CHUNKS_TABLE) 

516 return chunks is not None and chunks.count_rows() > 0 

517 

518 def initialize_meta_if_legacy(self) -> bool: 

519 """Pin a legacy store's identity to the current cfg if not already set. 

520 

521 Returns ``True`` when a meta row was just written. No-op when meta already 

522 exists or no chunks are present. This is the path that converts a 

523 pre-upgrade store (chunks present, no ``_meta``) into a gated store. It 

524 snapshots cfg under the write lock to keep the recorded identity coherent 

525 with what the gate is comparing against. 

526 """ 

527 if self.get_meta() is not None: 

528 return False 

529 if not self._has_chunks(): 

530 return False 

531 embedding_model = self._config.embedding_model 

532 embedding_dim = self._config.embedding_dim 

533 with write_lock(): 

534 # Re-check under the lock so two callers do not both warn-and-write. 

535 if self.get_meta() is not None: 

536 return False 

537 log.warning( 

538 "Legacy store has chunks but no _meta row. Initializing _meta from " 

539 "the current configuration (embedding_model=%s, embedding_dim=%d). " 

540 "If you changed embedding_model before upgrading, run `lilbee rebuild` " 

541 "to ensure the store is consistent.", 

542 embedding_model, 

543 embedding_dim, 

544 ) 

545 self._write_meta_unlocked(embedding_model=embedding_model, embedding_dim=embedding_dim) 

546 return True 

547 

548 def _ensure_embedding_compat(self) -> None: 

549 """Raise when the persisted embedding identity drifts from cfg. 

550 

551 Pure check, no side effects. Migration of legacy stores (chunks present, 

552 no ``_meta``) is the caller's responsibility via ``initialize_meta_if_legacy`` 

553 so this method is safe to call from inside an existing ``write_lock()`` 

554 (no recursive lock attempt). cfg fields are snapshotted at entry so the 

555 comparison is coherent even if another thread mutates them mid-call. 

556 """ 

557 current_model = self._config.embedding_model 

558 current_dim = self._config.embedding_dim 

559 meta = self.get_meta() 

560 if meta is None: 

561 return 

562 if meta["embedding_model"] == current_model and meta["embedding_dim"] == current_dim: 

563 return 

564 raise EmbeddingModelMismatchError( 

565 _embedding_mismatch_message( 

566 persisted_model=meta["embedding_model"], 

567 persisted_dim=meta["embedding_dim"], 

568 current_model=current_model, 

569 current_dim=current_dim, 

570 ) 

571 ) 

572 

573 def get_db(self) -> lancedb.DBConnection: 

574 if self._db is None: 

575 import lancedb as _lancedb 

576 

577 self._config.lancedb_dir.mkdir(parents=True, exist_ok=True) 

578 self._db = _lancedb.connect( 

579 str(self._config.lancedb_dir), 

580 read_consistency_interval=READ_CONSISTENCY_INTERVAL, 

581 ) 

582 return self._db 

583 

584 def open_table(self, name: str) -> lancedb.table.Table | None: 

585 """Open a table if it exists, otherwise return None.""" 

586 db = self.get_db() 

587 if name not in _table_names(db): 

588 return None 

589 return db.open_table(name) 

590 

591 def ensure_fts_index(self) -> None: 

592 """Create the chunks FTS index, or run ``optimize()`` once it exists. 

593 

594 ``optimize()`` folds newly added rows into the FTS index and also 

595 runs LanceDB's default compaction + version pruning (default prune 

596 window: 7 days). Work scales with recent deltas rather than total 

597 chunk count, so large corpora no longer pay the full 

598 ``create_fts_index(replace=True)`` rebuild cost on every sync. 

599 """ 

600 with write_lock(): 

601 table = self.open_table(CHUNKS_TABLE) 

602 if table is None: 

603 return 

604 try: 

605 if _has_fts_index(table): 

606 table.optimize() 

607 log.debug("FTS index optimized on '%s'", CHUNKS_TABLE) 

608 else: 

609 table.create_fts_index("chunk", replace=False) 

610 log.debug("FTS index created on '%s'", CHUNKS_TABLE) 

611 self._fts_ready = True 

612 except Exception: 

613 log.debug("FTS index ensure failed (empty table?)", exc_info=True) 

614 

615 def add_chunks(self, records: list[dict]) -> int: 

616 """Add chunk records to the store. Returns count added. 

617 

618 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was 

619 written under a different embedding model than the current ``cfg``. On the 

620 first write to a fresh store, ``_meta`` is initialized from the current cfg. 

621 

622 The gate runs inside the write lock and uses a single cfg snapshot so a 

623 concurrent ``set_embedding_model`` cannot slip a write in past a stale 

624 compatibility check. 

625 """ 

626 with write_lock(): 

627 embedding_model = self._config.embedding_model 

628 embedding_dim = self._config.embedding_dim 

629 self._ensure_embedding_compat() 

630 self._fts_ready = False 

631 if not records: 

632 return 0 

633 for rec in records: 

634 vec = rec.get("vector", []) 

635 if len(vec) != embedding_dim: 

636 raise ValueError( 

637 f"Vector dimension mismatch: expected {embedding_dim}, " 

638 f"got {len(vec)} (source={rec.get('source', '?')})" 

639 ) 

640 db = self.get_db() 

641 table = ensure_table(db, CHUNKS_TABLE, self._chunks_schema()) 

642 table.add(records) 

643 if self.get_meta() is None: 

644 self._write_meta_unlocked( 

645 embedding_model=embedding_model, embedding_dim=embedding_dim 

646 ) 

647 return len(records) 

648 

649 def bm25_probe(self, query_text: str, top_k: int = 5) -> list[SearchChunk]: 

650 """Quick BM25-only search for confidence checking. Returns up to top_k results.""" 

651 table = self.open_table(CHUNKS_TABLE) 

652 if table is None: 

653 return [] 

654 if not self._fts_ready: 

655 self.ensure_fts_index() 

656 if not self._fts_ready: 

657 return [] # pragma: no cover 

658 try: 

659 rows = table.search(query_text, query_type="fts").limit(top_k).to_list() 

660 return [SearchChunk(**r) for r in rows] 

661 except Exception: 

662 log.debug("BM25 probe failed", exc_info=True) 

663 return [] 

664 

665 def search( 

666 self, 

667 query_vector: list[float], 

668 top_k: int | None = None, 

669 max_distance: float | None = None, 

670 query_text: str | None = None, 

671 chunk_type: str | None = None, 

672 ) -> list[SearchChunk]: 

673 """Search for similar chunks. Hybrid when FTS available, else vector-only. 

674 

675 Results with distance > max_distance are filtered out (vector-only path). 

676 Pass max_distance=0 to disable filtering. 

677 When *chunk_type* is set, only chunks of that type ("raw" or "wiki") are returned. 

678 

679 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was 

680 written under a different embedding model than the current ``cfg``. 

681 """ 

682 if top_k is None: 

683 top_k = self._config.top_k 

684 if max_distance is None: 

685 max_distance = self._config.max_distance 

686 table = self.open_table(CHUNKS_TABLE) 

687 if table is None: 

688 return [] 

689 self.initialize_meta_if_legacy() 

690 self._ensure_embedding_compat() 

691 

692 if query_text and not self._fts_ready: 

693 self.ensure_fts_index() 

694 

695 if query_text and self._fts_ready: 

696 try: 

697 return _hybrid_search(table, query_text, query_vector, top_k, chunk_type) 

698 except Exception: 

699 log.debug("Hybrid search failed, falling back to vector-only", exc_info=True) 

700 

701 candidate_k = top_k * self._config.candidate_multiplier 

702 query = table.search(query_vector).metric("cosine").limit(candidate_k) 

703 if chunk_type: 

704 query = query.where(_chunk_type_predicate(chunk_type)) 

705 rows = query.to_list() 

706 log.debug( 

707 "Vector search: query=%r, candidates=%d, max_distance=%.2f", 

708 query_text or "vector-only", 

709 len(rows), 

710 max_distance, 

711 ) 

712 if rows: 

713 distances = [r.get("distance", 0) for r in rows[:5]] 

714 log.debug("Top 5 distances: %s", distances) 

715 results = [SearchChunk(**r) for r in rows] 

716 if max_distance > 0: 

717 before = len(results) 

718 if self._config.adaptive_threshold: 

719 results = self._adaptive_filter(results, top_k, max_distance) 

720 log.debug( 

721 "After adaptive filter: %d/%d results, threshold=%.2f", 

722 len(results), 

723 before, 

724 max_distance, 

725 ) 

726 else: 

727 results = self._fixed_filter(results, max_distance) 

728 log.debug( 

729 "After fixed filter: %d/%d results, threshold=%.2f", 

730 len(results), 

731 before, 

732 max_distance, 

733 ) 

734 if len(results) > top_k: 

735 results = mmr_rerank(query_vector, results, top_k, self._config.mmr_lambda) 

736 return results 

737 

738 def _adaptive_filter( 

739 self, results: list[SearchChunk], top_k: int, initial_threshold: float 

740 ) -> list[SearchChunk]: 

741 """Widen cosine distance threshold when too few results. 

742 Inspired by grantflow's (grantflow-ai/grantflow) adaptive retrieval 

743 pattern which widens thresholds on recursive retry. Step size and 

744 cap are configurable via ``self._config.adaptive_threshold_step``. 

745 

746 Pre-sorts results by distance for a single-pass cutoff search. 

747 Step size is ``self._config.adaptive_threshold_step`` (default 0.2). 

748 """ 

749 cap = max(initial_threshold, _MAX_THRESHOLD) 

750 step = self._config.adaptive_threshold_step 

751 

752 sorted_results = sorted(results, key=_get_distance) 

753 

754 threshold = initial_threshold 

755 for _ in range(_MAX_FILTER_ITERATIONS): 

756 if threshold > cap: 

757 break 

758 cutoff = _count_within_threshold(sorted_results, threshold) 

759 if cutoff >= top_k: 

760 return sorted_results[:cutoff] 

761 threshold += step 

762 # Final pass at cap 

763 cutoff = _count_within_threshold(sorted_results, cap) 

764 return sorted_results[:cutoff] 

765 

766 def _fixed_filter(self, results: list[SearchChunk], threshold: float) -> list[SearchChunk]: 

767 """Simple fixed threshold filter - keep only results within distance threshold.""" 

768 return [r for r in results if _get_distance(r) <= threshold] 

769 

770 def get_chunks_by_source(self, source: str) -> list[SearchChunk]: 

771 """Return every chunk whose ``source`` equals *source*.""" 

772 table = self.open_table(CHUNKS_TABLE) 

773 if table is None: 

774 return [] 

775 escaped = escape_sql_string(source) 

776 try: 

777 rows = table.search().where(f"source = '{escaped}'").limit(None).to_list() 

778 except Exception: 

779 # FTS-enabled tables return a query builder that cannot 

780 # handle .where() on arbitrary columns; fall through to a 

781 # pyarrow.compute filter on the Arrow table so the source 

782 # match runs in C++ without materializing non-matching rows. 

783 log.debug("get_chunks_by_source search() failed, using Arrow fallback", exc_info=True) 

784 arrow_tbl = table.to_arrow() 

785 filtered = arrow_tbl.filter(pc.equal(arrow_tbl["source"], source)) 

786 rows = filtered.to_pylist() 

787 return [SearchChunk(**r) for r in rows] 

788 

789 def delete_by_source(self, source: str) -> None: 

790 """Delete all chunks from a given source file.""" 

791 with write_lock(): 

792 table = self.open_table(CHUNKS_TABLE) 

793 if table is not None: 

794 _safe_delete_unlocked(table, f"source = '{escape_sql_string(source)}'") 

795 self._invalidate_source_cache() 

796 

797 def get_sources( 

798 self, 

799 *, 

800 search: str | None = None, 

801 limit: int | None = None, 

802 offset: int = 0, 

803 ) -> list[SourceRecord]: 

804 """Return source records, filtered by *search* and sliced by offset/limit.""" 

805 table = self.open_table(SOURCES_TABLE) 

806 if table is None: 

807 return [] 

808 query = table.search() 

809 where = _sources_search_filter(search) 

810 if where is not None: 

811 query = query.where(where) 

812 if offset: 

813 query = query.offset(offset) 

814 query = query.limit(limit) 

815 result: list[SourceRecord] = query.to_list() # type: ignore[assignment] 

816 return result 

817 

818 def count_sources(self, *, search: str | None = None) -> int: 

819 """Count tracked sources matching *search* without materializing rows.""" 

820 table = self.open_table(SOURCES_TABLE) 

821 if table is None: 

822 return 0 

823 where = _sources_search_filter(search) 

824 count: int = table.count_rows() if where is None else table.count_rows(filter=where) 

825 return count 

826 

827 def upsert_source( 

828 self, 

829 filename: str, 

830 file_hash: str, 

831 chunk_count: int, 

832 source_type: str = "document", 

833 ) -> None: 

834 """Add or update a source file tracking record.""" 

835 with write_lock(): 

836 db = self.get_db() 

837 table = ensure_table(db, SOURCES_TABLE, _sources_schema()) 

838 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'") 

839 table.add( 

840 [ 

841 { 

842 "filename": filename, 

843 "file_hash": file_hash, 

844 "ingested_at": datetime.now(UTC).isoformat(), 

845 "chunk_count": chunk_count, 

846 "source_type": source_type, 

847 } 

848 ] 

849 ) 

850 self._invalidate_source_cache() 

851 

852 def delete_source(self, filename: str) -> None: 

853 """Remove a source file tracking record.""" 

854 with write_lock(): 

855 table = self.open_table(SOURCES_TABLE) 

856 if table is not None: 

857 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'") 

858 self._invalidate_source_cache() 

859 

860 def remove_documents( 

861 self, 

862 names: list[str], 

863 *, 

864 delete_files: bool = False, 

865 documents_dir: Path | None = None, 

866 ) -> RemoveResult: 

867 """Remove documents from the knowledge base by source name. 

868 Looks up known sources, deletes chunks and source records for each. 

869 If *delete_files* is True, resolves the path and verifies it is 

870 contained within *documents_dir* before unlinking (path traversal guard). 

871 

872 Returns a RemoveResult with removed and not_found lists. 

873 """ 

874 if documents_dir is None: 

875 documents_dir = self._config.documents_dir 

876 

877 known = {s["filename"] for s in self.get_sources()} 

878 removed: list[str] = [] 

879 not_found: list[str] = [] 

880 

881 for name in names: 

882 if name not in known: 

883 not_found.append(name) 

884 continue 

885 self.delete_by_source(name) 

886 self.delete_source(name) 

887 removed.append(name) 

888 if delete_files: 

889 try: 

890 path = validate_path_within(documents_dir / name, documents_dir) 

891 except ValueError: 

892 log.warning("Path traversal blocked: %s escapes %s", name, documents_dir) 

893 continue 

894 if path.exists(): 

895 path.unlink() 

896 

897 return RemoveResult(removed=removed, not_found=not_found) 

898 

899 def clear_table(self, name: str, predicate: str) -> None: 

900 """Delete rows matching *predicate* from *name*. Acquires write lock.""" 

901 with write_lock(): 

902 table = self.open_table(name) 

903 if table is not None: 

904 _safe_delete_unlocked(table, predicate) 

905 

906 def add_citations(self, records: list[CitationRecord]) -> int: 

907 """Add citation records to the store. Returns count added.""" 

908 if not records: 

909 return 0 

910 with write_lock(): 

911 db = self.get_db() 

912 table = ensure_table(db, CITATIONS_TABLE, _citations_schema()) 

913 table.add(records) 

914 return len(records) 

915 

916 def get_citations_for_wiki(self, wiki_source: str) -> list[CitationRecord]: 

917 """Get all citations for a wiki page.""" 

918 table = self.open_table(CITATIONS_TABLE) 

919 if table is None: 

920 return [] 

921 escaped = escape_sql_string(wiki_source) 

922 rows: list[CitationRecord] = table.search().where(f"wiki_source = '{escaped}'").to_list() 

923 return rows 

924 

925 def get_citations_for_source(self, source_filename: str) -> list[CitationRecord]: 

926 """Get all citations that reference a source document (reverse lookup).""" 

927 table = self.open_table(CITATIONS_TABLE) 

928 if table is None: 

929 return [] 

930 escaped = escape_sql_string(source_filename) 

931 rows: list[CitationRecord] = ( 

932 table.search().where(f"source_filename = '{escaped}'").to_list() 

933 ) 

934 return rows 

935 

936 def delete_citations_for_wiki(self, wiki_source: str) -> None: 

937 """Delete all citations for a wiki page (used before regeneration).""" 

938 self.clear_table( 

939 CITATIONS_TABLE, 

940 f"wiki_source = '{escape_sql_string(wiki_source)}'", 

941 ) 

942 

943 def close(self) -> None: 

944 """Release the database connection and reset state.""" 

945 self._db = None 

946 self._fts_ready = False 

947 

948 def drop_all(self) -> None: 

949 """Drop all tables -- used by rebuild.""" 

950 with write_lock(): 

951 self._fts_ready = False 

952 db = self.get_db() 

953 for name in _table_names(db): 

954 db.drop_table(name) 

955 self._invalidate_source_cache()