Coverage for src / lilbee / store.py: 100%
100 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 08:27 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 08:27 +0000
1"""LanceDB vector store operations."""
3import logging
4from datetime import UTC, datetime
6import lancedb
7import pyarrow as pa
8from typing_extensions import TypedDict
10from lilbee.config import CHUNKS_TABLE, SOURCES_TABLE, cfg
12log = logging.getLogger(__name__)
15class SearchChunk(TypedDict):
16 """A search result row from LanceDB — chunk fields plus distance."""
18 source: str
19 content_type: str
20 page_start: int
21 page_end: int
22 line_start: int
23 line_end: int
24 chunk: str
25 chunk_index: int
26 vector: list[float]
27 _distance: float
30def _chunks_schema() -> pa.Schema:
31 return pa.schema(
32 [
33 pa.field("source", pa.utf8()),
34 pa.field("content_type", pa.utf8()),
35 pa.field("page_start", pa.int32()),
36 pa.field("page_end", pa.int32()),
37 pa.field("line_start", pa.int32()),
38 pa.field("line_end", pa.int32()),
39 pa.field("chunk", pa.utf8()),
40 pa.field("chunk_index", pa.int32()),
41 pa.field("vector", pa.list_(pa.float32(), cfg.embedding_dim)),
42 ]
43 )
46def _sources_schema() -> pa.Schema:
47 return pa.schema(
48 [
49 pa.field("filename", pa.utf8()),
50 pa.field("file_hash", pa.utf8()),
51 pa.field("ingested_at", pa.utf8()),
52 pa.field("chunk_count", pa.int32()),
53 ]
54 )
57def get_db() -> lancedb.DBConnection:
58 cfg.lancedb_dir.mkdir(parents=True, exist_ok=True)
59 return lancedb.connect(str(cfg.lancedb_dir))
62def _table_names(db: lancedb.DBConnection) -> list[str]:
63 """Get list of table names, handling the ListTablesResponse object."""
64 result = db.list_tables()
65 return result.tables if hasattr(result, "tables") else list(result)
68def ensure_table(db: lancedb.DBConnection, name: str, schema: pa.Schema) -> lancedb.table.Table:
69 if name in _table_names(db):
70 return db.open_table(name)
71 try:
72 return db.create_table(name, schema=schema)
73 except ValueError:
74 return db.open_table(name)
77def _open_table(name: str) -> lancedb.table.Table | None:
78 """Open a table if it exists, otherwise return None."""
79 db = get_db()
80 if name not in _table_names(db):
81 return None
82 return db.open_table(name)
85def safe_delete(table: lancedb.table.Table, predicate: str) -> None:
86 """Delete rows matching predicate, logging on failure."""
87 try:
88 table.delete(predicate)
89 except Exception:
90 log.warning("Failed to delete rows matching: %s", predicate, exc_info=True)
93def _escape_sql_string(value: str) -> str:
94 """Escape single quotes for SQL predicates."""
95 return value.replace("'", "''")
98def add_chunks(records: list[dict]) -> int:
99 """Add chunk records to the store. Returns count added."""
100 if not records:
101 return 0
102 for rec in records:
103 vec = rec.get("vector", [])
104 if len(vec) != cfg.embedding_dim:
105 raise ValueError(
106 f"Vector dimension mismatch: expected {cfg.embedding_dim}, got {len(vec)} "
107 f"(source={rec.get('source', '?')})"
108 )
109 db = get_db()
110 table = ensure_table(db, CHUNKS_TABLE, _chunks_schema())
111 table.add(records)
112 return len(records)
115def search(
116 query_vector: list[float],
117 top_k: int | None = None,
118 max_distance: float | None = None,
119) -> list[SearchChunk]:
120 """Search for similar chunks by vector similarity.
122 Results with distance > max_distance are filtered out.
123 Pass max_distance=0 to disable filtering.
124 """
125 if top_k is None:
126 top_k = cfg.top_k
127 if max_distance is None:
128 max_distance = cfg.max_distance
129 table = _open_table(CHUNKS_TABLE)
130 if table is None:
131 return []
132 results: list[SearchChunk] = table.search(query_vector).metric("cosine").limit(top_k).to_list()
133 if max_distance > 0:
134 results = [r for r in results if r["_distance"] <= max_distance]
135 return results
138def get_chunks_by_source(source: str) -> list[SearchChunk]:
139 """Return all chunks for a given source file."""
140 table = _open_table(CHUNKS_TABLE)
141 if table is None:
142 return []
143 escaped = _escape_sql_string(source)
144 rows: list[SearchChunk] = table.search().where(f"source = '{escaped}'").to_list()
145 return rows
148def delete_by_source(source: str) -> None:
149 """Delete all chunks from a given source file."""
150 table = _open_table(CHUNKS_TABLE)
151 if table is not None:
152 safe_delete(table, f"source = '{_escape_sql_string(source)}'")
155def get_sources() -> list[dict]:
156 """Get all tracked source file records."""
157 table = _open_table(SOURCES_TABLE)
158 if table is None:
159 return []
160 result: list[dict] = table.to_arrow().to_pylist()
161 return result
164def upsert_source(filename: str, file_hash: str, chunk_count: int) -> None:
165 """Add or update a source file tracking record."""
166 db = get_db()
167 table = ensure_table(db, SOURCES_TABLE, _sources_schema())
168 safe_delete(table, f"filename = '{_escape_sql_string(filename)}'")
169 table.add(
170 [
171 {
172 "filename": filename,
173 "file_hash": file_hash,
174 "ingested_at": datetime.now(UTC).isoformat(),
175 "chunk_count": chunk_count,
176 }
177 ]
178 )
181def delete_source(filename: str) -> None:
182 """Remove a source file tracking record."""
183 table = _open_table(SOURCES_TABLE)
184 if table is not None:
185 safe_delete(table, f"filename = '{_escape_sql_string(filename)}'")
188def drop_all() -> None:
189 """Drop all tables — used by rebuild."""
190 db = get_db()
191 for name in _table_names(db):
192 db.drop_table(name)