Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 280 additions & 8 deletions docs/projection.md

Large diffs are not rendered by default.

117 changes: 115 additions & 2 deletions primus/cli/subcommands/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,24 @@ def run(args, overrides):

launch_projection_from_cli(args, overrides)
elif args.suite == "performance":
from primus.pretrain import setup_backend_path
# Normalise mode: "inference" is an alias for "prefill"
mode = getattr(args, "mode", "training")
if mode == "inference":
args.mode = "prefill"

setup_backend_path(framework="megatron", verbose=True)
profiling_mode = getattr(args, "profiling_mode", "benchmark")

# Decode + simulate is fully analytical — no backend needed.
# Decode + benchmark/both needs the Megatron backend (runs real layers
# with seq_len=1 to measure decode-step GEMMs on the GPU).
needs_backend = profiling_mode != "simulate"
if mode == "decode" and profiling_mode == "simulate":
needs_backend = False

if needs_backend:
from primus.pretrain import setup_backend_path

setup_backend_path(framework="megatron", verbose=True)

from primus.core.projection.performance_projection import (
launch_projection_from_cli,
Expand Down Expand Up @@ -92,6 +107,104 @@ def register_subcommand(subparsers):
"If not provided, uses default cluster parameters.\n\n"
),
)
performance.add_argument(
"--profiling-mode",
type=str,
required=False,
default="benchmark",
choices=["benchmark", "simulate", "both"],
help=(
"Profiling mode for layer timing:\n"
" benchmark - Run actual GPU benchmarks (default, requires GPU)\n"
" simulate - Use simulation backends (origami for GEMM,\n"
" analytical model for SDPA). No GPU required.\n"
" both - Run both benchmark and simulation, report side-by-side\n"
),
)
performance.add_argument(
"--gemm-backend",
type=str,
required=False,
default=None,
choices=["origami"],
help=(
"GEMM simulation backend (only used when --profiling-mode is 'simulate' or 'both').\n"
" origami - Open-source GEMM performance model (default)\n"
),
)
performance.add_argument(
"--gpu-arch",
type=str,
required=False,
default=None,
help=(
"Target GPU architecture for simulation (e.g. 'mi300x', 'gfx942', 'mi355x', 'gfx950').\n"
"If not specified, auto-detected or uses PRIMUS_GPU_ARCH env var.\n"
),
)
performance.add_argument(
"--gpu-clock-mhz",
type=int,
required=False,
default=None,
help=(
"Override the GPU compute clock frequency in MHz for simulation.\n"
"If not specified, uses the default from the hardware profile for the\n"
"given --gpu-arch (e.g. 2100 MHz for MI300X/MI325X).\n"
"Can also be set via the PRIMUS_GPU_CLOCK_MHZ env var.\n"
"Example: --gpu-clock-mhz 1500\n"
),
)
performance.add_argument(
"--mode",
type=str,
required=False,
default="training",
choices=["training", "inference", "prefill", "decode"],
help=(
"Projection mode:\n"
" training - Project training iteration time (forward + backward +\n"
" optimizer step + gradient AllReduce). Default.\n"
" inference - Alias for 'prefill'.\n"
" prefill - Project inference prefill latency (forward-only, no\n"
" backward pass, optimizer, or gradient communication).\n"
" decode - Project autoregressive decode latency per token.\n"
" With --profiling-mode simulate: fully analytical (no GPU).\n"
" With --profiling-mode benchmark: benchmarks GEMMs with\n"
" seq_len=1 on GPU, overlays analytical KV cache model.\n"
),
)
performance.add_argument(
"--decode-batch-size",
type=int,
required=False,
default=None,
help=(
"Number of sequences being decoded concurrently (decode mode only).\n"
"Defaults to micro_batch_size from the config.\n"
),
)
performance.add_argument(
"--decode-context-length",
type=int,
required=False,
default=None,
help=(
"Current context length during decode, i.e. number of previous tokens\n"
"in the KV cache (decode mode only). Affects KV cache read time.\n"
"Defaults to sequence_length from the config.\n"
),
)
performance.add_argument(
"--num-generated-tokens",
type=int,
required=False,
default=None,
help=(
"Number of tokens to generate (decode mode only). Used to estimate\n"
"total generation time. Defaults to 128.\n"
),
)

