Coverage for src / lilbee / model_defaults.py: 100%

63 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-29 19:16 +0000

1"""Per-model default generation settings. 

2 

3Parses and caches generation parameters from key-value parameter text 

4or GGUF file metadata so that model-specific defaults (temperature, num_ctx, 

5etc.) are applied automatically when switching models. 

6""" 

7 

8from __future__ import annotations 

9 

10import contextlib 

11import logging 

12from dataclasses import dataclass, fields 

13from typing import Any 

14 

15log = logging.getLogger(__name__) 

16 

17# Parameter keys we recognise and their target types 

18_KNOWN_PARAM_TYPES: dict[str, type] = { 

19 "temperature": float, 

20 "top_p": float, 

21 "top_k": int, 

22 "repeat_penalty": float, 

23 "num_ctx": int, 

24 "max_tokens": int, 

25} 

26 

27# GGUF metadata keys mapped to ModelDefaults field names 

28_GGUF_KEY_MAP: dict[str, str] = { 

29 "general.temperature": "temperature", 

30 "general.top_p": "top_p", 

31 "general.top_k": "top_k", 

32 "general.repeat_penalty": "repeat_penalty", 

33} 

34 

35 

36@dataclass(frozen=True) 

37class ModelDefaults: 

38 """Frozen snapshot of a model's default generation parameters.""" 

39 

40 temperature: float | None = None 

41 top_p: float | None = None 

42 top_k: int | None = None 

43 repeat_penalty: float | None = None 

44 num_ctx: int | None = None 

45 max_tokens: int | None = None 

46 

47 

48class _DefaultsCache: 

49 """Encapsulates the per-model defaults cache (no module-level mutable global).""" 

50 

51 def __init__(self) -> None: 

52 self._data: dict[str, ModelDefaults] = {} 

53 

54 def get(self, model_name: str) -> ModelDefaults | None: 

55 return self._data.get(model_name) 

56 

57 def set(self, model_name: str, defaults: ModelDefaults) -> None: 

58 self._data[model_name] = defaults 

59 

60 def clear(self) -> None: 

61 self._data.clear() 

62 

63 

64_defaults_cache = _DefaultsCache() 

65 

66# Public API — preserves existing call sites. 

67get_defaults = _defaults_cache.get 

68set_defaults = _defaults_cache.set 

69clear_cache = _defaults_cache.clear 

70 

71 

72def parse_kv_parameters(text: str) -> ModelDefaults: 

73 """Parse multiline ``key value`` parameter format. 

74 Example input:: 

75 

76 temperature 0.7 

77 top_p 0.9 

78 num_ctx 4096 

79 stop <|im_end|> 

80 

81 Unknown keys (like ``stop``) are silently skipped. 

82 """ 

83 values: dict[str, Any] = {} 

84 for line in text.splitlines(): 

85 line = line.strip() 

86 if not line: 

87 continue 

88 parts = line.split(None, 1) 

89 if len(parts) != 2: 

90 continue 

91 key, raw_value = parts 

92 if key not in _KNOWN_PARAM_TYPES: 

93 continue 

94 try: 

95 values[key] = _KNOWN_PARAM_TYPES[key](raw_value) 

96 except (ValueError, TypeError): 

97 log.debug("Skipping unparseable param %s=%r", key, raw_value) 

98 return ModelDefaults(**values) 

99 

100 

101def read_gguf_defaults(metadata: dict[str, str]) -> ModelDefaults: 

102 """Extract generation defaults from a GGUF metadata dict. 

103 Looks for keys like ``general.temperature``, ``context_length`` (via the 

104 architecture-prefixed key already resolved by the caller into 

105 ``context_length``). 

106 """ 

107 values: dict[str, Any] = {} 

108 for gguf_key, field_name in _GGUF_KEY_MAP.items(): 

109 if gguf_key in metadata: 

110 try: 

111 target_type = _field_type(field_name) 

112 values[field_name] = target_type(metadata[gguf_key]) 

113 except (ValueError, TypeError): 

114 log.debug("Skipping unparseable GGUF key %s=%r", gguf_key, metadata[gguf_key]) 

115 if "context_length" in metadata: 

116 with contextlib.suppress(ValueError, TypeError): 

117 values["num_ctx"] = int(metadata["context_length"]) 

118 return ModelDefaults(**values) 

119 

120 

121def _field_type(field_name: str) -> type: 

122 """Return the base type for a ModelDefaults field.""" 

123 for f in fields(ModelDefaults): 

124 if f.name == field_name: 

125 return int if "int" in str(f.type) else float 

126 return float # pragma: no cover