-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcli.py
More file actions
310 lines (263 loc) · 12.8 KB
/
cli.py
File metadata and controls
310 lines (263 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
from dotenv import load_dotenv
load_dotenv()
import sys
import argparse
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from rich.table import Table
from rich.markup import escape
from rich.progress import Progress, BarColumn, TextColumn
from rich import print as rprint
from pathlib import Path
from setup.onboarding import run_onboarding, show_banner, VERSION
from orchestrator import LLMRouter
console = Console()
from rich.live import Live
from rich.text import Text
from rich.panel import Panel
def handle_query_streaming(router: LLMRouter, user_input: str, console: Console):
"""Handles a query with live streaming token output."""
domain = "..."
rewritten_prompt = ""
load_time = 0.0
inference_time = 0.0
cache_hit = False
context_turns = 0
accumulated = ""
done_event = {
"response": "", "domain": domain,
"inference_time_seconds": 0.0, "cache_hit": False,
"context_turns": 0, "rewritten_prompt": "",
"specialist_load_time": 0.0,
}
stream = router.stream_query(user_input)
event = None # Initialize to prevent UnboundLocalError
try:
# ── Phase 1: spin a status spinner during routing and model loading ─────────
with console.status(
"[cyan]Processing... (Router is persistent, specialist loads on-demand)",
spinner="dots"
) as status:
for event in stream:
etype = event["type"]
if etype == "routing":
domain = event.get("domain", domain)
rewritten_prompt = event.get("rewritten_prompt", "")
if event.get("is_multi_agent"):
status.update("[cyan]Multi-agent composition — routing to specialists...")
else:
status.update(
f"[cyan]Routed to [bold]{domain.upper()}[/bold] — loading specialist..."
)
elif etype == "loaded":
load_time = event["load_time"]
cache_hit = event["cache_hit"]
cache_str = "HOT ♻" if cache_hit else f"loaded in {load_time:.2f}s"
status.update(f"[cyan]Specialist {cache_str} — generating response...")
elif etype in ("token", "done"):
# Hand off to the Live streaming phase
break
if event is None:
console.print("[bold red]Error:[/bold red] Stream ended unexpectedly without yields.")
return done_event
# ── Phase 2: display tokens live (or full response for multi-agent) ─────────
with Live(
Panel("", title=f"[bold]Response (Specialist: {domain.lower()})[/bold]",
border_style="green"),
console=console,
refresh_per_second=15,
vertical_overflow="visible"
) as live:
# Handle the event that broke us out of Phase 1 first
if event["type"] == "token":
accumulated += event["content"]
live.update(Panel(
accumulated,
title=f"[bold]Response (Specialist: {domain.lower()})[/bold]",
border_style="green"
))
elif event["type"] == "done":
done_event = event
accumulated = event.get("response", "")
live.update(Panel(
accumulated,
title=f"[bold]Response (Specialist: {domain.lower()})[/bold]",
border_style="green"
))
# Continue consuming remaining stream events
for event in stream:
if event["type"] == "token":
accumulated += event["content"]
live.update(Panel(
accumulated,
title=f"[bold]Response (Specialist: {domain.lower()})[/bold]",
border_style="green"
))
elif event["type"] == "done":
done_event = event
inference_time = event.get("inference_time_seconds", 0.0)
cache_hit = event.get("cache_hit", False)
context_turns = event.get("context_turns", 0)
rewritten_prompt = event.get("rewritten_prompt", rewritten_prompt)
domain = event.get("domain", domain)
live.update(Panel(
event["response"],
title=f"[bold]Response (Specialist: {domain.lower()})[/bold]",
border_style="green"
))
break
except Exception as e:
console.print(f"[bold red]Streaming Error:[/bold red] {e}")
return done_event
console.print(
f"\n[dim]Router optimized prompt: {escape(rewritten_prompt[:80])}...[/dim]\n"
f"[dim]Metrics: Router: resident (0s) | "
f"Specialist Load: {load_time:.2f}s | "
f"Inference: {inference_time:.2f}s | "
f"Context: {context_turns} turns[/dim]"
)
return done_event
def show_availability(router: LLMRouter):
"""Displays a dashboard of which models are already downloaded."""
from loader.airllm_loader import GGUF_REGISTRY
availability = router.loader.get_local_models()
table = Table(title="Model Availability Dashboard", border_style="cyan")
table.add_column("Type", style="bold")
table.add_column("Domain", style="magenta")
table.add_column("Model (GGUF)", style="white")
table.add_column("Status", justify="center")
# Router
router_id = router.config["router"]["model_id"]
router_filename = GGUF_REGISTRY.get(router_id, ("", router_id))[1]
# Router is always loaded by the time this is called because it's persistent and init happens first
status = "[bold green]Loaded (persistent)[/bold green]"
table.add_row("Router", "N/A", router_filename, status)
# Specialists
for domain, spec in router.config["specialists"].items():
m_id = spec["model_id"]
is_local = availability.get(m_id)
filename = GGUF_REGISTRY.get(m_id, ("", m_id))[1]
status = "[green]Ready[/green]" if is_local else "[yellow]Download Required[/yellow]"
table.add_row("Specialist", domain.capitalize(), filename, status)
console.print(table)
console.print("[dim italic white]* Download will happen automatically on first query for each domain.[/dim italic white]\n")
def main():
# Show big ASCII banner first — always
show_banner()
# Check for --reconfigure flag
if "--reconfigure" in sys.argv:
run_onboarding(skip_banner=True)
return
# First run check — if no user_config.yaml, run onboarding
if not Path("user_config.yaml").exists():
console.print("[yellow]No user configuration found. Starting setup wizard...[/yellow]\n")
run_onboarding(skip_banner=True)
# Load config — prefer user_config.yaml over config.yaml
config_path = "user_config.yaml" if Path("user_config.yaml").exists() else "config.yaml"
parser = argparse.ArgumentParser(description="MELLM - LLM Router CLI")
parser.add_argument("--preload", type=str, help="Preload a specific domain model or 'all'")
parser.add_argument(
"--web",
action="store_true",
help="Launch the web UI instead of the CLI"
)
args = parser.parse_args()
try:
# Initialize router — router model loads here and stays resident
with console.status("[cyan]Loading router model (persistent)...[/cyan]"):
router = LLMRouter(config_path=config_path)
if args.web:
import uvicorn
from api import app, set_router
set_router(router) # inject the already-loaded router into the API
console.print("\n[bold cyan]Web UI starting...[/bold cyan]")
console.print("[dim]Open http://localhost:8000 in your browser[/dim]\n")
import webbrowser, threading
threading.Timer(1.5, lambda: webbrowser.open("http://localhost:8000")).start()
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="warning")
return
if args.preload:
specialists = router.config["specialists"]
target_domains = []
if args.preload.lower() == "all":
target_domains = list(specialists.keys())
elif args.preload.lower() in specialists:
target_domains = [args.preload.lower()]
else:
console.print(f"[bold red]Error:[/bold red] Invalid domain '{args.preload}'. Valid domains: {', '.join(specialists.keys())}, all")
router.shutdown()
sys.exit(1)
for domain in target_domains:
m_id = specialists[domain]["model_id"]
console.print(f"[bold blue]Preloading {domain}...[/bold blue]")
router.loader.get(m_id)
router.loader.unload(m_id)
console.print(f"[bold green]Done: {domain} cached.[/bold green]")
console.print("\n[bold green]Preloading complete.[/bold green]")
router.shutdown()
sys.exit(0)
# Show availability dashboard before starting normal mode
show_availability(router)
except Exception as e:
console.print(f"[bold red]Initialization Error:[/bold red] {escape(str(e))}")
sys.exit(1)
console.print("[green]System initialized. Router is persistent. Type 'exit' or 'quit' to stop.[/green]")
# Session stats tracked in CLI
session_stats = {
"total_queries": 0,
"cache_hits": 0,
}
try:
while True:
try:
user_input = console.input("\n[bold yellow]Query:[/bold yellow] ")
if user_input.lower() in ["exit", "quit"]:
break
if user_input.lower() == "clear":
router.conversation_history.clear()
console.print("[yellow]Context cleared. Starting fresh conversation.[/yellow]")
continue
if not user_input.strip():
continue
# Use the new streaming handler by default
result = handle_query_streaming(router, user_input, console)
if "error" in result:
console.print(f"[bold red]Pipeline Error:[/bold red] {result['error']}")
continue
domain = result["domain"]
cache_hit = result.get("cache_hit", False)
# Update session stats
session_stats["total_queries"] += 1
if cache_hit:
session_stats["cache_hits"] += 1
# Calculate efficiency
total = session_stats["total_queries"]
hits = session_stats["cache_hits"]
hit_rate = (hits / total * 100) if total > 0 else 0
router_time_saved = (total - 1) * 1.0 # ~1s saved per query after first
hot_label = "[green](HOT ♻)[/green]" if cache_hit else "[yellow](freshly loaded)[/yellow]"
# ── Efficiency panel ──────────────────────────────────────────
streak_display = " → ".join(router.domain_streak[-5:]) if router.domain_streak else "none"
console.print(Panel(
f"[bold cyan]Session Efficiency[/bold cyan]\n"
f" Queries this session : [white]{total}[/white]\n"
f" Specialist cache hits: [green]{hits}/{total} ({hit_rate:.0f}%)[/green]\n"
f" Router loads saved : [green]{total - 1}[/green] "
f"(~[yellow]{router_time_saved:.1f}s[/yellow] saved)\n"
f" Active specialist : [cyan]{domain.upper()}[/cyan] {hot_label}\n"
f" Context turns active : [cyan]{len(router.conversation_history)}/{router.max_history}[/cyan]\n"
f" Domain streak : [cyan]{streak_display}[/cyan]",
title=f"⚡ Efficiency ({VERSION})",
border_style="dim"
))
except KeyboardInterrupt:
console.print("\n[yellow]Interrupted.[/yellow]")
break
except Exception as e:
console.print(f"[bold red]Error:[/bold red] {escape(str(e))}")
finally:
if 'router' in locals():
router.shutdown()
if __name__ == "__main__":
main()