Coverage for src / lilbee / reasoning.py: 100%

103 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Reasoning token filter — detects <think>...</think> tags in streaming output. 

2 

3Reasoning models (Qwen3, DeepSeek-R1) wrap their thinking process in 

4``<think>...</think>`` tags. This module provides a stateful filter that 

5classifies tokens as reasoning or response content. 

6""" 

7 

8from __future__ import annotations 

9 

10import contextlib 

11import re 

12from collections.abc import Generator, Iterator 

13from dataclasses import dataclass 

14 

15from lilbee.providers.base import ClosableIterator 

16 

17_OPEN_TAG = "<think>" 

18_CLOSE_TAG = "</think>" 

19_THINK_BLOCK_RE = re.compile(r"<think>[\s\S]*?</think>\s*|<think>[\s\S]*$") 

20 

21 

22@dataclass 

23class StreamToken: 

24 """A classified token from the stream.""" 

25 

26 content: str 

27 is_reasoning: bool 

28 

29 

30class _TagParser: 

31 """Stateful parser that tracks whether we're inside a thinking block.""" 

32 

33 def __init__(self, *, show: bool) -> None: 

34 self.show = show 

35 self.buf = "" 

36 self.in_thinking = False 

37 self.reasoning_chars = 0 

38 

39 def feed(self, token: str) -> list[StreamToken]: 

40 """Feed a token and return any complete StreamTokens.""" 

41 self.buf += token 

42 result: list[StreamToken] = [] 

43 while self.buf: 

44 emitted = self._process_thinking() if self.in_thinking else self._process_normal() 

45 if emitted is None: 

46 break # waiting for more data (partial tag) 

47 if emitted.content: 

48 result.append(emitted) 

49 return result 

50 

51 def flush(self) -> StreamToken | None: 

52 """Flush remaining buffer at end of stream.""" 

53 if not self.buf: 

54 return None 

55 if self.in_thinking: 

56 return StreamToken(content=self.buf, is_reasoning=True) if self.show else None 

57 return StreamToken(content=self.buf, is_reasoning=False) 

58 

59 def _process_thinking(self) -> StreamToken | None: 

60 """Process buffer while inside a <think> block. Returns None if waiting.""" 

61 close_idx = self.buf.find(_CLOSE_TAG) 

62 if close_idx == -1: 

63 if _could_be_partial(_CLOSE_TAG, self.buf): 

64 return None # wait for more 

65 content = self.buf 

66 self.reasoning_chars += len(content) 

67 self.buf = "" 

68 return ( 

69 StreamToken(content=content, is_reasoning=True) 

70 if self.show 

71 else StreamToken(content="", is_reasoning=True) 

72 ) 

73 

74 thinking_content = self.buf[:close_idx] 

75 self.reasoning_chars += len(thinking_content) 

76 self.buf = self.buf[close_idx + len(_CLOSE_TAG) :] 

77 self.in_thinking = False 

78 if thinking_content and self.show: 

79 return StreamToken(content=thinking_content, is_reasoning=True) 

80 return StreamToken(content="", is_reasoning=True) 

81 

82 def _process_normal(self) -> StreamToken | None: 

83 """Process buffer while outside thinking blocks. Returns None if waiting.""" 

84 open_idx = self.buf.find(_OPEN_TAG) 

85 if open_idx == -1: 

86 if _could_be_partial(_OPEN_TAG, self.buf): 

87 return None # wait for more 

88 content = self.buf 

89 self.buf = "" 

90 return StreamToken(content=content, is_reasoning=False) 

91 

92 before = self.buf[:open_idx] 

93 self.buf = self.buf[open_idx + len(_OPEN_TAG) :] 

94 self.in_thinking = True 

95 return StreamToken(content=before, is_reasoning=False) 

96 

97 

98_MAX_REASONING_CHARS = 16_000 # ~4K tokens — safety limit for runaway reasoning 

99 

100# Drain at most this many chars of post-truncation response content 

101# before giving up; each token pull triggers another inference step. 

102_DRAIN_CAP = 512 

103 

104 

105def filter_reasoning(tokens: Iterator[str], *, show: bool) -> Iterator[StreamToken]: 

106 """Filter ``<think>...</think>`` tags from a token stream; cap runaway reasoning. 

107 

108 When *show* is True, yields thinking content as ``StreamToken(is_reasoning=True)``; 

109 when False, strips it. Always closes the upstream iterator on exit so a 

110 runaway think loop doesn't leave llama.cpp holding the chat lock. 

111 """ 

112 parser = _TagParser(show=show) 

113 try: 

114 truncated = yield from _stream_until_cap(parser, tokens) 

115 if truncated: 

116 yield from _drain_after_truncation(parser, tokens) 

117 final = parser.flush() 

118 if final and final.content: 

119 yield final 

120 finally: 

121 _close_iterator(tokens) 

122 

123 

124def _stream_until_cap( 

125 parser: _TagParser, tokens: Iterator[str] 

126) -> Generator[StreamToken, None, bool]: 

127 """Yield from *tokens* until the reasoning cap fires. Returns True if truncated.""" 

128 for token in tokens: 

129 for st in parser.feed(token): 

130 if st.content: 

131 yield st 

132 if parser.reasoning_chars > _MAX_REASONING_CHARS: 

133 parser.in_thinking = False 

134 parser.buf = "" 

135 yield StreamToken(content="\n[reasoning truncated]", is_reasoning=True) 

136 return True 

137 return False 

138 

139 

140def _drain_after_truncation(parser: _TagParser, tokens: Iterator[str]) -> Iterator[StreamToken]: 

141 """After truncation, yield up to ``_DRAIN_CAP`` chars of response content.""" 

142 drain_chars = 0 

143 for token in tokens: 

144 drain_chars += len(token) 

145 if drain_chars > _DRAIN_CAP: 

146 return 

147 for st in parser.feed(token): 

148 if st.content and not st.is_reasoning: 

149 yield st 

150 

151 

152def _close_iterator(tokens: Iterator[str]) -> None: 

153 """Close *tokens* if it satisfies the ClosableIterator protocol.""" 

154 if isinstance(tokens, ClosableIterator): 

155 with contextlib.suppress(Exception): 

156 tokens.close() 

157 

158 

159def strip_reasoning(text: str) -> str: 

160 """Remove ``<think>...</think>`` blocks from a complete (non-streaming) string.""" 

161 return _THINK_BLOCK_RE.sub("", text) 

162 

163 

164def _could_be_partial(tag: str, buf: str) -> bool: 

165 """Check if the end of buf could be the start of the given tag.""" 

166 return any(buf.endswith(tag[:length]) for length in range(1, len(tag)))