Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions solo_server/commands/distill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import typer
import subprocess
import os
from huggingface_hub import HfApi

app = typer.Typer()

def run_command(command):
"""
Runs a shell command and prints output.
"""
try:
result = subprocess.run(command, check=True, capture_output=True, text=True)
typer.echo(result.stdout)
except subprocess.CalledProcessError as e:
typer.echo(f"❌ Command failed: {e.stderr}", err=True)


@app.command()
def distill(
teacher_model: str,
student_model: str,
dataset_name: str,
output_dir: str = "./distilled_model",
num_train_epochs: int = 3,
per_device_train_batch_size: int = 1,
gradient_accumulation_steps: int = 8,
learning_rate: float = 2e-5,
alpha: float = 0.5,
temperature: float = 2.0,
use_flash_attention: bool = True,
bf16: bool = True,
):
"""
Runs model distillation using DistillKit.
"""
typer.echo(f"🚀 Starting distillation with teacher: {teacher_model}, student: {student_model}")

command = [
"python", "distill.py",
"--teacher-model", teacher_model,
"--student-model", student_model,
"--dataset-name", dataset_name,
"--output-dir", output_dir,
"--num-train-epochs", str(num_train_epochs),
"--per-device-train-batch-size", str(per_device_train_batch_size),
"--gradient-accumulation-steps", str(gradient_accumulation_steps),
"--learning-rate", str(learning_rate),
"--alpha", str(alpha),
"--temperature", str(temperature),
"--use-flash-attention", str(use_flash_attention),
"--bf16", str(bf16)
]

run_command(command)
typer.echo(f"✅ Distillation complete! Model saved at {output_dir}")

api = HfApi()
api.upload_file(
path_or_fileobj=output_dir,
path_in_repo="distilled_model",
repo_id="solo-ai/distilled-model",
repo_type="space",
)
typer.echo(f"✅ Model uploaded to Hugging Face!")

if __name__ == "__main__":
app()
55 changes: 55 additions & 0 deletions solo_server/commands/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import typer
import subprocess
import os
from huggingface_hub import HfApi

app = typer.Typer()

def run_command(command):
"""
Runs a shell command and prints output.
"""
try:
result = subprocess.run(command, check=True, capture_output=True, text=True)
typer.echo(result.stdout)
except subprocess.CalledProcessError as e:
typer.echo(f"❌ Command failed: {e.stderr}", err=True)


@app.command()
def merge(
config_path: str,
output_dir: str = "./merged_model",
cuda: bool = False,
allow_crimes: bool = False,
upload: bool = False,
hf_repo: str = "",
):
"""
Merges models using MergeKit, supporting SLERP and gradient-based merging.
"""
typer.echo(f"🚀 Running MergeKit with config: {config_path}")

command = ["mergekit-yaml", config_path, output_dir]
if cuda:
command.append("--cuda")
if allow_crimes:
command.append("--allow-crimes")

run_command(command)
typer.echo(f"✅ Merging complete! Model saved at {output_dir}")

# Upload to Hugging Face
if upload:
if not hf_repo:
typer.echo("❌ Hugging Face repository required for upload.", err=True)
raise typer.Exit(code=1)

typer.echo(f"📤 Uploading model to Hugging Face: {hf_repo}")
api = HfApi()
api.upload_folder(folder_path=output_dir, repo_id=hf_repo, repo_type="model")
typer.echo(f"✅ Model uploaded to https://huggingface.co/{hf_repo}")


if __name__ == "__main__":
app()
100 changes: 100 additions & 0 deletions solo_server/commands/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import typer
import subprocess
import os
from datasets import load_dataset
from gptqmodel import GPTQModel, QuantizeConfig

app = typer.Typer()

def run_command(command):
"""
Runs a shell command and prints output.
"""
try:
result = subprocess.run(command, check=True, capture_output=True, text=True)
typer.echo(result.stdout)
except subprocess.CalledProcessError as e:
typer.echo(f"❌ Command failed: {e.stderr}", err=True)


@app.command()
def install_gptqmodel():
"""
Installs GPTQModel. Tries official PyPI first, then from source if necessary.
"""
typer.echo("📦 Installing GPTQModel...")

command = ["pip", "install", "-v", "--no-build-isolation", "gptqmodel"]
try:
run_command(command)
typer.echo("✅ GPTQModel installed successfully!")
except:
typer.echo("⚠️ Standard installation failed. Trying source installation...")

command = [
"git", "clone", "https://github.com/ModelCloud/GPTQModel.git", "GPTQModel"
]
run_command(command)