parser.set_defaults(func=run)

Expand Down
216 changes: 204 additions & 12 deletions primus/core/projection/module_profilers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def __init__(self, config, sub_profilers=None):
self.module = None # Will be set during benchmarking
self._cached_results = None # Cache for (forward_time, backward_time, activation_memory)
self._cache_key = None # Cache key (batch_size, seq_len)
self._gemm_backend = None # Optional: GEMM simulation backend
self._sdpa_backend = None # Optional: SDPA simulation backend

def set_module(self, module):
"""Set the actual attention module for benchmarking."""
Expand All @@ -27,6 +29,18 @@ def set_module(self, module):
self._cached_results = None
self._cache_key = None

def set_gemm_backend(self, backend):
"""Set a GEMM simulation backend for attention linear projections."""
self._gemm_backend = backend
self._cached_results = None
self._cache_key = None

def set_sdpa_backend(self, backend):
"""Set an SDPA simulation backend for attention computation."""
self._sdpa_backend = backend
self._cached_results = None
self._cache_key = None

def estimated_num_params(self, rank: Optional[int] = None) -> int:
args = self.config.model_config
# Group-query & multi-latent attention support.
Expand Down Expand Up @@ -131,23 +145,201 @@ def _num_query_groups() -> int:

return tokens_per_rank * (activation_width + ln_width) * bytes_per_value

def _simulate_mla_gemms(self, batch_tokens: int, dtype: str) -> tuple[float, float]:
"""Simulate MLA (Multi-Latent Attention) projection GEMMs.

MLA uses LoRA-factored Q and compressed KV projections instead of
standard Q/K/V projections:
Forward (6 GEMMs): Q_down, Q_up, KV_down, KV_up, RoPE_proj, O_proj
Backward (12 GEMMs): dgrad + wgrad for each of the 6 projections
"""
args = self.config.model_config
backend = self._gemm_backend

hidden = args.hidden_size
heads = args.num_attention_heads
q_lora_rank = args.q_lora_rank
kv_lora_rank = args.kv_lora_rank
qk_head_dim = args.qk_head_dim
qk_pos_emb_head_dim = args.qk_pos_emb_head_dim
v_head_dim = args.v_head_dim

fwd_time = 0.0
bwd_time = 0.0
T = batch_tokens

# ---------- Forward ----------
if q_lora_rank is not None:
# Q down-proj: [T, hidden] × [hidden, q_lora_rank]
q_down_out = q_lora_rank
r = backend.simulate_gemm(T, q_down_out, hidden, dtype)
fwd_time += r.forward_time_ms
# Q up-proj: [T, q_lora_rank] × [q_lora_rank, heads*(qk_hd+qk_pe_hd)]
q_up_out = heads * (qk_head_dim + qk_pos_emb_head_dim)
r = backend.simulate_gemm(T, q_up_out, q_lora_rank, dtype)
fwd_time += r.forward_time_ms
else:
# Direct Q projection (no LoRA): [T, hidden] × [hidden, heads*(qk_hd+qk_pe_hd)]
q_up_out = heads * (qk_head_dim + qk_pos_emb_head_dim)
r = backend.simulate_gemm(T, q_up_out, hidden, dtype)
fwd_time += r.forward_time_ms

# KV down-proj: [T, hidden] × [hidden, kv_lora_rank]
kv_down_out = kv_lora_rank
r = backend.simulate_gemm(T, kv_down_out, hidden, dtype)
fwd_time += r.forward_time_ms
# KV up-proj: [T, kv_lora_rank] × [kv_lora_rank, heads*(qk_hd+v_hd)]
kv_up_out = heads * (qk_head_dim + v_head_dim)
r = backend.simulate_gemm(T, kv_up_out, kv_lora_rank, dtype)
fwd_time += r.forward_time_ms

