-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_manager.py
More file actions
99 lines (76 loc) · 3.12 KB
/
model_manager.py
File metadata and controls
99 lines (76 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Model manager for loading and switching between LLM models."""
import yaml
from pathlib import Path
from typing import Any
import config
class ModelManager:
"""Manages model loading and switching."""
def __init__(self, models_file: Path | None = None):
"""Load models from YAML file.
Args:
models_file: Path to models YAML. Defaults to config.MODELS_FILE.
"""
self.models_file = models_file or config.MODELS_FILE
self.override_file = config.MODEL_OVERRIDE_FILE
self._load_models()
def _load_models(self) -> None:
"""Load models from YAML file."""
with open(self.models_file) as f:
data = yaml.safe_load(f)
self.default_model = data.get("default_model", "deepseek/deepseek-v3.2")
self.models = {m["id"]: m for m in data.get("models", [])}
def get_current_model(self) -> str:
"""Get the currently selected model ID.
Returns model override if set, otherwise default model.
"""
if self.override_file.exists():
model_id = self.override_file.read_text().strip()
if model_id and model_id in self.models:
return model_id
return self.default_model
def set_model(self, model_id: str) -> dict[str, Any]:
"""Set the current model override.
Args:
model_id: Model ID to switch to.
Returns:
The model configuration dict.
Raises:
ValueError: If model_id is not found.
"""
if model_id not in self.models:
raise ValueError(f"Unknown model: {model_id}")
self.override_file.write_text(model_id)
return self.models[model_id]
def clear_override(self) -> None:
"""Clear model override, returning to default."""
if self.override_file.exists():
self.override_file.unlink()
def get_model_info(self, model_id: str) -> dict[str, Any] | None:
"""Get model info by ID."""
return self.models.get(model_id)
def list_models(self) -> list[dict[str, Any]]:
"""List all available models.
Returns:
List of model configuration dicts.
"""
return list(self.models.values())
def list_models_formatted(self) -> str:
"""List models in a formatted string for display."""
current = self.get_current_model()
lines = []
# Group by provider
by_provider: dict[str, list] = {}
for model in self.models.values():
provider = model.get("provider", "unknown")
if provider not in by_provider:
by_provider[provider] = []
by_provider[provider].append(model)
for provider, models in by_provider.items():
lines.append(f"\n{provider.upper()}:")
for m in models:
marker = " *" if m["id"] == current else ""
features = ", ".join(m.get("features", []))
lines.append(f" {m['id']}{marker}")
lines.append(f" {m['name']} [{features}]")
lines.append(f"\n* = current model")
return "\n".join(lines)