Coverage for src / lilbee / concepts.py: 100%

278 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Concept graph for LazyGraphRAG-style index-time knowledge extraction. 

2 

3Extracts noun-phrase concepts from chunks via spaCy, builds a PPMI-weighted 

4co-occurrence graph (Church & Hanks 1990), and clusters with Leiden 

5(Traag et al. 2019, graspologic-native). Used to boost search results by 

6concept overlap and expand queries via graph traversal. 

7 

8Requires optional ``graph`` extra: ``pip install lilbee[graph]``. 

9When dependencies are missing, all public functions degrade gracefully. 

10""" 

11 

12from __future__ import annotations 

13 

14import logging 

15import math 

16from collections import Counter 

17from dataclasses import dataclass 

18from typing import Any 

19 

20import pyarrow as pa 

21import pyarrow.compute as pc 

22 

23from lilbee.clustering import SourceCluster 

24from lilbee.config import ( 

25 CHUNK_CONCEPTS_TABLE, 

26 CONCEPT_EDGES_TABLE, 

27 CONCEPT_NODES_TABLE, 

28 Config, 

29) 

30from lilbee.store import Store, escape_sql_string 

31from lilbee.wiki.shared import is_valid_label 

32 

33log = logging.getLogger(__name__) 

34 

35_MIN_LEIDEN_WEIGHT = 0.01 

36 

37 

38def concepts_available() -> bool: 

39 """Check if concept graph dependencies (spacy, graspologic) are installed.""" 

40 try: 

41 import graspologic_native # noqa: F401 

42 import spacy # noqa: F401 

43 

44 return True 

45 except ImportError: 

46 return False 

47 

48 

49@dataclass 

50class Community: 

51 """A cluster of related concepts from Leiden partitioning.""" 

52 

53 cluster_id: int 

54 size: int 

55 concepts: list[str] 

56 

57 

58def _ensure_spacy_model() -> Any: 

59 """Load the spaCy NER model; raise ImportError with an install hint if missing.""" 

60 import spacy 

61 

62 model_name = "en_core_web_sm" 

63 try: 

64 return spacy.load(model_name) 

65 except OSError as exc: 

66 raise ImportError( 

67 f"spaCy model '{model_name}' not installed. Run: python -m spacy download {model_name}" 

68 ) from exc 

69 

70 

71def load_spacy_pipeline() -> Any: 

72 """Public wrapper around the shared spaCy NER + noun-chunk pipeline. 

73 

74 Raises ``ImportError`` if spaCy or the ``en_core_web_sm`` model cannot 

75 be installed. 

76 """ 

77 return _ensure_spacy_model() 

78 

79 

80def _filter_noun_chunks(doc: Any, max_concepts: int) -> list[str]: 

81 """Extract deduplicated, filtered noun chunks from a spaCy doc. 

82 

83 Applies the same :func:`is_valid_label` gate the wiki entity 

84 extractor uses, so structural-noise concepts (markdown table 

85 delimiters, page-number-prefixed tokens, sub-three-char fragments) 

86 never enter the co-occurrence graph and therefore never become a 

87 synthesis-page cluster label. 

88 

89 The gate runs on the lowercased form here while the NER extractor 

90 gates on the original-cased surface; the two decisions match 

91 because ``is_valid_label`` is case-agnostic today. Any future 

92 case-sensitive rule must land in both call sites together. 

