Coverage for src / lilbee / code_chunker.py: 100%

65 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 08:27 +0000

1"""Tree-sitter based code chunking — splits source files on function/class boundaries.""" 

2 

3import logging 

4from dataclasses import dataclass 

5from pathlib import Path 

6 

7import tree_sitter 

8 

9from lilbee.chunker import chunk_text 

10from lilbee.languages import DEFINITION_TYPES, EXT_TO_LANG 

11 

12log = logging.getLogger(__name__) 

13 

14# Container nodes whose children may also be definitions 

15_CONTAINERS = frozenset({"class_body", "block", "declaration_list", "impl_body"}) 

16 

17 

18@dataclass 

19class CodeChunk: 

20 """A chunk of source code with line location metadata.""" 

21 

22 chunk: str 

23 line_start: int 

24 line_end: int 

25 chunk_index: int 

26 

27 

28def get_parser(lang_name: str) -> tree_sitter.Parser | None: 

29 """Get a tree-sitter parser for the given language.""" 

30 try: 

31 from tree_sitter_language_pack import get_parser 

32 

33 return get_parser(lang_name) # type: ignore[arg-type] 

34 except Exception: 

35 log.debug("Failed to load tree-sitter language: %s", lang_name) 

36 return None 

37 

38 

39def _node_span(node: tree_sitter.Node, source: bytes) -> dict: 

40 """Extract text and line range from an AST node.""" 

41 return { 

42 "text": source[node.start_byte : node.end_byte].decode("utf-8", errors="replace"), 

43 "line_start": node.start_point.row + 1, 

44 "line_end": node.end_point.row + 1, 

45 } 

46 

47 

48def collect_definitions( 

49 root: tree_sitter.Node, 

50 source: bytes, 

51 def_types: frozenset[str], 

52) -> list[dict]: 

53 """Walk top-level children + one level of containers for definitions.""" 

54 results: list[dict] = [] 

55 for child in root.children: 

56 if child.type in def_types: 

57 results.append(_node_span(child, source)) 

58 elif child.type in _CONTAINERS: 

59 results.extend(_node_span(gc, source) for gc in child.children if gc.type in def_types) 

60 return results 

61 

62 

63def find_line(needle: str, lines: list[str], start: int) -> int: 

64 """Find the first line index (1-based) containing needle, from start.""" 

65 for i in range(start, len(lines)): 

66 if needle and needle in lines[i]: 

67 return i + 1 

68 return start + 1 

69 

70 

71def _fallback_chunks(text: str) -> list[CodeChunk]: 

72 """Token-based chunking with approximate line tracking.""" 

73 raw = chunk_text(text) 

74 lines = text.split("\n") 

75 results: list[CodeChunk] = [] 

76 search_from = 0 

77 

78 for idx, chunk in enumerate(raw): 

79 first_line = chunk.split("\n")[0][:80] 

80 line_start = find_line(first_line, lines, search_from) 

81 line_end = min(line_start + chunk.count("\n"), len(lines)) 

82 results.append( 

83 CodeChunk( 

84 chunk=chunk, 

85 line_start=line_start, 

86 line_end=line_end, 

87 chunk_index=idx, 

88 ) 

89 ) 

90 search_from = line_start 

91 

92 return results 

93 

94 

95def chunk_code(file_path: Path) -> list[CodeChunk]: 

96 """Chunk a source file using tree-sitter. Falls back to token-based if needed.""" 

97 source = file_path.read_bytes() 

98 source_text = source.decode("utf-8", errors="replace") 

99 

100 lang_name = EXT_TO_LANG.get(file_path.suffix.lower()) 

101 if not lang_name: 

102 return _fallback_chunks(source_text) 

103 

104 parser = get_parser(lang_name) 

105 def_types = DEFINITION_TYPES.get(lang_name, frozenset()) 

106 if not parser or not def_types: 

107 return _fallback_chunks(source_text) 

108 

109 definitions = collect_definitions(parser.parse(source).root_node, source, def_types) 

110 if not definitions: 

111 return _fallback_chunks(source_text) 

112 

113 prefix = f"# File: {file_path}\n\n" 

114 return [ 

115 CodeChunk( 

116 chunk=prefix + d["text"], 

117 line_start=d["line_start"], 

118 line_end=d["line_end"], 

119 chunk_index=i, 

120 ) 

121 for i, d in enumerate(definitions) 

122 ] 

123 

124 

125def supported_extensions() -> set[str]: 

126 """File extensions supported by tree-sitter chunking.""" 

127 return set(EXT_TO_LANG)