Coverage for src / lilbee / concepts.py: 100%
278 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Concept graph for LazyGraphRAG-style index-time knowledge extraction.
3Extracts noun-phrase concepts from chunks via spaCy, builds a PPMI-weighted
4co-occurrence graph (Church & Hanks 1990), and clusters with Leiden
5(Traag et al. 2019, graspologic-native). Used to boost search results by
6concept overlap and expand queries via graph traversal.
8Requires optional ``graph`` extra: ``pip install lilbee[graph]``.
9When dependencies are missing, all public functions degrade gracefully.
10"""
12from __future__ import annotations
14import logging
15import math
16from collections import Counter
17from dataclasses import dataclass
18from typing import Any
20import pyarrow as pa
21import pyarrow.compute as pc
23from lilbee.clustering import SourceCluster
24from lilbee.config import (
25 CHUNK_CONCEPTS_TABLE,
26 CONCEPT_EDGES_TABLE,
27 CONCEPT_NODES_TABLE,
28 Config,
29)
30from lilbee.store import Store, escape_sql_string
31from lilbee.wiki.shared import is_valid_label
33log = logging.getLogger(__name__)
35_MIN_LEIDEN_WEIGHT = 0.01
38def concepts_available() -> bool:
39 """Check if concept graph dependencies (spacy, graspologic) are installed."""
40 try:
41 import graspologic_native # noqa: F401
42 import spacy # noqa: F401
44 return True
45 except ImportError:
46 return False
49@dataclass
50class Community:
51 """A cluster of related concepts from Leiden partitioning."""
53 cluster_id: int
54 size: int
55 concepts: list[str]
58def _ensure_spacy_model() -> Any:
59 """Load the spaCy NER model; raise ImportError with an install hint if missing."""
60 import spacy
62 model_name = "en_core_web_sm"
63 try:
64 return spacy.load(model_name)
65 except OSError as exc:
66 raise ImportError(
67 f"spaCy model '{model_name}' not installed. Run: python -m spacy download {model_name}"
68 ) from exc
71def load_spacy_pipeline() -> Any:
72 """Public wrapper around the shared spaCy NER + noun-chunk pipeline.
74 Raises ``ImportError`` if spaCy or the ``en_core_web_sm`` model cannot
75 be installed.
76 """
77 return _ensure_spacy_model()
80def _filter_noun_chunks(doc: Any, max_concepts: int) -> list[str]:
81 """Extract deduplicated, filtered noun chunks from a spaCy doc.
83 Applies the same :func:`is_valid_label` gate the wiki entity
84 extractor uses, so structural-noise concepts (markdown table
85 delimiters, page-number-prefixed tokens, sub-three-char fragments)
86 never enter the co-occurrence graph and therefore never become a
87 synthesis-page cluster label.
89 The gate runs on the lowercased form here while the NER extractor
90 gates on the original-cased surface; the two decisions match
91 because ``is_valid_label`` is case-agnostic today. Any future
92 case-sensitive rule must land in both call sites together.
93 """
94 seen: set[str] = set()
95 concepts: list[str] = []
96 for chunk in doc.noun_chunks:
97 concept = chunk.text.lower().strip()
98 if not is_valid_label(concept):
99 continue
100 if concept in seen:
101 continue
102 seen.add(concept)
103 concepts.append(concept)
104 if len(concepts) >= max_concepts:
105 break
106 return concepts
109def _concept_nodes_schema() -> pa.Schema:
110 return pa.schema(
111 [
112 pa.field("concept", pa.utf8()),
113 pa.field("cluster_id", pa.int32()),
114 pa.field("degree", pa.int32()),
115 ]
116 )
119def _concept_edges_schema() -> pa.Schema:
120 return pa.schema(
121 [
122 pa.field("source", pa.utf8()),
123 pa.field("target", pa.utf8()),
124 pa.field("weight", pa.float32()),
125 ]
126 )
129def _chunk_concepts_schema() -> pa.Schema:
130 return pa.schema(
131 [
132 pa.field("chunk_source", pa.utf8()),
133 pa.field("chunk_index", pa.int32()),
134 pa.field("concept", pa.utf8()),
135 ]
136 )
139def _compute_pmi(
140 cooccurrences: Counter[tuple[str, str]],
141 concept_counts: Counter[str],
142 total_chunks: int,
143) -> dict[tuple[str, str], float]:
144 """Compute PPMI (Positive PMI) weights for concept co-occurrence pairs.
145 PPMI = max(0, log2(P(a,b) / (P(a) * P(b)))).
146 Based on Church & Hanks 1990, "Word Association Norms, Mutual Information,
147 and Lexicography." Negative values are clamped to zero to discard
148 anti-correlated pairs.
149 """
150 pmi: dict[tuple[str, str], float] = {}
151 for (a, b), count in cooccurrences.items():
152 p_a = concept_counts[a] / total_chunks
153 p_b = concept_counts[b] / total_chunks
154 if p_a == 0 or p_b == 0:
155 continue
156 p_ab = count / total_chunks
157 pmi[(a, b)] = max(0.0, math.log2(p_ab / (p_a * p_b)))
158 return pmi
161def _leiden_partition(
162 edge_rows: list[dict[str, Any]],
163) -> tuple[dict[str, int], dict[str, int]]:
164 """Run Leiden clustering on edge rows. Returns (partition, degree_map).
165 Uses graspologic-native's Rust implementation (Traag et al. 2019,
166 "From Louvain to Leiden: guaranteeing well-connected communities").
167 """
168 from graspologic_native import leiden
170 edges: list[tuple[str, str, float]] = [
171 (row["source"], row["target"], max(_MIN_LEIDEN_WEIGHT, row["weight"])) for row in edge_rows
172 ]
173 _modularity, partition = leiden(edges=edges) # type: ignore[call-arg]
175 degree_map: dict[str, int] = Counter()
176 for row in edge_rows:
177 degree_map[row["source"]] += 1
178 degree_map[row["target"]] += 1
179 return partition, dict(degree_map)
182class ConceptGraph:
183 """Concept graph -- extracts, stores, and queries concept relationships."""
185 def __init__(self, config: Config, store: Store) -> None:
186 self._config = config
187 self._store = store
188 self._nlp: Any = None
189 self._nlp_unavailable: bool = False
191 def _ensure_nlp(self) -> Any | None:
192 """Lazy-load and cache the spaCy model. Returns None if unavailable."""
193 if self._nlp_unavailable:
194 return None
195 if self._nlp is None:
196 try:
197 self._nlp = _ensure_spacy_model()
198 except ImportError:
199 log.warning("Concept graph disabled: spaCy model unavailable")
200 self._nlp_unavailable = True
201 return None
202 return self._nlp
204 def extract_concepts(self, text: str, max_concepts: int | None = None) -> list[str]:
205 """Extract noun-phrase concepts from text via spaCy."""
206 if max_concepts is None:
207 max_concepts = self._config.concept_max_per_chunk
208 if not text.strip():
209 return []
210 nlp = self._ensure_nlp()
211 if nlp is None:
212 return []
213 doc = nlp(text)
214 return _filter_noun_chunks(doc, max_concepts)
216 def extract_concepts_batch(self, texts: list[str]) -> list[list[str]]:
217 """Batch-extract concepts from multiple texts."""
218 if not texts:
219 return []
220 nlp = self._ensure_nlp()
221 if nlp is None:
222 return [[] for _ in texts]
223 max_concepts = self._config.concept_max_per_chunk
224 return [_filter_noun_chunks(doc, max_concepts) for doc in nlp.pipe(texts)]
226 def build_from_chunks(
227 self, chunk_ids: list[tuple[str, int]], concept_lists: list[list[str]]
228 ) -> None:
229 """Build co-occurrence graph from chunk concepts, compute PMI, store tables."""
230 from lilbee.lock import write_lock
231 from lilbee.store import ensure_table
233 if not chunk_ids:
234 return
236 cooccurrences: Counter[tuple[str, str]] = Counter()
237 concept_counts: Counter[str] = Counter()
238 chunk_concept_records: list[dict[str, Any]] = []
240 for (source, idx), concepts in zip(chunk_ids, concept_lists, strict=True):
241 for c in concepts:
242 concept_counts[c] += 1
243 chunk_concept_records.append(
244 {"chunk_source": source, "chunk_index": idx, "concept": c}
245 )
246 for i, a in enumerate(concepts):
247 for b in concepts[i + 1 :]:
248 pair = (min(a, b), max(a, b))
249 cooccurrences[pair] += 1
251 pmi_weights = _compute_pmi(cooccurrences, concept_counts, len(chunk_ids))
253 edge_records = [
254 {"source": a, "target": b, "weight": w} for (a, b), w in pmi_weights.items()
255 ]
257 node_records = [
258 {"concept": c, "cluster_id": 0, "degree": count} for c, count in concept_counts.items()
259 ]
261 with write_lock():
262 db = self._store.get_db()
263 # Always create tables so get_graph() returns True even when
264 # concept extraction yields no results for the current corpus.
265 nodes_tbl = ensure_table(db, CONCEPT_NODES_TABLE, _concept_nodes_schema())
266 edges_tbl = ensure_table(db, CONCEPT_EDGES_TABLE, _concept_edges_schema())
267 cc_tbl = ensure_table(db, CHUNK_CONCEPTS_TABLE, _chunk_concepts_schema())
268 if node_records:
269 nodes_tbl.add(node_records)
270 if edge_records:
271 edges_tbl.add(edge_records)
272 if chunk_concept_records:
273 cc_tbl.add(chunk_concept_records)
275 def boost_results(self, results: list[Any], query_concepts: list[str]) -> list[Any]:
276 """Boost search results whose chunks overlap with query concepts."""
277 if not query_concepts or not results:
278 return results
279 query_set = set(query_concepts)
280 boosted: list[Any] = []
281 for r in results:
282 chunk_concepts = set(self.get_chunk_concepts(r.source, r.chunk_index))
283 overlap = len(query_set & chunk_concepts)
284 if overlap > 0:
285 boost = (overlap / len(query_set)) * self._config.concept_boost_weight
286 r = r.model_copy()
287 if r.relevance_score is not None:
288 r.relevance_score = r.relevance_score + boost
289 elif r.distance is not None:
290 r.distance = max(self._config.concept_boost_floor, r.distance - boost)
291 boosted.append(r)
292 return boosted
294 def get_chunk_concepts(self, source: str, chunk_index: int) -> list[str]:
295 """Get concepts associated with a specific chunk."""
296 table = self._store.open_table(CHUNK_CONCEPTS_TABLE)
297 if table is None:
298 return []
299 escaped = escape_sql_string(source)
300 try:
301 rows = (
302 table.search()
303 .where(f"chunk_source = '{escaped}' AND chunk_index = {chunk_index}")
304 .to_list()
305 )
306 return [r["concept"] for r in rows]
307 except Exception:
308 return []
310 def expand_query(self, query: str) -> list[str]:
311 """Expand a query with related concepts from the graph."""
312 concepts = self.extract_concepts(query)
313 if not concepts:
314 return []
315 related: list[str] = []
316 seen = set(concepts)
317 for concept in concepts:
318 for neighbor in self.get_related_concepts(concept):
319 if neighbor not in seen:
320 related.append(neighbor)
321 seen.add(neighbor)
322 return related
324 def get_related_concepts(self, concept: str, depth: int = 1) -> list[str]:
325 """Find concepts related to *concept* via graph edges, up to *depth* hops.
327 One batched query per depth level — O(depth) DB round-trips,
328 independent of frontier size.
329 """
330 table = self._store.open_table(CONCEPT_EDGES_TABLE)
331 if table is None:
332 return []
333 visited: set[str] = {concept}
334 frontier: list[str] = [concept]
335 for _ in range(depth):
336 if not frontier:
337 break
338 escaped_list = ", ".join(f"'{escape_sql_string(n)}'" for n in frontier)
339 try:
340 rows = (
341 table.search()
342 .where(f"source IN ({escaped_list}) OR target IN ({escaped_list})")
343 .to_list()
344 )
345 except Exception:
346 log.debug(
347 "concept expand batch failed at frontier size %d",
348 len(frontier),
349 exc_info=True,
350 )
351 break
352 next_frontier: list[str] = []
353 for row in rows:
354 for endpoint in (row["source"], row["target"]):
355 if endpoint not in visited:
356 visited.add(endpoint)
357 next_frontier.append(endpoint)
358 frontier = next_frontier
359 return [c for c in visited if c != concept]
361 def top_communities(self, k: int = 10) -> list[Community]:
362 """Return the *k* largest concept communities.
364 Uses ``pyarrow.compute.value_counts`` to pick the top-k
365 cluster_ids in columnar memory, then materializes only those
366 clusters' members. Peak Python memory scales with members of
367 the top *k* clusters, not the total node count.
368 """
369 table = self._store.open_table(CONCEPT_NODES_TABLE)
370 if table is None:
371 return []
372 arrow_tbl = table.to_arrow()
373 if arrow_tbl.num_rows == 0:
374 return []
375 counts = pc.value_counts(arrow_tbl["cluster_id"]).to_pylist()
376 top = sorted(counts, key=lambda entry: entry["counts"], reverse=True)[:k]
377 top_ids = [entry["values"] for entry in top if entry["values"] is not None]
378 if not top_ids:
379 return []
380 member_rows = arrow_tbl.filter(
381 pc.is_in(arrow_tbl["cluster_id"], value_set=pa.array(top_ids))
382 ).to_pylist()
383 by_cluster: dict[int, list[str]] = {}
384 for row in member_rows:
385 by_cluster.setdefault(row["cluster_id"], []).append(row["concept"])
386 return [
387 Community(
388 cluster_id=cid,
389 size=len(by_cluster.get(cid, [])),
390 concepts=by_cluster.get(cid, []),
391 )
392 for cid in top_ids
393 if by_cluster.get(cid)
394 ]
396 def rebuild_clusters(self) -> None:
397 """Re-run Leiden clustering on the existing edge table."""
398 from lilbee.lock import write_lock
399 from lilbee.store import ensure_table
401 edges_table = self._store.open_table(CONCEPT_EDGES_TABLE)
402 if edges_table is None:
403 return
404 edge_rows = edges_table.to_arrow().to_pylist()
405 if not edge_rows:
406 return
408 partition, degree_map = _leiden_partition(edge_rows)
410 node_records = [
411 {
412 "concept": node,
413 "cluster_id": cluster_id,
414 "degree": degree_map.get(node, 0),
415 }
416 for node, cluster_id in partition.items()
417 ]
419 self._store.clear_table(CONCEPT_NODES_TABLE, "concept IS NOT NULL")
420 if node_records:
421 with write_lock():
422 db = self._store.get_db()
423 nodes_table = ensure_table(db, CONCEPT_NODES_TABLE, _concept_nodes_schema())
424 nodes_table.add(node_records)
426 def get_cluster_sources(self, min_sources: int = 3) -> dict[int, set[str]]:
427 """Return clusters that span at least *min_sources* distinct sources.
428 Joins concept_nodes (concept -> cluster_id) with chunk_concepts
429 (concept -> chunk_source) to find which document sources each
430 cluster touches.
431 """
432 nodes_table = self._store.open_table(CONCEPT_NODES_TABLE)
433 cc_table = self._store.open_table(CHUNK_CONCEPTS_TABLE)
434 if nodes_table is None or cc_table is None:
435 return {}
437 node_rows = nodes_table.to_arrow().to_pylist()
438 concept_to_cluster: dict[str, int] = {r["concept"]: r["cluster_id"] for r in node_rows}
440 cc_rows = cc_table.to_arrow().to_pylist()
441 cluster_sources: dict[int, set[str]] = {}
442 for row in cc_rows:
443 cid = concept_to_cluster.get(row["concept"])
444 if cid is None:
445 continue
446 cluster_sources.setdefault(cid, set()).add(row["chunk_source"])
448 return {
449 cid: sources for cid, sources in cluster_sources.items() if len(sources) >= min_sources
450 }
452 def get_cluster_label(self, cluster_id: int) -> str:
453 """Return a human-readable label for *cluster_id* (highest-degree concept)."""
454 table = self._store.open_table(CONCEPT_NODES_TABLE)
455 if table is None:
456 return f"cluster-{cluster_id}"
457 try:
458 rows = table.search().where(f"cluster_id = {int(cluster_id)}").to_list()
459 except Exception:
460 log.debug("get_cluster_label query failed", exc_info=True)
461 return f"cluster-{cluster_id}"
462 if not rows:
463 return f"cluster-{cluster_id}"
464 best = max(rows, key=lambda r: r["degree"])
465 return str(best["concept"])
467 def get_graph(self) -> bool:
468 """Check whether a concept graph exists in the store."""
469 if not self._config.concept_graph:
470 return False
471 return self._store.open_table(CONCEPT_NODES_TABLE) is not None
473 def reset_nlp_cache(self) -> None:
474 """Clear the spaCy model cache. For testing only."""
475 self._nlp = None
476 self._nlp_unavailable = False
479class ConceptGraphClusterer:
480 """SourceClusterer backed by the concept graph (requires ``[graph]`` extra).
482 Wraps :class:`ConceptGraph` so the wiki synthesis layer can consume
483 concept-based clusters through the generic ``SourceClusterer`` protocol
484 without importing ``ConceptGraph`` directly. Leaves ``ConceptGraph``
485 unchanged.
486 """
488 def __init__(self, config: Config, store: Store) -> None:
489 self._graph = ConceptGraph(config, store)
491 def available(self) -> bool:
492 """Concept-graph clustering needs both dependencies and a built graph."""
493 return bool(concepts_available() and self._graph.get_graph())
495 def get_clusters(self, min_sources: int = 3) -> list[SourceCluster]:
496 """Expose concept clusters as generic :class:`SourceCluster` values."""
497 cluster_sources = self._graph.get_cluster_sources(min_sources=min_sources)
498 return [
499 SourceCluster(
500 cluster_id=f"concept-{cid}",
501 label=self._graph.get_cluster_label(cid),
502 sources=frozenset(sources),
503 )
504 for cid, sources in cluster_sources.items()
505 if len(sources) >= min_sources
506 ]