Coverage for src / lilbee / query.py: 100%

102 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 08:27 +0000

1"""RAG query pipeline — embed question, search, generate answer with citations.""" 

2 

3from __future__ import annotations 

4 

5from collections.abc import Generator 

6from typing import Any 

7 

8import ollama 

9from pydantic import BaseModel 

10from typing_extensions import TypedDict 

11 

12from lilbee import embedder, store 

13from lilbee.config import cfg 

14from lilbee.store import SearchChunk 

15 

16 

17class ChatMessage(TypedDict): 

18 """A single chat message with role and content.""" 

19 

20 role: str 

21 content: str 

22 

23 

24_CONTEXT_TEMPLATE = """Context: 

25{context} 

26 

27Question: {question}""" 

28 

29 

30def format_source(result: SearchChunk) -> str: 

31 """Format a search result as a source citation line.""" 

32 source = result["source"] 

33 content_type = result["content_type"] 

34 

35 if content_type == "pdf": 

36 ps, pe = result["page_start"], result["page_end"] 

37 pages = f"page {ps}" if ps == pe else f"pages {ps}-{pe}" 

38 return f"{source}, {pages}" 

39 

40 if content_type == "code": 

41 ls, le = result["line_start"], result["line_end"] 

42 lines = f"line {ls}" if ls == le else f"lines {ls}-{le}" 

43 return f"{source}, {lines}" 

44 

45 return f"{source}" 

46 

47 

48def deduplicate_sources(results: list[SearchChunk], max_citations: int = 5) -> list[str]: 

49 """Merge results from same source into deduplicated citation lines.""" 

50 seen: set[str] = set() 

51 citations: list[str] = [] 

52 for r in results: 

53 line = format_source(r) 

54 if line not in seen: 

55 seen.add(line) 

56 citations.append(line) 

57 if len(citations) >= max_citations: 

58 break 

59 return citations 

60 

61 

62def sort_by_relevance(results: list[SearchChunk]) -> list[SearchChunk]: 

63 """Sort search results by distance (lower = more relevant).""" 

64 return sorted(results, key=lambda r: r.get("_distance", float("inf"))) 

65 

66 

67def build_context(results: list[SearchChunk]) -> str: 

68 """Build context block from search results.""" 

69 parts: list[str] = [] 

70 for i, r in enumerate(results, 1): 

71 parts.append(f"[{i}] {r['chunk']}") 

72 return "\n\n".join(parts) 

73 

74 

75def search_context(question: str, top_k: int = 0) -> list[SearchChunk]: 

76 """Embed question and return top-K matching chunks.""" 

77 if top_k == 0: 

78 top_k = cfg.top_k 

79 query_vec = embedder.embed(question) 

80 return store.search(query_vec, top_k=top_k) 

81 

82 

83class AskResult(BaseModel): 

84 """Structured result from ask_raw — answer text + raw search results.""" 

85 

86 answer: str 

87 sources: list[SearchChunk] 

88 

89 

90def ask_raw( 

91 question: str, 

92 top_k: int = 0, 

93 history: list[ChatMessage] | None = None, 

94 options: dict[str, Any] | None = None, 

95) -> AskResult: 

96 """One-shot question returning structured answer + raw sources.""" 

97 results = search_context(question, top_k=top_k) 

98 if not results: 

99 return AskResult( 

100 answer="No relevant documents found. Try ingesting some documents first.", 

101 sources=[], 

102 ) 

103 

104 results = sort_by_relevance(results) 

105 context = build_context(results) 

106 prompt = _CONTEXT_TEMPLATE.format(context=context, question=question) 

107 

108 messages: list[ChatMessage] = [{"role": "system", "content": cfg.system_prompt}] 

109 if history: 

110 messages.extend(history) 

111 messages.append({"role": "user", "content": prompt}) 

112 

113 opts = options if options is not None else cfg.generation_options() 

114 try: 

115 response = ollama.chat(model=cfg.chat_model, messages=messages, options=opts or None) 

116 except ollama.ResponseError as exc: 

117 raise RuntimeError( 

118 f"Model '{cfg.chat_model}' not found in Ollama. Run: ollama pull {cfg.chat_model}" 

119 ) from exc 

120 return AskResult(answer=response.message.content or "", sources=results) 

121 

122 

123def ask( 

124 question: str, 

125 top_k: int = 0, 

126 history: list[ChatMessage] | None = None, 

127 options: dict[str, Any] | None = None, 

128) -> str: 

129 """One-shot question: returns full answer with source citations.""" 

130 result = ask_raw(question, top_k=top_k, history=history, options=options) 

131 if not result.sources: 

132 return result.answer 

133 citations = deduplicate_sources(result.sources) 

134 return f"{result.answer}\n\nSources:\n" + "\n".join(citations) 

135 

136 

137def ask_stream( 

138 question: str, 

139 top_k: int = 0, 

140 history: list[ChatMessage] | None = None, 

141 options: dict[str, Any] | None = None, 

142) -> Generator[str, None, None]: 

143 """Streaming question: yields answer tokens, then source citations.""" 

144 results = search_context(question, top_k=top_k) 

145 if not results: 

146 yield "No relevant documents found. Try ingesting some documents first." 

147 return 

148 

149 results = sort_by_relevance(results) 

150 context = build_context(results) 

151 prompt = _CONTEXT_TEMPLATE.format(context=context, question=question) 

152 

153 messages: list[ChatMessage] = [{"role": "system", "content": cfg.system_prompt}] 

154 if history: 

155 messages.extend(history) 

156 messages.append({"role": "user", "content": prompt}) 

157 

158 opts = options if options is not None else cfg.generation_options() 

159 try: 

160 stream = ollama.chat( 

161 model=cfg.chat_model, 

162 messages=messages, 

163 stream=True, 

164 options=opts or None, 

165 ) 

166 except ollama.ResponseError as exc: 

167 raise RuntimeError( 

168 f"Model '{cfg.chat_model}' not found in Ollama. Run: ollama pull {cfg.chat_model}" 

169 ) from exc 

170 

171 try: 

172 for chunk in stream: 

173 token = chunk.message.content 

174 if token: 

175 yield token 

176 except ollama.ResponseError as exc: 

177 raise RuntimeError( 

178 f"Model '{cfg.chat_model}' not found in Ollama. Run: ollama pull {cfg.chat_model}" 

179 ) from exc 

180 except (ConnectionError, OSError) as exc: 

181 yield f"\n\n[Connection lost: {exc}]" 

182 

183 citations = deduplicate_sources(results) 

184 yield "\n\nSources:\n" + "\n".join(citations)