Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions torchprime/experimental/benchmark/hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Any

import torch
from transformers.models.llama import modeling_llama
from transformers.models.qwen3 import modeling_qwen3


def get_llama3_model(torch_dtype: torch.dtype):
"""Returns the Llama3.2 1B model."""
config = modeling_llama.LlamaConfig(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=128000,
eos_token_id=128001,
head_dim=64,
hidden_act="silu",
hidden_size=2048,
initializer_range=0.02,
intermediate_size=8192,
max_position_embeddings=131072,
mlp_bias=False,
num_attention_heads=32,
num_hidden_layers=16,
num_key_value_heads=8,
rms_norm_eps=1e-05,
rope_scaling={
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
rope_theta=500000.0,
tie_word_embeddings=True,
use_cache=True,
vocab_size=128256,
_attn_implementation="eager",
)
model = modeling_llama.LlamaForCausalLM(config).to(torch_dtype)
return model


def get_qwen3_model(torch_dtype: torch.dtype):
"""Returns the Qwen3 1.7B model."""
config = modeling_qwen3.Qwen3Config(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=151643,
eos_token_id=151645,
head_dim=128,
hidden_act="silu",
hidden_size=2048,
initializer_range=0.02,
intermediate_size=6144,
max_position_embeddings=40960,
max_window_layers=28,
num_attention_heads=16,
num_hidden_layers=28,
num_key_value_heads=8,
rms_norm_eps=1e-06,
rope_scaling=None,
rope_theta=1000000,
sliding_window=None,
tie_word_embeddings=True,
use_cache=True,
use_sliding_window=False,
vocab_size=151936,
_attn_implementation="eager",
)
model = modeling_qwen3.Qwen3ForCausalLM(config).to(torch_dtype)
return model


def get_model(model_name: str, dtype: torch.dtype) -> Any:
match model_name:
case "llama3.2-1B":
return get_llama3_model(dtype)
case "qwen3-1.7B":
return get_qwen3_model(dtype)
case _:
raise ValueError(f"Unsupported model: {model_name}")
110 changes: 110 additions & 0 deletions torchprime/experimental/benchmark/hf_models_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import argparse
import os
import time
from typing import Any

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm

from torchprime.experimental.benchmark.hf_model import get_model


def main(args):
# --- Configuration ---
dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32}
torch_dtype = dtype_map[args.dtype]

# It's good practice to define the device first.
device = torch_xla.device()

# Create the model on CPU first
model_cpu = get_model(args.model_name, torch_dtype)
config = model_cpu.config
model_cpu.eval() # Set to evaluation mode

# Move model to the XLA device.
model_tpu = model_cpu.to(device)

# Create dummy input_ids and move to the XLA device.
input_ids = torch.randint(
0, config.vocab_size, (args.batch_size, args.seq_len), dtype=torch.long
)
# Move inputs to the XLA device as well.
input_ids = input_ids.to(device)

# Preheat the cache.
print("Preheating...")
preheat_start_time = time.perf_counter()
with torch.no_grad():
# Assign to a variable to prevent garbage collection before sync.
logits = model_tpu(input_ids).logits
torch_xla.sync()
# xm.wait_device_ops()
preheat_end_time = time.perf_counter()
preheat_time = preheat_end_time - preheat_start_time
print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms")

# Initial run (warm-up) to trigger XLA compilation
print("Warming up (includes XLA graph compilation)...")
warmup_start_time = time.perf_counter()
with torch.no_grad():
logits = model_tpu(input_ids).logits
xm.wait_device_ops() # Block until the graph compilation and execution is complete.
warmup_end_time = time.perf_counter()
warmup_time = warmup_end_time - warmup_start_time

# Subsequent runs for measurement
print(f"Starting benchmark for {args.num_runs} runs...")
times = []
for i in range(args.num_runs):
start_time = time.perf_counter()
with torch.no_grad():
# Assign to a variable to prevent garbage collection before sync.
logits = model_tpu(input_ids).logits

xm.wait_device_ops() # Block until the step's computation is complete for accurate timing.
end_time = time.perf_counter()
times.append(end_time - start_time)
print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms")