# RoPE positional embedding projection: [T, hidden] × [hidden, qk_pos_emb_head_dim]
r = backend.simulate_gemm(T, qk_pos_emb_head_dim, hidden, dtype)
fwd_time += r.forward_time_ms

# Output projection: [T, heads*v_hd] × [heads*v_hd, hidden]
o_in = heads * v_head_dim
r = backend.simulate_gemm(T, hidden, o_in, dtype)
fwd_time += r.forward_time_ms

# ---------- Backward (dgrad + wgrad for each projection) ----------
if q_lora_rank is not None:
# Q down-proj dgrad: [T, q_down_out] × [q_down_out, hidden] → [T, hidden]
r = backend.simulate_gemm(T, hidden, q_down_out, dtype)
bwd_time += r.forward_time_ms
# Q down-proj wgrad: [hidden, T] × [T, q_down_out] → [hidden, q_down_out]
r = backend.simulate_gemm(hidden, q_down_out, T, dtype)
bwd_time += r.forward_time_ms
# Q up-proj dgrad: [T, q_up_out] × [q_up_out, q_lora_rank] → [T, q_lora_rank]
r = backend.simulate_gemm(T, q_lora_rank, q_up_out, dtype)
bwd_time += r.forward_time_ms
# Q up-proj wgrad: [q_lora_rank, T] × [T, q_up_out] → [q_lora_rank, q_up_out]
r = backend.simulate_gemm(q_lora_rank, q_up_out, T, dtype)
bwd_time += r.forward_time_ms
else:
# Direct Q dgrad + wgrad
r = backend.simulate_gemm(T, hidden, q_up_out, dtype)
bwd_time += r.forward_time_ms
r = backend.simulate_gemm(hidden, q_up_out, T, dtype)
bwd_time += r.forward_time_ms

# KV down-proj dgrad + wgrad
r = backend.simulate_gemm(T, hidden, kv_down_out, dtype)
bwd_time += r.forward_time_ms
r = backend.simulate_gemm(hidden, kv_down_out, T, dtype)
bwd_time += r.forward_time_ms
# KV up-proj dgrad + wgrad
r = backend.simulate_gemm(T, kv_lora_rank, kv_up_out, dtype)
bwd_time += r.forward_time_ms
r = backend.simulate_gemm(kv_lora_rank, kv_up_out, T, dtype)
bwd_time += r.forward_time_ms

# RoPE proj dgrad + wgrad
r = backend.simulate_gemm(T, hidden, qk_pos_emb_head_dim, dtype)
bwd_time += r.forward_time_ms
r = backend.simulate_gemm(hidden, qk_pos_emb_head_dim, T, dtype)
bwd_time += r.forward_time_ms

# O proj dgrad + wgrad
r = backend.simulate_gemm(T, o_in, hidden, dtype)
bwd_time += r.forward_time_ms
r = backend.simulate_gemm(o_in, hidden, T, dtype)
bwd_time += r.forward_time_ms

return fwd_time, bwd_time

def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]:
"""Get simulated results from GEMM + SDPA simulation backends."""
args = self.config.model_config
mp = self.config.model_parallel_config
tp_size = max(1, mp.tensor_model_parallel_size)
cp_size = max(1, mp.context_model_parallel_size)

batch_tokens = batch_size * seq_len // tp_size // cp_size
slen_per_cp = seq_len // cp_size

fwd_time = 0.0
bwd_time = 0.0

# 1. Simulate linear projection GEMMs using GEMM backend
if self._gemm_backend is not None:
gemm_dtype = "fp8" if getattr(args, "fp8", None) else "bf16"

