Coverage for src / lilbee / reranker.py: 100%
68 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Cross-encoder reranking for search results.
3Optional precision pass that scores each (query, chunk) pair through the
4active provider's ``rerank`` method. Only active when
5``cfg.reranker_model`` is set.
7Core technique: Nogueira & Cho 2019, "Passage Re-ranking with BERT"
8(https://arxiv.org/abs/1901.04085).
10Position-aware blending: derived from learning-to-rank literature
11(Burges et al. 2005). Top positions trust hybrid fusion more, lower
12positions trust the reranker more.
13"""
15from __future__ import annotations
17import logging
18from typing import NamedTuple
20from lilbee.config import Config
21from lilbee.store import SearchChunk
23log = logging.getLogger(__name__)
26class ScoredChunk(NamedTuple):
27 """A search chunk paired with its blended score."""
29 score: float
30 chunk: SearchChunk
33_TOP_POSITION_CUTOFF = 3
34_MID_POSITION_CUTOFF = 10
36_BLEND_SCHEDULE = {
37 "top": (0.70, 0.30),
38 "mid": (0.50, 0.50),
39 "bottom": (0.30, 0.70),
40}
43def _normalize_scores(scores: list[float]) -> list[float]:
44 """Min-max normalize raw cross-encoder scores to [0, 1]."""
45 min_score = min(scores)
46 max_score = max(scores)
47 score_range = max_score - min_score
48 if score_range > 0:
49 return [(s - min_score) / score_range for s in scores]
50 return [0.5] * len(scores)
53def _blend_scores(to_rerank: list[SearchChunk], norm_scores: list[float]) -> list[ScoredChunk]:
54 """Blend fusion scores with reranker scores using position-aware weights."""
55 blended: list[ScoredChunk] = []
56 for i, (chunk, rerank_score) in enumerate(zip(to_rerank, norm_scores, strict=True)):
57 fusion_score = chunk.relevance_score or (1.0 - (chunk.distance or 0.5))
58 fusion_norm = max(0.0, min(1.0, fusion_score))
60 if i < _TOP_POSITION_CUTOFF:
61 fw, rw = _BLEND_SCHEDULE["top"]
62 elif i < _MID_POSITION_CUTOFF:
63 fw, rw = _BLEND_SCHEDULE["mid"]
64 else:
65 fw, rw = _BLEND_SCHEDULE["bottom"]
67 final_score = fw * fusion_norm + rw * rerank_score
68 blended.append(ScoredChunk(final_score, chunk))
69 return blended
72def _pin_original_top(
73 blended: list[ScoredChunk],
74 to_rerank: list[SearchChunk],
75 skip_threshold: float,
76) -> list[ScoredChunk]:
77 """Pin the original top result if its relevance exceeds the skip threshold."""
78 top_score = to_rerank[0].relevance_score or 0 if to_rerank else 0
79 blended_sorted = sorted(blended, key=lambda x: x.score, reverse=True)
80 if top_score >= skip_threshold:
81 original_top = to_rerank[0]
82 if blended_sorted[0].chunk is not original_top:
83 blended_sorted = [ScoredChunk(999.0, original_top)] + [
84 ScoredChunk(s, c) for s, c in blended_sorted if c is not original_top
85 ]
86 return blended_sorted
89class Reranker:
90 """Cross-encoder reranker with position-aware blending.
92 Delegates scoring to the active provider's ``rerank``; handles result
93 blending and the BM25-protection pin (Nogueira & Cho 2019,
94 https://arxiv.org/abs/1901.04085).
95 """
97 def __init__(self, config: Config) -> None:
98 self._config = config
100 def rerank(
101 self,
102 query: str,
103 results: list[SearchChunk],
104 candidates: int | None = None,
105 ) -> list[SearchChunk]:
106 """Rerank search results through the provider's ``rerank`` method."""
107 if not self._config.reranker_model:
108 return results
109 if candidates is None:
110 candidates = self._config.rerank_candidates
111 to_rerank = results[:candidates]
112 remainder = results[candidates:]
114 if not to_rerank:
115 return results
117 scores = _score_candidates(query, to_rerank)
118 if scores is None:
119 return results
121 norm_scores = _normalize_scores(scores)
122 blended = _blend_scores(to_rerank, norm_scores)
123 blended_sorted = _pin_original_top(
124 blended, to_rerank, self._config.expansion_skip_threshold
125 )
127 reranked = [chunk for _, chunk in blended_sorted]
128 return reranked + remainder
131def _score_candidates(query: str, to_rerank: list[SearchChunk]) -> list[float] | None:
132 """Call the active provider's rerank; return None on error after logging."""
133 # circular: services -> reranker via Searcher; deferred so test-time
134 # monkeypatching of ``lilbee.services.get_services`` stays effective.
135 from lilbee.services import get_services
137 try:
138 provider = get_services().provider
139 return provider.rerank(query, [c.chunk for c in to_rerank])
140 except Exception as exc:
141 log.warning("Reranker failed; skipping rerank pass: %s", exc, exc_info=True)
142 return None