Coverage for src / lilbee / clustering_embedding.py: 100%
221 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Chunk-level mutual-kNN source clusterer — default wiki synthesis backend.
3Pipeline:
51. Load every chunk's (source, chunk_index, text, vector) from LanceDB.
62. Build a float32 matrix ``V`` and L2-normalize rows.
73. Compute a mutual k-nearest-neighbors graph over chunks using blocked
8 similarity. ``k`` auto-scales from corpus size unless
9 ``config.wiki_clusterer_k`` is set.
104. Run asynchronous Label Propagation (Raghavan et al. 2007) over the
11 mutual-kNN graph to obtain chunk-level communities.
125. Aggregate chunk communities into source communities, requiring a
13 source to contribute at least ``min(3, ceil(0.2 * total_chunks))``
14 chunks before it joins a cluster. This keeps a single stray chunk
15 from dragging a whole document into an unrelated cluster.
166. Filter communities that span fewer than ``min_sources`` distinct
17 sources, then label each cluster with TF-IDF scoring over member
18 chunk text against corpus-wide document frequency.
20Mutual-kNN is itself hub-robust: a pathological hub ends up in many
21one-way neighborhoods but can reciprocate at most ``k`` of them, so
22hub-driven bridging across topics is broken at the graph-construction
23step without any post-hoc similarity rescaling.
25Uses numpy for the O(N²) similarity kernel. numpy is declared in
26``pyproject.toml`` as a direct dependency.
28The implementation is blocked in row chunks of ``_BLOCK_SIZE`` to keep
29peak memory bounded regardless of corpus size, so it scales comfortably
30to tens of thousands of chunks on a laptop.
31"""
33from __future__ import annotations
35import logging
36import math
37from collections import Counter
38from dataclasses import dataclass, field
40import numpy as np
42from lilbee.clustering import SourceCluster
43from lilbee.config import CHUNKS_TABLE, Config
44from lilbee.store import Store
46log = logging.getLogger(__name__)
48# Block size for the similarity kernel. With N=10000 and D=768 this caps
49# peak float32 memory at block * N * 4 bytes ~= 40 MB.
50_BLOCK_SIZE = 1024
52# Label Propagation hard iteration cap. Convergence is typically reached
53# in well under 10 passes on real corpora.
54_MAX_LPA_ITERATIONS = 30
56# Minimum non-zero L2 norm for a row vector to be kept.
57_MIN_VECTOR_NORM = 1e-12
59# Source-membership thresholds. A source joins a chunk community when it
60# contributes at least `min(_MIN_SOURCE_CHUNKS, ceil(total * _MIN_SOURCE_FRACTION))`
61# of its chunks. The stricter (smaller) side wins, so a single stray chunk
62# from a long document never pulls the whole source into an unrelated cluster.
63_MIN_SOURCE_CHUNKS = 3
64_MIN_SOURCE_FRACTION = 0.2
66# TF-IDF labeling knobs.
67_LABEL_TOP_TERMS = 3
68# Chunks with fewer tokens than this are down-weighted when accumulating
69# term frequency so short boilerplate (headings, captions) cannot dominate
70# a cluster label. 20 tokens roughly matches the token count of a section
71# heading or a two-sentence summary.
72_SHORT_CHUNK_TOKEN_CAP = 20
74# kNN auto-scaling bounds. Formula: clamp(round(log2(N)+2), _MIN_K, _MAX_K).
75_MIN_K = 5
76_MAX_K = 20
78# Minimum token length for TF-IDF labeling. Shorter tokens are mostly
79# articles, prepositions, and single letters — noise that inflates term
80# counts without adding topic signal. Three characters keeps useful
81# acronyms (api, xml, sql).
82_MIN_TF_TOKEN_LEN = 3
85def _tokenize_for_tf(text: str) -> list[str]:
86 """Lowercase alphanumeric tokens for TF-IDF scoring.
88 Deliberately has NO stopword list: common words like "the" or "and"
89 get an IDF near zero (they appear in almost every chunk) so TF-IDF
90 filters them automatically. A hand-curated English stoplist would
91 add maintenance burden and break on non-English corpora for no
92 additional quality.
93 """
94 result: list[str] = []
95 for raw in text.lower().split():
96 word = "".join(ch for ch in raw if ch.isalnum())
97 if len(word) >= _MIN_TF_TOKEN_LEN:
98 result.append(word)
99 return result
102@dataclass
103class ChunkRecord:
104 """Lightweight view of one chunk row used by the clusterer."""
106 source: str
107 chunk_index: int
108 text: str
109 tokens: list[str] = field(default_factory=list)
112def auto_k(n: int) -> int:
113 """Pick a neighborhood size from corpus size via ``clamp(log2(N)+2)``."""
114 if n <= 1:
115 return _MIN_K
116 raw = round(math.log2(max(n, 4)) + 2)
117 return max(_MIN_K, min(_MAX_K, raw))
120def _parse_chunk_row(
121 row: dict[str, object],
122) -> tuple[ChunkRecord, list[float] | tuple[float, ...]] | None:
123 """Extract a chunk record + vector from a raw Arrow row, or None on invalid."""
124 vector = row.get("vector")
125 if not isinstance(vector, (list, tuple)):
126 return None
127 source = row.get("source")
128 if not isinstance(source, str):
129 return None
130 raw_text = row.get("chunk")
131 chunk_text = raw_text if isinstance(raw_text, str) else ""
132 raw_index = row.get("chunk_index")
133 chunk_index = raw_index if isinstance(raw_index, int) else 0
134 record = ChunkRecord(
135 source=source,
136 chunk_index=chunk_index,
137 text=chunk_text,
138 tokens=_tokenize_for_tf(chunk_text),
139 )
140 return record, vector
143def _load_chunk_records(
144 store: Store,
145) -> tuple[list[ChunkRecord], np.ndarray]:
146 """Scan the chunks table once and return records plus a float32 matrix.
148 Rows with an unparseable vector are skipped. Records are sorted by
149 ``(source, chunk_index)`` so downstream cluster IDs are stable
150 regardless of LanceDB's row return order. Records are tokenized once
151 here so TF-IDF labeling does not re-tokenize. The vector matrix is
152 preallocated and populated via numpy row-assignment, which pushes
153 the Python-level float cast into numpy's C loop and avoids building
154 a transient ``list[list[float]]``.
155 """
156 table = store.open_table(CHUNKS_TABLE)
157 if table is None:
158 return [], np.zeros((0, 0), dtype=np.float32)
160 parsed = [pair for pair in map(_parse_chunk_row, table.to_arrow().to_pylist()) if pair]
161 if not parsed:
162 return [], np.zeros((0, 0), dtype=np.float32)
164 parsed.sort(key=lambda pair: (pair[0].source, pair[0].chunk_index))
165 dim = len(parsed[0][1])
166 matrix = np.empty((len(parsed), dim), dtype=np.float32)
167 records: list[ChunkRecord] = []
168 for row_idx, (record, vector) in enumerate(parsed):
169 records.append(record)
170 matrix[row_idx] = vector
171 return records, matrix
174def normalize_rows(matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
175 """Return (normalized_matrix, keep_mask). Zero-norm rows are dropped."""
176 if matrix.size == 0:
177 return matrix, np.zeros(0, dtype=bool)
178 norms = np.linalg.norm(matrix, axis=1)
179 keep = norms > _MIN_VECTOR_NORM
180 if not keep.all():
181 matrix = matrix[keep]
182 norms = norms[keep]
183 return matrix / norms[:, None], keep
186def mutual_knn(matrix: np.ndarray, k: int) -> dict[int, set[int]]:
187 """Build a mutual k-nearest-neighbors graph over L2-normalized rows.
189 Computes similarity in row blocks so peak memory stays bounded.
190 Self-similarity is masked so each row's neighbors exclude itself.
191 Mutuality is enforced by keeping only edges ``(i, j)`` where
192 ``j`` is in row ``i``'s top-k AND ``i`` is in row ``j``'s top-k;
193 this single rule breaks hub-driven bridging without any extra
194 similarity rescaling.
195 """
196 n = matrix.shape[0]
197 if n == 0 or k <= 0:
198 return {}
200 effective_k = min(k, n - 1)
201 if effective_k <= 0:
202 return {i: set() for i in range(n)}
204 top_neighbors: list[set[int]] = [set() for _ in range(n)]
206 for start in range(0, n, _BLOCK_SIZE):
207 stop = min(start + _BLOCK_SIZE, n)
208 sim_block = matrix[start:stop] @ matrix.T # (block, n)
209 # Mask self-similarity so each row's own index is never returned.
210 # For the tail block where stop-start < _BLOCK_SIZE the fancy-index
211 # pairing still lines up because both sides are length stop-start.
212 block_rows = np.arange(stop - start)
213 sim_block[block_rows, np.arange(start, stop)] = -math.inf
214 # Partition by largest similarities without allocating a negated
215 # copy of the block — pass a negative kth to select the tail.
216 neighbor_idx = np.argpartition(sim_block, -effective_k, axis=1)[:, -effective_k:]
217 for local_row, global_row in enumerate(range(start, stop)):
218 top_neighbors[global_row] = set(neighbor_idx[local_row].tolist())
220 mutual: dict[int, set[int]] = {i: set() for i in range(n)}
221 for i in range(n):
222 for j in top_neighbors[i]:
223 if i in top_neighbors[j]:
224 mutual[i].add(j)
225 return mutual
228def label_propagation(
229 adjacency: dict[int, set[int]],
230 order: list[int],
231) -> list[int]:
232 """Async Label Propagation with deterministic min-label tie-breaking.
234 Each node adopts the most common label among its neighbors. Ties are
235 broken by smallest label id so the outcome is reproducible across
236 runs on the same corpus. The caller is responsible for passing a
237 complete ``adjacency`` (one entry per node 0..n-1) and an ``order``
238 that covers every node — ``get_clusters`` guarantees both.
239 """
240 n = len(adjacency)
241 labels = list(range(n))
242 for _ in range(_MAX_LPA_ITERATIONS):
243 changed = False
244 for node in order:
245 neighbors = adjacency.get(node)
246 if not neighbors:
247 continue
248 counts: Counter[int] = Counter(labels[j] for j in neighbors)
249 top_count = max(counts.values())
250 best = min(label for label, count in counts.items() if count == top_count)
251 if labels[node] != best:
252 labels[node] = best
253 changed = True
254 if not changed:
255 break
256 return labels
259def communities_by_label(labels: list[int]) -> dict[int, list[int]]:
260 """Group node indices by their final community label."""
261 communities: dict[int, list[int]] = {}
262 for node, label in enumerate(labels):
263 communities.setdefault(label, []).append(node)
264 return communities
267def _source_totals(records: list[ChunkRecord]) -> dict[str, int]:
268 """Return the total chunk count per source across the whole corpus."""
269 totals: dict[str, int] = {}
270 for record in records:
271 totals[record.source] = totals.get(record.source, 0) + 1
272 return totals
275def _filter_sources(
276 member_indices: list[int],
277 records: list[ChunkRecord],
278 source_totals: dict[str, int],
279) -> frozenset[str]:
280 """Apply the source-membership threshold to a community's members."""
281 per_source: dict[str, int] = {}
282 for idx in member_indices:
283 source = records[idx].source
284 per_source[source] = per_source.get(source, 0) + 1
285 kept: set[str] = set()
286 for source, count in per_source.items():
287 total = source_totals.get(source, count)
288 fractional_cutoff = math.ceil(total * _MIN_SOURCE_FRACTION)
289 cutoff = min(_MIN_SOURCE_CHUNKS, fractional_cutoff)
290 if count >= cutoff:
291 kept.add(source)
292 return frozenset(kept)
295def _corpus_document_frequency(records: list[ChunkRecord]) -> dict[str, int]:
296 """Compute document frequency (chunk count containing term) for every term."""
297 df: dict[str, int] = {}
298 for record in records:
299 for term in set(record.tokens):
300 df[term] = df.get(term, 0) + 1
301 return df
304def _label_community(
305 member_indices: list[int],
306 records: list[ChunkRecord],
307 df: dict[str, int],
308 total_chunks: int,
309 fallback: str,
310) -> str:
311 """Pick a topic label for a community using sublinear TF-IDF scoring."""
312 tf: dict[str, float] = {}
313 for idx in member_indices:
314 tokens = records[idx].tokens
315 if not tokens:
316 continue
317 weight = min(1.0, len(tokens) / _SHORT_CHUNK_TOKEN_CAP)
318 counts: Counter[str] = Counter(tokens)
319 for term, count in counts.items():
320 tf[term] = tf.get(term, 0.0) + weight * (1.0 + math.log(count))
322 scored: list[tuple[float, str]] = []
323 for term, term_tf in tf.items():
324 # Standard ``log(N / (1 + df))`` smoothing: the +1 keeps the
325 # denominator non-zero for new terms and damps the score of
326 # terms that appear in every chunk (where idf goes negative
327 # and the term is filtered out entirely).
328 idf = math.log(total_chunks / (1 + df.get(term, 0)))
329 if idf <= 0:
330 continue
331 scored.append((term_tf * idf, term))
332 if not scored:
333 return fallback
335 scored.sort(key=lambda pair: (-pair[0], pair[1]))
336 return " ".join(term for _, term in scored[:_LABEL_TOP_TERMS])
339def _build_clusters(
340 communities: dict[int, list[int]],
341 records: list[ChunkRecord],
342 source_totals: dict[str, int],
343 df: dict[str, int],
344 min_sources: int,
345) -> tuple[list[SourceCluster], int]:
346 """Turn raw chunk communities into published source clusters.
348 Returns ``(clusters, noise_chunk_count)`` where ``noise_chunk_count``
349 is the number of chunks whose community failed the source filter.
350 """
351 ordered = sorted(communities.items(), key=lambda pair: (-len(pair[1]), pair[0]))
352 total_chunks = len(records)
353 clusters: list[SourceCluster] = []
354 noise = 0
355 for idx, (_, members) in enumerate(ordered):
356 kept_sources = _filter_sources(members, records, source_totals)
357 if len(kept_sources) < min_sources:
358 noise += len(members)
359 continue
360 cluster_id = f"embedding-{idx}"
361 label = _label_community(members, records, df, total_chunks, fallback=cluster_id)
362 clusters.append(
363 SourceCluster(
364 cluster_id=cluster_id,
365 label=label,
366 sources=kept_sources,
367 )
368 )
369 return clusters, noise
372def _warn_if_undersegmented(
373 clusters: list[SourceCluster],
374 source_totals: dict[str, int],
375) -> None:
376 """Warn when a single cluster covers more than half the corpus sources."""
377 if not clusters or not source_totals:
378 return
379 total_sources = len(source_totals)
380 for cluster in clusters:
381 if len(cluster.sources) * 2 > total_sources:
382 log.warning(
383 "wiki clustering: cluster %r covers %d/%d sources; "
384 "consider lowering wiki_clusterer_k or check embedding quality",
385 cluster.label,
386 len(cluster.sources),
387 total_sources,
388 )
389 break
392class EmbeddingClusterer:
393 """Chunk-level mutual-kNN clusterer with TF-IDF labels."""
395 def __init__(self, config: Config, store: Store) -> None:
396 self._config = config
397 self._store = store
399 def available(self) -> bool:
400 """Clusterer is available when the chunks table has any rows.
402 ``count_rows()`` is a LanceDB call that can raise on transient
403 backend issues (concurrent compaction, schema rewrites). When
404 it does, we optimistically report available=True and let
405 ``get_clusters`` surface the real error on the next scan — the
406 alternative would silently disable wiki synthesis without the
407 user seeing why. A WARNING is emitted so the failure is still
408 visible at the default log level.
409 """
410 table = self._store.open_table(CHUNKS_TABLE)
411 if table is None:
412 return False
413 try:
414 return bool(table.count_rows())
415 except Exception:
416 log.warning(
417 "count_rows() failed on chunks table; reporting available=True "
418 "optimistically and deferring the error to get_clusters",
419 exc_info=True,
420 )
421 return True
423 def get_clusters(self, min_sources: int = 3) -> list[SourceCluster]:
424 """Return chunk-level communities projected to source clusters."""
425 records, matrix = _load_chunk_records(self._store)
426 if not records:
427 return []
429 matrix, keep_mask = normalize_rows(matrix)
430 records = [record for record, keep in zip(records, keep_mask, strict=True) if keep]
431 if not records:
432 return []
434 configured_k = self._config.wiki_clusterer_k
435 k = configured_k if configured_k > 0 else auto_k(len(records))
436 adjacency = mutual_knn(matrix, k)
437 if not any(adjacency.values()):
438 # WARNING (not INFO) so users see why synthesis produced zero
439 # pages at the default log level — matches the other degenerate
440 # clustering outcome, ``_warn_if_undersegmented``.
441 log.warning(
442 "wiki clustering: N=%d k=%d no mutual edges — skipping synthesis",
443 len(records),
444 k,
445 )
446 return []
447 labels = label_propagation(adjacency, order=list(range(len(records))))
448 communities = communities_by_label(labels)
450 totals = _source_totals(records)
451 df = _corpus_document_frequency(records)
452 clusters, noise = _build_clusters(communities, records, totals, df, min_sources)
454 log.info(
455 "wiki clustering: N=%d k=%d communities=%d kept=%d noise=%d",
456 len(records),
457 k,
458 len(communities),
459 len(clusters),
460 noise,
461 )
462 _warn_if_undersegmented(clusters, totals)
463 return clusters