os.chdir("GPTQModel")

command = ["pip", "install", "-v", "--no-build-isolation", "."]
run_command(command)

typer.echo("✅ GPTQModel installed from source!")


@app.command()
def quantize(
model_id: str = typer.Argument(..., help="Hugging Face model ID to quantize"),
output_dir: str = typer.Option("./quantized_model", help="Directory to save quantized model"),
dataset: str = typer.Option("allenai/c4", help="Dataset for calibration"),
dataset_file: str = typer.Option("en/c4-train.00001-of-01024.json.gz", help="Calibration file"),
num_samples: int = typer.Option(1024, help="Number of calibration samples"),
bits: int = typer.Option(4, help="Quantization bits"),
group_size: int = typer.Option(128, help="Quantization group size"),
batch_size: int = typer.Option(2, help="Batch size for quantization"),
):
"""
Quantizes a model using GPTQModel and saves the output.
"""
typer.echo(f"🔧 Quantizing model: {model_id} with {bits}-bit precision...")

# Load calibration dataset
calibration_dataset = load_dataset(dataset, data_files=dataset_file, split="train").select(range(num_samples))["text"]

# Define quantization config
quant_config = QuantizeConfig(bits=bits, group_size=group_size)

# Load model and perform quantization
model = GPTQModel.load(model_id, quant_config)
model.quantize(calibration_dataset, batch_size=batch_size)

# Save quantized model
model.save(output_dir)
typer.echo(f"✅ Quantized model saved to {output_dir}")

# Test post-quantization inference
test_prompt = "Uncovering deep insights begins with"
result = model.generate(test_prompt)[0]
typer.echo(f"📜 Sample Output: {model.tokenizer.decode(result)}")


@app.command()
def serve_quantized_model(
model_path: str = typer.Argument(..., help="Path to the quantized model"),
host: str = typer.Option("0.0.0.0", help="Host IP address"),
port: int = typer.Option(12345, help="Port for serving model"),
):
"""
Serves a quantized model using OpenAI-compatible API.
"""
typer.echo(f"🚀 Serving quantized model at {host}:{port}")

model = GPTQModel.load(model_path)
model.serve(host=host, port=str(port))
typer.echo(f"✅ Model API is running at http://{host}:{port}")


if __name__ == "__main__":
app()
77 changes: 77 additions & 0 deletions solo_server/commands/watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import typer
import subprocess
import torch
import torchvision.models as models
import weightwatcher as ww
import pandas as pd

app = typer.Typer()

def run_command(command):
"""
Runs a shell command and prints output.
"""
try:
result = subprocess.run(command, check=True, capture_output=True, text=True)
typer.echo(result.stdout)
except subprocess.CalledProcessError as e:
typer.echo(f"❌ Command failed: {e.stderr}", err=True)


@app.command()
def install_weightwatcher():
"""
Installs WeightWatcher. Tries official PyPI first, then TestPyPI if it fails.
"""
typer.echo("📦 Installing WeightWatcher...")

command = ["pip", "install", "weightwatcher"]
try:
run_command(command)
typer.echo("✅ WeightWatcher installed successfully!")
except:
typer.echo("⚠️ Standard installation failed. Trying TestPyPI...")
command = [
"python3", "-m", "pip", "install",
"--index-url", "https://test.pypi.org/simple/",
"--extra-index-url", "https://pypi.org/simple",
"weightwatcher"
]
run_command(command)
typer.echo("✅ WeightWatcher installed from TestPyPI!")


@app.command()
def analyze(
model_name: str = typer.Argument(..., help="Torchvision model name (e.g., vgg19_bn, resnet50)"),
save_results: bool = typer.Option(False, help="Save analysis results as CSV")
):
"""
Analyzes a model using WeightWatcher and prints generalization metrics.
"""
typer.echo(f"🔍 Analyzing model: {model_name}")

# Load the model from torchvision
try:
model = getattr(models, model_name)(pretrained=True)
except AttributeError:
typer.echo(f"❌ Error: Model '{model_name}' not found in torchvision.models", err=True)
raise typer.Exit(code=1)

# Run WeightWatcher analysis
watcher = ww.WeightWatcher(model=model)
details = watcher.analyze()
summary = watcher.get_summary(details)

typer.echo("📊 Model Analysis Summary:")
typer.echo(summary)

if save_results:
details_df = pd.DataFrame(details)
csv_filename = f"{model_name}_analysis.csv"
details_df.to_csv(csv_filename, index=False)
typer.echo(f"✅ Analysis results saved to {csv_filename}")


if __name__ == "__main__":
app()
Loading