Coverage for src / lilbee / clustering_embedding.py: 100%

221 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Chunk-level mutual-kNN source clusterer — default wiki synthesis backend. 

2 

3Pipeline: 

4 

51. Load every chunk's (source, chunk_index, text, vector) from LanceDB. 

62. Build a float32 matrix ``V`` and L2-normalize rows. 

73. Compute a mutual k-nearest-neighbors graph over chunks using blocked 

8 similarity. ``k`` auto-scales from corpus size unless 

9 ``config.wiki_clusterer_k`` is set. 

104. Run asynchronous Label Propagation (Raghavan et al. 2007) over the 

11 mutual-kNN graph to obtain chunk-level communities. 

125. Aggregate chunk communities into source communities, requiring a 

13 source to contribute at least ``min(3, ceil(0.2 * total_chunks))`` 

14 chunks before it joins a cluster. This keeps a single stray chunk 

15 from dragging a whole document into an unrelated cluster. 

166. Filter communities that span fewer than ``min_sources`` distinct 

17 sources, then label each cluster with TF-IDF scoring over member 

18 chunk text against corpus-wide document frequency. 

19 

20Mutual-kNN is itself hub-robust: a pathological hub ends up in many 

21one-way neighborhoods but can reciprocate at most ``k`` of them, so 

22hub-driven bridging across topics is broken at the graph-construction 

23step without any post-hoc similarity rescaling. 

24 

25Uses numpy for the O(N²) similarity kernel. numpy is declared in 

26``pyproject.toml`` as a direct dependency. 

27 

28The implementation is blocked in row chunks of ``_BLOCK_SIZE`` to keep 

29peak memory bounded regardless of corpus size, so it scales comfortably 

30to tens of thousands of chunks on a laptop. 

31""" 

32 

33from __future__ import annotations 

34 

35import logging 

36import math 

37from collections import Counter 

38from dataclasses import dataclass, field 

39 

40import numpy as np 

41 

42from lilbee.clustering import SourceCluster 

43from lilbee.config import CHUNKS_TABLE, Config 

44from lilbee.store import Store 

45 

46log = logging.getLogger(__name__) 

47 

48# Block size for the similarity kernel. With N=10000 and D=768 this caps 

49# peak float32 memory at block * N * 4 bytes ~= 40 MB. 

50_BLOCK_SIZE = 1024 

51 

52# Label Propagation hard iteration cap. Convergence is typically reached 

53# in well under 10 passes on real corpora. 

54_MAX_LPA_ITERATIONS = 30 

55 

56# Minimum non-zero L2 norm for a row vector to be kept. 

57_MIN_VECTOR_NORM = 1e-12 

58 

59# Source-membership thresholds. A source joins a chunk community when it 

60# contributes at least `min(_MIN_SOURCE_CHUNKS, ceil(total * _MIN_SOURCE_FRACTION))` 

61# of its chunks. The stricter (smaller) side wins, so a single stray chunk 

62# from a long document never pulls the whole source into an unrelated cluster. 

63_MIN_SOURCE_CHUNKS = 3 

64_MIN_SOURCE_FRACTION = 0.2 

65 

66# TF-IDF labeling knobs. 

67_LABEL_TOP_TERMS = 3 

68# Chunks with fewer tokens than this are down-weighted when accumulating 

69# term frequency so short boilerplate (headings, captions) cannot dominate 

70# a cluster label. 20 tokens roughly matches the token count of a section 

71# heading or a two-sentence summary. 

72_SHORT_CHUNK_TOKEN_CAP = 20 

73 

74# kNN auto-scaling bounds. Formula: clamp(round(log2(N)+2), _MIN_K, _MAX_K). 

75_MIN_K = 5 

76_MAX_K = 20 

77 

78# Minimum token length for TF-IDF labeling. Shorter tokens are mostly 

79# articles, prepositions, and single letters — noise that inflates term 

80# counts without adding topic signal. Three characters keeps useful 

81# acronyms (api, xml, sql). 

82_MIN_TF_TOKEN_LEN = 3 

83 

84 

85def _tokenize_for_tf(text: str) -> list[str]: 

86 """Lowercase alphanumeric tokens for TF-IDF scoring. 