93 """ 

94 seen: set[str] = set() 

95 concepts: list[str] = [] 

96 for chunk in doc.noun_chunks: 

97 concept = chunk.text.lower().strip() 

98 if not is_valid_label(concept): 

99 continue 

100 if concept in seen: 

101 continue 

102 seen.add(concept) 

103 concepts.append(concept) 

104 if len(concepts) >= max_concepts: 

105 break 

106 return concepts 

107 

108 

109def _concept_nodes_schema() -> pa.Schema: 

110 return pa.schema( 

111 [ 

112 pa.field("concept", pa.utf8()), 

113 pa.field("cluster_id", pa.int32()), 

114 pa.field("degree", pa.int32()), 

115 ] 

116 ) 

117 

118 

119def _concept_edges_schema() -> pa.Schema: 

120 return pa.schema( 

121 [ 

122 pa.field("source", pa.utf8()), 

123 pa.field("target", pa.utf8()), 

124 pa.field("weight", pa.float32()), 

125 ] 

126 ) 

127 

128 

129def _chunk_concepts_schema() -> pa.Schema: 

130 return pa.schema( 

131 [ 

132 pa.field("chunk_source", pa.utf8()), 

133 pa.field("chunk_index", pa.int32()), 

134 pa.field("concept", pa.utf8()), 

135 ] 

136 ) 

137 

138 

139def _compute_pmi( 

140 cooccurrences: Counter[tuple[str, str]], 

141 concept_counts: Counter[str], 

142 total_chunks: int, 

143) -> dict[tuple[str, str], float]: 

144 """Compute PPMI (Positive PMI) weights for concept co-occurrence pairs. 

145 PPMI = max(0, log2(P(a,b) / (P(a) * P(b)))). 

146 Based on Church & Hanks 1990, "Word Association Norms, Mutual Information, 

147 and Lexicography." Negative values are clamped to zero to discard 

148 anti-correlated pairs. 

149 """ 

150 pmi: dict[tuple[str, str], float] = {} 

151 for (a, b), count in cooccurrences.items(): 

152 p_a = concept_counts[a] / total_chunks 

153 p_b = concept_counts[b] / total_chunks 

154 if p_a == 0 or p_b == 0: 

155 continue 

156 p_ab = count / total_chunks 

157 pmi[(a, b)] = max(0.0, math.log2(p_ab / (p_a * p_b))) 

158 return pmi 

159 

160 

161def _leiden_partition( 

162 edge_rows: list[dict[str, Any]], 

163) -> tuple[dict[str, int], dict[str, int]]: 

164 """Run Leiden clustering on edge rows. Returns (partition, degree_map). 

165 Uses graspologic-native's Rust implementation (Traag et al. 2019, 

166 "From Louvain to Leiden: guaranteeing well-connected communities"). 

