diff --git a/torchprime/experimental/benchmark/hf_model.py b/torchprime/experimental/benchmark/hf_model.py new file mode 100644 index 00000000..95776b0c --- /dev/null +++ b/torchprime/experimental/benchmark/hf_model.py @@ -0,0 +1,81 @@ +from typing import Any + +import torch +from transformers.models.llama import modeling_llama +from transformers.models.qwen3 import modeling_qwen3 + + +def get_llama3_model(torch_dtype: torch.dtype): + """Returns the Llama3.2 1B model.""" + config = modeling_llama.LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=128000, + eos_token_id=128001, + head_dim=64, + hidden_act="silu", + hidden_size=2048, + initializer_range=0.02, + intermediate_size=8192, + max_position_embeddings=131072, + mlp_bias=False, + num_attention_heads=32, + num_hidden_layers=16, + num_key_value_heads=8, + rms_norm_eps=1e-05, + rope_scaling={ + "factor": 32.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + rope_theta=500000.0, + tie_word_embeddings=True, + use_cache=True, + vocab_size=128256, + _attn_implementation="eager", + ) + model = modeling_llama.LlamaForCausalLM(config).to(torch_dtype) + return model + + +def get_qwen3_model(torch_dtype: torch.dtype): + """Returns the Qwen3 1.7B model.""" + config = modeling_qwen3.Qwen3Config( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=151643, + eos_token_id=151645, + head_dim=128, + hidden_act="silu", + hidden_size=2048, + initializer_range=0.02, + intermediate_size=6144, + max_position_embeddings=40960, + max_window_layers=28, + num_attention_heads=16, + num_hidden_layers=28, + num_key_value_heads=8, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=1000000, + sliding_window=None, + tie_word_embeddings=True, + use_cache=True, + use_sliding_window=False, + vocab_size=151936, + _attn_implementation="eager", + ) + model = modeling_qwen3.Qwen3ForCausalLM(config).to(torch_dtype) + return model + + +def get_model(model_name: str, dtype: torch.dtype) -> Any: + match model_name: + case "llama3.2-1B": + return get_llama3_model(dtype) + case "qwen3-1.7B": + return get_qwen3_model(dtype) + case _: + raise ValueError(f"Unsupported model: {model_name}") \ No newline at end of file diff --git a/torchprime/experimental/benchmark/hf_models_forward.py b/torchprime/experimental/benchmark/hf_models_forward.py new file mode 100644 index 00000000..c987ba41 --- /dev/null +++ b/torchprime/experimental/benchmark/hf_models_forward.py @@ -0,0 +1,110 @@ +import argparse +import os +import time +from typing import Any + +import numpy as np +import torch +import torch_xla +import torch_xla.core.xla_model as xm + +from torchprime.experimental.benchmark.hf_model import get_model + + +def main(args): + # --- Configuration --- + dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} + torch_dtype = dtype_map[args.dtype] + + # It's good practice to define the device first. + device = torch_xla.device() + + # Create the model on CPU first + model_cpu = get_model(args.model_name, torch_dtype) + config = model_cpu.config + model_cpu.eval() # Set to evaluation mode + + # Move model to the XLA device. + model_tpu = model_cpu.to(device) + + # Create dummy input_ids and move to the XLA device. + input_ids = torch.randint( + 0, config.vocab_size, (args.batch_size, args.seq_len), dtype=torch.long + ) + # Move inputs to the XLA device as well. + input_ids = input_ids.to(device) + + # Preheat the cache. + print("Preheating...") + preheat_start_time = time.perf_counter() + with torch.no_grad(): + # Assign to a variable to prevent garbage collection before sync. + logits = model_tpu(input_ids).logits + torch_xla.sync() + # xm.wait_device_ops() + preheat_end_time = time.perf_counter() + preheat_time = preheat_end_time - preheat_start_time + print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms") + + # Initial run (warm-up) to trigger XLA compilation + print("Warming up (includes XLA graph compilation)...") + warmup_start_time = time.perf_counter() + with torch.no_grad(): + logits = model_tpu(input_ids).logits + xm.wait_device_ops() # Block until the graph compilation and execution is complete. + warmup_end_time = time.perf_counter() + warmup_time = warmup_end_time - warmup_start_time + + # Subsequent runs for measurement + print(f"Starting benchmark for {args.num_runs} runs...") + times = [] + for i in range(args.num_runs): + start_time = time.perf_counter() + with torch.no_grad(): + # Assign to a variable to prevent garbage collection before sync. + logits = model_tpu(input_ids).logits + + xm.wait_device_ops() # Block until the step's computation is complete for accurate timing. + end_time = time.perf_counter() + times.append(end_time - start_time) + print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms") + + # Print final performance results + print("\n--- Benchmark Results (Lazy Mode) ---") + print(f"Model: {args.model_name}, DType: {args.dtype}") + print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}") + print(f"Preheat time: {preheat_time * 1000:.2f} ms") + print(f"Warm-up time: {warmup_time * 1000:.2f} ms") + print(f"Number of runs: {len(times)}") + print(f"Average latency: {np.mean(times) * 1000:.2f} ms") + print(f"Median latency: {np.median(times) * 1000:.2f} ms") + print(f"P90 latency: {np.percentile(times, 90) * 1000:.2f} ms") + print(f"Min latency: {np.min(times) * 1000:.2f} ms") + print(f"Max latency: {np.max(times) * 1000:.2f} ms") + + # Add this line to wait for the TPU to finish and ensure a clean exit + xm.wait_device_ops() # Final sync to ensure all pending operations are done. + print("Script finished and exited cleanly.") + os._exit(0) # <-- Use os._exit() instead of sys.exit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark HF models on XLA (Lazy Mode).") + parser.add_argument( + "--model_name", + type=str, + default="llama3.2-1B", + choices=["llama3.2-1B", "qwen3-1.7B"], + help="Model to benchmark (must match a config file name).", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["bfloat16", "float32"], + help="Data type for the model.", + ) + parser.add_argument("--batch_size", type=int, default=1, help="Batch size.") + parser.add_argument("--seq_len", type=int, default=128, help="Sequence length.") + parser.add_argument("--num_runs", type=int, default=10, help="Number of benchmark runs.") + main(parser.parse_args()) \ No newline at end of file diff --git a/torchprime/experimental/benchmark/hf_models_forward_eager.py b/torchprime/experimental/benchmark/hf_models_forward_eager.py new file mode 100644 index 00000000..759ff2ea --- /dev/null +++ b/torchprime/experimental/benchmark/hf_models_forward_eager.py @@ -0,0 +1,111 @@ +import argparse +import os +import time +from typing import Any + +import numpy as np +import torch +import torch_xla +import torch_xla.core.xla_model as xm + +from torchprime.experimental.benchmark.hf_model import get_model + + +def main(args): + # --- Configuration --- + print("Running in PyTorch/XLA experimental eager mode.") + torch_xla.experimental.eager_mode(True) + + dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} + torch_dtype = dtype_map[args.dtype] + + # It's good practice to define the device first. + device = torch_xla.device() + + # Create the model on CPU first + model_cpu = get_model(args.model_name, torch_dtype) + config = model_cpu.config + model_cpu.eval() # Set to evaluation mode + + # Move model to the XLA device. + model_tpu = model_cpu.to(device) + + # Create dummy input_ids and move to the XLA device. + input_ids = torch.randint( + 0, config.vocab_size, (args.batch_size, args.seq_len), dtype=torch.long + ) + # Move inputs to the XLA device as well. + input_ids = input_ids.to(device) + + # Preheat the cache. + print("Preheating...") + preheat_start_time = time.perf_counter() + with torch.no_grad(): + # Assign to a variable to prevent garbage collection before sync. + logits = model_tpu(input_ids).logits + torch_xla.sync() + preheat_end_time = time.perf_counter() + preheat_time = preheat_end_time - preheat_start_time + print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms") + + # Initial run (warm-up) + print("Warming up...") + warmup_start_time = time.perf_counter() + with torch.no_grad(): + logits = model_tpu(input_ids).logits + # Block until the operation is complete. + xm.wait_device_ops() + warmup_end_time = time.perf_counter() + warmup_time = warmup_end_time - warmup_start_time + + # Subsequent runs for measurement + print(f"Starting benchmark for {args.num_runs} runs...") + times = [] + for i in range(args.num_runs): + start_time = time.perf_counter() + with torch.no_grad(): + # Assign to a variable to prevent garbage collection before sync. + logits = model_tpu(input_ids).logits + xm.wait_device_ops() + + end_time = time.perf_counter() + times.append(end_time - start_time) + print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms") + + # Print final performance results + print("\n--- Benchmark Results (Eager Mode) ---") + print(f"Model: {args.model_name}, DType: {args.dtype}") + print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}") + print(f"Preheat time: {preheat_time * 1000:.2f} ms") + print(f"Warm-up time: {warmup_time * 1000:.2f} ms") + print(f"Number of runs: {len(times)}") + print(f"Average latency: {np.mean(times) * 1000:.2f} ms") + print(f"Median latency: {np.median(times) * 1000:.2f} ms") + print(f"P90 latency: {np.percentile(times, 90) * 1000:.2f} ms") + print(f"Min latency: {np.min(times) * 1000:.2f} ms") + print(f"Max latency: {np.max(times) * 1000:.2f} ms") + + print("Script finished and exited cleanly.") + os._exit(0) # <-- Use os._exit() instead of sys.exit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark HF models on XLA (Eager Mode).") + parser.add_argument( + "--model_name", + type=str, + default="llama3.2-1B", + choices=["llama3.2-1B", "qwen3-1.7B"], + help="Model to benchmark (must match a config file name).", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["bfloat16", "float32"], + help="Data type for the model.", + ) + parser.add_argument("--batch_size", type=int, default=1, help="Batch size.") + parser.add_argument("--seq_len", type=int, default=128, help="Sequence length.") + parser.add_argument("--num_runs", type=int, default=10, help="Number of benchmark runs.") + main(parser.parse_args()) \ No newline at end of file