87 

88 Deliberately has NO stopword list: common words like "the" or "and" 

89 get an IDF near zero (they appear in almost every chunk) so TF-IDF 

90 filters them automatically. A hand-curated English stoplist would 

91 add maintenance burden and break on non-English corpora for no 

92 additional quality. 

93 """ 

94 result: list[str] = [] 

95 for raw in text.lower().split(): 

96 word = "".join(ch for ch in raw if ch.isalnum()) 

97 if len(word) >= _MIN_TF_TOKEN_LEN: 

98 result.append(word) 

99 return result 

100 

101 

102@dataclass 

103class ChunkRecord: 

104 """Lightweight view of one chunk row used by the clusterer.""" 

105 

106 source: str 

107 chunk_index: int 

108 text: str 

109 tokens: list[str] = field(default_factory=list) 

110 

111 

112def auto_k(n: int) -> int: 

113 """Pick a neighborhood size from corpus size via ``clamp(log2(N)+2)``.""" 

114 if n <= 1: 

115 return _MIN_K 

116 raw = round(math.log2(max(n, 4)) + 2) 

117 return max(_MIN_K, min(_MAX_K, raw)) 

118 

119 

120def _parse_chunk_row( 

121 row: dict[str, object], 

122) -> tuple[ChunkRecord, list[float] | tuple[float, ...]] | None: 

123 """Extract a chunk record + vector from a raw Arrow row, or None on invalid.""" 

124 vector = row.get("vector") 

125 if not isinstance(vector, (list, tuple)): 

126 return None 

127 source = row.get("source") 

128 if not isinstance(source, str): 

129 return None 

130 raw_text = row.get("chunk") 

131 chunk_text = raw_text if isinstance(raw_text, str) else "" 

132 raw_index = row.get("chunk_index") 

133 chunk_index = raw_index if isinstance(raw_index, int) else 0 

134 record = ChunkRecord( 

135 source=source, 

136 chunk_index=chunk_index, 

137 text=chunk_text, 

138 tokens=_tokenize_for_tf(chunk_text), 

139 ) 

140 return record, vector 

141 

142 

143def _load_chunk_records( 

144 store: Store, 

145) -> tuple[list[ChunkRecord], np.ndarray]: 

146 """Scan the chunks table once and return records plus a float32 matrix. 

147 

148 Rows with an unparseable vector are skipped. Records are sorted by 

149 ``(source, chunk_index)`` so downstream cluster IDs are stable 

150 regardless of LanceDB's row return order. Records are tokenized once 

151 here so TF-IDF labeling does not re-tokenize. The vector matrix is 

152 preallocated and populated via numpy row-assignment, which pushes 

153 the Python-level float cast into numpy's C loop and avoids building 

154 a transient ``list[list[float]]``. 

155 """ 

156 table = store.open_table(CHUNKS_TABLE) 

157 if table is None: 

158 return [], np.zeros((0, 0), dtype=np.float32) 

159 

160 parsed = [pair for pair in map(_parse_chunk_row, table.to_arrow().to_pylist()) if pair] 

161 if not parsed: 

162 return [], np.zeros((0, 0), dtype=np.float32) 

163 

164 parsed.sort(key=lambda pair: (pair[0].source, pair[0].chunk_index)) 

165 dim = len(parsed[0][1]) 

166 matrix = np.empty((len(parsed), dim), dtype=np.float32) 

167 records: list[ChunkRecord] = [] 

168 for row_idx, (record, vector) in enumerate(parsed): 

169 records.append(record) 

170 matrix[row_idx] = vector 

171 return records, matrix 

172 

173 

174def normalize_rows(matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 

175 """Return (normalized_matrix, keep_mask). Zero-norm rows are dropped.""" 

176 if matrix.size == 0: 

177 return matrix, np.zeros(0, dtype=bool) 

178 norms = np.linalg.norm(matrix, axis=1) 

179 keep = norms > _MIN_VECTOR_NORM 

180 if not keep.all(): 

181 matrix = matrix[keep] 

182 norms = norms[keep] 

183 return matrix / norms[:, None], keep 

184 

185 

186def mutual_knn(matrix: np.ndarray, k: int) -> dict[int, set[int]]: 

187 """Build a mutual k-nearest-neighbors graph over L2-normalized rows. 