167 """ 

168 from graspologic_native import leiden 

169 

170 edges: list[tuple[str, str, float]] = [ 

171 (row["source"], row["target"], max(_MIN_LEIDEN_WEIGHT, row["weight"])) for row in edge_rows 

172 ] 

173 _modularity, partition = leiden(edges=edges) # type: ignore[call-arg] 

174 

175 degree_map: dict[str, int] = Counter() 

176 for row in edge_rows: 

177 degree_map[row["source"]] += 1 

178 degree_map[row["target"]] += 1 

179 return partition, dict(degree_map) 

180 

181 

182class ConceptGraph: 

183 """Concept graph -- extracts, stores, and queries concept relationships.""" 

184 

185 def __init__(self, config: Config, store: Store) -> None: 

186 self._config = config 

187 self._store = store 

188 self._nlp: Any = None 

189 self._nlp_unavailable: bool = False 

190 

191 def _ensure_nlp(self) -> Any | None: 

192 """Lazy-load and cache the spaCy model. Returns None if unavailable.""" 

193 if self._nlp_unavailable: 

194 return None 

195 if self._nlp is None: 

196 try: 

197 self._nlp = _ensure_spacy_model() 

198 except ImportError: 

199 log.warning("Concept graph disabled: spaCy model unavailable") 

200 self._nlp_unavailable = True 

201 return None 

202 return self._nlp 

203 

204 def extract_concepts(self, text: str, max_concepts: int | None = None) -> list[str]: 

205 """Extract noun-phrase concepts from text via spaCy.""" 

206 if max_concepts is None: 

207 max_concepts = self._config.concept_max_per_chunk 

208 if not text.strip(): 

209 return [] 

210 nlp = self._ensure_nlp() 

211 if nlp is None: 

212 return [] 

213 doc = nlp(text) 

214 return _filter_noun_chunks(doc, max_concepts) 

215 

216 def extract_concepts_batch(self, texts: list[str]) -> list[list[str]]: 

217 """Batch-extract concepts from multiple texts.""" 

218 if not texts: 

219 return [] 

220 nlp = self._ensure_nlp() 

221 if nlp is None: 

222 return [[] for _ in texts] 

223 max_concepts = self._config.concept_max_per_chunk 

224 return [_filter_noun_chunks(doc, max_concepts) for doc in nlp.pipe(texts)] 

225 

226 def build_from_chunks( 

227 self, chunk_ids: list[tuple[str, int]], concept_lists: list[list[str]] 

228 ) -> None: 

229 """Build co-occurrence graph from chunk concepts, compute PMI, store tables.""" 

230 from lilbee.lock import write_lock 

231 from lilbee.store import ensure_table 

232 

233 if not chunk_ids: 

234 return 

235 

236 cooccurrences: Counter[tuple[str, str]] = Counter() 

237 concept_counts: Counter[str] = Counter() 

238 chunk_concept_records: list[dict[str, Any]] = [] 

239 

240 for (source, idx), concepts in zip(chunk_ids, concept_lists, strict=True): 

241 for c in concepts: 

242 concept_counts[c] += 1 

243 chunk_concept_records.append( 

244 {"chunk_source": source, "chunk_index": idx, "concept": c} 

245 ) 

246 for i, a in enumerate(concepts): 

247 for b in concepts[i + 1 :]: 

248 pair = (min(a, b), max(a, b)) 

249 cooccurrences[pair] += 1 

250 

251 pmi_weights = _compute_pmi(cooccurrences, concept_counts, len(chunk_ids)) 

252 

253 edge_records = [ 

254 {"source": a, "target": b, "weight": w} for (a, b), w in pmi_weights.items() 

255 ] 

256 

257 node_records = [ 

258 {"concept": c, "cluster_id": 0, "degree": count} for c, count in concept_counts.items() 

259 ] 

260 

261 with write_lock(): 

262 db = self._store.get_db() 

263 # Always create tables so get_graph() returns True even when 

264 # concept extraction yields no results for the current corpus. 

265 nodes_tbl = ensure_table(db, CONCEPT_NODES_TABLE, _concept_nodes_schema()) 

266 edges_tbl = ensure_table(db, CONCEPT_EDGES_TABLE, _concept_edges_schema()) 

267 cc_tbl = ensure_table(db, CHUNK_CONCEPTS_TABLE, _chunk_concepts_schema()) 

268 if node_records: 

269 nodes_tbl.add(node_records) 

270 if edge_records: 

271 edges_tbl.add(edge_records) 

272 if chunk_concept_records: 

273 cc_tbl.add(chunk_concept_records) 

274 

275 def boost_results(self, results: list[Any], query_concepts: list[str]) -> list[Any]: 

276 """Boost search results whose chunks overlap with query concepts.""" 

277 if not query_concepts or not results: 

278 return results 

279 query_set = set(query_concepts) 

280 boosted: list[Any] = [] 

281 for r in results: 

282 chunk_concepts = set(self.get_chunk_concepts(r.source, r.chunk_index)) 

283 overlap = len(query_set & chunk_concepts) 

284 if overlap > 0: 

285 boost = (overlap / len(query_set)) * self._config.concept_boost_weight 

286 r = r.model_copy() 

287 if r.relevance_score is not None: 

288 r.relevance_score = r.relevance_score + boost 

289 elif r.distance is not None: 

290 r.distance = max(self._config.concept_boost_floor, r.distance - boost) 

291 boosted.append(r) 

292 return boosted 

293 

294 def get_chunk_concepts(self, source: str, chunk_index: int) -> list[str]: 

295 """Get concepts associated with a specific chunk.""" 

296 table = self._store.open_table(CHUNK_CONCEPTS_TABLE) 

297 if table is None: 

298 return [] 

299 escaped = escape_sql_string(source) 

300 try: 

301 rows = ( 

302 table.search() 

303 .where(f"chunk_source = '{escaped}' AND chunk_index = {chunk_index}") 

304 .to_list() 

305 ) 

306 return [r["concept"] for r in rows] 

307 except Exception: 

308 return [] 

309 

310 def expand_query(self, query: str) -> list[str]: 

311 """Expand a query with related concepts from the graph.""" 

312 concepts = self.extract_concepts(query) 

313 if not concepts: 

314 return [] 

315 related: list[str] = [] 

316 seen = set(concepts) 

317 for concept in concepts: 

318 for neighbor in self.get_related_concepts(concept): 

319 if neighbor not in seen: 

320 related.append(neighbor) 

321 seen.add(neighbor) 

322 return related 

323 

324 def get_related_concepts(self, concept: str, depth: int = 1) -> list[str]: 

325 """Find concepts related to *concept* via graph edges, up to *depth* hops. 