if getattr(args, "multi_latent_attention", False):
# MLA: LoRA-factored Q and compressed KV projections
# 6 forward GEMMs + 12 backward GEMMs
mla_fwd, mla_bwd = self._simulate_mla_gemms(batch_tokens, gemm_dtype)
fwd_time += mla_fwd
bwd_time += mla_bwd
else:
# Standard attention: Q, K, V, O projections
# 4 forward GEMMs + 8 backward GEMMs
num_query_groups = (
args.num_query_groups
if args.group_query_attention and args.num_query_groups
else args.num_attention_heads
)
gemm_result = self._gemm_backend.simulate_attention_gemms(
batch_tokens=batch_tokens,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
kv_channels=args.kv_channels,
num_query_groups=num_query_groups,
dtype=gemm_dtype,
)
fwd_time += gemm_result.forward_time_ms
bwd_time += gemm_result.backward_time_ms

# 2. Simulate SDPA core computation using SDPA backend
if self._sdpa_backend is not None:
heads_per_rank = max(1, args.num_attention_heads // tp_size)

if getattr(args, "multi_latent_attention", False):
# MLA: Q·Kᵀ uses qk_head_dim + qk_pos_emb_head_dim (e.g. 192),
# P·V uses v_head_dim (e.g. 128).
sdpa_head_dim = args.qk_head_dim + args.qk_pos_emb_head_dim
sdpa_head_dim_v = args.v_head_dim
else:
sdpa_head_dim = args.kv_channels
sdpa_head_dim_v = None # same as head_dim

sdpa_result = self._sdpa_backend.simulate_sdpa(
batch_size=batch_size,
num_heads=heads_per_rank,
seq_len=slen_per_cp,
head_dim=sdpa_head_dim,
causal=True,
dtype="bf16",
head_dim_v=sdpa_head_dim_v,
)
fwd_time += sdpa_result.forward_time_ms
bwd_time += sdpa_result.backward_time_ms

activation_memory = self.estimated_activation_memory(batch_size, seq_len)
return (fwd_time, bwd_time, activation_memory)

def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]:
"""Get or compute benchmark results (cached)."""
cache_key = (batch_size, seq_len)

if self._cached_results is None or self._cache_key != cache_key:
# Context parallel / Sequence parallel adjustment
cp_size = self.config.model_parallel_config.context_model_parallel_size
# Effective sequence length per rank if CP is used
slen_per_cp = seq_len // cp_size

self._cached_results = benchmark_layer(
self.module,
[
(seq_len, batch_size, self.config.model_config.hidden_size),
((1, 1, slen_per_cp, seq_len), torch.bool),
],
)
if self._gemm_backend is not None or self._sdpa_backend is not None:
# Use simulation mode
self._cached_results = self._get_simulated_results(batch_size, seq_len)
else:
# Use actual GPU benchmarking
# Context parallel / Sequence parallel adjustment
cp_size = self.config.model_parallel_config.context_model_parallel_size
# Effective sequence length per rank if CP is used
slen_per_cp = seq_len // cp_size

self._cached_results = benchmark_layer(
self.module,
[
(seq_len, batch_size, self.config.model_config.hidden_size),
((1, 1, slen_per_cp, seq_len), torch.bool),
],
)
self._cache_key = cache_key
return self._cached_results

Expand Down
13 changes: 12 additions & 1 deletion primus/core/projection/module_profilers/collective_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,18 @@ class CollectiveArgs:
nics_per_node: Optional[int] = 8 # NICs per node (None = gpus_per_node)

# All-to-all specific
a2a_peer_lat: float = 0.45 # Per-peer latency overhead for a2a
a2a_peer_lat: float = 0.45 # Per-peer latency overhead for inter-node a2a
a2a_intra_node_peer_lat: float = 28.0 # Per-peer latency overhead for intra-node a2a
# Intra-node overhead is higher (~19-28 us) due to:
# - P2P scatter/gather scheduling overhead
# - RCCL internal synchronization barriers
# - Memory copy and buffer management
# Note: Preflight measurements for EP=8 intra-node A2A show:
# - Linear extrapolation: ~27.4 us per peer
# - 2MB measurement: ~28.1 us per peer
# - After subtracting bandwidth component: ~19.4 us per peer
# Default 28 us matches preflight measurements (middle of range)
# Can be overridden via hardware_config for GPU-specific calibration


def get_default_args(
Expand Down
Loading