188 

189 Computes similarity in row blocks so peak memory stays bounded. 

190 Self-similarity is masked so each row's neighbors exclude itself. 

191 Mutuality is enforced by keeping only edges ``(i, j)`` where 

192 ``j`` is in row ``i``'s top-k AND ``i`` is in row ``j``'s top-k; 

193 this single rule breaks hub-driven bridging without any extra 

194 similarity rescaling. 

195 """ 

196 n = matrix.shape[0] 

197 if n == 0 or k <= 0: 

198 return {} 

199 

200 effective_k = min(k, n - 1) 

201 if effective_k <= 0: 

202 return {i: set() for i in range(n)} 

203 

204 top_neighbors: list[set[int]] = [set() for _ in range(n)] 

205 

206 for start in range(0, n, _BLOCK_SIZE): 

207 stop = min(start + _BLOCK_SIZE, n) 

208 sim_block = matrix[start:stop] @ matrix.T # (block, n) 

209 # Mask self-similarity so each row's own index is never returned. 

210 # For the tail block where stop-start < _BLOCK_SIZE the fancy-index 

211 # pairing still lines up because both sides are length stop-start. 

212 block_rows = np.arange(stop - start) 

213 sim_block[block_rows, np.arange(start, stop)] = -math.inf 

214 # Partition by largest similarities without allocating a negated 

215 # copy of the block — pass a negative kth to select the tail. 

216 neighbor_idx = np.argpartition(sim_block, -effective_k, axis=1)[:, -effective_k:] 

217 for local_row, global_row in enumerate(range(start, stop)): 

218 top_neighbors[global_row] = set(neighbor_idx[local_row].tolist()) 

219 

220 mutual: dict[int, set[int]] = {i: set() for i in range(n)} 

221 for i in range(n): 

222 for j in top_neighbors[i]: 

223 if i in top_neighbors[j]: 

224 mutual[i].add(j) 

225 return mutual 

226 

227 

228def label_propagation( 

229 adjacency: dict[int, set[int]], 

230 order: list[int], 

231) -> list[int]: 

232 """Async Label Propagation with deterministic min-label tie-breaking. 

233 

234 Each node adopts the most common label among its neighbors. Ties are 

235 broken by smallest label id so the outcome is reproducible across 

236 runs on the same corpus. The caller is responsible for passing a 

237 complete ``adjacency`` (one entry per node 0..n-1) and an ``order`` 

238 that covers every node — ``get_clusters`` guarantees both. 

239 """ 

240 n = len(adjacency) 

241 labels = list(range(n)) 

242 for _ in range(_MAX_LPA_ITERATIONS): 

243 changed = False 

244 for node in order: 

245 neighbors = adjacency.get(node) 

246 if not neighbors: 

247 continue 

248 counts: Counter[int] = Counter(labels[j] for j in neighbors) 

249 top_count = max(counts.values()) 

250 best = min(label for label, count in counts.items() if count == top_count) 

251 if labels[node] != best: 

252 labels[node] = best 

253 changed = True 

254 if not changed: 

255 break 

256 return labels 

257 

258 

259def communities_by_label(labels: list[int]) -> dict[int, list[int]]: 

260 """Group node indices by their final community label.""" 

261 communities: dict[int, list[int]] = {} 

262 for node, label in enumerate(labels): 

263 communities.setdefault(label, []).append(node) 

264 return communities 

265 

266 

267def _source_totals(records: list[ChunkRecord]) -> dict[str, int]: 

268 """Return the total chunk count per source across the whole corpus.""" 

269 totals: dict[str, int] = {} 

270 for record in records: 

271 totals[record.source] = totals.get(record.source, 0) + 1 

272 return totals 

273 

274 

275def _filter_sources( 

276 member_indices: list[int], 

277 records: list[ChunkRecord], 

278 source_totals: dict[str, int], 

279) -> frozenset[str]: 

280 """Apply the source-membership threshold to a community's members.""" 

