diff --git a/docs/projection.md b/docs/projection.md index 5c5212026..5a2d93f58 100644 --- a/docs/projection.md +++ b/docs/projection.md @@ -1,6 +1,6 @@ # Performance Projection -Primus includes a performance projection tool that benchmarks transformer layers on a single node and projects training iteration times to multi-node configurations. +Primus includes a performance projection tool that benchmarks transformer layers on a single node and projects performance to multi-node configurations for **training**, **prefill**, and **decode** workloads. - **User-facing entry**: `primus-cli … -- projection performance [options]` - **Implementation entrypoint**: `primus/cli/subcommands/projection.py` @@ -14,15 +14,32 @@ The performance projection tool: 2. **Simulates** pipeline parallelism scheduling (including zero-bubble optimization) 3. **Projects** performance to multi-node configurations by modeling: - Data Parallelism (DP) scaling - - Gradient AllReduce communication overhead + - Gradient AllReduce communication overhead (training only) - Expert Parallelism (EP) All-to-All communication overhead - Inter-node vs intra-node communication differences -This allows you to estimate training performance on larger clusters without actually running on them. +This allows you to estimate training or inference performance on larger clusters without actually running on them. + +### Training vs Prefill vs Decode + +| Aspect | Training (`--mode training`) | Prefill (`--mode prefill`) | Decode (`--mode decode`) | +|--------|------------------------------|----------------------------|--------------------------| +| Compute | Forward + Backward + Wgrad | Forward only | Analytical (tiny GEMMs) | +| Bottleneck | Compute-bound | Compute-bound | **Memory-bandwidth-bound** (typically) | +| Optimizer | Adam step (HBM-bound) | None | None | +| Gradient sync | AllReduce across DP ranks | None | None | +| MoE A2A | Forward + Backward | Forward only | Analytical latency estimate | +| Pipeline sim | 1F1B / Zero-bubble schedule | Skipped | Skipped | +| FSDP | AllGather + ReduceScatter | None | None | +| DP scaling | Reduces microbatches per GPU | Adds independent replicas | Adds independent replicas | +| GPU required | Yes (benchmark) | Yes (benchmark) | No (simulate) / Yes (benchmark) | +| Metrics | Iteration time, tokens/s/GPU | Prefill latency, tokens/s/replica | Time/token, tokens/s, generation time | ## Quick Start -Run a basic performance projection for the minimum required nodes: +### Training Projection (default) + +Run a basic training performance projection for the minimum required nodes: ```bash export NNODES=1 @@ -33,7 +50,7 @@ bash runner/primus-cli direct --script primus/cli/main.py -- \ --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml ``` -Project performance to a specific number of nodes: +Project training performance to a specific number of nodes: ```bash export NNODES=1 @@ -45,6 +62,74 @@ bash runner/primus-cli direct --script primus/cli/main.py -- \ --target-nodes 4 ``` +### Prefill Projection + +Project inference prefill performance using the same config (`--mode inference` is an alias for `--mode prefill`): + +```bash +export NNODES=1 +export HSA_NO_SCRATCH_RECLAIM=1 + +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode prefill +``` + +Project prefill performance to multiple nodes (scales DP replicas): + +```bash +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode prefill --target-nodes 4 +``` + +### Decode Projection + +Project autoregressive decode (token generation) latency using the **analytical model** (no GPU required): + +```bash +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode decode --profiling-mode simulate +``` + +For **benchmark-enhanced** decode (real GPU measurement + analytical KV cache overlay): + +```bash +export NNODES=1 +export HSA_NO_SCRATCH_RECLAIM=1 + +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode decode +``` + +Customize decode parameters: + +```bash +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode decode \ + --decode-batch-size 32 \ + --decode-context-length 4096 \ + --num-generated-tokens 256 \ + --target-nodes 4 +``` + +Compare analytical and benchmark-enhanced side by side: + +```bash +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode decode --profiling-mode both +``` + ## Command Syntax ```bash @@ -58,6 +143,10 @@ primus-cli [global-options] [mode-args] -- projection performance [option | `--config` | string | Path to the Primus YAML configuration file (required) | | `--target-nodes` | int | Target number of nodes for projection. Defaults to minimum required by parallelism config | | `--hardware-config` | string | Path to YAML file with custom hardware parameters for communication modeling | +| `--mode` | string | `training` (default), `prefill`, `decode`, or `inference` (alias for prefill). Training projects full iteration time; prefill projects forward-only latency; decode projects per-token generation latency (supports both analytical and benchmark-enhanced profiling) | +| `--decode-batch-size` | int | Number of concurrent sequences being decoded (decode mode only). Defaults to micro_batch_size | +| `--decode-context-length` | int | Number of tokens already in the KV cache (decode mode only). Defaults to sequence_length | +| `--num-generated-tokens` | int | Number of tokens to generate (decode mode only). Used for total generation time estimate. Default: 128 | ### Parallelism Overrides @@ -163,6 +252,22 @@ Gradient AllReduce Size = num_params × 4 (FP32 gradients) **What it does**: Splits sequence length across GPUs for long-context training. +**How it interacts with other dimensions**: CP behaves differently for dense and MoE models: + +- **Dense models**: CP is an independent parallelism axis. It directly multiplies the GPU count: + ``` + Total GPUs = TP × PP × CP × DP + ``` + Each CP rank holds a different chunk of the sequence and they communicate via AllGather/AllToAll during attention. + +- **MoE models (with Parallel Folding)**: CP is **folded into** EP — the CP ranks are a subset of the EP ranks. The constraints are `CP ≤ EP` and `EP % CP == 0`. CP does **not** add extra GPUs beyond EP: + ``` + Total GPUs = TP × PP × EP × DP + ``` + Within the EP group, CP determines how many of the EP ranks share context-parallel work on attention. The remaining `EP / CP` factor provides inner data-parallel streams for the attention layers. For the MoE FFN layers, all EP ranks participate in expert parallelism as usual. + + For example, with EP=8 and CP=4: the 8 EP ranks form 2 groups of 4 for context-parallel attention, giving 2 inner-DP streams. Traditional (unfolded) parallelism would require `EP × CP = 32` GPUs; with folding it requires only `EP = 8`. + **How it's modeled**: CP affects the GPU topology for communication routing. Currently included in minimum GPU requirements calculation. ## Communication Modeling @@ -183,16 +288,25 @@ Communication times differ significantly based on: #### Minimum Nodes Required -The minimum nodes required is determined by: +The minimum nodes required depends on the model type: + +**Dense models** (no expert parallelism): ``` -Min Nodes = ceil(TP × PP × EP × CP / GPUs_per_node) +Min GPUs = TP × PP × CP +Min Nodes = ceil(Min GPUs / GPUs_per_node) +``` + +**MoE models** (with MoE Parallel Folding, where CP is folded into EP): +``` +Min GPUs = TP × PP × EP (CP ≤ EP, folded in) +Min Nodes = ceil(Min GPUs / GPUs_per_node) ``` #### Scaling Behavior - **DP scaling**: Linear speedup. Doubling DP halves iteration time (minus communication overhead). - **PP scaling**: Happens in multiples of pipeline replicas. With PP=3, you need 3, 6, 9... nodes to increase scaling. -- **EP scaling**: Divides the experts on EP nodes. +- **EP scaling**: Divides the experts across EP ranks. For MoE models, EP also subsumes the CP dimension. ## Example Output @@ -226,12 +340,170 @@ Multinode Scaling Projection Results ==================================================================================================== ``` +## Prefill Mode Details + +When `--mode prefill` (or `--mode inference`) is specified, the projection tool: + +1. **Benchmarks layers the same way** as training (same `--profiling-mode` options work) +2. **Uses only forward pass times** — backward/wgrad times are discarded +3. **Skips pipeline simulation** — inference doesn't use 1F1B scheduling +4. **Skips optimizer and gradient communication** — no backward pass means no gradients +5. **Reports prefill latency** — time for one forward pass through the full model +6. **Scales DP as replicas** — more nodes = more independent serving replicas, each with the same latency + +### Prefill Output Example + +``` +==================================================================================================== +Inference Projection Results +==================================================================================================== +📊 Parallelism: TP=1, PP=1, EP=8, CP=1 + +🎯 Target Configuration (4 nodes): + Nodes: 4, GPUs: 32 + TP=1, PP=1, EP=8, CP=1, DP(replicas)=4 + Prefill Latency: 125.456 ms (seq_len=4096, micro_batch=1) + Tokens/s per replica: 32,642 + Tokens/s total (4 replicas): 130,568 + Tokens/s/GPU: 4,080 +==================================================================================================== +``` + +### Key Prefill Metrics + +| Metric | Description | +|--------|-------------| +| Prefill Latency | Time for one forward pass (one batch of sequences) | +| Tokens/s per replica | Throughput of a single model instance | +| Tokens/s total | Aggregate throughput across all DP replicas | +| Tokens/s/GPU | Per-GPU throughput efficiency | + +## Decode Mode Details + +Decode mode supports two profiling approaches: + +| Approach | `--profiling-mode` | GPU Required | What it measures | +|----------|-------------------|--------------|-----------------| +| **Analytical** | `simulate` | No | Roofline model: HBM BW × (weights + KV cache) | +| **Benchmark-enhanced** | `benchmark` (default) | Yes | Real GPU GEMM timing (seq_len=1) + analytical KV cache overlay | +| **Both** | `both` | Yes | Runs both side-by-side for comparison | + +### How Decode Differs + +Autoregressive decode generates tokens **one at a time**. Each decode step: +- Loads all model weights from HBM (weights dominate memory access) +- Reads the KV cache (grows with context length) +- Performs tiny GEMMs of shape `[batch, 1, …]` — far below the compute roofline + +This makes decode fundamentally **memory-bandwidth-bound** for typical batch sizes. + +### Analytical Model (`--profiling-mode simulate`) + +Pure roofline model — no GPU needed. The decode time per token is estimated as: + +``` +time_per_token ≈ max(memory_time, compute_time) + comm_overhead +``` + +Where: +- **memory_time** = (total_weight_bytes + kv_cache_bytes) / HBM_bandwidth +- **compute_time** = total_FLOPs / peak_TFLOPS (usually much smaller) +- **comm_overhead** = TP AllReduce latency + PP P2P latency + MoE A2A latency + +### Benchmark-Enhanced Model (`--profiling-mode benchmark`) + +When a GPU is available, the tool can benchmark the actual transformer layers with `seq_len=1` to capture real kernel timings: + +1. **Benchmarks** each layer with `seq_len=1` and `batch_size=decode_batch_size` +2. **Extracts forward-only time** — captures real GEMM costs, TP AllReduce, kernel launch overhead +3. **Overlays analytical KV cache model** — the benchmark's 1-token self-attention doesn't include KV cache reads, so we add them analytically + +``` +decode_time = benchmarked_layer_forward(seq_len=1) + analytical_kv_cache_read + comm_overhead +``` + +This gives the best of both worlds: real hardware measurements for GEMMs plus a physics-based model for KV cache. + +### Decode Output Example (Analytical) + +``` +==================================================================================================== +Decode Projection Results (Analytical Model) +==================================================================================================== +📊 Parallelism: TP=1, PP=1, EP=8, CP=1 + Decode batch size: 32 + Context length (KV cache): 4096 + GPU arch: mi300x + HBM bandwidth: 5300 GB/s + Peak BF16 compute: 1307 TFLOPS + +⏱️ Per-Token Decode Time Breakdown: + Weight loading: 2.8453 ms (14450.0 MB) + KV cache read: 0.3124 ms (1587.5 MB) + Compute: 0.0012 ms (1.20 GFLOPS) + Bottleneck: MEMORY-BOUND (arith intensity=0.07 FLOPs/B, balance=247 FLOPs/B) + ───────────────────────────────────── + Total per token: 3.1577 ms + +🎯 Target Configuration (4 nodes): + ... +==================================================================================================== +``` + +### Decode Output Example (Benchmark-Enhanced) + +``` +==================================================================================================== +Decode Projection Results (Benchmark-Enhanced) +==================================================================================================== +📊 Parallelism: TP=1, PP=1, EP=8, CP=1 + Decode batch size: 32 + Context length (KV cache): 4096 + +⏱️ Per-Token Decode Time Breakdown: + Layer fwd (benchmarked, seq_len=1): 3.0215 ms (includes GEMMs + TP AllReduce) + KV cache read (analytical): 0.3124 ms (1587.5 MB) + ─── Analytical comparison ─── + Weight loading (analytical): 2.8453 ms (14450.0 MB) + Total (analytical): 3.1577 ms + ──────────────────────────── + ───────────────────────────────────── + Total per token: 3.3339 ms + +🎯 Target Configuration (4 nodes): + ... +==================================================================================================== +``` + +### Key Decode Metrics + +| Metric | Description | +|--------|-------------| +| Per-token latency | Time to generate one output token (all sequences in the batch) | +| Tokens/s per replica | Decode throughput of a single model instance | +| Tokens/s total | Aggregate decode throughput across all DP replicas | +| Tokens/s/GPU | Per-GPU decode throughput efficiency | +| Generation time | Estimated wall-clock time to generate N tokens | +| Bottleneck | Whether the decode step is memory-bound or compute-bound (analytical only) | +| KV cache memory | Memory consumed by the KV cache (grows with batch × context length) | + +### Decode-Specific CLI Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--decode-batch-size` | micro_batch_size | Number of sequences decoded concurrently. Larger batches amortize weight loading but increase KV cache memory | +| `--decode-context-length` | sequence_length | Tokens in the KV cache. Longer contexts increase KV cache read time per step | +| `--num-generated-tokens` | 128 | How many tokens to generate. Used to estimate total generation wall-clock time | +| `--profiling-mode` | benchmark | `simulate` for pure analytical (no GPU), `benchmark` for real GPU measurement + KV analytical overlay, `both` for side-by-side comparison | + ## Tips - **Start with 1 node**: Always benchmark on 1 node first to establish baseline performance. - **Understand scaling limits**: DP scaling is limited by global_batch_size / micro_batch_size. If you run out of microbatches, adding more nodes won't help. - **Check minimum nodes**: If your config requires multiple nodes (e.g., PP=4 with 8 GPUs/node), projection will automatically reduce PP for benchmarking. - **Pipeline scaling**: With PP > 1, you can only scale in multiples of the pipeline replica size. +- **Prefill mode**: Use `--mode prefill` (or `--mode inference`) to project prefill latency. The same config file works for both training and prefill projection. +- **Decode mode**: Use `--mode decode` to estimate per-token generation latency. With `--profiling-mode simulate` it runs instantly without a GPU; with `--profiling-mode benchmark` (default) it runs real GEMMs with seq_len=1 and overlays analytical KV cache timing. Use `--profiling-mode both` to compare the two approaches. Tune `--decode-batch-size` and `--decode-context-length` to explore the latency/memory trade-off. ## Related Documentation diff --git a/primus/cli/subcommands/projection.py b/primus/cli/subcommands/projection.py index 53fd62e16..b2917a890 100644 --- a/primus/cli/subcommands/projection.py +++ b/primus/cli/subcommands/projection.py @@ -14,9 +14,24 @@ def run(args, overrides): launch_projection_from_cli(args, overrides) elif args.suite == "performance": - from primus.pretrain import setup_backend_path + # Normalise mode: "inference" is an alias for "prefill" + mode = getattr(args, "mode", "training") + if mode == "inference": + args.mode = "prefill" - setup_backend_path(framework="megatron", verbose=True) + profiling_mode = getattr(args, "profiling_mode", "benchmark") + + # Decode + simulate is fully analytical — no backend needed. + # Decode + benchmark/both needs the Megatron backend (runs real layers + # with seq_len=1 to measure decode-step GEMMs on the GPU). + needs_backend = profiling_mode != "simulate" + if mode == "decode" and profiling_mode == "simulate": + needs_backend = False + + if needs_backend: + from primus.pretrain import setup_backend_path + + setup_backend_path(framework="megatron", verbose=True) from primus.core.projection.performance_projection import ( launch_projection_from_cli, @@ -92,6 +107,104 @@ def register_subcommand(subparsers): "If not provided, uses default cluster parameters.\n\n" ), ) + performance.add_argument( + "--profiling-mode", + type=str, + required=False, + default="benchmark", + choices=["benchmark", "simulate", "both"], + help=( + "Profiling mode for layer timing:\n" + " benchmark - Run actual GPU benchmarks (default, requires GPU)\n" + " simulate - Use simulation backends (origami for GEMM,\n" + " analytical model for SDPA). No GPU required.\n" + " both - Run both benchmark and simulation, report side-by-side\n" + ), + ) + performance.add_argument( + "--gemm-backend", + type=str, + required=False, + default=None, + choices=["origami"], + help=( + "GEMM simulation backend (only used when --profiling-mode is 'simulate' or 'both').\n" + " origami - Open-source GEMM performance model (default)\n" + ), + ) + performance.add_argument( + "--gpu-arch", + type=str, + required=False, + default=None, + help=( + "Target GPU architecture for simulation (e.g. 'mi300x', 'gfx942', 'mi355x', 'gfx950').\n" + "If not specified, auto-detected or uses PRIMUS_GPU_ARCH env var.\n" + ), + ) + performance.add_argument( + "--gpu-clock-mhz", + type=int, + required=False, + default=None, + help=( + "Override the GPU compute clock frequency in MHz for simulation.\n" + "If not specified, uses the default from the hardware profile for the\n" + "given --gpu-arch (e.g. 2100 MHz for MI300X/MI325X).\n" + "Can also be set via the PRIMUS_GPU_CLOCK_MHZ env var.\n" + "Example: --gpu-clock-mhz 1500\n" + ), + ) + performance.add_argument( + "--mode", + type=str, + required=False, + default="training", + choices=["training", "inference", "prefill", "decode"], + help=( + "Projection mode:\n" + " training - Project training iteration time (forward + backward +\n" + " optimizer step + gradient AllReduce). Default.\n" + " inference - Alias for 'prefill'.\n" + " prefill - Project inference prefill latency (forward-only, no\n" + " backward pass, optimizer, or gradient communication).\n" + " decode - Project autoregressive decode latency per token.\n" + " With --profiling-mode simulate: fully analytical (no GPU).\n" + " With --profiling-mode benchmark: benchmarks GEMMs with\n" + " seq_len=1 on GPU, overlays analytical KV cache model.\n" + ), + ) + performance.add_argument( + "--decode-batch-size", + type=int, + required=False, + default=None, + help=( + "Number of sequences being decoded concurrently (decode mode only).\n" + "Defaults to micro_batch_size from the config.\n" + ), + ) + performance.add_argument( + "--decode-context-length", + type=int, + required=False, + default=None, + help=( + "Current context length during decode, i.e. number of previous tokens\n" + "in the KV cache (decode mode only). Affects KV cache read time.\n" + "Defaults to sequence_length from the config.\n" + ), + ) + performance.add_argument( + "--num-generated-tokens", + type=int, + required=False, + default=None, + help=( + "Number of tokens to generate (decode mode only). Used to estimate\n" + "total generation time. Defaults to 128.\n" + ), + ) parser.set_defaults(func=run) diff --git a/primus/core/projection/module_profilers/attention.py b/primus/core/projection/module_profilers/attention.py index 63f8dd6dd..f227a00fa 100644 --- a/primus/core/projection/module_profilers/attention.py +++ b/primus/core/projection/module_profilers/attention.py @@ -19,6 +19,8 @@ def __init__(self, config, sub_profilers=None): self.module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend + self._sdpa_backend = None # Optional: SDPA simulation backend def set_module(self, module): """Set the actual attention module for benchmarking.""" @@ -27,6 +29,18 @@ def set_module(self, module): self._cached_results = None self._cache_key = None + def set_gemm_backend(self, backend): + """Set a GEMM simulation backend for attention linear projections.""" + self._gemm_backend = backend + self._cached_results = None + self._cache_key = None + + def set_sdpa_backend(self, backend): + """Set an SDPA simulation backend for attention computation.""" + self._sdpa_backend = backend + self._cached_results = None + self._cache_key = None + def estimated_num_params(self, rank: Optional[int] = None) -> int: args = self.config.model_config # Group-query & multi-latent attention support. @@ -131,23 +145,201 @@ def _num_query_groups() -> int: return tokens_per_rank * (activation_width + ln_width) * bytes_per_value + def _simulate_mla_gemms(self, batch_tokens: int, dtype: str) -> tuple[float, float]: + """Simulate MLA (Multi-Latent Attention) projection GEMMs. + + MLA uses LoRA-factored Q and compressed KV projections instead of + standard Q/K/V projections: + Forward (6 GEMMs): Q_down, Q_up, KV_down, KV_up, RoPE_proj, O_proj + Backward (12 GEMMs): dgrad + wgrad for each of the 6 projections + """ + args = self.config.model_config + backend = self._gemm_backend + + hidden = args.hidden_size + heads = args.num_attention_heads + q_lora_rank = args.q_lora_rank + kv_lora_rank = args.kv_lora_rank + qk_head_dim = args.qk_head_dim + qk_pos_emb_head_dim = args.qk_pos_emb_head_dim + v_head_dim = args.v_head_dim + + fwd_time = 0.0 + bwd_time = 0.0 + T = batch_tokens + + # ---------- Forward ---------- + if q_lora_rank is not None: + # Q down-proj: [T, hidden] × [hidden, q_lora_rank] + q_down_out = q_lora_rank + r = backend.simulate_gemm(T, q_down_out, hidden, dtype) + fwd_time += r.forward_time_ms + # Q up-proj: [T, q_lora_rank] × [q_lora_rank, heads*(qk_hd+qk_pe_hd)] + q_up_out = heads * (qk_head_dim + qk_pos_emb_head_dim) + r = backend.simulate_gemm(T, q_up_out, q_lora_rank, dtype) + fwd_time += r.forward_time_ms + else: + # Direct Q projection (no LoRA): [T, hidden] × [hidden, heads*(qk_hd+qk_pe_hd)] + q_up_out = heads * (qk_head_dim + qk_pos_emb_head_dim) + r = backend.simulate_gemm(T, q_up_out, hidden, dtype) + fwd_time += r.forward_time_ms + + # KV down-proj: [T, hidden] × [hidden, kv_lora_rank] + kv_down_out = kv_lora_rank + r = backend.simulate_gemm(T, kv_down_out, hidden, dtype) + fwd_time += r.forward_time_ms + # KV up-proj: [T, kv_lora_rank] × [kv_lora_rank, heads*(qk_hd+v_hd)] + kv_up_out = heads * (qk_head_dim + v_head_dim) + r = backend.simulate_gemm(T, kv_up_out, kv_lora_rank, dtype) + fwd_time += r.forward_time_ms + + # RoPE positional embedding projection: [T, hidden] × [hidden, qk_pos_emb_head_dim] + r = backend.simulate_gemm(T, qk_pos_emb_head_dim, hidden, dtype) + fwd_time += r.forward_time_ms + + # Output projection: [T, heads*v_hd] × [heads*v_hd, hidden] + o_in = heads * v_head_dim + r = backend.simulate_gemm(T, hidden, o_in, dtype) + fwd_time += r.forward_time_ms + + # ---------- Backward (dgrad + wgrad for each projection) ---------- + if q_lora_rank is not None: + # Q down-proj dgrad: [T, q_down_out] × [q_down_out, hidden] → [T, hidden] + r = backend.simulate_gemm(T, hidden, q_down_out, dtype) + bwd_time += r.forward_time_ms + # Q down-proj wgrad: [hidden, T] × [T, q_down_out] → [hidden, q_down_out] + r = backend.simulate_gemm(hidden, q_down_out, T, dtype) + bwd_time += r.forward_time_ms + # Q up-proj dgrad: [T, q_up_out] × [q_up_out, q_lora_rank] → [T, q_lora_rank] + r = backend.simulate_gemm(T, q_lora_rank, q_up_out, dtype) + bwd_time += r.forward_time_ms + # Q up-proj wgrad: [q_lora_rank, T] × [T, q_up_out] → [q_lora_rank, q_up_out] + r = backend.simulate_gemm(q_lora_rank, q_up_out, T, dtype) + bwd_time += r.forward_time_ms + else: + # Direct Q dgrad + wgrad + r = backend.simulate_gemm(T, hidden, q_up_out, dtype) + bwd_time += r.forward_time_ms + r = backend.simulate_gemm(hidden, q_up_out, T, dtype) + bwd_time += r.forward_time_ms + + # KV down-proj dgrad + wgrad + r = backend.simulate_gemm(T, hidden, kv_down_out, dtype) + bwd_time += r.forward_time_ms + r = backend.simulate_gemm(hidden, kv_down_out, T, dtype) + bwd_time += r.forward_time_ms + # KV up-proj dgrad + wgrad + r = backend.simulate_gemm(T, kv_lora_rank, kv_up_out, dtype) + bwd_time += r.forward_time_ms + r = backend.simulate_gemm(kv_lora_rank, kv_up_out, T, dtype) + bwd_time += r.forward_time_ms + + # RoPE proj dgrad + wgrad + r = backend.simulate_gemm(T, hidden, qk_pos_emb_head_dim, dtype) + bwd_time += r.forward_time_ms + r = backend.simulate_gemm(hidden, qk_pos_emb_head_dim, T, dtype) + bwd_time += r.forward_time_ms + + # O proj dgrad + wgrad + r = backend.simulate_gemm(T, o_in, hidden, dtype) + bwd_time += r.forward_time_ms + r = backend.simulate_gemm(o_in, hidden, T, dtype) + bwd_time += r.forward_time_ms + + return fwd_time, bwd_time + + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Get simulated results from GEMM + SDPA simulation backends.""" + args = self.config.model_config + mp = self.config.model_parallel_config + tp_size = max(1, mp.tensor_model_parallel_size) + cp_size = max(1, mp.context_model_parallel_size) + + batch_tokens = batch_size * seq_len // tp_size // cp_size + slen_per_cp = seq_len // cp_size + + fwd_time = 0.0 + bwd_time = 0.0 + + # 1. Simulate linear projection GEMMs using GEMM backend + if self._gemm_backend is not None: + gemm_dtype = "fp8" if getattr(args, "fp8", None) else "bf16" + + if getattr(args, "multi_latent_attention", False): + # MLA: LoRA-factored Q and compressed KV projections + # 6 forward GEMMs + 12 backward GEMMs + mla_fwd, mla_bwd = self._simulate_mla_gemms(batch_tokens, gemm_dtype) + fwd_time += mla_fwd + bwd_time += mla_bwd + else: + # Standard attention: Q, K, V, O projections + # 4 forward GEMMs + 8 backward GEMMs + num_query_groups = ( + args.num_query_groups + if args.group_query_attention and args.num_query_groups + else args.num_attention_heads + ) + gemm_result = self._gemm_backend.simulate_attention_gemms( + batch_tokens=batch_tokens, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + kv_channels=args.kv_channels, + num_query_groups=num_query_groups, + dtype=gemm_dtype, + ) + fwd_time += gemm_result.forward_time_ms + bwd_time += gemm_result.backward_time_ms + + # 2. Simulate SDPA core computation using SDPA backend + if self._sdpa_backend is not None: + heads_per_rank = max(1, args.num_attention_heads // tp_size) + + if getattr(args, "multi_latent_attention", False): + # MLA: Q·Kᵀ uses qk_head_dim + qk_pos_emb_head_dim (e.g. 192), + # P·V uses v_head_dim (e.g. 128). + sdpa_head_dim = args.qk_head_dim + args.qk_pos_emb_head_dim + sdpa_head_dim_v = args.v_head_dim + else: + sdpa_head_dim = args.kv_channels + sdpa_head_dim_v = None # same as head_dim + + sdpa_result = self._sdpa_backend.simulate_sdpa( + batch_size=batch_size, + num_heads=heads_per_rank, + seq_len=slen_per_cp, + head_dim=sdpa_head_dim, + causal=True, + dtype="bf16", + head_dim_v=sdpa_head_dim_v, + ) + fwd_time += sdpa_result.forward_time_ms + bwd_time += sdpa_result.backward_time_ms + + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - # Context parallel / Sequence parallel adjustment - cp_size = self.config.model_parallel_config.context_model_parallel_size - # Effective sequence length per rank if CP is used - slen_per_cp = seq_len // cp_size - - self._cached_results = benchmark_layer( - self.module, - [ - (seq_len, batch_size, self.config.model_config.hidden_size), - ((1, 1, slen_per_cp, seq_len), torch.bool), - ], - ) + if self._gemm_backend is not None or self._sdpa_backend is not None: + # Use simulation mode + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + # Use actual GPU benchmarking + # Context parallel / Sequence parallel adjustment + cp_size = self.config.model_parallel_config.context_model_parallel_size + # Effective sequence length per rank if CP is used + slen_per_cp = seq_len // cp_size + + self._cached_results = benchmark_layer( + self.module, + [ + (seq_len, batch_size, self.config.model_config.hidden_size), + ((1, 1, slen_per_cp, seq_len), torch.bool), + ], + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/module_profilers/collective_args.py b/primus/core/projection/module_profilers/collective_args.py index f8c03914f..b377e153e 100644 --- a/primus/core/projection/module_profilers/collective_args.py +++ b/primus/core/projection/module_profilers/collective_args.py @@ -51,7 +51,18 @@ class CollectiveArgs: nics_per_node: Optional[int] = 8 # NICs per node (None = gpus_per_node) # All-to-all specific - a2a_peer_lat: float = 0.45 # Per-peer latency overhead for a2a + a2a_peer_lat: float = 0.45 # Per-peer latency overhead for inter-node a2a + a2a_intra_node_peer_lat: float = 28.0 # Per-peer latency overhead for intra-node a2a + # Intra-node overhead is higher (~19-28 us) due to: + # - P2P scatter/gather scheduling overhead + # - RCCL internal synchronization barriers + # - Memory copy and buffer management + # Note: Preflight measurements for EP=8 intra-node A2A show: + # - Linear extrapolation: ~27.4 us per peer + # - 2MB measurement: ~28.1 us per peer + # - After subtracting bandwidth component: ~19.4 us per peer + # Default 28 us matches preflight measurements (middle of range) + # Can be overridden via hardware_config for GPU-specific calibration def get_default_args( diff --git a/primus/core/projection/module_profilers/collective_model.py b/primus/core/projection/module_profilers/collective_model.py index 792dc0e22..1b7358cc4 100644 --- a/primus/core/projection/module_profilers/collective_model.py +++ b/primus/core/projection/module_profilers/collective_model.py @@ -443,7 +443,10 @@ def single_shot_alltoall(args, msg_size, gpus, groups=None, protocol=None): return 0 intra_node_fanout, inter_node_fanout = get_max_fanout(args) msg_size_per_peer = ceil(msg_size / gpus) - gpus_per_node = min(gpus, args.node_size) + # Account for TP striding: with TP (hp) > 1, each EP rank occupies + # hp GPUs, so only node_size/hp EP ranks fit on a single node. + hp = getattr(args, "hp", 1) + gpus_per_node = min(gpus, args.node_size // max(hp, 1)) nics_per_node = args.nics_per_node if args.nics_per_node else gpus_per_node intra_node_gpus = gpus_per_node - 1 inter_node_gpus = max(0, gpus - gpus_per_node) @@ -487,8 +490,11 @@ def hierarchical_alltoall(args, msg_size, gpus, groups=None, protocol=None): if gpus == 1 or msg_size == 0: return 0 - gpus_per_node = min(gpus, args.node_size) - num_nodes = ceil(gpus / args.node_size) + # Account for TP striding: with TP (hp) > 1, each EP rank occupies + # hp GPUs, so only node_size/hp EP ranks fit on a single node. + hp = getattr(args, "hp", 1) + gpus_per_node = min(gpus, args.node_size // max(hp, 1)) + num_nodes = ceil(gpus / max(gpus_per_node, 1)) nics_per_node = args.nics_per_node if args.nics_per_node else gpus_per_node if num_nodes == 1: @@ -518,6 +524,76 @@ def hierarchical_alltoall(args, msg_size, gpus, groups=None, protocol=None): return t_total +def pxn_alltoall(args, msg_size, gpus, groups=None, protocol=None): + """ + PXN All-to-All - pipelined implementation for DeepEP. + + Based on DeepEP implementation of pipelined PXN-A2A with pipelined + scale-up (intra-node) and scale-out (inter-node) communication. + + Key features: + - Overlaps scale-up and scale-out communication + - Scale-out doesn't start until 4MB is accumulated for dispatch + - Falls back to regular alltoall for single-node configurations + """ + if gpus == 1 or msg_size == 0: + return 0 + + original_msg_size = msg_size + + # Account for TP striding: with TP (hp) > 1, each EP rank occupies + # hp GPUs, so only node_size/hp EP ranks fit on a single node. + hp = getattr(args, "hp", 1) + effective_gpus_per_node = min(gpus, args.node_size // max(hp, 1)) + num_nodes = ceil(gpus / max(effective_gpus_per_node, 1)) + + # If A2A is not crossing node boundaries, fall back to regular alltoall + if num_nodes <= 1: + return min( + single_shot_alltoall(args, msg_size, gpus, groups, protocol), + run_alltoall(args, msg_size, gpus, groups, protocol), + ) + + # PXN - AlltoAll - pipeline implementation + chunk_size = 4194304 # 4MB - DeepEP waits until 4MB is accumulated + + # Scale-out message size (inter-node communication) + scale_out_msg_size = int(original_msg_size * (num_nodes - 1) / num_nodes) + + # Scale-up delay: time to accumulate 4MB before scale-out starts + if scale_out_msg_size < chunk_size: + # If total scale-out msg size is less than 4MB, + # total time = scaleup_delay + scaleout_time + node_lat, _ = node_latency_and_volume_protocol(args, scale_out_msg_size, protocol) + scaleup_delay = node_lat + scale_out_msg_size / args.node_bw * 1.0e-3 + else: + # Scale-out comm doesn't start until 4MB is accumulated + node_lat, _ = node_latency_and_volume_protocol(args, chunk_size, protocol) + scaleup_delay = node_lat + chunk_size / args.node_bw * 1.0e-3 + + # Assume PXN style alltoall with overlapped scale-up and scale-out + node_msg_size = int(original_msg_size * (effective_gpus_per_node - 1) / effective_gpus_per_node) + scale_out_msg_size = int(original_msg_size * (num_nodes - 1) / num_nodes) + + # Calculate latencies with protocol inflation + node_lat, node_msg_size_adj = node_latency_and_volume_protocol(args, node_msg_size, protocol) + pod_lat, scale_out_msg_size_adj = pod_latency_and_volume_protocol(args, scale_out_msg_size, protocol) + + # Scale-up (intra-node) time + node_bw = args.bw_eff * args.node_bw + t_a2a_node = node_lat + node_msg_size_adj / node_bw * 1.0e-3 + + # Scale-out (inter-node) time with scaleup delay + pod_bw = args.bw_eff * args.pod_bw + t_a2a_scale_out = pod_lat + scale_out_msg_size_adj / pod_bw * 1.0e-3 + scaleup_delay + + # Total time is max of scale-up and scale-out (they overlap) + t_total = max(t_a2a_node, t_a2a_scale_out) + t_total += args.kernel_launch_latency + + return t_total + + def single_shot_allgather(args, msg_size, gpus, groups=None, protocol=None): """ Single shot allgather with max fanout and overlap. @@ -629,23 +705,51 @@ def alltoall(args, msg_size, gpus, groups=["ep"]): Select best alltoall algorithm among several options. Tries multiple protocols and algorithms, returns fastest. Applies per-peer latency overhead and minimum latency floor. + + If DeepEP is enabled (moe_enable_deepep=True), uses PXN All-to-All + which pipelines scale-up and scale-out communication. """ + # Check if DeepEP is enabled + use_deepep = getattr(args, "moe_enable_deepep", False) or getattr(args, "use_turbo_deepep", False) + min_a2a_time = float("inf") for p in ["simple", "ll", "ll64", "ll128"]: - direct_a2a_time = run_alltoall(args, msg_size, gpus, protocol=p) - single_shot_a2a_time = single_shot_alltoall(args, msg_size, gpus, protocol=p) - hierarchical_a2a_time = hierarchical_alltoall(args, msg_size, gpus, protocol=p) - a2a_time = min(direct_a2a_time, single_shot_a2a_time, hierarchical_a2a_time) + if use_deepep: + # Use PXN All-to-All for DeepEP + pxn_a2a_time = pxn_alltoall(args, msg_size, gpus, protocol=p) + a2a_time = pxn_a2a_time + else: + # Use regular All-to-All algorithms + direct_a2a_time = run_alltoall(args, msg_size, gpus, protocol=p) + single_shot_a2a_time = single_shot_alltoall(args, msg_size, gpus, protocol=p) + hierarchical_a2a_time = hierarchical_alltoall(args, msg_size, gpus, protocol=p) + a2a_time = min(direct_a2a_time, single_shot_a2a_time, hierarchical_a2a_time) + if a2a_time < min_a2a_time: min_a2a_time = a2a_time - # Add per-peer latency overhead for inter-node communication - # This accounts for RDMA QP setup, work request posting, completion polling, etc. - if hasattr(args, "a2a_peer_lat") and args.a2a_peer_lat > 0: - gpus_per_node = args.node_size - inter_node_peers = max(0, gpus - gpus_per_node) - peer_overhead = args.a2a_peer_lat * inter_node_peers - min_a2a_time += peer_overhead + # Add per-peer latency overhead for ALL A2A communication (both intra and inter-node) + # This accounts for: + # - P2P scatter/gather scheduling overhead + # - RCCL internal synchronization barriers + # - Memory copy and buffer management + # - RDMA QP setup, work request posting, completion polling (for inter-node) + # For intra-node A2A, overhead is significant due to synchronization and scheduling + # Measured overhead: ~50 us per peer for intra-node (vs ~0.45 us for inter-node) + gpus_per_node = args.node_size + intra_node_peers = min(gpus - 1, gpus_per_node - 1) # Peers within same node + inter_node_peers = max(0, gpus - gpus_per_node) # Peers on other nodes + + # Intra-node overhead is much higher due to synchronization and scheduling + # Based on preflight measurements: EP=8 intra-node A2A needs ~19-28 us per peer + # Inter-node overhead is lower (~0.45 us per peer) due to RDMA efficiency + intra_node_overhead_per_peer = getattr(args, "a2a_intra_node_peer_lat", 28.0) # Default 28 us + inter_node_overhead_per_peer = getattr(args, "a2a_peer_lat", 0.45) # Default 0.45 us + + peer_overhead = ( + intra_node_overhead_per_peer * intra_node_peers + inter_node_overhead_per_peer * inter_node_peers + ) + min_a2a_time += peer_overhead return min_a2a_time diff --git a/primus/core/projection/module_profilers/dense_mlp.py b/primus/core/projection/module_profilers/dense_mlp.py index d0a9aaadb..91c4ce0f2 100644 --- a/primus/core/projection/module_profilers/dense_mlp.py +++ b/primus/core/projection/module_profilers/dense_mlp.py @@ -19,6 +19,7 @@ def __init__(self, config, sub_profilers=None): self.module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend def set_module(self, module): """Set the actual Dense MLP module for benchmarking.""" @@ -27,6 +28,13 @@ def set_module(self, module): self._cached_results = None self._cache_key = None + def set_gemm_backend(self, backend): + """Set a GEMM simulation backend for simulated profiling.""" + self._gemm_backend = backend + # Invalidate cache when backend changes + self._cached_results = None + self._cache_key = None + def estimated_num_params(self, rank: Optional[int] = None) -> int: # For SwiGLU: 3 projections (gate, up, down) # For standard FFN: 2 projections (up, down) @@ -58,14 +66,39 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: # Peak memory is input + intermediate (both needed for backward) return intermediate_memory + activation_memory + output_memory + def _get_simulated_results(self, batch_size: int, seq_len: int) -> Tuple[float, float, int]: + """Get simulated results from the GEMM simulation backend.""" + tp_size = self.config.model_parallel_config.tensor_model_parallel_size + cp_size = self.config.model_parallel_config.context_model_parallel_size + batch_tokens = batch_size * seq_len // tp_size // cp_size + + # FP8-hybrid: MLP projections (gate, up, down) run in FP8 + gemm_dtype = "fp8" if getattr(self.config.model_config, "fp8", None) else "bf16" + sim_result = self._gemm_backend.simulate_mlp_gemms( + batch_tokens=batch_tokens, + hidden_size=self.config.model_config.hidden_size, + ffn_hidden_size=self.config.model_config.ffn_hidden_size, + dtype=gemm_dtype, + swiglu=self.config.model_config.swiglu, + ) + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return ( + sim_result.forward_time_ms, + sim_result.backward_time_ms, + activation_memory, + ) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> Tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - self._cached_results = benchmark_layer( - self.module, - [(seq_len, batch_size, self.config.model_config.hidden_size)], - ) + if self._gemm_backend is not None: + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + self._cached_results = benchmark_layer( + self.module, + [(seq_len, batch_size, self.config.model_config.hidden_size)], + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/module_profilers/embedding.py b/primus/core/projection/module_profilers/embedding.py index 5640e9de0..cfd2c3a90 100644 --- a/primus/core/projection/module_profilers/embedding.py +++ b/primus/core/projection/module_profilers/embedding.py @@ -19,6 +19,7 @@ def __init__(self, config, sub_profilers=None): self.module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._simulation_mode = False # Set to True when simulation backends are active def set_module(self, module): """Set the actual module for benchmarking.""" @@ -27,6 +28,12 @@ def set_module(self, module): self._cached_results = None self._cache_key = None + def set_simulation_mode(self, enabled: bool = True): + """Enable simulation mode (embedding lookup is estimated analytically).""" + self._simulation_mode = enabled + self._cached_results = None + self._cache_key = None + def estimated_num_params(self, rank: Optional[int] = None) -> int: return self.config.model_config.padded_vocab_size * self.config.model_config.hidden_size @@ -40,22 +47,41 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: * 2 ) # bf16 + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Estimate embedding time analytically (lookup is memory-bound, very fast).""" + tp_size = self.config.model_parallel_config.tensor_model_parallel_size + cp_size = self.config.model_parallel_config.context_model_parallel_size + tokens = batch_size * seq_len // tp_size // cp_size + hidden = self.config.model_config.hidden_size + # Embedding lookup: read tokens indices + write output embeddings + # Very fast relative to GEMM layers – use small fixed estimate + output_bytes = tokens * hidden * 2 # bf16 + # Assume ~4 TB/s effective bandwidth (MI300X), 1 read + 1 write pass + bw_bytes_per_ms = 4e9 # ~4 TB/s → bytes/ms + fwd_time = max(0.01, output_bytes / bw_bytes_per_ms) + bwd_time = fwd_time # Gradient scatter is similar cost + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - # Context parallel / Sequence parallel adjustment - cp_size = self.config.model_parallel_config.context_model_parallel_size - # Effective sequence length per rank if CP is used - slen_per_cp = seq_len // cp_size - - self._cached_results = benchmark_layer( - self.module, - [ - ((batch_size, slen_per_cp), torch.int64), - ], - ) + if self._simulation_mode: + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + # Context parallel / Sequence parallel adjustment + cp_size = self.config.model_parallel_config.context_model_parallel_size + # Effective sequence length per rank if CP is used + slen_per_cp = seq_len // cp_size + + self._cached_results = benchmark_layer( + self.module, + [ + ((batch_size, slen_per_cp), torch.int64), + ], + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/module_profilers/language_model.py b/primus/core/projection/module_profilers/language_model.py index ac15e1fd7..a975ab73b 100644 --- a/primus/core/projection/module_profilers/language_model.py +++ b/primus/core/projection/module_profilers/language_model.py @@ -157,6 +157,33 @@ def __init__(self, config, sub_profilers=None): ep_size=self.config.model_parallel_config.expert_model_parallel_size, num_virtual_pipeline_stages=self.config.model_parallel_config.virtual_pipeline_model_parallel_size, ) + self._gemm_backend = None + self._sdpa_backend = None + + def set_simulation_backends(self, gemm_backend=None, sdpa_backend=None): + """Set simulation backends and propagate to all sub-profilers.""" + self._gemm_backend = gemm_backend + self._sdpa_backend = sdpa_backend + + # Propagate to transformer layer sub-profilers (which further propagate + # to attention, MLP, router sub-profilers). + for key in ("dense_transformer_layer", "moe_transformer_layer"): + if key in self.sub_profilers and self.sub_profilers[key] is not None: + layer_profiler = self.sub_profilers[key] + if hasattr(layer_profiler, "set_simulation_backends"): + layer_profiler.set_simulation_backends(gemm_backend, sdpa_backend) + + # Propagate to embedding (uses simple analytical estimate in sim mode). + if "embedding" in self.sub_profilers and self.sub_profilers["embedding"] is not None: + emb = self.sub_profilers["embedding"] + if hasattr(emb, "set_simulation_mode"): + emb.set_simulation_mode(gemm_backend is not None or sdpa_backend is not None) + + # Propagate GEMM backend to output layer (vocab projection GEMM). + if "output_layer" in self.sub_profilers and self.sub_profilers["output_layer"] is not None: + out = self.sub_profilers["output_layer"] + if gemm_backend is not None and hasattr(out, "set_gemm_backend"): + out.set_gemm_backend(gemm_backend) def get_layers_for_rank( self, @@ -422,70 +449,108 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return int(total_act) def run_layer_benchmark(self, model, batch_size: int, seq_len: int) -> dict: - """Benchmark transformer layers plus embedding/output layers on this rank.""" + """Benchmark or simulate transformer layers plus embedding/output layers on this rank. - def unwrap_module(module): - """Recursively unwrap DistributedDataParallel / pipeline wrappers.""" - return unwrap_module(module.module) if hasattr(module, "module") else module + Supports two modes: + - **benchmark** (default): Runs actual GPU kernels and measures timing. + Requires *model* to be a real instantiated model (or list of model chunks). + - **simulate**: Uses GEMM and SDPA simulation backends. The *model* + parameter may be ``None`` – no GPU is required. - model_chunks = model if isinstance(model, list) else [model] + The mode is automatically selected based on whether simulation backends + have been set via :meth:`set_simulation_backends`. + """ + is_simulation_mode = self._gemm_backend is not None or self._sdpa_backend is not None + # ----------------------------------------------------------------- + # Unwrap model (only when an actual model is provided) + # ----------------------------------------------------------------- embedding_module = None output_module = None all_layers = [] - for chunk in model_chunks: - unwrapped = unwrap_module(chunk) - - language_model = getattr(unwrapped, "language_model", None) - if language_model is not None: - if hasattr(language_model, "embedding"): - embedding_module = language_model.embedding - if hasattr(language_model, "output_layer"): - output_module = language_model.output_layer - - if hasattr(language_model, "encoder") and hasattr(language_model.encoder, "layers"): - all_layers.extend(language_model.encoder.layers) - elif hasattr(language_model, "decoder") and hasattr(language_model.decoder, "layers"): - all_layers.extend(language_model.decoder.layers) - elif hasattr(language_model, "layers"): - all_layers.extend(language_model.layers) - continue + if model is not None: - if hasattr(unwrapped, "decoder") and hasattr(unwrapped.decoder, "layers"): - all_layers.extend(unwrapped.decoder.layers) - elif hasattr(unwrapped, "layers"): - all_layers.extend(unwrapped.layers) - else: - raise ValueError(f"Cannot find transformer layers in model chunk: {type(unwrapped)}") - if hasattr(unwrapped, "embedding"): - embedding_module = unwrapped.embedding - if hasattr(unwrapped, "output_layer"): - output_module = unwrapped.output_layer + def unwrap_module(module): + """Recursively unwrap DistributedDataParallel / pipeline wrappers.""" + return unwrap_module(module.module) if hasattr(module, "module") else module + + model_chunks = model if isinstance(model, list) else [model] + + for chunk in model_chunks: + unwrapped = unwrap_module(chunk) + + language_model = getattr(unwrapped, "language_model", None) + if language_model is not None: + if hasattr(language_model, "embedding"): + embedding_module = language_model.embedding + if hasattr(language_model, "output_layer"): + output_module = language_model.output_layer + + if hasattr(language_model, "encoder") and hasattr(language_model.encoder, "layers"): + all_layers.extend(language_model.encoder.layers) + elif hasattr(language_model, "decoder") and hasattr(language_model.decoder, "layers"): + all_layers.extend(language_model.decoder.layers) + elif hasattr(language_model, "layers"): + all_layers.extend(language_model.layers) + continue + + if hasattr(unwrapped, "decoder") and hasattr(unwrapped.decoder, "layers"): + all_layers.extend(unwrapped.decoder.layers) + elif hasattr(unwrapped, "layers"): + all_layers.extend(unwrapped.layers) + else: + raise ValueError(f"Cannot find transformer layers in model chunk: {type(unwrapped)}") + if hasattr(unwrapped, "embedding"): + embedding_module = unwrapped.embedding + if hasattr(unwrapped, "output_layer"): + output_module = unwrapped.output_layer + elif not is_simulation_mode: + raise ValueError( + "model=None is only allowed when simulation backends are set " + "(call set_simulation_backends first)" + ) is_rank_0 = int(os.getenv("RANK", "0")) == 0 + mode_label = "Simulating" if is_simulation_mode else "Benchmarking" if is_rank_0: - print(f"\n[Primus:Performance Projection] Found {len(all_layers)} transformer layers") + if model is not None: + print(f"\n[Primus:Performance Projection] Found {len(all_layers)} transformer layers") + else: + print(f"\n[Primus:Performance Projection] Pure simulation mode (no model)") print(f"[Primus:Performance Projection] This rank is responsible for layers: {self.layers}") + if is_simulation_mode: + backends = [] + if self._gemm_backend is not None: + backends.append(f"GEMM={self._gemm_backend.name()}") + if self._sdpa_backend is not None: + backends.append(f"SDPA={self._sdpa_backend.name()}") + print(f"[Primus:Performance Projection] Mode: SIMULATION ({', '.join(backends)})") embedding_stats = None output_stats = None - # Benchmark embedding if this rank hosts it. + # ---------------------------------------------------------------------- + # Benchmark / simulate embedding layer (if this rank hosts it) + # ---------------------------------------------------------------------- if 0 in self.layers: - if embedding_module is None: + if model is not None and embedding_module is None and not is_simulation_mode: if is_rank_0: print("[Primus:Performance Projection] WARNING: Embedding module not found on this rank.") else: if is_rank_0: - print("[Primus:Performance Projection] Benchmarking embedding layer...") + print(f"[Primus:Performance Projection] {mode_label} embedding layer...") profiler = self.sub_profilers["embedding"] - module = ( - embedding_module.word_embeddings - if hasattr(embedding_module, "word_embeddings") - else embedding_module - ) - profiler.set_module(module) + if embedding_module is not None: + module = ( + embedding_module.word_embeddings + if hasattr(embedding_module, "word_embeddings") + else embedding_module + ) + profiler.set_module(module) + # In simulation mode without a model, the profiler uses its + # analytical estimate (set_simulation_mode was already called + # by set_simulation_backends). emb_forward = profiler.measured_forward_time(batch_size, seq_len) emb_backward = profiler.measured_backward_time(batch_size, seq_len) emb_mem = profiler.measured_activation_memory(batch_size, seq_len) @@ -502,19 +567,25 @@ def unwrap_module(module): "activation_memory_bytes": emb_mem, } - # Benchmark output layer if this rank hosts the final layer. + # ---------------------------------------------------------------------- + # Benchmark / simulate output layer (if this rank hosts the final layer) + # ---------------------------------------------------------------------- last_layer_id = self.config.model_config.num_layers - 1 if last_layer_id in self.layers: - if output_module is None: + if model is not None and output_module is None and not is_simulation_mode: if is_rank_0: print( "[Primus:Performance Projection] WARNING: Output layer module not found on this rank." ) else: if is_rank_0: - print("[Primus:Performance Projection] Benchmarking output layer...") + print(f"[Primus:Performance Projection] {mode_label} output layer...") profiler = self.sub_profilers["output_layer"] - profiler.set_module(output_module) + if output_module is not None: + profiler.set_module(output_module) + # In simulation mode without a model, the output_layer profiler + # uses its GEMM backend (set_gemm_backend was already called by + # set_simulation_backends). out_forward = profiler.measured_forward_time(batch_size, seq_len) out_backward = profiler.measured_backward_time(batch_size, seq_len) out_mem = profiler.measured_activation_memory(batch_size, seq_len) @@ -532,13 +603,14 @@ def unwrap_module(module): } # ============================================================================== - # BENCHMARK LAYER TYPES (one of each type: dense, moe) + # BENCHMARK / SIMULATE LAYER TYPES (one of each type: dense, moe) # ============================================================================== results = {} profiled_types = set() for layer_idx in self.layers: - if layer_idx >= len(all_layers): + # In benchmark mode, guard against out-of-range layer indices. + if model is not None and layer_idx >= len(all_layers): if is_rank_0: print(f"[WARNING] Layer index {layer_idx} exceeds available layers ({len(all_layers)})") continue @@ -549,10 +621,8 @@ def unwrap_module(module): if layer_type in profiled_types: continue - layer_module = all_layers[layer_idx] - if is_rank_0: - print(f"\n[Primus:Performance Projection] Benchmarking Layer {layer_idx} ({layer_type})...") + print(f"\n[Primus:Performance Projection] {mode_label} Layer {layer_idx} ({layer_type})...") # Get the appropriate profiler if is_moe: @@ -560,21 +630,23 @@ def unwrap_module(module): else: layer_profiler = self.sub_profilers["dense_transformer_layer"] - # Set the layer module - layer_profiler.set_layer_module(layer_module) + # Set the layer module only when a real model is available. + if model is not None: + layer_module = all_layers[layer_idx] + layer_profiler.set_layer_module(layer_module) - # Benchmark full layer (uses optimized benchmark_layer with 64 iterations, warm caches) + # Benchmark/simulate full layer forward_time = layer_profiler.measured_forward_time(batch_size, seq_len) backward_time = layer_profiler.measured_backward_time(batch_size, seq_len) activation_memory = layer_profiler.measured_activation_memory(batch_size, seq_len) - # Benchmark Attention + # Benchmark/simulate Attention attn_profiler = layer_profiler.get_sub_profiler("self_attention") attn_forward = attn_profiler.measured_forward_time(batch_size, seq_len) attn_backward = attn_profiler.measured_backward_time(batch_size, seq_len) attn_mem = attn_profiler.measured_activation_memory(batch_size, seq_len) - # Benchmark MLP + # Benchmark/simulate MLP mlp_profiler = layer_profiler.get_sub_profiler("mlp") mlp_forward = mlp_profiler.measured_forward_time(batch_size, seq_len) mlp_backward = mlp_profiler.measured_backward_time(batch_size, seq_len) @@ -601,9 +673,10 @@ def unwrap_module(module): is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0: - print(f" Forward time: {forward_time:.2f} ms") - print(f" Backward time: {backward_time:.2f} ms") - print(f" Total: {forward_time + backward_time:.2f} ms") + src = "(simulated)" if is_simulation_mode else "(measured)" + print(f" Forward time: {forward_time:.2f} ms {src}") + print(f" Backward time: {backward_time:.2f} ms {src}") + print(f" Total: {forward_time + backward_time:.2f} ms {src}") print(f" Activation memory: {activation_memory / (1024**2):.2f} MB") print(f" Attention: fwd={attn_forward:.2f} ms, bwd={attn_backward:.2f} ms") print(f" MLP: fwd={mlp_forward:.2f} ms, bwd={mlp_backward:.2f} ms") diff --git a/primus/core/projection/module_profilers/moe_mlp.py b/primus/core/projection/module_profilers/moe_mlp.py index 070731fb6..8a9a06c5f 100644 --- a/primus/core/projection/module_profilers/moe_mlp.py +++ b/primus/core/projection/module_profilers/moe_mlp.py @@ -4,6 +4,7 @@ # See LICENSE for license information. ############################################################################### +import os from typing import Optional from primus.core.projection.base_module_profiler import BaseModuleProfiler @@ -12,6 +13,23 @@ from .utils import benchmark_layer +# Efficiency fractions for non-GEMM MoE overhead estimation. +# These express achievable bandwidth as a fraction of peak HBM bandwidth. +# The actual BW is ``fraction × peak_hbm_bw`` for the target architecture, +# so the model scales automatically across MI300X (5.3 TB/s), MI325X (6.0 +# TB/s), MI355X (8.0 TB/s), etc. +# +# PERMUTE (scatter/gather) — random-access token dispatch/combine. Irregular +# access patterns achieve only ~5-7 % of peak HBM bandwidth. +_PERMUTE_BW_FRACTION = 0.057 +# +# ACTIVATION (SwiGLU / GELU) — sequential element-wise ops that stream over +# contiguous buffers. Typically ~55-60 % of peak HBM bandwidth. +_ACTIVATION_BW_FRACTION = 0.566 +# +# Fallback absolute values used when the backend cannot report HBM bandwidth. +_FALLBACK_HBM_BW_GBPS = 5300.0 # MI300X default + class MoEMLPProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): @@ -19,6 +37,7 @@ def __init__(self, config, sub_profilers=None): self.module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend def set_module(self, module): """Set the actual MoE MLP module for benchmarking.""" @@ -27,6 +46,13 @@ def set_module(self, module): self._cached_results = None self._cache_key = None + def set_gemm_backend(self, backend): + """Set a GEMM simulation backend for simulated profiling.""" + self._gemm_backend = backend + # Invalidate cache when backend changes + self._cached_results = None + self._cache_key = None + def estimated_num_params(self, rank: Optional[int] = None) -> int: if self.config.model_config.moe_ffn_hidden_size is not None: moe_ffn = self.config.model_config.moe_ffn_hidden_size @@ -87,14 +113,226 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return total + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Get simulated results from the GEMM simulation backend for MoE MLP. + + In addition to expert GEMM time, this method estimates several + components of MoE execution that the GEMM simulation alone misses: + + 1. **Router overhead** — gate linear projection + softmax/top-K. + 2. **Token permutation** — dispatch (scatter) and combine (gather) + memory traffic with random-access patterns. + 3. **Activation function** — SwiGLU / GELU element-wise overhead. + + **Grouped GEMM performance model selection**: + When ``enable_primus_turbo`` and ``use_turbo_grouped_mlp`` are both + ``True`` in the training config, the expert GEMMs are modelled using + Origami's *batched* GEMM path (``batch=num_local_experts``). Primus + Turbo's grouped-GEMM kernel achieves near-ideal batched execution, + so the batched model is an accurate proxy. + + Otherwise (legacy ``grouped_gemm`` package), each expert is simulated + independently (``batch=1``) and the result is scaled by the number of + local experts. This more closely reflects the sequential per-expert + execution of the legacy kernel. + """ + tp_size = self.config.model_parallel_config.tensor_model_parallel_size + cp_size = self.config.model_parallel_config.context_model_parallel_size + ep_size = self.config.model_parallel_config.expert_model_parallel_size + + hidden_size = self.config.model_config.hidden_size + batch_tokens = batch_size * seq_len // tp_size // cp_size + topk = self.config.model_config.moe_router_topk + topk_tokens = batch_tokens * topk + + if self.config.model_config.moe_ffn_hidden_size is not None: + moe_ffn = self.config.model_config.moe_ffn_hidden_size + else: + moe_ffn = self.config.model_config.ffn_hidden_size + + num_experts = self.config.model_config.num_experts or 1 + num_local_experts = num_experts // ep_size + tokens_per_expert = topk_tokens // max(num_local_experts, 1) + + # FP8-hybrid: MoE expert MLP projections run in FP8 + gemm_dtype = "fp8" if getattr(self.config.model_config, "fp8", None) else "bf16" + bytes_per_el = 1 if gemm_dtype == "fp8" else 2 + + # ── 1. Routed expert GEMMs ── + M = tokens_per_expert + H = hidden_size + F = moe_ffn + + # Determine grouped-GEMM performance model. + # Primus Turbo's grouped-GEMM kernel achieves near-ideal batched + # execution → model as Origami batched GEMM (batch=num_local_experts). + # Legacy grouped_gemm executes experts more sequentially → model as + # individual GEMM (batch=1) × num_local_experts. + use_turbo = getattr(self.config.model_config, "enable_primus_turbo", False) and getattr( + self.config.model_config, "use_turbo_grouped_mlp", False + ) + + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0 and num_local_experts > 1: + mode = "Turbo (batched)" if use_turbo else "Legacy (sequential)" + print( + f" [MoE MLP] Grouped-GEMM model: {mode}" + f" ({num_local_experts} local experts, M={M}, H={H}, F={F})" + ) + + if use_turbo: + # ── Turbo model: batched GEMM (all experts in parallel) ── + B = num_local_experts + if self.config.model_config.swiglu: + gate_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + expert_fwd_ms = gate_fwd.forward_time_ms + up_fwd.forward_time_ms + down_fwd.forward_time_ms + gate_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + gate_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) + down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=B) + expert_bwd_ms = ( + gate_dg.forward_time_ms + + gate_wg.forward_time_ms + + up_dg.forward_time_ms + + up_wg.forward_time_ms + + down_dg.forward_time_ms + + down_wg.forward_time_ms + ) + else: + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + expert_fwd_ms = up_fwd.forward_time_ms + down_fwd.forward_time_ms + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) + down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=B) + expert_bwd_ms = ( + up_dg.forward_time_ms + + up_wg.forward_time_ms + + down_dg.forward_time_ms + + down_wg.forward_time_ms + ) + + expert_fwd = expert_fwd_ms + expert_bwd = expert_bwd_ms + else: + # ── Legacy model: individual GEMM × num_local_experts ── + if self.config.model_config.swiglu: + gate_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + expert_fwd_ms = gate_fwd.forward_time_ms + up_fwd.forward_time_ms + down_fwd.forward_time_ms + gate_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + gate_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) + down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=1) + expert_bwd_ms = ( + gate_dg.forward_time_ms + + gate_wg.forward_time_ms + + up_dg.forward_time_ms + + up_wg.forward_time_ms + + down_dg.forward_time_ms + + down_wg.forward_time_ms + ) + else: + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + expert_fwd_ms = up_fwd.forward_time_ms + down_fwd.forward_time_ms + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) + down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=1) + expert_bwd_ms = ( + up_dg.forward_time_ms + + up_wg.forward_time_ms + + down_dg.forward_time_ms + + down_wg.forward_time_ms + ) + + expert_fwd = expert_fwd_ms * num_local_experts + expert_bwd = expert_bwd_ms * num_local_experts + + fwd_time = expert_fwd + bwd_time = expert_bwd + + # ── 2. Router overhead ── + # Gate linear: [batch_tokens, num_experts, hidden_size] + router_gemm = self._gemm_backend.simulate_gemm(batch_tokens, num_experts, hidden_size, gemm_dtype) + router_fwd_ms = router_gemm.forward_time_ms + # Softmax + top-K selection + auxiliary loss overhead (empirical) + topk_overhead_ms = 0.1 + 0.002 * num_experts + router_fwd_ms += topk_overhead_ms + # Backward: dgrad + wgrad for gate linear + router_bwd_ms = 2.0 * router_gemm.forward_time_ms + topk_overhead_ms + + fwd_time += router_fwd_ms + bwd_time += router_bwd_ms + + # ── 3. Token permutation overhead (dispatch + combine) ── + # Dispatch: gather tokens by expert assignment → irregular memory access + # Combine: scatter expert outputs back → weighted reduce + # + # Derive effective BW from the target GPU's peak HBM bandwidth so the + # model adapts automatically to different architectures. + peak_hbm = ( + self._gemm_backend.hbm_bandwidth_gbps + if self._gemm_backend is not None and self._gemm_backend.hbm_bandwidth_gbps is not None + else _FALLBACK_HBM_BW_GBPS + ) + permute_eff_bw_gbps = peak_hbm * _PERMUTE_BW_FRACTION + activation_bw_gbps = peak_hbm * _ACTIVATION_BW_FRACTION + + dispatch_bytes = (batch_tokens + topk_tokens) * hidden_size * bytes_per_el + combine_bytes = (topk_tokens + batch_tokens) * hidden_size * bytes_per_el + permute_fwd_ms = dispatch_bytes / (permute_eff_bw_gbps * 1e6) + permute_bwd_ms = combine_bytes / (permute_eff_bw_gbps * 1e6) + + fwd_time += permute_fwd_ms + bwd_time += permute_bwd_ms + + # ── 4. Activation function overhead (SwiGLU / GELU) ── + if self.config.model_config.swiglu: + act_bytes = 3 * topk_tokens * moe_ffn * bytes_per_el # gate+up read, result write + else: + act_bytes = 2 * topk_tokens * moe_ffn * bytes_per_el # read + write + activation_ms = act_bytes / (activation_bw_gbps * 1e6) + + fwd_time += activation_ms + bwd_time += activation_ms + + # ── 5. Shared experts (if any) ── + shared_sz = self.config.model_config.moe_shared_expert_intermediate_size + if shared_sz: + shared_result = self._gemm_backend.simulate_mlp_gemms( + batch_tokens=batch_tokens, + hidden_size=hidden_size, + ffn_hidden_size=shared_sz, + dtype=gemm_dtype, + swiglu=self.config.model_config.swiglu, + ) + fwd_time += shared_result.forward_time_ms + bwd_time += shared_result.backward_time_ms + + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - self._cached_results = benchmark_layer( - self.module, - [(seq_len, batch_size, self.config.model_config.hidden_size)], - ) + if self._gemm_backend is not None: + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + self._cached_results = benchmark_layer( + self.module, + [(seq_len, batch_size, self.config.model_config.hidden_size)], + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/module_profilers/output_layer.py b/primus/core/projection/module_profilers/output_layer.py index ddd7d5ee4..c2e27555b 100644 --- a/primus/core/projection/module_profilers/output_layer.py +++ b/primus/core/projection/module_profilers/output_layer.py @@ -17,6 +17,7 @@ def __init__(self, config, sub_profilers=None): self.module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend def set_module(self, module): """Set the actual module for benchmarking.""" @@ -25,6 +26,12 @@ def set_module(self, module): self._cached_results = None self._cache_key = None + def set_gemm_backend(self, backend): + """Set a GEMM simulation backend for simulated profiling.""" + self._gemm_backend = backend + self._cached_results = None + self._cache_key = None + def estimated_num_params(self, rank: Optional[int] = None) -> int: return self.config.model_config.padded_vocab_size * self.config.model_config.hidden_size @@ -38,22 +45,62 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: * 2 ) # bf16 + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Simulate output layer using GEMM backend (vocab projection GEMM).""" + tp_size = self.config.model_parallel_config.tensor_model_parallel_size + cp_size = self.config.model_parallel_config.context_model_parallel_size + batch_tokens = batch_size * seq_len // tp_size // cp_size + hidden_size = self.config.model_config.hidden_size + vocab_size = self.config.model_config.padded_vocab_size + + # Output projection GEMM fwd: [batch_tokens, hidden_size] x [hidden_size, vocab_size] + fwd_result = self._gemm_backend.simulate_gemm( + m=batch_tokens, + n=vocab_size, + k=hidden_size, + dtype="bf16", + ) + fwd_time = fwd_result.forward_time_ms + + # Backward: simulate actual dgrad + wgrad GEMMs + # dgrad: [batch_tokens, vocab_size] x [vocab_size, hidden_size] -> [batch_tokens, hidden_size] + dgrad_result = self._gemm_backend.simulate_gemm( + m=batch_tokens, + n=hidden_size, + k=vocab_size, + dtype="bf16", + ) + # wgrad: [hidden_size, batch_tokens] x [batch_tokens, vocab_size] -> [hidden_size, vocab_size] + wgrad_result = self._gemm_backend.simulate_gemm( + m=hidden_size, + n=vocab_size, + k=batch_tokens, + dtype="bf16", + ) + bwd_time = dgrad_result.forward_time_ms + wgrad_result.forward_time_ms + + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - # Context parallel / Sequence parallel adjustment - cp_size = self.config.model_parallel_config.context_model_parallel_size - # Effective sequence length per rank if CP is used - slen_per_cp = seq_len // cp_size - - self._cached_results = benchmark_layer( - self.module, - [ - (slen_per_cp, batch_size, self.config.model_config.hidden_size), - ], - ) + if self._gemm_backend is not None: + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + # Context parallel / Sequence parallel adjustment + cp_size = self.config.model_parallel_config.context_model_parallel_size + # Effective sequence length per rank if CP is used + slen_per_cp = seq_len // cp_size + + self._cached_results = benchmark_layer( + self.module, + [ + (slen_per_cp, batch_size, self.config.model_config.hidden_size), + ], + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/module_profilers/transformer_layer.py b/primus/core/projection/module_profilers/transformer_layer.py index 4fa5ce281..59f593615 100644 --- a/primus/core/projection/module_profilers/transformer_layer.py +++ b/primus/core/projection/module_profilers/transformer_layer.py @@ -4,13 +4,16 @@ # See LICENSE for license information. ############################################################################### +import os from typing import Optional from primus.core.projection.base_module_profiler import BaseModuleProfiler from primus.core.projection.profiler_spec import ModuleProfilerSpec from primus.core.projection.training_config import TrainingConfig +from . import collective_model as cm from .attention import AttentionProfiler +from .collective_args import get_default_args from .dense_mlp import DenseMLPProfiler from .layer_norm import LayerNormProfiler from .moe_mlp import MoEMLPProfiler @@ -18,6 +21,105 @@ from .router import RouterProfiler from .utils import benchmark_layer + +def _estimate_tp_allreduce_time_ms(config, batch_size: int, seq_len: int) -> float: + """ + Estimate TP AllReduce time for a single AllReduce operation (in ms). + + In Megatron-style tensor parallelism each transformer layer performs + 2 AllReduces in forward (after attention row-parallel output projection, + after MLP row-parallel down projection) and 2 in backward. + With sequence parallelism the AllReduce is replaced by + ReduceScatter + AllGather pairs, but the total data volume is equivalent. + + Returns 0.0 when TP <= 1 (no communication needed). + """ + tp = config.model_parallel_config.tensor_model_parallel_size + if tp <= 1: + return 0.0 + + cp = getattr(config.model_parallel_config, "context_parallel_size", 1) or 1 + pp = config.model_parallel_config.pipeline_model_parallel_size + ep = getattr(config.model_parallel_config, "expert_model_parallel_size", 1) or 1 + hidden_size = config.model_config.hidden_size + + # Message size: activations after row-parallel projection + # Shape: [batch_size * seq_len / CP, hidden_size], BF16 (2 bytes) + message_size_bytes = batch_size * seq_len * hidden_size * 2 // cp + + # Setup collective communication args + gpus_per_node = int(os.environ.get("GPUS_PER_NODE", "8")) + num_nodes = int(os.environ.get("NNODES", "1")) + + coll_args = get_default_args( + num_nodes=num_nodes, + gpus_per_node=gpus_per_node, + tp=tp, + pp=pp, + ep=ep, + cp=cp, + ) + + # TP AllReduce is across tp ranks (typically intra-node) + ar_time_us = cm.allreduce(coll_args, message_size_bytes, tp, groups=["tp"]) + return ar_time_us / 1000.0 # Convert microseconds → milliseconds + + +def _estimate_moe_a2a_time_ms(config, batch_size: int, seq_len: int) -> float: + """ + Estimate MoE All-to-All time (dispatch + combine) per layer per direction (in ms). + + Each MoE layer performs two A2A operations per direction: + 1. **Dispatch**: scatter tokens to the EP ranks that own the assigned experts. + 2. **Combine**: gather expert outputs back to the originating ranks. + + In benchmark mode this cost is captured inside the measured layer time. + In simulation mode we must add it explicitly because the layer profiler + only simulates GEMM / SDPA compute and TP AllReduce. + + Returns 0.0 when EP <= 1 (all experts are local, no A2A needed). + """ + ep = getattr(config.model_parallel_config, "expert_model_parallel_size", 1) or 1 + if ep <= 1: + return 0.0 + + tp = config.model_parallel_config.tensor_model_parallel_size + pp = config.model_parallel_config.pipeline_model_parallel_size + cp = getattr(config.model_parallel_config, "context_parallel_size", 1) or 1 + hidden_size = config.model_config.hidden_size + moe_router_topk = getattr(config.model_config, "moe_router_topk", 2) + + # A2A message size: each rank sends/receives all routed tokens + # Shape: [batch_size * seq_len * topk, hidden_size], BF16 (2 bytes) + tokens_per_batch = batch_size * seq_len + dispatch_size_bytes = tokens_per_batch * hidden_size * moe_router_topk * 2 + + # Setup collective communication args + gpus_per_node = int(os.environ.get("GPUS_PER_NODE", "8")) + num_nodes = int(os.environ.get("NNODES", "1")) + + coll_args = get_default_args( + num_nodes=num_nodes, + gpus_per_node=gpus_per_node, + tp=tp, + pp=pp, + ep=ep, + cp=cp, + ) + + # Propagate DeepEP setting if present (affects A2A algorithm selection) + moe_enable_deepep = getattr(config.model_parallel_config, "moe_enable_deepep", False) + use_turbo_deepep = getattr(config.model_parallel_config, "use_turbo_deepep", False) + coll_args.moe_enable_deepep = moe_enable_deepep + coll_args.use_turbo_deepep = use_turbo_deepep + + # Dispatch A2A + Combine A2A (same message size, same time) + a2a_dispatch_us = cm.alltoall(coll_args, dispatch_size_bytes, ep, groups=["ep"]) + a2a_combine_us = cm.alltoall(coll_args, dispatch_size_bytes, ep, groups=["ep"]) + + return (a2a_dispatch_us + a2a_combine_us) / 1000.0 # Convert us → ms + + # Transformer Layer Data Flow # # +----------------+ @@ -69,10 +171,31 @@ def __init__(self, config, sub_profilers=None): self.layer_module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend + self._sdpa_backend = None # Optional: SDPA simulation backend def get_sub_profiler(self, name: str): return self.sub_profilers.get(name) + def set_simulation_backends(self, gemm_backend=None, sdpa_backend=None): + """Set simulation backends and propagate to sub-profilers.""" + self._gemm_backend = gemm_backend + self._sdpa_backend = sdpa_backend + # Propagate to sub-profilers + if "self_attention" in self.sub_profilers: + attn = self.sub_profilers["self_attention"] + if gemm_backend is not None and hasattr(attn, "set_gemm_backend"): + attn.set_gemm_backend(gemm_backend) + if sdpa_backend is not None and hasattr(attn, "set_sdpa_backend"): + attn.set_sdpa_backend(sdpa_backend) + if "mlp" in self.sub_profilers: + mlp = self.sub_profilers["mlp"] + if gemm_backend is not None and hasattr(mlp, "set_gemm_backend"): + mlp.set_gemm_backend(gemm_backend) + # Invalidate cache + self._cached_results = None + self._cache_key = None + def set_layer_module(self, layer_module): """Set the actual transformer layer module for benchmarking.""" self.layer_module = layer_module @@ -99,17 +222,40 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: + self.sub_profilers["residual_add"].estimated_activation_memory(batch_size, seq_len) * 2 ) + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Aggregate simulated results from sub-profilers, including TP AllReduce.""" + attn_fwd = self.sub_profilers["self_attention"].measured_forward_time(batch_size, seq_len) + attn_bwd = self.sub_profilers["self_attention"].measured_backward_time(batch_size, seq_len) + mlp_fwd = self.sub_profilers["mlp"].measured_forward_time(batch_size, seq_len) + mlp_bwd = self.sub_profilers["mlp"].measured_backward_time(batch_size, seq_len) + + # Add TP AllReduce communication overhead (simulation only). + # Each transformer layer has 2 AllReduces per direction: + # - After attention row-parallel output projection + # - After MLP row-parallel down projection + # (With sequence parallelism these become RS+AG pairs with equal volume.) + tp_ar_ms = _estimate_tp_allreduce_time_ms(self.config, batch_size, seq_len) + + fwd_time = attn_fwd + mlp_fwd + 2 * tp_ar_ms + bwd_time = attn_bwd + mlp_bwd + 2 * tp_ar_ms + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - # Get TransformerConfig from the layer module itself (has fp8 setting) - transformer_config = getattr(self.layer_module, "config", None) - self._cached_results = benchmark_layer( - self.layer_module, - [(seq_len, batch_size, self.config.model_config.hidden_size)], - transformer_config=transformer_config, - ) + if self._gemm_backend is not None or self._sdpa_backend is not None: + # Use simulation mode + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + # Get TransformerConfig from the layer module itself (has fp8 setting) + transformer_config = getattr(self.layer_module, "config", None) + self._cached_results = benchmark_layer( + self.layer_module, + [(seq_len, batch_size, self.config.model_config.hidden_size)], + transformer_config=transformer_config, + ) self._cache_key = cache_key return self._cached_results @@ -132,10 +278,31 @@ def __init__(self, config, sub_profilers=None): self.layer_module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend + self._sdpa_backend = None # Optional: SDPA simulation backend def get_sub_profiler(self, name: str): return self.sub_profilers.get(name) + def set_simulation_backends(self, gemm_backend=None, sdpa_backend=None): + """Set simulation backends and propagate to sub-profilers.""" + self._gemm_backend = gemm_backend + self._sdpa_backend = sdpa_backend + # Propagate to sub-profilers + if "self_attention" in self.sub_profilers: + attn = self.sub_profilers["self_attention"] + if gemm_backend is not None and hasattr(attn, "set_gemm_backend"): + attn.set_gemm_backend(gemm_backend) + if sdpa_backend is not None and hasattr(attn, "set_sdpa_backend"): + attn.set_sdpa_backend(sdpa_backend) + if "mlp" in self.sub_profilers: + mlp = self.sub_profilers["mlp"] + if gemm_backend is not None and hasattr(mlp, "set_gemm_backend"): + mlp.set_gemm_backend(gemm_backend) + # Invalidate cache + self._cached_results = None + self._cache_key = None + def set_layer_module(self, layer_module): """Set the actual transformer layer module for benchmarking.""" self.layer_module = layer_module @@ -164,17 +331,52 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: + self.sub_profilers["residual_add"].estimated_activation_memory(batch_size, seq_len) * 2 ) + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Aggregate simulated results from sub-profilers. + + Includes TP AllReduce and MoE All-to-All communication overhead that + would be captured in the measured layer time during benchmark mode but + must be added explicitly in simulation mode. + """ + attn_fwd = self.sub_profilers["self_attention"].measured_forward_time(batch_size, seq_len) + attn_bwd = self.sub_profilers["self_attention"].measured_backward_time(batch_size, seq_len) + mlp_fwd = self.sub_profilers["mlp"].measured_forward_time(batch_size, seq_len) + mlp_bwd = self.sub_profilers["mlp"].measured_backward_time(batch_size, seq_len) + + # Add TP AllReduce communication overhead (simulation only). + # Each transformer layer has 2 AllReduces per direction: + # - After attention row-parallel output projection + # - After MLP row-parallel down projection + # (With sequence parallelism these become RS+AG pairs with equal volume.) + tp_ar_ms = _estimate_tp_allreduce_time_ms(self.config, batch_size, seq_len) + + # Add MoE All-to-All communication overhead (simulation only). + # Each MoE layer performs dispatch A2A + combine A2A per direction. + # In benchmark mode this is captured in the measured layer time; + # in simulation mode the layer profiler only computes GEMM / SDPA + # so we must add it here. + moe_a2a_ms = _estimate_moe_a2a_time_ms(self.config, batch_size, seq_len) + + fwd_time = attn_fwd + mlp_fwd + 2 * tp_ar_ms + moe_a2a_ms + bwd_time = attn_bwd + mlp_bwd + 2 * tp_ar_ms + moe_a2a_ms + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - # Get TransformerConfig from the layer module itself (has fp8 setting) - transformer_config = getattr(self.layer_module, "config", None) - self._cached_results = benchmark_layer( - self.layer_module, - [(seq_len, batch_size, self.config.model_config.hidden_size)], - transformer_config=transformer_config, - ) + if self._gemm_backend is not None or self._sdpa_backend is not None: + # Use simulation mode + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + # Get TransformerConfig from the layer module itself (has fp8 setting) + transformer_config = getattr(self.layer_module, "config", None) + self._cached_results = benchmark_layer( + self.layer_module, + [(seq_len, batch_size, self.config.model_config.hidden_size)], + transformer_config=transformer_config, + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/performance_projection/projection.py b/primus/core/projection/performance_projection/projection.py index da6e4bf2a..7ba62180d 100644 --- a/primus/core/projection/performance_projection/projection.py +++ b/primus/core/projection/performance_projection/projection.py @@ -23,14 +23,367 @@ from primus.core.projection.performance_projection.simulator import ( SchedulerSimulationRunner, ) +from primus.core.projection.simulation_backends.factory import ( + get_gemm_simulation_backend, + get_sdpa_simulation_backend, +) from primus.core.projection.training_config import ( convert_primus_config_to_projection_config, ) -from primus.modules.trainer.megatron.pre_trainer import MegatronPretrainTrainer + +# NOTE: MegatronPretrainTrainer is imported lazily inside _run_layer_benchmark() +# to avoid pulling in the megatron dependency when running in pure simulation mode +# (--profiling-mode simulate). _MAX_EXPERT_PARALLEL_SIZE = 8 _BYTES_PER_GB = 1024**3 +# Projection mode constants +MODE_TRAINING = "training" +MODE_INFERENCE = "inference" +MODE_PREFILL = "prefill" +MODE_DECODE = "decode" + + +def _calculate_min_gpus(tp, pp, ep, cp): + """Calculate minimum GPUs required by parallelism config. + + For MoE models (EP > 1), CP is folded into EP via MoE Parallel Folding: + the CP ranks are a subset of the EP ranks, so the minimum GPU count is + TP × PP × EP. Constraints: CP ≤ EP and EP % CP == 0. + + For dense models (EP ≤ 1), CP is an independent parallelism axis, so the + minimum GPU count is TP × PP × CP. + + Note: DP is *not* affected by this folding — EP borrows from the DP + dimension, so DP = world_size / (TP × PP × CP) in both cases. + """ + if ep > 1: + # MoE: CP is folded into EP (MoE Parallel Folding) + return tp * pp * ep + else: + # Dense: CP is an independent axis + return tp * pp * cp + + +# HBM bandwidth (GB/s) by GPU architecture — used for optimizer step estimation +_HBM_BANDWIDTH_GBPS: Dict[str, float] = { + "mi300x": 5300.0, + "gfx942": 5300.0, + "mi325x": 6000.0, + "mi355x": 8000.0, + "gfx950": 8000.0, +} + +# Peak BF16 TFLOPS per GPU — used for compute roofline in decode estimation +_PEAK_TFLOPS_BF16: Dict[str, float] = { + "mi300x": 1307.0, + "gfx942": 1307.0, + "mi325x": 1307.0, + "mi355x": 2611.0, + "gfx950": 2611.0, +} + + +def _get_hw_params(gpu_arch: Optional[str] = None): + """Return (hbm_bandwidth_gb_s, peak_tflops_bf16) for the given GPU arch.""" + arch = (gpu_arch or os.getenv("PRIMUS_GPU_ARCH", "mi300x")).lower().strip() + hbm_bw = _HBM_BANDWIDTH_GBPS.get(arch, 5300.0) + peak_tf = _PEAK_TFLOPS_BF16.get(arch, 1307.0) + return hbm_bw, peak_tf + + +# ========================================================================= +# Analytical Decode Model +# ========================================================================= + + +def _estimate_decode_time_per_token( + training_config, + decode_batch_size: int, + context_length: int, + tp: int = 1, + pp: int = 1, + ep: int = 1, + gpu_arch: Optional[str] = None, + hardware_config: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + Analytically estimate the time for **one decode step** (generating one token + per sequence in the batch). + + Decode is memory-bandwidth-bound for small batch sizes: the GPU must load + all model weights once per step, and read the KV cache. The computation + (tiny GEMMs of shape ``[batch, 1, …]``) is far below the compute roofline. + + Model: + time_per_token ≈ (weight_bytes + kv_cache_bytes) / HBM_bandwidth + + TP_allreduce_latency + + Args: + training_config: TrainingConfig with model / parallel config. + decode_batch_size: Number of sequences decoded concurrently. + context_length: Number of tokens already in the KV cache. + tp, pp, ep: Parallelism dimensions. + gpu_arch: GPU architecture string (for HBM BW / peak TFLOPS). + hardware_config: Optional custom HW config dict. + + Returns: + Dict with detailed breakdown (all times in milliseconds). + """ + mc = training_config.model_config + hidden = mc.hidden_size + num_layers = mc.num_layers + num_heads = mc.num_attention_heads + head_dim = mc.kv_channels or (hidden // num_heads) + gqa = mc.group_query_attention + num_kv_heads = mc.num_query_groups if gqa else num_heads + ffn_hidden = mc.ffn_hidden_size or (hidden * 4) + vocab_size = mc.padded_vocab_size or 100352 + moe_pattern = mc.moe_pattern or [0] * num_layers + num_experts = mc.num_experts or 0 + moe_ffn = mc.moe_ffn_hidden_size or ffn_hidden + moe_topk = mc.moe_router_topk or 1 + shared_expert_size = mc.moe_shared_expert_intermediate_size or 0 + + bytes_per_param = 2 # BF16 + + hbm_bw_gbps, peak_tflops = _get_hw_params(gpu_arch) + hbm_bw_bytes_per_ms = hbm_bw_gbps * 1e9 / 1e3 # bytes / ms + peak_flops_per_ms = peak_tflops * 1e12 / 1e3 # FLOPs / ms + + # Per-layer weight bytes (sharded by TP) + # Attention: Q(h→h) + K(h→kv_dim) + V(h→kv_dim) + O(h→h) + kv_dim = num_kv_heads * head_dim + attn_weight_bytes = ( + (hidden * hidden + hidden * kv_dim + hidden * kv_dim + hidden * hidden) # Q # K # V # O + * bytes_per_param + // tp + ) + + # Dense MLP: gate(h→ffn) + up(h→ffn) + down(ffn→h) (SwiGLU has gate+up) + dense_mlp_weight_bytes = 3 * hidden * ffn_hidden * bytes_per_param // tp + + # MoE MLP weights per GPU (num_experts/EP experts, each gate+up+down, divided by expert_TP) + expert_tp = 1 # expert TP typically 1 + if num_experts > 0: + experts_per_gpu = max(num_experts // max(ep, 1), 1) + moe_mlp_weight_bytes = 3 * hidden * moe_ffn * experts_per_gpu * bytes_per_param // expert_tp + # Router weight: (hidden → num_experts) + router_weight_bytes = hidden * num_experts * bytes_per_param + # Shared expert (if any) + shared_expert_weight_bytes = ( + 3 * hidden * shared_expert_size * bytes_per_param // tp if shared_expert_size > 0 else 0 + ) + else: + moe_mlp_weight_bytes = 0 + router_weight_bytes = 0 + shared_expert_weight_bytes = 0 + + # KV cache bytes per layer per decode step + # Read K: batch × num_kv_heads/TP × context_len × head_dim × bytes + # Read V: same + kv_heads_per_gpu = max(num_kv_heads // tp, 1) + kv_cache_per_layer_bytes = ( + 2 * decode_batch_size * kv_heads_per_gpu * context_length * head_dim * bytes_per_param + ) + + # Compute FLOPs per layer per decode step (batch × 1 token) + # Linear projections: 2 * batch * M * N (M=1, N=weight cols) + attn_proj_flops = ( + 2 + * decode_batch_size + * (hidden * hidden + hidden * kv_dim + hidden * kv_dim + hidden * hidden) # Q # K # V # O + ) // tp + dense_mlp_flops = 2 * decode_batch_size * 3 * hidden * ffn_hidden // tp + + # Attention score + V multiply: batch × num_heads/TP × context_len × head_dim × 2 + heads_per_gpu = max(num_heads // tp, 1) + attn_score_flops = 2 * decode_batch_size * heads_per_gpu * context_length * head_dim # Q·K^T + attn_v_flops = 2 * decode_batch_size * heads_per_gpu * context_length * head_dim # score·V + attn_kv_flops = attn_score_flops + attn_v_flops + + # MoE compute: each token routed to topk experts + if num_experts > 0: + moe_mlp_flops = 2 * decode_batch_size * moe_topk * 3 * hidden * moe_ffn // max(ep, 1) + else: + moe_mlp_flops = 0 + + # ---- Aggregate across all layers ---- + num_dense_layers = sum(1 for p in moe_pattern if p == 0) + num_moe_layers = sum(1 for p in moe_pattern if p == 1) + layers_per_pp = num_layers // max(pp, 1) + + total_weight_bytes = ( + attn_weight_bytes * num_layers # attention in every layer + + dense_mlp_weight_bytes * num_dense_layers + + (moe_mlp_weight_bytes + router_weight_bytes + shared_expert_weight_bytes) * num_moe_layers + ) // max(pp, 1) + + total_kv_bytes = kv_cache_per_layer_bytes * layers_per_pp + + # Embedding + output layer weights + embedding_weight_bytes = vocab_size * hidden * bytes_per_param // tp # only on first PP stage + output_weight_bytes = vocab_size * hidden * bytes_per_param // tp # only on last PP stage + # Amortise across PP stages + total_weight_bytes += (embedding_weight_bytes + output_weight_bytes) // max(pp, 1) + + total_compute_flops = ( + (attn_proj_flops + attn_kv_flops) * layers_per_pp + + dense_mlp_flops * (num_dense_layers // max(pp, 1)) + + moe_mlp_flops * (num_moe_layers // max(pp, 1)) + ) + + # ---- Roofline: memory time vs compute time ---- + memory_time_ms = (total_weight_bytes + total_kv_bytes) / hbm_bw_bytes_per_ms + compute_time_ms = total_compute_flops / peak_flops_per_ms + + # TP AllReduce latency (2× per layer: after attention + after MLP) + # For decode, messages are tiny (batch × hidden × 2 bytes) — latency-dominated + tp_latency_per_ar_us = 5.0 # ~5 µs typical intra-node AllReduce latency + if hardware_config: + tp_latency_per_ar_us = hardware_config.get("intra_node_latency_us", 1.0) * 3 + tp_allreduce_count = 2 * layers_per_pp # attn + MLP per layer + tp_overhead_ms = (tp_allreduce_count * tp_latency_per_ar_us / 1000) if tp > 1 else 0 + + # PP overhead (serial forward through PP stages, activation P2P) + pp_p2p_latency_us = 5.0 + if hardware_config: + pp_p2p_latency_us = hardware_config.get("intra_node_latency_us", 1.0) * 3 + pp_overhead_ms = ((pp - 1) * pp_p2p_latency_us / 1000) if pp > 1 else 0 + + # EP All-to-All for MoE decode (tiny messages, latency-dominated) + ep_a2a_latency_us = 10.0 + if hardware_config: + ep_a2a_latency_us = hardware_config.get("intra_node_latency_us", 1.0) * 5 + moe_a2a_per_layer_ms = (2 * ep_a2a_latency_us / 1000) if ep > 1 else 0 # dispatch + combine + moe_a2a_total_ms = moe_a2a_per_layer_ms * (num_moe_layers // max(pp, 1)) + + # Total decode time per token + decode_time_ms = max(memory_time_ms, compute_time_ms) + tp_overhead_ms + pp_overhead_ms + moe_a2a_total_ms + + # Arithmetic intensity (FLOPs / byte) — shows how memory-bound we are + total_bytes_accessed = total_weight_bytes + total_kv_bytes + arith_intensity = total_compute_flops / total_bytes_accessed if total_bytes_accessed > 0 else 0 + # Machine balance point: peak_flops / hbm_bw + balance_point = (peak_tflops * 1e12) / (hbm_bw_gbps * 1e9) + + return { + "decode_time_ms": decode_time_ms, + "memory_time_ms": memory_time_ms, + "compute_time_ms": compute_time_ms, + "tp_overhead_ms": tp_overhead_ms, + "pp_overhead_ms": pp_overhead_ms, + "moe_a2a_total_ms": moe_a2a_total_ms, + "total_weight_bytes": total_weight_bytes, + "total_kv_bytes": total_kv_bytes, + "total_weight_mb": total_weight_bytes / (1024 * 1024), + "total_kv_mb": total_kv_bytes / (1024 * 1024), + "total_compute_tflops": total_compute_flops / 1e12, + "arithmetic_intensity": arith_intensity, + "balance_point": balance_point, + "is_memory_bound": arith_intensity < balance_point, + "hbm_bw_gbps": hbm_bw_gbps, + "peak_tflops": peak_tflops, + "decode_batch_size": decode_batch_size, + "context_length": context_length, + "layers_per_pp": layers_per_pp, + "num_dense_layers": num_dense_layers, + "num_moe_layers": num_moe_layers, + } + + +def _estimate_optimizer_step_ms( + training_config, + dp_size: int = 1, + gpu_arch: Optional[str] = None, +) -> float: + """ + Estimate the optimizer step time (Adam/AdamW) for one training iteration. + + The optimizer step is HBM-bandwidth-bound. For each parameter, + Adam reads/writes the following (mixed-precision training with BF16 + forward, FP32 master weights): + + Read: FP32 master param (4B) + FP32 grad (4B) + m (4B) + v (4B) = 16 B + Write: FP32 master param (4B) + m (4B) + v (4B) + BF16 param (2B) = 14 B + Total: 30 bytes per parameter + + With distributed_optimizer or FSDP, optimizer state is sharded across + DP ranks, so each GPU only updates ``N_params / dp_size`` parameters. + + Returns: + Optimizer step time in milliseconds. + """ + model_config = training_config.model_config + mp_config = training_config.model_parallel_config + + # --- Count total parameters per GPU (post TP/PP/EP sharding) ---------- + hidden = model_config.hidden_size + ffn_hidden = model_config.ffn_hidden_size or (hidden * 4) + moe_ffn = model_config.moe_ffn_hidden_size or ffn_hidden + num_layers = model_config.num_layers + num_experts = model_config.num_experts or 0 + moe_pattern = model_config.moe_pattern # list of 0/1 + num_moe_layers = sum(1 for p in moe_pattern if p == 1) + num_dense_layers = num_layers - num_moe_layers + + tp = mp_config.tensor_model_parallel_size + pp = mp_config.pipeline_model_parallel_size + ep = getattr(mp_config, "expert_model_parallel_size", 1) or 1 + + # Attention params: Q, K, V, O -> 4 * h * h (per layer, sharded by TP) + attn_params_per_layer = 4 * hidden * hidden // tp + # Dense MLP: gate, up, down -> 3 * h * ffn (per layer, sharded by TP) + dense_mlp_params_per_layer = 3 * hidden * ffn_hidden // tp + # Expert MLP params per expert: 3 * h * moe_ffn (NOT sharded by TP normally) + expert_tp = getattr(mp_config, "expert_tensor_parallel_size", None) or 1 + expert_mlp_params_per_expert = 3 * hidden * moe_ffn // expert_tp + + # Non-expert params across all layers (sharded by TP, PP) + non_expert_params = num_layers * attn_params_per_layer + num_dense_layers * dense_mlp_params_per_layer + # Expert params (sharded by EP, expert_TP, PP) + expert_params = num_moe_layers * num_experts * expert_mlp_params_per_expert // max(ep, 1) + + # Shared experts (if any) + shared_sz = getattr(model_config, "moe_shared_expert_intermediate_size", 0) or 0 + shared_expert_params = 0 + if shared_sz and num_moe_layers > 0: + shared_expert_params = num_moe_layers * 3 * hidden * shared_sz // tp + + total_params_per_gpu = (non_expert_params + expert_params + shared_expert_params) // pp + + # Embedding + output layer params (only on first / last PP rank, amortise) + vocab_size = getattr(model_config, "vocab_size", 0) or 0 + if vocab_size and pp > 0: + embedding_params = vocab_size * hidden // tp + output_params = vocab_size * hidden // tp + # Amortise across PP ranks (only 1 rank holds each) + total_params_per_gpu += (embedding_params + output_params) // pp + + # --- Distributed optimizer / FSDP sharding --- + use_distributed_optimizer = getattr(mp_config, "use_distributed_optimizer", False) + use_fsdp = getattr(mp_config, "use_torch_fsdp2", False) + + if use_distributed_optimizer or use_fsdp: + # Optimizer state is sharded across DP ranks + params_for_optim = total_params_per_gpu // max(dp_size, 1) + else: + params_for_optim = total_params_per_gpu + + # --- Compute time --- + bytes_per_param = 30 # Adam read+write (mixed precision) + total_bytes = params_for_optim * bytes_per_param + + # Look up HBM bandwidth + arch = (gpu_arch or os.getenv("PRIMUS_GPU_ARCH", "mi300x")).lower().strip() + hbm_bw_gbps = _HBM_BANDWIDTH_GBPS.get(arch, 5300.0) + hbm_bw_bytes_per_ms = hbm_bw_gbps * 1e9 / 1e3 # bytes/ms + + optimizer_time_ms = total_bytes / hbm_bw_bytes_per_ms + + return optimizer_time_ms + # ============================================================================= # Hardware and Communication Functions (moved from multinode_projection) @@ -75,6 +428,13 @@ def calculate_collective_communication_time( hardware_config=hardware_config, ) + # Check if DeepEP is enabled and pass it to coll_args + mp_config = training_config.model_parallel_config + moe_enable_deepep = getattr(mp_config, "moe_enable_deepep", False) + use_turbo_deepep = getattr(mp_config, "use_turbo_deepep", False) + coll_args.moe_enable_deepep = moe_enable_deepep + coll_args.use_turbo_deepep = use_turbo_deepep + # Model parameters hidden_size = model_config.hidden_size num_layers = model_config.num_layers @@ -109,7 +469,7 @@ def calculate_collective_communication_time( message_info = {} per_layer_info = [] # Store per-layer communication details - # Check if FSDP is enabled (needed to determine gradient sync strategy) + # Get model parallel config (already retrieved above for DeepEP check) mp_config = training_config.model_parallel_config # Note: use_torch_fsdp2 = True means actual FSDP (shards weights, uses all-gather/reduce-scatter) # use_distributed_optimizer = True means ZeRO-1 style (shards optimizer state only, uses all-reduce) @@ -168,9 +528,11 @@ def calculate_collective_communication_time( message_info["moe_ar_no_overlap"] = False # 2. MoE All-to-All (EP group) + # With TP > 1 and sequence parallelism, each GPU holds S/TP tokens. + # The A2A dispatches these S/TP tokens; AG(TP) recovers full S after A2A. if ep > 1 and num_moe_layers > 0: - tokens_per_batch = seq_len * batch_size - dispatch_size = tokens_per_batch * hidden_size * moe_router_topk * 2 # BF16 + tokens_per_gpu = seq_len * batch_size // max(tp, 1) # S/TP with seq parallel + dispatch_size = tokens_per_gpu * hidden_size * moe_router_topk * 2 # BF16 a2a_dispatch = cm.alltoall(coll_args, dispatch_size, ep, groups=["ep"]) a2a_combine = cm.alltoall(coll_args, dispatch_size, ep, groups=["ep"]) @@ -199,6 +561,10 @@ def calculate_collective_communication_time( # FSDP shards weights across DP ranks. Each layer needs: # - Forward: All-gather to reconstruct full weights # - Backward: Reduce-scatter to distribute gradients back to shards + # + # Recompute correction: with recompute_granularity="full", every backward + # layer recomputes its forward pass, requiring a SECOND AllGather per + # layer to re-fetch the sharded weights. # Note: use_fsdp and mp_config already defined above if use_fsdp and dp > 1: @@ -214,22 +580,36 @@ def calculate_collective_communication_time( # All-gather: each rank sends its shard (1/DP), receives full weights # Total data moved = weight_size * (DP-1)/DP per rank - ag_time_per_layer = cm.allgather(coll_args, weight_size_per_layer, dp, groups=["dp"]) + ag_time_per_layer_us = cm.allgather(coll_args, weight_size_per_layer, dp, groups=["dp"]) # Reduce-scatter: each rank sends full gradients, receives its shard - # Gradients are in FP32 for optimizer (4 bytes), but reduce-scatter often uses BF16 grad_size_per_layer = params_per_dense_layer * 2 # BF16 gradients for communication - rs_time_per_layer = cm.reduce_scatter(coll_args, grad_size_per_layer, dp, groups=["dp"]) + rs_time_per_layer_us = cm.reduce_scatter(coll_args, grad_size_per_layer, dp, groups=["dp"]) + + # --- Recompute correction --- + # With recompute_granularity="full", during the backward pass each layer + # re-runs its forward pass. This means the weights must be AllGathered + # AGAIN for each recomputed layer (the first AG result was freed after + # the initial forward). The ReduceScatter count is unchanged (1 per + # layer backward). + recompute_gran = getattr(mp_config, "recompute_granularity", None) + recomp_n_layers = getattr(mp_config, "recompute_num_layers", 0) or 0 + ag_multiplier = 1 # default: AG once per layer (forward) + if recompute_gran == "full" and recomp_n_layers > 0: + # Each recomputed layer needs a second AG in backward + recomp_ratio = min(recomp_n_layers, num_layers) / num_layers + ag_multiplier = 1 + recomp_ratio # e.g. 2.0 when all layers recomputed # Calculate total FSDP time for all layers - total_fsdp_ag_fwd = (ag_time_per_layer * num_layers) / 1000 # ms - total_fsdp_rs_bwd = (rs_time_per_layer * num_layers) / 1000 # ms + total_fsdp_ag_fwd = (ag_time_per_layer_us * num_layers * ag_multiplier) / 1000 # ms + total_fsdp_rs_bwd = (rs_time_per_layer_us * num_layers) / 1000 # ms breakdown["fsdp_allgather_fwd"] = total_fsdp_ag_fwd breakdown["fsdp_reducescatter_bwd"] = total_fsdp_rs_bwd message_info["fsdp_weight_size_per_layer_mb"] = weight_size_per_layer / (1024 * 1024) - message_info["fsdp_ag_per_layer_ms"] = ag_time_per_layer / 1000 - message_info["fsdp_rs_per_layer_ms"] = rs_time_per_layer / 1000 + message_info["fsdp_ag_per_layer_ms"] = ag_time_per_layer_us / 1000 + message_info["fsdp_rs_per_layer_ms"] = rs_time_per_layer_us / 1000 + message_info["fsdp_ag_multiplier"] = ag_multiplier message_info["fsdp_enabled"] = True else: breakdown["fsdp_allgather_fwd"] = 0.0 @@ -275,38 +655,23 @@ def calculate_collective_communication_time( else: message_info["gradient_allreduce_overlapped"] = False - # Check if FSDP communication can be overlapped - # In FSDP2, prefetch can overlap all-gather with compute of current layer - # Reduce-scatter can overlap with forward of next microbatch - # However, overlap is NOT 100%: - # - First layer's all-gather cannot overlap (nothing before it) - # - Last layer's reduce-scatter cannot overlap (nothing after it) - # - There's always some exposed communication at boundaries if use_fsdp and dp > 1: - overlap_fsdp = getattr(mp_config, "use_torch_fsdp2", False) # FSDP2 has better overlap + overlap_fsdp = getattr(mp_config, "use_torch_fsdp2", False) if overlap_fsdp: - # Calculate per-layer times - fsdp_ag_per_layer = message_info.get("fsdp_ag_per_layer_ms", 0) - fsdp_rs_per_layer = message_info.get("fsdp_rs_per_layer_ms", 0) - - # Exposed time: first layer's all-gather + last layer's reduce-scatter - # Plus some overhead from imperfect pipelining (~10-20% of remaining) - exposed_ag = fsdp_ag_per_layer # First layer cannot overlap - exposed_rs = fsdp_rs_per_layer # Last layer cannot overlap - remaining_ag = breakdown.get("fsdp_allgather_fwd", 0) - exposed_ag - remaining_rs = breakdown.get("fsdp_reducescatter_bwd", 0) - exposed_rs - - # Assume ~70% overlap efficiency for the rest (conservative for multi-node) - overlap_efficiency = 0.7 - hidden_ag = remaining_ag * overlap_efficiency - hidden_rs = remaining_rs * overlap_efficiency - - total_comm_time -= hidden_ag - total_comm_time -= hidden_rs + total_fsdp_ag = breakdown.get("fsdp_allgather_fwd", 0) + total_fsdp_rs = breakdown.get("fsdp_reducescatter_bwd", 0) + + # Overlap factor applied uniformly to all FSDP + # communication - (AllGather fwd, AllGather recompute, ReduceScatter). + FSDP_OVERLAP = 0.93 + + total_fsdp = total_fsdp_ag + total_fsdp_rs + total_hidden = total_fsdp * FSDP_OVERLAP + total_comm_time -= total_hidden message_info["fsdp_overlapped"] = True - message_info["fsdp_exposed_ms"] = ( - exposed_ag + exposed_rs + (remaining_ag + remaining_rs) * (1 - overlap_efficiency) - ) + message_info["fsdp_overlap"] = FSDP_OVERLAP + message_info["fsdp_overall_overlap"] = FSDP_OVERLAP + message_info["fsdp_exposed_ms"] = total_fsdp - total_hidden else: message_info["fsdp_overlapped"] = False @@ -452,6 +817,99 @@ def extract_single_node_time_from_profiling(profiling_results: dict, training_co return total_time_ms +def extract_single_node_time_inference(profiling_results: dict, training_config) -> float: + """ + Extract total single-node **forward-only** time from profiling results. + + This is the inference counterpart of :func:`extract_single_node_time_from_profiling`. + It uses only the forward pass timings (no backward, no recomputation overhead). + + Args: + profiling_results: Dict with integer keys for layers and "embedding", "output" + training_config: Training configuration containing model config + + Returns: + Total forward-only time in milliseconds for the full model (one pass) + """ + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + + if is_rank_0: + print("[Primus:Inference Projection] Extracting forward-only timing from benchmark results...") + print("-" * 100) + + model_config = training_config.model_config + moe_pattern = model_config.moe_pattern + + num_total_layers = len(moe_pattern) + + profiled_layer_indices = sorted([k for k in profiling_results.keys() if isinstance(k, int)]) + if is_rank_0: + print(f" Profiled layers: {profiled_layer_indices}") + print(f" Full model has {num_total_layers} transformer layers") + + total_time_ms = 0.0 + + # Embedding layer (forward only) + if "embedding" in profiling_results: + emb = profiling_results["embedding"] + emb_time = emb.get("forward_time_ms", 0) + total_time_ms += emb_time + if is_rank_0: + print(f" Embedding (fwd): {emb_time:.2f} ms") + + # Analyze profiled transformer layers — forward only + profiled_dense_fwd = [] + profiled_moe_fwd = [] + + for layer_idx in profiled_layer_indices: + if layer_idx < len(moe_pattern): + layer_data = profiling_results[layer_idx] + fwd_time = layer_data.get("forward_time_ms", 0) + + if moe_pattern[layer_idx] == 0: + profiled_dense_fwd.append(fwd_time) + else: + profiled_moe_fwd.append(fwd_time) + + avg_dense_fwd = sum(profiled_dense_fwd) / len(profiled_dense_fwd) if profiled_dense_fwd else 0 + avg_moe_fwd = sum(profiled_moe_fwd) / len(profiled_moe_fwd) if profiled_moe_fwd else 0 + + num_dense_layers = sum(1 for x in moe_pattern if x == 0) + num_moe_layers = sum(1 for x in moe_pattern if x == 1) + + total_dense_time = avg_dense_fwd * num_dense_layers + total_moe_time = avg_moe_fwd * num_moe_layers + total_transformer_time = total_dense_time + total_moe_time + total_time_ms += total_transformer_time + + if is_rank_0: + if profiled_dense_fwd: + print(f" Dense Layers: {len(profiled_dense_fwd)} profiled → {num_dense_layers} total") + print(f" Avg fwd per layer: {avg_dense_fwd:.2f} ms") + print(f" Total fwd time: {total_dense_time:.2f} ms") + + if profiled_moe_fwd: + print(f" MoE Layers: {len(profiled_moe_fwd)} profiled → {num_moe_layers} total") + print(f" Avg fwd per layer: {avg_moe_fwd:.2f} ms") + print(f" Total fwd time: {total_moe_time:.2f} ms") + + # Output layer (forward only) + if "output" in profiling_results: + out = profiling_results["output"] + out_time = out.get("forward_time_ms", 0) + total_time_ms += out_time + if is_rank_0: + print(f" Output Layer (fwd): {out_time:.2f} ms") + + if is_rank_0: + print("-" * 100) + print(f"[Primus:Inference Projection] Extrapolated Forward-Only Time: {total_time_ms:.2f} ms") + print(f" (Based on {len(profiled_layer_indices)} profiled layers → {num_total_layers} total layers)") + print("=" * 100) + + return total_time_ms + + # ============================================================================= # Layer Configuration Functions # ============================================================================= @@ -514,7 +972,10 @@ def _limit_layers_for_projection(module_config): def _rescale_expert_parallelism(module_config): """ - Cap expert_model_parallel_size so that EP * TP * CP <= 8 and adjust num_experts. + Cap expert_model_parallel_size so that EP * TP <= GPUs_per_node and adjust num_experts. + + With MoE Parallel Folding, CP is folded into EP (CP ranks are a subset of + EP ranks), so the minimum GPUs for a MoE config is EP * TP, not EP * TP * CP. """ expert_mp_size = getattr(module_config, "expert_model_parallel_size", None) if expert_mp_size is None or expert_mp_size <= _MAX_EXPERT_PARALLEL_SIZE: @@ -522,13 +983,16 @@ def _rescale_expert_parallelism(module_config): current_cp = getattr(module_config, "context_parallel_size", 1) or 1 if expert_mp_size is None: expert_mp_size = 1 - if expert_mp_size * current_tp * current_cp <= _MAX_EXPERT_PARALLEL_SIZE: + # MoE Parallel Folding: CP is folded into EP, so min GPUs = EP * TP + if expert_mp_size * current_tp <= _MAX_EXPERT_PARALLEL_SIZE: return None num_experts = getattr(module_config, "num_experts", None) current_tp = getattr(module_config, "tensor_model_parallel_size", 1) or 1 current_cp = getattr(module_config, "context_parallel_size", 1) or 1 - total_parallel_product = max(1, current_tp * current_cp) + # MoE Parallel Folding: CP is folded into EP, so only TP contributes to + # the per-EP-rank GPU cost. + total_parallel_product = max(1, current_tp) max_ep_allowed = max(1, _MAX_EXPERT_PARALLEL_SIZE // total_parallel_product) new_expert_mp = min(expert_mp_size, _MAX_EXPERT_PARALLEL_SIZE, max_ep_allowed) @@ -592,8 +1056,9 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): pp = getattr(original_config, "pipeline_model_parallel_size", 1) or 1 ep = getattr(original_config, "expert_model_parallel_size", 1) or 1 cp = getattr(original_config, "context_parallel_size", 1) or 1 + num_experts = getattr(original_config, "num_experts", None) - gpus_required = tp * pp * ep * cp + gpus_required = _calculate_min_gpus(tp, pp, ep, cp) nodes_required = (gpus_required + gpus_per_node - 1) // gpus_per_node # If already fits on 1 node, no adjustment needed @@ -607,14 +1072,17 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): "original_ep": ep, "benchmark_ep": ep, "original_cp": cp, + "original_num_experts": num_experts, + "benchmark_num_experts": num_experts, } # Step 1: Reduce PP to 1 benchmark_pp = 1 - benchmark_gpus_required = tp * benchmark_pp * ep * cp + benchmark_gpus_required = _calculate_min_gpus(tp, benchmark_pp, ep, cp) # Step 2: If still doesn't fit, rescale EP benchmark_ep = ep + benchmark_num_experts = num_experts if benchmark_gpus_required > gpus_per_node: print( f"[Primus:Performance Projection] After reducing PP to 1, " @@ -626,7 +1094,8 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): rescale_info = _rescale_expert_parallelism(original_config) if rescale_info: benchmark_ep = rescale_info["ep_after"] - benchmark_gpus_required = tp * benchmark_pp * benchmark_ep * cp + benchmark_num_experts = rescale_info.get("num_experts_after", num_experts) + benchmark_gpus_required = _calculate_min_gpus(tp, benchmark_pp, benchmark_ep, cp) if benchmark_gpus_required > gpus_per_node: raise ValueError( @@ -666,6 +1135,8 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): "original_ep": ep, "benchmark_ep": benchmark_ep, "original_cp": cp, + "original_num_experts": num_experts, + "benchmark_num_experts": benchmark_num_experts, } @@ -694,7 +1165,7 @@ def _estimate_pp_communication_overhead(training_config, pp_size, hardware_confi # Get hardware setup gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) - gpus_required = tp * pp_size * ep * cp + gpus_required = _calculate_min_gpus(tp, pp_size, ep, cp) num_nodes = (gpus_required + gpus_per_node - 1) // gpus_per_node # Get collective model args @@ -717,8 +1188,9 @@ def _estimate_pp_communication_overhead(training_config, pp_size, hardware_confi p2p_size = batch_size * seq_len * hidden_size * 2 # BF16 # Number of microbatches + # DP = world_size / (TP × PP × CP) — EP excluded (borrows from DP via folding) global_batch_size = runtime_config.global_batch_size - data_parallel_size = (num_nodes * gpus_per_node) // (tp * pp_size * ep * cp) + data_parallel_size = (num_nodes * gpus_per_node) // (tp * pp_size * cp) num_microbatches = global_batch_size // (batch_size * data_parallel_size) # P2P time: 2 * (PP-1) sends per microbatch (forward + backward) @@ -733,47 +1205,69 @@ def _estimate_pp_communication_overhead(training_config, pp_size, hardware_confi return total_p2p_time_ms -def _compute_ep_mlp_scale(model_config, benchmark_ep, original_ep): +def _compute_ep_mlp_scale( + model_config, + benchmark_ep, + original_ep, + original_num_experts=None, + benchmark_num_experts=None, +): """ Compute the MLP time scaling factor when EP changes, accounting for - shared experts (EP-independent) vs routed experts (EP-dependent). + shared experts (EP-independent) vs routed experts. + + Key insight — per-GPU routed compute is EP-invariant + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + In Megatron MoE each GPU in the EP group processes the **same** + micro-batch. After A2A token redistribution every GPU ends up + computing ``batch_tokens × topk`` total token-expert pairs, + regardless of EP. Therefore: + + * When ``_rescale_expert_parallelism`` adjusts ``num_experts`` + proportionally (preserving ``experts_per_rank``), the profiled / + simulated MLP time already reflects the correct per-rank workload + → **no scaling needed** (returns 1.0). - In Megatron MoE: - - Routed expert compute per GPU ∝ (topk / EP) × moe_ffn_hidden_size - - Shared expert compute is constant regardless of EP + * When EP changes but ``num_experts`` stays fixed (hypothetical — + our rescaling always preserves experts_per_rank), the GEMM shapes + change (fewer, larger GEMMs vs. more, smaller GEMMs) but total + FLOPs remain identical. We conservatively return 1.0 because the + simple ``benchmark_ep / original_ep`` ratio does not capture GEMM- + efficiency differences. - The profiled MLP time at benchmark_ep includes both. When scaling to - original_ep, only the routed portion changes. + Shared expert compute is constant regardless of EP (no A2A needed). + + Args: + model_config: Model configuration. + benchmark_ep: EP used during profiling / simulation. + original_ep: EP of the target deployment. + original_num_experts: Total expert count in the target config. + benchmark_num_experts: Total expert count after rescaling + (may differ from ``original_num_experts`` if + ``_rescale_expert_parallelism`` adjusted it). Returns: float: Scale factor to apply to the profiled MLP time. + 1.0 when experts_per_rank is preserved (the common case). """ - topk = getattr(model_config, "moe_router_topk", 1) or 1 - moe_ffn = getattr(model_config, "moe_ffn_hidden_size", None) - shared_ffn = getattr(model_config, "moe_shared_expert_intermediate_size", None) - # Derive num_shared_experts: explicit attribute, or infer from - # moe_shared_expert_intermediate_size // moe_ffn_hidden_size - num_shared = getattr(model_config, "num_shared_experts", 0) or 0 - if num_shared == 0 and shared_ffn and moe_ffn: - num_shared = shared_ffn // moe_ffn - - if not moe_ffn or num_shared == 0 or not shared_ffn: - # No shared experts — all MLP compute is routed, scales with 1/EP - return benchmark_ep / original_ep - - # FLOPs proportional to tokens × ffn_size - # Routed: (topk / benchmark_ep) tokens per expert-slot, through moe_ffn - # Shared: all tokens (1.0), through shared_ffn - routed_flops = (topk / benchmark_ep) * moe_ffn - shared_flops = num_shared * shared_ffn - total_flops = routed_flops + shared_flops - - routed_fraction = routed_flops / total_flops - shared_fraction = shared_flops / total_flops - - # Routed portion scales by benchmark_ep / original_ep; shared stays constant - scale = shared_fraction + routed_fraction * (benchmark_ep / original_ep) - return scale + if benchmark_ep == original_ep: + return 1.0 + + # Determine whether experts_per_rank was preserved by rescaling. + # When it is, per-GPU routed compute is identical — no scaling. + if original_num_experts is not None and benchmark_num_experts is not None: + orig_epr = original_num_experts / original_ep + bench_epr = benchmark_num_experts / benchmark_ep + if abs(orig_epr - bench_epr) < 0.5: + # experts_per_rank preserved — per-GPU routed compute unchanged. + return 1.0 + + # Fallback: per-GPU MoE routed compute is EP-invariant (A2A + # redistributes tokens so each GPU processes batch_tokens × topk). + # The simple benchmark_ep / original_ep ratio is NOT correct because + # total FLOPs are constant; only GEMM shapes differ. We return 1.0 + # rather than an inaccurate heuristic. + return 1.0 def _estimate_ep_communication_overhead( @@ -807,11 +1301,11 @@ def _estimate_ep_communication_overhead( gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) # Calculate nodes required for original EP - gpus_required_original = tp * pp * original_ep * cp + gpus_required_original = _calculate_min_gpus(tp, pp, original_ep, cp) num_nodes_original = (gpus_required_original + gpus_per_node - 1) // gpus_per_node # Calculate nodes for benchmark EP (should be 1) - gpus_required_benchmark = tp * pp * benchmark_ep * cp + gpus_required_benchmark = _calculate_min_gpus(tp, pp, benchmark_ep, cp) num_nodes_benchmark = (gpus_required_benchmark + gpus_per_node - 1) // gpus_per_node # Get collective model args for original EP configuration @@ -837,13 +1331,17 @@ def _estimate_ep_communication_overhead( ) # Calculate All-to-All message size for MoE layers + # With TP > 1 and sequence parallelism, each GPU holds S/TP tokens. + # The AlltoAll dispatcher sends these S/TP tokens across the EP group, + # then AllGather(TP) recovers full S tokens after A2A (see workflow in + # MoEAlltoAllTokenDispatcher: step 3 = A2A(EP), step 4 = AG(TP)). hidden_size = model_config.hidden_size batch_size = runtime_config.micro_batch_size seq_len = runtime_config.sequence_length moe_router_topk = getattr(model_config, "moe_router_topk", 2) - tokens_per_batch = seq_len * batch_size - dispatch_size = tokens_per_batch * hidden_size * moe_router_topk * 2 # BF16 + tokens_per_gpu = seq_len * batch_size // max(tp, 1) # S/TP with seq parallel + dispatch_size = tokens_per_gpu * hidden_size * moe_router_topk * 2 # BF16 # Calculate All-to-All time for original EP (dispatch + combine) a2a_dispatch_original = cm.alltoall(coll_args_original, dispatch_size, original_ep, groups=["ep"]) @@ -1177,6 +1675,8 @@ def _report_simulation_results(sim_results, training_config): def _run_layer_benchmark(primus_config, unknown_overrides): + from primus.modules.trainer.megatron.pre_trainer import MegatronPretrainTrainer + module_config = primus_config.get_module_config("pre_trainer") _limit_layers_for_projection(module_config) rescale_info = _rescale_expert_parallelism(module_config) @@ -1209,6 +1709,7 @@ def _run_layer_benchmark(primus_config, unknown_overrides): print("[Primus:Performance Projection] Initializing Megatron...") trainer.init() + print("[Primus:Performance Projection] Setting up model and optimizer...") trainer.setup() @@ -1248,6 +1749,73 @@ def _run_layer_benchmark(primus_config, unknown_overrides): return profiling_results +def _run_layer_simulation(primus_config, args): + """ + Run layer simulation using GEMM + SDPA simulation backends (no GPU required). + + This mirrors :func:`_run_layer_benchmark` but replaces actual GPU kernel + benchmarks with analytical / model-based simulation. It does *not* + instantiate a trainer or model – only the profiler tree is built from the + ``TrainingConfig``. + + Args: + primus_config: Primus configuration (will be mutated – layer counts + are reduced for consistency with the benchmark flow). + args: CLI arguments (``--gemm-backend``, ``--gpu-arch``). + + Returns: + dict: Profiling results in the same format as ``_run_layer_benchmark``. + """ + module_config = primus_config.get_module_config("pre_trainer") + _limit_layers_for_projection(module_config) + _rescale_expert_parallelism(module_config) + training_config = convert_primus_config_to_projection_config(primus_config) + + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + + # ---- Create simulation backends ---- + gemm_backend_name = getattr(args, "gemm_backend", None) + gpu_arch = getattr(args, "gpu_arch", None) + gpu_clock_mhz = getattr(args, "gpu_clock_mhz", None) + + gemm_backend = get_gemm_simulation_backend( + backend_name=gemm_backend_name, + gpu_arch=gpu_arch, + gpu_clock_mhz=gpu_clock_mhz, + ) + sdpa_backend = get_sdpa_simulation_backend(gpu_arch=gpu_arch, gpu_clock_mhz=gpu_clock_mhz) + + # ---- Build profiler tree (no model needed) ---- + if is_rank_0: + print("[Primus:Performance Projection] Building simulation profiler...") + model_profiler_spec = get_language_model_profiler_spec(training_config) + model_profiler = build_profiler(model_profiler_spec) + + # Wire simulation backends into the entire profiler hierarchy + model_profiler.set_simulation_backends(gemm_backend, sdpa_backend) + + seq_len = training_config.runtime_config.sequence_length + batch_size = training_config.runtime_config.micro_batch_size + + if is_rank_0: + print("[Primus:Performance Projection] Simulating with:") + print(f" Batch Size: {batch_size}") + print(f" Sequence Length: {seq_len}") + print(f" GEMM backend: {gemm_backend.name()}") + print(f" SDPA backend: {sdpa_backend.name()}") + print("" + "=" * 100) + print("[Primus:Performance Projection] Starting layer simulation...") + print("=" * 100) + + # Run simulation (model=None – no GPU required) + profiling_results = model_profiler.run_layer_benchmark( + model=None, + batch_size=batch_size, + seq_len=seq_len, + ) + return profiling_results + + def _run_pipeline_simulation_megatron_zb(training_config, profiling_results): """ Run pipeline simulation using the actual Megatron zero-bubble scheduler. @@ -1450,9 +2018,10 @@ def _run_multinode_projection( cp = getattr(mp_config, "context_model_parallel_size", 1) gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) - # Calculate minimum nodes required by parallelism config - # EP is included in the minimum GPUs calculation (need GPUs to hold experts) - gpus_required = tp * pp * ep * cp + # Calculate minimum nodes required by parallelism config. + # For MoE (EP > 1): CP is folded into EP via MoE Parallel Folding, + # so min GPUs = TP × PP × EP. For dense: min GPUs = TP × PP × CP. + gpus_required = _calculate_min_gpus(tp, pp, ep, cp) min_nodes_required = (gpus_required + gpus_per_node - 1) // gpus_per_node # Validate target >= minimum required @@ -1463,9 +2032,12 @@ def _run_multinode_projection( f"--target-nodes must be >= {min_nodes_required}." ) - # Calculate DP for scaling - EXCLUDES EP (DP scaling is independent of EP) - # EP distributes experts but doesn't affect how many data batches can be processed in parallel - gpus_for_dp = tp * pp * cp # EP excluded for DP calculation + # Calculate DP for scaling. EP is excluded from this divisor because, with + # MoE Parallel Folding, EP borrows from the DP dimension (not from extra + # GPUs). Data-loading DP = world_size / (TP × PP × CP) for both dense and + # MoE models. Within each EP group the CP ranks share context-parallel + # attention work while EP/CP ranks provide inner data-parallel streams. + gpus_for_dp = tp * pp * cp # EP excluded — it borrows from DP total_gpus_target = target_nodes * gpus_per_node dp_target = total_gpus_target // gpus_for_dp @@ -1610,21 +2182,25 @@ def _run_multinode_projection( ) target_microbatches_per_gpu = 1 + # Estimate optimizer step time (once per iteration, after all microbatches) + gpu_arch = getattr(args, "gpu_arch", None) + optimizer_step_ms = _estimate_optimizer_step_ms(training_config, dp_target, gpu_arch) + # Build full iteration time: - # compute (per-microbatch) × num_microbatches + gradient allreduce (once per iter) + # compute (per-microbatch) × num_microbatches + gradient allreduce + optimizer step if time_includes_all_microbatches: - full_iteration_time_ms = projected_time_ms + grad_ar_per_iteration_ms + full_iteration_time_ms = projected_time_ms + grad_ar_per_iteration_ms + optimizer_step_ms time_breakdown_str = f"{full_iteration_time_ms:.3f} ms (from pipeline simulation" if grad_ar_per_iteration_ms > 0: time_breakdown_str += f" + {grad_ar_per_iteration_ms:.1f} ms grad AR" - time_breakdown_str += ")" + time_breakdown_str += f" + {optimizer_step_ms:.1f} ms optimizer)" else: compute_total = projected_time_ms * target_microbatches_per_gpu - full_iteration_time_ms = compute_total + grad_ar_per_iteration_ms + full_iteration_time_ms = compute_total + grad_ar_per_iteration_ms + optimizer_step_ms time_breakdown_str = f"{full_iteration_time_ms:.3f} ms ({target_microbatches_per_gpu} microbatches × {projected_time_ms:.3f} ms" if grad_ar_per_iteration_ms > 0: time_breakdown_str += f" + {grad_ar_per_iteration_ms:.1f} ms grad AR" - time_breakdown_str += ")" + time_breakdown_str += f" + {optimizer_step_ms:.1f} ms optimizer)" # Calculate tokens/s/GPU (tokens processed per second per GPU) tokens_per_iter = global_batch * seq_len @@ -1693,52 +2269,643 @@ def _run_multinode_projection( } -def launch_projection_from_cli(args, overrides): - """ - Entry point for the 'performance_projection' subcommand. +# ============================================================================= +# Inference Projection Functions +# ============================================================================= - Benchmarks Megatron transformer layers and aggregates performance metrics. - If --target-nodes is specified, also runs multinode scaling projection. - If the parallelism configuration requires multiple nodes, automatically reduces - to single-node for benchmarking and estimates performance with PP overhead. +def calculate_inference_communication_time( + training_config, + num_nodes: int, + gpus_per_node: int, + tp: int, + pp: int, + ep: int, + cp: int, + hardware_config: Dict[str, Any] = None, +) -> Tuple[float, Dict[str, float], Dict[str, Any]]: + """ + Calculate collective communication time for **inference** (forward-only). - Args: - args: Command-line arguments - overrides: Configuration overrides + Compared to the training version, this skips: + - Gradient AllReduce (no backward pass) + - Backward MoE All-to-All + - FSDP reduce-scatter (no gradient sharding) + + Returns: + (total_comm_time_ms, breakdown_dict, message_info_dict) """ - cfg_path = Path(args.config) - if not cfg_path.exists(): - raise FileNotFoundError(f"[Primus:Performance Projection] Config file '{cfg_path}' not found.") + model_config = training_config.model_config + runtime_config = training_config.runtime_config - # Load Primus configuration - primus_config, unknown_overrides = load_primus_config(args, overrides) - primus_config_original = copy.deepcopy(primus_config) + coll_args = get_default_args( + num_nodes=num_nodes, + gpus_per_node=gpus_per_node, + tp=tp, + pp=pp, + ep=ep, + cp=cp, + hardware_config=hardware_config, + ) - # Check if we need to reduce config for single-node benchmarking - gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) + hidden_size = model_config.hidden_size + num_layers = model_config.num_layers + moe_router_topk = model_config.moe_router_topk + moe_pattern = model_config.moe_pattern + batch_size = runtime_config.micro_batch_size + seq_len = runtime_config.sequence_length + num_moe_layers = sum(1 for p in moe_pattern if p == 1) - # Get target nodes from CLI flag (--target-nodes) - target_nodes = getattr(args, "target_nodes", None) + breakdown = {} + message_info = {} - # Store original parallelism before any modifications - module_config = primus_config.get_module_config("pre_trainer") - reduction_info = _calculate_single_node_config(copy.deepcopy(module_config), gpus_per_node) + # No gradient AllReduce for inference + breakdown["gradient_allreduce"] = 0.0 + message_info["gradient_allreduce_size"] = 0 + message_info["gradient_allreduce_size_mb"] = 0.0 - # Calculate minimum nodes required - min_nodes_required = reduction_info["original_nodes_required"] + # MoE All-to-All — forward only (dispatch + combine, no backward) + if ep > 1 and num_moe_layers > 0: + tokens_per_gpu = seq_len * batch_size // max(tp, 1) + dispatch_size = tokens_per_gpu * hidden_size * moe_router_topk * 2 # BF16 - # If target_nodes not specified, default to minimum required - if target_nodes is None: - target_nodes = min_nodes_required + a2a_dispatch = cm.alltoall(coll_args, dispatch_size, ep, groups=["ep"]) + a2a_combine = cm.alltoall(coll_args, dispatch_size, ep, groups=["ep"]) - if reduction_info["adjusted"]: - print("" + "=" * 100) - print("[Primus:Performance Projection] Multi-node configuration detected") - print("=" * 100) - print(f" Original configuration requires {min_nodes_required} nodes minimum:") - print( - f" TP={reduction_info['original_tp']}, PP={reduction_info['original_pp']}, " + total_a2a_fwd = (a2a_dispatch + a2a_combine) * num_moe_layers / 1000 # ms + + breakdown["moe_a2a_fwd"] = total_a2a_fwd + message_info["moe_a2a_size"] = dispatch_size + message_info["moe_a2a_size_mb"] = dispatch_size / (1024 * 1024) + message_info["moe_a2a_per_layer_fwd"] = (a2a_dispatch + a2a_combine) / 1000 + message_info["num_moe_layers"] = num_moe_layers + else: + breakdown["moe_a2a_fwd"] = 0.0 + message_info["moe_a2a_size"] = 0 + message_info["moe_a2a_size_mb"] = 0.0 + message_info["moe_a2a_per_layer_fwd"] = 0.0 + message_info["num_moe_layers"] = 0 + + # No backward A2A + # No FSDP communication (no gradient sharding in inference) + breakdown["fsdp_allgather_fwd"] = 0.0 + breakdown["fsdp_reducescatter_bwd"] = 0.0 + message_info["fsdp_enabled"] = False + + message_info["num_layers"] = num_layers + + total_comm_time = sum(breakdown.values()) + return total_comm_time, breakdown, message_info + + +def _run_inference_projection( + training_config, + forward_time_ms: float, + profiling_results, + args, + target_nodes: int, +): + """ + Run inference-mode multinode projection. + + Unlike the training projection, this: + - Uses only forward pass time (no backward, no wgrad) + - Skips optimizer step estimation + - Skips gradient AllReduce + - Reports prefill latency and throughput + + Args: + training_config: Configuration object + forward_time_ms: Measured forward-only time in ms for one pass + profiling_results: Layer profiling results + args: CLI arguments + target_nodes: Target number of nodes for projection + """ + import torch.distributed as dist + + is_rank_0 = not dist.is_initialized() or dist.get_rank() == 0 + + mp_config = training_config.model_parallel_config + + tp = mp_config.tensor_model_parallel_size + pp = mp_config.pipeline_model_parallel_size + ep = getattr(mp_config, "expert_model_parallel_size", 1) + cp = getattr(mp_config, "context_model_parallel_size", 1) + gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) + + gpus_required = tp * pp * ep * cp + min_nodes_required = (gpus_required + gpus_per_node - 1) // gpus_per_node + + if target_nodes < min_nodes_required: + raise ValueError( + f"[Primus:Inference Projection] ERROR: Cannot project to {target_nodes} nodes. " + f"Minimum required by parallelism config is {min_nodes_required} nodes." + ) + + total_gpus_target = target_nodes * gpus_per_node + # For inference, DP means we can serve independent requests in parallel + gpus_for_dp = tp * pp * cp + dp_target = total_gpus_target // gpus_for_dp + + if is_rank_0: + print("" + "=" * 100) + print("Inference Parallelism Configuration") + print("=" * 100) + print(f" TP: {tp}, PP: {pp}, EP: {ep}, CP: {cp}") + print(f" GPUs per Node: {gpus_per_node}") + print(f" Minimum GPUs Required: {gpus_required}") + print(f" Minimum Nodes Required: {min_nodes_required}") + print(f" Target Nodes: {target_nodes}") + + # Load hardware config if provided + hardware_config_dict = None + if hasattr(args, "hardware_config") and args.hardware_config: + hardware_config_dict = load_hardware_config(args.hardware_config) + if is_rank_0: + print(f" Using custom hardware config from: {args.hardware_config}") + + # Calculate inference communication times (forward-only) + total_comm_time_ms, breakdown, message_info = calculate_inference_communication_time( + training_config, + target_nodes, + gpus_per_node, + tp, + pp, + ep, + cp, + hardware_config_dict, + ) + + # Inference projected time: forward compute + forward communication + projected_time_ms = forward_time_ms + + # For PP > 1, add estimated PP P2P overhead (forward only — one direction) + if pp > 1: + if hardware_config_dict: + pp_bw = hardware_config_dict.get("intra_node_bandwidth_gbps", 896.0) + else: + pp_bw = 896.0 # Default xGMI bandwidth + hidden = training_config.model_config.hidden_size + seq_len = training_config.runtime_config.sequence_length + micro_batch = training_config.runtime_config.micro_batch_size + activation_size = seq_len * micro_batch * hidden * 2 # BF16 + pp_latency_us = 1.0 + pp_time_per_stage = (activation_size / (pp_bw * 1e9) * 1e6 + pp_latency_us) / 1000 # ms + pp_overhead_ms = pp_time_per_stage * (pp - 1) + projected_time_ms += pp_overhead_ms + if is_rank_0: + print(f" PP forward overhead ({pp - 1} hops): {pp_overhead_ms:.3f} ms") + + # Get runtime config for throughput calculation + runtime_config = training_config.runtime_config + seq_len = getattr(runtime_config, "sequence_length", 4096) + micro_batch = getattr(runtime_config, "micro_batch_size", 1) + + # Prefill latency = forward time for one batch on one model replica + prefill_latency_ms = projected_time_ms + + # Tokens per pass (one forward pass processes all tokens in the sequence) + tokens_per_pass = seq_len * micro_batch + + # Throughput per replica + tokens_per_sec_per_replica = tokens_per_pass * 1000 / prefill_latency_ms if prefill_latency_ms > 0 else 0 + + # Number of independent replicas with DP + num_replicas = dp_target + + # Aggregate throughput across all replicas + total_tokens_per_sec = tokens_per_sec_per_replica * num_replicas + tokens_per_sec_per_gpu = total_tokens_per_sec / total_gpus_target if total_gpus_target > 0 else 0 + + if is_rank_0: + print("" + "=" * 100) + print("Inference Projection Results") + print("=" * 100) + print(f"📊 Parallelism: TP={tp}, PP={pp}, EP={ep}, CP={cp}") + + # Communication Breakdown + if total_comm_time_ms > 0: + print("📡 Communication Breakdown (forward-only):") + for op_name, op_time in breakdown.items(): + if op_time > 0: + print(f" {op_name}: {op_time:.3f} ms", end="") + if op_name == "moe_a2a_fwd" and "moe_a2a_size_mb" in message_info: + print( + f" (message: {message_info['moe_a2a_size_mb']:.2f} MB, " + f"{message_info['num_moe_layers']} layers × " + f"{message_info['moe_a2a_per_layer_fwd']:.3f} ms/layer)" + ) + else: + print("") + print(f" Total Communication: {total_comm_time_ms:.3f} ms") + + print(f"🎯 Target Configuration ({target_nodes} nodes):") + print(f" Nodes: {target_nodes}, GPUs: {total_gpus_target}") + print(f" TP={tp}, PP={pp}, EP={ep}, CP={cp}, DP(replicas)={num_replicas}") + print( + f" Prefill Latency: {prefill_latency_ms:.3f} ms " + f"(seq_len={seq_len}, micro_batch={micro_batch})" + ) + print(f" Tokens/s per replica: {tokens_per_sec_per_replica:,.0f}") + print(f" Tokens/s total ({num_replicas} replicas): {total_tokens_per_sec:,.0f}") + print(f" Tokens/s/GPU: {tokens_per_sec_per_gpu:,.0f}") + print("=" * 100) + + return { + "target_nodes": target_nodes, + "target_gpus": total_gpus_target, + "tp": tp, + "pp": pp, + "ep": ep, + "cp": cp, + "dp_replicas": num_replicas, + "prefill_latency_ms": prefill_latency_ms, + "tokens_per_sec_per_replica": tokens_per_sec_per_replica, + "tokens_per_sec_total": total_tokens_per_sec, + "tokens_per_sec_per_gpu": tokens_per_sec_per_gpu, + } + + +def _run_decode_layer_benchmark(primus_config, unknown_overrides, decode_batch_size): + """ + Benchmark transformer layers with seq_len=1 to measure decode-step GEMM + times on the GPU. + + The Megatron trainer is initialized with the **original** config (so that + all Megatron assertions pass normally). Only the ``run_layer_benchmark`` + call uses ``seq_len=1`` and ``batch_size=decode_batch_size`` — these + override the input tensor shapes without touching validated Megatron args. + + The resulting forward_time_ms per layer captures real kernel timings for + the tiny GEMMs characteristic of autoregressive decode. Attention timing + will be unrealistically small (1-token self-attention without KV cache), + so callers should overlay an analytical KV cache model. + + Args: + primus_config: Primus configuration (will be mutated for layer + limiting / EP rescaling only — seq_length is NOT changed). + unknown_overrides: Extra CLI overrides for the trainer. + decode_batch_size: Number of concurrent sequences. + + Returns: + dict: Profiling results in the same format as _run_layer_benchmark. + """ + from primus.modules.trainer.megatron.pre_trainer import MegatronPretrainTrainer + + module_config = primus_config.get_module_config("pre_trainer") + + # NOTE: Do NOT change seq_length or micro_batch_size in the module config — + # Megatron's argument validation has many assertions on seq_length (e.g. + # divisibility by CP, TP, position-embedding limits). We keep the original + # config for trainer init and only pass seq_len=1 to run_layer_benchmark(). + _limit_layers_for_projection(module_config) + rescale_info = _rescale_expert_parallelism(module_config) + training_config = convert_primus_config_to_projection_config(primus_config) + + master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") + master_port = int(os.getenv("MASTER_PORT", "29500")) + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + + is_rank_0 = rank == 0 + if is_rank_0: + print("[Primus:Decode Benchmark] Initializing MegatronPretrainTrainer...") + print( + "[Primus:Decode Benchmark] (trainer uses original seq_length for init; " + "benchmark will use seq_len=1)" + ) + + # Disable overlap/FSDP features for profiling + primus_config.get_module_config("pre_trainer").overlap_grad_reduce = False + primus_config.get_module_config("pre_trainer").overlap_param_gather = False + primus_config.get_module_config("pre_trainer").use_torch_fsdp2 = False + + trainer = MegatronPretrainTrainer( + module_name="pre_trainer", + primus_config=primus_config, + module_rank=rank, + module_world_size=world_size, + module_master_addr=master_addr, + module_master_port=master_port, + extra_args=unknown_overrides, + ) + + if is_rank_0: + print("[Primus:Decode Benchmark] Initializing Megatron...") + trainer.init() + if is_rank_0: + print("[Primus:Decode Benchmark] Setting up model and optimizer...") + trainer.setup() + + if is_rank_0: + print("[Primus:Decode Benchmark] Building model profiler...") + model_profiler_spec = get_language_model_profiler_spec(training_config) + model_profiler = build_profiler(model_profiler_spec) + + # Override seq_len and batch_size ONLY for the benchmark call + seq_len = 1 + batch_size = decode_batch_size + + if is_rank_0: + print("[Primus:Decode Benchmark] Benchmarking with:") + print(f" Rank: {rank}") + print(f" World Size: {world_size}") + print(f" Batch Size: {batch_size} (decode concurrent sequences)") + print(f" Sequence Length: {seq_len} (1 token per decode step)") + if rescale_info: + note = ( + f" NOTE: MoE rescaled -> EP {rescale_info['ep_before']} -> {rescale_info['ep_after']}" + f" (TP={rescale_info['tp']}, CP={rescale_info['cp']})" + ) + print(note) + + if is_rank_0: + print("" + "=" * 100) + print("[Primus:Decode Benchmark] Starting layer benchmarking (seq_len=1)...") + print("=" * 100) + + profiling_results = model_profiler.run_layer_benchmark( + model=trainer.model, + batch_size=batch_size, + seq_len=seq_len, + ) + return profiling_results + + +def _run_decode_projection( + training_config, + args, + target_nodes: int, + profiling_results: Optional[Dict] = None, +): + """ + Run decode-mode projection. + + Supports two modes: + - **Analytical** (profiling_results=None): Fully analytical model based + on HBM bandwidth and model architecture. No GPU needed. + - **Benchmark-enhanced** (profiling_results provided): Uses real GPU + benchmark timings for GEMMs (seq_len=1) and overlays an analytical + KV cache attention model for the portion that benchmarks can't capture. + + Args: + training_config: Configuration object. + args: CLI arguments (decode_batch_size, decode_context_length, etc.). + target_nodes: Target number of nodes for projection. + profiling_results: Optional benchmark results from _run_decode_layer_benchmark. + """ + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + use_benchmark = profiling_results is not None + + mp_config = training_config.model_parallel_config + tp = mp_config.tensor_model_parallel_size + pp = mp_config.pipeline_model_parallel_size + ep = getattr(mp_config, "expert_model_parallel_size", 1) + cp = getattr(mp_config, "context_model_parallel_size", 1) + gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) + total_gpus = target_nodes * gpus_per_node + + # DP replicas for decode + gpus_for_dp = tp * pp * cp + dp_replicas = total_gpus // gpus_for_dp + + # Decode-specific parameters from CLI or config defaults + decode_batch = getattr(args, "decode_batch_size", None) + if decode_batch is None: + decode_batch = training_config.runtime_config.micro_batch_size + context_len = getattr(args, "decode_context_length", None) + if context_len is None: + context_len = training_config.runtime_config.sequence_length + num_gen_tokens = getattr(args, "num_generated_tokens", None) or 128 + + gpu_arch = getattr(args, "gpu_arch", None) + hardware_config_dict = None + if hasattr(args, "hardware_config") and args.hardware_config: + hardware_config_dict = load_hardware_config(args.hardware_config) + + # Always run analytical model (used for KV cache estimate and as baseline) + analytical = _estimate_decode_time_per_token( + training_config, + decode_batch_size=decode_batch, + context_length=context_len, + tp=tp, + pp=pp, + ep=ep, + gpu_arch=gpu_arch, + hardware_config=hardware_config_dict, + ) + + # ── Compute decode time ────────────────────────────────────────────── + if use_benchmark: + # Benchmark-enhanced: real GPU GEMM timings + analytical KV cache + # + # The benchmark ran layers with seq_len=1. The forward_time_ms per + # layer captures the real GEMM cost (QKV proj, output proj, MLP) + # plus TP AllReduce — but attention was only 1-token self-attention + # (no KV cache). We add the analytical KV cache read time on top. + mc = training_config.model_config + num_layers = mc.num_layers + layers_per_pp = num_layers // max(pp, 1) + + # Sum forward times from benchmark (only transformer layers) + benchmark_fwd_ms = 0.0 + layer_count = 0 + for layer_idx, layer_data in profiling_results.items(): + if isinstance(layer_data, dict) and "forward_time_ms" in layer_data: + benchmark_fwd_ms += layer_data["forward_time_ms"] + layer_count += 1 + + # Scale to full model depth (benchmark may run fewer representative layers) + if layer_count > 0 and layer_count < layers_per_pp: + benchmark_fwd_ms = benchmark_fwd_ms * layers_per_pp / layer_count + + # KV cache read time (analytical — benchmark doesn't have a KV cache) + hbm_bw_bytes_per_ms = analytical["hbm_bw_gbps"] * 1e9 / 1e3 + kv_cache_ms = analytical["total_kv_bytes"] / hbm_bw_bytes_per_ms + + # PP overhead (serial stages) + pp_overhead_ms = analytical["pp_overhead_ms"] + + # MoE EP All-to-All (latency-dominated for 1-token messages) + moe_a2a_ms = analytical["moe_a2a_total_ms"] + + decode_time_ms = benchmark_fwd_ms + kv_cache_ms + pp_overhead_ms + moe_a2a_ms + method_label = "Benchmark-Enhanced" + else: + # Pure analytical + decode_time_ms = analytical["decode_time_ms"] + benchmark_fwd_ms = None + kv_cache_ms = analytical["total_kv_bytes"] / (analytical["hbm_bw_gbps"] * 1e9 / 1e3) + pp_overhead_ms = analytical["pp_overhead_ms"] + moe_a2a_ms = analytical["moe_a2a_total_ms"] + method_label = "Analytical Model" + + tokens_per_sec_per_replica = decode_batch * 1000 / decode_time_ms if decode_time_ms > 0 else 0 + total_tokens_per_sec = tokens_per_sec_per_replica * dp_replicas + tokens_per_sec_per_gpu = total_tokens_per_sec / total_gpus if total_gpus > 0 else 0 + + total_generation_time_ms = decode_time_ms * num_gen_tokens + + # KV cache memory per GPU + kv_mb_per_gpu = analytical["total_kv_mb"] + weight_mb_per_gpu = analytical["total_weight_mb"] + + if is_rank_0: + print("\n" + "=" * 100) + print(f"Decode Projection Results ({method_label})") + print("=" * 100) + + print(f"📊 Parallelism: TP={tp}, PP={pp}, EP={ep}, CP={cp}") + print(f" Decode batch size: {decode_batch}") + print(f" Context length (KV cache): {context_len}") + print(f" GPU arch: {(gpu_arch or os.getenv('PRIMUS_GPU_ARCH', 'mi300x')).lower()}") + print(f" HBM bandwidth: {analytical['hbm_bw_gbps']:.0f} GB/s") + print(f" Peak BF16 compute: {analytical['peak_tflops']:.0f} TFLOPS") + + print() + print("⏱️ Per-Token Decode Time Breakdown:") + + if use_benchmark: + # Show benchmarked GEMM time + analytical KV cache overlay + print( + f" Layer fwd (benchmarked, seq_len=1): {benchmark_fwd_ms:.4f} ms " + f"(includes GEMMs + TP AllReduce)" + ) + print( + f" KV cache read (analytical): {kv_cache_ms:.4f} ms " + f"({analytical['total_kv_mb']:.1f} MB)" + ) + # Also show analytical-only for comparison + analytical_weight_ms = analytical["total_weight_bytes"] / (analytical["hbm_bw_gbps"] * 1e9 / 1e3) + analytical_total = analytical["decode_time_ms"] + print(f" ─── Analytical comparison ───") + print( + f" Weight loading (analytical): {analytical_weight_ms:.4f} ms " + f"({analytical['total_weight_mb']:.1f} MB)" + ) + print(f" Total (analytical): {analytical_total:.4f} ms") + print(f" ────────────────────────────") + else: + weight_only_ms = analytical["total_weight_bytes"] / (analytical["hbm_bw_gbps"] * 1e9 / 1e3) + print(f" Weight loading: {weight_only_ms:.4f} ms ({analytical['total_weight_mb']:.1f} MB)") + print(f" KV cache read: {kv_cache_ms:.4f} ms ({analytical['total_kv_mb']:.1f} MB)") + print( + f" Compute: {analytical['compute_time_ms']:.4f} ms " + f"({analytical['total_compute_tflops'] * 1000:.2f} GFLOPS)" + ) + if analytical["tp_overhead_ms"] > 0: + print(f" TP AllReduce: {analytical['tp_overhead_ms']:.4f} ms") + + if pp_overhead_ms > 0: + print(f" PP P2P: {pp_overhead_ms:.4f} ms") + if moe_a2a_ms > 0: + print(f" MoE All-to-All: {moe_a2a_ms:.4f} ms") + + if not use_benchmark: + bound = "MEMORY-BOUND" if analytical["is_memory_bound"] else "COMPUTE-BOUND" + print( + f" Bottleneck: {bound} " + f"(arith intensity={analytical['arithmetic_intensity']:.2f} FLOPs/B, " + f"balance={analytical['balance_point']:.0f} FLOPs/B)" + ) + + print(f" ─────────────────────────────────────") + print(f" Total per token: {decode_time_ms:.4f} ms") + + print() + print(f"🎯 Target Configuration ({target_nodes} nodes):") + print(f" Nodes: {target_nodes}, GPUs: {total_gpus}") + print(f" TP={tp}, PP={pp}, EP={ep}, CP={cp}, DP(replicas)={dp_replicas}") + print(f" Per-token latency: {decode_time_ms:.4f} ms") + print(f" Tokens/s per replica: {tokens_per_sec_per_replica:,.0f}") + print(f" Tokens/s total ({dp_replicas} replicas): {total_tokens_per_sec:,.0f}") + print(f" Tokens/s/GPU: {tokens_per_sec_per_gpu:,.0f}") + print() + print( + f" Generation of {num_gen_tokens} tokens: {total_generation_time_ms:.1f} ms " + f"({total_generation_time_ms / 1000:.2f} s)" + ) + + print() + print("💾 Memory Estimate (per GPU):") + print(f" Model weights: {weight_mb_per_gpu:.1f} MB ({weight_mb_per_gpu / 1024:.2f} GB)") + print( + f" KV cache: {kv_mb_per_gpu:.1f} MB ({kv_mb_per_gpu / 1024:.2f} GB) " + f"(batch={decode_batch}, ctx={context_len})" + ) + total_mem_gb = (weight_mb_per_gpu + kv_mb_per_gpu) / 1024 + print(f" Total: {total_mem_gb:.2f} GB") + + print("=" * 100) + + return { + "target_nodes": target_nodes, + "target_gpus": total_gpus, + "tp": tp, + "pp": pp, + "ep": ep, + "cp": cp, + "dp_replicas": dp_replicas, + "decode_batch_size": decode_batch, + "context_length": context_len, + "decode_time_per_token_ms": decode_time_ms, + "tokens_per_sec_per_replica": tokens_per_sec_per_replica, + "tokens_per_sec_total": total_tokens_per_sec, + "tokens_per_sec_per_gpu": tokens_per_sec_per_gpu, + "total_generation_time_ms": total_generation_time_ms, + "weight_mb_per_gpu": weight_mb_per_gpu, + "kv_cache_mb_per_gpu": kv_mb_per_gpu, + "is_memory_bound": analytical["is_memory_bound"], + "method": method_label, + } + + +def launch_projection_from_cli(args, overrides): + """ + Entry point for the 'performance_projection' subcommand. + + Benchmarks Megatron transformer layers and aggregates performance metrics. + + If --target-nodes is specified, also runs multinode scaling projection. + If the parallelism configuration requires multiple nodes, automatically reduces + to single-node for benchmarking and estimates performance with PP overhead. + + Args: + args: Command-line arguments + overrides: Configuration overrides + """ + cfg_path = Path(args.config) + if not cfg_path.exists(): + raise FileNotFoundError(f"[Primus:Performance Projection] Config file '{cfg_path}' not found.") + + # Load Primus configuration + primus_config, unknown_overrides = load_primus_config(args, overrides) + primus_config_original = copy.deepcopy(primus_config) + + # Check if we need to reduce config for single-node benchmarking + gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) + + # Get target nodes from CLI flag (--target-nodes) + target_nodes = getattr(args, "target_nodes", None) + + # Store original parallelism before any modifications + module_config = primus_config.get_module_config("pre_trainer") + reduction_info = _calculate_single_node_config(copy.deepcopy(module_config), gpus_per_node) + + # Calculate minimum nodes required + min_nodes_required = reduction_info["original_nodes_required"] + + # If target_nodes not specified, default to minimum required + if target_nodes is None: + target_nodes = min_nodes_required + + if reduction_info["adjusted"]: + print("" + "=" * 100) + print("[Primus:Performance Projection] Multi-node configuration detected") + print("=" * 100) + print(f" Original configuration requires {min_nodes_required} nodes minimum:") + print( + f" TP={reduction_info['original_tp']}, PP={reduction_info['original_pp']}, " f"EP={reduction_info['original_ep']}, CP={reduction_info['original_cp']}" ) print(" Reducing to single-node configuration for benchmarking:") @@ -1767,16 +2934,124 @@ def launch_projection_from_cli(args, overrides): primus_config.get_module_config("pre_trainer").expert_model_parallel_size = reduction_info[ "benchmark_ep" ] + # Also propagate num_experts adjustment so that the profiler sees + # the correct experts_per_rank (e.g. 128/4=32, not 256/4=64). + if reduction_info.get("benchmark_num_experts") is not None: + primus_config.get_module_config("pre_trainer").num_experts = reduction_info[ + "benchmark_num_experts" + ] + + # ========================================================================= + # DECODE MODE — analytical or benchmark-enhanced + # ========================================================================= + _projection_mode = getattr(args, "mode", MODE_TRAINING) + if _projection_mode == MODE_DECODE: + training_config = convert_primus_config_to_projection_config(primus_config_original) + + profiling_mode = getattr(args, "profiling_mode", "benchmark") + decode_profiling_results = None + + if profiling_mode in ("benchmark", "both"): + # Benchmark layers with seq_len=1 to get real GPU GEMM timings + decode_batch = getattr(args, "decode_batch_size", None) + if decode_batch is None: + decode_batch = training_config.runtime_config.micro_batch_size + + decode_profiling_results = _run_decode_layer_benchmark( + copy.deepcopy(primus_config), + unknown_overrides, + decode_batch_size=decode_batch, + ) - profiling_results = _run_layer_benchmark(primus_config, unknown_overrides) + if profiling_mode == "both": + # Run both: show analytical and benchmark-enhanced side by side + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + + # Analytical + if is_rank_0: + print("\n" + "=" * 100) + print("[Primus:Decode] Running ANALYTICAL projection...") + print("=" * 100) + _run_decode_projection(training_config, args, target_nodes, profiling_results=None) + + # Benchmark-enhanced + if is_rank_0: + print("\n" + "=" * 100) + print("[Primus:Decode] Running BENCHMARK-ENHANCED projection...") + print("=" * 100) + _run_decode_projection( + training_config, + args, + target_nodes, + profiling_results=decode_profiling_results, + ) + else: + # Single mode: benchmark or simulate + _run_decode_projection( + training_config, + args, + target_nodes, + profiling_results=decode_profiling_results, + ) + return + + # Determine profiling mode + profiling_mode = getattr(args, "profiling_mode", "benchmark") + + if profiling_mode == "simulate": + # Pure simulation – no GPU / trainer required + profiling_results = _run_layer_simulation(primus_config, args) + elif profiling_mode == "both": + # Run both benchmark and simulation, keep benchmark results for + # downstream pipeline simulation / multinode projection, but print + # a side-by-side comparison. + sim_results = _run_layer_simulation(copy.deepcopy(primus_config), args) + bench_results = _run_layer_benchmark(primus_config, unknown_overrides) + + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print("\n" + "=" * 100) + print("[Primus:Performance Projection] Benchmark vs Simulation Comparison") + print("=" * 100) + for key in bench_results: + if key in ("embedding", "output"): + continue + bd = bench_results[key] + sd = sim_results.get(key, {}) + if not isinstance(bd, dict): + continue + lt = bd.get("type", key) + b_fwd = bd.get("forward_time_ms", 0) + b_bwd = bd.get("backward_time_ms", 0) + s_fwd = sd.get("forward_time_ms", 0) + s_bwd = sd.get("backward_time_ms", 0) + fwd_err = ((s_fwd - b_fwd) / b_fwd * 100) if b_fwd else 0 + bwd_err = ((s_bwd - b_bwd) / b_bwd * 100) if b_bwd else 0 + print(f" Layer type: {lt}") + print(f" Forward: bench={b_fwd:.2f} ms sim={s_fwd:.2f} ms (err={fwd_err:+.1f}%)") + print(f" Backward: bench={b_bwd:.2f} ms sim={s_bwd:.2f} ms (err={bwd_err:+.1f}%)") + print("=" * 100) + + # Use benchmark results for the rest of the pipeline + profiling_results = bench_results + else: + # Default: actual GPU benchmark + profiling_results = _run_layer_benchmark(primus_config, unknown_overrides) # Use original config for projection calculations training_config = convert_primus_config_to_projection_config(primus_config_original) + # Determine projection mode (training / prefill / decode) + projection_mode = getattr(args, "mode", MODE_TRAINING) + # Normalise: "inference" is an alias for "prefill" + if projection_mode == MODE_INFERENCE: + projection_mode = MODE_PREFILL + # Update data_parallel_size based on target_nodes - # This ensures the pipeline simulation calculates the correct number of microbatches - # NOTE: For MoE models, EP does NOT reduce DP (experts are distributed but tokens are replicated) - # DP = world_size / (TP × PP × CP) [EP is excluded] + # This ensures the pipeline simulation calculates the correct number of microbatches. + # DP = world_size / (TP × PP × CP) for both dense and MoE. With MoE Parallel + # Folding, EP borrows from the DP dimension (CP is folded into EP), so EP + # does not appear in the DP divisor. mp_config = training_config.model_parallel_config tp = mp_config.tensor_model_parallel_size pp = mp_config.pipeline_model_parallel_size @@ -1787,7 +3062,7 @@ def launch_projection_from_cli(args, overrides): # The pipeline simulator simulates the target config, so it needs target DP for microbatch calculation target_world_size = target_nodes * gpus_per_node - # For MoE models: DP calculation excludes EP since experts are distributed but data is replicated + # DP = world_size / (TP × PP × CP) — EP excluded (borrows from DP via folding) target_dp = target_world_size // (tp * pp * cp) # Also show benchmark config for reference @@ -1799,12 +3074,85 @@ def launch_projection_from_cli(args, overrides): # Only print from rank 0 is_rank_0 = int(os.getenv("RANK", "0")) == 0 + mode_label = "Prefill" if projection_mode == MODE_PREFILL else "Training" if is_rank_0: - print("[Primus:Performance Projection] Configuration Summary:") + print(f"[Primus:{mode_label} Projection] Configuration Summary:") print( f" Benchmark Config: PP={benchmark_pp}, EP={benchmark_ep}, TP={tp}, CP={cp}, DP={benchmark_dp} (1 node)" ) print(f" Target Config: PP={pp}, EP={ep}, TP={tp}, CP={cp}, DP={target_dp} ({target_nodes} nodes)") + print(f" Mode: {projection_mode}") + + # ========================================================================= + # PREFILL MODE — forward-only projection, no backward/optimizer/gradient + # ========================================================================= + if projection_mode == MODE_PREFILL: + # For inference, EP overhead adjustment is forward-only + if reduction_info["adjusted"] and reduction_info["original_ep"] != reduction_info["benchmark_ep"]: + original_ep = reduction_info["original_ep"] + benchmark_ep = reduction_info["benchmark_ep"] + original_num_experts = reduction_info.get("original_num_experts") + benchmark_num_experts = reduction_info.get("benchmark_num_experts") + + hardware_config_dict = None + if hasattr(args, "hardware_config") and args.hardware_config: + hardware_config_dict = load_hardware_config(args.hardware_config) + + fwd_overhead_per_layer, _ = _estimate_ep_communication_overhead( + training_config, + original_ep, + benchmark_ep, + hardware_config_dict, + ) + ep_mlp_scale = _compute_ep_mlp_scale( + training_config.model_config, + benchmark_ep, + original_ep, + original_num_experts=original_num_experts, + benchmark_num_experts=benchmark_num_experts, + ) + + if is_rank_0: + print("[Primus:Inference Projection] Adjusting profiling for EP (forward-only):") + print(f" EP rescaled: {benchmark_ep} → {original_ep}") + print(f" MLP fwd scale factor: {ep_mlp_scale:.3f}") + if fwd_overhead_per_layer > 0: + print(f" A2A fwd delta: +{fwd_overhead_per_layer:.3f} ms/layer") + + for layer_idx, layer_data in profiling_results.items(): + if isinstance(layer_data, dict) and layer_data.get("type") == "moe": + old_fwd = layer_data.get("forward_time_ms", 0) + mlp_info = layer_data.get("mlp", {}) + mlp_fwd = mlp_info.get("forward_time_ms", 0) + new_mlp_fwd = mlp_fwd * ep_mlp_scale + mlp_delta_fwd = new_mlp_fwd - mlp_fwd + new_fwd = old_fwd + mlp_delta_fwd + fwd_overhead_per_layer + layer_data["forward_time_ms"] = new_fwd + if mlp_info: + mlp_info["forward_time_ms"] = new_mlp_fwd + + # Extract forward-only time + forward_time_ms = extract_single_node_time_inference(profiling_results, training_config) + + # Run inference projection + if target_nodes >= min_nodes_required: + if is_rank_0: + print("" + "=" * 100) + print("[Primus:Inference] Running inference projection") + print("=" * 100) + + _run_inference_projection( + training_config, + forward_time_ms, + profiling_results, + args, + target_nodes, + ) + return + + # ========================================================================= + # TRAINING MODE — full forward + backward + optimizer + gradient AllReduce + # ========================================================================= # Use BENCHMARK DP for pipeline simulation to get consistent baseline # The multinode projection will then scale from this baseline to target @@ -1836,13 +3184,15 @@ def launch_projection_from_cli(args, overrides): if reduction_info["adjusted"] and reduction_info["original_ep"] != reduction_info["benchmark_ep"]: original_ep = reduction_info["original_ep"] benchmark_ep = reduction_info["benchmark_ep"] + original_num_experts = reduction_info.get("original_num_experts") + benchmark_num_experts = reduction_info.get("benchmark_num_experts") # Load hardware config if provided hardware_config_dict = None if hasattr(args, "hardware_config") and args.hardware_config: hardware_config_dict = load_hardware_config(args.hardware_config) - # Calculate EP communication overhead per layer + # Calculate EP communication overhead per layer (A2A delta) fwd_overhead_per_layer, bwd_overhead_per_layer = _estimate_ep_communication_overhead( training_config, original_ep, @@ -1850,68 +3200,70 @@ def launch_projection_from_cli(args, overrides): hardware_config_dict, ) - # EP compute scaling: when EP increases, each GPU handles fewer routed - # expert tokens, but shared expert compute stays constant. - # Use _compute_ep_mlp_scale to get the correct fraction-aware scale. - ep_mlp_scale = _compute_ep_mlp_scale(training_config.model_config, benchmark_ep, original_ep) + # EP compute scaling. Per-GPU routed compute is EP-invariant + # (A2A redistributes tokens so each GPU always processes + # batch_tokens × topk total token-expert pairs). When + # _rescale_expert_parallelism preserves experts_per_rank the + # profiled/simulated MLP time already reflects the correct + # per-rank workload, so ep_mlp_scale == 1.0. + ep_mlp_scale = _compute_ep_mlp_scale( + training_config.model_config, + benchmark_ep, + original_ep, + original_num_experts=original_num_experts, + benchmark_num_experts=benchmark_num_experts, + ) if is_rank_0: print("[Primus:Performance Projection] Adjusting profiling results for EP scaling:") print(f" EP rescaled: {benchmark_ep} → {original_ep}") - print(f" MLP time scale factor: {ep_mlp_scale:.3f}") - # Show shared vs routed breakdown - topk = getattr(training_config.model_config, "moe_router_topk", 1) or 1 - moe_ffn = getattr(training_config.model_config, "moe_ffn_hidden_size", None) - shared_ffn = getattr( - training_config.model_config, - "moe_shared_expert_intermediate_size", - None, - ) - num_shared = getattr(training_config.model_config, "num_shared_experts", 0) or 0 - if moe_ffn and num_shared > 0 and shared_ffn: - routed_flops = (topk / benchmark_ep) * moe_ffn - shared_flops = num_shared * shared_ffn - total_flops = routed_flops + shared_flops + if original_num_experts is not None and benchmark_num_experts is not None: + orig_epr = original_num_experts // original_ep + bench_epr = benchmark_num_experts // benchmark_ep print( - f" Routed fraction: {routed_flops/total_flops:.1%} (topk={topk}, EP={benchmark_ep}, ffn={moe_ffn})" + f" Experts per rank: benchmark={bench_epr} " + f"(E={benchmark_num_experts}, EP={benchmark_ep}), " + f"target={orig_epr} " + f"(E={original_num_experts}, EP={original_ep})" ) - print( - f" Shared fraction: {shared_flops/total_flops:.1%} ({num_shared} shared expert(s), ffn={shared_ffn})" - ) - else: - print(f" No shared experts — full routed scaling ({benchmark_ep}/{original_ep})") + print(f" MLP time scale factor: {ep_mlp_scale:.3f}") if fwd_overhead_per_layer > 0 or bwd_overhead_per_layer > 0: print(f" Adding per-layer All-to-All overhead:") print(f" Forward: +{fwd_overhead_per_layer:.3f} ms/layer") print(f" Backward: +{bwd_overhead_per_layer:.3f} ms/layer") - # Adjust MoE layer times in profiling_results + # Adjust MoE layer times in profiling_results using a DELTA approach: + # new_layer_time = old_layer_time + (mlp_delta) + (a2a_delta) + # This preserves TP AllReduce, A2A(benchmark_ep), LayerNorm, and + # other components that are baked into the profiled layer time. moe_layers_adjusted = 0 for layer_idx, layer_data in profiling_results.items(): if isinstance(layer_data, dict) and layer_data.get("type") == "moe": + old_fwd = layer_data.get("forward_time_ms", 0) + old_bwd = layer_data.get("backward_time_ms", 0) + mlp_info = layer_data.get("mlp", {}) mlp_fwd = mlp_info.get("forward_time_ms", 0) mlp_bwd = mlp_info.get("backward_time_ms", 0) - attn_info = layer_data.get("attention", {}) - attn_fwd = attn_info.get("forward_time_ms", 0) - attn_bwd = attn_info.get("backward_time_ms", 0) - # Scale MLP compute (shared-expert-aware), keep attention unchanged + # Compute MLP delta (usually 0 when scale == 1.0) new_mlp_fwd = mlp_fwd * ep_mlp_scale new_mlp_bwd = mlp_bwd * ep_mlp_scale + mlp_delta_fwd = new_mlp_fwd - mlp_fwd + mlp_delta_bwd = new_mlp_bwd - mlp_bwd - # New layer time = attention + scaled MLP + A2A comm overhead - new_fwd = attn_fwd + new_mlp_fwd + fwd_overhead_per_layer - new_bwd = attn_bwd + new_mlp_bwd + bwd_overhead_per_layer + # Delta approach: start from the full profiled layer time + # (which includes TP AR, A2A(benchmark_ep), norms, etc.) + # and add the MLP compute delta + A2A comm delta. + new_fwd = old_fwd + mlp_delta_fwd + fwd_overhead_per_layer + new_bwd = old_bwd + mlp_delta_bwd + bwd_overhead_per_layer if is_rank_0 and moe_layers_adjusted == 0: - old_fwd = layer_data.get("forward_time_ms", 0) - old_bwd = layer_data.get("backward_time_ms", 0) print(f" MoE layer adjustment (per layer):") print(f" MLP fwd: {mlp_fwd:.2f} → {new_mlp_fwd:.2f} ms (×{ep_mlp_scale:.3f})") print(f" MLP bwd: {mlp_bwd:.2f} → {new_mlp_bwd:.2f} ms (×{ep_mlp_scale:.3f})") - print(f" Attn fwd: {attn_fwd:.2f} ms (unchanged)") - print(f" Attn bwd: {attn_bwd:.2f} ms (unchanged)") + print(f" A2A fwd delta: +{fwd_overhead_per_layer:.3f} ms") + print(f" A2A bwd delta: +{bwd_overhead_per_layer:.3f} ms") print(f" Layer fwd: {old_fwd:.2f} → {new_fwd:.2f} ms") print(f" Layer bwd: {old_bwd:.2f} → {new_bwd:.2f} ms") @@ -2022,9 +3374,16 @@ def launch_projection_from_cli(args, overrides): # Total EP overhead = per-layer overhead * number of MoE layers total_ep_overhead_ms = (fwd_overhead_per_layer + bwd_overhead_per_layer) * num_moe_layers - # EP compute scaling (shared-expert-aware) + # EP compute scaling — per-GPU MoE routed compute is + # EP-invariant (see _compute_ep_mlp_scale docstring). + original_num_experts = reduction_info.get("original_num_experts") + benchmark_num_experts = reduction_info.get("benchmark_num_experts") ep_mlp_scale = _compute_ep_mlp_scale( - training_config.model_config, benchmark_ep_val, original_ep + training_config.model_config, + benchmark_ep_val, + original_ep, + original_num_experts=original_num_experts, + benchmark_num_experts=benchmark_num_experts, ) # Estimate MLP portion of MoE layer time from profiling results mlp_time_reduction = 0.0 diff --git a/primus/core/projection/simulation_backends/__init__.py b/primus/core/projection/simulation_backends/__init__.py new file mode 100644 index 000000000..898900ddc --- /dev/null +++ b/primus/core/projection/simulation_backends/__init__.py @@ -0,0 +1,23 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +from primus.core.projection.simulation_backends.base import ( + GEMMSimulationBackend, + SDPASimulationBackend, + SimulationResult, +) +from primus.core.projection.simulation_backends.factory import ( + get_gemm_simulation_backend, + get_sdpa_simulation_backend, +) + +__all__ = [ + "GEMMSimulationBackend", + "SDPASimulationBackend", + "SimulationResult", + "get_gemm_simulation_backend", + "get_sdpa_simulation_backend", +] diff --git a/primus/core/projection/simulation_backends/base.py b/primus/core/projection/simulation_backends/base.py new file mode 100644 index 000000000..0b4a8fa21 --- /dev/null +++ b/primus/core/projection/simulation_backends/base.py @@ -0,0 +1,315 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Abstract base classes for GEMM and SDPA simulation backends. + +These backends provide simulated (analytical/model-based) timing for GEMM and +SDPA operations, allowing performance projection without running actual GPU +kernels. Two concrete GEMM backends are shipped: + +- **Origami** (open-source, default) – ``origami_backend.py`` + +An SDPA simulation backend is provided in ``sdpa_simulator.py``. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + + +@dataclass +class SimulationResult: + """Result from a simulation backend.""" + + # Predicted time in milliseconds + forward_time_ms: float = 0.0 + backward_time_ms: float = 0.0 + + # Optional: predicted TFLOPS / bandwidth + tflops: Optional[float] = None + bandwidth_gbps: Optional[float] = None + + # Optional: extra metadata from the backend + metadata: Dict[str, Any] = field(default_factory=dict) + + +class GEMMSimulationBackend(ABC): + """Abstract interface for GEMM simulation backends.""" + + @abstractmethod + def name(self) -> str: + """Return human-readable backend name.""" + ... + + @abstractmethod + def is_available(self) -> bool: + """Return True if this backend can be used in the current environment.""" + ... + + @property + def hbm_bandwidth_gbps(self) -> Optional[float]: + """Peak HBM bandwidth in GB/s for the target GPU, or *None* if unknown. + + Concrete backends should override this when the target architecture is + known so that downstream profilers (e.g. MoE non-GEMM overhead) can + derive memory-bandwidth estimates from the actual hardware rather than + relying on hardcoded absolute numbers. + """ + return None + + @abstractmethod + def simulate_gemm( + self, + m: int, + n: int, + k: int, + dtype: str = "bf16", + trans_a: bool = False, + trans_b: bool = False, + batch: int = 1, + ) -> SimulationResult: + """ + Simulate a single GEMM (or batched GEMM) and return predicted timing. + + Args: + m, n, k: Matrix dimensions (C = A @ B, A:[M,K] B:[K,N] C:[M,N]) + dtype: Data type string ("bf16", "fp16", "fp8", "fp32") + trans_a: Whether A is transposed + trans_b: Whether B is transposed + batch: Number of independent GEMMs with the **same** (M, N, K) to + run as a **batched** GEMM. For MoE experts this is used as an + approximation of grouped GEMM under the assumption of uniform + token distribution across experts (so every sub-problem has the + same M = tokens_per_expert). Defaults to 1 (single GEMM). + + NOTE: Origami's ``problem.batch`` models batched GEMM, not true + grouped GEMM (where each sub-problem can have a different M). + This is an acceptable approximation when token distribution is + assumed uniform. If Origami adds native grouped-GEMM support + in the future, this should be updated. + + Returns: + SimulationResult with forward_time_ms populated. + """ + ... + + def simulate_mlp_gemms( + self, + batch_tokens: int, + hidden_size: int, + ffn_hidden_size: int, + dtype: str = "bf16", + swiglu: bool = False, + num_experts: int = 1, + ) -> SimulationResult: + """ + Simulate the GEMM operations in a dense MLP (gate/up/down projections). + + Default implementation calls ``simulate_gemm`` for each projection and + sums the times. Backends may override for better accuracy. + + When ``num_experts > 1``, each GEMM is simulated as a **batched** GEMM + (``batch = num_experts``) with per-expert token count, which serves as + an approximation of grouped GEMM under uniform token distribution. + See ``simulate_gemm`` docstring for caveats. + + Args: + batch_tokens: Number of tokens per expert + (batch_size * seq_len / TP / CP for dense; + topk_tokens / num_local_experts for MoE routed experts) + hidden_size: Model hidden dimension + ffn_hidden_size: FFN intermediate dimension + dtype: Data type string + swiglu: Whether SwiGLU activation is used (3 projections vs 2) + num_experts: Number of local experts. When > 1, each projection + GEMM uses ``batch = num_experts`` as a batched-GEMM + approximation of grouped GEMM. Defaults to 1 (dense MLP). + + Returns: + SimulationResult with forward_time_ms and backward_time_ms. + """ + # Use batched GEMM (batch=num_experts) as approximation of grouped GEMM. + # Valid under uniform token distribution (all experts get the same M). + # TODO: switch to native grouped-GEMM simulation if/when Origami supports it. + b = num_experts + + if swiglu: + # Gate projection fwd: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] + gate_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + # Up projection fwd: same shape as gate + up_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + # Down projection fwd: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] + down_fwd = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + + fwd_time = gate_fwd.forward_time_ms + up_fwd.forward_time_ms + down_fwd.forward_time_ms + + # Backward: simulate actual dgrad + wgrad GEMMs per projection + # Gate dgrad: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] + gate_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + # Gate wgrad: [hidden, tokens] x [tokens, ffn] -> [hidden, ffn] + gate_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) + # Up dgrad + wgrad: same shapes as gate + up_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + up_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) + # Down dgrad: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] + down_dgrad = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + # Down wgrad: [ffn, tokens] x [tokens, hidden] -> [ffn, hidden] + down_wgrad = self.simulate_gemm(ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b) + + bwd_time = ( + gate_dgrad.forward_time_ms + + gate_wgrad.forward_time_ms + + up_dgrad.forward_time_ms + + up_wgrad.forward_time_ms + + down_dgrad.forward_time_ms + + down_wgrad.forward_time_ms + ) + else: + # Up projection fwd: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] + up_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + # Down projection fwd: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] + down_fwd = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + + fwd_time = up_fwd.forward_time_ms + down_fwd.forward_time_ms + + # Backward: simulate actual dgrad + wgrad GEMMs per projection + # Up dgrad: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] + up_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + # Up wgrad: [hidden, tokens] x [tokens, ffn] -> [hidden, ffn] + up_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) + # Down dgrad: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] + down_dgrad = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + # Down wgrad: [ffn, tokens] x [tokens, hidden] -> [ffn, hidden] + down_wgrad = self.simulate_gemm(ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b) + + bwd_time = ( + up_dgrad.forward_time_ms + + up_wgrad.forward_time_ms + + down_dgrad.forward_time_ms + + down_wgrad.forward_time_ms + ) + + return SimulationResult(forward_time_ms=fwd_time, backward_time_ms=bwd_time) + + def simulate_attention_gemms( + self, + batch_tokens: int, + hidden_size: int, + num_attention_heads: int, + kv_channels: int, + num_query_groups: int, + dtype: str = "bf16", + ) -> SimulationResult: + """ + Simulate the linear projection GEMMs in the attention block + (QKV projections + output projection). Does NOT include the SDPA + computation itself – use SDPASimulationBackend for that. + + Default implementation calls ``simulate_gemm`` for Q, K, V, O projections. + + Returns: + SimulationResult with forward_time_ms and backward_time_ms. + """ + fwd_time = 0.0 + bwd_time = 0.0 + + # Q projection fwd: [tokens, hidden] x [hidden, heads*kv_channels] + q_out = num_attention_heads * kv_channels + q_fwd = self.simulate_gemm(batch_tokens, q_out, hidden_size, dtype) + fwd_time += q_fwd.forward_time_ms + + # K projection fwd: [tokens, hidden] x [hidden, num_query_groups*kv_channels] + k_out = num_query_groups * kv_channels + k_fwd = self.simulate_gemm(batch_tokens, k_out, hidden_size, dtype) + fwd_time += k_fwd.forward_time_ms + + # V projection fwd: same shape as K + v_fwd = self.simulate_gemm(batch_tokens, k_out, hidden_size, dtype) + fwd_time += v_fwd.forward_time_ms + + # Output projection fwd: [tokens, heads*kv_channels] x [heads*kv_channels, hidden] + o_fwd = self.simulate_gemm(batch_tokens, hidden_size, q_out, dtype) + fwd_time += o_fwd.forward_time_ms + + # Backward: simulate actual dgrad + wgrad GEMMs per projection + # Q dgrad: [tokens, q_out] x [q_out, hidden] -> [tokens, hidden] + q_dgrad = self.simulate_gemm(batch_tokens, hidden_size, q_out, dtype) + # Q wgrad: [hidden, tokens] x [tokens, q_out] -> [hidden, q_out] + q_wgrad = self.simulate_gemm(hidden_size, q_out, batch_tokens, dtype) + bwd_time += q_dgrad.forward_time_ms + q_wgrad.forward_time_ms + + # K dgrad: [tokens, k_out] x [k_out, hidden] -> [tokens, hidden] + k_dgrad = self.simulate_gemm(batch_tokens, hidden_size, k_out, dtype) + # K wgrad: [hidden, tokens] x [tokens, k_out] -> [hidden, k_out] + k_wgrad = self.simulate_gemm(hidden_size, k_out, batch_tokens, dtype) + bwd_time += k_dgrad.forward_time_ms + k_wgrad.forward_time_ms + + # V dgrad + wgrad: same shapes as K + v_dgrad = self.simulate_gemm(batch_tokens, hidden_size, k_out, dtype) + v_wgrad = self.simulate_gemm(hidden_size, k_out, batch_tokens, dtype) + bwd_time += v_dgrad.forward_time_ms + v_wgrad.forward_time_ms + + # O dgrad: [tokens, hidden] x [hidden, q_out] -> [tokens, q_out] + o_dgrad = self.simulate_gemm(batch_tokens, q_out, hidden_size, dtype) + # O wgrad: [q_out, tokens] x [tokens, hidden] -> [q_out, hidden] + o_wgrad = self.simulate_gemm(q_out, hidden_size, batch_tokens, dtype) + bwd_time += o_dgrad.forward_time_ms + o_wgrad.forward_time_ms + + return SimulationResult(forward_time_ms=fwd_time, backward_time_ms=bwd_time) + + +class SDPASimulationBackend(ABC): + """Abstract interface for Scaled Dot-Product Attention simulation.""" + + @abstractmethod + def name(self) -> str: + """Return human-readable backend name.""" + ... + + @abstractmethod + def is_available(self) -> bool: + """Return True if this backend can be used in the current environment.""" + ... + + @abstractmethod + def simulate_sdpa( + self, + batch_size: int, + num_heads: int, + seq_len: int, + head_dim: int, + causal: bool = True, + dtype: str = "bf16", + seq_len_kv: Optional[int] = None, + num_heads_kv: Optional[int] = None, + head_dim_v: Optional[int] = None, + ) -> SimulationResult: + """ + Simulate a Scaled Dot-Product Attention operation. + + Args: + batch_size: Batch size + num_heads: Number of query attention heads (per TP rank) + seq_len: Query sequence length (per CP rank) + head_dim: Head dimension used in the Q·Kᵀ dot-product. For + standard MHA/GQA this equals ``kv_channels``; for MLA it + equals ``qk_head_dim + qk_pos_emb_head_dim`` (e.g. 192). + causal: Whether causal masking is used + dtype: Data type string + seq_len_kv: Key/Value sequence length. Defaults to ``seq_len`` + (self-attention). + num_heads_kv: Number of KV heads. Defaults to ``num_heads`` + (MHA). Set lower for GQA / MQA. + head_dim_v: Value head dimension used in the P·V product and + for sizing the output O. Defaults to ``head_dim`` (standard + attention where Q/K/V all share the same dimension). For + MLA this should be set to ``v_head_dim`` (e.g. 128). + + Returns: + SimulationResult with forward_time_ms and backward_time_ms. + """ + ... diff --git a/primus/core/projection/simulation_backends/factory.py b/primus/core/projection/simulation_backends/factory.py new file mode 100644 index 000000000..e050f1f8c --- /dev/null +++ b/primus/core/projection/simulation_backends/factory.py @@ -0,0 +1,100 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Factory functions for creating simulation backends. + +Backend selection for GEMM: + 1. If ``PRIMUS_GEMM_BACKEND`` is set, use that backend explicitly. + 2. Otherwise, use **origami** (the default, open-source backend). + +SDPA always uses the built-in analytical simulator. +""" + +import os +from typing import Optional + +from primus.core.projection.simulation_backends.base import ( + GEMMSimulationBackend, + SDPASimulationBackend, +) + + +def get_gemm_simulation_backend( + backend_name: Optional[str] = None, + gpu_arch: Optional[str] = None, + gpu_clock_mhz: Optional[int] = None, +) -> GEMMSimulationBackend: + """ + Create and return the GEMM simulation backend (origami). + + Args: + backend_name: Explicit backend name. Currently only "origami" is supported. + If None, defaults to origami. + gpu_arch: GPU architecture override (e.g. "gfx942", "mi300x", "mi325x"). + gpu_clock_mhz: Override the GPU compute clock frequency in MHz. + + Returns: + A GEMMSimulationBackend instance. + + Raises: + RuntimeError: If the backend is not available. + """ + name = backend_name or os.getenv("PRIMUS_GEMM_BACKEND", None) + + if name is not None: + name = name.lower().strip() + + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + + if name is not None and name != "origami": + raise ValueError(f"Unknown GEMM simulation backend: '{name}'. " f"Supported backend: 'origami'") + + from primus.core.projection.simulation_backends.origami_backend import ( + OrigamiGEMMBackend, + ) + + backend = OrigamiGEMMBackend(gpu_arch=gpu_arch, gpu_clock_mhz=gpu_clock_mhz) + if not backend.is_available(): + raise RuntimeError( + "Origami GEMM simulation backend is not available.\n" "Install it with: pip install origami" + ) + + if is_rank_0: + print("[Primus:Simulation] Using GEMM backend: origami") + return backend + + +def get_sdpa_simulation_backend( + gpu_arch: Optional[str] = None, + gpu_clock_mhz: Optional[int] = None, +) -> SDPASimulationBackend: + """ + Create and return the SDPA simulation backend. + + Uses the Origami 1-CU tile-level model of the FAv3 (Flash Attention v3) + kernels. Origami must be installed. + + Args: + gpu_arch: GPU architecture override (e.g. "mi300x", "mi355x"). + gpu_clock_mhz: Override the GPU compute clock frequency in MHz. + + Returns: + An SDPASimulationBackend instance. + + Raises: + RuntimeError: If the Origami backend is not available. + """ + from primus.core.projection.simulation_backends.sdpa_simulator import SDPASimulator + + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print("[Primus:Simulation] Using SDPA backend: sdpa_simulator (FAv3 Origami 1-CU)") + + return SDPASimulator( + gpu_arch=gpu_arch, + gpu_clock_mhz=gpu_clock_mhz, + ) diff --git a/primus/core/projection/simulation_backends/origami_backend.py b/primus/core/projection/simulation_backends/origami_backend.py new file mode 100644 index 000000000..be2bad979 --- /dev/null +++ b/primus/core/projection/simulation_backends/origami_backend.py @@ -0,0 +1,477 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Origami GEMM simulation backend. + +Origami is an open-source analytical performance model for GEMM kernels on +AMD GPUs (part of the ROCm ecosystem). It predicts kernel execution time +based on matrix dimensions, data type, tile configuration, and hardware +characteristics. + +This is the **default** backend for GEMM simulation in Primus performance +projection. + +Installation: + pip install git+https://github.com/ROCm/rocm-libraries.git#subdirectory=shared/origami/python + +Environment variables: + PRIMUS_GEMM_BACKEND – set to "origami" (or leave unset) to use this backend. + PRIMUS_GPU_ARCH – GPU architecture override (e.g. "gfx942", "gfx950"). + PRIMUS_GPU_DEVICE – GPU device index for hardware detection (default: 0). +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +from primus.core.projection.simulation_backends.base import ( + GEMMSimulationBackend, + SimulationResult, +) + +# --------------------------------------------------------------------------- +# Lazy import – we don't want to fail at module-import time. +# --------------------------------------------------------------------------- +_origami = None +_origami_available: Optional[bool] = None + + +def _try_import_origami(): + """Try to import origami and cache the result.""" + global _origami, _origami_available + if _origami_available is not None: + return _origami_available + + try: + import origami # type: ignore[import-untyped] + + _origami = origami + _origami_available = True + except ImportError: + _origami = None + _origami_available = False + + return _origami_available + + +# --------------------------------------------------------------------------- +# Known hardware profiles for GPU-less simulation via get_hardware_for_arch. +# --------------------------------------------------------------------------- +@dataclass +class _HardwareProfile: + """Parameters required by ``origami.get_hardware_for_arch``.""" + + arch_enum_name: str # attribute name on ``origami.architecture_t`` + n_cu: int + lds_capacity: int # bytes + l2_capacity: int # bytes (per XCD) + compute_clock_khz: int + hbm_bandwidth_gbps: float = 5300.0 # peak HBM bandwidth (GB/s) + + +_KNOWN_PROFILES: Dict[str, _HardwareProfile] = { + # MI300X / gfx942: HBM3 ~5.3 TB/s + "mi300x": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000, 5300.0), + "gfx942": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000, 5300.0), + # MI325X / gfx942 (same die as MI300X, HBM3E upgrade): ~6.0 TB/s + "mi325x": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000, 6000.0), + # MI355X / gfx950: HBM3E ~8.0 TB/s + "mi355x": _HardwareProfile("gfx950", 256, 65536, 4_194_304, 2_100_000, 8000.0), + "gfx950": _HardwareProfile("gfx950", 256, 65536, 4_194_304, 2_100_000, 8000.0), + # MI300A + "mi300a": _HardwareProfile("gfx942", 228, 65536, 4_194_304, 2_100_000, 4000.0), +} + +# --------------------------------------------------------------------------- +# Dtype mapping: Primus string → origami datatype string +# (origami.string_to_datatype accepts these short-hand names) +# --------------------------------------------------------------------------- +_DTYPE_MAP: Dict[str, str] = { + "bf16": "bf16", + "fp16": "f16", + "fp32": "f32", + "fp8": "bf8_fnuz", +} + +# --------------------------------------------------------------------------- +# Default candidate tile configurations (macro-tile M×N×K + occupancy). +# A wider search space yields better latency predictions at the cost of +# slightly longer selection time (still << 1 ms per GEMM). +# --------------------------------------------------------------------------- +_DEFAULT_TILE_SIZES: List[Tuple[int, int, int]] = [ + (64, 64, 32), + (64, 64, 64), + (64, 128, 32), + (64, 128, 64), + (128, 64, 32), + (128, 64, 64), + (128, 128, 32), + (128, 128, 64), + (128, 256, 32), + (128, 256, 64), + (256, 128, 32), + (256, 128, 64), + (256, 256, 32), + (256, 256, 64), +] +_DEFAULT_OCCUPANCIES: List[int] = [1, 2, 4] + + +class OrigamiGEMMBackend(GEMMSimulationBackend): + """ + GEMM simulation backend using Origami (open-source). + + Hardware is obtained in one of two ways (in priority order): + + 1. **From the local GPU** – ``origami.get_hardware_for_device(device_idx)`` + is called when a ROCm-capable GPU is present. The device index defaults + to 0 and can be overridden via the ``PRIMUS_GPU_DEVICE`` env var. + 2. **From a known profile** – when no GPU is available *and* a + ``--gpu-arch`` / ``PRIMUS_GPU_ARCH`` is provided, the backend falls back + to ``origami.get_hardware_for_arch`` with hard-coded parameters for + MI300X, MI355X, etc. + """ + + def __init__( + self, + gpu_arch: Optional[str] = None, + gpu_clock_mhz: Optional[int] = None, + n_cu_override: Optional[int] = None, + ): + """ + Args: + gpu_arch: Target GPU architecture string (e.g. "gfx942", "mi300x"). + If *None*, auto-detected from the current GPU or the + ``PRIMUS_GPU_ARCH`` env var. + gpu_clock_mhz: Override the compute clock frequency in MHz. + If *None*, uses the profile default or the + ``PRIMUS_GPU_CLOCK_MHZ`` env var. + n_cu_override: Override the number of Compute Units used by + Origami's performance model. Set to ``1`` for + per-tile / single-CU simulation (e.g. SDPA tile-level + modelling). If *None*, the profile's default CU + count is used. + """ + self._gpu_arch = gpu_arch or os.getenv("PRIMUS_GPU_ARCH", None) + if self._gpu_arch is not None: + self._gpu_arch = self._gpu_arch.lower().strip() + + # Clock override: CLI > env var > profile default + _env_clock = os.getenv("PRIMUS_GPU_CLOCK_MHZ", None) + self._clock_override_mhz: Optional[int] = gpu_clock_mhz or (int(_env_clock) if _env_clock else None) + + self._n_cu_override = n_cu_override + + # Lazily initialised origami objects – see ``_ensure_initialized``. + self._hardware = None # origami.hardware_t + self._configs = None # list[origami.config_t] + self._clock_ghz: Optional[float] = None + self._initialized = False + self._init_dtype: Optional[str] = None # tracks dtype used for config init + self._fp8_mi_unavailable = False # set True if FP8 MI is 0x0x0 + + # ------------------------------------------------------------------ + # GEMMSimulationBackend interface + # ------------------------------------------------------------------ + + def name(self) -> str: + return "origami" + + def is_available(self) -> bool: + return _try_import_origami() + + @property + def hbm_bandwidth_gbps(self) -> Optional[float]: + """Peak HBM bandwidth for the target architecture (GB/s). + + Resolved from the arch profile (``_KNOWN_PROFILES``) or the + ``PRIMUS_GPU_ARCH`` env var. Returns ``None`` only when no + architecture could be determined. + """ + arch = self._gpu_arch or os.getenv("PRIMUS_GPU_ARCH", "mi300x") + arch = arch.lower().strip() + profile = _KNOWN_PROFILES.get(arch) + return profile.hbm_bandwidth_gbps if profile is not None else None + + def simulate_gemm( + self, + m: int, + n: int, + k: int, + dtype: str = "bf16", + trans_a: bool = False, + trans_b: bool = False, + batch: int = 1, + ) -> SimulationResult: + if not self.is_available(): + raise RuntimeError( + "Origami is not installed. Install with:\n" + " pip install git+https://github.com/ROCm/rocm-libraries.git" + "#subdirectory=shared/origami/python" + ) + + # FP8 fallback strategy + # ~~~~~~~~~~~~~~~~~~~~~~ + # Origami v0.1.0 maps the Primus "fp8" dtype to "bf8_fnuz" (the FNUZ + # variant used by gfx942 / MI300X), but this datatype string is not + # yet recognised by Origami's ``string_to_datatype``. As a result, + # ``_ensure_initialized("fp8")`` sets ``_fp8_mi_unavailable = True`` + # on *all* current architectures. + # + # When the FP8 MI is unavailable we simulate in BF16 and halve the + # time (``÷ 2.0`` below). This is a first-order approximation: + # FP8 doubles the matrix-instruction throughput relative to BF16 on + # both gfx942 and gfx950. + # + # Origami does expose *native* ``bf8`` / ``f8`` MI for gfx950 + # (MI=16×16×128 vs BF16 MI=16×16×32), but its tile-level performance + # model currently does not translate the higher MI throughput into + # proportional time savings — native FP8 predictions are only ~1.03× + # faster than BF16 instead of the expected ~2×. Until Origami's FP8 + # performance model is updated, the BF16 ÷ 2 heuristic is more + # accurate for all supported architectures. + fp8_fallback = False + sim_dtype = dtype + if dtype == "fp8": + self._ensure_initialized("fp8") + if self._fp8_mi_unavailable: + sim_dtype = "bf16" + fp8_fallback = True + self._ensure_initialized("bf16") + # else: origami supports FP8 natively, proceed normally + else: + self._ensure_initialized(dtype) + + # ----- Build origami problem_t ----- + # NOTE: problem.batch models **batched** GEMM (all sub-problems share + # the same M, N, K). For MoE experts this is used as an approximation + # of grouped GEMM under uniform token distribution. Origami does not + # currently expose a native grouped-GEMM model. + problem = _origami.problem_t() + problem.size = _origami.dim3_t(m, n, k) + problem.batch = batch + problem.a_transpose = _origami.transpose_t.T if trans_a else _origami.transpose_t.N + problem.b_transpose = _origami.transpose_t.T if trans_b else _origami.transpose_t.N + + origami_dtype = _origami.string_to_datatype(_DTYPE_MAP.get(sim_dtype, "bf16")) + problem.a_dtype = origami_dtype + problem.b_dtype = origami_dtype + problem.c_dtype = origami_dtype + problem.d_dtype = origami_dtype + problem.mi_dtype = origami_dtype + problem.a_mx_block_size = 0 + problem.b_mx_block_size = 0 + + # ----- Select best config & predict latency (in clock cycles) ----- + try: + result = _origami.select_config(problem, self._hardware, self._configs) + except Exception as e: + raise RuntimeError( + f"Origami select_config failed for " f"(M={m}, N={n}, K={k}, dtype={dtype}): {e}" + ) from e + + latency_cycles = result.latency + time_ms = latency_cycles / (self._clock_ghz * 1e6) + + # FP8 speedup heuristic: we simulated in BF16 because Origami could + # not use a native FP8 matrix instruction. FP8 MFMA throughput is 2× + # that of BF16 on both gfx942 and gfx950, so halving the BF16 time is + # a reasonable first-order approximation for compute-bound GEMMs. + if fp8_fallback: + time_ms /= 2.0 + + # Compute achieved TFLOPS for metadata + flops = 2.0 * m * n * k * batch + time_s = time_ms / 1e3 + tflops = (flops / time_s / 1e12) if time_s > 0 else 0.0 + + return SimulationResult( + forward_time_ms=time_ms, + backward_time_ms=0.0, # Caller computes bwd from fwd + tflops=tflops, + metadata={ + "backend": "origami", + "gpu_arch": self._gpu_arch, + "latency_cycles": latency_cycles, + "clock_ghz": self._clock_ghz, + "dtype": dtype, + "batch": batch, + "fp8_fallback": fp8_fallback, + "best_tile": ( + result.config.mt.m, + result.config.mt.n, + result.config.mt.k, + ), + "best_occupancy": result.config.occupancy, + }, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _ensure_initialized(self, dtype: str = "bf16") -> None: + """Lazily initialise hardware and candidate configs. + + If called again with a *different* dtype the candidate config list is + rebuilt (the matrix-instruction size changes between BF16 and FP8). + """ + # Hardware only needs to be detected once. + if self._hardware is None: + self._hardware = self._get_hardware() + self._clock_ghz = self._hardware.compute_clock_ghz + + # (Re-)build candidate configs when the dtype changes. + if self._initialized and self._init_dtype == dtype: + return + + origami_str = _DTYPE_MAP.get(dtype, "bf16") + try: + origami_dtype = _origami.string_to_datatype(origami_str) + except (ValueError, RuntimeError): + # Origami doesn't know this dtype string — treat as unavailable + if dtype == "fp8": + self._fp8_mi_unavailable = True + return + + mi = self._hardware.get_recommended_matrix_instruction(origami_dtype) + + # Detect MI=0x0x0 (datatype not supported by this arch) + if mi.m == 0 and mi.n == 0 and mi.k == 0: + if dtype == "fp8": + self._fp8_mi_unavailable = True + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print( + f"[Primus:Origami] FP8 matrix instruction not available " + f"for this hardware; will use BF16 with 2x speedup factor" + ) + return + + configs: list = [] + for mt_m, mt_n, mt_k in _DEFAULT_TILE_SIZES: + for occ in _DEFAULT_OCCUPANCIES: + cfg = _origami.config_t() + cfg.mt = _origami.dim3_t(mt_m, mt_n, mt_k) + cfg.mi = mi + cfg.occupancy = occ + configs.append(cfg) + self._configs = configs + self._init_dtype = dtype + + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print( + f"[Primus:Origami] Initialised: " + f"N_CU={self._hardware.N_CU}, " + f"NUM_XCD={self._hardware.NUM_XCD}, " + f"clock={self._clock_ghz} GHz, " + f"MI={mi.m}x{mi.n}x{mi.k}, " + f"dtype={dtype}, " + f"{len(configs)} candidate configs" + ) + + self._initialized = True + + def _get_hardware(self): + """ + Obtain an ``origami.hardware_t`` instance. + + Priority: + 1. If ``--gpu-arch`` was explicitly provided AND we have a known + profile for it, use the profile. This ensures consistent + results regardless of the local GPU (important for simulation + mode targeting a *different* GPU, e.g. simulating MI325X on + MI300X). + 2. Otherwise, try the local GPU. + 3. Fall back to the arch profile if no GPU is available. + """ + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + + # 1. If an explicit arch was requested AND we have a profile, use it. + if self._gpu_arch is not None and self._gpu_arch in _KNOWN_PROFILES: + profile = _KNOWN_PROFILES[self._gpu_arch] + clock_khz = profile.compute_clock_khz + if self._clock_override_mhz is not None: + clock_khz = self._clock_override_mhz * 1000 + n_cu = self._n_cu_override if self._n_cu_override is not None else profile.n_cu + arch_enum = getattr(_origami.architecture_t, profile.arch_enum_name) + hw = _origami.get_hardware_for_arch( + arch_enum, + n_cu, + profile.lds_capacity, + profile.l2_capacity, + clock_khz, + ) + if is_rank_0: + override_tag = "" + if self._clock_override_mhz is not None: + override_tag = " (overridden via --gpu-clock-mhz)" + cu_tag = f" (n_cu_override={n_cu})" if self._n_cu_override is not None else "" + print( + f"[Primus:Origami] Using hardware profile for " + f"'{self._gpu_arch}': N_CU={n_cu}, " + f"clock={clock_khz / 1e6:.1f} GHz{override_tag}{cu_tag}" + ) + return hw + + # 2. Try local GPU + device_idx = int(os.getenv("PRIMUS_GPU_DEVICE", "0")) + try: + hw = _origami.get_hardware_for_device(device_idx) + if is_rank_0: + print( + f"[Primus:Origami] Hardware detected from device {device_idx}: " + f"N_CU={hw.N_CU}, NUM_XCD={hw.NUM_XCD}, " + f"clock={hw.compute_clock_ghz} GHz" + ) + return hw + except Exception: + pass # No GPU – try arch-based profile + + # 3. Fall back to known profile + if self._gpu_arch is None: + raise RuntimeError( + "Origami could not detect a GPU and no --gpu-arch / " + "PRIMUS_GPU_ARCH was specified. Either run on a machine with " + "a ROCm GPU or provide a target architecture " + "(e.g. --gpu-arch mi300x)." + ) + + profile = _KNOWN_PROFILES.get(self._gpu_arch) + if profile is None: + supported = ", ".join(sorted(_KNOWN_PROFILES.keys())) + raise RuntimeError( + f"Unknown GPU architecture '{self._gpu_arch}' for origami. " + f"Supported architectures: {supported}" + ) + + clock_khz = profile.compute_clock_khz + if self._clock_override_mhz is not None: + clock_khz = self._clock_override_mhz * 1000 + n_cu = self._n_cu_override if self._n_cu_override is not None else profile.n_cu + arch_enum = getattr(_origami.architecture_t, profile.arch_enum_name) + hw = _origami.get_hardware_for_arch( + arch_enum, + n_cu, + profile.lds_capacity, + profile.l2_capacity, + clock_khz, + ) + if is_rank_0: + override_tag = "" + if self._clock_override_mhz is not None: + override_tag = " (overridden via --gpu-clock-mhz)" + cu_tag = f" (n_cu_override={n_cu})" if self._n_cu_override is not None else "" + print( + f"[Primus:Origami] Using known hardware profile for " + f"'{self._gpu_arch}': N_CU={n_cu}, " + f"clock={clock_khz / 1e6:.1f} GHz{override_tag}{cu_tag}" + ) + return hw diff --git a/primus/core/projection/simulation_backends/sdpa_simulator.py b/primus/core/projection/simulation_backends/sdpa_simulator.py new file mode 100644 index 000000000..cd1629aaa --- /dev/null +++ b/primus/core/projection/simulation_backends/sdpa_simulator.py @@ -0,0 +1,586 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +SDPA simulation backend modelling the **FAv3** (Flash Attention v3) kernels. + +The forward and backward kernel parameters are extracted from FAv3 kernel +configurations: + + Forward: + Config : BF16 ; FMHA FWD ; D128 ; 1TG ; 8W ; 32m×8 ; 64n×1 ; 32×32×16 + • 1 Thread-Group, 8 Wavefronts per workgroup (512 threads) + • Q-tile = 256 rows (32m × 8 wavefronts) + • KV-tile = 64 columns per loop iteration + • 64 MFMAs per loop iteration (QKᵀ + softmax + PV, pipelined) + • Workgroups = ⌈S / 256⌉ × B × H + + Backward: + Config : BF16 ; FMHA BWD ; D128 ; 1TG ; 4W ; 16m×1 ; 64n×4 ; A32 + • 1 Thread-Group, 4 Wavefronts per workgroup (256 threads) + • Q-tile = 16 rows per inner loop step + • KV-tile = 256 columns per workgroup (64n × 4) + • 256 MFMAs per inner-loop iteration (dV, dP, dS, dQ, dK phases) + • Workgroups = ⌈S / 256⌉ × B × H + • Inner-loop iterations = ⌈S / 16⌉ (over Q blocks) + +Tile-level simulation using Origami with a 1-CU backend. +Flash Attention is a fused kernel where sub-operations (QKᵀ, softmax, PV) +execute **sequentially** within each workgroup tile. By running Origami on +a single CU for each per-tile GEMM and then multiplying by the number of +waves (workgroups / N_CU), we capture the additive cost structure and +per-tile wave quantisation effects that a global roofline misses. + +**Origami is required** — this simulator cannot operate without it. + +In the backward pass, the dQ gradient is accumulated across KV-workgroups +using ``buffer_atomic_add_f32`` (72 atomic instructions in the kernel). +Each KV-workgroup processes all Q positions and atomically adds its partial +dQ contribution, leading to contention proportional to ⌈S / 256⌉ concurrent +writers per dQ cache line. The atomic overhead is modelled as an additive +cost on top of the compute/memory time. +""" + +from __future__ import annotations + +import math +import os +from dataclasses import dataclass +from typing import Dict, Optional + +from primus.core.projection.simulation_backends.base import ( + SDPASimulationBackend, + SimulationResult, +) + +# ========================================================================= +# FAv3 kernel tile parameters +# ========================================================================= + + +@dataclass(frozen=True) +class _FAv3TileConfig: + """Tile & occupancy parameters extracted from a FAv3 kernel.""" + + q_tile_m: int # Q rows per workgroup + kv_tile_n: int # K/V positions per loop iteration + n_wavefronts: int # Wavefronts per workgroup + mfma_m: int = 32 # MFMA instruction M + mfma_n: int = 32 # MFMA instruction N + mfma_k: int = 16 # MFMA instruction K (BF16 on gfx950) + + +# Forward: 256 Q-rows, 64 KV-cols/iter, 8 wavefronts +_FAV3_FWD = _FAv3TileConfig(q_tile_m=256, kv_tile_n=64, n_wavefronts=8) + +# Backward: 16 Q-rows/inner-iter, 256 KV-cols/workgroup, 4 wavefronts +_FAV3_BWD = _FAv3TileConfig(q_tile_m=16, kv_tile_n=256, n_wavefronts=4) + + +# ========================================================================= +# Backward dQ atomic latencies (for tile-level model) +# ========================================================================= +# dQ is accumulated via buffer_atomic_add_f32 across KV-workgroups. +# Latency estimates for CDNA3 (gfx942) at typical clocks: +_ATOMIC_LATENCY_GLOBAL_NS = 400 # HBM read-modify-write latency per atomic op +_ATOMIC_LATENCY_LOCAL_NS = 40 # L1 / LDS atomic latency per op +_WARP_SIZE = 64 # CDNA wavefront width + + +# ========================================================================= +# GPU hardware specs +# ========================================================================= + + +@dataclass +class GPUHardwareSpec: + """Hardware specification for roofline modelling.""" + + # Peak compute throughput in TFLOPS (tera floating-point ops / sec) + peak_tflops_bf16: float = 1307.0 # MI300X BF16 peak + peak_tflops_fp16: float = 1307.0 + peak_tflops_fp8: float = 2614.0 + + # HBM bandwidth in GB/s + hbm_bandwidth_gbps: float = 5300.0 # MI300X HBM3 + + # Total CUs on the device + n_cu: int = 304 # MI300X + + # Number of XCDs on the device (cross-die atomics are more expensive) + n_xcd: int = 8 # MI300X has 8 XCDs + + +# Pre-defined hardware profiles +_HW_PROFILES: Dict[str, GPUHardwareSpec] = { + "mi300x": GPUHardwareSpec( + peak_tflops_bf16=1307.0, + peak_tflops_fp16=1307.0, + peak_tflops_fp8=2614.0, + hbm_bandwidth_gbps=5300.0, + n_cu=304, + n_xcd=8, + ), + "gfx942": GPUHardwareSpec( # same as MI300X + peak_tflops_bf16=1307.0, + peak_tflops_fp16=1307.0, + peak_tflops_fp8=2614.0, + hbm_bandwidth_gbps=5300.0, + n_cu=304, + n_xcd=8, + ), + "mi325x": GPUHardwareSpec( # gfx942 die, HBM3E (use --gpu-clock-mhz to override clock) + peak_tflops_bf16=1307.0, + peak_tflops_fp16=1307.0, + peak_tflops_fp8=2614.0, + hbm_bandwidth_gbps=6000.0, # HBM3E ~6 TB/s (vs 5.3 on MI300X) + n_cu=304, + n_xcd=8, + ), + "mi355x": GPUHardwareSpec( + peak_tflops_bf16=2384.0, + peak_tflops_fp16=2384.0, + peak_tflops_fp8=4768.0, + hbm_bandwidth_gbps=8000.0, + n_cu=256, + n_xcd=4, + ), + "gfx950": GPUHardwareSpec( # same as MI355X + peak_tflops_bf16=2384.0, + peak_tflops_fp16=2384.0, + peak_tflops_fp8=4768.0, + hbm_bandwidth_gbps=8000.0, + n_cu=256, + n_xcd=4, + ), +} + + +def _get_hardware_spec( + gpu_arch: Optional[str] = None, + gpu_clock_mhz: Optional[int] = None, +) -> GPUHardwareSpec: + """Get hardware spec for the given (or detected) GPU architecture. + + If *gpu_clock_mhz* is provided, the profile's TFLOPS values are scaled + proportionally (compute throughput is linear in clock frequency). + """ + arch = gpu_arch or os.getenv("PRIMUS_GPU_ARCH", "mi300x") + arch = arch.lower().strip() + spec = _HW_PROFILES.get(arch, _HW_PROFILES["mi300x"]) + + # Apply clock override — scale TFLOPS linearly + clock_override = gpu_clock_mhz or (int(v) if (v := os.getenv("PRIMUS_GPU_CLOCK_MHZ")) else None) + if clock_override is not None: + # Derive the profile's implicit clock from a known reference. + _PROFILE_CLOCK_MHZ = { + "mi300x": 2100, + "gfx942": 2100, + "mi325x": 2100, # same gfx942 compute die as MI300X + "mi355x": 2100, + "gfx950": 2100, + "mi300a": 2100, + } + base_clock = _PROFILE_CLOCK_MHZ.get(arch, 2100) + scale = clock_override / base_clock + spec = GPUHardwareSpec( + peak_tflops_bf16=spec.peak_tflops_bf16 * scale, + peak_tflops_fp16=spec.peak_tflops_fp16 * scale, + peak_tflops_fp8=spec.peak_tflops_fp8 * scale, + hbm_bandwidth_gbps=spec.hbm_bandwidth_gbps, # BW doesn't change with clock + n_cu=spec.n_cu, + n_xcd=spec.n_xcd, + ) + return spec + + +# ========================================================================= +# SDPASimulator — FAv3-based analytical model +# ========================================================================= + + +class SDPASimulator(SDPASimulationBackend): + """ + Analytical SDPA simulation modelling the FAv3 kernel structure. + + Uses an Origami GEMM backend with ``n_cu=1`` to simulate per-tile + GEMM execution time. Flash Attention is modelled as a fused kernel + where QKᵀ, softmax, and PV are sequential within each workgroup. + The total time = (per-tile-QKᵀ + per-tile-PV) × num_waves. This + naturally captures wave quantisation and per-tile efficiency without + needing an empirical ``compute_efficiency`` parameter. + + Also models the backward dQ atomic overhead from + ``buffer_atomic_add_f32`` accumulation across KV-workgroups. + + **Origami is required** — instantiation will fail if the Origami + backend is not available. + """ + + def __init__( + self, + gpu_arch: Optional[str] = None, + hardware_spec: Optional[GPUHardwareSpec] = None, + gpu_clock_mhz: Optional[int] = None, + ): + """ + Args: + gpu_arch: GPU architecture string (e.g. "mi300x", "gfx942", + "mi355x", "gfx950"). + hardware_spec: Override hardware spec directly. + gpu_clock_mhz: Override the GPU compute clock frequency in MHz. + If provided, the profile's TFLOPS are scaled proportionally. + + Raises: + RuntimeError: If the Origami backend is not available. + """ + self._hw = hardware_spec or _get_hardware_spec(gpu_arch, gpu_clock_mhz) + + # Create the Origami 1-CU backend for tile-level simulation. + # This is required — SDPA simulation cannot proceed without it. + self._tile_gemm = self._create_tile_gemm_backend(gpu_arch, gpu_clock_mhz) + if self._tile_gemm is None: + raise RuntimeError( + "SDPASimulator requires the Origami backend but it is not " + "available. Please install the 'origami' package or ensure " + "primus.core.projection.simulation_backends.origami_backend " + "is importable." + ) + + def name(self) -> str: + return "sdpa_simulator (FAv3)" + + def is_available(self) -> bool: + return self._tile_gemm is not None and self._tile_gemm.is_available() + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def simulate_sdpa( + self, + batch_size: int, + num_heads: int, + seq_len: int, + head_dim: int, + causal: bool = True, + dtype: str = "bf16", + seq_len_kv: Optional[int] = None, + num_heads_kv: Optional[int] = None, + head_dim_v: Optional[int] = None, + ) -> SimulationResult: + """ + Simulate FAv3 SDPA execution time using Origami 1-CU tile-level + simulation parameterised by the actual FAv3 tile configuration. + + Args: + batch_size: Batch size (B). + num_heads: Number of query heads (H_Q). + seq_len: Query sequence length (S_Q). + head_dim: Head dimension for Q·Kᵀ (D_qk). For standard + MHA/GQA this is ``kv_channels``; for MLA this is + ``qk_head_dim + qk_pos_emb_head_dim`` (e.g. 192). + causal: Whether causal masking is applied. + dtype: Data type ("bf16", "fp16", "fp8", "fp32"). + seq_len_kv: Key/Value sequence length (S_K). Defaults to + ``seq_len`` (self-attention). Set differently for + cross-attention or prefill with separate KV cache length. + num_heads_kv: Number of KV heads. Defaults to ``num_heads`` + (MHA). Set lower for GQA/MQA. + head_dim_v: Value head dimension for P·V (D_v). Defaults to + ``head_dim`` (standard attention). For MLA set to + ``v_head_dim`` (e.g. 128). + """ + B = batch_size + H_Q = num_heads + S_Q = seq_len + S_K = seq_len_kv if seq_len_kv is not None else seq_len + H_K = num_heads_kv if num_heads_kv is not None else num_heads + D_qk = head_dim + D_v = head_dim_v if head_dim_v is not None else head_dim + bpe = self._bytes_per_element(dtype) + + return self._simulate_tile_level( + B, + H_Q, + S_Q, + S_K, + H_K, + D_qk, + D_v, + causal, + dtype, + bpe, + ) + + # ------------------------------------------------------------------ + # Tile-level simulation (primary mode — requires Origami) + # ------------------------------------------------------------------ + + def _create_tile_gemm_backend( + self, + gpu_arch: Optional[str], + gpu_clock_mhz: Optional[int], + ): + """Try to create an Origami backend with 1 CU for per-tile simulation. + + Returns the backend on success, or ``None`` if Origami is not available. + """ + try: + from primus.core.projection.simulation_backends.origami_backend import ( + OrigamiGEMMBackend, + ) + + backend = OrigamiGEMMBackend( + gpu_arch=gpu_arch, + gpu_clock_mhz=gpu_clock_mhz, + n_cu_override=1, + ) + if backend.is_available(): + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print("[Primus:SDPA] Using Origami 1-CU tile-level simulation " "for Flash Attention") + return backend + except Exception as exc: + # If Origami is not available or fails to initialize, fall back to + # the analytic SDPA model by returning None here. + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print("[Primus:SDPA] Origami 1-CU tile-level simulation disabled " f"due to error: {exc}") + return None + + def _simulate_tile_level( + self, + B: int, + H_Q: int, + S_Q: int, + S_K: int, + H_K: int, + D_qk: int, + D_v: int, + causal: bool, + dtype: str, + bpe: int, + ) -> SimulationResult: + """ + Tile-level SDPA simulation using Origami on a single CU. + + Flash Attention is a **fused** kernel — within each workgroup, the + sub-operations (QKᵀ, softmax, PV) execute **sequentially** and the + intermediate S/P matrices stay in LDS/registers (never written to + HBM). This means the correct timing model is **additive** across + sub-operations per tile, not a global ``max(compute, memory)`` + roofline. + + Approach: + 1. Simulate each per-tile GEMM using Origami with ``n_cu=1``. + This captures tile-level wave quantisation, LDS traffic, and + pipeline effects that a global FLOP-rate model misses. + 2. Sum the per-tile GEMM times (additive, sequential execution). + 3. Multiply by ``num_waves = ⌈workgroups / N_CU⌉`` to account for + CU-level parallelism across tiles. + 4. Add dQ atomic overhead (backward only). + + Forward per-workgroup (q_tile_m=256 Q rows, sweeps all S_K): + QKᵀ: Q_tile[256, D_qk] × Kᵀ[D_qk, S_K] → S[256, S_K] + PV: P_tile[256, S_K] × V[S_K, D_v] → O[256, D_v] + Workgroups = ⌈S_Q / 256⌉ × B × H_Q + + Backward per-workgroup (kv_tile_n=256 KV cols, sweeps all S_Q): + 5 GEMMs per workgroup (FA backward algorithm): + 1. QKᵀ recompute: Q[S_Q, D_qk] × Kᵀ[D_qk, 256] → S[S_Q, 256] + 2. dP = dO × Vᵀ: dO[S_Q, D_v] × Vᵀ[D_v, 256] → dP[S_Q, 256] + 3. dV = Pᵀ × dO: Pᵀ[256, S_Q] × dO[S_Q, D_v] → dV[256, D_v] + 4. dQ = dS × K: dS[S_Q, 256] × K[256, D_qk] → dQ[S_Q, D_qk] + 5. dK = dSᵀ × Q: dSᵀ[256, S_Q] × Q[S_Q, D_qk] → dK[256, D_qk] + Workgroups = ⌈S_K / 256⌉ × B × H_Q + """ + assert self._tile_gemm is not None + N_CU = self._hw.n_cu + causal_factor = 0.5 if causal else 1.0 + + # ============================================================== + # FORWARD + # ============================================================== + fwd_n_wgs = math.ceil(S_Q / _FAV3_FWD.q_tile_m) * B * H_Q + fwd_waves = math.ceil(fwd_n_wgs / N_CU) + + # Per-workgroup GEMMs on 1 CU (tile sweeps all S_K positions): + # QKᵀ: [q_tile_m, D_qk, S_K] + r_fwd_qk = self._tile_gemm.simulate_gemm( + m=_FAV3_FWD.q_tile_m, + n=S_K, + k=D_qk, + dtype=dtype, + ) + # PV: [q_tile_m, S_K, D_v] + r_fwd_pv = self._tile_gemm.simulate_gemm( + m=_FAV3_FWD.q_tile_m, + n=D_v, + k=S_K, + dtype=dtype, + ) + + fwd_time_ms = (r_fwd_qk.forward_time_ms + r_fwd_pv.forward_time_ms) * fwd_waves + + # ============================================================== + # BACKWARD + # ============================================================== + kv_tile = _FAV3_BWD.kv_tile_n # 256 + bwd_n_wgs = math.ceil(S_K / kv_tile) * B * H_Q + bwd_waves = math.ceil(bwd_n_wgs / N_CU) + + # Per-workgroup GEMMs (5 operations, full Q-sweep on 1 CU): + # 1. QKᵀ recompute: [S_Q, D_qk, kv_tile] + r_bwd_qk = self._tile_gemm.simulate_gemm( + m=S_Q, + n=kv_tile, + k=D_qk, + dtype=dtype, + ) + # 2. dP = dO × Vᵀ: [S_Q, D_v, kv_tile] + r_bwd_dp = self._tile_gemm.simulate_gemm( + m=S_Q, + n=kv_tile, + k=D_v, + dtype=dtype, + ) + # 3. dV = Pᵀ × dO: [kv_tile, S_Q, D_v] + r_bwd_dv = self._tile_gemm.simulate_gemm( + m=kv_tile, + n=D_v, + k=S_Q, + dtype=dtype, + ) + # 4. dQ = dS × K: [S_Q, kv_tile, D_qk] + r_bwd_dq = self._tile_gemm.simulate_gemm( + m=S_Q, + n=D_qk, + k=kv_tile, + dtype=dtype, + ) + # 5. dK = dSᵀ × Q: [kv_tile, S_Q, D_qk] + r_bwd_dk = self._tile_gemm.simulate_gemm( + m=kv_tile, + n=D_qk, + k=S_Q, + dtype=dtype, + ) + + bwd_compute_ms = ( + r_bwd_qk.forward_time_ms + + r_bwd_dp.forward_time_ms + + r_bwd_dv.forward_time_ms + + r_bwd_dq.forward_time_ms + + r_bwd_dk.forward_time_ms + ) * bwd_waves + + # ── Backward dQ atomics (latency-based model) ── + # Each KV-workgroup atomically accumulates dQ via buffer_atomic_add_f32. + # The latency model counts warp-level reduction updates (global and + # local) and multiplies by the per-op latency. + num_k_tiles = math.ceil(kv_tile / kv_tile) # = 1 + warp_updates_global = math.ceil(num_k_tiles * math.ceil(D_qk / _WARP_SIZE)) + total_updates_global = warp_updates_global * bwd_waves + + warp_updates_local = math.ceil(kv_tile * math.ceil(D_qk / _WARP_SIZE)) + total_updates_local = warp_updates_local * bwd_waves + + bwd_atomic_ms = ( + _ATOMIC_LATENCY_GLOBAL_NS * total_updates_global + _ATOMIC_LATENCY_LOCAL_NS * total_updates_local + ) / 1e6 # ns → ms + + bwd_time_ms = bwd_compute_ms + bwd_atomic_ms + + # ============================================================== + # METADATA (FLOPs, bytes — for achieved-TFLOPS reporting) + # ============================================================== + fwd_flops = ( + 2.0 * B * H_Q * S_Q * S_K * D_qk # QKᵀ + + 2.0 * B * H_Q * S_Q * S_K * D_v # PV + + 5.0 * B * H_Q * S_Q * S_K # softmax + ) * causal_factor + + bwd_flops = ( + 2.0 * B * H_Q * S_Q * S_K * D_qk # QKᵀ recomp + + 2.0 * B * H_Q * S_Q * S_K * D_v # dP + + 2.0 * B * H_Q * S_K * S_Q * D_v # dV + + 2.0 * B * H_Q * S_Q * S_K * D_qk # dQ + + 2.0 * B * H_Q * S_K * S_Q * D_qk # dK + + 5.0 * B * H_Q * S_Q * S_K # softmax bwd + ) * causal_factor + + fwd_bytes = ( + B * H_Q * S_Q * D_qk * bpe # Q + + B * H_K * S_K * D_qk * bpe # K + + B * H_K * S_K * D_v * bpe # V + + B * H_Q * S_Q * D_v * bpe # O + + B * H_Q * S_Q * 4 # logsumexp (fp32) + ) + bwd_bytes = ( + B * H_Q * S_Q * D_qk * bpe # Q + + B * H_K * S_K * D_qk * bpe # K + + B * H_K * S_K * D_v * bpe # V + + B * H_Q * S_Q * D_v * bpe # O + + B * H_Q * S_Q * D_v * bpe # dO + + B * H_Q * S_Q * 4 # logsumexp (fp32) + + B * H_K * S_K * D_qk * bpe # dK + + B * H_K * S_K * D_v * bpe # dV + ) + + fwd_achieved_tflops = (fwd_flops / (fwd_time_ms * 1e-3)) / 1e12 if fwd_time_ms > 0 else 0 + + return SimulationResult( + forward_time_ms=fwd_time_ms, + backward_time_ms=bwd_time_ms, + tflops=fwd_achieved_tflops, + bandwidth_gbps=((fwd_bytes / (fwd_time_ms * 1e-3)) / 1e9 if fwd_time_ms > 0 else 0), + metadata={ + "backend": "sdpa_simulator (FAv3 tile-level, Origami 1-CU)", + # Standard metadata keys (for compatibility) + "fwd_compute_bound": True, + "fwd_compute_ms": fwd_time_ms, + "fwd_memory_ms": 0.0, # included in per-tile Origami model + "bwd_bottleneck": "compute+atomic", + "bwd_compute_ms": bwd_compute_ms, + "bwd_memory_ms": 0.0, # included in per-tile Origami model + "bwd_atomic_ms": bwd_atomic_ms, + "fwd_flops": fwd_flops, + "bwd_flops": bwd_flops, + "fwd_bytes": fwd_bytes, + "bwd_bytes": bwd_bytes, + "seq_len_q": S_Q, + "seq_len_kv": S_K, + "num_heads_q": H_Q, + "num_heads_kv": H_K, + "causal": causal, + # Tile-level details + "fwd_waves": fwd_waves, + "fwd_n_workgroups": fwd_n_wgs, + "fwd_qk_per_tile_ms": r_fwd_qk.forward_time_ms, + "fwd_pv_per_tile_ms": r_fwd_pv.forward_time_ms, + "bwd_waves": bwd_waves, + "bwd_n_workgroups": bwd_n_wgs, + "bwd_qk_recomp_per_tile_ms": r_bwd_qk.forward_time_ms, + "bwd_dp_per_tile_ms": r_bwd_dp.forward_time_ms, + "bwd_dv_per_tile_ms": r_bwd_dv.forward_time_ms, + "bwd_dq_per_tile_ms": r_bwd_dq.forward_time_ms, + "bwd_dk_per_tile_ms": r_bwd_dk.forward_time_ms, + "n_cu": N_CU, + # FAv3 tile parameters + "fwd_q_tile_m": _FAV3_FWD.q_tile_m, + "fwd_kv_tile_n": _FAV3_FWD.kv_tile_n, + "bwd_q_tile_m": _FAV3_BWD.q_tile_m, + "bwd_kv_tile_n": _FAV3_BWD.kv_tile_n, + }, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _bytes_per_element(self, dtype: str) -> int: + return {"bf16": 2, "fp16": 2, "fp32": 4, "fp8": 1}.get(dtype, 2) diff --git a/primus/core/projection/training_config.py b/primus/core/projection/training_config.py index 81c964f3d..7e84fbe6b 100644 --- a/primus/core/projection/training_config.py +++ b/primus/core/projection/training_config.py @@ -63,6 +63,12 @@ class ModelConfig: moe_shared_expert_intermediate_size: int = 0 # Misc share_embeddings_and_output_weights: bool = False + # Precision – None means bf16, "hybrid" means FP8-hybrid (linear GEMMs in FP8) + fp8: str = None + + # Primus Turbo flags — used to select the grouped-GEMM performance model + enable_primus_turbo: bool = False + use_turbo_grouped_mlp: bool = False @dataclass