Coverage for src / lilbee / store.py: 100%

100 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 08:27 +0000

1"""LanceDB vector store operations.""" 

2 

3import logging 

4from datetime import UTC, datetime 

5 

6import lancedb 

7import pyarrow as pa 

8from typing_extensions import TypedDict 

9 

10from lilbee.config import CHUNKS_TABLE, SOURCES_TABLE, cfg 

11 

12log = logging.getLogger(__name__) 

13 

14 

15class SearchChunk(TypedDict): 

16 """A search result row from LanceDB — chunk fields plus distance.""" 

17 

18 source: str 

19 content_type: str 

20 page_start: int 

21 page_end: int 

22 line_start: int 

23 line_end: int 

24 chunk: str 

25 chunk_index: int 

26 vector: list[float] 

27 _distance: float 

28 

29 

30def _chunks_schema() -> pa.Schema: 

31 return pa.schema( 

32 [ 

33 pa.field("source", pa.utf8()), 

34 pa.field("content_type", pa.utf8()), 

35 pa.field("page_start", pa.int32()), 

36 pa.field("page_end", pa.int32()), 

37 pa.field("line_start", pa.int32()), 

38 pa.field("line_end", pa.int32()), 

39 pa.field("chunk", pa.utf8()), 

40 pa.field("chunk_index", pa.int32()), 

41 pa.field("vector", pa.list_(pa.float32(), cfg.embedding_dim)), 

42 ] 

43 ) 

44 

45 

46def _sources_schema() -> pa.Schema: 

47 return pa.schema( 

48 [ 

49 pa.field("filename", pa.utf8()), 

50 pa.field("file_hash", pa.utf8()), 

51 pa.field("ingested_at", pa.utf8()), 

52 pa.field("chunk_count", pa.int32()), 

53 ] 

54 ) 

55 

56 

57def get_db() -> lancedb.DBConnection: 

58 cfg.lancedb_dir.mkdir(parents=True, exist_ok=True) 

59 return lancedb.connect(str(cfg.lancedb_dir)) 

60 

61 

62def _table_names(db: lancedb.DBConnection) -> list[str]: 

63 """Get list of table names, handling the ListTablesResponse object.""" 

64 result = db.list_tables() 

65 return result.tables if hasattr(result, "tables") else list(result) 

66 

67 

68def ensure_table(db: lancedb.DBConnection, name: str, schema: pa.Schema) -> lancedb.table.Table: 

69 if name in _table_names(db): 

70 return db.open_table(name) 

71 try: 

72 return db.create_table(name, schema=schema) 

73 except ValueError: 

74 return db.open_table(name) 

75 

76 

77def _open_table(name: str) -> lancedb.table.Table | None: 

78 """Open a table if it exists, otherwise return None.""" 

79 db = get_db() 

80 if name not in _table_names(db): 

81 return None 

82 return db.open_table(name) 

83 

84 

85def safe_delete(table: lancedb.table.Table, predicate: str) -> None: 

86 """Delete rows matching predicate, logging on failure.""" 

87 try: 

88 table.delete(predicate) 

89 except Exception: 

90 log.warning("Failed to delete rows matching: %s", predicate, exc_info=True) 

91 

92 

93def _escape_sql_string(value: str) -> str: 

94 """Escape single quotes for SQL predicates.""" 

95 return value.replace("'", "''") 

96 

97 

98def add_chunks(records: list[dict]) -> int: 

99 """Add chunk records to the store. Returns count added.""" 

100 if not records: 

101 return 0 

102 for rec in records: 

103 vec = rec.get("vector", []) 

104 if len(vec) != cfg.embedding_dim: 

105 raise ValueError( 

106 f"Vector dimension mismatch: expected {cfg.embedding_dim}, got {len(vec)} " 

107 f"(source={rec.get('source', '?')})" 

108 ) 

109 db = get_db() 

110 table = ensure_table(db, CHUNKS_TABLE, _chunks_schema()) 

111 table.add(records) 

112 return len(records) 

113 

114 

115def search( 

116 query_vector: list[float], 

117 top_k: int | None = None, 

118 max_distance: float | None = None, 

119) -> list[SearchChunk]: 

120 """Search for similar chunks by vector similarity. 

121 

122 Results with distance > max_distance are filtered out. 

123 Pass max_distance=0 to disable filtering. 

124 """ 

125 if top_k is None: 

126 top_k = cfg.top_k 

127 if max_distance is None: 

128 max_distance = cfg.max_distance 

129 table = _open_table(CHUNKS_TABLE) 

130 if table is None: 

131 return [] 

132 results: list[SearchChunk] = table.search(query_vector).metric("cosine").limit(top_k).to_list() 

133 if max_distance > 0: 

134 results = [r for r in results if r["_distance"] <= max_distance] 

135 return results 

136 

137 

138def get_chunks_by_source(source: str) -> list[SearchChunk]: 

139 """Return all chunks for a given source file.""" 

140 table = _open_table(CHUNKS_TABLE) 

141 if table is None: 

142 return [] 

143 escaped = _escape_sql_string(source) 

144 rows: list[SearchChunk] = table.search().where(f"source = '{escaped}'").to_list() 

145 return rows 

146 

147 

148def delete_by_source(source: str) -> None: 

149 """Delete all chunks from a given source file.""" 

150 table = _open_table(CHUNKS_TABLE) 

151 if table is not None: 

152 safe_delete(table, f"source = '{_escape_sql_string(source)}'") 

153 

154 

155def get_sources() -> list[dict]: 

156 """Get all tracked source file records.""" 

157 table = _open_table(SOURCES_TABLE) 

158 if table is None: 

159 return [] 

160 result: list[dict] = table.to_arrow().to_pylist() 

161 return result 

162 

163 

164def upsert_source(filename: str, file_hash: str, chunk_count: int) -> None: 

165 """Add or update a source file tracking record.""" 

166 db = get_db() 

167 table = ensure_table(db, SOURCES_TABLE, _sources_schema()) 

168 safe_delete(table, f"filename = '{_escape_sql_string(filename)}'") 

169 table.add( 

170 [ 

171 { 

172 "filename": filename, 

173 "file_hash": file_hash, 

174 "ingested_at": datetime.now(UTC).isoformat(), 

175 "chunk_count": chunk_count, 

176 } 

177 ] 

178 ) 

179 

180 

181def delete_source(filename: str) -> None: 

182 """Remove a source file tracking record.""" 

183 table = _open_table(SOURCES_TABLE) 

184 if table is not None: 

185 safe_delete(table, f"filename = '{_escape_sql_string(filename)}'") 

186 

187 

188def drop_all() -> None: 

189 """Drop all tables — used by rebuild.""" 

190 db = get_db() 

191 for name in _table_names(db): 

192 db.drop_table(name)