281 per_source: dict[str, int] = {} 

282 for idx in member_indices: 

283 source = records[idx].source 

284 per_source[source] = per_source.get(source, 0) + 1 

285 kept: set[str] = set() 

286 for source, count in per_source.items(): 

287 total = source_totals.get(source, count) 

288 fractional_cutoff = math.ceil(total * _MIN_SOURCE_FRACTION) 

289 cutoff = min(_MIN_SOURCE_CHUNKS, fractional_cutoff) 

290 if count >= cutoff: 

291 kept.add(source) 

292 return frozenset(kept) 

293 

294 

295def _corpus_document_frequency(records: list[ChunkRecord]) -> dict[str, int]: 

296 """Compute document frequency (chunk count containing term) for every term.""" 

297 df: dict[str, int] = {} 

298 for record in records: 

299 for term in set(record.tokens): 

300 df[term] = df.get(term, 0) + 1 

301 return df 

302 

303 

304def _label_community( 

305 member_indices: list[int], 

306 records: list[ChunkRecord], 

307 df: dict[str, int], 

308 total_chunks: int, 

309 fallback: str, 

310) -> str: 

311 """Pick a topic label for a community using sublinear TF-IDF scoring.""" 

312 tf: dict[str, float] = {} 

313 for idx in member_indices: 

314 tokens = records[idx].tokens 

315 if not tokens: 

316 continue 

317 weight = min(1.0, len(tokens) / _SHORT_CHUNK_TOKEN_CAP) 

318 counts: Counter[str] = Counter(tokens) 

319 for term, count in counts.items(): 

320 tf[term] = tf.get(term, 0.0) + weight * (1.0 + math.log(count)) 

321 

322 scored: list[tuple[float, str]] = [] 

323 for term, term_tf in tf.items(): 

324 # Standard ``log(N / (1 + df))`` smoothing: the +1 keeps the 

325 # denominator non-zero for new terms and damps the score of 

326 # terms that appear in every chunk (where idf goes negative 

327 # and the term is filtered out entirely). 

328 idf = math.log(total_chunks / (1 + df.get(term, 0))) 

329 if idf <= 0: 

330 continue 

331 scored.append((term_tf * idf, term)) 

332 if not scored: 

333 return fallback 

334 

335 scored.sort(key=lambda pair: (-pair[0], pair[1])) 

336 return " ".join(term for _, term in scored[:_LABEL_TOP_TERMS]) 

337 

338 

339def _build_clusters( 

340 communities: dict[int, list[int]], 

341 records: list[ChunkRecord], 

342 source_totals: dict[str, int], 

343 df: dict[str, int], 

344 min_sources: int, 

345) -> tuple[list[SourceCluster], int]: 

346 """Turn raw chunk communities into published source clusters. 

347 

348 Returns ``(clusters, noise_chunk_count)`` where ``noise_chunk_count`` 

349 is the number of chunks whose community failed the source filter. 

350 """ 

351 ordered = sorted(communities.items(), key=lambda pair: (-len(pair[1]), pair[0])) 

352 total_chunks = len(records) 

353 clusters: list[SourceCluster] = [] 

354 noise = 0 

355 for idx, (_, members) in enumerate(ordered): 

356 kept_sources = _filter_sources(members, records, source_totals) 

