From ebe342cd7a353854dc6b478ba09dcf9f45fbf062 Mon Sep 17 00:00:00 2001 From: ddiddi Date: Thu, 20 Feb 2025 15:55:41 -0800 Subject: [PATCH 1/3] add distill and merge commands using kits --- solo_server/commands/distill.py | 68 +++++++++++++++++++++++++++++++++ solo_server/commands/merge.py | 55 ++++++++++++++++++++++++++ 2 files changed, 123 insertions(+) create mode 100644 solo_server/commands/distill.py create mode 100644 solo_server/commands/merge.py diff --git a/solo_server/commands/distill.py b/solo_server/commands/distill.py new file mode 100644 index 0000000..5c6c7bb --- /dev/null +++ b/solo_server/commands/distill.py @@ -0,0 +1,68 @@ +import typer +import subprocess +import os +from huggingface_hub import HfApi + +app = typer.Typer() + +def run_command(command): + """ + Runs a shell command and prints output. + """ + try: + result = subprocess.run(command, check=True, capture_output=True, text=True) + typer.echo(result.stdout) + except subprocess.CalledProcessError as e: + typer.echo(f"❌ Command failed: {e.stderr}", err=True) + + +@app.command() +def distill( + teacher_model: str, + student_model: str, + dataset_name: str, + output_dir: str = "./distilled_model", + num_train_epochs: int = 3, + per_device_train_batch_size: int = 1, + gradient_accumulation_steps: int = 8, + learning_rate: float = 2e-5, + alpha: float = 0.5, + temperature: float = 2.0, + use_flash_attention: bool = True, + bf16: bool = True, +): + """ + Runs model distillation using DistillKit. + """ + typer.echo(f"🚀 Starting distillation with teacher: {teacher_model}, student: {student_model}") + + command = [ + "python", "distill.py", + "--teacher-model", teacher_model, + "--student-model", student_model, + "--dataset-name", dataset_name, + "--output-dir", output_dir, + "--num-train-epochs", str(num_train_epochs), + "--per-device-train-batch-size", str(per_device_train_batch_size), + "--gradient-accumulation-steps", str(gradient_accumulation_steps), + "--learning-rate", str(learning_rate), + "--alpha", str(alpha), + "--temperature", str(temperature), + "--use-flash-attention", str(use_flash_attention), + "--bf16", str(bf16) + ] + + run_command(command) + typer.echo(f"✅ Distillation complete! Model saved at {output_dir}") + + api = HfApi() + api.upload_file( + path_or_fileobj=output_dir, + path_in_repo="distilled_model", + repo_id="solo-ai/distilled-model", + repo_type="space", + ) + typer.echo(f"✅ Model uploaded to Hugging Face!") + +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/solo_server/commands/merge.py b/solo_server/commands/merge.py new file mode 100644 index 0000000..ec00bd1 --- /dev/null +++ b/solo_server/commands/merge.py @@ -0,0 +1,55 @@ +import typer +import subprocess +import os +from huggingface_hub import HfApi + +app = typer.Typer() + +def run_command(command): + """ + Runs a shell command and prints output. + """ + try: + result = subprocess.run(command, check=True, capture_output=True, text=True) + typer.echo(result.stdout) + except subprocess.CalledProcessError as e: + typer.echo(f"❌ Command failed: {e.stderr}", err=True) + + +@app.command() +def merge( + config_path: str, + output_dir: str = "./merged_model", + cuda: bool = False, + allow_crimes: bool = False, + upload: bool = False, + hf_repo: str = "", +): + """ + Merges models using MergeKit, supporting SLERP and gradient-based merging. + """ + typer.echo(f"🚀 Running MergeKit with config: {config_path}") + + command = ["mergekit-yaml", config_path, output_dir] + if cuda: + command.append("--cuda") + if allow_crimes: + command.append("--allow-crimes") + + run_command(command) + typer.echo(f"✅ Merging complete! Model saved at {output_dir}") + + # Upload to Hugging Face + if upload: + if not hf_repo: + typer.echo("❌ Hugging Face repository required for upload.", err=True) + raise typer.Exit(code=1) + + typer.echo(f"📤 Uploading model to Hugging Face: {hf_repo}") + api = HfApi() + api.upload_folder(folder_path=output_dir, repo_id=hf_repo, repo_type="model") + typer.echo(f"✅ Model uploaded to https://huggingface.co/{hf_repo}") + + +if __name__ == "__main__": + app() From c0eb2a1cffd155af07f651a1e6e8fc5cb3dd5eb8 Mon Sep 17 00:00:00 2001 From: ddiddi Date: Thu, 20 Feb 2025 16:22:24 -0800 Subject: [PATCH 2/3] add weight watchers analysis --- solo_server/commands/watcher.py | 77 +++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 solo_server/commands/watcher.py diff --git a/solo_server/commands/watcher.py b/solo_server/commands/watcher.py new file mode 100644 index 0000000..9dc2a4d --- /dev/null +++ b/solo_server/commands/watcher.py @@ -0,0 +1,77 @@ +import typer +import subprocess +import torch +import torchvision.models as models +import weightwatcher as ww +import pandas as pd + +app = typer.Typer() + +def run_command(command): + """ + Runs a shell command and prints output. + """ + try: + result = subprocess.run(command, check=True, capture_output=True, text=True) + typer.echo(result.stdout) + except subprocess.CalledProcessError as e: + typer.echo(f"❌ Command failed: {e.stderr}", err=True) + + +@app.command() +def install_weightwatcher(): + """ + Installs WeightWatcher. Tries official PyPI first, then TestPyPI if it fails. + """ + typer.echo("📦 Installing WeightWatcher...") + + command = ["pip", "install", "weightwatcher"] + try: + run_command(command) + typer.echo("✅ WeightWatcher installed successfully!") + except: + typer.echo("⚠️ Standard installation failed. Trying TestPyPI...") + command = [ + "python3", "-m", "pip", "install", + "--index-url", "https://test.pypi.org/simple/", + "--extra-index-url", "https://pypi.org/simple", + "weightwatcher" + ] + run_command(command) + typer.echo("✅ WeightWatcher installed from TestPyPI!") + + +@app.command() +def analyze( + model_name: str = typer.Argument(..., help="Torchvision model name (e.g., vgg19_bn, resnet50)"), + save_results: bool = typer.Option(False, help="Save analysis results as CSV") +): + """ + Analyzes a model using WeightWatcher and prints generalization metrics. + """ + typer.echo(f"🔍 Analyzing model: {model_name}") + + # Load the model from torchvision + try: + model = getattr(models, model_name)(pretrained=True) + except AttributeError: + typer.echo(f"❌ Error: Model '{model_name}' not found in torchvision.models", err=True) + raise typer.Exit(code=1) + + # Run WeightWatcher analysis + watcher = ww.WeightWatcher(model=model) + details = watcher.analyze() + summary = watcher.get_summary(details) + + typer.echo("📊 Model Analysis Summary:") + typer.echo(summary) + + if save_results: + details_df = pd.DataFrame(details) + csv_filename = f"{model_name}_analysis.csv" + details_df.to_csv(csv_filename, index=False) + typer.echo(f"✅ Analysis results saved to {csv_filename}") + + +if __name__ == "__main__": + app() From 49cf7d86c603b1f1d7a3052baa7e2045533faadf Mon Sep 17 00:00:00 2001 From: ddiddi Date: Thu, 20 Feb 2025 16:36:46 -0800 Subject: [PATCH 3/3] add quantized commands --- solo_server/commands/quantize.py | 100 +++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 solo_server/commands/quantize.py diff --git a/solo_server/commands/quantize.py b/solo_server/commands/quantize.py new file mode 100644 index 0000000..7dcd0d2 --- /dev/null +++ b/solo_server/commands/quantize.py @@ -0,0 +1,100 @@ +import typer +import subprocess +import os +from datasets import load_dataset +from gptqmodel import GPTQModel, QuantizeConfig + +app = typer.Typer() + +def run_command(command): + """ + Runs a shell command and prints output. + """ + try: + result = subprocess.run(command, check=True, capture_output=True, text=True) + typer.echo(result.stdout) + except subprocess.CalledProcessError as e: + typer.echo(f"❌ Command failed: {e.stderr}", err=True) + + +@app.command() +def install_gptqmodel(): + """ + Installs GPTQModel. Tries official PyPI first, then from source if necessary. + """ + typer.echo("📦 Installing GPTQModel...") + + command = ["pip", "install", "-v", "--no-build-isolation", "gptqmodel"] + try: + run_command(command) + typer.echo("✅ GPTQModel installed successfully!") + except: + typer.echo("⚠️ Standard installation failed. Trying source installation...") + + command = [ + "git", "clone", "https://github.com/ModelCloud/GPTQModel.git", "GPTQModel" + ] + run_command(command) + + os.chdir("GPTQModel") + + command = ["pip", "install", "-v", "--no-build-isolation", "."] + run_command(command) + + typer.echo("✅ GPTQModel installed from source!") + + +@app.command() +def quantize( + model_id: str = typer.Argument(..., help="Hugging Face model ID to quantize"), + output_dir: str = typer.Option("./quantized_model", help="Directory to save quantized model"), + dataset: str = typer.Option("allenai/c4", help="Dataset for calibration"), + dataset_file: str = typer.Option("en/c4-train.00001-of-01024.json.gz", help="Calibration file"), + num_samples: int = typer.Option(1024, help="Number of calibration samples"), + bits: int = typer.Option(4, help="Quantization bits"), + group_size: int = typer.Option(128, help="Quantization group size"), + batch_size: int = typer.Option(2, help="Batch size for quantization"), +): + """ + Quantizes a model using GPTQModel and saves the output. + """ + typer.echo(f"🔧 Quantizing model: {model_id} with {bits}-bit precision...") + + # Load calibration dataset + calibration_dataset = load_dataset(dataset, data_files=dataset_file, split="train").select(range(num_samples))["text"] + + # Define quantization config + quant_config = QuantizeConfig(bits=bits, group_size=group_size) + + # Load model and perform quantization + model = GPTQModel.load(model_id, quant_config) + model.quantize(calibration_dataset, batch_size=batch_size) + + # Save quantized model + model.save(output_dir) + typer.echo(f"✅ Quantized model saved to {output_dir}") + + # Test post-quantization inference + test_prompt = "Uncovering deep insights begins with" + result = model.generate(test_prompt)[0] + typer.echo(f"📜 Sample Output: {model.tokenizer.decode(result)}") + + +@app.command() +def serve_quantized_model( + model_path: str = typer.Argument(..., help="Path to the quantized model"), + host: str = typer.Option("0.0.0.0", help="Host IP address"), + port: int = typer.Option(12345, help="Port for serving model"), +): + """ + Serves a quantized model using OpenAI-compatible API. + """ + typer.echo(f"🚀 Serving quantized model at {host}:{port}") + + model = GPTQModel.load(model_path) + model.serve(host=host, port=str(port)) + typer.echo(f"✅ Model API is running at http://{host}:{port}") + + +if __name__ == "__main__": + app()