Coverage for src / lilbee / cli / tui / screens / catalog_utils.py: 100%

67 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Catalog data types, row builders, and formatting helpers.""" 

2 

3from __future__ import annotations 

4 

5import re 

6from dataclasses import dataclass 

7 

8from lilbee.catalog import PARAM_COUNT_RE, CatalogModel, ModelFamily, ModelVariant, extract_quant 

9from lilbee.model_manager import RemoteModel 

10 

11 

12@dataclass 

13class TableRow: 

14 """A row in the catalog grid or list view with source metadata. 

15 

16 ``name`` is the human-readable display label (e.g. "Qwen3 0.6B"). 

17 ``ref`` is the canonical identifier used for config persistence: 

18 ``hf_repo`` for catalog rows, ``hf_repo/filename`` for installed 

19 native models, and the provider's ref shape for remote/API rows. 

20 """ 

21 

22 name: str 

23 task: str 

24 params: str 

25 size: str 

26 quant: str 

27 downloads: str 

28 featured: bool 

29 installed: bool 

30 sort_downloads: int 

31 sort_size: float 

32 ref: str = "" 

33 backend: str = "" 

34 variant: ModelVariant | None = None 

35 family: ModelFamily | None = None 

36 catalog_model: CatalogModel | None = None 

37 remote_model: RemoteModel | None = None 

38 

39 

40def parse_param_label(name: str) -> str: 

41 """Extract parameter count label from model name (e.g. '8B', '0.6B').""" 

42 from lilbee.catalog import PARAM_COUNT_RE 

43 

44 match = PARAM_COUNT_RE.search(name) 

45 return match.group(1).upper() if match else "--" 

46 

47 

48def _format_downloads(n: int) -> str: 

49 if n >= 1_000_000: 

50 return f"{n / 1_000_000:.1f}M" 

51 if n >= 1_000: 

52 return f"{n / 1_000:.0f}K" 

53 return str(n) 

54 

55 

56def _format_size_mb(size_mb: int) -> str: 

57 """Format size in MB to a human-readable string.""" 

58 if size_mb == 0: 

59 return "--" 

60 if size_mb >= 1024: 

61 return f"{size_mb / 1024:.1f} GB" 

62 return f"{size_mb} MB" 

63 

64 

65def format_size_gb(size_gb: float) -> str: 

66 """Format size in GB to a human-readable string.""" 

67 if size_gb <= 0: 

68 return "--" 

69 return f"{size_gb:.1f} GB" 

70 

71 

72def _is_param_count(label: str) -> bool: 

73 """True when label looks like a parameter count (e.g. '8B', '0.6B').""" 

74 return bool(PARAM_COUNT_RE.fullmatch(label)) 

75 

76 

77def variant_to_row(v: ModelVariant, f: ModelFamily, installed: bool) -> TableRow: 

78 """Convert a ModelVariant + family to a TableRow.""" 

79 # Avoid duplicating the param count when the family name already ends with it. 

80 if v.param_count and not f.name.endswith(v.param_count): 

81 label = f"{f.name} {v.param_count}" 

82 else: 

83 label = f.name 

84 params = v.param_count if _is_param_count(v.param_count) else "--" 

85 return TableRow( 

86 name=label, 

87 task=f.task, 

88 params=params, 

89 size=_format_size_mb(v.size_mb), 

90 quant=v.quant or "--", 

91 downloads="--", 

92 featured=True, 

93 installed=installed, 

94 sort_downloads=0, 

95 sort_size=v.size_mb / 1024, 

96 ref=v.hf_repo, 

97 backend="native", 

98 variant=v, 

99 family=f, 

100 ) 

101 

102 

103def catalog_to_row(m: CatalogModel, installed: bool) -> TableRow: 

104 """Convert a CatalogModel to a TableRow.""" 

105 quant = extract_quant(m.gguf_filename) 

106 return TableRow( 

107 name=m.display_name, 

108 task=m.task, 

109 params=parse_param_label(m.display_name), 

110 size=format_size_gb(m.size_gb), 

111 quant=quant or "--", 

112 downloads=_format_downloads(m.downloads) if m.downloads > 0 else "--", 

113 featured=m.featured, 

114 installed=installed, 

115 sort_downloads=m.downloads, 

116 sort_size=m.size_gb, 

117 ref=m.ref, 

118 backend="native", 

119 catalog_model=m, 

120 ) 

121 

122 

123def remote_to_row(rm: RemoteModel) -> TableRow: 

124 """Convert a RemoteModel to a TableRow.""" 

125 return TableRow( 

126 name=rm.name, 

127 task=rm.task, 

128 params=rm.parameter_size or "--", 

129 size="--", 

130 quant="--", 

131 downloads="--", 

132 featured=False, 

133 installed=True, 

134 sort_downloads=0, 

135 sort_size=0.0, 

136 ref=rm.name, 

137 backend=rm.provider.lower(), 

138 remote_model=rm, 

139 ) 

140 

141 

142# Column sort key extractors 

143SORT_KEYS = { 

144 "Name": lambda r: r.name.lower(), 

145 "Task": lambda r: r.task, 

146 "Backend": lambda r: r.backend.lower(), 

147 "Params": lambda r: _param_sort_value(r.params), 

148 "Size": lambda r: r.sort_size, 

149 "Quant": lambda r: r.quant, 

150 "Downloads": lambda r: r.sort_downloads, 

151} 

152 

153 

154def _param_sort_value(params: str) -> float: 

155 """Convert param label to sortable float (e.g. '8B' -> 8.0).""" 

156 match = re.search(r"(\d+\.?\d*)", params) 

157 return float(match.group(1)) if match else 0.0 

158 

159 

160def matches_search(row: TableRow, search: str) -> bool: 

161 """Return True if the row matches the search text (hyphen/underscore-insensitive).""" 

162 if not search: 

163 return True 

164 needle = _normalize_for_search(search) 

165 return any( 

166 needle in _normalize_for_search(field) 

167 for field in (row.name, row.task, row.params, row.quant, row.backend) 

168 ) 

169 

170 

171def _normalize_for_search(value: str) -> str: 

172 return value.lower().replace("-", " ").replace("_", " ")