357 if len(kept_sources) < min_sources: 

358 noise += len(members) 

359 continue 

360 cluster_id = f"embedding-{idx}" 

361 label = _label_community(members, records, df, total_chunks, fallback=cluster_id) 

362 clusters.append( 

363 SourceCluster( 

364 cluster_id=cluster_id, 

365 label=label, 

366 sources=kept_sources, 

367 ) 

368 ) 

369 return clusters, noise 

370 

371 

372def _warn_if_undersegmented( 

373 clusters: list[SourceCluster], 

374 source_totals: dict[str, int], 

375) -> None: 

376 """Warn when a single cluster covers more than half the corpus sources.""" 

377 if not clusters or not source_totals: 

378 return 

379 total_sources = len(source_totals) 

380 for cluster in clusters: 

381 if len(cluster.sources) * 2 > total_sources: 

382 log.warning( 

383 "wiki clustering: cluster %r covers %d/%d sources; " 

384 "consider lowering wiki_clusterer_k or check embedding quality", 

385 cluster.label, 

386 len(cluster.sources), 

387 total_sources, 

388 ) 

389 break 

390 

391 

392class EmbeddingClusterer: 

393 """Chunk-level mutual-kNN clusterer with TF-IDF labels.""" 

394 

395 def __init__(self, config: Config, store: Store) -> None: 

396 self._config = config 

397 self._store = store 

398 

399 def available(self) -> bool: 

400 """Clusterer is available when the chunks table has any rows. 

401 

402 ``count_rows()`` is a LanceDB call that can raise on transient 

403 backend issues (concurrent compaction, schema rewrites). When 

404 it does, we optimistically report available=True and let 

405 ``get_clusters`` surface the real error on the next scan — the 

406 alternative would silently disable wiki synthesis without the 

407 user seeing why. A WARNING is emitted so the failure is still 

408 visible at the default log level. 

409 """ 

410 table = self._store.open_table(CHUNKS_TABLE) 

411 if table is None: 

412 return False 

413 try: 

414 return bool(table.count_rows()) 

415 except Exception: 

416 log.warning( 

417 "count_rows() failed on chunks table; reporting available=True " 

418 "optimistically and deferring the error to get_clusters", 

419 exc_info=True, 

420 ) 

421 return True 

422 

423 def get_clusters(self, min_sources: int = 3) -> list[SourceCluster]: 

424 """Return chunk-level communities projected to source clusters.""" 

425 records, matrix = _load_chunk_records(self._store) 

426 if not records: 

427 return [] 

428 

429 matrix, keep_mask = normalize_rows(matrix) 

430 records = [record for record, keep in zip(records, keep_mask, strict=True) if keep] 

431 if not records: 

432 return [] 

433 

434 configured_k = self._config.wiki_clusterer_k 

435 k = configured_k if configured_k > 0 else auto_k(len(records)) 

436 adjacency = mutual_knn(matrix, k) 

437 if not any(adjacency.values()): 

438 # WARNING (not INFO) so users see why synthesis produced zero 

439 # pages at the default log level — matches the other degenerate 

440 # clustering outcome, ``_warn_if_undersegmented``. 

441 log.warning( 

442 "wiki clustering: N=%d k=%d no mutual edges — skipping synthesis", 

443 len(records), 

444 k, 

445 ) 

446 return [] 

447 labels = label_propagation(adjacency, order=list(range(len(records)))) 

448 communities = communities_by_label(labels) 

449 

450 totals = _source_totals(records) 

451 df = _corpus_document_frequency(records) 

452 clusters, noise = _build_clusters(communities, records, totals, df, min_sources) 

453 

454 log.info( 

455 "wiki clustering: N=%d k=%d communities=%d kept=%d noise=%d", 

456 len(records), 

457 k, 

458 len(communities), 

459 len(clusters), 

460 noise, 

461 ) 

462 _warn_if_undersegmented(clusters, totals) 

463 return clusters