Coverage for src / lilbee / reasoning.py: 100%
103 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-29 19:16 +0000
1"""Reasoning token filter — detects <think>...</think> tags in streaming output.
3Reasoning models (Qwen3, DeepSeek-R1) wrap their thinking process in
4``<think>...</think>`` tags. This module provides a stateful filter that
5classifies tokens as reasoning or response content.
6"""
8from __future__ import annotations
10import contextlib
11import re
12from collections.abc import Generator, Iterator
13from dataclasses import dataclass
15from lilbee.providers.base import ClosableIterator
17_OPEN_TAG = "<think>"
18_CLOSE_TAG = "</think>"
19_THINK_BLOCK_RE = re.compile(r"<think>[\s\S]*?</think>\s*|<think>[\s\S]*$")
22@dataclass
23class StreamToken:
24 """A classified token from the stream."""
26 content: str
27 is_reasoning: bool
30class _TagParser:
31 """Stateful parser that tracks whether we're inside a thinking block."""
33 def __init__(self, *, show: bool) -> None:
34 self.show = show
35 self.buf = ""
36 self.in_thinking = False
37 self.reasoning_chars = 0
39 def feed(self, token: str) -> list[StreamToken]:
40 """Feed a token and return any complete StreamTokens."""
41 self.buf += token
42 result: list[StreamToken] = []
43 while self.buf:
44 emitted = self._process_thinking() if self.in_thinking else self._process_normal()
45 if emitted is None:
46 break # waiting for more data (partial tag)
47 if emitted.content:
48 result.append(emitted)
49 return result
51 def flush(self) -> StreamToken | None:
52 """Flush remaining buffer at end of stream."""
53 if not self.buf:
54 return None
55 if self.in_thinking:
56 return StreamToken(content=self.buf, is_reasoning=True) if self.show else None
57 return StreamToken(content=self.buf, is_reasoning=False)
59 def _process_thinking(self) -> StreamToken | None:
60 """Process buffer while inside a <think> block. Returns None if waiting."""
61 close_idx = self.buf.find(_CLOSE_TAG)
62 if close_idx == -1:
63 if _could_be_partial(_CLOSE_TAG, self.buf):
64 return None # wait for more
65 content = self.buf
66 self.reasoning_chars += len(content)
67 self.buf = ""
68 return (
69 StreamToken(content=content, is_reasoning=True)
70 if self.show
71 else StreamToken(content="", is_reasoning=True)
72 )
74 thinking_content = self.buf[:close_idx]
75 self.reasoning_chars += len(thinking_content)
76 self.buf = self.buf[close_idx + len(_CLOSE_TAG) :]
77 self.in_thinking = False
78 if thinking_content and self.show:
79 return StreamToken(content=thinking_content, is_reasoning=True)
80 return StreamToken(content="", is_reasoning=True)
82 def _process_normal(self) -> StreamToken | None:
83 """Process buffer while outside thinking blocks. Returns None if waiting."""
84 open_idx = self.buf.find(_OPEN_TAG)
85 if open_idx == -1:
86 if _could_be_partial(_OPEN_TAG, self.buf):
87 return None # wait for more
88 content = self.buf
89 self.buf = ""
90 return StreamToken(content=content, is_reasoning=False)
92 before = self.buf[:open_idx]
93 self.buf = self.buf[open_idx + len(_OPEN_TAG) :]
94 self.in_thinking = True
95 return StreamToken(content=before, is_reasoning=False)
98_MAX_REASONING_CHARS = 16_000 # ~4K tokens — safety limit for runaway reasoning
100# Drain at most this many chars of post-truncation response content
101# before giving up; each token pull triggers another inference step.
102_DRAIN_CAP = 512
105def filter_reasoning(tokens: Iterator[str], *, show: bool) -> Iterator[StreamToken]:
106 """Filter ``<think>...</think>`` tags from a token stream; cap runaway reasoning.
108 When *show* is True, yields thinking content as ``StreamToken(is_reasoning=True)``;
109 when False, strips it. Always closes the upstream iterator on exit so a
110 runaway think loop doesn't leave llama.cpp holding the chat lock.
111 """
112 parser = _TagParser(show=show)
113 try:
114 truncated = yield from _stream_until_cap(parser, tokens)
115 if truncated:
116 yield from _drain_after_truncation(parser, tokens)
117 final = parser.flush()
118 if final and final.content:
119 yield final
120 finally:
121 _close_iterator(tokens)
124def _stream_until_cap(
125 parser: _TagParser, tokens: Iterator[str]
126) -> Generator[StreamToken, None, bool]:
127 """Yield from *tokens* until the reasoning cap fires. Returns True if truncated."""
128 for token in tokens:
129 for st in parser.feed(token):
130 if st.content:
131 yield st
132 if parser.reasoning_chars > _MAX_REASONING_CHARS:
133 parser.in_thinking = False
134 parser.buf = ""
135 yield StreamToken(content="\n[reasoning truncated]", is_reasoning=True)
136 return True
137 return False
140def _drain_after_truncation(parser: _TagParser, tokens: Iterator[str]) -> Iterator[StreamToken]:
141 """After truncation, yield up to ``_DRAIN_CAP`` chars of response content."""
142 drain_chars = 0
143 for token in tokens:
144 drain_chars += len(token)
145 if drain_chars > _DRAIN_CAP:
146 return
147 for st in parser.feed(token):
148 if st.content and not st.is_reasoning:
149 yield st
152def _close_iterator(tokens: Iterator[str]) -> None:
153 """Close *tokens* if it satisfies the ClosableIterator protocol."""
154 if isinstance(tokens, ClosableIterator):
155 with contextlib.suppress(Exception):
156 tokens.close()
159def strip_reasoning(text: str) -> str:
160 """Remove ``<think>...</think>`` blocks from a complete (non-streaming) string."""
161 return _THINK_BLOCK_RE.sub("", text)
164def _could_be_partial(tag: str, buf: str) -> bool:
165 """Check if the end of buf could be the start of the given tag."""
166 return any(buf.endswith(tag[:length]) for length in range(1, len(tag)))