# Print final performance results
print("\n--- Benchmark Results (Lazy Mode) ---")
print(f"Model: {args.model_name}, DType: {args.dtype}")
print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}")
print(f"Preheat time: {preheat_time * 1000:.2f} ms")
print(f"Warm-up time: {warmup_time * 1000:.2f} ms")
print(f"Number of runs: {len(times)}")
print(f"Average latency: {np.mean(times) * 1000:.2f} ms")
print(f"Median latency: {np.median(times) * 1000:.2f} ms")
print(f"P90 latency: {np.percentile(times, 90) * 1000:.2f} ms")
print(f"Min latency: {np.min(times) * 1000:.2f} ms")
print(f"Max latency: {np.max(times) * 1000:.2f} ms")

# Add this line to wait for the TPU to finish and ensure a clean exit
xm.wait_device_ops() # Final sync to ensure all pending operations are done.
print("Script finished and exited cleanly.")
os._exit(0) # <-- Use os._exit() instead of sys.exit()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark HF models on XLA (Lazy Mode).")
parser.add_argument(
"--model_name",
type=str,
default="llama3.2-1B",
choices=["llama3.2-1B", "qwen3-1.7B"],
help="Model to benchmark (must match a config file name).",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float32"],
help="Data type for the model.",
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size.")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length.")
parser.add_argument("--num_runs", type=int, default=10, help="Number of benchmark runs.")
main(parser.parse_args())
111 changes: 111 additions & 0 deletions torchprime/experimental/benchmark/hf_models_forward_eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import argparse
import os
import time
from typing import Any

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm

from torchprime.experimental.benchmark.hf_model import get_model


def main(args):
# --- Configuration ---
print("Running in PyTorch/XLA experimental eager mode.")
torch_xla.experimental.eager_mode(True)

dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32}
torch_dtype = dtype_map[args.dtype]

# It's good practice to define the device first.
device = torch_xla.device()

# Create the model on CPU first
model_cpu = get_model(args.model_name, torch_dtype)
config = model_cpu.config
model_cpu.eval() # Set to evaluation mode

# Move model to the XLA device.
model_tpu = model_cpu.to(device)

# Create dummy input_ids and move to the XLA device.
input_ids = torch.randint(
0, config.vocab_size, (args.batch_size, args.seq_len), dtype=torch.long
)
# Move inputs to the XLA device as well.
input_ids = input_ids.to(device)

# Preheat the cache.
print("Preheating...")
preheat_start_time = time.perf_counter()
with torch.no_grad():
# Assign to a variable to prevent garbage collection before sync.
logits = model_tpu(input_ids).logits
torch_xla.sync()
preheat_end_time = time.perf_counter()
preheat_time = preheat_end_time - preheat_start_time
print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms")

# Initial run (warm-up)
print("Warming up...")
warmup_start_time = time.perf_counter()
with torch.no_grad():
logits = model_tpu(input_ids).logits
# Block until the operation is complete.
xm.wait_device_ops()
warmup_end_time = time.perf_counter()
warmup_time = warmup_end_time - warmup_start_time

# Subsequent runs for measurement
print(f"Starting benchmark for {args.num_runs} runs...")
times = []
for i in range(args.num_runs):
start_time = time.perf_counter()
with torch.no_grad():
# Assign to a variable to prevent garbage collection before sync.
logits = model_tpu(input_ids).logits
xm.wait_device_ops()

end_time = time.perf_counter()
times.append(end_time - start_time)
print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms")

# Print final performance results
print("\n--- Benchmark Results (Eager Mode) ---")
print(f"Model: {args.model_name}, DType: {args.dtype}")
print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}")
print(f"Preheat time: {preheat_time * 1000:.2f} ms")
print(f"Warm-up time: {warmup_time * 1000:.2f} ms")
print(f"Number of runs: {len(times)}")
print(f"Average latency: {np.mean(times) * 1000:.2f} ms")
print(f"Median latency: {np.median(times) * 1000:.2f} ms")
print(f"P90 latency: {np.percentile(times, 90) * 1000:.2f} ms")
print(f"Min latency: {np.min(times) * 1000:.2f} ms")
print(f"Max latency: {np.max(times) * 1000:.2f} ms")

print("Script finished and exited cleanly.")
os._exit(0) # <-- Use os._exit() instead of sys.exit()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark HF models on XLA (Eager Mode).")
parser.add_argument(
"--model_name",
type=str,
default="llama3.2-1B",
choices=["llama3.2-1B", "qwen3-1.7B"],
help="Model to benchmark (must match a config file name).",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float32"],
help="Data type for the model.",
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size.")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length.")
parser.add_argument("--num_runs", type=int, default=10, help="Number of benchmark runs.")
main(parser.parse_args())
Loading