Coverage for src / lilbee / store.py: 100%
463 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""LanceDB vector store operations."""
3from __future__ import annotations
5import logging
6import math
7import threading
8from dataclasses import dataclass
9from datetime import UTC, datetime, timedelta
10from enum import StrEnum
11from pathlib import Path
12from typing import TYPE_CHECKING, TypedDict
14import numpy as np
15import pyarrow as pa
16import pyarrow.compute as pc
17from pydantic import BaseModel, ConfigDict, Field, field_validator
19if TYPE_CHECKING:
20 import lancedb
21 import lancedb.table
23from lilbee.config import (
24 CHUNKS_TABLE,
25 CITATIONS_TABLE,
26 META_TABLE,
27 SOURCES_TABLE,
28 Config,
29 cfg,
30)
31from lilbee.lock import write_lock
32from lilbee.security import validate_path_within
34log = logging.getLogger(__name__)
37def install_lancedb_thread_error_suppressor() -> None:
38 """Install a ``threading.excepthook`` that swallows lancedb shutdown noise.
39 lancedb has no ``close()`` API and its internal event loop thread crashes
40 during Python interpreter teardown. The exception is harmless (the process
41 is exiting anyway) but pollutes CLI/TUI output. This is opt-in so importing
42 ``lilbee.store`` has no hidden side effects; call it once from the CLI/TUI
43 bootstrap.
44 """
45 original = threading.excepthook
47 def _hook(args: threading.ExceptHookArgs) -> None:
48 if args.thread and "LanceDB" in args.thread.name:
49 return
50 original(args)
52 threading.excepthook = _hook
55# How often readers re-check the manifest for new versions from other processes.
56# Zero means strong consistency (every read checks); higher values reduce disk I/O
57# on slow media (HDD) at the cost of serving slightly stale data.
58READ_CONSISTENCY_INTERVAL = timedelta(seconds=5)
60# Values for the ``chunk_type`` column. Everything goes in as raw except wiki
61# pages written by the wiki producer; callers filter with ``Store.search(chunk_type=...)``.
62CHUNK_TYPE_RAW = "raw"
63CHUNK_TYPE_WIKI = "wiki"
66class SearchScope(StrEnum):
67 """What the user wants to search over.
69 Values are used as-is on CLI flags, MCP params, and HTTP query strings.
70 ``BOTH`` resolves to a ``None`` ``chunk_type`` (no filter); the two
71 others map 1:1 to the chunks-table values.
72 """
74 RAW = CHUNK_TYPE_RAW
75 WIKI = CHUNK_TYPE_WIKI
76 BOTH = "both"
79def scope_to_chunk_type(scope: SearchScope | str | None) -> str | None:
80 """Translate a user-facing scope into a ``Store.search`` ``chunk_type`` arg.
82 ``None``/``"both"`` → no filter. ``"raw"`` / ``"wiki"`` → the matching
83 chunks-table value. Raises ``ValueError`` on any other string.
84 """
85 if scope is None:
86 return None
87 normalized = SearchScope(scope)
88 if normalized is SearchScope.BOTH:
89 return None
90 return normalized.value
93class SearchChunk(BaseModel):
94 """A search result from LanceDB.
95 Hybrid results have ``relevance_score`` set (higher = better).
96 Vector-only results have ``distance`` set (lower = better).
97 """
99 model_config = ConfigDict(populate_by_name=True)
101 source: str
102 content_type: str
103 chunk_type: str = CHUNK_TYPE_RAW
105 @field_validator("chunk_type", mode="before")
106 @classmethod
107 def _coerce_none_chunk_type(cls, v: str | None) -> str:
108 """LanceDB rows from before the chunk_type column was added return None."""
109 return v if v is not None else CHUNK_TYPE_RAW
111 page_start: int
112 page_end: int
113 line_start: int
114 line_end: int
115 chunk: str
116 chunk_index: int
117 vector: list[float] = Field(repr=False)
118 distance: float | None = Field(None, alias="_distance")
119 relevance_score: float | None = Field(None, alias="_relevance_score")
122def cosine_sim(a: list[float], b: list[float]) -> float:
123 """Cosine similarity between two vectors."""
124 dot = sum(x * y for x, y in zip(a, b, strict=True))
125 norm_a = math.sqrt(sum(x * x for x in a))
126 norm_b = math.sqrt(sum(x * x for x in b))
127 if norm_a == 0 or norm_b == 0:
128 return 0.0
129 return dot / (norm_a * norm_b)
132def mmr_rerank(
133 query_vector: list[float],
134 results: list[SearchChunk],
135 top_k: int,
136 mmr_lambda: float | None = None,
137) -> list[SearchChunk]:
138 """Maximal Marginal Relevance — select diverse results.
139 Algorithm: Carbonell & Goldstein 1998,
140 "The Use of MMR, Diversity-Based Reranking for Reordering Documents
141 and Producing Summaries."
143 ``mmr_lambda`` controls the relevance/diversity tradeoff:
144 0.0 = maximum diversity, 1.0 = pure relevance.
145 Defaults to ``cfg.mmr_lambda`` (0.5).
147 Complexity: O(top_k · N · D) time, O(N · D) space for N candidates
148 of dimension D. Each outer iteration updates a running max-redundancy
149 vector via one matmul rather than recomputing pairs pairwise.
150 Candidate vectors run through numpy in ``float32``, which can pick a
151 different candidate than the pure-Python ``float64`` loop on
152 ties within ~1e-7; distinct in principle, unobservable in practice
153 since sub-float32 differences are below retrieval signal.
154 """
155 if mmr_lambda is None:
156 mmr_lambda = cfg.mmr_lambda
157 if len(results) <= top_k:
158 return results
160 candidate_vecs = np.asarray([r.vector for r in results], dtype=np.float32)
161 query = np.asarray(query_vector, dtype=np.float32)
162 # L2-normalize once so cosine becomes a plain dot product.
163 cand_norms = np.linalg.norm(candidate_vecs, axis=1, keepdims=True)
164 cand_norms[cand_norms == 0] = 1.0
165 cand_unit = candidate_vecs / cand_norms
166 query_norm = float(np.linalg.norm(query)) or 1.0
167 query_unit = query / query_norm
169 relevance = cand_unit @ query_unit # shape (N,)
171 n = len(results)
172 max_redundancy = np.zeros(n, dtype=np.float32)
173 available = np.ones(n, dtype=bool)
174 selected: list[SearchChunk] = []
176 for picks in range(top_k):
177 redundancy_term = max_redundancy if picks > 0 else np.zeros(n, dtype=np.float32)
178 score = mmr_lambda * relevance - (1.0 - mmr_lambda) * redundancy_term
179 # Mask already-picked candidates so argmax skips them.
180 score = np.where(available, score, -np.inf)
181 best = int(np.argmax(score))
182 selected.append(results[best])
183 available[best] = False
184 # Update running max redundancy against the newly-selected vector.
185 similarity = cand_unit @ cand_unit[best]
186 max_redundancy = np.maximum(max_redundancy, similarity)
188 return selected
191class SourceRecord(TypedDict):
192 """A tracked source document record."""
194 filename: str
195 file_hash: str
196 ingested_at: str
197 chunk_count: int
198 source_type: str
201class CitationRecord(TypedDict):
202 """A citation linking a wiki chunk to a specific source location."""
204 wiki_source: str
205 wiki_chunk_index: int
206 citation_key: str
207 claim_type: str
208 source_filename: str
209 source_hash: str
210 page_start: int
211 page_end: int
212 line_start: int
213 line_end: int
214 excerpt: str
215 created_at: str
218class StoreMeta(TypedDict):
219 """Single-row store metadata recording the embedding model used to build the store.
221 Compatibility is checked before every read and write. When ``cfg.embedding_model``
222 or ``cfg.embedding_dim`` drifts from the persisted row, the store refuses to serve
223 until ``lilbee rebuild`` (CLI) or ``POST /api/sync {"force_rebuild": true}`` (HTTP)
224 rewrites the chunks under the new model.
225 """
227 embedding_model: str
228 embedding_dim: int
229 schema_version: int
230 updated_at: str
233# ``schema_version`` is an integer for forward-compat. Bump only if we ever need to
234# add or rename a meta column without forcing every store to drop_all.
235META_SCHEMA_VERSION = 1
237# Always-true predicate used to clear the single-row ``_meta`` table before re-insert.
238# Lance's ``Table.delete`` requires a SQL where clause; this matches every row without
239# coupling the deletion to any specific column's value domain.
240_META_DELETE_ALL_PREDICATE = "schema_version IS NOT NULL"
243class EmbeddingModelMismatchError(RuntimeError):
244 """Raised when stored vectors were built with a different embedding model than ``cfg``.
246 Carries a user-facing message naming both the persisted and the configured model and
247 pointing at the two recovery paths (``lilbee rebuild`` and ``POST /api/sync`` with
248 ``force_rebuild=true``).
249 """
252def _embedding_mismatch_message(
253 persisted_model: str,
254 persisted_dim: int,
255 current_model: str,
256 current_dim: int,
257) -> str:
258 return (
259 f"The vector store was built with embedding model '{persisted_model}' "
260 f"(dim {persisted_dim}), but lilbee is now configured to use "
261 f"'{current_model}' (dim {current_dim}). Search and ingest are disabled "
262 "until the store is rebuilt under the new model. "
263 'Run `lilbee rebuild` or POST /api/sync with `{"force_rebuild": true}`.'
264 )
267def _meta_schema() -> pa.Schema:
268 return pa.schema(
269 [
270 pa.field("embedding_model", pa.utf8()),
271 pa.field("embedding_dim", pa.int32()),
272 pa.field("schema_version", pa.int32()),
273 pa.field("updated_at", pa.utf8()),
274 ]
275 )
278def _sources_schema() -> pa.Schema:
279 return pa.schema(
280 [
281 pa.field("filename", pa.utf8()),
282 pa.field("file_hash", pa.utf8()),
283 pa.field("ingested_at", pa.utf8()),
284 pa.field("chunk_count", pa.int32()),
285 pa.field("source_type", pa.utf8()),
286 ]
287 )
290def _citations_schema() -> pa.Schema:
291 return pa.schema(
292 [
293 pa.field("wiki_source", pa.utf8()),
294 pa.field("wiki_chunk_index", pa.int32()),
295 pa.field("citation_key", pa.utf8()),
296 pa.field("claim_type", pa.utf8()),
297 pa.field("source_filename", pa.utf8()),
298 pa.field("source_hash", pa.utf8()),
299 pa.field("page_start", pa.int32()),
300 pa.field("page_end", pa.int32()),
301 pa.field("line_start", pa.int32()),
302 pa.field("line_end", pa.int32()),
303 pa.field("excerpt", pa.utf8()),
304 pa.field("created_at", pa.utf8()),
305 ]
306 )
309def _table_names(db: lancedb.DBConnection) -> list[str]:
310 """Get list of table names, handling the ListTablesResponse object."""
311 result = db.list_tables()
312 try:
313 return result.tables # type: ignore[no-any-return, union-attr]
314 except AttributeError:
315 return list(result) # type: ignore[arg-type]
318def ensure_table(db: lancedb.DBConnection, name: str, schema: pa.Schema) -> lancedb.table.Table:
319 if name in _table_names(db):
320 return db.open_table(name)
321 try:
322 return db.create_table(name, schema=schema)
323 except ValueError:
324 return db.open_table(name)
327def _safe_delete_unlocked(table: lancedb.table.Table, predicate: str) -> None:
328 """Delete rows matching predicate, logging on failure. Caller must hold write lock."""
329 try:
330 table.delete(predicate)
331 except Exception:
332 log.warning("Failed to delete rows matching: %s", predicate, exc_info=True)
335def safe_delete(table: lancedb.table.Table, predicate: str) -> None:
336 """Delete rows matching predicate, logging on failure."""
337 with write_lock():
338 _safe_delete_unlocked(table, predicate)
341def escape_sql_string(value: str) -> str:
342 """Escape single quotes for SQL predicates."""
343 return value.replace("\\", "\\\\").replace("'", "''")
346def _chunk_type_predicate(chunk_type: str) -> str:
347 """SQL predicate that matches ``chunk_type`` while tolerating NULL rows.
349 Rows written before ``chunk_type`` was populated land as NULL. They
350 are semantically raw, so a ``'raw'`` filter still includes them; a
351 ``'wiki'`` filter excludes them.
352 """
353 escaped = escape_sql_string(chunk_type)
354 if chunk_type == CHUNK_TYPE_RAW:
355 return f"(chunk_type = '{escaped}' OR chunk_type IS NULL)"
356 return f"chunk_type = '{escaped}'"
359def _hybrid_search(
360 table: lancedb.table.Table,
361 query_text: str,
362 query_vector: list[float],
363 top_k: int,
364 chunk_type: str | None = None,
365) -> list[SearchChunk]:
366 """Run hybrid (vector + FTS) search with RRF reranking.
368 When ``chunk_type`` is set, the predicate is pushed into the query so
369 the limit applies *after* the type filter. Post-filtering would
370 silently starve wiki-only queries whose matches live past the top-K
371 hybrid window.
372 """
373 from lancedb.rerankers import RRFReranker
375 query = (
376 table.search(query_type="hybrid")
377 .vector(query_vector)
378 .text(query_text)
379 .rerank(RRFReranker())
380 )
381 if chunk_type:
382 query = query.where(_chunk_type_predicate(chunk_type))
383 rows = query.limit(top_k).to_list()
384 return [SearchChunk(**r) for r in rows]
387@dataclass
388class RemoveResult:
389 """Result of a remove_documents operation."""
391 removed: list[str]
392 not_found: list[str]
395_MAX_THRESHOLD = 1.0
396_MAX_FILTER_ITERATIONS = 20 # safety cap to prevent runaway loops
399def _get_distance(chunk: SearchChunk) -> float:
400 """Extract distance as a sortable float (inf for None)."""
401 return chunk.distance if chunk.distance is not None else float("inf")
404def _count_within_threshold(sorted_results: list[SearchChunk], threshold: float) -> int:
405 """Count results whose distance is within the given threshold."""
406 for i, r in enumerate(sorted_results):
407 if _get_distance(r) > threshold:
408 return i
409 return len(sorted_results)
412def _has_fts_index(table: lancedb.table.Table) -> bool:
413 """Return True when an FTS index on the chunk column already exists."""
414 try:
415 for idx in table.list_indices():
416 if idx.index_type == "FTS" and "chunk" in idx.columns:
417 return True
418 except Exception:
419 return False
420 return False
423def _sources_search_filter(search: str | None) -> str | None:
424 """Case-insensitive filename WHERE clause, or ``None`` for empty *search*."""
425 if not search:
426 return None
427 escaped = escape_sql_string(search.lower())
428 return f"LOWER(filename) LIKE '%{escaped}%'"
431class Store:
432 """LanceDB vector store — wraps all DB operations with config-driven defaults."""
434 def __init__(self, config: Config) -> None:
435 self._config = config
436 self._fts_ready: bool = False
437 self._db: lancedb.DBConnection | None = None
438 # Cache of {filename: ingested_at} rebuilt only when sources
439 # mutate; callers (temporal filter) hit it per-query.
440 self._source_ingested_cache: dict[str, str] | None = None
442 def _invalidate_source_cache(self) -> None:
443 """Drop the cached {filename: ingested_at} map."""
444 self._source_ingested_cache = None
446 def source_ingested_at_map(self) -> dict[str, str]:
447 """Return {filename: ingested_at} for every source, cached until mutation.
449 Best-effort: a reader racing a concurrent invalidation can store a
450 pre-mutation snapshot. The only consumer (temporal query filter)
451 treats a missing/stale entry as "do not filter," so staleness
452 degrades ranking precision, never correctness.
453 """
454 if self._source_ingested_cache is not None:
455 return self._source_ingested_cache
456 mapping = {s["filename"]: s.get("ingested_at", "") for s in self.get_sources()}
457 self._source_ingested_cache = mapping
458 return mapping
460 def _chunks_schema(self) -> pa.Schema:
461 return pa.schema(
462 [
463 pa.field("source", pa.utf8()),
464 pa.field("content_type", pa.utf8()),
465 pa.field("chunk_type", pa.utf8()),
466 pa.field("page_start", pa.int32()),
467 pa.field("page_end", pa.int32()),
468 pa.field("line_start", pa.int32()),
469 pa.field("line_end", pa.int32()),
470 pa.field("chunk", pa.utf8()),
471 pa.field("chunk_index", pa.int32()),
472 pa.field("vector", pa.list_(pa.float32(), self._config.embedding_dim)),
473 ]
474 )
476 def get_meta(self) -> StoreMeta | None:
477 """Return the persisted store metadata row, or ``None`` if unset."""
478 table = self.open_table(META_TABLE)
479 if table is None:
480 return None
481 rows = table.search().limit(1).to_list()
482 if not rows:
483 return None
484 row = rows[0]
485 return StoreMeta(
486 embedding_model=row["embedding_model"],
487 embedding_dim=int(row["embedding_dim"]),
488 schema_version=int(row["schema_version"]),
489 updated_at=row["updated_at"],
490 )
492 def _write_meta_unlocked(self, *, embedding_model: str, embedding_dim: int) -> None:
493 """Overwrite the single ``_meta`` row with the supplied identity.
495 Caller must hold ``write_lock()``. Args are passed explicitly rather than
496 re-read from ``self._config`` so the caller can snapshot cfg at a coherent
497 instant and not race with a concurrent ``set_embedding_model``.
498 """
499 db = self.get_db()
500 table = ensure_table(db, META_TABLE, _meta_schema())
501 _safe_delete_unlocked(table, _META_DELETE_ALL_PREDICATE)
502 table.add(
503 [
504 {
505 "embedding_model": embedding_model,
506 "embedding_dim": embedding_dim,
507 "schema_version": META_SCHEMA_VERSION,
508 "updated_at": datetime.now(UTC).isoformat(),
509 }
510 ]
511 )
513 def _has_chunks(self) -> bool:
514 """Return True when the chunks table exists and has at least one row."""
515 chunks = self.open_table(CHUNKS_TABLE)
516 return chunks is not None and chunks.count_rows() > 0
518 def initialize_meta_if_legacy(self) -> bool:
519 """Pin a legacy store's identity to the current cfg if not already set.
521 Returns ``True`` when a meta row was just written. No-op when meta already
522 exists or no chunks are present. This is the path that converts a
523 pre-upgrade store (chunks present, no ``_meta``) into a gated store. It
524 snapshots cfg under the write lock to keep the recorded identity coherent
525 with what the gate is comparing against.
526 """
527 if self.get_meta() is not None:
528 return False
529 if not self._has_chunks():
530 return False
531 embedding_model = self._config.embedding_model
532 embedding_dim = self._config.embedding_dim
533 with write_lock():
534 # Re-check under the lock so two callers do not both warn-and-write.
535 if self.get_meta() is not None:
536 return False
537 log.warning(
538 "Legacy store has chunks but no _meta row. Initializing _meta from "
539 "the current configuration (embedding_model=%s, embedding_dim=%d). "
540 "If you changed embedding_model before upgrading, run `lilbee rebuild` "
541 "to ensure the store is consistent.",
542 embedding_model,
543 embedding_dim,
544 )
545 self._write_meta_unlocked(embedding_model=embedding_model, embedding_dim=embedding_dim)
546 return True
548 def _ensure_embedding_compat(self) -> None:
549 """Raise when the persisted embedding identity drifts from cfg.
551 Pure check, no side effects. Migration of legacy stores (chunks present,
552 no ``_meta``) is the caller's responsibility via ``initialize_meta_if_legacy``
553 so this method is safe to call from inside an existing ``write_lock()``
554 (no recursive lock attempt). cfg fields are snapshotted at entry so the
555 comparison is coherent even if another thread mutates them mid-call.
556 """
557 current_model = self._config.embedding_model
558 current_dim = self._config.embedding_dim
559 meta = self.get_meta()
560 if meta is None:
561 return
562 if meta["embedding_model"] == current_model and meta["embedding_dim"] == current_dim:
563 return
564 raise EmbeddingModelMismatchError(
565 _embedding_mismatch_message(
566 persisted_model=meta["embedding_model"],
567 persisted_dim=meta["embedding_dim"],
568 current_model=current_model,
569 current_dim=current_dim,
570 )
571 )
573 def get_db(self) -> lancedb.DBConnection:
574 if self._db is None:
575 import lancedb as _lancedb
577 self._config.lancedb_dir.mkdir(parents=True, exist_ok=True)
578 self._db = _lancedb.connect(
579 str(self._config.lancedb_dir),
580 read_consistency_interval=READ_CONSISTENCY_INTERVAL,
581 )
582 return self._db
584 def open_table(self, name: str) -> lancedb.table.Table | None:
585 """Open a table if it exists, otherwise return None."""
586 db = self.get_db()
587 if name not in _table_names(db):
588 return None
589 return db.open_table(name)
591 def ensure_fts_index(self) -> None:
592 """Create the chunks FTS index, or run ``optimize()`` once it exists.
594 ``optimize()`` folds newly added rows into the FTS index and also
595 runs LanceDB's default compaction + version pruning (default prune
596 window: 7 days). Work scales with recent deltas rather than total
597 chunk count, so large corpora no longer pay the full
598 ``create_fts_index(replace=True)`` rebuild cost on every sync.
599 """
600 with write_lock():
601 table = self.open_table(CHUNKS_TABLE)
602 if table is None:
603 return
604 try:
605 if _has_fts_index(table):
606 table.optimize()
607 log.debug("FTS index optimized on '%s'", CHUNKS_TABLE)
608 else:
609 table.create_fts_index("chunk", replace=False)
610 log.debug("FTS index created on '%s'", CHUNKS_TABLE)
611 self._fts_ready = True
612 except Exception:
613 log.debug("FTS index ensure failed (empty table?)", exc_info=True)
615 def add_chunks(self, records: list[dict]) -> int:
616 """Add chunk records to the store. Returns count added.
618 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was
619 written under a different embedding model than the current ``cfg``. On the
620 first write to a fresh store, ``_meta`` is initialized from the current cfg.
622 The gate runs inside the write lock and uses a single cfg snapshot so a
623 concurrent ``set_embedding_model`` cannot slip a write in past a stale
624 compatibility check.
625 """
626 with write_lock():
627 embedding_model = self._config.embedding_model
628 embedding_dim = self._config.embedding_dim
629 self._ensure_embedding_compat()
630 self._fts_ready = False
631 if not records:
632 return 0
633 for rec in records:
634 vec = rec.get("vector", [])
635 if len(vec) != embedding_dim:
636 raise ValueError(
637 f"Vector dimension mismatch: expected {embedding_dim}, "
638 f"got {len(vec)} (source={rec.get('source', '?')})"
639 )
640 db = self.get_db()
641 table = ensure_table(db, CHUNKS_TABLE, self._chunks_schema())
642 table.add(records)
643 if self.get_meta() is None:
644 self._write_meta_unlocked(
645 embedding_model=embedding_model, embedding_dim=embedding_dim
646 )
647 return len(records)
649 def bm25_probe(self, query_text: str, top_k: int = 5) -> list[SearchChunk]:
650 """Quick BM25-only search for confidence checking. Returns up to top_k results."""
651 table = self.open_table(CHUNKS_TABLE)
652 if table is None:
653 return []
654 if not self._fts_ready:
655 self.ensure_fts_index()
656 if not self._fts_ready:
657 return [] # pragma: no cover
658 try:
659 rows = table.search(query_text, query_type="fts").limit(top_k).to_list()
660 return [SearchChunk(**r) for r in rows]
661 except Exception:
662 log.debug("BM25 probe failed", exc_info=True)
663 return []
665 def search(
666 self,
667 query_vector: list[float],
668 top_k: int | None = None,
669 max_distance: float | None = None,
670 query_text: str | None = None,
671 chunk_type: str | None = None,
672 ) -> list[SearchChunk]:
673 """Search for similar chunks. Hybrid when FTS available, else vector-only.
675 Results with distance > max_distance are filtered out (vector-only path).
676 Pass max_distance=0 to disable filtering.
677 When *chunk_type* is set, only chunks of that type ("raw" or "wiki") are returned.
679 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was
680 written under a different embedding model than the current ``cfg``.
681 """
682 if top_k is None:
683 top_k = self._config.top_k
684 if max_distance is None:
685 max_distance = self._config.max_distance
686 table = self.open_table(CHUNKS_TABLE)
687 if table is None:
688 return []
689 self.initialize_meta_if_legacy()
690 self._ensure_embedding_compat()
692 if query_text and not self._fts_ready:
693 self.ensure_fts_index()
695 if query_text and self._fts_ready:
696 try:
697 return _hybrid_search(table, query_text, query_vector, top_k, chunk_type)
698 except Exception:
699 log.debug("Hybrid search failed, falling back to vector-only", exc_info=True)
701 candidate_k = top_k * self._config.candidate_multiplier
702 query = table.search(query_vector).metric("cosine").limit(candidate_k)
703 if chunk_type:
704 query = query.where(_chunk_type_predicate(chunk_type))
705 rows = query.to_list()
706 log.debug(
707 "Vector search: query=%r, candidates=%d, max_distance=%.2f",
708 query_text or "vector-only",
709 len(rows),
710 max_distance,
711 )
712 if rows:
713 distances = [r.get("distance", 0) for r in rows[:5]]
714 log.debug("Top 5 distances: %s", distances)
715 results = [SearchChunk(**r) for r in rows]
716 if max_distance > 0:
717 before = len(results)
718 if self._config.adaptive_threshold:
719 results = self._adaptive_filter(results, top_k, max_distance)
720 log.debug(
721 "After adaptive filter: %d/%d results, threshold=%.2f",
722 len(results),
723 before,
724 max_distance,
725 )
726 else:
727 results = self._fixed_filter(results, max_distance)
728 log.debug(
729 "After fixed filter: %d/%d results, threshold=%.2f",
730 len(results),
731 before,
732 max_distance,
733 )
734 if len(results) > top_k:
735 results = mmr_rerank(query_vector, results, top_k, self._config.mmr_lambda)
736 return results
738 def _adaptive_filter(
739 self, results: list[SearchChunk], top_k: int, initial_threshold: float
740 ) -> list[SearchChunk]:
741 """Widen cosine distance threshold when too few results.
742 Inspired by grantflow's (grantflow-ai/grantflow) adaptive retrieval
743 pattern which widens thresholds on recursive retry. Step size and
744 cap are configurable via ``self._config.adaptive_threshold_step``.
746 Pre-sorts results by distance for a single-pass cutoff search.
747 Step size is ``self._config.adaptive_threshold_step`` (default 0.2).
748 """
749 cap = max(initial_threshold, _MAX_THRESHOLD)
750 step = self._config.adaptive_threshold_step
752 sorted_results = sorted(results, key=_get_distance)
754 threshold = initial_threshold
755 for _ in range(_MAX_FILTER_ITERATIONS):
756 if threshold > cap:
757 break
758 cutoff = _count_within_threshold(sorted_results, threshold)
759 if cutoff >= top_k:
760 return sorted_results[:cutoff]
761 threshold += step
762 # Final pass at cap
763 cutoff = _count_within_threshold(sorted_results, cap)
764 return sorted_results[:cutoff]
766 def _fixed_filter(self, results: list[SearchChunk], threshold: float) -> list[SearchChunk]:
767 """Simple fixed threshold filter - keep only results within distance threshold."""
768 return [r for r in results if _get_distance(r) <= threshold]
770 def get_chunks_by_source(self, source: str) -> list[SearchChunk]:
771 """Return every chunk whose ``source`` equals *source*."""
772 table = self.open_table(CHUNKS_TABLE)
773 if table is None:
774 return []
775 escaped = escape_sql_string(source)
776 try:
777 rows = table.search().where(f"source = '{escaped}'").limit(None).to_list()
778 except Exception:
779 # FTS-enabled tables return a query builder that cannot
780 # handle .where() on arbitrary columns; fall through to a
781 # pyarrow.compute filter on the Arrow table so the source
782 # match runs in C++ without materializing non-matching rows.
783 log.debug("get_chunks_by_source search() failed, using Arrow fallback", exc_info=True)
784 arrow_tbl = table.to_arrow()
785 filtered = arrow_tbl.filter(pc.equal(arrow_tbl["source"], source))
786 rows = filtered.to_pylist()
787 return [SearchChunk(**r) for r in rows]
789 def delete_by_source(self, source: str) -> None:
790 """Delete all chunks from a given source file."""
791 with write_lock():
792 table = self.open_table(CHUNKS_TABLE)
793 if table is not None:
794 _safe_delete_unlocked(table, f"source = '{escape_sql_string(source)}'")
795 self._invalidate_source_cache()
797 def get_sources(
798 self,
799 *,
800 search: str | None = None,
801 limit: int | None = None,
802 offset: int = 0,
803 ) -> list[SourceRecord]:
804 """Return source records, filtered by *search* and sliced by offset/limit."""
805 table = self.open_table(SOURCES_TABLE)
806 if table is None:
807 return []
808 query = table.search()
809 where = _sources_search_filter(search)
810 if where is not None:
811 query = query.where(where)
812 if offset:
813 query = query.offset(offset)
814 query = query.limit(limit)
815 result: list[SourceRecord] = query.to_list() # type: ignore[assignment]
816 return result
818 def count_sources(self, *, search: str | None = None) -> int:
819 """Count tracked sources matching *search* without materializing rows."""
820 table = self.open_table(SOURCES_TABLE)
821 if table is None:
822 return 0
823 where = _sources_search_filter(search)
824 count: int = table.count_rows() if where is None else table.count_rows(filter=where)
825 return count
827 def upsert_source(
828 self,
829 filename: str,
830 file_hash: str,
831 chunk_count: int,
832 source_type: str = "document",
833 ) -> None:
834 """Add or update a source file tracking record."""
835 with write_lock():
836 db = self.get_db()
837 table = ensure_table(db, SOURCES_TABLE, _sources_schema())
838 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'")
839 table.add(
840 [
841 {
842 "filename": filename,
843 "file_hash": file_hash,
844 "ingested_at": datetime.now(UTC).isoformat(),
845 "chunk_count": chunk_count,
846 "source_type": source_type,
847 }
848 ]
849 )
850 self._invalidate_source_cache()
852 def delete_source(self, filename: str) -> None:
853 """Remove a source file tracking record."""
854 with write_lock():
855 table = self.open_table(SOURCES_TABLE)
856 if table is not None:
857 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'")
858 self._invalidate_source_cache()
860 def remove_documents(
861 self,
862 names: list[str],
863 *,
864 delete_files: bool = False,
865 documents_dir: Path | None = None,
866 ) -> RemoveResult:
867 """Remove documents from the knowledge base by source name.
868 Looks up known sources, deletes chunks and source records for each.
869 If *delete_files* is True, resolves the path and verifies it is
870 contained within *documents_dir* before unlinking (path traversal guard).
872 Returns a RemoveResult with removed and not_found lists.
873 """
874 if documents_dir is None:
875 documents_dir = self._config.documents_dir
877 known = {s["filename"] for s in self.get_sources()}
878 removed: list[str] = []
879 not_found: list[str] = []
881 for name in names:
882 if name not in known:
883 not_found.append(name)
884 continue
885 self.delete_by_source(name)
886 self.delete_source(name)
887 removed.append(name)
888 if delete_files:
889 try:
890 path = validate_path_within(documents_dir / name, documents_dir)
891 except ValueError:
892 log.warning("Path traversal blocked: %s escapes %s", name, documents_dir)
893 continue
894 if path.exists():
895 path.unlink()
897 return RemoveResult(removed=removed, not_found=not_found)
899 def clear_table(self, name: str, predicate: str) -> None:
900 """Delete rows matching *predicate* from *name*. Acquires write lock."""
901 with write_lock():
902 table = self.open_table(name)
903 if table is not None:
904 _safe_delete_unlocked(table, predicate)
906 def add_citations(self, records: list[CitationRecord]) -> int:
907 """Add citation records to the store. Returns count added."""
908 if not records:
909 return 0
910 with write_lock():
911 db = self.get_db()
912 table = ensure_table(db, CITATIONS_TABLE, _citations_schema())
913 table.add(records)
914 return len(records)
916 def get_citations_for_wiki(self, wiki_source: str) -> list[CitationRecord]:
917 """Get all citations for a wiki page."""
918 table = self.open_table(CITATIONS_TABLE)
919 if table is None:
920 return []
921 escaped = escape_sql_string(wiki_source)
922 rows: list[CitationRecord] = table.search().where(f"wiki_source = '{escaped}'").to_list()
923 return rows
925 def get_citations_for_source(self, source_filename: str) -> list[CitationRecord]:
926 """Get all citations that reference a source document (reverse lookup)."""
927 table = self.open_table(CITATIONS_TABLE)
928 if table is None:
929 return []
930 escaped = escape_sql_string(source_filename)
931 rows: list[CitationRecord] = (
932 table.search().where(f"source_filename = '{escaped}'").to_list()
933 )
934 return rows
936 def delete_citations_for_wiki(self, wiki_source: str) -> None:
937 """Delete all citations for a wiki page (used before regeneration)."""
938 self.clear_table(
939 CITATIONS_TABLE,
940 f"wiki_source = '{escape_sql_string(wiki_source)}'",
941 )
943 def close(self) -> None:
944 """Release the database connection and reset state."""
945 self._db = None
946 self._fts_ready = False
948 def drop_all(self) -> None:
949 """Drop all tables -- used by rebuild."""
950 with write_lock():
951 self._fts_ready = False
952 db = self.get_db()
953 for name in _table_names(db):
954 db.drop_table(name)
955 self._invalidate_source_cache()