326 

327 One batched query per depth level — O(depth) DB round-trips, 

328 independent of frontier size. 

329 """ 

330 table = self._store.open_table(CONCEPT_EDGES_TABLE) 

331 if table is None: 

332 return [] 

333 visited: set[str] = {concept} 

334 frontier: list[str] = [concept] 

335 for _ in range(depth): 

336 if not frontier: 

337 break 

338 escaped_list = ", ".join(f"'{escape_sql_string(n)}'" for n in frontier) 

339 try: 

340 rows = ( 

341 table.search() 

342 .where(f"source IN ({escaped_list}) OR target IN ({escaped_list})") 

343 .to_list() 

344 ) 

345 except Exception: 

346 log.debug( 

347 "concept expand batch failed at frontier size %d", 

348 len(frontier), 

349 exc_info=True, 

350 ) 

351 break 

352 next_frontier: list[str] = [] 

353 for row in rows: 

354 for endpoint in (row["source"], row["target"]): 

355 if endpoint not in visited: 

356 visited.add(endpoint) 

357 next_frontier.append(endpoint) 

358 frontier = next_frontier 

359 return [c for c in visited if c != concept] 

360 

361 def top_communities(self, k: int = 10) -> list[Community]: 

362 """Return the *k* largest concept communities. 

363 

364 Uses ``pyarrow.compute.value_counts`` to pick the top-k 

365 cluster_ids in columnar memory, then materializes only those 

366 clusters' members. Peak Python memory scales with members of 

367 the top *k* clusters, not the total node count. 

368 """ 

369 table = self._store.open_table(CONCEPT_NODES_TABLE) 

370 if table is None: 

371 return [] 

372 arrow_tbl = table.to_arrow() 

373 if arrow_tbl.num_rows == 0: 

374 return [] 

375 counts = pc.value_counts(arrow_tbl["cluster_id"]).to_pylist() 

376 top = sorted(counts, key=lambda entry: entry["counts"], reverse=True)[:k] 

377 top_ids = [entry["values"] for entry in top if entry["values"] is not None] 

378 if not top_ids: 

379 return [] 

380 member_rows = arrow_tbl.filter( 

381 pc.is_in(arrow_tbl["cluster_id"], value_set=pa.array(top_ids)) 

382 ).to_pylist() 

383 by_cluster: dict[int, list[str]] = {} 

384 for row in member_rows: 

385 by_cluster.setdefault(row["cluster_id"], []).append(row["concept"]) 

386 return [ 

387 Community( 

388 cluster_id=cid, 

389 size=len(by_cluster.get(cid, [])), 

390 concepts=by_cluster.get(cid, []), 

391 ) 

392 for cid in top_ids 

393 if by_cluster.get(cid) 

394 ] 

395 

396 def rebuild_clusters(self) -> None: 

397 """Re-run Leiden clustering on the existing edge table.""" 

398 from lilbee.lock import write_lock 

399 from lilbee.store import ensure_table 

400 

401 edges_table = self._store.open_table(CONCEPT_EDGES_TABLE) 

402 if edges_table is None: 

403 return 

404 edge_rows = edges_table.to_arrow().to_pylist() 

405 if not edge_rows: 

406 return 

407 

408 partition, degree_map = _leiden_partition(edge_rows) 

409 

410 node_records = [ 

411 { 

412 "concept": node, 

413 "cluster_id": cluster_id, 

414 "degree": degree_map.get(node, 0), 

415 } 

416 for node, cluster_id in partition.items() 

417 ] 

418 

419 self._store.clear_table(CONCEPT_NODES_TABLE, "concept IS NOT NULL") 

420 if node_records: 

421 with write_lock(): 

422 db = self._store.get_db() 

423 nodes_table = ensure_table(db, CONCEPT_NODES_TABLE, _concept_nodes_schema()) 

424 nodes_table.add(node_records) 

425 

426 def get_cluster_sources(self, min_sources: int = 3) -> dict[int, set[str]]: 

427 """Return clusters that span at least *min_sources* distinct sources. 

428 Joins concept_nodes (concept -> cluster_id) with chunk_concepts 

429 (concept -> chunk_source) to find which document sources each 

430 cluster touches. 

431 """ 

432 nodes_table = self._store.open_table(CONCEPT_NODES_TABLE) 

433 cc_table = self._store.open_table(CHUNK_CONCEPTS_TABLE) 

434 if nodes_table is None or cc_table is None: 

435 return {} 

436 

437 node_rows = nodes_table.to_arrow().to_pylist() 

438 concept_to_cluster: dict[str, int] = {r["concept"]: r["cluster_id"] for r in node_rows} 

439 

440 cc_rows = cc_table.to_arrow().to_pylist() 

441 cluster_sources: dict[int, set[str]] = {} 

442 for row in cc_rows: 

443 cid = concept_to_cluster.get(row["concept"]) 

444 if cid is None: 

445 continue 

446 cluster_sources.setdefault(cid, set()).add(row["chunk_source"]) 

447 

448 return { 

449 cid: sources for cid, sources in cluster_sources.items() if len(sources) >= min_sources 

450 } 

451 

452 def get_cluster_label(self, cluster_id: int) -> str: 

453 """Return a human-readable label for *cluster_id* (highest-degree concept).""" 

454 table = self._store.open_table(CONCEPT_NODES_TABLE) 

455 if table is None: 

456 return f"cluster-{cluster_id}" 

457 try: 

458 rows = table.search().where(f"cluster_id = {int(cluster_id)}").to_list() 

459 except Exception: 

460 log.debug("get_cluster_label query failed", exc_info=True) 

461 return f"cluster-{cluster_id}" 

462 if not rows: 

463 return f"cluster-{cluster_id}" 

464 best = max(rows, key=lambda r: r["degree"]) 

465 return str(best["concept"]) 

466 

467 def get_graph(self) -> bool: 

468 """Check whether a concept graph exists in the store.""" 

469 if not self._config.concept_graph: 

470 return False 

471 return self._store.open_table(CONCEPT_NODES_TABLE) is not None 

472 

473 def reset_nlp_cache(self) -> None: 

474 """Clear the spaCy model cache. For testing only.""" 

475 self._nlp = None 

476 self._nlp_unavailable = False 

477 

478 

479class ConceptGraphClusterer: 

480 """SourceClusterer backed by the concept graph (requires ``[graph]`` extra). 

481 

482 Wraps :class:`ConceptGraph` so the wiki synthesis layer can consume 

483 concept-based clusters through the generic ``SourceClusterer`` protocol 

484 without importing ``ConceptGraph`` directly. Leaves ``ConceptGraph`` 

485 unchanged. 

486 """ 

487 

488 def __init__(self, config: Config, store: Store) -> None: 

489 self._graph = ConceptGraph(config, store) 

490 

491 def available(self) -> bool: 

492 """Concept-graph clustering needs both dependencies and a built graph.""" 

493 return bool(concepts_available() and self._graph.get_graph()) 

494 

495 def get_clusters(self, min_sources: int = 3) -> list[SourceCluster]: 

496 """Expose concept clusters as generic :class:`SourceCluster` values.""" 

497 cluster_sources = self._graph.get_cluster_sources(min_sources=min_sources) 

498 return [ 

499 SourceCluster( 

500 cluster_id=f"concept-{cid}", 

501 label=self._graph.get_cluster_label(cid), 

502 sources=frozenset(sources), 

503 ) 

504 for cid, sources in cluster_sources.items() 

505 if len(sources) >= min_sources 

506 ]