From 954890d3e2f0833f1a37e1bd0886f9c4660d0c28 Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Fri, 13 Feb 2026 23:16:16 +0000 Subject: [PATCH 01/12] Add simulation backends and improve performance projection - Add Origami GEMM simulation backend with MI300X/MI325X hardware profiles - Add FAv3 analytical SDPA simulator with roofline model, atomic overhead modeling for backward pass, and GQA/MQA support - Add simulation mode (--profiling-mode simulate) for GPU-free performance projection using Origami for GEMMs and analytical model for Flash Attention - Wire simulation backends into all module profilers (attention, dense MLP, MoE MLP, embedding, output layer, transformer layer) - Add --gemm-backend and --gpu-arch CLI arguments for simulation control - Fix FSDP communication model to double AllGather count when recompute_granularity='full' is enabled --- primus/cli/subcommands/projection.py | 43 +++- .../projection/module_profilers/attention.py | 91 ++++++-- .../projection/module_profilers/dense_mlp.py | 35 ++- .../projection/module_profilers/embedding.py | 48 ++++- .../module_profilers/language_model.py | 202 ++++++++++++------ .../projection/module_profilers/moe_mlp.py | 65 +++++- .../module_profilers/output_layer.py | 53 ++++- .../module_profilers/transformer_layer.py | 100 +++++++-- .../performance_projection/projection.py | 147 ++++++++++++- 9 files changed, 655 insertions(+), 129 deletions(-) diff --git a/primus/cli/subcommands/projection.py b/primus/cli/subcommands/projection.py index 53fd62e16..6fce8bece 100644 --- a/primus/cli/subcommands/projection.py +++ b/primus/cli/subcommands/projection.py @@ -14,9 +14,13 @@ def run(args, overrides): launch_projection_from_cli(args, overrides) elif args.suite == "performance": - from primus.pretrain import setup_backend_path + profiling_mode = getattr(args, "profiling_mode", "benchmark") - setup_backend_path(framework="megatron", verbose=True) + if profiling_mode != "simulate": + # Benchmark or "both" modes need the Megatron backend + from primus.pretrain import setup_backend_path + + setup_backend_path(framework="megatron", verbose=True) from primus.core.projection.performance_projection import ( launch_projection_from_cli, @@ -92,6 +96,41 @@ def register_subcommand(subparsers): "If not provided, uses default cluster parameters.\n\n" ), ) + performance.add_argument( + "--profiling-mode", + type=str, + required=False, + default="benchmark", + choices=["benchmark", "simulate", "both"], + help=( + "Profiling mode for layer timing:\n" + " benchmark - Run actual GPU benchmarks (default, requires GPU)\n" + " simulate - Use simulation backends (origami for GEMM,\n" + " analytical model for SDPA). No GPU required.\n" + " both - Run both benchmark and simulation, report side-by-side\n" + ), + ) + performance.add_argument( + "--gemm-backend", + type=str, + required=False, + default=None, + choices=["origami"], + help=( + "GEMM simulation backend (only used when --profiling-mode is 'simulate' or 'both').\n" + " origami - Open-source GEMM performance model (default)\n" + ), + ) + performance.add_argument( + "--gpu-arch", + type=str, + required=False, + default=None, + help=( + "Target GPU architecture for simulation (e.g. 'mi300x', 'gfx942', 'mi355x', 'gfx950').\n" + "If not specified, auto-detected or uses PRIMUS_GPU_ARCH env var.\n" + ), + ) parser.set_defaults(func=run) diff --git a/primus/core/projection/module_profilers/attention.py b/primus/core/projection/module_profilers/attention.py index 63f8dd6dd..22828dfb9 100644 --- a/primus/core/projection/module_profilers/attention.py +++ b/primus/core/projection/module_profilers/attention.py @@ -19,6 +19,8 @@ def __init__(self, config, sub_profilers=None): self.module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend + self._sdpa_backend = None # Optional: SDPA simulation backend def set_module(self, module): """Set the actual attention module for benchmarking.""" @@ -27,6 +29,18 @@ def set_module(self, module): self._cached_results = None self._cache_key = None + def set_gemm_backend(self, backend): + """Set a GEMM simulation backend for attention linear projections.""" + self._gemm_backend = backend + self._cached_results = None + self._cache_key = None + + def set_sdpa_backend(self, backend): + """Set an SDPA simulation backend for attention computation.""" + self._sdpa_backend = backend + self._cached_results = None + self._cache_key = None + def estimated_num_params(self, rank: Optional[int] = None) -> int: args = self.config.model_config # Group-query & multi-latent attention support. @@ -131,23 +145,76 @@ def _num_query_groups() -> int: return tokens_per_rank * (activation_width + ln_width) * bytes_per_value + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Get simulated results from GEMM + SDPA simulation backends.""" + args = self.config.model_config + mp = self.config.model_parallel_config + tp_size = max(1, mp.tensor_model_parallel_size) + cp_size = max(1, mp.context_model_parallel_size) + + batch_tokens = batch_size * seq_len // tp_size // cp_size + slen_per_cp = seq_len // cp_size + + fwd_time = 0.0 + bwd_time = 0.0 + + # 1. Simulate linear projection GEMMs (Q, K, V, O) using GEMM backend + if self._gemm_backend is not None: + num_query_groups = ( + args.num_query_groups + if args.group_query_attention and args.num_query_groups + else args.num_attention_heads + ) + gemm_result = self._gemm_backend.simulate_attention_gemms( + batch_tokens=batch_tokens, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + kv_channels=args.kv_channels, + num_query_groups=num_query_groups, + dtype="bf16", + ) + fwd_time += gemm_result.forward_time_ms + bwd_time += gemm_result.backward_time_ms + + # 2. Simulate SDPA core computation using SDPA backend + if self._sdpa_backend is not None: + heads_per_rank = max(1, args.num_attention_heads // tp_size) + sdpa_result = self._sdpa_backend.simulate_sdpa( + batch_size=batch_size, + num_heads=heads_per_rank, + seq_len=slen_per_cp, + head_dim=args.kv_channels, + causal=True, + dtype="bf16", + ) + fwd_time += sdpa_result.forward_time_ms + bwd_time += sdpa_result.backward_time_ms + + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - # Context parallel / Sequence parallel adjustment - cp_size = self.config.model_parallel_config.context_model_parallel_size - # Effective sequence length per rank if CP is used - slen_per_cp = seq_len // cp_size - - self._cached_results = benchmark_layer( - self.module, - [ - (seq_len, batch_size, self.config.model_config.hidden_size), - ((1, 1, slen_per_cp, seq_len), torch.bool), - ], - ) + if self._gemm_backend is not None or self._sdpa_backend is not None: + # Use simulation mode + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + # Use actual GPU benchmarking + # Context parallel / Sequence parallel adjustment + cp_size = self.config.model_parallel_config.context_model_parallel_size + # Effective sequence length per rank if CP is used + slen_per_cp = seq_len // cp_size + + self._cached_results = benchmark_layer( + self.module, + [ + (seq_len, batch_size, self.config.model_config.hidden_size), + ((1, 1, slen_per_cp, seq_len), torch.bool), + ], + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/module_profilers/dense_mlp.py b/primus/core/projection/module_profilers/dense_mlp.py index d0a9aaadb..5008bd5a2 100644 --- a/primus/core/projection/module_profilers/dense_mlp.py +++ b/primus/core/projection/module_profilers/dense_mlp.py @@ -19,6 +19,7 @@ def __init__(self, config, sub_profilers=None): self.module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend def set_module(self, module): """Set the actual Dense MLP module for benchmarking.""" @@ -27,6 +28,13 @@ def set_module(self, module): self._cached_results = None self._cache_key = None + def set_gemm_backend(self, backend): + """Set a GEMM simulation backend for simulated profiling.""" + self._gemm_backend = backend + # Invalidate cache when backend changes + self._cached_results = None + self._cache_key = None + def estimated_num_params(self, rank: Optional[int] = None) -> int: # For SwiGLU: 3 projections (gate, up, down) # For standard FFN: 2 projections (up, down) @@ -58,14 +66,33 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: # Peak memory is input + intermediate (both needed for backward) return intermediate_memory + activation_memory + output_memory + def _get_simulated_results(self, batch_size: int, seq_len: int) -> Tuple[float, float, int]: + """Get simulated results from the GEMM simulation backend.""" + tp_size = self.config.model_parallel_config.tensor_model_parallel_size + cp_size = self.config.model_parallel_config.context_model_parallel_size + batch_tokens = batch_size * seq_len // tp_size // cp_size + + sim_result = self._gemm_backend.simulate_mlp_gemms( + batch_tokens=batch_tokens, + hidden_size=self.config.model_config.hidden_size, + ffn_hidden_size=self.config.model_config.ffn_hidden_size, + dtype="bf16", + swiglu=self.config.model_config.swiglu, + ) + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (sim_result.forward_time_ms, sim_result.backward_time_ms, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> Tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - self._cached_results = benchmark_layer( - self.module, - [(seq_len, batch_size, self.config.model_config.hidden_size)], - ) + if self._gemm_backend is not None: + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + self._cached_results = benchmark_layer( + self.module, + [(seq_len, batch_size, self.config.model_config.hidden_size)], + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/module_profilers/embedding.py b/primus/core/projection/module_profilers/embedding.py index 5640e9de0..cfd2c3a90 100644 --- a/primus/core/projection/module_profilers/embedding.py +++ b/primus/core/projection/module_profilers/embedding.py @@ -19,6 +19,7 @@ def __init__(self, config, sub_profilers=None): self.module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._simulation_mode = False # Set to True when simulation backends are active def set_module(self, module): """Set the actual module for benchmarking.""" @@ -27,6 +28,12 @@ def set_module(self, module): self._cached_results = None self._cache_key = None + def set_simulation_mode(self, enabled: bool = True): + """Enable simulation mode (embedding lookup is estimated analytically).""" + self._simulation_mode = enabled + self._cached_results = None + self._cache_key = None + def estimated_num_params(self, rank: Optional[int] = None) -> int: return self.config.model_config.padded_vocab_size * self.config.model_config.hidden_size @@ -40,22 +47,41 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: * 2 ) # bf16 + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Estimate embedding time analytically (lookup is memory-bound, very fast).""" + tp_size = self.config.model_parallel_config.tensor_model_parallel_size + cp_size = self.config.model_parallel_config.context_model_parallel_size + tokens = batch_size * seq_len // tp_size // cp_size + hidden = self.config.model_config.hidden_size + # Embedding lookup: read tokens indices + write output embeddings + # Very fast relative to GEMM layers – use small fixed estimate + output_bytes = tokens * hidden * 2 # bf16 + # Assume ~4 TB/s effective bandwidth (MI300X), 1 read + 1 write pass + bw_bytes_per_ms = 4e9 # ~4 TB/s → bytes/ms + fwd_time = max(0.01, output_bytes / bw_bytes_per_ms) + bwd_time = fwd_time # Gradient scatter is similar cost + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - # Context parallel / Sequence parallel adjustment - cp_size = self.config.model_parallel_config.context_model_parallel_size - # Effective sequence length per rank if CP is used - slen_per_cp = seq_len // cp_size - - self._cached_results = benchmark_layer( - self.module, - [ - ((batch_size, slen_per_cp), torch.int64), - ], - ) + if self._simulation_mode: + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + # Context parallel / Sequence parallel adjustment + cp_size = self.config.model_parallel_config.context_model_parallel_size + # Effective sequence length per rank if CP is used + slen_per_cp = seq_len // cp_size + + self._cached_results = benchmark_layer( + self.module, + [ + ((batch_size, slen_per_cp), torch.int64), + ], + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/module_profilers/language_model.py b/primus/core/projection/module_profilers/language_model.py index ac15e1fd7..815366522 100644 --- a/primus/core/projection/module_profilers/language_model.py +++ b/primus/core/projection/module_profilers/language_model.py @@ -157,6 +157,33 @@ def __init__(self, config, sub_profilers=None): ep_size=self.config.model_parallel_config.expert_model_parallel_size, num_virtual_pipeline_stages=self.config.model_parallel_config.virtual_pipeline_model_parallel_size, ) + self._gemm_backend = None + self._sdpa_backend = None + + def set_simulation_backends(self, gemm_backend=None, sdpa_backend=None): + """Set simulation backends and propagate to all sub-profilers.""" + self._gemm_backend = gemm_backend + self._sdpa_backend = sdpa_backend + + # Propagate to transformer layer sub-profilers (which further propagate + # to attention, MLP, router sub-profilers). + for key in ("dense_transformer_layer", "moe_transformer_layer"): + if key in self.sub_profilers and self.sub_profilers[key] is not None: + layer_profiler = self.sub_profilers[key] + if hasattr(layer_profiler, "set_simulation_backends"): + layer_profiler.set_simulation_backends(gemm_backend, sdpa_backend) + + # Propagate to embedding (uses simple analytical estimate in sim mode). + if "embedding" in self.sub_profilers and self.sub_profilers["embedding"] is not None: + emb = self.sub_profilers["embedding"] + if hasattr(emb, "set_simulation_mode"): + emb.set_simulation_mode(gemm_backend is not None or sdpa_backend is not None) + + # Propagate GEMM backend to output layer (vocab projection GEMM). + if "output_layer" in self.sub_profilers and self.sub_profilers["output_layer"] is not None: + out = self.sub_profilers["output_layer"] + if gemm_backend is not None and hasattr(out, "set_gemm_backend"): + out.set_gemm_backend(gemm_backend) def get_layers_for_rank( self, @@ -422,70 +449,111 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return int(total_act) def run_layer_benchmark(self, model, batch_size: int, seq_len: int) -> dict: - """Benchmark transformer layers plus embedding/output layers on this rank.""" + """Benchmark or simulate transformer layers plus embedding/output layers on this rank. - def unwrap_module(module): - """Recursively unwrap DistributedDataParallel / pipeline wrappers.""" - return unwrap_module(module.module) if hasattr(module, "module") else module + Supports two modes: + - **benchmark** (default): Runs actual GPU kernels and measures timing. + Requires *model* to be a real instantiated model (or list of model chunks). + - **simulate**: Uses GEMM and SDPA simulation backends. The *model* + parameter may be ``None`` – no GPU is required. - model_chunks = model if isinstance(model, list) else [model] + The mode is automatically selected based on whether simulation backends + have been set via :meth:`set_simulation_backends`. + """ + is_simulation_mode = self._gemm_backend is not None or self._sdpa_backend is not None + # ----------------------------------------------------------------- + # Unwrap model (only when an actual model is provided) + # ----------------------------------------------------------------- embedding_module = None output_module = None all_layers = [] - for chunk in model_chunks: - unwrapped = unwrap_module(chunk) - - language_model = getattr(unwrapped, "language_model", None) - if language_model is not None: - if hasattr(language_model, "embedding"): - embedding_module = language_model.embedding - if hasattr(language_model, "output_layer"): - output_module = language_model.output_layer - - if hasattr(language_model, "encoder") and hasattr(language_model.encoder, "layers"): - all_layers.extend(language_model.encoder.layers) - elif hasattr(language_model, "decoder") and hasattr(language_model.decoder, "layers"): - all_layers.extend(language_model.decoder.layers) - elif hasattr(language_model, "layers"): - all_layers.extend(language_model.layers) - continue - - if hasattr(unwrapped, "decoder") and hasattr(unwrapped.decoder, "layers"): - all_layers.extend(unwrapped.decoder.layers) - elif hasattr(unwrapped, "layers"): - all_layers.extend(unwrapped.layers) - else: - raise ValueError(f"Cannot find transformer layers in model chunk: {type(unwrapped)}") - if hasattr(unwrapped, "embedding"): - embedding_module = unwrapped.embedding - if hasattr(unwrapped, "output_layer"): - output_module = unwrapped.output_layer + if model is not None: + def unwrap_module(module): + """Recursively unwrap DistributedDataParallel / pipeline wrappers.""" + return unwrap_module(module.module) if hasattr(module, "module") else module + + model_chunks = model if isinstance(model, list) else [model] + + for chunk in model_chunks: + unwrapped = unwrap_module(chunk) + + language_model = getattr(unwrapped, "language_model", None) + if language_model is not None: + if hasattr(language_model, "embedding"): + embedding_module = language_model.embedding + if hasattr(language_model, "output_layer"): + output_module = language_model.output_layer + + if hasattr(language_model, "encoder") and hasattr(language_model.encoder, "layers"): + all_layers.extend(language_model.encoder.layers) + elif hasattr(language_model, "decoder") and hasattr(language_model.decoder, "layers"): + all_layers.extend(language_model.decoder.layers) + elif hasattr(language_model, "layers"): + all_layers.extend(language_model.layers) + continue + + if hasattr(unwrapped, "decoder") and hasattr(unwrapped.decoder, "layers"): + all_layers.extend(unwrapped.decoder.layers) + elif hasattr(unwrapped, "layers"): + all_layers.extend(unwrapped.layers) + else: + raise ValueError( + f"Cannot find transformer layers in model chunk: {type(unwrapped)}" + ) + if hasattr(unwrapped, "embedding"): + embedding_module = unwrapped.embedding + if hasattr(unwrapped, "output_layer"): + output_module = unwrapped.output_layer + elif not is_simulation_mode: + raise ValueError( + "model=None is only allowed when simulation backends are set " + "(call set_simulation_backends first)" + ) is_rank_0 = int(os.getenv("RANK", "0")) == 0 + mode_label = "Simulating" if is_simulation_mode else "Benchmarking" if is_rank_0: - print(f"\n[Primus:Performance Projection] Found {len(all_layers)} transformer layers") + if model is not None: + print(f"\n[Primus:Performance Projection] Found {len(all_layers)} transformer layers") + else: + print(f"\n[Primus:Performance Projection] Pure simulation mode (no model)") print(f"[Primus:Performance Projection] This rank is responsible for layers: {self.layers}") + if is_simulation_mode: + backends = [] + if self._gemm_backend is not None: + backends.append(f"GEMM={self._gemm_backend.name()}") + if self._sdpa_backend is not None: + backends.append(f"SDPA={self._sdpa_backend.name()}") + print(f"[Primus:Performance Projection] Mode: SIMULATION ({', '.join(backends)})") embedding_stats = None output_stats = None - # Benchmark embedding if this rank hosts it. + # ---------------------------------------------------------------------- + # Benchmark / simulate embedding layer (if this rank hosts it) + # ---------------------------------------------------------------------- if 0 in self.layers: - if embedding_module is None: + if model is not None and embedding_module is None and not is_simulation_mode: if is_rank_0: - print("[Primus:Performance Projection] WARNING: Embedding module not found on this rank.") + print( + "[Primus:Performance Projection] WARNING: Embedding module not found on this rank." + ) else: if is_rank_0: - print("[Primus:Performance Projection] Benchmarking embedding layer...") + print(f"[Primus:Performance Projection] {mode_label} embedding layer...") profiler = self.sub_profilers["embedding"] - module = ( - embedding_module.word_embeddings - if hasattr(embedding_module, "word_embeddings") - else embedding_module - ) - profiler.set_module(module) + if embedding_module is not None: + module = ( + embedding_module.word_embeddings + if hasattr(embedding_module, "word_embeddings") + else embedding_module + ) + profiler.set_module(module) + # In simulation mode without a model, the profiler uses its + # analytical estimate (set_simulation_mode was already called + # by set_simulation_backends). emb_forward = profiler.measured_forward_time(batch_size, seq_len) emb_backward = profiler.measured_backward_time(batch_size, seq_len) emb_mem = profiler.measured_activation_memory(batch_size, seq_len) @@ -502,19 +570,25 @@ def unwrap_module(module): "activation_memory_bytes": emb_mem, } - # Benchmark output layer if this rank hosts the final layer. + # ---------------------------------------------------------------------- + # Benchmark / simulate output layer (if this rank hosts the final layer) + # ---------------------------------------------------------------------- last_layer_id = self.config.model_config.num_layers - 1 if last_layer_id in self.layers: - if output_module is None: + if model is not None and output_module is None and not is_simulation_mode: if is_rank_0: print( "[Primus:Performance Projection] WARNING: Output layer module not found on this rank." ) else: if is_rank_0: - print("[Primus:Performance Projection] Benchmarking output layer...") + print(f"[Primus:Performance Projection] {mode_label} output layer...") profiler = self.sub_profilers["output_layer"] - profiler.set_module(output_module) + if output_module is not None: + profiler.set_module(output_module) + # In simulation mode without a model, the output_layer profiler + # uses its GEMM backend (set_gemm_backend was already called by + # set_simulation_backends). out_forward = profiler.measured_forward_time(batch_size, seq_len) out_backward = profiler.measured_backward_time(batch_size, seq_len) out_mem = profiler.measured_activation_memory(batch_size, seq_len) @@ -532,15 +606,18 @@ def unwrap_module(module): } # ============================================================================== - # BENCHMARK LAYER TYPES (one of each type: dense, moe) + # BENCHMARK / SIMULATE LAYER TYPES (one of each type: dense, moe) # ============================================================================== results = {} profiled_types = set() for layer_idx in self.layers: - if layer_idx >= len(all_layers): + # In benchmark mode, guard against out-of-range layer indices. + if model is not None and layer_idx >= len(all_layers): if is_rank_0: - print(f"[WARNING] Layer index {layer_idx} exceeds available layers ({len(all_layers)})") + print( + f"[WARNING] Layer index {layer_idx} exceeds available layers ({len(all_layers)})" + ) continue is_moe = self.config.model_config.moe_pattern[layer_idx] @@ -549,10 +626,8 @@ def unwrap_module(module): if layer_type in profiled_types: continue - layer_module = all_layers[layer_idx] - if is_rank_0: - print(f"\n[Primus:Performance Projection] Benchmarking Layer {layer_idx} ({layer_type})...") + print(f"\n[Primus:Performance Projection] {mode_label} Layer {layer_idx} ({layer_type})...") # Get the appropriate profiler if is_moe: @@ -560,21 +635,23 @@ def unwrap_module(module): else: layer_profiler = self.sub_profilers["dense_transformer_layer"] - # Set the layer module - layer_profiler.set_layer_module(layer_module) + # Set the layer module only when a real model is available. + if model is not None: + layer_module = all_layers[layer_idx] + layer_profiler.set_layer_module(layer_module) - # Benchmark full layer (uses optimized benchmark_layer with 64 iterations, warm caches) + # Benchmark/simulate full layer forward_time = layer_profiler.measured_forward_time(batch_size, seq_len) backward_time = layer_profiler.measured_backward_time(batch_size, seq_len) activation_memory = layer_profiler.measured_activation_memory(batch_size, seq_len) - # Benchmark Attention + # Benchmark/simulate Attention attn_profiler = layer_profiler.get_sub_profiler("self_attention") attn_forward = attn_profiler.measured_forward_time(batch_size, seq_len) attn_backward = attn_profiler.measured_backward_time(batch_size, seq_len) attn_mem = attn_profiler.measured_activation_memory(batch_size, seq_len) - # Benchmark MLP + # Benchmark/simulate MLP mlp_profiler = layer_profiler.get_sub_profiler("mlp") mlp_forward = mlp_profiler.measured_forward_time(batch_size, seq_len) mlp_backward = mlp_profiler.measured_backward_time(batch_size, seq_len) @@ -601,9 +678,10 @@ def unwrap_module(module): is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0: - print(f" Forward time: {forward_time:.2f} ms") - print(f" Backward time: {backward_time:.2f} ms") - print(f" Total: {forward_time + backward_time:.2f} ms") + src = "(simulated)" if is_simulation_mode else "(measured)" + print(f" Forward time: {forward_time:.2f} ms {src}") + print(f" Backward time: {backward_time:.2f} ms {src}") + print(f" Total: {forward_time + backward_time:.2f} ms {src}") print(f" Activation memory: {activation_memory / (1024**2):.2f} MB") print(f" Attention: fwd={attn_forward:.2f} ms, bwd={attn_backward:.2f} ms") print(f" MLP: fwd={mlp_forward:.2f} ms, bwd={mlp_backward:.2f} ms") diff --git a/primus/core/projection/module_profilers/moe_mlp.py b/primus/core/projection/module_profilers/moe_mlp.py index 070731fb6..9fec31e54 100644 --- a/primus/core/projection/module_profilers/moe_mlp.py +++ b/primus/core/projection/module_profilers/moe_mlp.py @@ -19,6 +19,7 @@ def __init__(self, config, sub_profilers=None): self.module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend def set_module(self, module): """Set the actual MoE MLP module for benchmarking.""" @@ -27,6 +28,13 @@ def set_module(self, module): self._cached_results = None self._cache_key = None + def set_gemm_backend(self, backend): + """Set a GEMM simulation backend for simulated profiling.""" + self._gemm_backend = backend + # Invalidate cache when backend changes + self._cached_results = None + self._cache_key = None + def estimated_num_params(self, rank: Optional[int] = None) -> int: if self.config.model_config.moe_ffn_hidden_size is not None: moe_ffn = self.config.model_config.moe_ffn_hidden_size @@ -87,14 +95,63 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return total + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Get simulated results from the GEMM simulation backend for MoE MLP.""" + tp_size = self.config.model_parallel_config.tensor_model_parallel_size + cp_size = self.config.model_parallel_config.context_model_parallel_size + ep_size = self.config.model_parallel_config.expert_model_parallel_size + + batch_tokens = batch_size * seq_len // tp_size // cp_size + topk_tokens = batch_tokens * self.config.model_config.moe_router_topk + + if self.config.model_config.moe_ffn_hidden_size is not None: + moe_ffn = self.config.model_config.moe_ffn_hidden_size + else: + moe_ffn = self.config.model_config.ffn_hidden_size + + # Simulate routed expert MLP GEMMs (topk tokens through experts / EP) + # Each expert processes topk_tokens / num_local_experts tokens on average + num_local_experts = (self.config.model_config.num_experts or 1) // ep_size + tokens_per_expert = topk_tokens // max(num_local_experts, 1) + + sim_result = self._gemm_backend.simulate_mlp_gemms( + batch_tokens=tokens_per_expert, + hidden_size=self.config.model_config.hidden_size, + ffn_hidden_size=moe_ffn, + dtype="bf16", + swiglu=self.config.model_config.swiglu, + ) + # Scale by number of local experts (they run sequentially or in grouped GEMM) + fwd_time = sim_result.forward_time_ms * num_local_experts + bwd_time = sim_result.backward_time_ms * num_local_experts + + # Shared experts (if any) + shared_sz = self.config.model_config.moe_shared_expert_intermediate_size + if shared_sz: + shared_result = self._gemm_backend.simulate_mlp_gemms( + batch_tokens=batch_tokens, + hidden_size=self.config.model_config.hidden_size, + ffn_hidden_size=shared_sz, + dtype="bf16", + swiglu=self.config.model_config.swiglu, + ) + fwd_time += shared_result.forward_time_ms + bwd_time += shared_result.backward_time_ms + + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - self._cached_results = benchmark_layer( - self.module, - [(seq_len, batch_size, self.config.model_config.hidden_size)], - ) + if self._gemm_backend is not None: + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + self._cached_results = benchmark_layer( + self.module, + [(seq_len, batch_size, self.config.model_config.hidden_size)], + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/module_profilers/output_layer.py b/primus/core/projection/module_profilers/output_layer.py index ddd7d5ee4..90575882c 100644 --- a/primus/core/projection/module_profilers/output_layer.py +++ b/primus/core/projection/module_profilers/output_layer.py @@ -17,6 +17,7 @@ def __init__(self, config, sub_profilers=None): self.module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend def set_module(self, module): """Set the actual module for benchmarking.""" @@ -25,6 +26,12 @@ def set_module(self, module): self._cached_results = None self._cache_key = None + def set_gemm_backend(self, backend): + """Set a GEMM simulation backend for simulated profiling.""" + self._gemm_backend = backend + self._cached_results = None + self._cache_key = None + def estimated_num_params(self, rank: Optional[int] = None) -> int: return self.config.model_config.padded_vocab_size * self.config.model_config.hidden_size @@ -38,22 +45,46 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: * 2 ) # bf16 + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Simulate output layer using GEMM backend (vocab projection GEMM).""" + tp_size = self.config.model_parallel_config.tensor_model_parallel_size + cp_size = self.config.model_parallel_config.context_model_parallel_size + batch_tokens = batch_size * seq_len // tp_size // cp_size + hidden_size = self.config.model_config.hidden_size + vocab_size = self.config.model_config.padded_vocab_size + + # Output projection GEMM: [batch_tokens, hidden_size] x [hidden_size, vocab_size] + sim_result = self._gemm_backend.simulate_gemm( + m=batch_tokens, + n=vocab_size, + k=hidden_size, + dtype="bf16", + ) + fwd_time = sim_result.forward_time_ms + bwd_time = fwd_time * 2.0 # dgrad + wgrad + + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - # Context parallel / Sequence parallel adjustment - cp_size = self.config.model_parallel_config.context_model_parallel_size - # Effective sequence length per rank if CP is used - slen_per_cp = seq_len // cp_size - - self._cached_results = benchmark_layer( - self.module, - [ - (slen_per_cp, batch_size, self.config.model_config.hidden_size), - ], - ) + if self._gemm_backend is not None: + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + # Context parallel / Sequence parallel adjustment + cp_size = self.config.model_parallel_config.context_model_parallel_size + # Effective sequence length per rank if CP is used + slen_per_cp = seq_len // cp_size + + self._cached_results = benchmark_layer( + self.module, + [ + (slen_per_cp, batch_size, self.config.model_config.hidden_size), + ], + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/module_profilers/transformer_layer.py b/primus/core/projection/module_profilers/transformer_layer.py index 4fa5ce281..41320443c 100644 --- a/primus/core/projection/module_profilers/transformer_layer.py +++ b/primus/core/projection/module_profilers/transformer_layer.py @@ -69,10 +69,31 @@ def __init__(self, config, sub_profilers=None): self.layer_module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend + self._sdpa_backend = None # Optional: SDPA simulation backend def get_sub_profiler(self, name: str): return self.sub_profilers.get(name) + def set_simulation_backends(self, gemm_backend=None, sdpa_backend=None): + """Set simulation backends and propagate to sub-profilers.""" + self._gemm_backend = gemm_backend + self._sdpa_backend = sdpa_backend + # Propagate to sub-profilers + if "self_attention" in self.sub_profilers: + attn = self.sub_profilers["self_attention"] + if gemm_backend is not None and hasattr(attn, "set_gemm_backend"): + attn.set_gemm_backend(gemm_backend) + if sdpa_backend is not None and hasattr(attn, "set_sdpa_backend"): + attn.set_sdpa_backend(sdpa_backend) + if "mlp" in self.sub_profilers: + mlp = self.sub_profilers["mlp"] + if gemm_backend is not None and hasattr(mlp, "set_gemm_backend"): + mlp.set_gemm_backend(gemm_backend) + # Invalidate cache + self._cached_results = None + self._cache_key = None + def set_layer_module(self, layer_module): """Set the actual transformer layer module for benchmarking.""" self.layer_module = layer_module @@ -99,17 +120,32 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: + self.sub_profilers["residual_add"].estimated_activation_memory(batch_size, seq_len) * 2 ) + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Aggregate simulated results from sub-profilers.""" + attn_fwd = self.sub_profilers["self_attention"].measured_forward_time(batch_size, seq_len) + attn_bwd = self.sub_profilers["self_attention"].measured_backward_time(batch_size, seq_len) + mlp_fwd = self.sub_profilers["mlp"].measured_forward_time(batch_size, seq_len) + mlp_bwd = self.sub_profilers["mlp"].measured_backward_time(batch_size, seq_len) + fwd_time = attn_fwd + mlp_fwd + bwd_time = attn_bwd + mlp_bwd + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - # Get TransformerConfig from the layer module itself (has fp8 setting) - transformer_config = getattr(self.layer_module, "config", None) - self._cached_results = benchmark_layer( - self.layer_module, - [(seq_len, batch_size, self.config.model_config.hidden_size)], - transformer_config=transformer_config, - ) + if self._gemm_backend is not None or self._sdpa_backend is not None: + # Use simulation mode + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + # Get TransformerConfig from the layer module itself (has fp8 setting) + transformer_config = getattr(self.layer_module, "config", None) + self._cached_results = benchmark_layer( + self.layer_module, + [(seq_len, batch_size, self.config.model_config.hidden_size)], + transformer_config=transformer_config, + ) self._cache_key = cache_key return self._cached_results @@ -132,10 +168,31 @@ def __init__(self, config, sub_profilers=None): self.layer_module = None # Will be set during benchmarking self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) + self._gemm_backend = None # Optional: GEMM simulation backend + self._sdpa_backend = None # Optional: SDPA simulation backend def get_sub_profiler(self, name: str): return self.sub_profilers.get(name) + def set_simulation_backends(self, gemm_backend=None, sdpa_backend=None): + """Set simulation backends and propagate to sub-profilers.""" + self._gemm_backend = gemm_backend + self._sdpa_backend = sdpa_backend + # Propagate to sub-profilers + if "self_attention" in self.sub_profilers: + attn = self.sub_profilers["self_attention"] + if gemm_backend is not None and hasattr(attn, "set_gemm_backend"): + attn.set_gemm_backend(gemm_backend) + if sdpa_backend is not None and hasattr(attn, "set_sdpa_backend"): + attn.set_sdpa_backend(sdpa_backend) + if "mlp" in self.sub_profilers: + mlp = self.sub_profilers["mlp"] + if gemm_backend is not None and hasattr(mlp, "set_gemm_backend"): + mlp.set_gemm_backend(gemm_backend) + # Invalidate cache + self._cached_results = None + self._cache_key = None + def set_layer_module(self, layer_module): """Set the actual transformer layer module for benchmarking.""" self.layer_module = layer_module @@ -164,17 +221,32 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: + self.sub_profilers["residual_add"].estimated_activation_memory(batch_size, seq_len) * 2 ) + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + """Aggregate simulated results from sub-profilers.""" + attn_fwd = self.sub_profilers["self_attention"].measured_forward_time(batch_size, seq_len) + attn_bwd = self.sub_profilers["self_attention"].measured_backward_time(batch_size, seq_len) + mlp_fwd = self.sub_profilers["mlp"].measured_forward_time(batch_size, seq_len) + mlp_bwd = self.sub_profilers["mlp"].measured_backward_time(batch_size, seq_len) + fwd_time = attn_fwd + mlp_fwd + bwd_time = attn_bwd + mlp_bwd + activation_memory = self.estimated_activation_memory(batch_size, seq_len) + return (fwd_time, bwd_time, activation_memory) + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: - # Get TransformerConfig from the layer module itself (has fp8 setting) - transformer_config = getattr(self.layer_module, "config", None) - self._cached_results = benchmark_layer( - self.layer_module, - [(seq_len, batch_size, self.config.model_config.hidden_size)], - transformer_config=transformer_config, - ) + if self._gemm_backend is not None or self._sdpa_backend is not None: + # Use simulation mode + self._cached_results = self._get_simulated_results(batch_size, seq_len) + else: + # Get TransformerConfig from the layer module itself (has fp8 setting) + transformer_config = getattr(self.layer_module, "config", None) + self._cached_results = benchmark_layer( + self.layer_module, + [(seq_len, batch_size, self.config.model_config.hidden_size)], + transformer_config=transformer_config, + ) self._cache_key = cache_key return self._cached_results diff --git a/primus/core/projection/performance_projection/projection.py b/primus/core/projection/performance_projection/projection.py index da6e4bf2a..4285c9bbd 100644 --- a/primus/core/projection/performance_projection/projection.py +++ b/primus/core/projection/performance_projection/projection.py @@ -23,10 +23,16 @@ from primus.core.projection.performance_projection.simulator import ( SchedulerSimulationRunner, ) +from primus.core.projection.simulation_backends.factory import ( + get_gemm_simulation_backend, + get_sdpa_simulation_backend, +) from primus.core.projection.training_config import ( convert_primus_config_to_projection_config, ) -from primus.modules.trainer.megatron.pre_trainer import MegatronPretrainTrainer +# NOTE: MegatronPretrainTrainer is imported lazily inside _run_layer_benchmark() +# to avoid pulling in the megatron dependency when running in pure simulation mode +# (--profiling-mode simulate). _MAX_EXPERT_PARALLEL_SIZE = 8 _BYTES_PER_GB = 1024**3 @@ -199,6 +205,10 @@ def calculate_collective_communication_time( # FSDP shards weights across DP ranks. Each layer needs: # - Forward: All-gather to reconstruct full weights # - Backward: Reduce-scatter to distribute gradients back to shards + # + # Recompute correction: with recompute_granularity="full", every backward + # layer recomputes its forward pass, requiring a SECOND AllGather per + # layer to re-fetch the sharded weights. # Note: use_fsdp and mp_config already defined above if use_fsdp and dp > 1: @@ -214,22 +224,36 @@ def calculate_collective_communication_time( # All-gather: each rank sends its shard (1/DP), receives full weights # Total data moved = weight_size * (DP-1)/DP per rank - ag_time_per_layer = cm.allgather(coll_args, weight_size_per_layer, dp, groups=["dp"]) + ag_time_per_layer_us = cm.allgather(coll_args, weight_size_per_layer, dp, groups=["dp"]) # Reduce-scatter: each rank sends full gradients, receives its shard - # Gradients are in FP32 for optimizer (4 bytes), but reduce-scatter often uses BF16 grad_size_per_layer = params_per_dense_layer * 2 # BF16 gradients for communication - rs_time_per_layer = cm.reduce_scatter(coll_args, grad_size_per_layer, dp, groups=["dp"]) + rs_time_per_layer_us = cm.reduce_scatter(coll_args, grad_size_per_layer, dp, groups=["dp"]) + + # --- Recompute correction --- + # With recompute_granularity="full", during the backward pass each layer + # re-runs its forward pass. This means the weights must be AllGathered + # AGAIN for each recomputed layer (the first AG result was freed after + # the initial forward). The ReduceScatter count is unchanged (1 per + # layer backward). + recompute_gran = getattr(mp_config, "recompute_granularity", None) + recomp_n_layers = getattr(mp_config, "recompute_num_layers", 0) or 0 + ag_multiplier = 1 # default: AG once per layer (forward) + if recompute_gran == "full" and recomp_n_layers > 0: + # Each recomputed layer needs a second AG in backward + recomp_ratio = min(recomp_n_layers, num_layers) / num_layers + ag_multiplier = 1 + recomp_ratio # e.g. 2.0 when all layers recomputed # Calculate total FSDP time for all layers - total_fsdp_ag_fwd = (ag_time_per_layer * num_layers) / 1000 # ms - total_fsdp_rs_bwd = (rs_time_per_layer * num_layers) / 1000 # ms + total_fsdp_ag_fwd = (ag_time_per_layer_us * num_layers * ag_multiplier) / 1000 # ms + total_fsdp_rs_bwd = (rs_time_per_layer_us * num_layers) / 1000 # ms breakdown["fsdp_allgather_fwd"] = total_fsdp_ag_fwd breakdown["fsdp_reducescatter_bwd"] = total_fsdp_rs_bwd message_info["fsdp_weight_size_per_layer_mb"] = weight_size_per_layer / (1024 * 1024) - message_info["fsdp_ag_per_layer_ms"] = ag_time_per_layer / 1000 - message_info["fsdp_rs_per_layer_ms"] = rs_time_per_layer / 1000 + message_info["fsdp_ag_per_layer_ms"] = ag_time_per_layer_us / 1000 + message_info["fsdp_rs_per_layer_ms"] = rs_time_per_layer_us / 1000 + message_info["fsdp_ag_multiplier"] = ag_multiplier message_info["fsdp_enabled"] = True else: breakdown["fsdp_allgather_fwd"] = 0.0 @@ -1177,6 +1201,8 @@ def _report_simulation_results(sim_results, training_config): def _run_layer_benchmark(primus_config, unknown_overrides): + from primus.modules.trainer.megatron.pre_trainer import MegatronPretrainTrainer + module_config = primus_config.get_module_config("pre_trainer") _limit_layers_for_projection(module_config) rescale_info = _rescale_expert_parallelism(module_config) @@ -1248,6 +1274,68 @@ def _run_layer_benchmark(primus_config, unknown_overrides): return profiling_results +def _run_layer_simulation(primus_config, args): + """ + Run layer simulation using GEMM + SDPA simulation backends (no GPU required). + + This mirrors :func:`_run_layer_benchmark` but replaces actual GPU kernel + benchmarks with analytical / model-based simulation. It does *not* + instantiate a trainer or model – only the profiler tree is built from the + ``TrainingConfig``. + + Args: + primus_config: Primus configuration (will be mutated – layer counts + are reduced for consistency with the benchmark flow). + args: CLI arguments (``--gemm-backend``, ``--gpu-arch``). + + Returns: + dict: Profiling results in the same format as ``_run_layer_benchmark``. + """ + module_config = primus_config.get_module_config("pre_trainer") + _limit_layers_for_projection(module_config) + _rescale_expert_parallelism(module_config) + training_config = convert_primus_config_to_projection_config(primus_config) + + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + + # ---- Create simulation backends ---- + gemm_backend_name = getattr(args, "gemm_backend", None) + gpu_arch = getattr(args, "gpu_arch", None) + + gemm_backend = get_gemm_simulation_backend(backend_name=gemm_backend_name, gpu_arch=gpu_arch) + sdpa_backend = get_sdpa_simulation_backend(gpu_arch=gpu_arch) + + # ---- Build profiler tree (no model needed) ---- + if is_rank_0: + print("[Primus:Performance Projection] Building simulation profiler...") + model_profiler_spec = get_language_model_profiler_spec(training_config) + model_profiler = build_profiler(model_profiler_spec) + + # Wire simulation backends into the entire profiler hierarchy + model_profiler.set_simulation_backends(gemm_backend, sdpa_backend) + + seq_len = training_config.runtime_config.sequence_length + batch_size = training_config.runtime_config.micro_batch_size + + if is_rank_0: + print("[Primus:Performance Projection] Simulating with:") + print(f" Batch Size: {batch_size}") + print(f" Sequence Length: {seq_len}") + print(f" GEMM backend: {gemm_backend.name()}") + print(f" SDPA backend: {sdpa_backend.name()}") + print("" + "=" * 100) + print("[Primus:Performance Projection] Starting layer simulation...") + print("=" * 100) + + # Run simulation (model=None – no GPU required) + profiling_results = model_profiler.run_layer_benchmark( + model=None, + batch_size=batch_size, + seq_len=seq_len, + ) + return profiling_results + + def _run_pipeline_simulation_megatron_zb(training_config, profiling_results): """ Run pipeline simulation using the actual Megatron zero-bubble scheduler. @@ -1768,7 +1856,48 @@ def launch_projection_from_cli(args, overrides): "benchmark_ep" ] - profiling_results = _run_layer_benchmark(primus_config, unknown_overrides) + # Determine profiling mode + profiling_mode = getattr(args, "profiling_mode", "benchmark") + + if profiling_mode == "simulate": + # Pure simulation – no GPU / trainer required + profiling_results = _run_layer_simulation(primus_config, args) + elif profiling_mode == "both": + # Run both benchmark and simulation, keep benchmark results for + # downstream pipeline simulation / multinode projection, but print + # a side-by-side comparison. + sim_results = _run_layer_simulation(copy.deepcopy(primus_config), args) + bench_results = _run_layer_benchmark(primus_config, unknown_overrides) + + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print("\n" + "=" * 100) + print("[Primus:Performance Projection] Benchmark vs Simulation Comparison") + print("=" * 100) + for key in bench_results: + if key in ("embedding", "output"): + continue + bd = bench_results[key] + sd = sim_results.get(key, {}) + if not isinstance(bd, dict): + continue + lt = bd.get("type", key) + b_fwd = bd.get("forward_time_ms", 0) + b_bwd = bd.get("backward_time_ms", 0) + s_fwd = sd.get("forward_time_ms", 0) + s_bwd = sd.get("backward_time_ms", 0) + fwd_err = ((s_fwd - b_fwd) / b_fwd * 100) if b_fwd else 0 + bwd_err = ((s_bwd - b_bwd) / b_bwd * 100) if b_bwd else 0 + print(f" Layer type: {lt}") + print(f" Forward: bench={b_fwd:.2f} ms sim={s_fwd:.2f} ms (err={fwd_err:+.1f}%)") + print(f" Backward: bench={b_bwd:.2f} ms sim={s_bwd:.2f} ms (err={bwd_err:+.1f}%)") + print("=" * 100) + + # Use benchmark results for the rest of the pipeline + profiling_results = bench_results + else: + # Default: actual GPU benchmark + profiling_results = _run_layer_benchmark(primus_config, unknown_overrides) # Use original config for projection calculations training_config = convert_primus_config_to_projection_config(primus_config_original) From e2bf1a99378161f329ecfd963ba5ca58a60d141a Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Fri, 13 Feb 2026 17:56:46 -0800 Subject: [PATCH 02/12] feat(projection): Add FSDP overlap model, PXN All-to-All for DeepEP, and intra-node A2A overhead - Add per-phase FSDP overlap model (90% forward AG, 24% backward AG, 34% RS) - Add PXN All-to-All algorithm for DeepEP with pipelined scale-up/scale-out - Enable PXN automatically when DeepEP is detected (moe_enable_deepep or use_turbo_deepep) - Adjust intra-node A2A overhead to 28 us per peer (from preflight measurements) - Separate intra-node and inter-node A2A overhead modeling --- .../module_profilers/collective_args.py | 13 +- .../module_profilers/collective_model.py | 117 ++++++++++++++++-- .../performance_projection/projection.py | 90 +++++++++----- 3 files changed, 178 insertions(+), 42 deletions(-) diff --git a/primus/core/projection/module_profilers/collective_args.py b/primus/core/projection/module_profilers/collective_args.py index f8c03914f..b377e153e 100644 --- a/primus/core/projection/module_profilers/collective_args.py +++ b/primus/core/projection/module_profilers/collective_args.py @@ -51,7 +51,18 @@ class CollectiveArgs: nics_per_node: Optional[int] = 8 # NICs per node (None = gpus_per_node) # All-to-all specific - a2a_peer_lat: float = 0.45 # Per-peer latency overhead for a2a + a2a_peer_lat: float = 0.45 # Per-peer latency overhead for inter-node a2a + a2a_intra_node_peer_lat: float = 28.0 # Per-peer latency overhead for intra-node a2a + # Intra-node overhead is higher (~19-28 us) due to: + # - P2P scatter/gather scheduling overhead + # - RCCL internal synchronization barriers + # - Memory copy and buffer management + # Note: Preflight measurements for EP=8 intra-node A2A show: + # - Linear extrapolation: ~27.4 us per peer + # - 2MB measurement: ~28.1 us per peer + # - After subtracting bandwidth component: ~19.4 us per peer + # Default 28 us matches preflight measurements (middle of range) + # Can be overridden via hardware_config for GPU-specific calibration def get_default_args( diff --git a/primus/core/projection/module_profilers/collective_model.py b/primus/core/projection/module_profilers/collective_model.py index 792dc0e22..f8a929d4e 100644 --- a/primus/core/projection/module_profilers/collective_model.py +++ b/primus/core/projection/module_profilers/collective_model.py @@ -518,6 +518,74 @@ def hierarchical_alltoall(args, msg_size, gpus, groups=None, protocol=None): return t_total +def pxn_alltoall(args, msg_size, gpus, groups=None, protocol=None): + """ + PXN All-to-All - pipelined implementation for DeepEP. + + Based on DeepEP implementation of pipelined PXN-A2A with pipelined + scale-up (intra-node) and scale-out (inter-node) communication. + + Key features: + - Overlaps scale-up and scale-out communication + - Scale-out doesn't start until 4MB is accumulated for dispatch + - Falls back to regular alltoall for single-node configurations + """ + if gpus == 1 or msg_size == 0: + return 0 + + original_msg_size = msg_size + + # Nodes participating in the exchange + num_nodes = ceil(gpus / args.node_size) + + # If A2A is not crossing node boundaries, fall back to regular alltoall + if num_nodes <= 1: + return min( + single_shot_alltoall(args, msg_size, gpus, groups, protocol), + run_alltoall(args, msg_size, gpus, groups, protocol) + ) + + # PXN - AlltoAll - pipeline implementation + chunk_size = 4194304 # 4MB - DeepEP waits until 4MB is accumulated + + # Scale-out message size (inter-node communication) + scale_out_msg_size = int(original_msg_size * (num_nodes - 1) / num_nodes) + + # Scale-up delay: time to accumulate 4MB before scale-out starts + scaleup_delay = 0.0 + if scale_out_msg_size < chunk_size: + # If total scale-out msg size is less than 4MB, + # total time = scaleup_delay + scaleout_time + node_lat, _ = node_latency_and_volume_protocol(args, scale_out_msg_size, protocol) + scaleup_delay = node_lat + scale_out_msg_size / args.node_bw * 1.0e-3 + else: + # Scale-out comm doesn't start until 4MB is accumulated + node_lat, _ = node_latency_and_volume_protocol(args, chunk_size, protocol) + scaleup_delay = node_lat + chunk_size / args.node_bw * 1.0e-3 + + # Assume PXN style alltoall with overlapped scale-up and scale-out + node_msg_size = int(original_msg_size * (args.node_size - 1) / args.node_size) + scale_out_msg_size = int(original_msg_size * (num_nodes - 1) / num_nodes) + + # Calculate latencies with protocol inflation + node_lat, node_msg_size_adj = node_latency_and_volume_protocol(args, node_msg_size, protocol) + pod_lat, scale_out_msg_size_adj = pod_latency_and_volume_protocol(args, scale_out_msg_size, protocol) + + # Scale-up (intra-node) time + node_bw = args.bw_eff * args.node_bw + t_a2a_node = node_lat + node_msg_size_adj / node_bw * 1.0e-3 + + # Scale-out (inter-node) time with scaleup delay + pod_bw = args.bw_eff * args.pod_bw + t_a2a_scale_out = pod_lat + scale_out_msg_size_adj / pod_bw * 1.0e-3 + scaleup_delay + + # Total time is max of scale-up and scale-out (they overlap) + t_total = max(t_a2a_node, t_a2a_scale_out) + t_total += args.kernel_launch_latency + + return t_total + + def single_shot_allgather(args, msg_size, gpus, groups=None, protocol=None): """ Single shot allgather with max fanout and overlap. @@ -629,23 +697,50 @@ def alltoall(args, msg_size, gpus, groups=["ep"]): Select best alltoall algorithm among several options. Tries multiple protocols and algorithms, returns fastest. Applies per-peer latency overhead and minimum latency floor. + + If DeepEP is enabled (moe_enable_deepep=True), uses PXN All-to-All + which pipelines scale-up and scale-out communication. """ + # Check if DeepEP is enabled + use_deepep = getattr(args, "moe_enable_deepep", False) or getattr(args, "use_turbo_deepep", False) + min_a2a_time = float("inf") for p in ["simple", "ll", "ll64", "ll128"]: - direct_a2a_time = run_alltoall(args, msg_size, gpus, protocol=p) - single_shot_a2a_time = single_shot_alltoall(args, msg_size, gpus, protocol=p) - hierarchical_a2a_time = hierarchical_alltoall(args, msg_size, gpus, protocol=p) - a2a_time = min(direct_a2a_time, single_shot_a2a_time, hierarchical_a2a_time) + if use_deepep: + # Use PXN All-to-All for DeepEP + pxn_a2a_time = pxn_alltoall(args, msg_size, gpus, protocol=p) + a2a_time = pxn_a2a_time + else: + # Use regular All-to-All algorithms + direct_a2a_time = run_alltoall(args, msg_size, gpus, protocol=p) + single_shot_a2a_time = single_shot_alltoall(args, msg_size, gpus, protocol=p) + hierarchical_a2a_time = hierarchical_alltoall(args, msg_size, gpus, protocol=p) + a2a_time = min(direct_a2a_time, single_shot_a2a_time, hierarchical_a2a_time) + if a2a_time < min_a2a_time: min_a2a_time = a2a_time - # Add per-peer latency overhead for inter-node communication - # This accounts for RDMA QP setup, work request posting, completion polling, etc. - if hasattr(args, "a2a_peer_lat") and args.a2a_peer_lat > 0: - gpus_per_node = args.node_size - inter_node_peers = max(0, gpus - gpus_per_node) - peer_overhead = args.a2a_peer_lat * inter_node_peers - min_a2a_time += peer_overhead + # Add per-peer latency overhead for ALL A2A communication (both intra and inter-node) + # This accounts for: + # - P2P scatter/gather scheduling overhead + # - RCCL internal synchronization barriers + # - Memory copy and buffer management + # - RDMA QP setup, work request posting, completion polling (for inter-node) + # For intra-node A2A, overhead is significant due to synchronization and scheduling + # Measured overhead: ~50 us per peer for intra-node (vs ~0.45 us for inter-node) + gpus_per_node = args.node_size + intra_node_peers = min(gpus - 1, gpus_per_node - 1) # Peers within same node + inter_node_peers = max(0, gpus - gpus_per_node) # Peers on other nodes + + # Intra-node overhead is much higher due to synchronization and scheduling + # Based on preflight measurements: EP=8 intra-node A2A needs ~19-28 us per peer + # Inter-node overhead is lower (~0.45 us per peer) due to RDMA efficiency + intra_node_overhead_per_peer = getattr(args, "a2a_intra_node_peer_lat", 28.0) # Default 28 us + inter_node_overhead_per_peer = getattr(args, "a2a_peer_lat", 0.45) # Default 0.45 us + + peer_overhead = (intra_node_overhead_per_peer * intra_node_peers + + inter_node_overhead_per_peer * inter_node_peers) + min_a2a_time += peer_overhead return min_a2a_time diff --git a/primus/core/projection/performance_projection/projection.py b/primus/core/projection/performance_projection/projection.py index 4285c9bbd..b08c8c36e 100644 --- a/primus/core/projection/performance_projection/projection.py +++ b/primus/core/projection/performance_projection/projection.py @@ -81,6 +81,13 @@ def calculate_collective_communication_time( hardware_config=hardware_config, ) + # Check if DeepEP is enabled and pass it to coll_args + mp_config = training_config.model_parallel_config + moe_enable_deepep = getattr(mp_config, "moe_enable_deepep", False) + use_turbo_deepep = getattr(mp_config, "use_turbo_deepep", False) + coll_args.moe_enable_deepep = moe_enable_deepep + coll_args.use_turbo_deepep = use_turbo_deepep + # Model parameters hidden_size = model_config.hidden_size num_layers = model_config.num_layers @@ -115,7 +122,7 @@ def calculate_collective_communication_time( message_info = {} per_layer_info = [] # Store per-layer communication details - # Check if FSDP is enabled (needed to determine gradient sync strategy) + # Get model parallel config (already retrieved above for DeepEP check) mp_config = training_config.model_parallel_config # Note: use_torch_fsdp2 = True means actual FSDP (shards weights, uses all-gather/reduce-scatter) # use_distributed_optimizer = True means ZeRO-1 style (shards optimizer state only, uses all-reduce) @@ -299,38 +306,61 @@ def calculate_collective_communication_time( else: message_info["gradient_allreduce_overlapped"] = False - # Check if FSDP communication can be overlapped - # In FSDP2, prefetch can overlap all-gather with compute of current layer - # Reduce-scatter can overlap with forward of next microbatch - # However, overlap is NOT 100%: - # - First layer's all-gather cannot overlap (nothing before it) - # - Last layer's reduce-scatter cannot overlap (nothing after it) - # - There's always some exposed communication at boundaries + # FSDP overlap model (calibrated against LLaMA3-70B MI300X trace) + # --------------------------------------------------------------- + # FSDP2 prefetches next layer's AllGather while the current layer + # computes, and ReduceScatter runs after backward completes. + # Overlap differs significantly between forward and backward: + # + # Forward AG: ~90-95% overlap — prefetch hides AG behind compute + # Backward AG: ~24% overlap — eager prefetch finishes long before + # the compute stream is ready (dependency + # chain through previous layer's backward) + # RS: ~34% overlap — inherently post-compute, only partially + # overlaps with next layer's recompute AG + # + # These per-phase percentages are largely model-independent for FSDP2 + # with full recompute; the overall overlap is ~50% for 70B-class models + # and ~64% without recompute. if use_fsdp and dp > 1: - overlap_fsdp = getattr(mp_config, "use_torch_fsdp2", False) # FSDP2 has better overlap + overlap_fsdp = getattr(mp_config, "use_torch_fsdp2", False) if overlap_fsdp: - # Calculate per-layer times - fsdp_ag_per_layer = message_info.get("fsdp_ag_per_layer_ms", 0) - fsdp_rs_per_layer = message_info.get("fsdp_rs_per_layer_ms", 0) - - # Exposed time: first layer's all-gather + last layer's reduce-scatter - # Plus some overhead from imperfect pipelining (~10-20% of remaining) - exposed_ag = fsdp_ag_per_layer # First layer cannot overlap - exposed_rs = fsdp_rs_per_layer # Last layer cannot overlap - remaining_ag = breakdown.get("fsdp_allgather_fwd", 0) - exposed_ag - remaining_rs = breakdown.get("fsdp_reducescatter_bwd", 0) - exposed_rs - - # Assume ~70% overlap efficiency for the rest (conservative for multi-node) - overlap_efficiency = 0.7 - hidden_ag = remaining_ag * overlap_efficiency - hidden_rs = remaining_rs * overlap_efficiency - - total_comm_time -= hidden_ag - total_comm_time -= hidden_rs + recompute_gran = getattr(mp_config, "recompute_granularity", None) + recomp_n = getattr(mp_config, "recompute_num_layers", 0) or 0 + has_recompute = recompute_gran == "full" and recomp_n > 0 + + total_fsdp_ag = breakdown.get("fsdp_allgather_fwd", 0) + total_fsdp_rs = breakdown.get("fsdp_reducescatter_bwd", 0) + + if has_recompute: + # With full recompute the AG total already includes the 2× + # multiplier. Split into forward AG and backward (recomp) AG. + recomp_ratio = min(recomp_n, num_layers) / num_layers + ag_multiplier_val = message_info.get("fsdp_ag_multiplier", 1 + recomp_ratio) + fwd_ag_total = total_fsdp_ag / ag_multiplier_val + bwd_ag_total = total_fsdp_ag - fwd_ag_total + else: + fwd_ag_total = total_fsdp_ag + bwd_ag_total = 0.0 + + # Per-phase overlap percentages (from trace calibration) + FWD_AG_OVERLAP = 0.90 # forward AG hidden behind compute + BWD_AG_OVERLAP = 0.24 # backward recompute AG (structural limit) + RS_OVERLAP = 0.34 # ReduceScatter (structural limit) + + hidden_fwd_ag = fwd_ag_total * FWD_AG_OVERLAP + hidden_bwd_ag = bwd_ag_total * BWD_AG_OVERLAP + hidden_rs = total_fsdp_rs * RS_OVERLAP + + total_hidden = hidden_fwd_ag + hidden_bwd_ag + hidden_rs + total_comm_time -= total_hidden message_info["fsdp_overlapped"] = True - message_info["fsdp_exposed_ms"] = ( - exposed_ag + exposed_rs + (remaining_ag + remaining_rs) * (1 - overlap_efficiency) - ) + message_info["fsdp_fwd_ag_overlap"] = FWD_AG_OVERLAP + message_info["fsdp_bwd_ag_overlap"] = BWD_AG_OVERLAP + message_info["fsdp_rs_overlap"] = RS_OVERLAP + total_fsdp = total_fsdp_ag + total_fsdp_rs + message_info["fsdp_overall_overlap"] = total_hidden / total_fsdp if total_fsdp > 0 else 0 + message_info["fsdp_exposed_ms"] = total_fsdp - total_hidden else: message_info["fsdp_overlapped"] = False From 7668b4e57c948522be2783334cf04bd50e9529dc Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Tue, 17 Feb 2026 09:46:56 -0800 Subject: [PATCH 03/12] feat(projection): add simulation backends, FP8 support, optimizer step estimation, and --gpu-clock-mhz override --- primus/cli/subcommands/projection.py | 17 +- .../projection/module_profilers/attention.py | 38 +- .../projection/module_profilers/dense_mlp.py | 34 +- .../projection/module_profilers/moe_mlp.py | 40 +- .../performance_projection/projection.py | 614 +++++++++++++----- .../simulation_backends/__init__.py | 23 + .../projection/simulation_backends/base.py | 213 ++++++ .../projection/simulation_backends/factory.py | 118 ++++ .../simulation_backends/origami_backend.py | 437 +++++++++++++ .../simulation_backends/sdpa_simulator.py | 543 ++++++++++++++++ primus/core/projection/training_config.py | 30 +- 11 files changed, 1925 insertions(+), 182 deletions(-) create mode 100644 primus/core/projection/simulation_backends/__init__.py create mode 100644 primus/core/projection/simulation_backends/base.py create mode 100644 primus/core/projection/simulation_backends/factory.py create mode 100644 primus/core/projection/simulation_backends/origami_backend.py create mode 100644 primus/core/projection/simulation_backends/sdpa_simulator.py diff --git a/primus/cli/subcommands/projection.py b/primus/cli/subcommands/projection.py index 6fce8bece..821bb4a0c 100644 --- a/primus/cli/subcommands/projection.py +++ b/primus/cli/subcommands/projection.py @@ -64,7 +64,9 @@ def register_subcommand(subparsers): suite_parsers = parser.add_subparsers(dest="suite", required=True) # ---------- memory ---------- - memory = suite_parsers.add_parser("memory", help="Memory projection only (per-GPU memory analysis).") + memory = suite_parsers.add_parser( + "memory", help="Memory projection only (per-GPU memory analysis)." + ) from primus.core.launcher.parser import add_pretrain_parser add_pretrain_parser(memory) @@ -131,6 +133,19 @@ def register_subcommand(subparsers): "If not specified, auto-detected or uses PRIMUS_GPU_ARCH env var.\n" ), ) + performance.add_argument( + "--gpu-clock-mhz", + type=int, + required=False, + default=None, + help=( + "Override the GPU compute clock frequency in MHz for simulation.\n" + "If not specified, uses the default from the hardware profile for the\n" + "given --gpu-arch (e.g. 2100 MHz for MI300X, 1200 MHz for MI325X).\n" + "Can also be set via the PRIMUS_GPU_CLOCK_MHZ env var.\n" + "Example: --gpu-clock-mhz 1500\n" + ), + ) parser.set_defaults(func=run) diff --git a/primus/core/projection/module_profilers/attention.py b/primus/core/projection/module_profilers/attention.py index 22828dfb9..400139aff 100644 --- a/primus/core/projection/module_profilers/attention.py +++ b/primus/core/projection/module_profilers/attention.py @@ -17,7 +17,9 @@ class AttentionProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.module = None # Will be set during benchmarking - self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) + self._cached_results = ( + None # Cache for (forward_time, backward_time, activation_memory) + ) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend self._sdpa_backend = None # Optional: SDPA simulation backend @@ -52,7 +54,9 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: ) # Projection ratio: (kv_channels * n_heads) / hidden_size - query_proj_to_hidden = (args.kv_channels * args.num_attention_heads) / args.hidden_size + query_proj_to_hidden = ( + args.kv_channels * args.num_attention_heads + ) / args.hidden_size if args.multi_latent_attention: # q_term: either dense or LoRA factored Q with RoPE/Q-norm @@ -65,14 +69,19 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: else: q_term = args.q_lora_rank * ( args.hidden_size - + args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim) + + args.num_attention_heads + * (args.qk_head_dim + args.qk_pos_emb_head_dim) + 1 ) attn = ( q_term # kv lora + rope + kv norm + args.kv_lora_rank - * (args.hidden_size + args.num_attention_heads * (args.qk_head_dim + args.v_head_dim) + 1) + * ( + args.hidden_size + + args.num_attention_heads * (args.qk_head_dim + args.v_head_dim) + + 1 + ) # pos emb + args.hidden_size * args.qk_pos_emb_head_dim # out proj @@ -85,7 +94,10 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: 2 * args.hidden_size * args.hidden_size - * ((1 + (num_query_groups / args.num_attention_heads)) * query_proj_to_hidden) + * ( + (1 + (num_query_groups / args.num_attention_heads)) + * query_proj_to_hidden + ) ) def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: @@ -130,7 +142,9 @@ def _num_query_groups() -> int: kv_projection_size = args.kv_channels * _num_query_groups() # Need to retain Q, K, V as well as the projected context/output. - activation_width = query_projection_size + 2 * kv_projection_size + args.hidden_size + activation_width = ( + query_projection_size + 2 * kv_projection_size + args.hidden_size + ) if args.qk_layernorm: ln_width += kv_projection_size * 2 @@ -145,7 +159,9 @@ def _num_query_groups() -> int: return tokens_per_rank * (activation_width + ln_width) * bytes_per_value - def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_simulated_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Get simulated results from GEMM + SDPA simulation backends.""" args = self.config.model_config mp = self.config.model_parallel_config @@ -165,13 +181,15 @@ def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, if args.group_query_attention and args.num_query_groups else args.num_attention_heads ) + # FP8-hybrid: linear projections (QKV, O) run in FP8 + gemm_dtype = "fp8" if getattr(args, "fp8", None) else "bf16" gemm_result = self._gemm_backend.simulate_attention_gemms( batch_tokens=batch_tokens, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, kv_channels=args.kv_channels, num_query_groups=num_query_groups, - dtype="bf16", + dtype=gemm_dtype, ) fwd_time += gemm_result.forward_time_ms bwd_time += gemm_result.backward_time_ms @@ -193,7 +211,9 @@ def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_benchmark_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) diff --git a/primus/core/projection/module_profilers/dense_mlp.py b/primus/core/projection/module_profilers/dense_mlp.py index 5008bd5a2..83d1db655 100644 --- a/primus/core/projection/module_profilers/dense_mlp.py +++ b/primus/core/projection/module_profilers/dense_mlp.py @@ -17,7 +17,9 @@ class DenseMLPProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.module = None # Will be set during benchmarking - self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) + self._cached_results = ( + None # Cache for (forward_time, backward_time, activation_memory) + ) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend @@ -55,34 +57,50 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: # Memory after first projection(s) if self.config.model_config.swiglu: # Need to store both gate and up projections for backward - intermediate_memory = 2 * num_tokens * self.config.model_config.ffn_hidden_size * 2 # bf16 + intermediate_memory = ( + 2 * num_tokens * self.config.model_config.ffn_hidden_size * 2 + ) # bf16 else: - intermediate_memory = num_tokens * self.config.model_config.ffn_hidden_size * 2 # bf16 + intermediate_memory = ( + num_tokens * self.config.model_config.ffn_hidden_size * 2 + ) # bf16 # After activation - activation_memory = num_tokens * self.config.model_config.ffn_hidden_size * 2 # bf16 + activation_memory = ( + num_tokens * self.config.model_config.ffn_hidden_size * 2 + ) # bf16 output_memory = num_tokens * self.config.model_config.hidden_size * 2 # bf16 # Peak memory is input + intermediate (both needed for backward) return intermediate_memory + activation_memory + output_memory - def _get_simulated_results(self, batch_size: int, seq_len: int) -> Tuple[float, float, int]: + def _get_simulated_results( + self, batch_size: int, seq_len: int + ) -> Tuple[float, float, int]: """Get simulated results from the GEMM simulation backend.""" tp_size = self.config.model_parallel_config.tensor_model_parallel_size cp_size = self.config.model_parallel_config.context_model_parallel_size batch_tokens = batch_size * seq_len // tp_size // cp_size + # FP8-hybrid: MLP projections (gate, up, down) run in FP8 + gemm_dtype = "fp8" if getattr(self.config.model_config, "fp8", None) else "bf16" sim_result = self._gemm_backend.simulate_mlp_gemms( batch_tokens=batch_tokens, hidden_size=self.config.model_config.hidden_size, ffn_hidden_size=self.config.model_config.ffn_hidden_size, - dtype="bf16", + dtype=gemm_dtype, swiglu=self.config.model_config.swiglu, ) activation_memory = self.estimated_activation_memory(batch_size, seq_len) - return (sim_result.forward_time_ms, sim_result.backward_time_ms, activation_memory) + return ( + sim_result.forward_time_ms, + sim_result.backward_time_ms, + activation_memory, + ) - def _get_benchmark_results(self, batch_size: int, seq_len: int) -> Tuple[float, float, int]: + def _get_benchmark_results( + self, batch_size: int, seq_len: int + ) -> Tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: diff --git a/primus/core/projection/module_profilers/moe_mlp.py b/primus/core/projection/module_profilers/moe_mlp.py index 9fec31e54..f51fa1269 100644 --- a/primus/core/projection/module_profilers/moe_mlp.py +++ b/primus/core/projection/module_profilers/moe_mlp.py @@ -17,7 +17,9 @@ class MoEMLPProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.module = None # Will be set during benchmarking - self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) + self._cached_results = ( + None # Cache for (forward_time, backward_time, activation_memory) + ) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend @@ -44,16 +46,26 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: # For SwiGLU: 3 projections per expert (gate, up, down) # For standard FFN: 2 projections per expert (up, down) num_ffn_projections = 3 if self.config.model_config.swiglu else 2 - per_expert_params = num_ffn_projections * self.config.model_config.hidden_size * moe_ffn - ep = 1 if rank is None else self.config.model_parallel_config.expert_model_parallel_size + per_expert_params = ( + num_ffn_projections * self.config.model_config.hidden_size * moe_ffn + ) + ep = ( + 1 + if rank is None + else self.config.model_parallel_config.expert_model_parallel_size + ) - all_experts_params = self.config.model_config.num_experts * per_expert_params // ep + all_experts_params = ( + self.config.model_config.num_experts * per_expert_params // ep + ) # Shared experts (if any) shared_sz = 0 if self.config.model_config.moe_shared_expert_intermediate_size is not None: shared_sz = self.config.model_config.moe_shared_expert_intermediate_size - shared_params = num_ffn_projections * self.config.model_config.hidden_size * shared_sz + shared_params = ( + num_ffn_projections * self.config.model_config.hidden_size * shared_sz + ) return all_experts_params + shared_params @@ -90,12 +102,16 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: # After activation activation_memory = num_tokens * moe_ffn * 2 # bf16 - output_memory = num_tokens * self.config.model_config.hidden_size * 2 # bf16 + output_memory = ( + num_tokens * self.config.model_config.hidden_size * 2 + ) # bf16 total += intermediate_memory + activation_memory + output_memory return total - def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_simulated_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Get simulated results from the GEMM simulation backend for MoE MLP.""" tp_size = self.config.model_parallel_config.tensor_model_parallel_size cp_size = self.config.model_parallel_config.context_model_parallel_size @@ -114,11 +130,13 @@ def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, num_local_experts = (self.config.model_config.num_experts or 1) // ep_size tokens_per_expert = topk_tokens // max(num_local_experts, 1) + # FP8-hybrid: MoE expert MLP projections run in FP8 + gemm_dtype = "fp8" if getattr(self.config.model_config, "fp8", None) else "bf16" sim_result = self._gemm_backend.simulate_mlp_gemms( batch_tokens=tokens_per_expert, hidden_size=self.config.model_config.hidden_size, ffn_hidden_size=moe_ffn, - dtype="bf16", + dtype=gemm_dtype, swiglu=self.config.model_config.swiglu, ) # Scale by number of local experts (they run sequentially or in grouped GEMM) @@ -132,7 +150,7 @@ def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, batch_tokens=batch_tokens, hidden_size=self.config.model_config.hidden_size, ffn_hidden_size=shared_sz, - dtype="bf16", + dtype=gemm_dtype, swiglu=self.config.model_config.swiglu, ) fwd_time += shared_result.forward_time_ms @@ -141,7 +159,9 @@ def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_benchmark_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: diff --git a/primus/core/projection/performance_projection/projection.py b/primus/core/projection/performance_projection/projection.py index b08c8c36e..320ba6c78 100644 --- a/primus/core/projection/performance_projection/projection.py +++ b/primus/core/projection/performance_projection/projection.py @@ -30,6 +30,7 @@ from primus.core.projection.training_config import ( convert_primus_config_to_projection_config, ) + # NOTE: MegatronPretrainTrainer is imported lazily inside _run_layer_benchmark() # to avoid pulling in the megatron dependency when running in pure simulation mode # (--profiling-mode simulate). @@ -37,6 +38,114 @@ _MAX_EXPERT_PARALLEL_SIZE = 8 _BYTES_PER_GB = 1024**3 +# HBM bandwidth (GB/s) by GPU architecture — used for optimizer step estimation +_HBM_BANDWIDTH_GBPS: Dict[str, float] = { + "mi300x": 5300.0, + "gfx942": 5300.0, + "mi325x": 6000.0, + "mi355x": 8000.0, + "gfx950": 8000.0, +} + + +def _estimate_optimizer_step_ms( + training_config, + dp_size: int = 1, + gpu_arch: Optional[str] = None, +) -> float: + """ + Estimate the optimizer step time (Adam/AdamW) for one training iteration. + + The optimizer step is HBM-bandwidth-bound. For each parameter, + Adam reads/writes the following (mixed-precision training with BF16 + forward, FP32 master weights): + + Read: FP32 master param (4B) + FP32 grad (4B) + m (4B) + v (4B) = 16 B + Write: FP32 master param (4B) + m (4B) + v (4B) + BF16 param (2B) = 14 B + Total: 30 bytes per parameter + + With distributed_optimizer or FSDP, optimizer state is sharded across + DP ranks, so each GPU only updates ``N_params / dp_size`` parameters. + + Returns: + Optimizer step time in milliseconds. + """ + model_config = training_config.model_config + mp_config = training_config.model_parallel_config + + # --- Count total parameters per GPU (post TP/PP/EP sharding) ---------- + hidden = model_config.hidden_size + ffn_hidden = model_config.ffn_hidden_size or (hidden * 4) + moe_ffn = model_config.moe_ffn_hidden_size or ffn_hidden + num_layers = model_config.num_layers + num_experts = model_config.num_experts or 0 + moe_pattern = model_config.moe_pattern # list of 0/1 + num_moe_layers = sum(1 for p in moe_pattern if p == 1) + num_dense_layers = num_layers - num_moe_layers + + tp = mp_config.tensor_model_parallel_size + pp = mp_config.pipeline_model_parallel_size + ep = getattr(mp_config, "expert_model_parallel_size", 1) or 1 + + # Attention params: Q, K, V, O -> 4 * h * h (per layer, sharded by TP) + attn_params_per_layer = 4 * hidden * hidden // tp + # Dense MLP: gate, up, down -> 3 * h * ffn (per layer, sharded by TP) + dense_mlp_params_per_layer = 3 * hidden * ffn_hidden // tp + # Expert MLP params per expert: 3 * h * moe_ffn (NOT sharded by TP normally) + expert_tp = getattr(mp_config, "expert_tensor_parallel_size", None) or 1 + expert_mlp_params_per_expert = 3 * hidden * moe_ffn // expert_tp + + # Non-expert params across all layers (sharded by TP, PP) + non_expert_params = ( + num_layers * attn_params_per_layer + + num_dense_layers * dense_mlp_params_per_layer + ) + # Expert params (sharded by EP, expert_TP, PP) + expert_params = ( + num_moe_layers * num_experts * expert_mlp_params_per_expert // max(ep, 1) + ) + + # Shared experts (if any) + shared_sz = getattr(model_config, "moe_shared_expert_intermediate_size", 0) or 0 + shared_expert_params = 0 + if shared_sz and num_moe_layers > 0: + shared_expert_params = num_moe_layers * 3 * hidden * shared_sz // tp + + total_params_per_gpu = ( + non_expert_params + expert_params + shared_expert_params + ) // pp + + # Embedding + output layer params (only on first / last PP rank, amortise) + vocab_size = getattr(model_config, "vocab_size", 0) or 0 + if vocab_size and pp > 0: + embedding_params = vocab_size * hidden // tp + output_params = vocab_size * hidden // tp + # Amortise across PP ranks (only 1 rank holds each) + total_params_per_gpu += (embedding_params + output_params) // pp + + # --- Distributed optimizer / FSDP sharding --- + use_distributed_optimizer = getattr(mp_config, "use_distributed_optimizer", False) + use_fsdp = getattr(mp_config, "use_torch_fsdp2", False) + + if use_distributed_optimizer or use_fsdp: + # Optimizer state is sharded across DP ranks + params_for_optim = total_params_per_gpu // max(dp_size, 1) + else: + params_for_optim = total_params_per_gpu + + # --- Compute time --- + bytes_per_param = 30 # Adam read+write (mixed precision) + total_bytes = params_for_optim * bytes_per_param + + # Look up HBM bandwidth + arch = (gpu_arch or os.getenv("PRIMUS_GPU_ARCH", "mi300x")).lower().strip() + hbm_bw_gbps = _HBM_BANDWIDTH_GBPS.get(arch, 5300.0) + hbm_bw_bytes_per_ms = hbm_bw_gbps * 1e9 / 1e3 # bytes/ms + + optimizer_time_ms = total_bytes / hbm_bw_bytes_per_ms + + return optimizer_time_ms + # ============================================================================= # Hardware and Communication Functions (moved from multinode_projection) @@ -148,12 +257,16 @@ def calculate_collective_communication_time( bw_eff = getattr(coll_args, "bw_eff", 0.91) inter_bw = pod_bw * bw_eff # GB/s per link msg_scale = (dp_replicas - 1) / dp_replicas - expert_ar_time_ms = 2 * expert_grad_size * msg_scale / (inter_bw * 1e9) * 1e3 + expert_ar_time_ms = ( + 2 * expert_grad_size * msg_scale / (inter_bw * 1e9) * 1e3 + ) # Non-expert gradient allreduce: across full DP group non_expert_per_rank = non_expert_params // (tp * pp) non_expert_grad_size = non_expert_per_rank * 4 # FP32 - non_expert_ar_time = cm.allreduce(coll_args, non_expert_grad_size, dp, groups=["dp"]) + non_expert_ar_time = cm.allreduce( + coll_args, non_expert_grad_size, dp, groups=["dp"] + ) non_expert_ar_ms = non_expert_ar_time / 1000 total_ar_ms = expert_ar_time_ms + non_expert_ar_ms @@ -223,19 +336,29 @@ def calculate_collective_communication_time( # Dense layer: ~12 * hidden^2 params (qkv_proj, o_proj, mlp up/down/gate) # MoE layer: similar attention + num_experts * expert_params ffn_hidden = model_config.ffn_hidden_size or hidden_size * 4 - params_per_dense_layer = hidden_size * hidden_size * 4 + hidden_size * ffn_hidden * 3 # attn + MLP - params_per_dense_layer = params_per_dense_layer // tp # Divide by TP (params are TP-sharded) + params_per_dense_layer = ( + hidden_size * hidden_size * 4 + hidden_size * ffn_hidden * 3 + ) # attn + MLP + params_per_dense_layer = ( + params_per_dense_layer // tp + ) # Divide by TP (params are TP-sharded) # Weight size in bytes (BF16 = 2 bytes) weight_size_per_layer = params_per_dense_layer * 2 # All-gather: each rank sends its shard (1/DP), receives full weights # Total data moved = weight_size * (DP-1)/DP per rank - ag_time_per_layer_us = cm.allgather(coll_args, weight_size_per_layer, dp, groups=["dp"]) + ag_time_per_layer_us = cm.allgather( + coll_args, weight_size_per_layer, dp, groups=["dp"] + ) # Reduce-scatter: each rank sends full gradients, receives its shard - grad_size_per_layer = params_per_dense_layer * 2 # BF16 gradients for communication - rs_time_per_layer_us = cm.reduce_scatter(coll_args, grad_size_per_layer, dp, groups=["dp"]) + grad_size_per_layer = ( + params_per_dense_layer * 2 + ) # BF16 gradients for communication + rs_time_per_layer_us = cm.reduce_scatter( + coll_args, grad_size_per_layer, dp, groups=["dp"] + ) # --- Recompute correction --- # With recompute_granularity="full", during the backward pass each layer @@ -252,12 +375,16 @@ def calculate_collective_communication_time( ag_multiplier = 1 + recomp_ratio # e.g. 2.0 when all layers recomputed # Calculate total FSDP time for all layers - total_fsdp_ag_fwd = (ag_time_per_layer_us * num_layers * ag_multiplier) / 1000 # ms + total_fsdp_ag_fwd = ( + ag_time_per_layer_us * num_layers * ag_multiplier + ) / 1000 # ms total_fsdp_rs_bwd = (rs_time_per_layer_us * num_layers) / 1000 # ms breakdown["fsdp_allgather_fwd"] = total_fsdp_ag_fwd breakdown["fsdp_reducescatter_bwd"] = total_fsdp_rs_bwd - message_info["fsdp_weight_size_per_layer_mb"] = weight_size_per_layer / (1024 * 1024) + message_info["fsdp_weight_size_per_layer_mb"] = weight_size_per_layer / ( + 1024 * 1024 + ) message_info["fsdp_ag_per_layer_ms"] = ag_time_per_layer_us / 1000 message_info["fsdp_rs_per_layer_ms"] = rs_time_per_layer_us / 1000 message_info["fsdp_ag_multiplier"] = ag_multiplier @@ -296,7 +423,9 @@ def calculate_collective_communication_time( total_comm_time = sum(breakdown.values()) # Check if gradient all-reduce should be overlapped - overlap_grad_reduce = getattr(mp_config, "overlap_grad_reduce", True) # Default to True + overlap_grad_reduce = getattr( + mp_config, "overlap_grad_reduce", True + ) # Default to True # If overlapped and NOT MoE-no-overlap, don't add to critical path moe_no_overlap = message_info.get("moe_ar_no_overlap", False) @@ -336,7 +465,9 @@ def calculate_collective_communication_time( # With full recompute the AG total already includes the 2× # multiplier. Split into forward AG and backward (recomp) AG. recomp_ratio = min(recomp_n, num_layers) / num_layers - ag_multiplier_val = message_info.get("fsdp_ag_multiplier", 1 + recomp_ratio) + ag_multiplier_val = message_info.get( + "fsdp_ag_multiplier", 1 + recomp_ratio + ) fwd_ag_total = total_fsdp_ag / ag_multiplier_val bwd_ag_total = total_fsdp_ag - fwd_ag_total else: @@ -344,13 +475,13 @@ def calculate_collective_communication_time( bwd_ag_total = 0.0 # Per-phase overlap percentages (from trace calibration) - FWD_AG_OVERLAP = 0.90 # forward AG hidden behind compute - BWD_AG_OVERLAP = 0.24 # backward recompute AG (structural limit) - RS_OVERLAP = 0.34 # ReduceScatter (structural limit) + FWD_AG_OVERLAP = 0.90 # forward AG hidden behind compute + BWD_AG_OVERLAP = 0.24 # backward recompute AG (structural limit) + RS_OVERLAP = 0.34 # ReduceScatter (structural limit) hidden_fwd_ag = fwd_ag_total * FWD_AG_OVERLAP hidden_bwd_ag = bwd_ag_total * BWD_AG_OVERLAP - hidden_rs = total_fsdp_rs * RS_OVERLAP + hidden_rs = total_fsdp_rs * RS_OVERLAP total_hidden = hidden_fwd_ag + hidden_bwd_ag + hidden_rs total_comm_time -= total_hidden @@ -359,7 +490,9 @@ def calculate_collective_communication_time( message_info["fsdp_bwd_ag_overlap"] = BWD_AG_OVERLAP message_info["fsdp_rs_overlap"] = RS_OVERLAP total_fsdp = total_fsdp_ag + total_fsdp_rs - message_info["fsdp_overall_overlap"] = total_hidden / total_fsdp if total_fsdp > 0 else 0 + message_info["fsdp_overall_overlap"] = ( + total_hidden / total_fsdp if total_fsdp > 0 else 0 + ) message_info["fsdp_exposed_ms"] = total_fsdp - total_hidden else: message_info["fsdp_overlapped"] = False @@ -367,7 +500,9 @@ def calculate_collective_communication_time( return total_comm_time, breakdown, message_info, per_layer_info -def extract_single_node_time_from_profiling(profiling_results: dict, training_config) -> float: +def extract_single_node_time_from_profiling( + profiling_results: dict, training_config +) -> float: """ Extract total single-node time from profiling results. @@ -384,7 +519,9 @@ def extract_single_node_time_from_profiling(profiling_results: dict, training_co is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0: - print("[Primus:Performance Projection] Extracting timing from benchmark results...") + print( + "[Primus:Performance Projection] Extracting timing from benchmark results..." + ) print("-" * 100) model_config = training_config.model_config @@ -397,12 +534,16 @@ def extract_single_node_time_from_profiling(profiling_results: dict, training_co num_total_layers = len(moe_pattern) # Get profiled layer indices - profiled_layer_indices = sorted([k for k in profiling_results.keys() if isinstance(k, int)]) + profiled_layer_indices = sorted( + [k for k in profiling_results.keys() if isinstance(k, int)] + ) if is_rank_0: print(f" Profiled layers: {profiled_layer_indices}") print(f" Full model has {num_total_layers} transformer layers") if recompute_granularity == "full" and recompute_num_layers > 0: - print(f" Recomputation: {recompute_num_layers} layers (granularity={recompute_granularity})") + print( + f" Recomputation: {recompute_num_layers} layers (granularity={recompute_granularity})" + ) total_time_ms = 0.0 @@ -435,12 +576,24 @@ def extract_single_node_time_from_profiling(profiling_results: dict, training_co profiled_moe_fwd_times.append(fwd_time) # Calculate averages from profiled layers - avg_dense_time = sum(profiled_dense_times) / len(profiled_dense_times) if profiled_dense_times else 0 + avg_dense_time = ( + sum(profiled_dense_times) / len(profiled_dense_times) + if profiled_dense_times + else 0 + ) avg_dense_fwd = ( - sum(profiled_dense_fwd_times) / len(profiled_dense_fwd_times) if profiled_dense_fwd_times else 0 + sum(profiled_dense_fwd_times) / len(profiled_dense_fwd_times) + if profiled_dense_fwd_times + else 0 + ) + avg_moe_time = ( + sum(profiled_moe_times) / len(profiled_moe_times) if profiled_moe_times else 0 + ) + avg_moe_fwd = ( + sum(profiled_moe_fwd_times) / len(profiled_moe_fwd_times) + if profiled_moe_fwd_times + else 0 ) - avg_moe_time = sum(profiled_moe_times) / len(profiled_moe_times) if profiled_moe_times else 0 - avg_moe_fwd = sum(profiled_moe_fwd_times) / len(profiled_moe_fwd_times) if profiled_moe_fwd_times else 0 # Count total dense and MoE layers in full model num_dense_layers = sum(1 for x in moe_pattern if x == 0) @@ -456,13 +609,21 @@ def extract_single_node_time_from_profiling(profiling_results: dict, training_co # Print detailed breakdown if is_rank_0: if profiled_dense_times: - print(f" Dense Layers: {len(profiled_dense_times)} profiled → {num_dense_layers} total") - print(f" Avg per layer: {avg_dense_time:.2f} ms (fwd={avg_dense_fwd:.2f} ms)") + print( + f" Dense Layers: {len(profiled_dense_times)} profiled → {num_dense_layers} total" + ) + print( + f" Avg per layer: {avg_dense_time:.2f} ms (fwd={avg_dense_fwd:.2f} ms)" + ) print(f" Total time: {total_dense_time:.2f} ms") if profiled_moe_times: - print(f" MoE Layers: {len(profiled_moe_times)} profiled → {num_moe_layers} total") - print(f" Avg per layer: {avg_moe_time:.2f} ms (fwd={avg_moe_fwd:.2f} ms)") + print( + f" MoE Layers: {len(profiled_moe_times)} profiled → {num_moe_layers} total" + ) + print( + f" Avg per layer: {avg_moe_time:.2f} ms (fwd={avg_moe_fwd:.2f} ms)" + ) print(f" Total time: {total_moe_time:.2f} ms") # Output layer @@ -493,14 +654,20 @@ def extract_single_node_time_from_profiling(profiling_results: dict, training_co if is_rank_0: print(f" Recomputation Overhead: {recompute_overhead_ms:.2f} ms") - print(f" ({recompute_dense_layers} dense + {recompute_moe_layers} MoE layers recomputed)") + print( + f" ({recompute_dense_layers} dense + {recompute_moe_layers} MoE layers recomputed)" + ) if is_rank_0: print("-" * 100) - print(f"[Primus:Performance Projection] Extrapolated Baseline Time: {total_time_ms:.2f} ms/iteration") + print( + f"[Primus:Performance Projection] Extrapolated Baseline Time: {total_time_ms:.2f} ms/iteration" + ) if recompute_overhead_ms > 0: print(f" (Includes {recompute_overhead_ms:.2f} ms recomputation overhead)") - print(f" (Based on {len(profiled_layer_indices)} profiled layers → {num_total_layers} total layers)") + print( + f" (Based on {len(profiled_layer_indices)} profiled layers → {num_total_layers} total layers)" + ) print("=" * 100) return total_time_ms @@ -674,7 +841,9 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): f"[Primus:Performance Projection] After reducing PP to 1, " f"config still requires {benchmark_gpus_required} GPUs (TP={tp}, EP={ep}, CP={cp})." ) - print(f"[Primus:Performance Projection] Rescaling EP to fit on {gpus_per_node} GPUs...") + print( + f"[Primus:Performance Projection] Rescaling EP to fit on {gpus_per_node} GPUs..." + ) # Rescale EP to fit rescale_info = _rescale_expert_parallelism(original_config) @@ -723,7 +892,9 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): } -def _estimate_pp_communication_overhead(training_config, pp_size, hardware_config_dict=None): +def _estimate_pp_communication_overhead( + training_config, pp_size, hardware_config_dict=None +): """ Estimate the PP P2P communication overhead for a given PP size. @@ -782,7 +953,9 @@ def _estimate_pp_communication_overhead(training_config, pp_size, hardware_confi # Total P2P time per iteration # Forward: (PP-1) sends, Backward: (PP-1) sends # Times number of microbatches - total_p2p_time_ms = 2 * (pp_size - 1) * num_microbatches * p2p_time_per_transfer / 1000 + total_p2p_time_ms = ( + 2 * (pp_size - 1) * num_microbatches * p2p_time_per_transfer / 1000 + ) return total_p2p_time_ms @@ -900,14 +1073,24 @@ def _estimate_ep_communication_overhead( dispatch_size = tokens_per_batch * hidden_size * moe_router_topk * 2 # BF16 # Calculate All-to-All time for original EP (dispatch + combine) - a2a_dispatch_original = cm.alltoall(coll_args_original, dispatch_size, original_ep, groups=["ep"]) - a2a_combine_original = cm.alltoall(coll_args_original, dispatch_size, original_ep, groups=["ep"]) + a2a_dispatch_original = cm.alltoall( + coll_args_original, dispatch_size, original_ep, groups=["ep"] + ) + a2a_combine_original = cm.alltoall( + coll_args_original, dispatch_size, original_ep, groups=["ep"] + ) a2a_time_original_fwd = (a2a_dispatch_original + a2a_combine_original) / 1000 # ms # Calculate All-to-All time for benchmark EP (dispatch + combine) - a2a_dispatch_benchmark = cm.alltoall(coll_args_benchmark, dispatch_size, benchmark_ep, groups=["ep"]) - a2a_combine_benchmark = cm.alltoall(coll_args_benchmark, dispatch_size, benchmark_ep, groups=["ep"]) - a2a_time_benchmark_fwd = (a2a_dispatch_benchmark + a2a_combine_benchmark) / 1000 # ms + a2a_dispatch_benchmark = cm.alltoall( + coll_args_benchmark, dispatch_size, benchmark_ep, groups=["ep"] + ) + a2a_combine_benchmark = cm.alltoall( + coll_args_benchmark, dispatch_size, benchmark_ep, groups=["ep"] + ) + a2a_time_benchmark_fwd = ( + a2a_dispatch_benchmark + a2a_combine_benchmark + ) / 1000 # ms # The overhead is the difference (original is larger due to inter-node communication) fwd_overhead_per_layer = a2a_time_original_fwd - a2a_time_benchmark_fwd @@ -928,7 +1111,9 @@ def _extract_layer_type_timings(layer_results: dict) -> Dict[str, dict[str, floa continue forward = float(result.get("forward_time_ms", 0.0) or 0.0) backward = float(result.get("backward_time_ms", 0.0) or 0.0) - activation = float(result.get("activation_memory_bytes", 0.0) or 0.0) / _BYTES_PER_GB + activation = ( + float(result.get("activation_memory_bytes", 0.0) or 0.0) / _BYTES_PER_GB + ) type_timings[layer_type] = { "forward": forward, "backward": backward, @@ -951,7 +1136,9 @@ def _add_io_layer_timings(chunk_timings: List[list[dict]], profiling_results: di emb_bwd = embedding.get("backward_time_ms", 0.0) or 0.0 first_chunk["bwd"] += emb_bwd # wgrad already included in backward, don't add again - first_chunk["activation"] += (embedding.get("activation_memory_bytes", 0.0) or 0.0) / _BYTES_PER_GB + first_chunk["activation"] += ( + embedding.get("activation_memory_bytes", 0.0) or 0.0 + ) / _BYTES_PER_GB output = profiling_results.get("output") if output and chunk_timings[-1]: @@ -960,10 +1147,14 @@ def _add_io_layer_timings(chunk_timings: List[list[dict]], profiling_results: di out_bwd = output.get("backward_time_ms", 0.0) or 0.0 last_chunk["bwd"] += out_bwd # wgrad already included in backward, don't add again - last_chunk["activation"] += (output.get("activation_memory_bytes", 0.0) or 0.0) / _BYTES_PER_GB + last_chunk["activation"] += ( + output.get("activation_memory_bytes", 0.0) or 0.0 + ) / _BYTES_PER_GB -def _build_chunk_time_matrix(training_config, layer_results: dict) -> Optional[List[List[dict]]]: +def _build_chunk_time_matrix( + training_config, layer_results: dict +) -> Optional[List[List[dict]]]: model_cfg = getattr(training_config, "model_config", None) mp_cfg = getattr(training_config, "model_parallel_config", None) if model_cfg is None or mp_cfg is None: @@ -974,7 +1165,10 @@ def _build_chunk_time_matrix(training_config, layer_results: dict) -> Optional[L return None layer_type_pattern = getattr(model_cfg, "moe_pattern", None) - if not isinstance(layer_type_pattern, (list, tuple)) or len(layer_type_pattern) != total_layers: + if ( + not isinstance(layer_type_pattern, (list, tuple)) + or len(layer_type_pattern) != total_layers + ): layer_type_pattern = [0] * total_layers type_timings = _extract_layer_type_timings(layer_results) if not type_timings: @@ -1006,14 +1200,20 @@ def _build_chunk_time_matrix(training_config, layer_results: dict) -> Optional[L ) if not layers: chunk_timings.append( - [{"fwd": 0.0, "bwd": 0.0, "wgrad": 0.0, "activation": 0.0} for _ in range(vpp_size)] + [ + {"fwd": 0.0, "bwd": 0.0, "wgrad": 0.0, "activation": 0.0} + for _ in range(vpp_size) + ] ) continue layers_per_chunk = len(layers) // vpp_size if vpp_size else len(layers) if layers_per_chunk == 0: chunk_timings.append( - [{"fwd": 0.0, "bwd": 0.0, "wgrad": 0.0, "activation": 0.0} for _ in range(vpp_size)] + [ + {"fwd": 0.0, "bwd": 0.0, "wgrad": 0.0, "activation": 0.0} + for _ in range(vpp_size) + ] ) continue @@ -1049,7 +1249,9 @@ def _compute_micro_batches(runtime_cfg, model_parallel_config) -> int: return max(1, math.ceil(global_batch / denominator)) -def _build_scheduler_sim_config(training_config, profiling_results, enable_zero_bubble=False): +def _build_scheduler_sim_config( + training_config, profiling_results, enable_zero_bubble=False +): chunk_time_matrix = _build_chunk_time_matrix(training_config, profiling_results) assert chunk_time_matrix is not None @@ -1057,7 +1259,9 @@ def _build_scheduler_sim_config(training_config, profiling_results, enable_zero_ # The zero-bubble scheduler schedules these separately to minimize pipeline bubbles. # Typically B and W are roughly equal in duration (each ~50% of total backward). if enable_zero_bubble: - print("[Primus:Performance Projection] Splitting backward time for zero-bubble scheduling:") + print( + "[Primus:Performance Projection] Splitting backward time for zero-bubble scheduling:" + ) print(" B (input grad) = 50% of backward, W (weight grad) = 50% of backward") for rank_chunks in chunk_time_matrix: for chunk in rank_chunks: @@ -1102,7 +1306,9 @@ def _build_scheduler_sim_config(training_config, profiling_results, enable_zero_ "vpp_size": 1, "micro_batches": micro_batches, } - print("[Primus:Performance Projection] Using zero-bubble scheduler (enable_zero_bubble=True)") + print( + "[Primus:Performance Projection] Using zero-bubble scheduler (enable_zero_bubble=True)" + ) elif vpp_size > 1: scheduler = { "name": "interleaved_1f1b", @@ -1181,14 +1387,18 @@ def _report_simulation_results(sim_results, training_config): activation_trace = scheduled_layers.get("activation_memory_usage") or [] peak_activation = ( - max(activation_trace) if activation_trace else scheduled_layers.get("memory", 0.0) + max(activation_trace) + if activation_trace + else scheduled_layers.get("memory", 0.0) ) # Map rank_idx to pipeline rank (rank_idx // vpp_size) vpp_size = mp_cfg.virtual_pipeline_model_parallel_size or 1 pp_rank = rank_idx // vpp_size if pp_rank not in param_mem_cache: - param_mem_cache[pp_rank] = _get_parameter_memory(training_config, pp_rank) + param_mem_cache[pp_rank] = _get_parameter_memory( + training_config, pp_rank + ) param_mem_gb = param_mem_cache[pp_rank] total_peak_gb = peak_activation + param_mem_gb rank_stats.append( @@ -1250,9 +1460,15 @@ def _run_layer_benchmark(primus_config, unknown_overrides): primus_config.get_module_config("pre_trainer").overlap_param_gather = False primus_config.get_module_config("pre_trainer").use_torch_fsdp2 = False print("[Primus:Performance Projection] Config (with profiling overrides):") - print(f" overlap_grad_reduce: {primus_config.get_module_config('pre_trainer').overlap_grad_reduce}") - print(f" overlap_param_gather: {primus_config.get_module_config('pre_trainer').overlap_param_gather}") - print(f" use_torch_fsdp2: {primus_config.get_module_config('pre_trainer').use_torch_fsdp2}") + print( + f" overlap_grad_reduce: {primus_config.get_module_config('pre_trainer').overlap_grad_reduce}" + ) + print( + f" overlap_param_gather: {primus_config.get_module_config('pre_trainer').overlap_param_gather}" + ) + print( + f" use_torch_fsdp2: {primus_config.get_module_config('pre_trainer').use_torch_fsdp2}" + ) trainer = MegatronPretrainTrainer( module_name="pre_trainer", primus_config=primus_config, @@ -1331,9 +1547,16 @@ def _run_layer_simulation(primus_config, args): # ---- Create simulation backends ---- gemm_backend_name = getattr(args, "gemm_backend", None) gpu_arch = getattr(args, "gpu_arch", None) + gpu_clock_mhz = getattr(args, "gpu_clock_mhz", None) - gemm_backend = get_gemm_simulation_backend(backend_name=gemm_backend_name, gpu_arch=gpu_arch) - sdpa_backend = get_sdpa_simulation_backend(gpu_arch=gpu_arch) + gemm_backend = get_gemm_simulation_backend( + backend_name=gemm_backend_name, + gpu_arch=gpu_arch, + gpu_clock_mhz=gpu_clock_mhz, + ) + sdpa_backend = get_sdpa_simulation_backend( + gpu_arch=gpu_arch, gpu_clock_mhz=gpu_clock_mhz + ) # ---- Build profiler tree (no model needed) ---- if is_rank_0: @@ -1403,7 +1626,9 @@ def _run_pipeline_simulation_megatron_zb(training_config, profiling_results): mem_b = [] mem_w = [] - print("[Primus:Performance Projection] Using Megatron zero-bubble scheduler (ILP-based)") + print( + "[Primus:Performance Projection] Using Megatron zero-bubble scheduler (ILP-based)" + ) print(f" PP size: {pp_size}, Microbatches: {micro_batches}") for rank_idx, rank_chunks in enumerate(chunk_time_matrix): @@ -1427,7 +1652,9 @@ def _run_pipeline_simulation_megatron_zb(training_config, profiling_results): mem_b.append(float(-act_gb * 0.5)) # B releases half mem_w.append(float(-act_gb * 0.5)) # W releases remaining half - print(f" Stage {rank_idx}: F={fwd:.2f}ms, B={b_time:.2f}ms, W={w_time:.2f}ms, act={act_gb:.2f}GB") + print( + f" Stage {rank_idx}: F={fwd:.2f}ms, B={b_time:.2f}ms, W={w_time:.2f}ms, act={act_gb:.2f}GB" + ) # Estimate communication cost (P2P latency) # Use a small default value; actual value depends on hardware @@ -1456,7 +1683,9 @@ def _run_pipeline_simulation_megatron_zb(training_config, profiling_results): step_time_ms = best_time # Calculate bubble time - total_compute_per_mb = sum(cost_f) / pp_size + sum(cost_b) / pp_size + sum(cost_w) / pp_size + total_compute_per_mb = ( + sum(cost_f) / pp_size + sum(cost_b) / pp_size + sum(cost_w) / pp_size + ) ideal_time = total_compute_per_mb * micro_batches bubble_time = step_time_ms - ideal_time bubble_ratio = bubble_time / step_time_ms if step_time_ms > 0 else 0 @@ -1469,7 +1698,9 @@ def _run_pipeline_simulation_megatron_zb(training_config, profiling_results): return step_time_ms -def _run_pipeline_simulation(training_config, profiling_results, enable_zero_bubble=False): +def _run_pipeline_simulation( + training_config, profiling_results, enable_zero_bubble=False +): """ Run pipeline simulation and return the step time. @@ -1484,12 +1715,16 @@ def _run_pipeline_simulation(training_config, profiling_results, enable_zero_bub # Use Megatron's actual ZB scheduler for more accurate simulation if enable_zero_bubble: try: - return _run_pipeline_simulation_megatron_zb(training_config, profiling_results) + return _run_pipeline_simulation_megatron_zb( + training_config, profiling_results + ) except Exception as e: print(f"[Primus:Performance Projection] Megatron ZB scheduler failed: {e}") print("[Primus:Performance Projection] Falling back to simple simulator...") - sim_config = _build_scheduler_sim_config(training_config, profiling_results, enable_zero_bubble) + sim_config = _build_scheduler_sim_config( + training_config, profiling_results, enable_zero_bubble + ) if sim_config is None: return None print("[Primus:Performance Projection] Running pipeline schedule simulator...") @@ -1605,19 +1840,23 @@ def _run_multinode_projection( print(f" Using custom hardware config from: {args.hardware_config}") else: if is_rank_0: - print(" Using default hardware parameters from custom_hardware_example.yaml") + print( + " Using default hardware parameters from custom_hardware_example.yaml" + ) # Calculate communication times - total_comm_time_ms, breakdown, message_info, per_layer_info = calculate_collective_communication_time( - training_config, - target_nodes, - gpus_per_node, - tp, - pp, - ep, - cp, - dp_target, - hardware_config_dict, + total_comm_time_ms, breakdown, message_info, per_layer_info = ( + calculate_collective_communication_time( + training_config, + target_nodes, + gpus_per_node, + tp, + pp, + ep, + cp, + dp_target, + hardware_config_dict, + ) ) # Benchmarked time is for the minimum node configuration @@ -1647,16 +1886,18 @@ def _run_multinode_projection( grad_ar_per_iteration_ms = 0.0 # Non-overlapped allreduce time (added once) if dp_target > 1: # Calculate gradient all-reduce for target - _, target_breakdown, target_message_info, _ = calculate_collective_communication_time( - training_config, - target_nodes, - gpus_per_node, - tp, - pp, - ep, - cp, - dp_target, - hardware_config_dict, + _, target_breakdown, target_message_info, _ = ( + calculate_collective_communication_time( + training_config, + target_nodes, + gpus_per_node, + tp, + pp, + ep, + cp, + dp_target, + hardware_config_dict, + ) ) target_grad_ar = target_breakdown.get("gradient_allreduce", 0) moe_ar_no_overlap = target_message_info.get("moe_ar_no_overlap", False) @@ -1680,16 +1921,18 @@ def _run_multinode_projection( projected_time_ms = projected_compute_time_ms # For reporting, get full breakdown for target - total_comm_time_ms, breakdown, message_info, per_layer_info = calculate_collective_communication_time( - training_config, - target_nodes, - gpus_per_node, - tp, - pp, - ep, - cp, - dp_target, - hardware_config_dict, + total_comm_time_ms, breakdown, message_info, per_layer_info = ( + calculate_collective_communication_time( + training_config, + target_nodes, + gpus_per_node, + tp, + pp, + ep, + cp, + dp_target, + hardware_config_dict, + ) ) # Add exposed FSDP communication time to projected time @@ -1710,7 +1953,9 @@ def _run_multinode_projection( # Calculate number of microbatches per GPU for the target configuration target_microbatches_per_gpu = ( - global_batch // (micro_batch * target_dp_for_microbatch) if target_dp_for_microbatch > 0 else 1 + global_batch // (micro_batch * target_dp_for_microbatch) + if target_dp_for_microbatch > 0 + else 1 ) # Handle edge case where global_batch is smaller than micro_batch * target_dp @@ -1728,21 +1973,33 @@ def _run_multinode_projection( ) target_microbatches_per_gpu = 1 + # Estimate optimizer step time (once per iteration, after all microbatches) + gpu_arch = getattr(args, "gpu_arch", None) + optimizer_step_ms = _estimate_optimizer_step_ms( + training_config, dp_target, gpu_arch + ) + # Build full iteration time: - # compute (per-microbatch) × num_microbatches + gradient allreduce (once per iter) + # compute (per-microbatch) × num_microbatches + gradient allreduce + optimizer step if time_includes_all_microbatches: - full_iteration_time_ms = projected_time_ms + grad_ar_per_iteration_ms - time_breakdown_str = f"{full_iteration_time_ms:.3f} ms (from pipeline simulation" + full_iteration_time_ms = ( + projected_time_ms + grad_ar_per_iteration_ms + optimizer_step_ms + ) + time_breakdown_str = ( + f"{full_iteration_time_ms:.3f} ms (from pipeline simulation" + ) if grad_ar_per_iteration_ms > 0: time_breakdown_str += f" + {grad_ar_per_iteration_ms:.1f} ms grad AR" - time_breakdown_str += ")" + time_breakdown_str += f" + {optimizer_step_ms:.1f} ms optimizer)" else: compute_total = projected_time_ms * target_microbatches_per_gpu - full_iteration_time_ms = compute_total + grad_ar_per_iteration_ms + full_iteration_time_ms = ( + compute_total + grad_ar_per_iteration_ms + optimizer_step_ms + ) time_breakdown_str = f"{full_iteration_time_ms:.3f} ms ({target_microbatches_per_gpu} microbatches × {projected_time_ms:.3f} ms" if grad_ar_per_iteration_ms > 0: time_breakdown_str += f" + {grad_ar_per_iteration_ms:.1f} ms grad AR" - time_breakdown_str += ")" + time_breakdown_str += f" + {optimizer_step_ms:.1f} ms optimizer)" # Calculate tokens/s/GPU (tokens processed per second per GPU) tokens_per_iter = global_batch * seq_len @@ -1764,7 +2021,10 @@ def _run_multinode_projection( for op_name, op_time in breakdown.items(): if op_time > 0: print(f" {op_name}: {op_time:.3f} ms", end="") - if op_name == "gradient_allreduce" and "gradient_allreduce_size_mb" in message_info: + if ( + op_name == "gradient_allreduce" + and "gradient_allreduce_size_mb" in message_info + ): moe_no_overlap = message_info.get("moe_ar_no_overlap", False) if moe_no_overlap: detail = " [MoE: NOT overlapped]" @@ -1774,9 +2034,13 @@ def _run_multinode_projection( detail += f"\n Expert AR: {expert_ms:.1f} ms (across {dp_reps} nodes)" detail += f" | Non-expert AR: {non_expert_ms:.1f} ms" else: - overlapped_flag = message_info.get("gradient_allreduce_overlapped", False) + overlapped_flag = message_info.get( + "gradient_allreduce_overlapped", False + ) detail = " [OVERLAPPED]" if overlapped_flag else "" - print(f" (message: {message_info['gradient_allreduce_size_mb']:.2f} MB){detail}") + print( + f" (message: {message_info['gradient_allreduce_size_mb']:.2f} MB){detail}" + ) elif op_name == "moe_a2a_fwd" and "moe_a2a_size_mb" in message_info: print( f" (message: {message_info['moe_a2a_size_mb']:.2f} MB, {message_info['num_moe_layers']} layers × {message_info['moe_a2a_per_layer_fwd']:.3f} ms/layer)" @@ -1827,7 +2091,9 @@ def launch_projection_from_cli(args, overrides): """ cfg_path = Path(args.config) if not cfg_path.exists(): - raise FileNotFoundError(f"[Primus:Performance Projection] Config file '{cfg_path}' not found.") + raise FileNotFoundError( + f"[Primus:Performance Projection] Config file '{cfg_path}' not found." + ) # Load Primus configuration primus_config, unknown_overrides = load_primus_config(args, overrides) @@ -1841,7 +2107,9 @@ def launch_projection_from_cli(args, overrides): # Store original parallelism before any modifications module_config = primus_config.get_module_config("pre_trainer") - reduction_info = _calculate_single_node_config(copy.deepcopy(module_config), gpus_per_node) + reduction_info = _calculate_single_node_config( + copy.deepcopy(module_config), gpus_per_node + ) # Calculate minimum nodes required min_nodes_required = reduction_info["original_nodes_required"] @@ -1868,9 +2136,13 @@ def launch_projection_from_cli(args, overrides): # Show what was changed changes = [] if reduction_info["original_pp"] != reduction_info["benchmark_pp"]: - changes.append(f"PP {reduction_info['original_pp']} → {reduction_info['benchmark_pp']}") + changes.append( + f"PP {reduction_info['original_pp']} → {reduction_info['benchmark_pp']}" + ) if reduction_info["original_ep"] != reduction_info["benchmark_ep"]: - changes.append(f"EP {reduction_info['original_ep']} → {reduction_info['benchmark_ep']}") + changes.append( + f"EP {reduction_info['original_ep']} → {reduction_info['benchmark_ep']}" + ) if changes: print(f" ({', '.join(changes)})") @@ -1879,12 +2151,12 @@ def launch_projection_from_cli(args, overrides): print("=" * 100) # Apply the reduction to the config used for benchmarking - primus_config.get_module_config("pre_trainer").pipeline_model_parallel_size = reduction_info[ - "benchmark_pp" - ] - primus_config.get_module_config("pre_trainer").expert_model_parallel_size = reduction_info[ - "benchmark_ep" - ] + primus_config.get_module_config("pre_trainer").pipeline_model_parallel_size = ( + reduction_info["benchmark_pp"] + ) + primus_config.get_module_config("pre_trainer").expert_model_parallel_size = ( + reduction_info["benchmark_ep"] + ) # Determine profiling mode profiling_mode = getattr(args, "profiling_mode", "benchmark") @@ -1919,8 +2191,12 @@ def launch_projection_from_cli(args, overrides): fwd_err = ((s_fwd - b_fwd) / b_fwd * 100) if b_fwd else 0 bwd_err = ((s_bwd - b_bwd) / b_bwd * 100) if b_bwd else 0 print(f" Layer type: {lt}") - print(f" Forward: bench={b_fwd:.2f} ms sim={s_fwd:.2f} ms (err={fwd_err:+.1f}%)") - print(f" Backward: bench={b_bwd:.2f} ms sim={s_bwd:.2f} ms (err={bwd_err:+.1f}%)") + print( + f" Forward: bench={b_fwd:.2f} ms sim={s_fwd:.2f} ms (err={fwd_err:+.1f}%)" + ) + print( + f" Backward: bench={b_bwd:.2f} ms sim={s_bwd:.2f} ms (err={bwd_err:+.1f}%)" + ) print("=" * 100) # Use benchmark results for the rest of the pipeline @@ -1963,7 +2239,9 @@ def launch_projection_from_cli(args, overrides): print( f" Benchmark Config: PP={benchmark_pp}, EP={benchmark_ep}, TP={tp}, CP={cp}, DP={benchmark_dp} (1 node)" ) - print(f" Target Config: PP={pp}, EP={ep}, TP={tp}, CP={cp}, DP={target_dp} ({target_nodes} nodes)") + print( + f" Target Config: PP={pp}, EP={ep}, TP={tp}, CP={cp}, DP={target_dp} ({target_nodes} nodes)" + ) # Use BENCHMARK DP for pipeline simulation to get consistent baseline # The multinode projection will then scale from this baseline to target @@ -1975,7 +2253,9 @@ def launch_projection_from_cli(args, overrides): # common case for configs that already require all target GPUs for their # parallelism dims). Using benchmark_dp here would give 2× too many # microbatches when benchmark_dp < target_dp. - target_microbatches = global_batch // (micro_batch * target_dp) if target_dp > 0 else 1 + target_microbatches = ( + global_batch // (micro_batch * target_dp) if target_dp > 0 else 1 + ) target_microbatches = max(1, target_microbatches) benchmark_microbatches = global_batch // (micro_batch * benchmark_dp) if is_rank_0: @@ -1992,7 +2272,10 @@ def launch_projection_from_cli(args, overrides): # If EP was rescaled, adjust profiling_results to add EP overhead BEFORE pipeline simulation ep_overhead_applied = False - if reduction_info["adjusted"] and reduction_info["original_ep"] != reduction_info["benchmark_ep"]: + if ( + reduction_info["adjusted"] + and reduction_info["original_ep"] != reduction_info["benchmark_ep"] + ): original_ep = reduction_info["original_ep"] benchmark_ep = reduction_info["benchmark_ep"] @@ -2002,20 +2285,26 @@ def launch_projection_from_cli(args, overrides): hardware_config_dict = load_hardware_config(args.hardware_config) # Calculate EP communication overhead per layer - fwd_overhead_per_layer, bwd_overhead_per_layer = _estimate_ep_communication_overhead( - training_config, - original_ep, - benchmark_ep, - hardware_config_dict, + fwd_overhead_per_layer, bwd_overhead_per_layer = ( + _estimate_ep_communication_overhead( + training_config, + original_ep, + benchmark_ep, + hardware_config_dict, + ) ) # EP compute scaling: when EP increases, each GPU handles fewer routed # expert tokens, but shared expert compute stays constant. # Use _compute_ep_mlp_scale to get the correct fraction-aware scale. - ep_mlp_scale = _compute_ep_mlp_scale(training_config.model_config, benchmark_ep, original_ep) + ep_mlp_scale = _compute_ep_mlp_scale( + training_config.model_config, benchmark_ep, original_ep + ) if is_rank_0: - print("[Primus:Performance Projection] Adjusting profiling results for EP scaling:") + print( + "[Primus:Performance Projection] Adjusting profiling results for EP scaling:" + ) print(f" EP rescaled: {benchmark_ep} → {original_ep}") print(f" MLP time scale factor: {ep_mlp_scale:.3f}") # Show shared vs routed breakdown @@ -2026,7 +2315,9 @@ def launch_projection_from_cli(args, overrides): "moe_shared_expert_intermediate_size", None, ) - num_shared = getattr(training_config.model_config, "num_shared_experts", 0) or 0 + num_shared = ( + getattr(training_config.model_config, "num_shared_experts", 0) or 0 + ) if moe_ffn and num_shared > 0 and shared_ffn: routed_flops = (topk / benchmark_ep) * moe_ffn shared_flops = num_shared * shared_ffn @@ -2038,7 +2329,9 @@ def launch_projection_from_cli(args, overrides): f" Shared fraction: {shared_flops/total_flops:.1%} ({num_shared} shared expert(s), ffn={shared_ffn})" ) else: - print(f" No shared experts — full routed scaling ({benchmark_ep}/{original_ep})") + print( + f" No shared experts — full routed scaling ({benchmark_ep}/{original_ep})" + ) if fwd_overhead_per_layer > 0 or bwd_overhead_per_layer > 0: print(f" Adding per-layer All-to-All overhead:") print(f" Forward: +{fwd_overhead_per_layer:.3f} ms/layer") @@ -2067,8 +2360,12 @@ def launch_projection_from_cli(args, overrides): old_fwd = layer_data.get("forward_time_ms", 0) old_bwd = layer_data.get("backward_time_ms", 0) print(f" MoE layer adjustment (per layer):") - print(f" MLP fwd: {mlp_fwd:.2f} → {new_mlp_fwd:.2f} ms (×{ep_mlp_scale:.3f})") - print(f" MLP bwd: {mlp_bwd:.2f} → {new_mlp_bwd:.2f} ms (×{ep_mlp_scale:.3f})") + print( + f" MLP fwd: {mlp_fwd:.2f} → {new_mlp_fwd:.2f} ms (×{ep_mlp_scale:.3f})" + ) + print( + f" MLP bwd: {mlp_bwd:.2f} → {new_mlp_bwd:.2f} ms (×{ep_mlp_scale:.3f})" + ) print(f" Attn fwd: {attn_fwd:.2f} ms (unchanged)") print(f" Attn bwd: {attn_bwd:.2f} ms (unchanged)") print(f" Layer fwd: {old_fwd:.2f} → {new_fwd:.2f} ms") @@ -2118,13 +2415,17 @@ def launch_projection_from_cli(args, overrides): # No need to add additional PP overhead benchmarked_time_ms = pipeline_simulation_time_ms if is_rank_0: - print(f" (Pipeline simulation already includes PP={reduction_info['original_pp']} effects)") + print( + f" (Pipeline simulation already includes PP={reduction_info['original_pp']} effects)" + ) else: if is_rank_0: print( "[Primus:Performance Projection] Pipeline simulation not available, using extrapolated time from profiling" ) - measured_time_ms = extract_single_node_time_from_profiling(profiling_results, training_config) + measured_time_ms = extract_single_node_time_from_profiling( + profiling_results, training_config + ) # If we reduced PP for benchmarking, estimate the time with PP overhead if reduction_info["adjusted"]: @@ -2142,7 +2443,9 @@ def launch_projection_from_cli(args, overrides): if is_rank_0: print("[Primus:Performance Projection] Time Adjustment:") - print(f" Measured time (PP={reduction_info['benchmark_pp']}): {measured_time_ms:.2f} ms") + print( + f" Measured time (PP={reduction_info['benchmark_pp']}): {measured_time_ms:.2f} ms" + ) print( f" Estimated PP overhead (PP={reduction_info['original_pp']}): {pp_overhead_ms:.2f} ms" ) @@ -2162,24 +2465,32 @@ def launch_projection_from_cli(args, overrides): benchmark_ep_val = reduction_info["benchmark_ep"] # Get the number of MoE layers - moe_pattern = getattr(training_config.model_config, "moe_layer_pattern", []) + moe_pattern = getattr( + training_config.model_config, "moe_layer_pattern", [] + ) if not moe_pattern: # If no pattern, check if model has MoE layers - num_moe_layers = getattr(training_config.model_config, "num_moe_layers", 0) + num_moe_layers = getattr( + training_config.model_config, "num_moe_layers", 0 + ) else: num_moe_layers = sum(1 for x in moe_pattern if x == 1) if num_moe_layers > 0: # Calculate EP communication overhead per layer - fwd_overhead_per_layer, bwd_overhead_per_layer = _estimate_ep_communication_overhead( - training_config, - original_ep, - benchmark_ep_val, - hardware_config_dict, + fwd_overhead_per_layer, bwd_overhead_per_layer = ( + _estimate_ep_communication_overhead( + training_config, + original_ep, + benchmark_ep_val, + hardware_config_dict, + ) ) # Total EP overhead = per-layer overhead * number of MoE layers - total_ep_overhead_ms = (fwd_overhead_per_layer + bwd_overhead_per_layer) * num_moe_layers + total_ep_overhead_ms = ( + fwd_overhead_per_layer + bwd_overhead_per_layer + ) * num_moe_layers # EP compute scaling (shared-expert-aware) ep_mlp_scale = _compute_ep_mlp_scale( @@ -2188,23 +2499,32 @@ def launch_projection_from_cli(args, overrides): # Estimate MLP portion of MoE layer time from profiling results mlp_time_reduction = 0.0 for layer_idx, layer_data in profiling_results.items(): - if isinstance(layer_data, dict) and layer_data.get("type") == "moe": + if ( + isinstance(layer_data, dict) + and layer_data.get("type") == "moe" + ): mlp_info = layer_data.get("mlp", {}) - mlp_total = mlp_info.get("forward_time_ms", 0) + mlp_info.get( - "backward_time_ms", 0 - ) + mlp_total = mlp_info.get( + "forward_time_ms", 0 + ) + mlp_info.get("backward_time_ms", 0) mlp_time_reduction = mlp_total * (1 - ep_mlp_scale) break # All MoE layers have same profiled time total_mlp_reduction_ms = mlp_time_reduction * num_moe_layers if is_rank_0: - print("[Primus:Performance Projection] EP Compute + Communication Adjustment:") + print( + "[Primus:Performance Projection] EP Compute + Communication Adjustment:" + ) print(f" EP rescaled: {benchmark_ep_val} → {original_ep}") print(f" Number of MoE layers: {num_moe_layers}") print(f" MLP time scale factor: {ep_mlp_scale:.3f}") - print(f" Total MLP compute reduction: -{total_mlp_reduction_ms:.3f} ms") - print(f" Total A2A comm overhead: +{total_ep_overhead_ms:.3f} ms") + print( + f" Total MLP compute reduction: -{total_mlp_reduction_ms:.3f} ms" + ) + print( + f" Total A2A comm overhead: +{total_ep_overhead_ms:.3f} ms" + ) net_change = total_ep_overhead_ms - total_mlp_reduction_ms print(f" Net adjustment: {net_change:+.3f} ms") diff --git a/primus/core/projection/simulation_backends/__init__.py b/primus/core/projection/simulation_backends/__init__.py new file mode 100644 index 000000000..898900ddc --- /dev/null +++ b/primus/core/projection/simulation_backends/__init__.py @@ -0,0 +1,23 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +from primus.core.projection.simulation_backends.base import ( + GEMMSimulationBackend, + SDPASimulationBackend, + SimulationResult, +) +from primus.core.projection.simulation_backends.factory import ( + get_gemm_simulation_backend, + get_sdpa_simulation_backend, +) + +__all__ = [ + "GEMMSimulationBackend", + "SDPASimulationBackend", + "SimulationResult", + "get_gemm_simulation_backend", + "get_sdpa_simulation_backend", +] diff --git a/primus/core/projection/simulation_backends/base.py b/primus/core/projection/simulation_backends/base.py new file mode 100644 index 000000000..91b5aecd1 --- /dev/null +++ b/primus/core/projection/simulation_backends/base.py @@ -0,0 +1,213 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Abstract base classes for GEMM and SDPA simulation backends. + +These backends provide simulated (analytical/model-based) timing for GEMM and +SDPA operations, allowing performance projection without running actual GPU +kernels. Two concrete GEMM backends are shipped: + +- **Origami** (open-source, default) – ``origami_backend.py`` + +An SDPA simulation backend is provided in ``sdpa_simulator.py``. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + + +@dataclass +class SimulationResult: + """Result from a simulation backend.""" + + # Predicted time in milliseconds + forward_time_ms: float = 0.0 + backward_time_ms: float = 0.0 + + # Optional: predicted TFLOPS / bandwidth + tflops: Optional[float] = None + bandwidth_gbps: Optional[float] = None + + # Optional: extra metadata from the backend + metadata: Dict[str, Any] = field(default_factory=dict) + + +class GEMMSimulationBackend(ABC): + """Abstract interface for GEMM simulation backends.""" + + @abstractmethod + def name(self) -> str: + """Return human-readable backend name.""" + ... + + @abstractmethod + def is_available(self) -> bool: + """Return True if this backend can be used in the current environment.""" + ... + + @abstractmethod + def simulate_gemm( + self, + m: int, + n: int, + k: int, + dtype: str = "bf16", + trans_a: bool = False, + trans_b: bool = False, + ) -> SimulationResult: + """ + Simulate a single GEMM operation and return predicted timing. + + Args: + m, n, k: Matrix dimensions (C = A @ B, A:[M,K] B:[K,N] C:[M,N]) + dtype: Data type string ("bf16", "fp16", "fp8", "fp32") + trans_a: Whether A is transposed + trans_b: Whether B is transposed + + Returns: + SimulationResult with forward_time_ms populated. + """ + ... + + def simulate_mlp_gemms( + self, + batch_tokens: int, + hidden_size: int, + ffn_hidden_size: int, + dtype: str = "bf16", + swiglu: bool = False, + ) -> SimulationResult: + """ + Simulate the GEMM operations in a dense MLP (gate/up/down projections). + + Default implementation calls ``simulate_gemm`` for each projection and + sums the times. Backends may override for better accuracy. + + Args: + batch_tokens: Number of tokens (batch_size * seq_len / TP / CP) + hidden_size: Model hidden dimension + ffn_hidden_size: FFN intermediate dimension + dtype: Data type string + swiglu: Whether SwiGLU activation is used (3 projections vs 2) + + Returns: + SimulationResult with forward_time_ms and backward_time_ms. + """ + fwd_time = 0.0 + bwd_time = 0.0 + + if swiglu: + # Gate projection: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] + gate_res = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype) + # Up projection: same shape + up_res = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype) + # Down projection: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] + down_res = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype) + + fwd_time = gate_res.forward_time_ms + up_res.forward_time_ms + down_res.forward_time_ms + # Backward is approximately 2x forward (dgrad + wgrad per projection) + bwd_time = fwd_time * 2.0 + else: + # Up projection + up_res = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype) + # Down projection + down_res = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype) + + fwd_time = up_res.forward_time_ms + down_res.forward_time_ms + bwd_time = fwd_time * 2.0 + + return SimulationResult(forward_time_ms=fwd_time, backward_time_ms=bwd_time) + + def simulate_attention_gemms( + self, + batch_tokens: int, + hidden_size: int, + num_attention_heads: int, + kv_channels: int, + num_query_groups: int, + dtype: str = "bf16", + ) -> SimulationResult: + """ + Simulate the linear projection GEMMs in the attention block + (QKV projections + output projection). Does NOT include the SDPA + computation itself – use SDPASimulationBackend for that. + + Default implementation calls ``simulate_gemm`` for Q, K, V, O projections. + + Returns: + SimulationResult with forward_time_ms and backward_time_ms. + """ + fwd_time = 0.0 + + # Q projection: [tokens, hidden] x [hidden, heads*kv_channels] + q_out = num_attention_heads * kv_channels + q_res = self.simulate_gemm(batch_tokens, q_out, hidden_size, dtype) + fwd_time += q_res.forward_time_ms + + # K projection: [tokens, hidden] x [hidden, num_query_groups*kv_channels] + k_out = num_query_groups * kv_channels + k_res = self.simulate_gemm(batch_tokens, k_out, hidden_size, dtype) + fwd_time += k_res.forward_time_ms + + # V projection: same shape as K + v_res = self.simulate_gemm(batch_tokens, k_out, hidden_size, dtype) + fwd_time += v_res.forward_time_ms + + # Output projection: [tokens, heads*kv_channels] x [heads*kv_channels, hidden] + o_res = self.simulate_gemm(batch_tokens, hidden_size, q_out, dtype) + fwd_time += o_res.forward_time_ms + + bwd_time = fwd_time * 2.0 # dgrad + wgrad + + return SimulationResult(forward_time_ms=fwd_time, backward_time_ms=bwd_time) + + +class SDPASimulationBackend(ABC): + """Abstract interface for Scaled Dot-Product Attention simulation.""" + + @abstractmethod + def name(self) -> str: + """Return human-readable backend name.""" + ... + + @abstractmethod + def is_available(self) -> bool: + """Return True if this backend can be used in the current environment.""" + ... + + @abstractmethod + def simulate_sdpa( + self, + batch_size: int, + num_heads: int, + seq_len: int, + head_dim: int, + causal: bool = True, + dtype: str = "bf16", + seq_len_kv: Optional[int] = None, + num_heads_kv: Optional[int] = None, + ) -> SimulationResult: + """ + Simulate a Scaled Dot-Product Attention operation. + + Args: + batch_size: Batch size + num_heads: Number of query attention heads (per TP rank) + seq_len: Query sequence length (per CP rank) + head_dim: Head dimension (kv_channels) + causal: Whether causal masking is used + dtype: Data type string + seq_len_kv: Key/Value sequence length. Defaults to ``seq_len`` + (self-attention). + num_heads_kv: Number of KV heads. Defaults to ``num_heads`` + (MHA). Set lower for GQA / MQA. + + Returns: + SimulationResult with forward_time_ms and backward_time_ms. + """ + ... diff --git a/primus/core/projection/simulation_backends/factory.py b/primus/core/projection/simulation_backends/factory.py new file mode 100644 index 000000000..277406db6 --- /dev/null +++ b/primus/core/projection/simulation_backends/factory.py @@ -0,0 +1,118 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Factory functions for creating simulation backends. + +Backend selection for GEMM: + 1. If ``PRIMUS_GEMM_BACKEND`` is set, use that backend explicitly. + 2. Otherwise, use **origami** (the default, open-source backend). + +SDPA always uses the built-in analytical simulator. +""" + +import os +from typing import Optional + +from primus.core.projection.simulation_backends.base import ( + GEMMSimulationBackend, + SDPASimulationBackend, +) + + +def get_gemm_simulation_backend( + backend_name: Optional[str] = None, + gpu_arch: Optional[str] = None, + gpu_clock_mhz: Optional[int] = None, +) -> GEMMSimulationBackend: + """ + Create and return the GEMM simulation backend (origami). + + Args: + backend_name: Explicit backend name. Currently only "origami" is supported. + If None, defaults to origami. + gpu_arch: GPU architecture override (e.g. "gfx942", "mi300x", "mi325x"). + gpu_clock_mhz: Override the GPU compute clock frequency in MHz. + + Returns: + A GEMMSimulationBackend instance. + + Raises: + RuntimeError: If the backend is not available. + """ + name = backend_name or os.getenv("PRIMUS_GEMM_BACKEND", None) + + if name is not None: + name = name.lower().strip() + + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + + if name is not None and name != "origami": + raise ValueError( + f"Unknown GEMM simulation backend: '{name}'. " + f"Supported backend: 'origami'" + ) + + from primus.core.projection.simulation_backends.origami_backend import ( + OrigamiGEMMBackend, + ) + + backend = OrigamiGEMMBackend(gpu_arch=gpu_arch, gpu_clock_mhz=gpu_clock_mhz) + if not backend.is_available(): + raise RuntimeError( + "Origami GEMM simulation backend is not available.\n" + "Install it with: pip install origami" + ) + + if is_rank_0: + print("[Primus:Simulation] Using GEMM backend: origami") + return backend + + +def get_sdpa_simulation_backend( + gpu_arch: Optional[str] = None, + compute_efficiency: float = 0.51, + memory_efficiency: float = 0.85, + gpu_clock_mhz: Optional[int] = None, +) -> SDPASimulationBackend: + """ + Create and return the SDPA simulation backend. + + The default backend is an analytical model of the FAv3 (Flash Attention v3) + kernels, with tile sizes, wavefront counts, and efficiency factors + derived from the kernel configurations. + + Args: + gpu_arch: GPU architecture override (e.g. "mi300x", "mi355x"). + compute_efficiency: Fraction of peak compute achieved (0-1). + Defaults to 0.51 — calibrated against measured FAv3 traces on + MI300X (B=3, H_Q=64, S=8192, D=128, H_KV=8, GQA, causal, BF16). + The lower-than-theoretical efficiency accounts for GQA head + broadcasting overhead, LDS bank conflicts, barrier synchronisation, + and register pressure. + memory_efficiency: Fraction of peak HBM bandwidth achieved (0-1). + Defaults to 0.85 — FAv3 streaming pattern typically achieves 0.80-0.90. + gpu_clock_mhz: Override the GPU compute clock frequency in MHz. + + Returns: + An SDPASimulationBackend instance. + """ + from primus.core.projection.simulation_backends.sdpa_simulator import ( + SDPASimulator, + ) + + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print( + "[Primus:Simulation] Using SDPA backend: sdpa_simulator (FAv3 analytical model)" + ) + + return SDPASimulator( + gpu_arch=gpu_arch, + compute_efficiency=compute_efficiency, + memory_efficiency=memory_efficiency, + gpu_clock_mhz=gpu_clock_mhz, + ) diff --git a/primus/core/projection/simulation_backends/origami_backend.py b/primus/core/projection/simulation_backends/origami_backend.py new file mode 100644 index 000000000..5d344fdd4 --- /dev/null +++ b/primus/core/projection/simulation_backends/origami_backend.py @@ -0,0 +1,437 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Origami GEMM simulation backend. + +Origami is an open-source analytical performance model for GEMM kernels on +AMD GPUs (part of the ROCm ecosystem). It predicts kernel execution time +based on matrix dimensions, data type, tile configuration, and hardware +characteristics. + +This is the **default** backend for GEMM simulation in Primus performance +projection. + +Installation: + pip install git+https://github.com/ROCm/rocm-libraries.git#subdirectory=shared/origami/python + +Environment variables: + PRIMUS_GEMM_BACKEND – set to "origami" (or leave unset) to use this backend. + PRIMUS_GPU_ARCH – GPU architecture override (e.g. "gfx942", "gfx950"). + PRIMUS_GPU_DEVICE – GPU device index for hardware detection (default: 0). +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +from primus.core.projection.simulation_backends.base import ( + GEMMSimulationBackend, + SimulationResult, +) + +# --------------------------------------------------------------------------- +# Lazy import – we don't want to fail at module-import time. +# --------------------------------------------------------------------------- +_origami = None +_origami_available: Optional[bool] = None + + +def _try_import_origami(): + """Try to import origami and cache the result.""" + global _origami, _origami_available + if _origami_available is not None: + return _origami_available + + try: + import origami # type: ignore[import-untyped] + + _origami = origami + _origami_available = True + except ImportError: + _origami = None + _origami_available = False + + return _origami_available + + +# --------------------------------------------------------------------------- +# Known hardware profiles for GPU-less simulation via get_hardware_for_arch. +# --------------------------------------------------------------------------- +@dataclass +class _HardwareProfile: + """Parameters required by ``origami.get_hardware_for_arch``.""" + + arch_enum_name: str # attribute name on ``origami.architecture_t`` + n_cu: int + lds_capacity: int # bytes + l2_capacity: int # bytes (per XCD) + compute_clock_khz: int + + +_KNOWN_PROFILES: Dict[str, _HardwareProfile] = { + # MI300X / gfx942 + "mi300x": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000), + "gfx942": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000), + # MI325X / gfx942 (same die as MI300X, HBM3E upgrade) + "mi325x": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000), + # MI355X / gfx950 + "mi355x": _HardwareProfile("gfx950", 256, 65536, 4_194_304, 2_100_000), + "gfx950": _HardwareProfile("gfx950", 256, 65536, 4_194_304, 2_100_000), + # MI300A + "mi300a": _HardwareProfile("gfx942", 228, 65536, 4_194_304, 2_100_000), +} + +# --------------------------------------------------------------------------- +# Dtype mapping: Primus string → origami datatype string +# (origami.string_to_datatype accepts these short-hand names) +# --------------------------------------------------------------------------- +_DTYPE_MAP: Dict[str, str] = { + "bf16": "bf16", + "fp16": "f16", + "fp32": "f32", + "fp8": "bf8_fnuz", +} + +# --------------------------------------------------------------------------- +# Default candidate tile configurations (macro-tile M×N×K + occupancy). +# A wider search space yields better latency predictions at the cost of +# slightly longer selection time (still << 1 ms per GEMM). +# --------------------------------------------------------------------------- +_DEFAULT_TILE_SIZES: List[Tuple[int, int, int]] = [ + (64, 64, 32), + (64, 64, 64), + (64, 128, 32), + (64, 128, 64), + (128, 64, 32), + (128, 64, 64), + (128, 128, 32), + (128, 128, 64), + (128, 256, 32), + (128, 256, 64), + (256, 128, 32), + (256, 128, 64), + (256, 256, 32), + (256, 256, 64), +] +_DEFAULT_OCCUPANCIES: List[int] = [1, 2, 4] + + +class OrigamiGEMMBackend(GEMMSimulationBackend): + """ + GEMM simulation backend using Origami (open-source). + + Hardware is obtained in one of two ways (in priority order): + + 1. **From the local GPU** – ``origami.get_hardware_for_device(device_idx)`` + is called when a ROCm-capable GPU is present. The device index defaults + to 0 and can be overridden via the ``PRIMUS_GPU_DEVICE`` env var. + 2. **From a known profile** – when no GPU is available *and* a + ``--gpu-arch`` / ``PRIMUS_GPU_ARCH`` is provided, the backend falls back + to ``origami.get_hardware_for_arch`` with hard-coded parameters for + MI300X, MI355X, etc. + """ + + def __init__( + self, + gpu_arch: Optional[str] = None, + gpu_clock_mhz: Optional[int] = None, + ): + """ + Args: + gpu_arch: Target GPU architecture string (e.g. "gfx942", "mi300x"). + If *None*, auto-detected from the current GPU or the + ``PRIMUS_GPU_ARCH`` env var. + gpu_clock_mhz: Override the compute clock frequency in MHz. + If *None*, uses the profile default or the + ``PRIMUS_GPU_CLOCK_MHZ`` env var. + """ + self._gpu_arch = gpu_arch or os.getenv("PRIMUS_GPU_ARCH", None) + if self._gpu_arch is not None: + self._gpu_arch = self._gpu_arch.lower().strip() + + # Clock override: CLI > env var > profile default + _env_clock = os.getenv("PRIMUS_GPU_CLOCK_MHZ", None) + self._clock_override_mhz: Optional[int] = gpu_clock_mhz or ( + int(_env_clock) if _env_clock else None + ) + + # Lazily initialised origami objects – see ``_ensure_initialized``. + self._hardware = None # origami.hardware_t + self._configs = None # list[origami.config_t] + self._clock_ghz: Optional[float] = None + self._initialized = False + self._init_dtype: Optional[str] = None # tracks dtype used for config init + self._fp8_mi_unavailable = False # set True if FP8 MI is 0x0x0 + + # ------------------------------------------------------------------ + # GEMMSimulationBackend interface + # ------------------------------------------------------------------ + + def name(self) -> str: + return "origami" + + def is_available(self) -> bool: + return _try_import_origami() + + def simulate_gemm( + self, + m: int, + n: int, + k: int, + dtype: str = "bf16", + trans_a: bool = False, + trans_b: bool = False, + batch: int = 1, + ) -> SimulationResult: + if not self.is_available(): + raise RuntimeError( + "Origami is not installed. Install with:\n" + " pip install git+https://github.com/ROCm/rocm-libraries.git" + "#subdirectory=shared/origami/python" + ) + + # FP8 fallback: if Origami doesn't support FP8 MI for this arch, + # simulate as BF16. On MI300X, Origami BF16 predictions already + # closely match FP8-hybrid measured performance (the natural model + # overestimation roughly offsets the FP8 compute gain). + fp8_fallback = False + sim_dtype = dtype + if dtype == "fp8": + self._ensure_initialized("fp8") + if self._fp8_mi_unavailable: + sim_dtype = "bf16" + fp8_fallback = True + self._ensure_initialized("bf16") + # else: origami supports FP8 natively, proceed normally + else: + self._ensure_initialized(dtype) + + # ----- Build origami problem_t ----- + problem = _origami.problem_t() + problem.size = _origami.dim3_t(m, n, k) + problem.batch = batch + problem.a_transpose = ( + _origami.transpose_t.T if trans_a else _origami.transpose_t.N + ) + problem.b_transpose = ( + _origami.transpose_t.T if trans_b else _origami.transpose_t.N + ) + + origami_dtype = _origami.string_to_datatype(_DTYPE_MAP.get(sim_dtype, "bf16")) + problem.a_dtype = origami_dtype + problem.b_dtype = origami_dtype + problem.c_dtype = origami_dtype + problem.d_dtype = origami_dtype + problem.mi_dtype = origami_dtype + problem.a_mx_block_size = 0 + problem.b_mx_block_size = 0 + + # ----- Select best config & predict latency (in clock cycles) ----- + try: + result = _origami.select_config(problem, self._hardware, self._configs) + except Exception as e: + raise RuntimeError( + f"Origami select_config failed for " + f"(M={m}, N={n}, K={k}, dtype={dtype}): {e}" + ) from e + + latency_cycles = result.latency + time_ms = latency_cycles / (self._clock_ghz * 1e6) + + # FP8 speedup: when Origami can't natively simulate FP8 (no FP8 MI) + # we ran the simulation in BF16. FP8 has 2x the throughput of BF16 + # on gfx942, so divide the time by 2. + if fp8_fallback: + time_ms /= 2.0 + + # Compute achieved TFLOPS for metadata + flops = 2.0 * m * n * k * batch + time_s = time_ms / 1e3 + tflops = (flops / time_s / 1e12) if time_s > 0 else 0.0 + + return SimulationResult( + forward_time_ms=time_ms, + backward_time_ms=0.0, # Caller computes bwd from fwd + tflops=tflops, + metadata={ + "backend": "origami", + "gpu_arch": self._gpu_arch, + "latency_cycles": latency_cycles, + "clock_ghz": self._clock_ghz, + "dtype": dtype, + "batch": batch, + "fp8_fallback": fp8_fallback, + "best_tile": ( + result.config.mt.m, + result.config.mt.n, + result.config.mt.k, + ), + "best_occupancy": result.config.occupancy, + }, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _ensure_initialized(self, dtype: str = "bf16") -> None: + """Lazily initialise hardware and candidate configs. + + If called again with a *different* dtype the candidate config list is + rebuilt (the matrix-instruction size changes between BF16 and FP8). + """ + # Hardware only needs to be detected once. + if self._hardware is None: + self._hardware = self._get_hardware() + self._clock_ghz = self._hardware.compute_clock_ghz + + # (Re-)build candidate configs when the dtype changes. + if self._initialized and self._init_dtype == dtype: + return + + origami_str = _DTYPE_MAP.get(dtype, "bf16") + try: + origami_dtype = _origami.string_to_datatype(origami_str) + except (ValueError, RuntimeError): + # Origami doesn't know this dtype string — treat as unavailable + if dtype == "fp8": + self._fp8_mi_unavailable = True + return + + mi = self._hardware.get_recommended_matrix_instruction(origami_dtype) + + # Detect MI=0x0x0 (datatype not supported by this arch) + if mi.m == 0 and mi.n == 0 and mi.k == 0: + if dtype == "fp8": + self._fp8_mi_unavailable = True + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print( + f"[Primus:Origami] FP8 matrix instruction not available " + f"for this hardware; will use BF16 with 2x speedup factor" + ) + return + + configs: list = [] + for mt_m, mt_n, mt_k in _DEFAULT_TILE_SIZES: + for occ in _DEFAULT_OCCUPANCIES: + cfg = _origami.config_t() + cfg.mt = _origami.dim3_t(mt_m, mt_n, mt_k) + cfg.mi = mi + cfg.occupancy = occ + configs.append(cfg) + self._configs = configs + self._init_dtype = dtype + + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print( + f"[Primus:Origami] Initialised: " + f"N_CU={self._hardware.N_CU}, " + f"NUM_XCD={self._hardware.NUM_XCD}, " + f"clock={self._clock_ghz} GHz, " + f"MI={mi.m}x{mi.n}x{mi.k}, " + f"dtype={dtype}, " + f"{len(configs)} candidate configs" + ) + + self._initialized = True + + def _get_hardware(self): + """ + Obtain an ``origami.hardware_t`` instance. + + Priority: + 1. If ``--gpu-arch`` was explicitly provided AND we have a known + profile for it, use the profile. This ensures consistent + results regardless of the local GPU (important for simulation + mode targeting a *different* GPU, e.g. simulating MI325X on + MI300X). + 2. Otherwise, try the local GPU. + 3. Fall back to the arch profile if no GPU is available. + """ + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + + # 1. If an explicit arch was requested AND we have a profile, use it. + if self._gpu_arch is not None and self._gpu_arch in _KNOWN_PROFILES: + profile = _KNOWN_PROFILES[self._gpu_arch] + clock_khz = profile.compute_clock_khz + if self._clock_override_mhz is not None: + clock_khz = self._clock_override_mhz * 1000 + arch_enum = getattr(_origami.architecture_t, profile.arch_enum_name) + hw = _origami.get_hardware_for_arch( + arch_enum, + profile.n_cu, + profile.lds_capacity, + profile.l2_capacity, + clock_khz, + ) + if is_rank_0: + override_tag = "" + if self._clock_override_mhz is not None: + override_tag = " (overridden via --gpu-clock-mhz)" + print( + f"[Primus:Origami] Using hardware profile for " + f"'{self._gpu_arch}': N_CU={profile.n_cu}, " + f"clock={clock_khz / 1e6:.1f} GHz{override_tag}" + ) + return hw + + # 2. Try local GPU + device_idx = int(os.getenv("PRIMUS_GPU_DEVICE", "0")) + try: + hw = _origami.get_hardware_for_device(device_idx) + if is_rank_0: + print( + f"[Primus:Origami] Hardware detected from device {device_idx}: " + f"N_CU={hw.N_CU}, NUM_XCD={hw.NUM_XCD}, " + f"clock={hw.compute_clock_ghz} GHz" + ) + return hw + except Exception: + pass # No GPU – try arch-based profile + + # 3. Fall back to known profile + if self._gpu_arch is None: + raise RuntimeError( + "Origami could not detect a GPU and no --gpu-arch / " + "PRIMUS_GPU_ARCH was specified. Either run on a machine with " + "a ROCm GPU or provide a target architecture " + "(e.g. --gpu-arch mi300x)." + ) + + profile = _KNOWN_PROFILES.get(self._gpu_arch) + if profile is None: + supported = ", ".join(sorted(_KNOWN_PROFILES.keys())) + raise RuntimeError( + f"Unknown GPU architecture '{self._gpu_arch}' for origami. " + f"Supported architectures: {supported}" + ) + + clock_khz = profile.compute_clock_khz + if self._clock_override_mhz is not None: + clock_khz = self._clock_override_mhz * 1000 + arch_enum = getattr(_origami.architecture_t, profile.arch_enum_name) + hw = _origami.get_hardware_for_arch( + arch_enum, + profile.n_cu, + profile.lds_capacity, + profile.l2_capacity, + clock_khz, + ) + if is_rank_0: + override_tag = "" + if self._clock_override_mhz is not None: + override_tag = " (overridden via --gpu-clock-mhz)" + print( + f"[Primus:Origami] Using known hardware profile for " + f"'{self._gpu_arch}': N_CU={profile.n_cu}, " + f"clock={clock_khz / 1e6:.1f} GHz{override_tag}" + ) + return hw diff --git a/primus/core/projection/simulation_backends/sdpa_simulator.py b/primus/core/projection/simulation_backends/sdpa_simulator.py new file mode 100644 index 000000000..060335864 --- /dev/null +++ b/primus/core/projection/simulation_backends/sdpa_simulator.py @@ -0,0 +1,543 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +SDPA simulation backend modelling the **FAv3** (Flash Attention v3) kernels. + +The forward and backward kernel parameters are extracted from FAv3 kernel +configurations: + + Forward: + Config : BF16 ; FMHA FWD ; D128 ; 1TG ; 8W ; 32m×8 ; 64n×1 ; 32×32×16 + • 1 Thread-Group, 8 Wavefronts per workgroup (512 threads) + • Q-tile = 256 rows (32m × 8 wavefronts) + • KV-tile = 64 columns per loop iteration + • MFMA instruction: ``v_mfma_f32_32x32x16_bf16`` + • 64 MFMAs per loop iteration (QKᵀ + softmax + PV, pipelined) + • Workgroups = ⌈S / 256⌉ × B × H + + Backward: + Config : BF16 ; FMHA BWD ; D128 ; 1TG ; 4W ; 16m×1 ; 64n×4 ; A32 + • 1 Thread-Group, 4 Wavefronts per workgroup (256 threads) + • Q-tile = 16 rows per inner loop step + • KV-tile = 256 columns per workgroup (64n × 4) + • 256 MFMAs per inner-loop iteration (dV, dP, dS, dQ, dK phases) + • Workgroups = ⌈S / 256⌉ × B × H + • Inner-loop iterations = ⌈S / 16⌉ (over Q blocks) + +The model uses a **roofline** approach: + time = max(compute_time, memory_time, atomic_time) +with FAv3-specific compute/memory efficiency factors and CU utilisation +derived from the tile sizes. + +In the backward pass, the dQ gradient is accumulated across KV-workgroups +using ``buffer_atomic_add_f32`` (72 atomic instructions in the kernel). +Each KV-workgroup processes all Q positions and atomically adds its partial +dQ contribution, leading to contention proportional to ⌈S / 256⌉ concurrent +writers per dQ cache line. The atomic overhead is modelled as a separate +bottleneck dimension in the roofline. +""" + +from __future__ import annotations + +import math +import os +from dataclasses import dataclass +from typing import Dict, Optional + +from primus.core.projection.simulation_backends.base import ( + SDPASimulationBackend, + SimulationResult, +) + +# ========================================================================= +# FAv3 kernel tile parameters +# ========================================================================= + + +@dataclass(frozen=True) +class _FAv3TileConfig: + """Tile & occupancy parameters extracted from a FAv3 kernel.""" + + q_tile_m: int # Q rows per workgroup + kv_tile_n: int # K/V positions per loop iteration + n_wavefronts: int # Wavefronts per workgroup + mfma_m: int = 32 # MFMA instruction M + mfma_n: int = 32 # MFMA instruction N + mfma_k: int = 16 # MFMA instruction K (BF16 on gfx950) + + +# Forward: 256 Q-rows, 64 KV-cols/iter, 8 wavefronts +_FAV3_FWD = _FAv3TileConfig(q_tile_m=256, kv_tile_n=64, n_wavefronts=8) + +# Backward: 16 Q-rows/inner-iter, 256 KV-cols/workgroup, 4 wavefronts +_FAV3_BWD = _FAv3TileConfig(q_tile_m=16, kv_tile_n=256, n_wavefronts=4) + + +# ========================================================================= +# GPU hardware specs +# ========================================================================= + + +@dataclass +class GPUHardwareSpec: + """Hardware specification for roofline modelling.""" + + # Peak compute throughput in TFLOPS (tera floating-point ops / sec) + peak_tflops_bf16: float = 1307.0 # MI300X BF16 peak + peak_tflops_fp16: float = 1307.0 + peak_tflops_fp8: float = 2614.0 + + # HBM bandwidth in GB/s + hbm_bandwidth_gbps: float = 5300.0 # MI300X HBM3 + + # Total CUs on the device + n_cu: int = 304 # MI300X + + # Max wavefronts per CU (SIMD occupancy limit) + max_waves_per_cu: int = 8 + + # Number of XCDs on the device (cross-die atomics are more expensive) + n_xcd: int = 8 # MI300X has 8 XCDs + + +# Pre-defined hardware profiles +_HW_PROFILES: Dict[str, GPUHardwareSpec] = { + "mi300x": GPUHardwareSpec( + peak_tflops_bf16=1307.0, + peak_tflops_fp16=1307.0, + peak_tflops_fp8=2614.0, + hbm_bandwidth_gbps=5300.0, + n_cu=304, + n_xcd=8, + ), + "gfx942": GPUHardwareSpec( # same as MI300X + peak_tflops_bf16=1307.0, + peak_tflops_fp16=1307.0, + peak_tflops_fp8=2614.0, + hbm_bandwidth_gbps=5300.0, + n_cu=304, + n_xcd=8, + ), + "mi325x": GPUHardwareSpec( # gfx942 die, HBM3E (use --gpu-clock-mhz to override clock) + peak_tflops_bf16=1307.0, + peak_tflops_fp16=1307.0, + peak_tflops_fp8=2614.0, + hbm_bandwidth_gbps=6000.0, # HBM3E ~6 TB/s (vs 5.3 on MI300X) + n_cu=304, + n_xcd=8, + ), + "mi355x": GPUHardwareSpec( + peak_tflops_bf16=2384.0, + peak_tflops_fp16=2384.0, + peak_tflops_fp8=4768.0, + hbm_bandwidth_gbps=8000.0, + n_cu=256, + n_xcd=4, + ), + "gfx950": GPUHardwareSpec( # same as MI355X + peak_tflops_bf16=2384.0, + peak_tflops_fp16=2384.0, + peak_tflops_fp8=4768.0, + hbm_bandwidth_gbps=8000.0, + n_cu=256, + n_xcd=4, + ), +} + + +def _get_hardware_spec( + gpu_arch: Optional[str] = None, + gpu_clock_mhz: Optional[int] = None, +) -> GPUHardwareSpec: + """Get hardware spec for the given (or detected) GPU architecture. + + If *gpu_clock_mhz* is provided, the profile's TFLOPS values are scaled + proportionally (compute throughput is linear in clock frequency). + """ + arch = gpu_arch or os.getenv("PRIMUS_GPU_ARCH", "mi300x") + arch = arch.lower().strip() + spec = _HW_PROFILES.get(arch, _HW_PROFILES["mi300x"]) + + # Apply clock override — scale TFLOPS linearly + clock_override = gpu_clock_mhz or ( + int(v) if (v := os.getenv("PRIMUS_GPU_CLOCK_MHZ")) else None + ) + if clock_override is not None: + # Derive the profile's implicit clock from a known reference. + _PROFILE_CLOCK_MHZ = { + "mi300x": 2100, + "gfx942": 2100, + "mi325x": 1200, + "mi355x": 2100, + "gfx950": 2100, + "mi300a": 2100, + } + base_clock = _PROFILE_CLOCK_MHZ.get(arch, 2100) + scale = clock_override / base_clock + spec = GPUHardwareSpec( + peak_tflops_bf16=spec.peak_tflops_bf16 * scale, + peak_tflops_fp16=spec.peak_tflops_fp16 * scale, + peak_tflops_fp8=spec.peak_tflops_fp8 * scale, + hbm_bandwidth_gbps=spec.hbm_bandwidth_gbps, # BW doesn't change with clock + n_cu=spec.n_cu, + n_xcd=spec.n_xcd, + ) + return spec + + +# ========================================================================= +# SDPASimulator — FAv3-based analytical model +# ========================================================================= + + +class SDPASimulator(SDPASimulationBackend): + """ + Analytical SDPA simulation modelling the FAv3 kernel structure. + + The model captures: + 1. **Total FLOPs** from the SDPA math (QKᵀ, softmax, PV for fwd; + dV, dP/dS, dQ, dK, softmax-bwd for bwd). + 2. **Flash-Attention memory IO** — Q/K/V are streamed from HBM; + the full S/P matrices are never materialised. + 3. **CU utilisation** — derived from the FAv3 tile sizes and the + number of workgroups that can execute concurrently. + 4. **Achieved efficiency** — higher than generic kernels because + FAv3 is hand-tuned ISA with software pipelining and LDS-based + data movement. + 5. **Atomic overhead (BWD only)** — dQ is accumulated across + KV-workgroups via ``buffer_atomic_add_f32`` in FP32. The model + accounts for the read-modify-write penalty and contention from + ⌈S / 256⌉ concurrent writers per dQ cache line. + """ + + def __init__( + self, + gpu_arch: Optional[str] = None, + hardware_spec: Optional[GPUHardwareSpec] = None, + compute_efficiency: float = 0.51, + memory_efficiency: float = 0.85, + atomic_rmw_factor: float = 4.0, + gpu_clock_mhz: Optional[int] = None, + ): + """ + Args: + gpu_arch: GPU architecture string (e.g. "mi300x", "gfx942", + "mi355x", "gfx950"). + hardware_spec: Override hardware spec directly. + compute_efficiency: Fraction of peak TFLOPS achieved (0-1). + Calibrated against measured FAv3 traces on MI300X: + * Measured FA fwd = 5.05 ms, bwd = 10.00 ms + (B=3, H_Q=64, S=8192, D=128, H_KV=8, causal, BF16) + * 0.51 matches measured within 1%. + The lower-than-peak efficiency (vs theoretical 0.75-0.85) + accounts for GQA head broadcasting, LDS bank conflicts, + barrier synchronisation, and register pressure. + memory_efficiency: Fraction of peak HBM bandwidth achieved (0-1). + FAv3 streaming pattern typically achieves 0.80-0.90. + atomic_rmw_factor: Base slowdown of ``buffer_atomic_add_f32`` + relative to a plain ``buffer_store`` (read-modify-write + overhead). Typical range 3-6 on CDNA3. Contention from + multiple writers is modelled *on top* of this factor. + gpu_clock_mhz: Override the GPU compute clock frequency in MHz. + If provided, the profile's TFLOPS are scaled proportionally. + """ + self._hw = hardware_spec or _get_hardware_spec(gpu_arch, gpu_clock_mhz) + self._compute_eff = compute_efficiency + self._memory_eff = memory_efficiency + self._atomic_rmw_factor = atomic_rmw_factor + + def name(self) -> str: + return "sdpa_simulator (FAv3)" + + def is_available(self) -> bool: + return True # Pure-Python analytical model, always available + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def simulate_sdpa( + self, + batch_size: int, + num_heads: int, + seq_len: int, + head_dim: int, + causal: bool = True, + dtype: str = "bf16", + seq_len_kv: Optional[int] = None, + num_heads_kv: Optional[int] = None, + ) -> SimulationResult: + """ + Simulate FAv3 SDPA execution time using a roofline model + parameterised by the actual FAv3 tile configuration. + + Args: + batch_size: Batch size (B). + num_heads: Number of query heads (H_Q). + seq_len: Query sequence length (S_Q). + head_dim: Head dimension (D). + causal: Whether causal masking is applied. + dtype: Data type ("bf16", "fp16", "fp8", "fp32"). + seq_len_kv: Key/Value sequence length (S_K). Defaults to + ``seq_len`` (self-attention). Set differently for + cross-attention or prefill with separate KV cache length. + num_heads_kv: Number of KV heads. Defaults to ``num_heads`` + (MHA). Set lower for GQA/MQA. + """ + B = batch_size + H_Q = num_heads + S_Q = seq_len + S_K = seq_len_kv if seq_len_kv is not None else seq_len + H_K = num_heads_kv if num_heads_kv is not None else num_heads + D = head_dim + bpe = self._bytes_per_element(dtype) + + # GQA ratio: each KV head serves (H_Q / H_K) query heads. + # The FLOPs are still per-query-head, so total FLOPs scale with H_Q. + # Memory for K/V scales with H_K, memory for Q/O scales with H_Q. + + causal_factor = 0.5 if causal else 1.0 + + # ============================================================== + # 1. COMPUTE (FLOP counts) + # ============================================================== + # Forward (per query head, then × H_Q) + # QKᵀ : 2·B·H_Q·S_Q·S_K·D (batched GEMM) + # softmax : ~5·B·H_Q·S_Q·S_K (exp, sub-max, sum, div, mul) + # PV : 2·B·H_Q·S_Q·S_K·D (batched GEMM — P is S_Q×S_K, V is S_K×D) + # NOTE for PV: output is (S_Q, D), inner dim is S_K. + # For causal masking, only ~half the S_Q×S_K elements are computed + # (only valid when S_Q == S_K; for cross-attn causal is usually False). + fwd_gemm_flops = 2.0 * (2.0 * B * H_Q * S_Q * S_K * D) * causal_factor + fwd_softmax_flops = 5.0 * B * H_Q * S_Q * S_K * causal_factor + fwd_flops = fwd_gemm_flops + fwd_softmax_flops + + # Backward (4 batched GEMMs + softmax backward) + # dV = Pᵀ @ dO : 2·B·H_Q·S_K·S_Q → (S_K, D) inner dim S_Q + # dP = dO @ Vᵀ : 2·B·H_Q·S_Q·D → (S_Q, S_K) inner dim D + # dS = softmax_bwd : ~5·B·H_Q·S_Q·S_K + # dQ = dS @ K : 2·B·H_Q·S_Q·S_K → (S_Q, D) inner dim S_K + # dK = dSᵀ @ Q : 2·B·H_Q·S_K·S_Q → (S_K, D) inner dim S_Q + bwd_gemm_flops = 2.0 * (4.0 * B * H_Q * S_Q * S_K * D) * causal_factor + bwd_softmax_flops = 5.0 * B * H_Q * S_Q * S_K * causal_factor + bwd_flops = bwd_gemm_flops + bwd_softmax_flops + + # ============================================================== + # 2. MEMORY IO (Flash Attention – no S/P materialised to HBM) + # ============================================================== + # Forward reads: Q (B·H_Q·S_Q·D), K (B·H_K·S_K·D), V (B·H_K·S_K·D) + # Forward writes: O (B·H_Q·S_Q·D) + logsumexp (B·H_Q·S_Q, fp32) + fwd_read_bytes = ( + B * H_Q * S_Q * D * bpe # Q + + B * H_K * S_K * D * bpe # K + + B * H_K * S_K * D * bpe # V + ) + fwd_write_bytes = ( + B * H_Q * S_Q * D * bpe + B * H_Q * S_Q * 4 # O # logsumexp (fp32) + ) + fwd_bytes = fwd_read_bytes + fwd_write_bytes + + # Backward reads: Q, K, V, O, dO + logsumexp + # Backward regular writes: dK (B·H_K·S_K·D) + dV (B·H_K·S_K·D) + # NOTE: dQ uses buffer_atomic_add_f32 — accounted separately. + bwd_read_bytes = ( + B * H_Q * S_Q * D * bpe # Q + + B * H_K * S_K * D * bpe # K + + B * H_K * S_K * D * bpe # V + + B * H_Q * S_Q * D * bpe # O + + B * H_Q * S_Q * D * bpe # dO + + B * H_Q * S_Q * 4 # logsumexp (fp32) + ) + bwd_regular_write_bytes = ( + B * H_K * S_K * D * bpe + B * H_K * S_K * D * bpe # dK # dV + ) + bwd_bytes = bwd_read_bytes + bwd_regular_write_bytes + + # ============================================================== + # 3. dQ ATOMIC OVERHEAD (BWD only) + # ============================================================== + # In FAv3 backward, each KV-workgroup loops over ALL Q positions + # and atomically accumulates its partial dQ via buffer_atomic_add_f32. + # + # From the FAv3 backward kernel: + # - 72 buffer_atomic_add_f32 instructions in the kernel + # - 8 atomics per Q-block (per wavefront, 64 threads each) + # - 4 wavefronts per workgroup + # - Per Q-block: 8 × 64 × 4W = 2048 atomic ops = 8 KB (FP32) + # = 16 rows × 128 cols × 4 bytes = 8192 bytes ✓ + # + # Contention & L2 coalescing: + # ceil(S_K/256) KV-workgroups all write to the same dQ rows. + # Workgroups on the SAME XCD can coalesce their atomics in the + # local L2 cache (the add is accumulated in L2, only the final + # value is flushed to HBM). So the effective number of HBM + # atomic writes per dQ element is min(n_kv_wgs, n_xcd) rather + # than the full n_kv_wgs. + # + # Each HBM atomic write is a read-modify-write, which costs + # ~rmw_factor × the bandwidth of a regular store. + n_kv_workgroups = math.ceil(S_K / _FAV3_BWD.kv_tile_n) + + # How many KV-workgroups per XCD (for L2 coalescing estimate) + hbm_writers_per_element = min(n_kv_workgroups, self._hw.n_xcd) + + # Effective dQ bytes hitting HBM (after L2 coalescing) + # dQ shape is (B, H_Q, S_Q, D), stored in FP32 (4 bytes) + dq_atomic_bytes = float(hbm_writers_per_element) * B * H_Q * S_Q * D * 4.0 + + # Atomic slowdown = just the RMW factor (contention within-XCD + # is absorbed by L2; cross-XCD traffic goes to different memory + # channels and can proceed in parallel) + atomic_slowdown = self._atomic_rmw_factor + + # ============================================================== + # 4. CU UTILISATION (from FAv3 tile config) + # ============================================================== + fwd_cu_util = self._cu_utilisation(B, H_Q, S_Q, _FAV3_FWD) + bwd_cu_util = self._cu_utilisation(B, H_Q, S_K, _FAV3_BWD) + + # ============================================================== + # 5. ROOFLINE: time = max(compute, memory, atomics) + # ============================================================== + peak_tflops = self._peak_tflops(dtype) + + # Effective throughput = peak × efficiency × CU utilisation + fwd_eff_tflops = peak_tflops * self._compute_eff * fwd_cu_util + bwd_eff_tflops = peak_tflops * self._compute_eff * bwd_cu_util + + fwd_eff_bw = self._hw.hbm_bandwidth_gbps * self._memory_eff + bwd_eff_bw = self._hw.hbm_bandwidth_gbps * self._memory_eff + + # Effective atomic bandwidth (HBM BW reduced by RMW + contention) + bwd_eff_atomic_bw = ( + self._hw.hbm_bandwidth_gbps * self._memory_eff / atomic_slowdown + ) + + # Compute-bound time (ms) + fwd_compute_ms = (fwd_flops / (fwd_eff_tflops * 1e12)) * 1e3 + bwd_compute_ms = (bwd_flops / (bwd_eff_tflops * 1e12)) * 1e3 + + # Memory-bound time (ms) — regular (non-atomic) IO + fwd_memory_ms = (fwd_bytes / (fwd_eff_bw * 1e9)) * 1e3 + bwd_memory_ms = (bwd_bytes / (bwd_eff_bw * 1e9)) * 1e3 + + # Atomic-bound time (ms) — dQ accumulation via buffer_atomic_add_f32 + bwd_atomic_ms = (dq_atomic_bytes / (bwd_eff_atomic_bw * 1e9)) * 1e3 + + fwd_time_ms = max(fwd_compute_ms, fwd_memory_ms) + bwd_time_ms = max(bwd_compute_ms, bwd_memory_ms, bwd_atomic_ms) + + # Achieved metrics + fwd_achieved_tflops = ( + (fwd_flops / (fwd_time_ms * 1e-3)) / 1e12 if fwd_time_ms > 0 else 0 + ) + + # Determine what bounds each pass + bwd_bottleneck = "compute" + if bwd_atomic_ms >= bwd_compute_ms and bwd_atomic_ms >= bwd_memory_ms: + bwd_bottleneck = "atomic" + elif bwd_memory_ms >= bwd_compute_ms: + bwd_bottleneck = "memory" + + return SimulationResult( + forward_time_ms=fwd_time_ms, + backward_time_ms=bwd_time_ms, + tflops=fwd_achieved_tflops, + bandwidth_gbps=( + (fwd_bytes / (fwd_time_ms * 1e-3)) / 1e9 if fwd_time_ms > 0 else 0 + ), + metadata={ + "backend": "sdpa_simulator (FAv3)", + "fwd_compute_bound": fwd_compute_ms >= fwd_memory_ms, + "fwd_compute_ms": fwd_compute_ms, + "fwd_memory_ms": fwd_memory_ms, + "bwd_bottleneck": bwd_bottleneck, + "bwd_compute_ms": bwd_compute_ms, + "bwd_memory_ms": bwd_memory_ms, + "bwd_atomic_ms": bwd_atomic_ms, + "fwd_flops": fwd_flops, + "bwd_flops": bwd_flops, + "fwd_bytes": fwd_bytes, + "bwd_bytes": bwd_bytes, + "seq_len_q": S_Q, + "seq_len_kv": S_K, + "num_heads_q": H_Q, + "num_heads_kv": H_K, + # dQ atomic details (buffer_atomic_add_f32) + "bwd_dq_kv_workgroups": n_kv_workgroups, + "bwd_dq_hbm_writers_per_elem": hbm_writers_per_element, + "bwd_dq_atomic_hbm_bytes": dq_atomic_bytes, + "bwd_dq_rmw_factor": atomic_slowdown, + "bwd_eff_atomic_bw_gbps": bwd_eff_atomic_bw, + # CU utilisation + "fwd_cu_utilisation": fwd_cu_util, + "bwd_cu_utilisation": bwd_cu_util, + "causal": causal, + # FAv3 tile parameters + "fwd_q_tile_m": _FAV3_FWD.q_tile_m, + "fwd_kv_tile_n": _FAV3_FWD.kv_tile_n, + "fwd_wavefronts": _FAV3_FWD.n_wavefronts, + "bwd_q_tile_m": _FAV3_BWD.q_tile_m, + "bwd_kv_tile_n": _FAV3_BWD.kv_tile_n, + "bwd_wavefronts": _FAV3_BWD.n_wavefronts, + }, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _bytes_per_element(self, dtype: str) -> int: + return {"bf16": 2, "fp16": 2, "fp32": 4, "fp8": 1}.get(dtype, 2) + + def _peak_tflops(self, dtype: str) -> float: + return { + "bf16": self._hw.peak_tflops_bf16, + "fp16": self._hw.peak_tflops_fp16, + "fp8": self._hw.peak_tflops_fp8, + "fp32": self._hw.peak_tflops_bf16 / 4, + }.get(dtype, self._hw.peak_tflops_bf16) + + def _cu_utilisation( + self, + batch_size: int, + num_heads: int, + seq_len: int, + tile_cfg: _FAv3TileConfig, + ) -> float: + """ + Estimate CU utilisation for a FAv3 kernel launch. + + FAv3 forward dispatches one workgroup per Q-tile per (batch, head). + FAv3 backward dispatches one workgroup per KV-tile per (batch, head). + + Each workgroup occupies ``n_wavefronts`` wavefront slots on a CU. + If the workgroup uses fewer than ``max_waves_per_cu`` wavefronts, + multiple workgroups *may* share a CU (higher occupancy). + + CU utilisation = min(active_CUs, N_CU) / N_CU + """ + # Number of workgroups + # For FWD: each wg handles q_tile_m rows → ceil(S / q_tile_m) wgs per (B,H) + # For BWD: each wg handles kv_tile_n cols → ceil(S / kv_tile_n) wgs per (B,H) + if tile_cfg is _FAV3_FWD: + n_tiles = math.ceil(seq_len / tile_cfg.q_tile_m) + else: + # BWD: workgroups over KV dimension + n_tiles = math.ceil(seq_len / tile_cfg.kv_tile_n) + + n_workgroups = n_tiles * batch_size * num_heads + + # How many workgroups can share a single CU? + wgs_per_cu = self._hw.max_waves_per_cu // tile_cfg.n_wavefronts + wgs_per_cu = max(wgs_per_cu, 1) + + # Effective CU slots + cu_slots = self._hw.n_cu * wgs_per_cu + active_slots = min(n_workgroups, cu_slots) + + return active_slots / cu_slots diff --git a/primus/core/projection/training_config.py b/primus/core/projection/training_config.py index 81c964f3d..641ec6566 100644 --- a/primus/core/projection/training_config.py +++ b/primus/core/projection/training_config.py @@ -63,6 +63,8 @@ class ModelConfig: moe_shared_expert_intermediate_size: int = 0 # Misc share_embeddings_and_output_weights: bool = False + # Precision – None means bf16, "hybrid" means FP8-hybrid (linear GEMMs in FP8) + fp8: str = None @dataclass @@ -94,7 +96,9 @@ def megatron_derive_default_args(args): if not hasattr(args, "data_parallel_size") or args.data_parallel_size is None: args.data_parallel_size = world_size // ( - args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.context_parallel_size + args.tensor_model_parallel_size + * args.pipeline_model_parallel_size + * args.context_parallel_size ) if not hasattr(args, "virtual_pipeline_model_parallel_size"): args.virtual_pipeline_model_parallel_size = None @@ -105,23 +109,31 @@ def megatron_derive_default_args(args): args.virtual_pipeline_model_parallel_size = 1 elif args.num_layers_per_virtual_pipeline_stage is not None: args.virtual_pipeline_model_parallel_size = args.num_layers // ( - args.num_layers_per_virtual_pipeline_stage * args.pipeline_model_parallel_size + args.num_layers_per_virtual_pipeline_stage + * args.pipeline_model_parallel_size ) - args.share_embeddings_and_output_weights = not args.untie_embeddings_and_output_weights + args.share_embeddings_and_output_weights = ( + not args.untie_embeddings_and_output_weights + ) if args.num_experts is None: args.moe_pattern = [0] * args.num_layers else: if isinstance(args.moe_layer_freq, int): - args.moe_pattern = [1 if (i % args.moe_layer_freq == 0) else 0 for i in range(args.num_layers)] + args.moe_pattern = [ + 1 if (i % args.moe_layer_freq == 0) else 0 + for i in range(args.num_layers) + ] elif isinstance(args.moe_layer_freq, list): args.moe_pattern = args.moe_layer_freq elif isinstance(args.moe_layer_freq, str): try: parsed = eval(args.moe_layer_freq) except Exception: - raise ValueError(f"Invalid moe_layer_freq format: {args.moe_layer_freq}") + raise ValueError( + f"Invalid moe_layer_freq format: {args.moe_layer_freq}" + ) # Handle case where eval returns an int (e.g., "1" -> 1 means all layers are MoE) if isinstance(parsed, int): @@ -130,14 +142,18 @@ def megatron_derive_default_args(args): args.moe_pattern = [1] * args.num_layers else: # Every Nth layer is MoE - args.moe_pattern = [1 if (i % parsed == 0) else 0 for i in range(args.num_layers)] + args.moe_pattern = [ + 1 if (i % parsed == 0) else 0 for i in range(args.num_layers) + ] elif isinstance(parsed, list): args.moe_pattern = parsed assert ( len(args.moe_pattern) == args.num_layers ), f"Invalid moe_layer_freq length: {len(args.moe_pattern)} (expected {args.num_layers})" else: - raise ValueError(f"Invalid moe_layer_freq format after eval: {type(parsed)}") + raise ValueError( + f"Invalid moe_layer_freq format after eval: {type(parsed)}" + ) # naming conversion args.sequence_length = args.seq_length From af0d6af558e7a6e26500186aaae43a6fe63d3d30 Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Thu, 19 Feb 2026 03:42:04 +0000 Subject: [PATCH 04/12] feat(projection): MLA support, explicit backward GEMM simulation, and MoE overhead modeling - Add Multi-Latent Attention (MLA) GEMM simulation with LoRA-factored Q and compressed KV projections (6 fwd + 12 bwd GEMMs) in attention profiler - Add MLA-aware SDPA simulation with split D_qk/D_v head dimensions - Replace 2x-forward backward approximation with explicit dgrad + wgrad GEMM simulation in attention, MLP, MoE, and output layer profilers - Add batched GEMM support (batch param) for Turbo grouped-GEMM modeling vs legacy sequential per-expert execution in MoE MLP profiler - Add router overhead, token permutation, and activation function overhead estimation in MoE MLP simulation - Add TP AllReduce and MoE All-to-All communication overhead estimation in transformer layer simulation mode - Switch EP MLP scaling to delta-based approach preserving profiled layer components (TP AR, A2A, norms) and EP-invariant routed compute model - Add enable_primus_turbo and use_turbo_grouped_mlp config flags --- .../projection/module_profilers/attention.py | 163 +++++++++++++-- .../projection/module_profilers/moe_mlp.py | 178 ++++++++++++++-- .../module_profilers/output_layer.py | 24 ++- .../module_profilers/transformer_layer.py | 135 +++++++++++- .../performance_projection/projection.py | 196 +++++++++++------- .../projection/simulation_backends/base.py | 160 ++++++++++---- .../projection/simulation_backends/factory.py | 8 +- .../simulation_backends/origami_backend.py | 4 + .../simulation_backends/sdpa_simulator.py | 78 ++++--- primus/core/projection/training_config.py | 4 + 10 files changed, 754 insertions(+), 196 deletions(-) diff --git a/primus/core/projection/module_profilers/attention.py b/primus/core/projection/module_profilers/attention.py index 400139aff..b45bebb4c 100644 --- a/primus/core/projection/module_profilers/attention.py +++ b/primus/core/projection/module_profilers/attention.py @@ -159,6 +159,111 @@ def _num_query_groups() -> int: return tokens_per_rank * (activation_width + ln_width) * bytes_per_value + def _simulate_mla_gemms( + self, batch_tokens: int, dtype: str + ) -> tuple[float, float]: + """Simulate MLA (Multi-Latent Attention) projection GEMMs. + + MLA uses LoRA-factored Q and compressed KV projections instead of + standard Q/K/V projections: + Forward (6 GEMMs): Q_down, Q_up, KV_down, KV_up, RoPE_proj, O_proj + Backward (12 GEMMs): dgrad + wgrad for each of the 6 projections + """ + args = self.config.model_config + backend = self._gemm_backend + + hidden = args.hidden_size + heads = args.num_attention_heads + q_lora_rank = args.q_lora_rank + kv_lora_rank = args.kv_lora_rank + qk_head_dim = args.qk_head_dim + qk_pos_emb_head_dim = args.qk_pos_emb_head_dim + v_head_dim = args.v_head_dim + + fwd_time = 0.0 + bwd_time = 0.0 + T = batch_tokens + + # ---------- Forward ---------- + if q_lora_rank is not None: + # Q down-proj: [T, hidden] × [hidden, q_lora_rank] + q_down_out = q_lora_rank + r = backend.simulate_gemm(T, q_down_out, hidden, dtype) + fwd_time += r.forward_time_ms + # Q up-proj: [T, q_lora_rank] × [q_lora_rank, heads*(qk_hd+qk_pe_hd)] + q_up_out = heads * (qk_head_dim + qk_pos_emb_head_dim) + r = backend.simulate_gemm(T, q_up_out, q_lora_rank, dtype) + fwd_time += r.forward_time_ms + else: + # Direct Q projection (no LoRA): [T, hidden] × [hidden, heads*(qk_hd+qk_pe_hd)] + q_up_out = heads * (qk_head_dim + qk_pos_emb_head_dim) + r = backend.simulate_gemm(T, q_up_out, hidden, dtype) + fwd_time += r.forward_time_ms + + # KV down-proj: [T, hidden] × [hidden, kv_lora_rank] + kv_down_out = kv_lora_rank + r = backend.simulate_gemm(T, kv_down_out, hidden, dtype) + fwd_time += r.forward_time_ms + # KV up-proj: [T, kv_lora_rank] × [kv_lora_rank, heads*(qk_hd+v_hd)] + kv_up_out = heads * (qk_head_dim + v_head_dim) + r = backend.simulate_gemm(T, kv_up_out, kv_lora_rank, dtype) + fwd_time += r.forward_time_ms + + # RoPE positional embedding projection: [T, hidden] × [hidden, qk_pos_emb_head_dim] + r = backend.simulate_gemm(T, qk_pos_emb_head_dim, hidden, dtype) + fwd_time += r.forward_time_ms + + # Output projection: [T, heads*v_hd] × [heads*v_hd, hidden] + o_in = heads * v_head_dim + r = backend.simulate_gemm(T, hidden, o_in, dtype) + fwd_time += r.forward_time_ms + + # ---------- Backward (dgrad + wgrad for each projection) ---------- + if q_lora_rank is not None: + # Q down-proj dgrad: [T, q_down_out] × [q_down_out, hidden] → [T, hidden] + r = backend.simulate_gemm(T, hidden, q_down_out, dtype) + bwd_time += r.forward_time_ms + # Q down-proj wgrad: [hidden, T] × [T, q_down_out] → [hidden, q_down_out] + r = backend.simulate_gemm(hidden, q_down_out, T, dtype) + bwd_time += r.forward_time_ms + # Q up-proj dgrad: [T, q_up_out] × [q_up_out, q_lora_rank] → [T, q_lora_rank] + r = backend.simulate_gemm(T, q_lora_rank, q_up_out, dtype) + bwd_time += r.forward_time_ms + # Q up-proj wgrad: [q_lora_rank, T] × [T, q_up_out] → [q_lora_rank, q_up_out] + r = backend.simulate_gemm(q_lora_rank, q_up_out, T, dtype) + bwd_time += r.forward_time_ms + else: + # Direct Q dgrad + wgrad + r = backend.simulate_gemm(T, hidden, q_up_out, dtype) + bwd_time += r.forward_time_ms + r = backend.simulate_gemm(hidden, q_up_out, T, dtype) + bwd_time += r.forward_time_ms + + # KV down-proj dgrad + wgrad + r = backend.simulate_gemm(T, hidden, kv_down_out, dtype) + bwd_time += r.forward_time_ms + r = backend.simulate_gemm(hidden, kv_down_out, T, dtype) + bwd_time += r.forward_time_ms + # KV up-proj dgrad + wgrad + r = backend.simulate_gemm(T, kv_lora_rank, kv_up_out, dtype) + bwd_time += r.forward_time_ms + r = backend.simulate_gemm(kv_lora_rank, kv_up_out, T, dtype) + bwd_time += r.forward_time_ms + + # RoPE proj dgrad + wgrad + r = backend.simulate_gemm(T, hidden, qk_pos_emb_head_dim, dtype) + bwd_time += r.forward_time_ms + r = backend.simulate_gemm(hidden, qk_pos_emb_head_dim, T, dtype) + bwd_time += r.forward_time_ms + + # O proj dgrad + wgrad + r = backend.simulate_gemm(T, o_in, hidden, dtype) + bwd_time += r.forward_time_ms + r = backend.simulate_gemm(o_in, hidden, T, dtype) + bwd_time += r.forward_time_ms + + return fwd_time, bwd_time + def _get_simulated_results( self, batch_size: int, seq_len: int ) -> tuple[float, float, int]: @@ -174,36 +279,58 @@ def _get_simulated_results( fwd_time = 0.0 bwd_time = 0.0 - # 1. Simulate linear projection GEMMs (Q, K, V, O) using GEMM backend + # 1. Simulate linear projection GEMMs using GEMM backend if self._gemm_backend is not None: - num_query_groups = ( - args.num_query_groups - if args.group_query_attention and args.num_query_groups - else args.num_attention_heads - ) - # FP8-hybrid: linear projections (QKV, O) run in FP8 gemm_dtype = "fp8" if getattr(args, "fp8", None) else "bf16" - gemm_result = self._gemm_backend.simulate_attention_gemms( - batch_tokens=batch_tokens, - hidden_size=args.hidden_size, - num_attention_heads=args.num_attention_heads, - kv_channels=args.kv_channels, - num_query_groups=num_query_groups, - dtype=gemm_dtype, - ) - fwd_time += gemm_result.forward_time_ms - bwd_time += gemm_result.backward_time_ms + + if getattr(args, "multi_latent_attention", False): + # MLA: LoRA-factored Q and compressed KV projections + # 6 forward GEMMs + 12 backward GEMMs + mla_fwd, mla_bwd = self._simulate_mla_gemms( + batch_tokens, gemm_dtype + ) + fwd_time += mla_fwd + bwd_time += mla_bwd + else: + # Standard attention: Q, K, V, O projections + # 4 forward GEMMs + 8 backward GEMMs + num_query_groups = ( + args.num_query_groups + if args.group_query_attention and args.num_query_groups + else args.num_attention_heads + ) + gemm_result = self._gemm_backend.simulate_attention_gemms( + batch_tokens=batch_tokens, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + kv_channels=args.kv_channels, + num_query_groups=num_query_groups, + dtype=gemm_dtype, + ) + fwd_time += gemm_result.forward_time_ms + bwd_time += gemm_result.backward_time_ms # 2. Simulate SDPA core computation using SDPA backend if self._sdpa_backend is not None: heads_per_rank = max(1, args.num_attention_heads // tp_size) + + if getattr(args, "multi_latent_attention", False): + # MLA: Q·Kᵀ uses qk_head_dim + qk_pos_emb_head_dim (e.g. 192), + # P·V uses v_head_dim (e.g. 128). + sdpa_head_dim = args.qk_head_dim + args.qk_pos_emb_head_dim + sdpa_head_dim_v = args.v_head_dim + else: + sdpa_head_dim = args.kv_channels + sdpa_head_dim_v = None # same as head_dim + sdpa_result = self._sdpa_backend.simulate_sdpa( batch_size=batch_size, num_heads=heads_per_rank, seq_len=slen_per_cp, - head_dim=args.kv_channels, + head_dim=sdpa_head_dim, causal=True, dtype="bf16", + head_dim_v=sdpa_head_dim_v, ) fwd_time += sdpa_result.forward_time_ms bwd_time += sdpa_result.backward_time_ms diff --git a/primus/core/projection/module_profilers/moe_mlp.py b/primus/core/projection/module_profilers/moe_mlp.py index f51fa1269..40e2d0946 100644 --- a/primus/core/projection/module_profilers/moe_mlp.py +++ b/primus/core/projection/module_profilers/moe_mlp.py @@ -4,6 +4,8 @@ # See LICENSE for license information. ############################################################################### +import math +import os from typing import Optional from primus.core.projection.base_module_profiler import BaseModuleProfiler @@ -12,6 +14,10 @@ from .utils import benchmark_layer +# Memory-bandwidth constants for non-GEMM MoE overhead estimation +_PERMUTE_EFF_BW_GBPS = 300.0 # scatter/gather effective BW (~5-10% peak HBM) +_ACTIVATION_BW_GBPS = 3000.0 # sequential element-wise ops (~60% peak HBM) + class MoEMLPProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): @@ -112,43 +118,183 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: def _get_simulated_results( self, batch_size: int, seq_len: int ) -> tuple[float, float, int]: - """Get simulated results from the GEMM simulation backend for MoE MLP.""" + """Get simulated results from the GEMM simulation backend for MoE MLP. + + In addition to expert GEMM time, this method estimates several + components of MoE execution that the GEMM simulation alone misses: + + 1. **Router overhead** — gate linear projection + softmax/top-K. + 2. **Token permutation** — dispatch (scatter) and combine (gather) + memory traffic with random-access patterns. + 3. **Activation function** — SwiGLU / GELU element-wise overhead. + + **Grouped GEMM performance model selection**: + When ``enable_primus_turbo`` and ``use_turbo_grouped_mlp`` are both + ``True`` in the training config, the expert GEMMs are modelled using + Origami's *batched* GEMM path (``batch=num_local_experts``). Primus + Turbo's grouped-GEMM kernel achieves near-ideal batched execution, + so the batched model is an accurate proxy. + + Otherwise (legacy ``grouped_gemm`` package), each expert is simulated + independently (``batch=1``) and the result is scaled by the number of + local experts. This more closely reflects the sequential per-expert + execution of the legacy kernel. + """ tp_size = self.config.model_parallel_config.tensor_model_parallel_size cp_size = self.config.model_parallel_config.context_model_parallel_size ep_size = self.config.model_parallel_config.expert_model_parallel_size + hidden_size = self.config.model_config.hidden_size batch_tokens = batch_size * seq_len // tp_size // cp_size - topk_tokens = batch_tokens * self.config.model_config.moe_router_topk + topk = self.config.model_config.moe_router_topk + topk_tokens = batch_tokens * topk if self.config.model_config.moe_ffn_hidden_size is not None: moe_ffn = self.config.model_config.moe_ffn_hidden_size else: moe_ffn = self.config.model_config.ffn_hidden_size - # Simulate routed expert MLP GEMMs (topk tokens through experts / EP) - # Each expert processes topk_tokens / num_local_experts tokens on average - num_local_experts = (self.config.model_config.num_experts or 1) // ep_size + num_experts = self.config.model_config.num_experts or 1 + num_local_experts = num_experts // ep_size tokens_per_expert = topk_tokens // max(num_local_experts, 1) # FP8-hybrid: MoE expert MLP projections run in FP8 gemm_dtype = "fp8" if getattr(self.config.model_config, "fp8", None) else "bf16" - sim_result = self._gemm_backend.simulate_mlp_gemms( - batch_tokens=tokens_per_expert, - hidden_size=self.config.model_config.hidden_size, - ffn_hidden_size=moe_ffn, - dtype=gemm_dtype, - swiglu=self.config.model_config.swiglu, + bytes_per_el = 1 if gemm_dtype == "fp8" else 2 + + # ── 1. Routed expert GEMMs ── + M = tokens_per_expert + H = hidden_size + F = moe_ffn + + # Determine grouped-GEMM performance model. + # Primus Turbo's grouped-GEMM kernel achieves near-ideal batched + # execution → model as Origami batched GEMM (batch=num_local_experts). + # Legacy grouped_gemm executes experts more sequentially → model as + # individual GEMM (batch=1) × num_local_experts. + use_turbo = ( + getattr(self.config.model_config, "enable_primus_turbo", False) + and getattr(self.config.model_config, "use_turbo_grouped_mlp", False) ) - # Scale by number of local experts (they run sequentially or in grouped GEMM) - fwd_time = sim_result.forward_time_ms * num_local_experts - bwd_time = sim_result.backward_time_ms * num_local_experts - # Shared experts (if any) + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0 and num_local_experts > 1: + mode = "Turbo (batched)" if use_turbo else "Legacy (sequential)" + print(f" [MoE MLP] Grouped-GEMM model: {mode}" + f" ({num_local_experts} local experts, M={M}, H={H}, F={F})") + + expert_fwd_ms = 0.0 + expert_bwd_ms = 0.0 + + if use_turbo: + # ── Turbo model: batched GEMM (all experts in parallel) ── + B = num_local_experts + if self.config.model_config.swiglu: + gate_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + expert_fwd_ms = (gate_fwd.forward_time_ms + + up_fwd.forward_time_ms + + down_fwd.forward_time_ms) + gate_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + gate_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) + down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=B) + expert_bwd_ms = (gate_dg.forward_time_ms + gate_wg.forward_time_ms + + up_dg.forward_time_ms + up_wg.forward_time_ms + + down_dg.forward_time_ms + down_wg.forward_time_ms) + else: + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + expert_fwd_ms = up_fwd.forward_time_ms + down_fwd.forward_time_ms + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) + down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=B) + expert_bwd_ms = (up_dg.forward_time_ms + up_wg.forward_time_ms + + down_dg.forward_time_ms + down_wg.forward_time_ms) + + expert_fwd = expert_fwd_ms + expert_bwd = expert_bwd_ms + else: + # ── Legacy model: individual GEMM × num_local_experts ── + if self.config.model_config.swiglu: + gate_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + expert_fwd_ms = (gate_fwd.forward_time_ms + + up_fwd.forward_time_ms + + down_fwd.forward_time_ms) + gate_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + gate_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) + down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=1) + expert_bwd_ms = (gate_dg.forward_time_ms + gate_wg.forward_time_ms + + up_dg.forward_time_ms + up_wg.forward_time_ms + + down_dg.forward_time_ms + down_wg.forward_time_ms) + else: + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + expert_fwd_ms = up_fwd.forward_time_ms + down_fwd.forward_time_ms + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) + down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=1) + expert_bwd_ms = (up_dg.forward_time_ms + up_wg.forward_time_ms + + down_dg.forward_time_ms + down_wg.forward_time_ms) + + expert_fwd = expert_fwd_ms * num_local_experts + expert_bwd = expert_bwd_ms * num_local_experts + + fwd_time = expert_fwd + bwd_time = expert_bwd + + # ── 2. Router overhead ── + # Gate linear: [batch_tokens, num_experts, hidden_size] + router_gemm = self._gemm_backend.simulate_gemm( + batch_tokens, num_experts, hidden_size, gemm_dtype + ) + router_fwd_ms = router_gemm.forward_time_ms + # Softmax + top-K selection + auxiliary loss overhead (empirical) + topk_overhead_ms = 0.1 + 0.002 * num_experts + router_fwd_ms += topk_overhead_ms + # Backward: dgrad + wgrad for gate linear + router_bwd_ms = 2.0 * router_gemm.forward_time_ms + topk_overhead_ms + + fwd_time += router_fwd_ms + bwd_time += router_bwd_ms + + # ── 3. Token permutation overhead (dispatch + combine) ── + # Dispatch: gather tokens by expert assignment → irregular memory access + # Combine: scatter expert outputs back → weighted reduce + dispatch_bytes = (batch_tokens + topk_tokens) * hidden_size * bytes_per_el + combine_bytes = (topk_tokens + batch_tokens) * hidden_size * bytes_per_el + permute_fwd_ms = dispatch_bytes / (_PERMUTE_EFF_BW_GBPS * 1e6) + permute_bwd_ms = combine_bytes / (_PERMUTE_EFF_BW_GBPS * 1e6) + + fwd_time += permute_fwd_ms + bwd_time += permute_bwd_ms + + # ── 4. Activation function overhead (SwiGLU / GELU) ── + if self.config.model_config.swiglu: + act_bytes = 3 * topk_tokens * moe_ffn * bytes_per_el # gate+up read, result write + else: + act_bytes = 2 * topk_tokens * moe_ffn * bytes_per_el # read + write + activation_ms = act_bytes / (_ACTIVATION_BW_GBPS * 1e6) + + fwd_time += activation_ms + bwd_time += activation_ms + + # ── 5. Shared experts (if any) ── shared_sz = self.config.model_config.moe_shared_expert_intermediate_size if shared_sz: shared_result = self._gemm_backend.simulate_mlp_gemms( batch_tokens=batch_tokens, - hidden_size=self.config.model_config.hidden_size, + hidden_size=hidden_size, ffn_hidden_size=shared_sz, dtype=gemm_dtype, swiglu=self.config.model_config.swiglu, diff --git a/primus/core/projection/module_profilers/output_layer.py b/primus/core/projection/module_profilers/output_layer.py index 90575882c..c2e27555b 100644 --- a/primus/core/projection/module_profilers/output_layer.py +++ b/primus/core/projection/module_profilers/output_layer.py @@ -53,15 +53,31 @@ def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, hidden_size = self.config.model_config.hidden_size vocab_size = self.config.model_config.padded_vocab_size - # Output projection GEMM: [batch_tokens, hidden_size] x [hidden_size, vocab_size] - sim_result = self._gemm_backend.simulate_gemm( + # Output projection GEMM fwd: [batch_tokens, hidden_size] x [hidden_size, vocab_size] + fwd_result = self._gemm_backend.simulate_gemm( m=batch_tokens, n=vocab_size, k=hidden_size, dtype="bf16", ) - fwd_time = sim_result.forward_time_ms - bwd_time = fwd_time * 2.0 # dgrad + wgrad + fwd_time = fwd_result.forward_time_ms + + # Backward: simulate actual dgrad + wgrad GEMMs + # dgrad: [batch_tokens, vocab_size] x [vocab_size, hidden_size] -> [batch_tokens, hidden_size] + dgrad_result = self._gemm_backend.simulate_gemm( + m=batch_tokens, + n=hidden_size, + k=vocab_size, + dtype="bf16", + ) + # wgrad: [hidden_size, batch_tokens] x [batch_tokens, vocab_size] -> [hidden_size, vocab_size] + wgrad_result = self._gemm_backend.simulate_gemm( + m=hidden_size, + n=vocab_size, + k=batch_tokens, + dtype="bf16", + ) + bwd_time = dgrad_result.forward_time_ms + wgrad_result.forward_time_ms activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) diff --git a/primus/core/projection/module_profilers/transformer_layer.py b/primus/core/projection/module_profilers/transformer_layer.py index 41320443c..329634819 100644 --- a/primus/core/projection/module_profilers/transformer_layer.py +++ b/primus/core/projection/module_profilers/transformer_layer.py @@ -4,12 +4,15 @@ # See LICENSE for license information. ############################################################################### +import os from typing import Optional from primus.core.projection.base_module_profiler import BaseModuleProfiler from primus.core.projection.profiler_spec import ModuleProfilerSpec from primus.core.projection.training_config import TrainingConfig +from . import collective_model as cm +from .collective_args import get_default_args from .attention import AttentionProfiler from .dense_mlp import DenseMLPProfiler from .layer_norm import LayerNormProfiler @@ -18,6 +21,98 @@ from .router import RouterProfiler from .utils import benchmark_layer +def _estimate_tp_allreduce_time_ms(config, batch_size: int, seq_len: int) -> float: + """ + Estimate TP AllReduce time for a single AllReduce operation (in ms). + + In Megatron-style tensor parallelism each transformer layer performs + 2 AllReduces in forward (after attention row-parallel output projection, + after MLP row-parallel down projection) and 2 in backward. + With sequence parallelism the AllReduce is replaced by + ReduceScatter + AllGather pairs, but the total data volume is equivalent. + + Returns 0.0 when TP <= 1 (no communication needed). + """ + tp = config.model_parallel_config.tensor_model_parallel_size + if tp <= 1: + return 0.0 + + cp = getattr(config.model_parallel_config, "context_parallel_size", 1) or 1 + pp = config.model_parallel_config.pipeline_model_parallel_size + ep = getattr(config.model_parallel_config, "expert_model_parallel_size", 1) or 1 + hidden_size = config.model_config.hidden_size + + # Message size: activations after row-parallel projection + # Shape: [batch_size * seq_len / CP, hidden_size], BF16 (2 bytes) + message_size_bytes = batch_size * seq_len * hidden_size * 2 // cp + + # Setup collective communication args + gpus_per_node = int(os.environ.get("GPUS_PER_NODE", "8")) + num_nodes = int(os.environ.get("NNODES", "1")) + + coll_args = get_default_args( + num_nodes=num_nodes, + gpus_per_node=gpus_per_node, + tp=tp, pp=pp, ep=ep, cp=cp, + ) + + # TP AllReduce is across tp ranks (typically intra-node) + ar_time_us = cm.allreduce(coll_args, message_size_bytes, tp, groups=["tp"]) + return ar_time_us / 1000.0 # Convert microseconds → milliseconds + + +def _estimate_moe_a2a_time_ms(config, batch_size: int, seq_len: int) -> float: + """ + Estimate MoE All-to-All time (dispatch + combine) per layer per direction (in ms). + + Each MoE layer performs two A2A operations per direction: + 1. **Dispatch**: scatter tokens to the EP ranks that own the assigned experts. + 2. **Combine**: gather expert outputs back to the originating ranks. + + In benchmark mode this cost is captured inside the measured layer time. + In simulation mode we must add it explicitly because the layer profiler + only simulates GEMM / SDPA compute and TP AllReduce. + + Returns 0.0 when EP <= 1 (all experts are local, no A2A needed). + """ + ep = getattr(config.model_parallel_config, "expert_model_parallel_size", 1) or 1 + if ep <= 1: + return 0.0 + + tp = config.model_parallel_config.tensor_model_parallel_size + pp = config.model_parallel_config.pipeline_model_parallel_size + cp = getattr(config.model_parallel_config, "context_parallel_size", 1) or 1 + hidden_size = config.model_config.hidden_size + moe_router_topk = getattr(config.model_config, "moe_router_topk", 2) + + # A2A message size: each rank sends/receives all routed tokens + # Shape: [batch_size * seq_len * topk, hidden_size], BF16 (2 bytes) + tokens_per_batch = batch_size * seq_len + dispatch_size_bytes = tokens_per_batch * hidden_size * moe_router_topk * 2 + + # Setup collective communication args + gpus_per_node = int(os.environ.get("GPUS_PER_NODE", "8")) + num_nodes = int(os.environ.get("NNODES", "1")) + + coll_args = get_default_args( + num_nodes=num_nodes, + gpus_per_node=gpus_per_node, + tp=tp, pp=pp, ep=ep, cp=cp, + ) + + # Propagate DeepEP setting if present (affects A2A algorithm selection) + moe_enable_deepep = getattr(config.model_parallel_config, "moe_enable_deepep", False) + use_turbo_deepep = getattr(config.model_parallel_config, "use_turbo_deepep", False) + coll_args.moe_enable_deepep = moe_enable_deepep + coll_args.use_turbo_deepep = use_turbo_deepep + + # Dispatch A2A + Combine A2A (same message size, same time) + a2a_dispatch_us = cm.alltoall(coll_args, dispatch_size_bytes, ep, groups=["ep"]) + a2a_combine_us = cm.alltoall(coll_args, dispatch_size_bytes, ep, groups=["ep"]) + + return (a2a_dispatch_us + a2a_combine_us) / 1000.0 # Convert us → ms + + # Transformer Layer Data Flow # # +----------------+ @@ -121,13 +216,21 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: ) def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: - """Aggregate simulated results from sub-profilers.""" + """Aggregate simulated results from sub-profilers, including TP AllReduce.""" attn_fwd = self.sub_profilers["self_attention"].measured_forward_time(batch_size, seq_len) attn_bwd = self.sub_profilers["self_attention"].measured_backward_time(batch_size, seq_len) mlp_fwd = self.sub_profilers["mlp"].measured_forward_time(batch_size, seq_len) mlp_bwd = self.sub_profilers["mlp"].measured_backward_time(batch_size, seq_len) - fwd_time = attn_fwd + mlp_fwd - bwd_time = attn_bwd + mlp_bwd + + # Add TP AllReduce communication overhead (simulation only). + # Each transformer layer has 2 AllReduces per direction: + # - After attention row-parallel output projection + # - After MLP row-parallel down projection + # (With sequence parallelism these become RS+AG pairs with equal volume.) + tp_ar_ms = _estimate_tp_allreduce_time_ms(self.config, batch_size, seq_len) + + fwd_time = attn_fwd + mlp_fwd + 2 * tp_ar_ms + bwd_time = attn_bwd + mlp_bwd + 2 * tp_ar_ms activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) @@ -222,13 +325,33 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: ) def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: - """Aggregate simulated results from sub-profilers.""" + """Aggregate simulated results from sub-profilers. + + Includes TP AllReduce and MoE All-to-All communication overhead that + would be captured in the measured layer time during benchmark mode but + must be added explicitly in simulation mode. + """ attn_fwd = self.sub_profilers["self_attention"].measured_forward_time(batch_size, seq_len) attn_bwd = self.sub_profilers["self_attention"].measured_backward_time(batch_size, seq_len) mlp_fwd = self.sub_profilers["mlp"].measured_forward_time(batch_size, seq_len) mlp_bwd = self.sub_profilers["mlp"].measured_backward_time(batch_size, seq_len) - fwd_time = attn_fwd + mlp_fwd - bwd_time = attn_bwd + mlp_bwd + + # Add TP AllReduce communication overhead (simulation only). + # Each transformer layer has 2 AllReduces per direction: + # - After attention row-parallel output projection + # - After MLP row-parallel down projection + # (With sequence parallelism these become RS+AG pairs with equal volume.) + tp_ar_ms = _estimate_tp_allreduce_time_ms(self.config, batch_size, seq_len) + + # Add MoE All-to-All communication overhead (simulation only). + # Each MoE layer performs dispatch A2A + combine A2A per direction. + # In benchmark mode this is captured in the measured layer time; + # in simulation mode the layer profiler only computes GEMM / SDPA + # so we must add it here. + moe_a2a_ms = _estimate_moe_a2a_time_ms(self.config, batch_size, seq_len) + + fwd_time = attn_fwd + mlp_fwd + 2 * tp_ar_ms + moe_a2a_ms + bwd_time = attn_bwd + mlp_bwd + 2 * tp_ar_ms + moe_a2a_ms activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) diff --git a/primus/core/projection/performance_projection/projection.py b/primus/core/projection/performance_projection/projection.py index 320ba6c78..3b04e8d1f 100644 --- a/primus/core/projection/performance_projection/projection.py +++ b/primus/core/projection/performance_projection/projection.py @@ -435,7 +435,7 @@ def calculate_collective_communication_time( else: message_info["gradient_allreduce_overlapped"] = False - # FSDP overlap model (calibrated against LLaMA3-70B MI300X trace) + # FSDP overlap model # --------------------------------------------------------------- # FSDP2 prefetches next layer's AllGather while the current layer # computes, and ReduceScatter runs after backward completes. @@ -474,7 +474,7 @@ def calculate_collective_communication_time( fwd_ag_total = total_fsdp_ag bwd_ag_total = 0.0 - # Per-phase overlap percentages (from trace calibration) + # Per-phase overlap percentages FWD_AG_OVERLAP = 0.90 # forward AG hidden behind compute BWD_AG_OVERLAP = 0.24 # backward recompute AG (structural limit) RS_OVERLAP = 0.34 # ReduceScatter (structural limit) @@ -813,6 +813,7 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): pp = getattr(original_config, "pipeline_model_parallel_size", 1) or 1 ep = getattr(original_config, "expert_model_parallel_size", 1) or 1 cp = getattr(original_config, "context_parallel_size", 1) or 1 + num_experts = getattr(original_config, "num_experts", None) gpus_required = tp * pp * ep * cp nodes_required = (gpus_required + gpus_per_node - 1) // gpus_per_node @@ -828,6 +829,8 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): "original_ep": ep, "benchmark_ep": ep, "original_cp": cp, + "original_num_experts": num_experts, + "benchmark_num_experts": num_experts, } # Step 1: Reduce PP to 1 @@ -836,6 +839,7 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): # Step 2: If still doesn't fit, rescale EP benchmark_ep = ep + benchmark_num_experts = num_experts if benchmark_gpus_required > gpus_per_node: print( f"[Primus:Performance Projection] After reducing PP to 1, " @@ -849,6 +853,7 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): rescale_info = _rescale_expert_parallelism(original_config) if rescale_info: benchmark_ep = rescale_info["ep_after"] + benchmark_num_experts = rescale_info.get("num_experts_after", num_experts) benchmark_gpus_required = tp * benchmark_pp * benchmark_ep * cp if benchmark_gpus_required > gpus_per_node: @@ -889,6 +894,8 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): "original_ep": ep, "benchmark_ep": benchmark_ep, "original_cp": cp, + "original_num_experts": num_experts, + "benchmark_num_experts": benchmark_num_experts, } @@ -960,47 +967,69 @@ def _estimate_pp_communication_overhead( return total_p2p_time_ms -def _compute_ep_mlp_scale(model_config, benchmark_ep, original_ep): +def _compute_ep_mlp_scale( + model_config, + benchmark_ep, + original_ep, + original_num_experts=None, + benchmark_num_experts=None, +): """ Compute the MLP time scaling factor when EP changes, accounting for - shared experts (EP-independent) vs routed experts (EP-dependent). + shared experts (EP-independent) vs routed experts. + + Key insight — per-GPU routed compute is EP-invariant + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + In Megatron MoE each GPU in the EP group processes the **same** + micro-batch. After A2A token redistribution every GPU ends up + computing ``batch_tokens × topk`` total token-expert pairs, + regardless of EP. Therefore: + + * When ``_rescale_expert_parallelism`` adjusts ``num_experts`` + proportionally (preserving ``experts_per_rank``), the profiled / + simulated MLP time already reflects the correct per-rank workload + → **no scaling needed** (returns 1.0). + + * When EP changes but ``num_experts`` stays fixed (hypothetical — + our rescaling always preserves experts_per_rank), the GEMM shapes + change (fewer, larger GEMMs vs. more, smaller GEMMs) but total + FLOPs remain identical. We conservatively return 1.0 because the + simple ``benchmark_ep / original_ep`` ratio does not capture GEMM- + efficiency differences. - In Megatron MoE: - - Routed expert compute per GPU ∝ (topk / EP) × moe_ffn_hidden_size - - Shared expert compute is constant regardless of EP + Shared expert compute is constant regardless of EP (no A2A needed). - The profiled MLP time at benchmark_ep includes both. When scaling to - original_ep, only the routed portion changes. + Args: + model_config: Model configuration. + benchmark_ep: EP used during profiling / simulation. + original_ep: EP of the target deployment. + original_num_experts: Total expert count in the target config. + benchmark_num_experts: Total expert count after rescaling + (may differ from ``original_num_experts`` if + ``_rescale_expert_parallelism`` adjusted it). Returns: float: Scale factor to apply to the profiled MLP time. + 1.0 when experts_per_rank is preserved (the common case). """ - topk = getattr(model_config, "moe_router_topk", 1) or 1 - moe_ffn = getattr(model_config, "moe_ffn_hidden_size", None) - shared_ffn = getattr(model_config, "moe_shared_expert_intermediate_size", None) - # Derive num_shared_experts: explicit attribute, or infer from - # moe_shared_expert_intermediate_size // moe_ffn_hidden_size - num_shared = getattr(model_config, "num_shared_experts", 0) or 0 - if num_shared == 0 and shared_ffn and moe_ffn: - num_shared = shared_ffn // moe_ffn - - if not moe_ffn or num_shared == 0 or not shared_ffn: - # No shared experts — all MLP compute is routed, scales with 1/EP - return benchmark_ep / original_ep - - # FLOPs proportional to tokens × ffn_size - # Routed: (topk / benchmark_ep) tokens per expert-slot, through moe_ffn - # Shared: all tokens (1.0), through shared_ffn - routed_flops = (topk / benchmark_ep) * moe_ffn - shared_flops = num_shared * shared_ffn - total_flops = routed_flops + shared_flops - - routed_fraction = routed_flops / total_flops - shared_fraction = shared_flops / total_flops - - # Routed portion scales by benchmark_ep / original_ep; shared stays constant - scale = shared_fraction + routed_fraction * (benchmark_ep / original_ep) - return scale + if benchmark_ep == original_ep: + return 1.0 + + # Determine whether experts_per_rank was preserved by rescaling. + # When it is, per-GPU routed compute is identical — no scaling. + if original_num_experts is not None and benchmark_num_experts is not None: + orig_epr = original_num_experts / original_ep + bench_epr = benchmark_num_experts / benchmark_ep + if abs(orig_epr - bench_epr) < 0.5: + # experts_per_rank preserved — per-GPU routed compute unchanged. + return 1.0 + + # Fallback: per-GPU MoE routed compute is EP-invariant (A2A + # redistributes tokens so each GPU processes batch_tokens × topk). + # The simple benchmark_ep / original_ep ratio is NOT correct because + # total FLOPs are constant; only GEMM shapes differ. We return 1.0 + # rather than an inaccurate heuristic. + return 1.0 def _estimate_ep_communication_overhead( @@ -2157,6 +2186,12 @@ def launch_projection_from_cli(args, overrides): primus_config.get_module_config("pre_trainer").expert_model_parallel_size = ( reduction_info["benchmark_ep"] ) + # Also propagate num_experts adjustment so that the profiler sees + # the correct experts_per_rank (e.g. 128/4=32, not 256/4=64). + if reduction_info.get("benchmark_num_experts") is not None: + primus_config.get_module_config("pre_trainer").num_experts = ( + reduction_info["benchmark_num_experts"] + ) # Determine profiling mode profiling_mode = getattr(args, "profiling_mode", "benchmark") @@ -2278,13 +2313,15 @@ def launch_projection_from_cli(args, overrides): ): original_ep = reduction_info["original_ep"] benchmark_ep = reduction_info["benchmark_ep"] + original_num_experts = reduction_info.get("original_num_experts") + benchmark_num_experts = reduction_info.get("benchmark_num_experts") # Load hardware config if provided hardware_config_dict = None if hasattr(args, "hardware_config") and args.hardware_config: hardware_config_dict = load_hardware_config(args.hardware_config) - # Calculate EP communication overhead per layer + # Calculate EP communication overhead per layer (A2A delta) fwd_overhead_per_layer, bwd_overhead_per_layer = ( _estimate_ep_communication_overhead( training_config, @@ -2294,11 +2331,18 @@ def launch_projection_from_cli(args, overrides): ) ) - # EP compute scaling: when EP increases, each GPU handles fewer routed - # expert tokens, but shared expert compute stays constant. - # Use _compute_ep_mlp_scale to get the correct fraction-aware scale. + # EP compute scaling. Per-GPU routed compute is EP-invariant + # (A2A redistributes tokens so each GPU always processes + # batch_tokens × topk total token-expert pairs). When + # _rescale_expert_parallelism preserves experts_per_rank the + # profiled/simulated MLP time already reflects the correct + # per-rank workload, so ep_mlp_scale == 1.0. ep_mlp_scale = _compute_ep_mlp_scale( - training_config.model_config, benchmark_ep, original_ep + training_config.model_config, + benchmark_ep, + original_ep, + original_num_experts=original_num_experts, + benchmark_num_experts=benchmark_num_experts, ) if is_rank_0: @@ -2306,59 +2350,48 @@ def launch_projection_from_cli(args, overrides): "[Primus:Performance Projection] Adjusting profiling results for EP scaling:" ) print(f" EP rescaled: {benchmark_ep} → {original_ep}") - print(f" MLP time scale factor: {ep_mlp_scale:.3f}") - # Show shared vs routed breakdown - topk = getattr(training_config.model_config, "moe_router_topk", 1) or 1 - moe_ffn = getattr(training_config.model_config, "moe_ffn_hidden_size", None) - shared_ffn = getattr( - training_config.model_config, - "moe_shared_expert_intermediate_size", - None, - ) - num_shared = ( - getattr(training_config.model_config, "num_shared_experts", 0) or 0 - ) - if moe_ffn and num_shared > 0 and shared_ffn: - routed_flops = (topk / benchmark_ep) * moe_ffn - shared_flops = num_shared * shared_ffn - total_flops = routed_flops + shared_flops - print( - f" Routed fraction: {routed_flops/total_flops:.1%} (topk={topk}, EP={benchmark_ep}, ffn={moe_ffn})" - ) - print( - f" Shared fraction: {shared_flops/total_flops:.1%} ({num_shared} shared expert(s), ffn={shared_ffn})" - ) - else: + if original_num_experts is not None and benchmark_num_experts is not None: + orig_epr = original_num_experts // original_ep + bench_epr = benchmark_num_experts // benchmark_ep print( - f" No shared experts — full routed scaling ({benchmark_ep}/{original_ep})" + f" Experts per rank: benchmark={bench_epr} " + f"(E={benchmark_num_experts}, EP={benchmark_ep}), " + f"target={orig_epr} " + f"(E={original_num_experts}, EP={original_ep})" ) + print(f" MLP time scale factor: {ep_mlp_scale:.3f}") if fwd_overhead_per_layer > 0 or bwd_overhead_per_layer > 0: print(f" Adding per-layer All-to-All overhead:") print(f" Forward: +{fwd_overhead_per_layer:.3f} ms/layer") print(f" Backward: +{bwd_overhead_per_layer:.3f} ms/layer") - # Adjust MoE layer times in profiling_results + # Adjust MoE layer times in profiling_results using a DELTA approach: + # new_layer_time = old_layer_time + (mlp_delta) + (a2a_delta) + # This preserves TP AllReduce, A2A(benchmark_ep), LayerNorm, and + # other components that are baked into the profiled layer time. moe_layers_adjusted = 0 for layer_idx, layer_data in profiling_results.items(): if isinstance(layer_data, dict) and layer_data.get("type") == "moe": + old_fwd = layer_data.get("forward_time_ms", 0) + old_bwd = layer_data.get("backward_time_ms", 0) + mlp_info = layer_data.get("mlp", {}) mlp_fwd = mlp_info.get("forward_time_ms", 0) mlp_bwd = mlp_info.get("backward_time_ms", 0) - attn_info = layer_data.get("attention", {}) - attn_fwd = attn_info.get("forward_time_ms", 0) - attn_bwd = attn_info.get("backward_time_ms", 0) - # Scale MLP compute (shared-expert-aware), keep attention unchanged + # Compute MLP delta (usually 0 when scale == 1.0) new_mlp_fwd = mlp_fwd * ep_mlp_scale new_mlp_bwd = mlp_bwd * ep_mlp_scale + mlp_delta_fwd = new_mlp_fwd - mlp_fwd + mlp_delta_bwd = new_mlp_bwd - mlp_bwd - # New layer time = attention + scaled MLP + A2A comm overhead - new_fwd = attn_fwd + new_mlp_fwd + fwd_overhead_per_layer - new_bwd = attn_bwd + new_mlp_bwd + bwd_overhead_per_layer + # Delta approach: start from the full profiled layer time + # (which includes TP AR, A2A(benchmark_ep), norms, etc.) + # and add the MLP compute delta + A2A comm delta. + new_fwd = old_fwd + mlp_delta_fwd + fwd_overhead_per_layer + new_bwd = old_bwd + mlp_delta_bwd + bwd_overhead_per_layer if is_rank_0 and moe_layers_adjusted == 0: - old_fwd = layer_data.get("forward_time_ms", 0) - old_bwd = layer_data.get("backward_time_ms", 0) print(f" MoE layer adjustment (per layer):") print( f" MLP fwd: {mlp_fwd:.2f} → {new_mlp_fwd:.2f} ms (×{ep_mlp_scale:.3f})" @@ -2366,8 +2399,8 @@ def launch_projection_from_cli(args, overrides): print( f" MLP bwd: {mlp_bwd:.2f} → {new_mlp_bwd:.2f} ms (×{ep_mlp_scale:.3f})" ) - print(f" Attn fwd: {attn_fwd:.2f} ms (unchanged)") - print(f" Attn bwd: {attn_bwd:.2f} ms (unchanged)") + print(f" A2A fwd delta: +{fwd_overhead_per_layer:.3f} ms") + print(f" A2A bwd delta: +{bwd_overhead_per_layer:.3f} ms") print(f" Layer fwd: {old_fwd:.2f} → {new_fwd:.2f} ms") print(f" Layer bwd: {old_bwd:.2f} → {new_bwd:.2f} ms") @@ -2492,9 +2525,16 @@ def launch_projection_from_cli(args, overrides): fwd_overhead_per_layer + bwd_overhead_per_layer ) * num_moe_layers - # EP compute scaling (shared-expert-aware) + # EP compute scaling — per-GPU MoE routed compute is + # EP-invariant (see _compute_ep_mlp_scale docstring). + original_num_experts = reduction_info.get("original_num_experts") + benchmark_num_experts = reduction_info.get("benchmark_num_experts") ep_mlp_scale = _compute_ep_mlp_scale( - training_config.model_config, benchmark_ep_val, original_ep + training_config.model_config, + benchmark_ep_val, + original_ep, + original_num_experts=original_num_experts, + benchmark_num_experts=benchmark_num_experts, ) # Estimate MLP portion of MoE layer time from profiling results mlp_time_reduction = 0.0 diff --git a/primus/core/projection/simulation_backends/base.py b/primus/core/projection/simulation_backends/base.py index 91b5aecd1..542eca1db 100644 --- a/primus/core/projection/simulation_backends/base.py +++ b/primus/core/projection/simulation_backends/base.py @@ -59,15 +59,27 @@ def simulate_gemm( dtype: str = "bf16", trans_a: bool = False, trans_b: bool = False, + batch: int = 1, ) -> SimulationResult: """ - Simulate a single GEMM operation and return predicted timing. + Simulate a single GEMM (or batched GEMM) and return predicted timing. Args: m, n, k: Matrix dimensions (C = A @ B, A:[M,K] B:[K,N] C:[M,N]) dtype: Data type string ("bf16", "fp16", "fp8", "fp32") trans_a: Whether A is transposed trans_b: Whether B is transposed + batch: Number of independent GEMMs with the **same** (M, N, K) to + run as a **batched** GEMM. For MoE experts this is used as an + approximation of grouped GEMM under the assumption of uniform + token distribution across experts (so every sub-problem has the + same M = tokens_per_expert). Defaults to 1 (single GEMM). + + NOTE: Origami's ``problem.batch`` models batched GEMM, not true + grouped GEMM (where each sub-problem can have a different M). + This is an acceptable approximation when token distribution is + assumed uniform. If Origami adds native grouped-GEMM support + in the future, this should be updated. Returns: SimulationResult with forward_time_ms populated. @@ -81,6 +93,7 @@ def simulate_mlp_gemms( ffn_hidden_size: int, dtype: str = "bf16", swiglu: bool = False, + num_experts: int = 1, ) -> SimulationResult: """ Simulate the GEMM operations in a dense MLP (gate/up/down projections). @@ -88,38 +101,83 @@ def simulate_mlp_gemms( Default implementation calls ``simulate_gemm`` for each projection and sums the times. Backends may override for better accuracy. + When ``num_experts > 1``, each GEMM is simulated as a **batched** GEMM + (``batch = num_experts``) with per-expert token count, which serves as + an approximation of grouped GEMM under uniform token distribution. + See ``simulate_gemm`` docstring for caveats. + Args: - batch_tokens: Number of tokens (batch_size * seq_len / TP / CP) + batch_tokens: Number of tokens per expert + (batch_size * seq_len / TP / CP for dense; + topk_tokens / num_local_experts for MoE routed experts) hidden_size: Model hidden dimension ffn_hidden_size: FFN intermediate dimension dtype: Data type string swiglu: Whether SwiGLU activation is used (3 projections vs 2) + num_experts: Number of local experts. When > 1, each projection + GEMM uses ``batch = num_experts`` as a batched-GEMM + approximation of grouped GEMM. Defaults to 1 (dense MLP). Returns: SimulationResult with forward_time_ms and backward_time_ms. """ fwd_time = 0.0 bwd_time = 0.0 + # Use batched GEMM (batch=num_experts) as approximation of grouped GEMM. + # Valid under uniform token distribution (all experts get the same M). + # TODO: switch to native grouped-GEMM simulation if/when Origami supports it. + b = num_experts if swiglu: - # Gate projection: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] - gate_res = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype) - # Up projection: same shape - up_res = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype) - # Down projection: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] - down_res = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype) - - fwd_time = gate_res.forward_time_ms + up_res.forward_time_ms + down_res.forward_time_ms - # Backward is approximately 2x forward (dgrad + wgrad per projection) - bwd_time = fwd_time * 2.0 + # Gate projection fwd: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] + gate_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + # Up projection fwd: same shape as gate + up_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + # Down projection fwd: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] + down_fwd = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + + fwd_time = gate_fwd.forward_time_ms + up_fwd.forward_time_ms + down_fwd.forward_time_ms + + # Backward: simulate actual dgrad + wgrad GEMMs per projection + # Gate dgrad: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] + gate_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + # Gate wgrad: [hidden, tokens] x [tokens, ffn] -> [hidden, ffn] + gate_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) + # Up dgrad + wgrad: same shapes as gate + up_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + up_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) + # Down dgrad: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] + down_dgrad = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + # Down wgrad: [ffn, tokens] x [tokens, hidden] -> [ffn, hidden] + down_wgrad = self.simulate_gemm(ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b) + + bwd_time = ( + gate_dgrad.forward_time_ms + gate_wgrad.forward_time_ms + + up_dgrad.forward_time_ms + up_wgrad.forward_time_ms + + down_dgrad.forward_time_ms + down_wgrad.forward_time_ms + ) else: - # Up projection - up_res = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype) - # Down projection - down_res = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype) - - fwd_time = up_res.forward_time_ms + down_res.forward_time_ms - bwd_time = fwd_time * 2.0 + # Up projection fwd: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] + up_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + # Down projection fwd: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] + down_fwd = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + + fwd_time = up_fwd.forward_time_ms + down_fwd.forward_time_ms + + # Backward: simulate actual dgrad + wgrad GEMMs per projection + # Up dgrad: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] + up_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + # Up wgrad: [hidden, tokens] x [tokens, ffn] -> [hidden, ffn] + up_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) + # Down dgrad: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] + down_dgrad = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + # Down wgrad: [ffn, tokens] x [tokens, hidden] -> [ffn, hidden] + down_wgrad = self.simulate_gemm(ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b) + + bwd_time = ( + up_dgrad.forward_time_ms + up_wgrad.forward_time_ms + + down_dgrad.forward_time_ms + down_wgrad.forward_time_ms + ) return SimulationResult(forward_time_ms=fwd_time, backward_time_ms=bwd_time) @@ -143,26 +201,49 @@ def simulate_attention_gemms( SimulationResult with forward_time_ms and backward_time_ms. """ fwd_time = 0.0 + bwd_time = 0.0 - # Q projection: [tokens, hidden] x [hidden, heads*kv_channels] + # Q projection fwd: [tokens, hidden] x [hidden, heads*kv_channels] q_out = num_attention_heads * kv_channels - q_res = self.simulate_gemm(batch_tokens, q_out, hidden_size, dtype) - fwd_time += q_res.forward_time_ms + q_fwd = self.simulate_gemm(batch_tokens, q_out, hidden_size, dtype) + fwd_time += q_fwd.forward_time_ms - # K projection: [tokens, hidden] x [hidden, num_query_groups*kv_channels] + # K projection fwd: [tokens, hidden] x [hidden, num_query_groups*kv_channels] k_out = num_query_groups * kv_channels - k_res = self.simulate_gemm(batch_tokens, k_out, hidden_size, dtype) - fwd_time += k_res.forward_time_ms - - # V projection: same shape as K - v_res = self.simulate_gemm(batch_tokens, k_out, hidden_size, dtype) - fwd_time += v_res.forward_time_ms - - # Output projection: [tokens, heads*kv_channels] x [heads*kv_channels, hidden] - o_res = self.simulate_gemm(batch_tokens, hidden_size, q_out, dtype) - fwd_time += o_res.forward_time_ms - - bwd_time = fwd_time * 2.0 # dgrad + wgrad + k_fwd = self.simulate_gemm(batch_tokens, k_out, hidden_size, dtype) + fwd_time += k_fwd.forward_time_ms + + # V projection fwd: same shape as K + v_fwd = self.simulate_gemm(batch_tokens, k_out, hidden_size, dtype) + fwd_time += v_fwd.forward_time_ms + + # Output projection fwd: [tokens, heads*kv_channels] x [heads*kv_channels, hidden] + o_fwd = self.simulate_gemm(batch_tokens, hidden_size, q_out, dtype) + fwd_time += o_fwd.forward_time_ms + + # Backward: simulate actual dgrad + wgrad GEMMs per projection + # Q dgrad: [tokens, q_out] x [q_out, hidden] -> [tokens, hidden] + q_dgrad = self.simulate_gemm(batch_tokens, hidden_size, q_out, dtype) + # Q wgrad: [hidden, tokens] x [tokens, q_out] -> [hidden, q_out] + q_wgrad = self.simulate_gemm(hidden_size, q_out, batch_tokens, dtype) + bwd_time += q_dgrad.forward_time_ms + q_wgrad.forward_time_ms + + # K dgrad: [tokens, k_out] x [k_out, hidden] -> [tokens, hidden] + k_dgrad = self.simulate_gemm(batch_tokens, hidden_size, k_out, dtype) + # K wgrad: [hidden, tokens] x [tokens, k_out] -> [hidden, k_out] + k_wgrad = self.simulate_gemm(hidden_size, k_out, batch_tokens, dtype) + bwd_time += k_dgrad.forward_time_ms + k_wgrad.forward_time_ms + + # V dgrad + wgrad: same shapes as K + v_dgrad = self.simulate_gemm(batch_tokens, hidden_size, k_out, dtype) + v_wgrad = self.simulate_gemm(hidden_size, k_out, batch_tokens, dtype) + bwd_time += v_dgrad.forward_time_ms + v_wgrad.forward_time_ms + + # O dgrad: [tokens, hidden] x [hidden, q_out] -> [tokens, q_out] + o_dgrad = self.simulate_gemm(batch_tokens, q_out, hidden_size, dtype) + # O wgrad: [q_out, tokens] x [tokens, hidden] -> [q_out, hidden] + o_wgrad = self.simulate_gemm(q_out, hidden_size, batch_tokens, dtype) + bwd_time += o_dgrad.forward_time_ms + o_wgrad.forward_time_ms return SimulationResult(forward_time_ms=fwd_time, backward_time_ms=bwd_time) @@ -191,6 +272,7 @@ def simulate_sdpa( dtype: str = "bf16", seq_len_kv: Optional[int] = None, num_heads_kv: Optional[int] = None, + head_dim_v: Optional[int] = None, ) -> SimulationResult: """ Simulate a Scaled Dot-Product Attention operation. @@ -199,13 +281,19 @@ def simulate_sdpa( batch_size: Batch size num_heads: Number of query attention heads (per TP rank) seq_len: Query sequence length (per CP rank) - head_dim: Head dimension (kv_channels) + head_dim: Head dimension used in the Q·Kᵀ dot-product. For + standard MHA/GQA this equals ``kv_channels``; for MLA it + equals ``qk_head_dim + qk_pos_emb_head_dim`` (e.g. 192). causal: Whether causal masking is used dtype: Data type string seq_len_kv: Key/Value sequence length. Defaults to ``seq_len`` (self-attention). num_heads_kv: Number of KV heads. Defaults to ``num_heads`` (MHA). Set lower for GQA / MQA. + head_dim_v: Value head dimension used in the P·V product and + for sizing the output O. Defaults to ``head_dim`` (standard + attention where Q/K/V all share the same dimension). For + MLA this should be set to ``v_head_dim`` (e.g. 128). Returns: SimulationResult with forward_time_ms and backward_time_ms. diff --git a/primus/core/projection/simulation_backends/factory.py b/primus/core/projection/simulation_backends/factory.py index 277406db6..a9baeb689 100644 --- a/primus/core/projection/simulation_backends/factory.py +++ b/primus/core/projection/simulation_backends/factory.py @@ -88,11 +88,9 @@ def get_sdpa_simulation_backend( Args: gpu_arch: GPU architecture override (e.g. "mi300x", "mi355x"). compute_efficiency: Fraction of peak compute achieved (0-1). - Defaults to 0.51 — calibrated against measured FAv3 traces on - MI300X (B=3, H_Q=64, S=8192, D=128, H_KV=8, GQA, causal, BF16). - The lower-than-theoretical efficiency accounts for GQA head - broadcasting overhead, LDS bank conflicts, barrier synchronisation, - and register pressure. + Defaults to 0.51. The lower-than-theoretical efficiency accounts + for GQA head broadcasting overhead, LDS bank conflicts, barrier + synchronisation, and register pressure. memory_efficiency: Fraction of peak HBM bandwidth achieved (0-1). Defaults to 0.85 — FAv3 streaming pattern typically achieves 0.80-0.90. gpu_clock_mhz: Override the GPU compute clock frequency in MHz. diff --git a/primus/core/projection/simulation_backends/origami_backend.py b/primus/core/projection/simulation_backends/origami_backend.py index 5d344fdd4..5b82c5bd3 100644 --- a/primus/core/projection/simulation_backends/origami_backend.py +++ b/primus/core/projection/simulation_backends/origami_backend.py @@ -213,6 +213,10 @@ def simulate_gemm( self._ensure_initialized(dtype) # ----- Build origami problem_t ----- + # NOTE: problem.batch models **batched** GEMM (all sub-problems share + # the same M, N, K). For MoE experts this is used as an approximation + # of grouped GEMM under uniform token distribution. Origami does not + # currently expose a native grouped-GEMM model. problem = _origami.problem_t() problem.size = _origami.dim3_t(m, n, k) problem.batch = batch diff --git a/primus/core/projection/simulation_backends/sdpa_simulator.py b/primus/core/projection/simulation_backends/sdpa_simulator.py index 060335864..1b8052630 100644 --- a/primus/core/projection/simulation_backends/sdpa_simulator.py +++ b/primus/core/projection/simulation_backends/sdpa_simulator.py @@ -229,10 +229,6 @@ def __init__( "mi355x", "gfx950"). hardware_spec: Override hardware spec directly. compute_efficiency: Fraction of peak TFLOPS achieved (0-1). - Calibrated against measured FAv3 traces on MI300X: - * Measured FA fwd = 5.05 ms, bwd = 10.00 ms - (B=3, H_Q=64, S=8192, D=128, H_KV=8, causal, BF16) - * 0.51 matches measured within 1%. The lower-than-peak efficiency (vs theoretical 0.75-0.85) accounts for GQA head broadcasting, LDS bank conflicts, barrier synchronisation, and register pressure. @@ -270,6 +266,7 @@ def simulate_sdpa( dtype: str = "bf16", seq_len_kv: Optional[int] = None, num_heads_kv: Optional[int] = None, + head_dim_v: Optional[int] = None, ) -> SimulationResult: """ Simulate FAv3 SDPA execution time using a roofline model @@ -279,7 +276,9 @@ def simulate_sdpa( batch_size: Batch size (B). num_heads: Number of query heads (H_Q). seq_len: Query sequence length (S_Q). - head_dim: Head dimension (D). + head_dim: Head dimension for Q·Kᵀ (D_qk). For standard + MHA/GQA this is ``kv_channels``; for MLA this is + ``qk_head_dim + qk_pos_emb_head_dim`` (e.g. 192). causal: Whether causal masking is applied. dtype: Data type ("bf16", "fp16", "fp8", "fp32"). seq_len_kv: Key/Value sequence length (S_K). Defaults to @@ -287,13 +286,17 @@ def simulate_sdpa( cross-attention or prefill with separate KV cache length. num_heads_kv: Number of KV heads. Defaults to ``num_heads`` (MHA). Set lower for GQA/MQA. + head_dim_v: Value head dimension for P·V (D_v). Defaults to + ``head_dim`` (standard attention). For MLA set to + ``v_head_dim`` (e.g. 128). """ B = batch_size H_Q = num_heads S_Q = seq_len S_K = seq_len_kv if seq_len_kv is not None else seq_len H_K = num_heads_kv if num_heads_kv is not None else num_heads - D = head_dim + D_qk = head_dim + D_v = head_dim_v if head_dim_v is not None else head_dim bpe = self._bytes_per_element(dtype) # GQA ratio: each KV head serves (H_Q / H_K) query heads. @@ -306,54 +309,63 @@ def simulate_sdpa( # 1. COMPUTE (FLOP counts) # ============================================================== # Forward (per query head, then × H_Q) - # QKᵀ : 2·B·H_Q·S_Q·S_K·D (batched GEMM) - # softmax : ~5·B·H_Q·S_Q·S_K (exp, sub-max, sum, div, mul) - # PV : 2·B·H_Q·S_Q·S_K·D (batched GEMM — P is S_Q×S_K, V is S_K×D) - # NOTE for PV: output is (S_Q, D), inner dim is S_K. + # QKᵀ : 2·B·H_Q·S_Q·S_K·D_qk (batched GEMM) + # softmax : ~5·B·H_Q·S_Q·S_K (exp, sub-max, sum, div, mul) + # PV : 2·B·H_Q·S_Q·S_K·D_v (batched GEMM — P is S_Q×S_K, V is S_K×D_v) + # NOTE for PV: output is (S_Q, D_v), inner dim is S_K. # For causal masking, only ~half the S_Q×S_K elements are computed # (only valid when S_Q == S_K; for cross-attn causal is usually False). - fwd_gemm_flops = 2.0 * (2.0 * B * H_Q * S_Q * S_K * D) * causal_factor + # + # When D_qk == D_v (standard MHA/GQA) this reduces to the familiar + # 2 × (2·B·H·S·S·D) formula. For MLA, D_qk > D_v (e.g. 192 vs 128). + fwd_qk_flops = 2.0 * B * H_Q * S_Q * S_K * D_qk * causal_factor + fwd_pv_flops = 2.0 * B * H_Q * S_Q * S_K * D_v * causal_factor fwd_softmax_flops = 5.0 * B * H_Q * S_Q * S_K * causal_factor - fwd_flops = fwd_gemm_flops + fwd_softmax_flops + fwd_flops = fwd_qk_flops + fwd_pv_flops + fwd_softmax_flops # Backward (4 batched GEMMs + softmax backward) - # dV = Pᵀ @ dO : 2·B·H_Q·S_K·S_Q → (S_K, D) inner dim S_Q - # dP = dO @ Vᵀ : 2·B·H_Q·S_Q·D → (S_Q, S_K) inner dim D + # dV = Pᵀ @ dO : 2·B·H_Q·S_K·S_Q·D_v (inner dim S_Q, out S_K×D_v) + # dP = dO @ Vᵀ : 2·B·H_Q·S_Q·S_K·D_v (inner dim D_v, out S_Q×S_K) # dS = softmax_bwd : ~5·B·H_Q·S_Q·S_K - # dQ = dS @ K : 2·B·H_Q·S_Q·S_K → (S_Q, D) inner dim S_K - # dK = dSᵀ @ Q : 2·B·H_Q·S_K·S_Q → (S_K, D) inner dim S_Q - bwd_gemm_flops = 2.0 * (4.0 * B * H_Q * S_Q * S_K * D) * causal_factor + # dQ = dS @ K : 2·B·H_Q·S_Q·S_K·D_qk (inner dim S_K, out S_Q×D_qk) + # dK = dSᵀ @ Q : 2·B·H_Q·S_K·S_Q·D_qk (inner dim S_Q, out S_K×D_qk) + bwd_dv_flops = 2.0 * B * H_Q * S_K * S_Q * D_v * causal_factor + bwd_dp_flops = 2.0 * B * H_Q * S_Q * S_K * D_v * causal_factor + bwd_dq_flops = 2.0 * B * H_Q * S_Q * S_K * D_qk * causal_factor + bwd_dk_flops = 2.0 * B * H_Q * S_K * S_Q * D_qk * causal_factor bwd_softmax_flops = 5.0 * B * H_Q * S_Q * S_K * causal_factor - bwd_flops = bwd_gemm_flops + bwd_softmax_flops + bwd_flops = bwd_dv_flops + bwd_dp_flops + bwd_dq_flops + bwd_dk_flops + bwd_softmax_flops # ============================================================== # 2. MEMORY IO (Flash Attention – no S/P materialised to HBM) # ============================================================== - # Forward reads: Q (B·H_Q·S_Q·D), K (B·H_K·S_K·D), V (B·H_K·S_K·D) - # Forward writes: O (B·H_Q·S_Q·D) + logsumexp (B·H_Q·S_Q, fp32) + # Q and K use D_qk; V and O use D_v. + # Forward reads: Q (B·H_Q·S_Q·D_qk), K (B·H_K·S_K·D_qk), V (B·H_K·S_K·D_v) + # Forward writes: O (B·H_Q·S_Q·D_v) + logsumexp (B·H_Q·S_Q, fp32) fwd_read_bytes = ( - B * H_Q * S_Q * D * bpe # Q - + B * H_K * S_K * D * bpe # K - + B * H_K * S_K * D * bpe # V + B * H_Q * S_Q * D_qk * bpe # Q + + B * H_K * S_K * D_qk * bpe # K + + B * H_K * S_K * D_v * bpe # V ) fwd_write_bytes = ( - B * H_Q * S_Q * D * bpe + B * H_Q * S_Q * 4 # O # logsumexp (fp32) + B * H_Q * S_Q * D_v * bpe + B * H_Q * S_Q * 4 # O # logsumexp (fp32) ) fwd_bytes = fwd_read_bytes + fwd_write_bytes # Backward reads: Q, K, V, O, dO + logsumexp - # Backward regular writes: dK (B·H_K·S_K·D) + dV (B·H_K·S_K·D) + # Backward regular writes: dK (B·H_K·S_K·D_qk) + dV (B·H_K·S_K·D_v) # NOTE: dQ uses buffer_atomic_add_f32 — accounted separately. bwd_read_bytes = ( - B * H_Q * S_Q * D * bpe # Q - + B * H_K * S_K * D * bpe # K - + B * H_K * S_K * D * bpe # V - + B * H_Q * S_Q * D * bpe # O - + B * H_Q * S_Q * D * bpe # dO + B * H_Q * S_Q * D_qk * bpe # Q + + B * H_K * S_K * D_qk * bpe # K + + B * H_K * S_K * D_v * bpe # V + + B * H_Q * S_Q * D_v * bpe # O + + B * H_Q * S_Q * D_v * bpe # dO + B * H_Q * S_Q * 4 # logsumexp (fp32) ) bwd_regular_write_bytes = ( - B * H_K * S_K * D * bpe + B * H_K * S_K * D * bpe # dK # dV + B * H_K * S_K * D_qk * bpe # dK + + B * H_K * S_K * D_v * bpe # dV ) bwd_bytes = bwd_read_bytes + bwd_regular_write_bytes @@ -386,8 +398,8 @@ def simulate_sdpa( hbm_writers_per_element = min(n_kv_workgroups, self._hw.n_xcd) # Effective dQ bytes hitting HBM (after L2 coalescing) - # dQ shape is (B, H_Q, S_Q, D), stored in FP32 (4 bytes) - dq_atomic_bytes = float(hbm_writers_per_element) * B * H_Q * S_Q * D * 4.0 + # dQ shape is (B, H_Q, S_Q, D_qk), stored in FP32 (4 bytes) + dq_atomic_bytes = float(hbm_writers_per_element) * B * H_Q * S_Q * D_qk * 4.0 # Atomic slowdown = just the RMW factor (contention within-XCD # is absorbed by L2; cross-XCD traffic goes to different memory diff --git a/primus/core/projection/training_config.py b/primus/core/projection/training_config.py index 641ec6566..475a3bb50 100644 --- a/primus/core/projection/training_config.py +++ b/primus/core/projection/training_config.py @@ -66,6 +66,10 @@ class ModelConfig: # Precision – None means bf16, "hybrid" means FP8-hybrid (linear GEMMs in FP8) fp8: str = None + # Primus Turbo flags — used to select the grouped-GEMM performance model + enable_primus_turbo: bool = False + use_turbo_grouped_mlp: bool = False + @dataclass class TrainingConfig: From 0668c5c6423b28758c93050575779dd406cead8c Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Thu, 19 Feb 2026 22:50:40 +0000 Subject: [PATCH 05/12] feat(simulation): replace SDPA roofline model with Origami 1-CU tile-level simulation Replace the analytical roofline SDPA simulator with an Origami-based tile-level model that simulates each per-workgroup GEMM on a single CU. This captures wave quantisation, LDS traffic, and pipeline effects that the global max(compute, memory) roofline missed, eliminating the need for empirical compute_efficiency / memory_efficiency parameters. --- primus/cli/subcommands/projection.py | 2 +- .../projection/module_profilers/moe_mlp.py | 37 +- .../projection/simulation_backends/base.py | 11 + .../projection/simulation_backends/factory.py | 20 +- .../simulation_backends/origami_backend.py | 87 +++- .../simulation_backends/sdpa_simulator.py | 492 +++++++++--------- 6 files changed, 367 insertions(+), 282 deletions(-) diff --git a/primus/cli/subcommands/projection.py b/primus/cli/subcommands/projection.py index 821bb4a0c..8ef62985a 100644 --- a/primus/cli/subcommands/projection.py +++ b/primus/cli/subcommands/projection.py @@ -141,7 +141,7 @@ def register_subcommand(subparsers): help=( "Override the GPU compute clock frequency in MHz for simulation.\n" "If not specified, uses the default from the hardware profile for the\n" - "given --gpu-arch (e.g. 2100 MHz for MI300X, 1200 MHz for MI325X).\n" + "given --gpu-arch (e.g. 2100 MHz for MI300X/MI325X).\n" "Can also be set via the PRIMUS_GPU_CLOCK_MHZ env var.\n" "Example: --gpu-clock-mhz 1500\n" ), diff --git a/primus/core/projection/module_profilers/moe_mlp.py b/primus/core/projection/module_profilers/moe_mlp.py index 40e2d0946..0f4dc585a 100644 --- a/primus/core/projection/module_profilers/moe_mlp.py +++ b/primus/core/projection/module_profilers/moe_mlp.py @@ -14,9 +14,22 @@ from .utils import benchmark_layer -# Memory-bandwidth constants for non-GEMM MoE overhead estimation -_PERMUTE_EFF_BW_GBPS = 300.0 # scatter/gather effective BW (~5-10% peak HBM) -_ACTIVATION_BW_GBPS = 3000.0 # sequential element-wise ops (~60% peak HBM) +# Efficiency fractions for non-GEMM MoE overhead estimation. +# These express achievable bandwidth as a fraction of peak HBM bandwidth. +# The actual BW is ``fraction × peak_hbm_bw`` for the target architecture, +# so the model scales automatically across MI300X (5.3 TB/s), MI325X (6.0 +# TB/s), MI355X (8.0 TB/s), etc. +# +# PERMUTE (scatter/gather) — random-access token dispatch/combine. Irregular +# access patterns achieve only ~5-7 % of peak HBM bandwidth. +_PERMUTE_BW_FRACTION = 0.057 +# +# ACTIVATION (SwiGLU / GELU) — sequential element-wise ops that stream over +# contiguous buffers. Typically ~55-60 % of peak HBM bandwidth. +_ACTIVATION_BW_FRACTION = 0.566 +# +# Fallback absolute values used when the backend cannot report HBM bandwidth. +_FALLBACK_HBM_BW_GBPS = 5300.0 # MI300X default class MoEMLPProfiler(BaseModuleProfiler): @@ -271,10 +284,22 @@ def _get_simulated_results( # ── 3. Token permutation overhead (dispatch + combine) ── # Dispatch: gather tokens by expert assignment → irregular memory access # Combine: scatter expert outputs back → weighted reduce + # + # Derive effective BW from the target GPU's peak HBM bandwidth so the + # model adapts automatically to different architectures. + peak_hbm = ( + self._gemm_backend.hbm_bandwidth_gbps + if self._gemm_backend is not None + and self._gemm_backend.hbm_bandwidth_gbps is not None + else _FALLBACK_HBM_BW_GBPS + ) + permute_eff_bw_gbps = peak_hbm * _PERMUTE_BW_FRACTION + activation_bw_gbps = peak_hbm * _ACTIVATION_BW_FRACTION + dispatch_bytes = (batch_tokens + topk_tokens) * hidden_size * bytes_per_el combine_bytes = (topk_tokens + batch_tokens) * hidden_size * bytes_per_el - permute_fwd_ms = dispatch_bytes / (_PERMUTE_EFF_BW_GBPS * 1e6) - permute_bwd_ms = combine_bytes / (_PERMUTE_EFF_BW_GBPS * 1e6) + permute_fwd_ms = dispatch_bytes / (permute_eff_bw_gbps * 1e6) + permute_bwd_ms = combine_bytes / (permute_eff_bw_gbps * 1e6) fwd_time += permute_fwd_ms bwd_time += permute_bwd_ms @@ -284,7 +309,7 @@ def _get_simulated_results( act_bytes = 3 * topk_tokens * moe_ffn * bytes_per_el # gate+up read, result write else: act_bytes = 2 * topk_tokens * moe_ffn * bytes_per_el # read + write - activation_ms = act_bytes / (_ACTIVATION_BW_GBPS * 1e6) + activation_ms = act_bytes / (activation_bw_gbps * 1e6) fwd_time += activation_ms bwd_time += activation_ms diff --git a/primus/core/projection/simulation_backends/base.py b/primus/core/projection/simulation_backends/base.py index 542eca1db..b9daea699 100644 --- a/primus/core/projection/simulation_backends/base.py +++ b/primus/core/projection/simulation_backends/base.py @@ -50,6 +50,17 @@ def is_available(self) -> bool: """Return True if this backend can be used in the current environment.""" ... + @property + def hbm_bandwidth_gbps(self) -> Optional[float]: + """Peak HBM bandwidth in GB/s for the target GPU, or *None* if unknown. + + Concrete backends should override this when the target architecture is + known so that downstream profilers (e.g. MoE non-GEMM overhead) can + derive memory-bandwidth estimates from the actual hardware rather than + relying on hardcoded absolute numbers. + """ + return None + @abstractmethod def simulate_gemm( self, diff --git a/primus/core/projection/simulation_backends/factory.py b/primus/core/projection/simulation_backends/factory.py index a9baeb689..114b85cb4 100644 --- a/primus/core/projection/simulation_backends/factory.py +++ b/primus/core/projection/simulation_backends/factory.py @@ -74,29 +74,23 @@ def get_gemm_simulation_backend( def get_sdpa_simulation_backend( gpu_arch: Optional[str] = None, - compute_efficiency: float = 0.51, - memory_efficiency: float = 0.85, gpu_clock_mhz: Optional[int] = None, ) -> SDPASimulationBackend: """ Create and return the SDPA simulation backend. - The default backend is an analytical model of the FAv3 (Flash Attention v3) - kernels, with tile sizes, wavefront counts, and efficiency factors - derived from the kernel configurations. + Uses the Origami 1-CU tile-level model of the FAv3 (Flash Attention v3) + kernels. Origami must be installed. Args: gpu_arch: GPU architecture override (e.g. "mi300x", "mi355x"). - compute_efficiency: Fraction of peak compute achieved (0-1). - Defaults to 0.51. The lower-than-theoretical efficiency accounts - for GQA head broadcasting overhead, LDS bank conflicts, barrier - synchronisation, and register pressure. - memory_efficiency: Fraction of peak HBM bandwidth achieved (0-1). - Defaults to 0.85 — FAv3 streaming pattern typically achieves 0.80-0.90. gpu_clock_mhz: Override the GPU compute clock frequency in MHz. Returns: An SDPASimulationBackend instance. + + Raises: + RuntimeError: If the Origami backend is not available. """ from primus.core.projection.simulation_backends.sdpa_simulator import ( SDPASimulator, @@ -105,12 +99,10 @@ def get_sdpa_simulation_backend( is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0: print( - "[Primus:Simulation] Using SDPA backend: sdpa_simulator (FAv3 analytical model)" + "[Primus:Simulation] Using SDPA backend: sdpa_simulator (FAv3 Origami 1-CU)" ) return SDPASimulator( gpu_arch=gpu_arch, - compute_efficiency=compute_efficiency, - memory_efficiency=memory_efficiency, gpu_clock_mhz=gpu_clock_mhz, ) diff --git a/primus/core/projection/simulation_backends/origami_backend.py b/primus/core/projection/simulation_backends/origami_backend.py index 5b82c5bd3..94e29d3e1 100644 --- a/primus/core/projection/simulation_backends/origami_backend.py +++ b/primus/core/projection/simulation_backends/origami_backend.py @@ -72,19 +72,20 @@ class _HardwareProfile: lds_capacity: int # bytes l2_capacity: int # bytes (per XCD) compute_clock_khz: int + hbm_bandwidth_gbps: float = 5300.0 # peak HBM bandwidth (GB/s) _KNOWN_PROFILES: Dict[str, _HardwareProfile] = { - # MI300X / gfx942 - "mi300x": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000), - "gfx942": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000), - # MI325X / gfx942 (same die as MI300X, HBM3E upgrade) - "mi325x": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000), - # MI355X / gfx950 - "mi355x": _HardwareProfile("gfx950", 256, 65536, 4_194_304, 2_100_000), - "gfx950": _HardwareProfile("gfx950", 256, 65536, 4_194_304, 2_100_000), + # MI300X / gfx942: HBM3 ~5.3 TB/s + "mi300x": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000, 5300.0), + "gfx942": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000, 5300.0), + # MI325X / gfx942 (same die as MI300X, HBM3E upgrade): ~6.0 TB/s + "mi325x": _HardwareProfile("gfx942", 304, 65536, 4_194_304, 2_100_000, 6000.0), + # MI355X / gfx950: HBM3E ~8.0 TB/s + "mi355x": _HardwareProfile("gfx950", 256, 65536, 4_194_304, 2_100_000, 8000.0), + "gfx950": _HardwareProfile("gfx950", 256, 65536, 4_194_304, 2_100_000, 8000.0), # MI300A - "mi300a": _HardwareProfile("gfx942", 228, 65536, 4_194_304, 2_100_000), + "mi300a": _HardwareProfile("gfx942", 228, 65536, 4_194_304, 2_100_000, 4000.0), } # --------------------------------------------------------------------------- @@ -141,6 +142,7 @@ def __init__( self, gpu_arch: Optional[str] = None, gpu_clock_mhz: Optional[int] = None, + n_cu_override: Optional[int] = None, ): """ Args: @@ -150,6 +152,11 @@ def __init__( gpu_clock_mhz: Override the compute clock frequency in MHz. If *None*, uses the profile default or the ``PRIMUS_GPU_CLOCK_MHZ`` env var. + n_cu_override: Override the number of Compute Units used by + Origami's performance model. Set to ``1`` for + per-tile / single-CU simulation (e.g. SDPA tile-level + modelling). If *None*, the profile's default CU + count is used. """ self._gpu_arch = gpu_arch or os.getenv("PRIMUS_GPU_ARCH", None) if self._gpu_arch is not None: @@ -161,6 +168,8 @@ def __init__( int(_env_clock) if _env_clock else None ) + self._n_cu_override = n_cu_override + # Lazily initialised origami objects – see ``_ensure_initialized``. self._hardware = None # origami.hardware_t self._configs = None # list[origami.config_t] @@ -179,6 +188,19 @@ def name(self) -> str: def is_available(self) -> bool: return _try_import_origami() + @property + def hbm_bandwidth_gbps(self) -> Optional[float]: + """Peak HBM bandwidth for the target architecture (GB/s). + + Resolved from the arch profile (``_KNOWN_PROFILES``) or the + ``PRIMUS_GPU_ARCH`` env var. Returns ``None`` only when no + architecture could be determined. + """ + arch = self._gpu_arch or os.getenv("PRIMUS_GPU_ARCH", "mi300x") + arch = arch.lower().strip() + profile = _KNOWN_PROFILES.get(arch) + return profile.hbm_bandwidth_gbps if profile is not None else None + def simulate_gemm( self, m: int, @@ -196,10 +218,26 @@ def simulate_gemm( "#subdirectory=shared/origami/python" ) - # FP8 fallback: if Origami doesn't support FP8 MI for this arch, - # simulate as BF16. On MI300X, Origami BF16 predictions already - # closely match FP8-hybrid measured performance (the natural model - # overestimation roughly offsets the FP8 compute gain). + # FP8 fallback strategy + # ~~~~~~~~~~~~~~~~~~~~~~ + # Origami v0.1.0 maps the Primus "fp8" dtype to "bf8_fnuz" (the FNUZ + # variant used by gfx942 / MI300X), but this datatype string is not + # yet recognised by Origami's ``string_to_datatype``. As a result, + # ``_ensure_initialized("fp8")`` sets ``_fp8_mi_unavailable = True`` + # on *all* current architectures. + # + # When the FP8 MI is unavailable we simulate in BF16 and halve the + # time (``÷ 2.0`` below). This is a first-order approximation: + # FP8 doubles the matrix-instruction throughput relative to BF16 on + # both gfx942 and gfx950. + # + # Origami does expose *native* ``bf8`` / ``f8`` MI for gfx950 + # (MI=16×16×128 vs BF16 MI=16×16×32), but its tile-level performance + # model currently does not translate the higher MI throughput into + # proportional time savings — native FP8 predictions are only ~1.03× + # faster than BF16 instead of the expected ~2×. Until Origami's FP8 + # performance model is updated, the BF16 ÷ 2 heuristic is more + # accurate for all supported architectures. fp8_fallback = False sim_dtype = dtype if dtype == "fp8": @@ -248,9 +286,10 @@ def simulate_gemm( latency_cycles = result.latency time_ms = latency_cycles / (self._clock_ghz * 1e6) - # FP8 speedup: when Origami can't natively simulate FP8 (no FP8 MI) - # we ran the simulation in BF16. FP8 has 2x the throughput of BF16 - # on gfx942, so divide the time by 2. + # FP8 speedup heuristic: we simulated in BF16 because Origami could + # not use a native FP8 matrix instruction. FP8 MFMA throughput is 2× + # that of BF16 on both gfx942 and gfx950, so halving the BF16 time is + # a reasonable first-order approximation for compute-bound GEMMs. if fp8_fallback: time_ms /= 2.0 @@ -368,10 +407,11 @@ def _get_hardware(self): clock_khz = profile.compute_clock_khz if self._clock_override_mhz is not None: clock_khz = self._clock_override_mhz * 1000 + n_cu = self._n_cu_override if self._n_cu_override is not None else profile.n_cu arch_enum = getattr(_origami.architecture_t, profile.arch_enum_name) hw = _origami.get_hardware_for_arch( arch_enum, - profile.n_cu, + n_cu, profile.lds_capacity, profile.l2_capacity, clock_khz, @@ -380,10 +420,11 @@ def _get_hardware(self): override_tag = "" if self._clock_override_mhz is not None: override_tag = " (overridden via --gpu-clock-mhz)" + cu_tag = f" (n_cu_override={n_cu})" if self._n_cu_override is not None else "" print( f"[Primus:Origami] Using hardware profile for " - f"'{self._gpu_arch}': N_CU={profile.n_cu}, " - f"clock={clock_khz / 1e6:.1f} GHz{override_tag}" + f"'{self._gpu_arch}': N_CU={n_cu}, " + f"clock={clock_khz / 1e6:.1f} GHz{override_tag}{cu_tag}" ) return hw @@ -421,10 +462,11 @@ def _get_hardware(self): clock_khz = profile.compute_clock_khz if self._clock_override_mhz is not None: clock_khz = self._clock_override_mhz * 1000 + n_cu = self._n_cu_override if self._n_cu_override is not None else profile.n_cu arch_enum = getattr(_origami.architecture_t, profile.arch_enum_name) hw = _origami.get_hardware_for_arch( arch_enum, - profile.n_cu, + n_cu, profile.lds_capacity, profile.l2_capacity, clock_khz, @@ -433,9 +475,10 @@ def _get_hardware(self): override_tag = "" if self._clock_override_mhz is not None: override_tag = " (overridden via --gpu-clock-mhz)" + cu_tag = f" (n_cu_override={n_cu})" if self._n_cu_override is not None else "" print( f"[Primus:Origami] Using known hardware profile for " - f"'{self._gpu_arch}': N_CU={profile.n_cu}, " - f"clock={clock_khz / 1e6:.1f} GHz{override_tag}" + f"'{self._gpu_arch}': N_CU={n_cu}, " + f"clock={clock_khz / 1e6:.1f} GHz{override_tag}{cu_tag}" ) return hw diff --git a/primus/core/projection/simulation_backends/sdpa_simulator.py b/primus/core/projection/simulation_backends/sdpa_simulator.py index 1b8052630..8b957321d 100644 --- a/primus/core/projection/simulation_backends/sdpa_simulator.py +++ b/primus/core/projection/simulation_backends/sdpa_simulator.py @@ -15,7 +15,6 @@ • 1 Thread-Group, 8 Wavefronts per workgroup (512 threads) • Q-tile = 256 rows (32m × 8 wavefronts) • KV-tile = 64 columns per loop iteration - • MFMA instruction: ``v_mfma_f32_32x32x16_bf16`` • 64 MFMAs per loop iteration (QKᵀ + softmax + PV, pipelined) • Workgroups = ⌈S / 256⌉ × B × H @@ -28,17 +27,21 @@ • Workgroups = ⌈S / 256⌉ × B × H • Inner-loop iterations = ⌈S / 16⌉ (over Q blocks) -The model uses a **roofline** approach: - time = max(compute_time, memory_time, atomic_time) -with FAv3-specific compute/memory efficiency factors and CU utilisation -derived from the tile sizes. +Tile-level simulation using Origami with a 1-CU backend. +Flash Attention is a fused kernel where sub-operations (QKᵀ, softmax, PV) +execute **sequentially** within each workgroup tile. By running Origami on +a single CU for each per-tile GEMM and then multiplying by the number of +waves (workgroups / N_CU), we capture the additive cost structure and +per-tile wave quantisation effects that a global roofline misses. + +**Origami is required** — this simulator cannot operate without it. In the backward pass, the dQ gradient is accumulated across KV-workgroups using ``buffer_atomic_add_f32`` (72 atomic instructions in the kernel). Each KV-workgroup processes all Q positions and atomically adds its partial dQ contribution, leading to contention proportional to ⌈S / 256⌉ concurrent -writers per dQ cache line. The atomic overhead is modelled as a separate -bottleneck dimension in the roofline. +writers per dQ cache line. The atomic overhead is modelled as an additive +cost on top of the compute/memory time. """ from __future__ import annotations @@ -77,6 +80,16 @@ class _FAv3TileConfig: _FAV3_BWD = _FAv3TileConfig(q_tile_m=16, kv_tile_n=256, n_wavefronts=4) +# ========================================================================= +# Backward dQ atomic latencies (for tile-level model) +# ========================================================================= +# dQ is accumulated via buffer_atomic_add_f32 across KV-workgroups. +# Latency estimates for CDNA3 (gfx942) at typical clocks: +_ATOMIC_LATENCY_GLOBAL_NS = 400 # HBM read-modify-write latency per atomic op +_ATOMIC_LATENCY_LOCAL_NS = 40 # L1 / LDS atomic latency per op +_WARP_SIZE = 64 # CDNA wavefront width + + # ========================================================================= # GPU hardware specs # ========================================================================= @@ -97,9 +110,6 @@ class GPUHardwareSpec: # Total CUs on the device n_cu: int = 304 # MI300X - # Max wavefronts per CU (SIMD occupancy limit) - max_waves_per_cu: int = 8 - # Number of XCDs on the device (cross-die atomics are more expensive) n_xcd: int = 8 # MI300X has 8 XCDs @@ -171,7 +181,7 @@ def _get_hardware_spec( _PROFILE_CLOCK_MHZ = { "mi300x": 2100, "gfx942": 2100, - "mi325x": 1200, + "mi325x": 2100, # same gfx942 compute die as MI300X "mi355x": 2100, "gfx950": 2100, "mi300a": 2100, @@ -198,29 +208,24 @@ class SDPASimulator(SDPASimulationBackend): """ Analytical SDPA simulation modelling the FAv3 kernel structure. - The model captures: - 1. **Total FLOPs** from the SDPA math (QKᵀ, softmax, PV for fwd; - dV, dP/dS, dQ, dK, softmax-bwd for bwd). - 2. **Flash-Attention memory IO** — Q/K/V are streamed from HBM; - the full S/P matrices are never materialised. - 3. **CU utilisation** — derived from the FAv3 tile sizes and the - number of workgroups that can execute concurrently. - 4. **Achieved efficiency** — higher than generic kernels because - FAv3 is hand-tuned ISA with software pipelining and LDS-based - data movement. - 5. **Atomic overhead (BWD only)** — dQ is accumulated across - KV-workgroups via ``buffer_atomic_add_f32`` in FP32. The model - accounts for the read-modify-write penalty and contention from - ⌈S / 256⌉ concurrent writers per dQ cache line. + Uses an Origami GEMM backend with ``n_cu=1`` to simulate per-tile + GEMM execution time. Flash Attention is modelled as a fused kernel + where QKᵀ, softmax, and PV are sequential within each workgroup. + The total time = (per-tile-QKᵀ + per-tile-PV) × num_waves. This + naturally captures wave quantisation and per-tile efficiency without + needing an empirical ``compute_efficiency`` parameter. + + Also models the backward dQ atomic overhead from + ``buffer_atomic_add_f32`` accumulation across KV-workgroups. + + **Origami is required** — instantiation will fail if the Origami + backend is not available. """ def __init__( self, gpu_arch: Optional[str] = None, hardware_spec: Optional[GPUHardwareSpec] = None, - compute_efficiency: float = 0.51, - memory_efficiency: float = 0.85, - atomic_rmw_factor: float = 4.0, gpu_clock_mhz: Optional[int] = None, ): """ @@ -228,29 +233,30 @@ def __init__( gpu_arch: GPU architecture string (e.g. "mi300x", "gfx942", "mi355x", "gfx950"). hardware_spec: Override hardware spec directly. - compute_efficiency: Fraction of peak TFLOPS achieved (0-1). - The lower-than-peak efficiency (vs theoretical 0.75-0.85) - accounts for GQA head broadcasting, LDS bank conflicts, - barrier synchronisation, and register pressure. - memory_efficiency: Fraction of peak HBM bandwidth achieved (0-1). - FAv3 streaming pattern typically achieves 0.80-0.90. - atomic_rmw_factor: Base slowdown of ``buffer_atomic_add_f32`` - relative to a plain ``buffer_store`` (read-modify-write - overhead). Typical range 3-6 on CDNA3. Contention from - multiple writers is modelled *on top* of this factor. gpu_clock_mhz: Override the GPU compute clock frequency in MHz. If provided, the profile's TFLOPS are scaled proportionally. + + Raises: + RuntimeError: If the Origami backend is not available. """ self._hw = hardware_spec or _get_hardware_spec(gpu_arch, gpu_clock_mhz) - self._compute_eff = compute_efficiency - self._memory_eff = memory_efficiency - self._atomic_rmw_factor = atomic_rmw_factor + + # Create the Origami 1-CU backend for tile-level simulation. + # This is required — SDPA simulation cannot proceed without it. + self._tile_gemm = self._create_tile_gemm_backend(gpu_arch, gpu_clock_mhz) + if self._tile_gemm is None: + raise RuntimeError( + "SDPASimulator requires the Origami backend but it is not " + "available. Please install the 'origami' package or ensure " + "primus.core.projection.simulation_backends.origami_backend " + "is importable." + ) def name(self) -> str: return "sdpa_simulator (FAv3)" def is_available(self) -> bool: - return True # Pure-Python analytical model, always available + return self._tile_gemm is not None and self._tile_gemm.is_available() # ------------------------------------------------------------------ # Public API @@ -269,8 +275,8 @@ def simulate_sdpa( head_dim_v: Optional[int] = None, ) -> SimulationResult: """ - Simulate FAv3 SDPA execution time using a roofline model - parameterised by the actual FAv3 tile configuration. + Simulate FAv3 SDPA execution time using Origami 1-CU tile-level + simulation parameterised by the actual FAv3 tile configuration. Args: batch_size: Batch size (B). @@ -299,162 +305,214 @@ def simulate_sdpa( D_v = head_dim_v if head_dim_v is not None else head_dim bpe = self._bytes_per_element(dtype) - # GQA ratio: each KV head serves (H_Q / H_K) query heads. - # The FLOPs are still per-query-head, so total FLOPs scale with H_Q. - # Memory for K/V scales with H_K, memory for Q/O scales with H_Q. + return self._simulate_tile_level( + B, H_Q, S_Q, S_K, H_K, D_qk, D_v, causal, dtype, bpe, + ) + # ------------------------------------------------------------------ + # Tile-level simulation (primary mode — requires Origami) + # ------------------------------------------------------------------ + + def _create_tile_gemm_backend( + self, + gpu_arch: Optional[str], + gpu_clock_mhz: Optional[int], + ): + """Try to create an Origami backend with 1 CU for per-tile simulation. + + Returns the backend on success, or ``None`` if Origami is not available. + """ + try: + from primus.core.projection.simulation_backends.origami_backend import ( + OrigamiGEMMBackend, + ) + + backend = OrigamiGEMMBackend( + gpu_arch=gpu_arch, + gpu_clock_mhz=gpu_clock_mhz, + n_cu_override=1, + ) + if backend.is_available(): + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print( + "[Primus:SDPA] Using Origami 1-CU tile-level simulation " + "for Flash Attention" + ) + return backend + except Exception: + pass + return None + + def _simulate_tile_level( + self, + B: int, + H_Q: int, + S_Q: int, + S_K: int, + H_K: int, + D_qk: int, + D_v: int, + causal: bool, + dtype: str, + bpe: int, + ) -> SimulationResult: + """ + Tile-level SDPA simulation using Origami on a single CU. + + Flash Attention is a **fused** kernel — within each workgroup, the + sub-operations (QKᵀ, softmax, PV) execute **sequentially** and the + intermediate S/P matrices stay in LDS/registers (never written to + HBM). This means the correct timing model is **additive** across + sub-operations per tile, not a global ``max(compute, memory)`` + roofline. + + Approach: + 1. Simulate each per-tile GEMM using Origami with ``n_cu=1``. + This captures tile-level wave quantisation, LDS traffic, and + pipeline effects that a global FLOP-rate model misses. + 2. Sum the per-tile GEMM times (additive, sequential execution). + 3. Multiply by ``num_waves = ⌈workgroups / N_CU⌉`` to account for + CU-level parallelism across tiles. + 4. Add dQ atomic overhead (backward only). + + Forward per-workgroup (q_tile_m=256 Q rows, sweeps all S_K): + QKᵀ: Q_tile[256, D_qk] × Kᵀ[D_qk, S_K] → S[256, S_K] + PV: P_tile[256, S_K] × V[S_K, D_v] → O[256, D_v] + Workgroups = ⌈S_Q / 256⌉ × B × H_Q + + Backward per-workgroup (kv_tile_n=256 KV cols, sweeps all S_Q): + 5 GEMMs per workgroup (FA backward algorithm): + 1. QKᵀ recompute: Q[S_Q, D_qk] × Kᵀ[D_qk, 256] → S[S_Q, 256] + 2. dP = dO × Vᵀ: dO[S_Q, D_v] × Vᵀ[D_v, 256] → dP[S_Q, 256] + 3. dV = Pᵀ × dO: Pᵀ[256, S_Q] × dO[S_Q, D_v] → dV[256, D_v] + 4. dQ = dS × K: dS[S_Q, 256] × K[256, D_qk] → dQ[S_Q, D_qk] + 5. dK = dSᵀ × Q: dSᵀ[256, S_Q] × Q[S_Q, D_qk] → dK[256, D_qk] + Workgroups = ⌈S_K / 256⌉ × B × H_Q + """ + assert self._tile_gemm is not None + N_CU = self._hw.n_cu causal_factor = 0.5 if causal else 1.0 # ============================================================== - # 1. COMPUTE (FLOP counts) + # FORWARD # ============================================================== - # Forward (per query head, then × H_Q) - # QKᵀ : 2·B·H_Q·S_Q·S_K·D_qk (batched GEMM) - # softmax : ~5·B·H_Q·S_Q·S_K (exp, sub-max, sum, div, mul) - # PV : 2·B·H_Q·S_Q·S_K·D_v (batched GEMM — P is S_Q×S_K, V is S_K×D_v) - # NOTE for PV: output is (S_Q, D_v), inner dim is S_K. - # For causal masking, only ~half the S_Q×S_K elements are computed - # (only valid when S_Q == S_K; for cross-attn causal is usually False). - # - # When D_qk == D_v (standard MHA/GQA) this reduces to the familiar - # 2 × (2·B·H·S·S·D) formula. For MLA, D_qk > D_v (e.g. 192 vs 128). - fwd_qk_flops = 2.0 * B * H_Q * S_Q * S_K * D_qk * causal_factor - fwd_pv_flops = 2.0 * B * H_Q * S_Q * S_K * D_v * causal_factor - fwd_softmax_flops = 5.0 * B * H_Q * S_Q * S_K * causal_factor - fwd_flops = fwd_qk_flops + fwd_pv_flops + fwd_softmax_flops - - # Backward (4 batched GEMMs + softmax backward) - # dV = Pᵀ @ dO : 2·B·H_Q·S_K·S_Q·D_v (inner dim S_Q, out S_K×D_v) - # dP = dO @ Vᵀ : 2·B·H_Q·S_Q·S_K·D_v (inner dim D_v, out S_Q×S_K) - # dS = softmax_bwd : ~5·B·H_Q·S_Q·S_K - # dQ = dS @ K : 2·B·H_Q·S_Q·S_K·D_qk (inner dim S_K, out S_Q×D_qk) - # dK = dSᵀ @ Q : 2·B·H_Q·S_K·S_Q·D_qk (inner dim S_Q, out S_K×D_qk) - bwd_dv_flops = 2.0 * B * H_Q * S_K * S_Q * D_v * causal_factor - bwd_dp_flops = 2.0 * B * H_Q * S_Q * S_K * D_v * causal_factor - bwd_dq_flops = 2.0 * B * H_Q * S_Q * S_K * D_qk * causal_factor - bwd_dk_flops = 2.0 * B * H_Q * S_K * S_Q * D_qk * causal_factor - bwd_softmax_flops = 5.0 * B * H_Q * S_Q * S_K * causal_factor - bwd_flops = bwd_dv_flops + bwd_dp_flops + bwd_dq_flops + bwd_dk_flops + bwd_softmax_flops + fwd_n_wgs = math.ceil(S_Q / _FAV3_FWD.q_tile_m) * B * H_Q + fwd_waves = math.ceil(fwd_n_wgs / N_CU) + + # Per-workgroup GEMMs on 1 CU (tile sweeps all S_K positions): + # QKᵀ: [q_tile_m, D_qk, S_K] + r_fwd_qk = self._tile_gemm.simulate_gemm( + m=_FAV3_FWD.q_tile_m, n=S_K, k=D_qk, dtype=dtype, + ) + # PV: [q_tile_m, S_K, D_v] + r_fwd_pv = self._tile_gemm.simulate_gemm( + m=_FAV3_FWD.q_tile_m, n=D_v, k=S_K, dtype=dtype, + ) + + fwd_time_ms = ( + r_fwd_qk.forward_time_ms + r_fwd_pv.forward_time_ms + ) * fwd_waves # ============================================================== - # 2. MEMORY IO (Flash Attention – no S/P materialised to HBM) + # BACKWARD # ============================================================== - # Q and K use D_qk; V and O use D_v. - # Forward reads: Q (B·H_Q·S_Q·D_qk), K (B·H_K·S_K·D_qk), V (B·H_K·S_K·D_v) - # Forward writes: O (B·H_Q·S_Q·D_v) + logsumexp (B·H_Q·S_Q, fp32) - fwd_read_bytes = ( + kv_tile = _FAV3_BWD.kv_tile_n # 256 + bwd_n_wgs = math.ceil(S_K / kv_tile) * B * H_Q + bwd_waves = math.ceil(bwd_n_wgs / N_CU) + + # Per-workgroup GEMMs (5 operations, full Q-sweep on 1 CU): + # 1. QKᵀ recompute: [S_Q, D_qk, kv_tile] + r_bwd_qk = self._tile_gemm.simulate_gemm( + m=S_Q, n=kv_tile, k=D_qk, dtype=dtype, + ) + # 2. dP = dO × Vᵀ: [S_Q, D_v, kv_tile] + r_bwd_dp = self._tile_gemm.simulate_gemm( + m=S_Q, n=kv_tile, k=D_v, dtype=dtype, + ) + # 3. dV = Pᵀ × dO: [kv_tile, S_Q, D_v] + r_bwd_dv = self._tile_gemm.simulate_gemm( + m=kv_tile, n=D_v, k=S_Q, dtype=dtype, + ) + # 4. dQ = dS × K: [S_Q, kv_tile, D_qk] + r_bwd_dq = self._tile_gemm.simulate_gemm( + m=S_Q, n=D_qk, k=kv_tile, dtype=dtype, + ) + # 5. dK = dSᵀ × Q: [kv_tile, S_Q, D_qk] + r_bwd_dk = self._tile_gemm.simulate_gemm( + m=kv_tile, n=D_qk, k=S_Q, dtype=dtype, + ) + + bwd_compute_ms = ( + r_bwd_qk.forward_time_ms + + r_bwd_dp.forward_time_ms + + r_bwd_dv.forward_time_ms + + r_bwd_dq.forward_time_ms + + r_bwd_dk.forward_time_ms + ) * bwd_waves + + # ── Backward dQ atomics (latency-based model) ── + # Each KV-workgroup atomically accumulates dQ via buffer_atomic_add_f32. + # The latency model counts warp-level reduction updates (global and + # local) and multiplies by the per-op latency. + num_k_tiles = math.ceil(kv_tile / kv_tile) # = 1 + warp_updates_global = math.ceil( + num_k_tiles * math.ceil(D_qk / _WARP_SIZE) + ) + total_updates_global = warp_updates_global * bwd_waves + + warp_updates_local = math.ceil( + kv_tile * math.ceil(D_qk / _WARP_SIZE) + ) + total_updates_local = warp_updates_local * bwd_waves + + bwd_atomic_ms = ( + _ATOMIC_LATENCY_GLOBAL_NS * total_updates_global + + _ATOMIC_LATENCY_LOCAL_NS * total_updates_local + ) / 1e6 # ns → ms + + bwd_time_ms = bwd_compute_ms + bwd_atomic_ms + + # ============================================================== + # METADATA (FLOPs, bytes — for achieved-TFLOPS reporting) + # ============================================================== + fwd_flops = ( + 2.0 * B * H_Q * S_Q * S_K * D_qk # QKᵀ + + 2.0 * B * H_Q * S_Q * S_K * D_v # PV + + 5.0 * B * H_Q * S_Q * S_K # softmax + ) * causal_factor + + bwd_flops = ( + 2.0 * B * H_Q * S_Q * S_K * D_qk # QKᵀ recomp + + 2.0 * B * H_Q * S_Q * S_K * D_v # dP + + 2.0 * B * H_Q * S_K * S_Q * D_v # dV + + 2.0 * B * H_Q * S_Q * S_K * D_qk # dQ + + 2.0 * B * H_Q * S_K * S_Q * D_qk # dK + + 5.0 * B * H_Q * S_Q * S_K # softmax bwd + ) * causal_factor + + fwd_bytes = ( B * H_Q * S_Q * D_qk * bpe # Q + B * H_K * S_K * D_qk * bpe # K + B * H_K * S_K * D_v * bpe # V + + B * H_Q * S_Q * D_v * bpe # O + + B * H_Q * S_Q * 4 # logsumexp (fp32) ) - fwd_write_bytes = ( - B * H_Q * S_Q * D_v * bpe + B * H_Q * S_Q * 4 # O # logsumexp (fp32) - ) - fwd_bytes = fwd_read_bytes + fwd_write_bytes - - # Backward reads: Q, K, V, O, dO + logsumexp - # Backward regular writes: dK (B·H_K·S_K·D_qk) + dV (B·H_K·S_K·D_v) - # NOTE: dQ uses buffer_atomic_add_f32 — accounted separately. - bwd_read_bytes = ( + bwd_bytes = ( B * H_Q * S_Q * D_qk * bpe # Q + B * H_K * S_K * D_qk * bpe # K + B * H_K * S_K * D_v * bpe # V + B * H_Q * S_Q * D_v * bpe # O + B * H_Q * S_Q * D_v * bpe # dO + B * H_Q * S_Q * 4 # logsumexp (fp32) - ) - bwd_regular_write_bytes = ( - B * H_K * S_K * D_qk * bpe # dK + + B * H_K * S_K * D_qk * bpe # dK + B * H_K * S_K * D_v * bpe # dV ) - bwd_bytes = bwd_read_bytes + bwd_regular_write_bytes - - # ============================================================== - # 3. dQ ATOMIC OVERHEAD (BWD only) - # ============================================================== - # In FAv3 backward, each KV-workgroup loops over ALL Q positions - # and atomically accumulates its partial dQ via buffer_atomic_add_f32. - # - # From the FAv3 backward kernel: - # - 72 buffer_atomic_add_f32 instructions in the kernel - # - 8 atomics per Q-block (per wavefront, 64 threads each) - # - 4 wavefronts per workgroup - # - Per Q-block: 8 × 64 × 4W = 2048 atomic ops = 8 KB (FP32) - # = 16 rows × 128 cols × 4 bytes = 8192 bytes ✓ - # - # Contention & L2 coalescing: - # ceil(S_K/256) KV-workgroups all write to the same dQ rows. - # Workgroups on the SAME XCD can coalesce their atomics in the - # local L2 cache (the add is accumulated in L2, only the final - # value is flushed to HBM). So the effective number of HBM - # atomic writes per dQ element is min(n_kv_wgs, n_xcd) rather - # than the full n_kv_wgs. - # - # Each HBM atomic write is a read-modify-write, which costs - # ~rmw_factor × the bandwidth of a regular store. - n_kv_workgroups = math.ceil(S_K / _FAV3_BWD.kv_tile_n) - - # How many KV-workgroups per XCD (for L2 coalescing estimate) - hbm_writers_per_element = min(n_kv_workgroups, self._hw.n_xcd) - - # Effective dQ bytes hitting HBM (after L2 coalescing) - # dQ shape is (B, H_Q, S_Q, D_qk), stored in FP32 (4 bytes) - dq_atomic_bytes = float(hbm_writers_per_element) * B * H_Q * S_Q * D_qk * 4.0 - - # Atomic slowdown = just the RMW factor (contention within-XCD - # is absorbed by L2; cross-XCD traffic goes to different memory - # channels and can proceed in parallel) - atomic_slowdown = self._atomic_rmw_factor - - # ============================================================== - # 4. CU UTILISATION (from FAv3 tile config) - # ============================================================== - fwd_cu_util = self._cu_utilisation(B, H_Q, S_Q, _FAV3_FWD) - bwd_cu_util = self._cu_utilisation(B, H_Q, S_K, _FAV3_BWD) - - # ============================================================== - # 5. ROOFLINE: time = max(compute, memory, atomics) - # ============================================================== - peak_tflops = self._peak_tflops(dtype) - - # Effective throughput = peak × efficiency × CU utilisation - fwd_eff_tflops = peak_tflops * self._compute_eff * fwd_cu_util - bwd_eff_tflops = peak_tflops * self._compute_eff * bwd_cu_util - - fwd_eff_bw = self._hw.hbm_bandwidth_gbps * self._memory_eff - bwd_eff_bw = self._hw.hbm_bandwidth_gbps * self._memory_eff - - # Effective atomic bandwidth (HBM BW reduced by RMW + contention) - bwd_eff_atomic_bw = ( - self._hw.hbm_bandwidth_gbps * self._memory_eff / atomic_slowdown - ) - - # Compute-bound time (ms) - fwd_compute_ms = (fwd_flops / (fwd_eff_tflops * 1e12)) * 1e3 - bwd_compute_ms = (bwd_flops / (bwd_eff_tflops * 1e12)) * 1e3 - - # Memory-bound time (ms) — regular (non-atomic) IO - fwd_memory_ms = (fwd_bytes / (fwd_eff_bw * 1e9)) * 1e3 - bwd_memory_ms = (bwd_bytes / (bwd_eff_bw * 1e9)) * 1e3 - # Atomic-bound time (ms) — dQ accumulation via buffer_atomic_add_f32 - bwd_atomic_ms = (dq_atomic_bytes / (bwd_eff_atomic_bw * 1e9)) * 1e3 - - fwd_time_ms = max(fwd_compute_ms, fwd_memory_ms) - bwd_time_ms = max(bwd_compute_ms, bwd_memory_ms, bwd_atomic_ms) - - # Achieved metrics fwd_achieved_tflops = ( (fwd_flops / (fwd_time_ms * 1e-3)) / 1e12 if fwd_time_ms > 0 else 0 ) - # Determine what bounds each pass - bwd_bottleneck = "compute" - if bwd_atomic_ms >= bwd_compute_ms and bwd_atomic_ms >= bwd_memory_ms: - bwd_bottleneck = "atomic" - elif bwd_memory_ms >= bwd_compute_ms: - bwd_bottleneck = "memory" - return SimulationResult( forward_time_ms=fwd_time_ms, backward_time_ms=bwd_time_ms, @@ -463,13 +521,14 @@ def simulate_sdpa( (fwd_bytes / (fwd_time_ms * 1e-3)) / 1e9 if fwd_time_ms > 0 else 0 ), metadata={ - "backend": "sdpa_simulator (FAv3)", - "fwd_compute_bound": fwd_compute_ms >= fwd_memory_ms, - "fwd_compute_ms": fwd_compute_ms, - "fwd_memory_ms": fwd_memory_ms, - "bwd_bottleneck": bwd_bottleneck, + "backend": "sdpa_simulator (FAv3 tile-level, Origami 1-CU)", + # Standard metadata keys (for compatibility) + "fwd_compute_bound": True, + "fwd_compute_ms": fwd_time_ms, + "fwd_memory_ms": 0.0, # included in per-tile Origami model + "bwd_bottleneck": "compute+atomic", "bwd_compute_ms": bwd_compute_ms, - "bwd_memory_ms": bwd_memory_ms, + "bwd_memory_ms": 0.0, # included in per-tile Origami model "bwd_atomic_ms": bwd_atomic_ms, "fwd_flops": fwd_flops, "bwd_flops": bwd_flops, @@ -479,23 +538,25 @@ def simulate_sdpa( "seq_len_kv": S_K, "num_heads_q": H_Q, "num_heads_kv": H_K, - # dQ atomic details (buffer_atomic_add_f32) - "bwd_dq_kv_workgroups": n_kv_workgroups, - "bwd_dq_hbm_writers_per_elem": hbm_writers_per_element, - "bwd_dq_atomic_hbm_bytes": dq_atomic_bytes, - "bwd_dq_rmw_factor": atomic_slowdown, - "bwd_eff_atomic_bw_gbps": bwd_eff_atomic_bw, - # CU utilisation - "fwd_cu_utilisation": fwd_cu_util, - "bwd_cu_utilisation": bwd_cu_util, "causal": causal, + # Tile-level details + "fwd_waves": fwd_waves, + "fwd_n_workgroups": fwd_n_wgs, + "fwd_qk_per_tile_ms": r_fwd_qk.forward_time_ms, + "fwd_pv_per_tile_ms": r_fwd_pv.forward_time_ms, + "bwd_waves": bwd_waves, + "bwd_n_workgroups": bwd_n_wgs, + "bwd_qk_recomp_per_tile_ms": r_bwd_qk.forward_time_ms, + "bwd_dp_per_tile_ms": r_bwd_dp.forward_time_ms, + "bwd_dv_per_tile_ms": r_bwd_dv.forward_time_ms, + "bwd_dq_per_tile_ms": r_bwd_dq.forward_time_ms, + "bwd_dk_per_tile_ms": r_bwd_dk.forward_time_ms, + "n_cu": N_CU, # FAv3 tile parameters "fwd_q_tile_m": _FAV3_FWD.q_tile_m, "fwd_kv_tile_n": _FAV3_FWD.kv_tile_n, - "fwd_wavefronts": _FAV3_FWD.n_wavefronts, "bwd_q_tile_m": _FAV3_BWD.q_tile_m, "bwd_kv_tile_n": _FAV3_BWD.kv_tile_n, - "bwd_wavefronts": _FAV3_BWD.n_wavefronts, }, ) @@ -506,50 +567,3 @@ def simulate_sdpa( def _bytes_per_element(self, dtype: str) -> int: return {"bf16": 2, "fp16": 2, "fp32": 4, "fp8": 1}.get(dtype, 2) - def _peak_tflops(self, dtype: str) -> float: - return { - "bf16": self._hw.peak_tflops_bf16, - "fp16": self._hw.peak_tflops_fp16, - "fp8": self._hw.peak_tflops_fp8, - "fp32": self._hw.peak_tflops_bf16 / 4, - }.get(dtype, self._hw.peak_tflops_bf16) - - def _cu_utilisation( - self, - batch_size: int, - num_heads: int, - seq_len: int, - tile_cfg: _FAv3TileConfig, - ) -> float: - """ - Estimate CU utilisation for a FAv3 kernel launch. - - FAv3 forward dispatches one workgroup per Q-tile per (batch, head). - FAv3 backward dispatches one workgroup per KV-tile per (batch, head). - - Each workgroup occupies ``n_wavefronts`` wavefront slots on a CU. - If the workgroup uses fewer than ``max_waves_per_cu`` wavefronts, - multiple workgroups *may* share a CU (higher occupancy). - - CU utilisation = min(active_CUs, N_CU) / N_CU - """ - # Number of workgroups - # For FWD: each wg handles q_tile_m rows → ceil(S / q_tile_m) wgs per (B,H) - # For BWD: each wg handles kv_tile_n cols → ceil(S / kv_tile_n) wgs per (B,H) - if tile_cfg is _FAV3_FWD: - n_tiles = math.ceil(seq_len / tile_cfg.q_tile_m) - else: - # BWD: workgroups over KV dimension - n_tiles = math.ceil(seq_len / tile_cfg.kv_tile_n) - - n_workgroups = n_tiles * batch_size * num_heads - - # How many workgroups can share a single CU? - wgs_per_cu = self._hw.max_waves_per_cu // tile_cfg.n_wavefronts - wgs_per_cu = max(wgs_per_cu, 1) - - # Effective CU slots - cu_slots = self._hw.n_cu * wgs_per_cu - active_slots = min(n_workgroups, cu_slots) - - return active_slots / cu_slots From 23291fd05576cf701e5cdecdc25fa311d97ac48c Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Sat, 21 Feb 2026 00:17:39 +0000 Subject: [PATCH 06/12] fix: correct A2A message size for TP>1 and TP-striding in collective model - A2A operates on S/TP tokens per GPU due to sequence parallelism. Divide dispatch_size by max(tp, 1) in both _estimate_ep_communication_overhead and calculate_collective_communication_time. - Account for TP striding (hp) in single_shot_alltoall, hierarchical_alltoall, and pxn_alltoall: effective EP ranks per node = node_size / hp. --- .../module_profilers/collective_model.py | 21 +++++++++++++------ .../performance_projection/projection.py | 14 +++++++++---- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/primus/core/projection/module_profilers/collective_model.py b/primus/core/projection/module_profilers/collective_model.py index f8a929d4e..66d11f1fc 100644 --- a/primus/core/projection/module_profilers/collective_model.py +++ b/primus/core/projection/module_profilers/collective_model.py @@ -443,7 +443,10 @@ def single_shot_alltoall(args, msg_size, gpus, groups=None, protocol=None): return 0 intra_node_fanout, inter_node_fanout = get_max_fanout(args) msg_size_per_peer = ceil(msg_size / gpus) - gpus_per_node = min(gpus, args.node_size) + # Account for TP striding: with TP (hp) > 1, each EP rank occupies + # hp GPUs, so only node_size/hp EP ranks fit on a single node. + hp = getattr(args, "hp", 1) + gpus_per_node = min(gpus, args.node_size // max(hp, 1)) nics_per_node = args.nics_per_node if args.nics_per_node else gpus_per_node intra_node_gpus = gpus_per_node - 1 inter_node_gpus = max(0, gpus - gpus_per_node) @@ -487,8 +490,11 @@ def hierarchical_alltoall(args, msg_size, gpus, groups=None, protocol=None): if gpus == 1 or msg_size == 0: return 0 - gpus_per_node = min(gpus, args.node_size) - num_nodes = ceil(gpus / args.node_size) + # Account for TP striding: with TP (hp) > 1, each EP rank occupies + # hp GPUs, so only node_size/hp EP ranks fit on a single node. + hp = getattr(args, "hp", 1) + gpus_per_node = min(gpus, args.node_size // max(hp, 1)) + num_nodes = ceil(gpus / max(gpus_per_node, 1)) nics_per_node = args.nics_per_node if args.nics_per_node else gpus_per_node if num_nodes == 1: @@ -535,8 +541,11 @@ def pxn_alltoall(args, msg_size, gpus, groups=None, protocol=None): original_msg_size = msg_size - # Nodes participating in the exchange - num_nodes = ceil(gpus / args.node_size) + # Account for TP striding: with TP (hp) > 1, each EP rank occupies + # hp GPUs, so only node_size/hp EP ranks fit on a single node. + hp = getattr(args, "hp", 1) + effective_gpus_per_node = min(gpus, args.node_size // max(hp, 1)) + num_nodes = ceil(gpus / max(effective_gpus_per_node, 1)) # If A2A is not crossing node boundaries, fall back to regular alltoall if num_nodes <= 1: @@ -564,7 +573,7 @@ def pxn_alltoall(args, msg_size, gpus, groups=None, protocol=None): scaleup_delay = node_lat + chunk_size / args.node_bw * 1.0e-3 # Assume PXN style alltoall with overlapped scale-up and scale-out - node_msg_size = int(original_msg_size * (args.node_size - 1) / args.node_size) + node_msg_size = int(original_msg_size * (effective_gpus_per_node - 1) / effective_gpus_per_node) scale_out_msg_size = int(original_msg_size * (num_nodes - 1) / num_nodes) # Calculate latencies with protocol inflation diff --git a/primus/core/projection/performance_projection/projection.py b/primus/core/projection/performance_projection/projection.py index 3b04e8d1f..6a27e544f 100644 --- a/primus/core/projection/performance_projection/projection.py +++ b/primus/core/projection/performance_projection/projection.py @@ -294,9 +294,11 @@ def calculate_collective_communication_time( message_info["moe_ar_no_overlap"] = False # 2. MoE All-to-All (EP group) + # With TP > 1 and sequence parallelism, each GPU holds S/TP tokens. + # The A2A dispatches these S/TP tokens; AG(TP) recovers full S after A2A. if ep > 1 and num_moe_layers > 0: - tokens_per_batch = seq_len * batch_size - dispatch_size = tokens_per_batch * hidden_size * moe_router_topk * 2 # BF16 + tokens_per_gpu = seq_len * batch_size // max(tp, 1) # S/TP with seq parallel + dispatch_size = tokens_per_gpu * hidden_size * moe_router_topk * 2 # BF16 a2a_dispatch = cm.alltoall(coll_args, dispatch_size, ep, groups=["ep"]) a2a_combine = cm.alltoall(coll_args, dispatch_size, ep, groups=["ep"]) @@ -1093,13 +1095,17 @@ def _estimate_ep_communication_overhead( ) # Calculate All-to-All message size for MoE layers + # With TP > 1 and sequence parallelism, each GPU holds S/TP tokens. + # The AlltoAll dispatcher sends these S/TP tokens across the EP group, + # then AllGather(TP) recovers full S tokens after A2A (see workflow in + # MoEAlltoAllTokenDispatcher: step 3 = A2A(EP), step 4 = AG(TP)). hidden_size = model_config.hidden_size batch_size = runtime_config.micro_batch_size seq_len = runtime_config.sequence_length moe_router_topk = getattr(model_config, "moe_router_topk", 2) - tokens_per_batch = seq_len * batch_size - dispatch_size = tokens_per_batch * hidden_size * moe_router_topk * 2 # BF16 + tokens_per_gpu = seq_len * batch_size // max(tp, 1) # S/TP with seq parallel + dispatch_size = tokens_per_gpu * hidden_size * moe_router_topk * 2 # BF16 # Calculate All-to-All time for original EP (dispatch + combine) a2a_dispatch_original = cm.alltoall( From 68a793515d9483b5aa9196ca8e0f480b1f1ca684 Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Sat, 21 Feb 2026 01:40:04 +0000 Subject: [PATCH 07/12] refactor: simplify FSDP overlap model to single 93% uniform factor Replace the three separate per-phase overlap percentages (FWD_AG=90%, BWD_AG=24%, RS=34%) with a single uniform 93% overlap applied to all FSDP communication (AllGather fwd, AllGather recompute, ReduceScatter). 93% is observed on the actual run for llama3-70b. --- .../performance_projection/projection.py | 39 ++++--------------- 1 file changed, 7 insertions(+), 32 deletions(-) diff --git a/primus/core/projection/performance_projection/projection.py b/primus/core/projection/performance_projection/projection.py index 6a27e544f..f16220ee3 100644 --- a/primus/core/projection/performance_projection/projection.py +++ b/primus/core/projection/performance_projection/projection.py @@ -437,22 +437,6 @@ def calculate_collective_communication_time( else: message_info["gradient_allreduce_overlapped"] = False - # FSDP overlap model - # --------------------------------------------------------------- - # FSDP2 prefetches next layer's AllGather while the current layer - # computes, and ReduceScatter runs after backward completes. - # Overlap differs significantly between forward and backward: - # - # Forward AG: ~90-95% overlap — prefetch hides AG behind compute - # Backward AG: ~24% overlap — eager prefetch finishes long before - # the compute stream is ready (dependency - # chain through previous layer's backward) - # RS: ~34% overlap — inherently post-compute, only partially - # overlaps with next layer's recompute AG - # - # These per-phase percentages are largely model-independent for FSDP2 - # with full recompute; the overall overlap is ~50% for 70B-class models - # and ~64% without recompute. if use_fsdp and dp > 1: overlap_fsdp = getattr(mp_config, "use_torch_fsdp2", False) if overlap_fsdp: @@ -476,25 +460,16 @@ def calculate_collective_communication_time( fwd_ag_total = total_fsdp_ag bwd_ag_total = 0.0 - # Per-phase overlap percentages - FWD_AG_OVERLAP = 0.90 # forward AG hidden behind compute - BWD_AG_OVERLAP = 0.24 # backward recompute AG (structural limit) - RS_OVERLAP = 0.34 # ReduceScatter (structural limit) - - hidden_fwd_ag = fwd_ag_total * FWD_AG_OVERLAP - hidden_bwd_ag = bwd_ag_total * BWD_AG_OVERLAP - hidden_rs = total_fsdp_rs * RS_OVERLAP + # Overlap factor applied uniformly to all FSDP + # communication - (AllGather fwd, AllGather recompute, ReduceScatter). + FSDP_OVERLAP = 0.93 - total_hidden = hidden_fwd_ag + hidden_bwd_ag + hidden_rs + total_fsdp = total_fsdp_ag + total_fsdp_rs + total_hidden = total_fsdp * FSDP_OVERLAP total_comm_time -= total_hidden message_info["fsdp_overlapped"] = True - message_info["fsdp_fwd_ag_overlap"] = FWD_AG_OVERLAP - message_info["fsdp_bwd_ag_overlap"] = BWD_AG_OVERLAP - message_info["fsdp_rs_overlap"] = RS_OVERLAP - total_fsdp = total_fsdp_ag + total_fsdp_rs - message_info["fsdp_overall_overlap"] = ( - total_hidden / total_fsdp if total_fsdp > 0 else 0 - ) + message_info["fsdp_overlap"] = FSDP_OVERLAP + message_info["fsdp_overall_overlap"] = FSDP_OVERLAP message_info["fsdp_exposed_ms"] = total_fsdp - total_hidden else: message_info["fsdp_overlapped"] = False From c73dbd8ab4e6d4ff941e0032dec7d705d67deb86 Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Sat, 21 Feb 2026 02:02:30 +0000 Subject: [PATCH 08/12] style: run black formatter on projection module Fix formatting issues flagged by CI code-lint job. --- .../memory_projection/projection.py | 20 ++- .../projection/module_profilers/attention.py | 8 +- .../module_profilers/collective_args.py | 4 +- .../module_profilers/collective_model.py | 138 ++++++++++----- .../projection/module_profilers/embedding.py | 17 +- .../module_profilers/language_model.py | 165 +++++++++++++----- .../projection/module_profilers/moe_mlp.py | 119 ++++++++----- .../module_profilers/output_layer.py | 17 +- .../projection/module_profilers/router.py | 4 +- .../module_profilers/transformer_layer.py | 87 ++++++--- .../core/projection/module_profilers/utils.py | 14 +- .../performance_projection/projection.py | 6 +- .../performance_projection/simulator.py | 82 ++++++--- primus/core/projection/profiler_spec.py | 6 +- .../projection/simulation_backends/base.py | 81 ++++++--- .../simulation_backends/origami_backend.py | 14 +- .../simulation_backends/sdpa_simulator.py | 59 +++++-- 17 files changed, 593 insertions(+), 248 deletions(-) diff --git a/primus/core/projection/memory_projection/projection.py b/primus/core/projection/memory_projection/projection.py index c1f988c35..e83b5e957 100644 --- a/primus/core/projection/memory_projection/projection.py +++ b/primus/core/projection/memory_projection/projection.py @@ -17,7 +17,9 @@ ) -def print_profiler_hierarchy(profiler, batch_size, seq_len, rank=None, name="root", depth=0, visited=None): +def print_profiler_hierarchy( + profiler, batch_size, seq_len, rank=None, name="root", depth=0, visited=None +): """ Recursively print the profiler hierarchy with num_params and activation_memory for each component. @@ -46,13 +48,17 @@ def print_profiler_hierarchy(profiler, batch_size, seq_len, rank=None, name="roo if depth == 0: # Only output the total number of parameters for the entire model for depth 0. num_params = profiler.estimated_num_params(rank=None) - print(f"{indent} Total Number of Parameters: {num_params / 1e9:.6f} Billion ({num_params:,})") + print( + f"{indent} Total Number of Parameters: {num_params / 1e9:.6f} Billion ({num_params:,})" + ) else: num_params = profiler.estimated_num_params(rank=rank) activation_mem = profiler.estimated_activation_memory(batch_size, seq_len) print(f"{indent}[{name}]") print(f"{indent} Params: {num_params / 1e9:.6f} Billion ({num_params:,})") - print(f"{indent} Activation Memory: {activation_mem / 1024 / 1024 / 1024:.4f} GB") + print( + f"{indent} Activation Memory: {activation_mem / 1024 / 1024 / 1024:.4f} GB" + ) # Recursively process sub_profilers if they exist if hasattr(profiler, "sub_profilers") and profiler.sub_profilers: @@ -79,7 +85,9 @@ def launch_projection_from_cli(args, overrides): """ cfg_path = Path(args.config) if not cfg_path.exists(): - raise FileNotFoundError(f"[Primus:Projection] Config file '{cfg_path}' not found.") + raise FileNotFoundError( + f"[Primus:Projection] Config file '{cfg_path}' not found." + ) config_parser = PrimusParser() primus_config = config_parser.parse(args) @@ -117,7 +125,9 @@ def launch_projection_from_cli(args, overrides): print("=" * 100) print(f"[Primus:Projection] Memory Projection Summary on Rank {rank}:") print(f" Params: {num_params / 1e9:.6f} Billion ({num_params:,})") - print(f" Param+Optimizer Memory: {num_params * num_bytes_per_param / 1024 / 1024 / 1024:.4f} GB") + print( + f" Param+Optimizer Memory: {num_params * num_bytes_per_param / 1024 / 1024 / 1024:.4f} GB" + ) print( f" Activation Memory (per batch size {batch_size}, seq len {seq_len}): " f"{activation_memory / 1024 / 1024 / 1024:.4f} GB" diff --git a/primus/core/projection/module_profilers/attention.py b/primus/core/projection/module_profilers/attention.py index b45bebb4c..5dc6eb88e 100644 --- a/primus/core/projection/module_profilers/attention.py +++ b/primus/core/projection/module_profilers/attention.py @@ -159,9 +159,7 @@ def _num_query_groups() -> int: return tokens_per_rank * (activation_width + ln_width) * bytes_per_value - def _simulate_mla_gemms( - self, batch_tokens: int, dtype: str - ) -> tuple[float, float]: + def _simulate_mla_gemms(self, batch_tokens: int, dtype: str) -> tuple[float, float]: """Simulate MLA (Multi-Latent Attention) projection GEMMs. MLA uses LoRA-factored Q and compressed KV projections instead of @@ -286,9 +284,7 @@ def _get_simulated_results( if getattr(args, "multi_latent_attention", False): # MLA: LoRA-factored Q and compressed KV projections # 6 forward GEMMs + 12 backward GEMMs - mla_fwd, mla_bwd = self._simulate_mla_gemms( - batch_tokens, gemm_dtype - ) + mla_fwd, mla_bwd = self._simulate_mla_gemms(batch_tokens, gemm_dtype) fwd_time += mla_fwd bwd_time += mla_bwd else: diff --git a/primus/core/projection/module_profilers/collective_args.py b/primus/core/projection/module_profilers/collective_args.py index b377e153e..cf6246d60 100644 --- a/primus/core/projection/module_profilers/collective_args.py +++ b/primus/core/projection/module_profilers/collective_args.py @@ -52,7 +52,9 @@ class CollectiveArgs: # All-to-all specific a2a_peer_lat: float = 0.45 # Per-peer latency overhead for inter-node a2a - a2a_intra_node_peer_lat: float = 28.0 # Per-peer latency overhead for intra-node a2a + a2a_intra_node_peer_lat: float = ( + 28.0 # Per-peer latency overhead for intra-node a2a + ) # Intra-node overhead is higher (~19-28 us) due to: # - P2P scatter/gather scheduling overhead # - RCCL internal synchronization barriers diff --git a/primus/core/projection/module_profilers/collective_model.py b/primus/core/projection/module_profilers/collective_model.py index 66d11f1fc..f22548758 100644 --- a/primus/core/projection/module_profilers/collective_model.py +++ b/primus/core/projection/module_profilers/collective_model.py @@ -110,7 +110,9 @@ def sendrecv(args, msg_size): return t -def direct_alltoall(args, msg_size, gpus, groups=["ep"], protocol=None, original_msg_size=None): +def direct_alltoall( + args, msg_size, gpus, groups=["ep"], protocol=None, original_msg_size=None +): """ Direct alltoall for HP=1, hierarchical with parallel NIC utilization. @@ -139,7 +141,9 @@ def direct_alltoall(args, msg_size, gpus, groups=["ep"], protocol=None, original intra_node_volume = msg_size * intra_fraction inter_node_volume_per_gpu = msg_size * inter_fraction - node_lat, intra_vol_adj = node_latency_and_volume_protocol(args, intra_node_volume, protocol) + node_lat, intra_vol_adj = node_latency_and_volume_protocol( + args, intra_node_volume, protocol + ) pod_lat = args.pod_lat # Intra-node time @@ -154,7 +158,10 @@ def direct_alltoall(args, msg_size, gpus, groups=["ep"], protocol=None, original t_inter = total_inter_volume / aggregate_inter_bw * 1.0e-3 + pod_lat else: remote_nodes = num_nodes - 1 - t_inter = inter_node_volume_per_gpu / (args.bw_eff * args.pod_bw) * 1.0e-3 + pod_lat * remote_nodes + t_inter = ( + inter_node_volume_per_gpu / (args.bw_eff * args.pod_bw) * 1.0e-3 + + pod_lat * remote_nodes + ) # Overlap intra and inter t_a2a = max(t_intra, t_inter) @@ -195,7 +202,9 @@ def run_alltoall(args, msg_size, gpus, groups=["ep"], protocol=None): elif (args.hp * gpus > args.node_size) and (args.hp * gpus) <= args.pod_size: # Alltoall fits within pod if args.hp == 1: - return direct_alltoall(args, msg_size, gpus, groups, protocol, original_msg_size) + return direct_alltoall( + args, msg_size, gpus, groups, protocol, original_msg_size + ) bw = args.bw_eff * args.pod_bw lat = args.pod_lat else: @@ -231,7 +240,11 @@ def cp_allgather(args, msg_size, gpus, protocol=None): bw = args.cluster_bw * args.bw_eff lat = args.cluster_lat # Logarithmic steps for tree allgather - t = msg_size / bw * 1.0e-3 + lat * np.ceil(np.log2(gpus)) + args.kernel_launch_latency + t = ( + msg_size / bw * 1.0e-3 + + lat * np.ceil(np.log2(gpus)) + + args.kernel_launch_latency + ) return t @@ -454,7 +467,9 @@ def single_shot_alltoall(args, msg_size, gpus, groups=None, protocol=None): t_intra_node = 0 t_inter_node = 0 if intra_node_gpus > 0: - node_lat, msg_size_per_peer_adj = node_latency_and_volume_protocol(args, msg_size_per_peer, protocol) + node_lat, msg_size_per_peer_adj = node_latency_and_volume_protocol( + args, msg_size_per_peer, protocol + ) node_bw = args.bw_eff * args.node_bw intra_node_rounds = ceil(intra_node_gpus / intra_node_fanout) t_intra_node = intra_node_rounds * node_lat @@ -505,7 +520,9 @@ def hierarchical_alltoall(args, msg_size, gpus, groups=None, protocol=None): inter_node_volume_per_gpu = msg_size * (gpus - gpus_per_node) / gpus # Intra-node time - node_lat, intra_vol_adj = node_latency_and_volume_protocol(args, intra_node_volume, protocol) + node_lat, intra_vol_adj = node_latency_and_volume_protocol( + args, intra_node_volume, protocol + ) node_bw = args.bw_eff * args.node_bw t_intra = node_lat + intra_vol_adj / node_bw * 1.0e-3 @@ -516,7 +533,10 @@ def hierarchical_alltoall(args, msg_size, gpus, groups=None, protocol=None): t_inter = args.pod_lat + total_inter_volume / aggregate_inter_bw * 1.0e-3 else: effective_pod_bw = args.bw_eff * args.pod_bw - t_inter = args.pod_lat * num_nodes + inter_node_volume_per_gpu / effective_pod_bw * 1.0e-3 + t_inter = ( + args.pod_lat * num_nodes + + inter_node_volume_per_gpu / effective_pod_bw * 1.0e-3 + ) t_total = max(t_intra, t_inter) t_total += args.kernel_launch_latency @@ -527,10 +547,10 @@ def hierarchical_alltoall(args, msg_size, gpus, groups=None, protocol=None): def pxn_alltoall(args, msg_size, gpus, groups=None, protocol=None): """ PXN All-to-All - pipelined implementation for DeepEP. - - Based on DeepEP implementation of pipelined PXN-A2A with pipelined + + Based on DeepEP implementation of pipelined PXN-A2A with pipelined scale-up (intra-node) and scale-out (inter-node) communication. - + Key features: - Overlaps scale-up and scale-out communication - Scale-out doesn't start until 4MB is accumulated for dispatch @@ -538,60 +558,68 @@ def pxn_alltoall(args, msg_size, gpus, groups=None, protocol=None): """ if gpus == 1 or msg_size == 0: return 0 - + original_msg_size = msg_size - + # Account for TP striding: with TP (hp) > 1, each EP rank occupies # hp GPUs, so only node_size/hp EP ranks fit on a single node. hp = getattr(args, "hp", 1) effective_gpus_per_node = min(gpus, args.node_size // max(hp, 1)) num_nodes = ceil(gpus / max(effective_gpus_per_node, 1)) - + # If A2A is not crossing node boundaries, fall back to regular alltoall if num_nodes <= 1: return min( single_shot_alltoall(args, msg_size, gpus, groups, protocol), - run_alltoall(args, msg_size, gpus, groups, protocol) + run_alltoall(args, msg_size, gpus, groups, protocol), ) - + # PXN - AlltoAll - pipeline implementation chunk_size = 4194304 # 4MB - DeepEP waits until 4MB is accumulated - + # Scale-out message size (inter-node communication) scale_out_msg_size = int(original_msg_size * (num_nodes - 1) / num_nodes) - + # Scale-up delay: time to accumulate 4MB before scale-out starts scaleup_delay = 0.0 if scale_out_msg_size < chunk_size: - # If total scale-out msg size is less than 4MB, + # If total scale-out msg size is less than 4MB, # total time = scaleup_delay + scaleout_time - node_lat, _ = node_latency_and_volume_protocol(args, scale_out_msg_size, protocol) + node_lat, _ = node_latency_and_volume_protocol( + args, scale_out_msg_size, protocol + ) scaleup_delay = node_lat + scale_out_msg_size / args.node_bw * 1.0e-3 else: # Scale-out comm doesn't start until 4MB is accumulated node_lat, _ = node_latency_and_volume_protocol(args, chunk_size, protocol) scaleup_delay = node_lat + chunk_size / args.node_bw * 1.0e-3 - + # Assume PXN style alltoall with overlapped scale-up and scale-out - node_msg_size = int(original_msg_size * (effective_gpus_per_node - 1) / effective_gpus_per_node) + node_msg_size = int( + original_msg_size * (effective_gpus_per_node - 1) / effective_gpus_per_node + ) scale_out_msg_size = int(original_msg_size * (num_nodes - 1) / num_nodes) - + # Calculate latencies with protocol inflation - node_lat, node_msg_size_adj = node_latency_and_volume_protocol(args, node_msg_size, protocol) - pod_lat, scale_out_msg_size_adj = pod_latency_and_volume_protocol(args, scale_out_msg_size, protocol) - + node_lat, node_msg_size_adj = node_latency_and_volume_protocol( + args, node_msg_size, protocol + ) + pod_lat, scale_out_msg_size_adj = pod_latency_and_volume_protocol( + args, scale_out_msg_size, protocol + ) + # Scale-up (intra-node) time node_bw = args.bw_eff * args.node_bw t_a2a_node = node_lat + node_msg_size_adj / node_bw * 1.0e-3 - + # Scale-out (inter-node) time with scaleup delay pod_bw = args.bw_eff * args.pod_bw t_a2a_scale_out = pod_lat + scale_out_msg_size_adj / pod_bw * 1.0e-3 + scaleup_delay - + # Total time is max of scale-up and scale-out (they overlap) t_total = max(t_a2a_node, t_a2a_scale_out) t_total += args.kernel_launch_latency - + return t_total @@ -610,7 +638,9 @@ def single_shot_allgather(args, msg_size, gpus, groups=None, protocol=None): t_intra_node = 0 t_inter_node = 0 if intra_node_gpus > 0: - node_lat, msg_size_per_peer_node = node_latency_and_volume_protocol(args, msg_size_per_peer, protocol) + node_lat, msg_size_per_peer_node = node_latency_and_volume_protocol( + args, msg_size_per_peer, protocol + ) node_bw = args.bw_eff * args.node_bw intra_node_rounds = ceil(intra_node_gpus / intra_node_fanout) t_intra_node = intra_node_rounds * node_lat @@ -643,7 +673,9 @@ def single_shot_reduce_scatter(args, msg_size, gpus, groups=["hp"], protocol=Non t_intra_node = 0 t_inter_node = 0 if intra_node_gpus > 0: - node_lat, msg_size_per_peer_node = node_latency_and_volume_protocol(args, msg_size_per_peer, protocol) + node_lat, msg_size_per_peer_node = node_latency_and_volume_protocol( + args, msg_size_per_peer, protocol + ) node_bw = args.bw_eff * args.node_bw intra_node_rounds = ceil(intra_node_gpus / intra_node_fanout) t_intra_node = intra_node_rounds * node_lat @@ -673,7 +705,9 @@ def single_shot_allreduce(args, msg_size, gpus, groups=["hp"], protocol=None): return 0 t_rs = single_shot_reduce_scatter(args, msg_size, gpus, groups, protocol) t_ag = single_shot_allgather(args, msg_size, gpus, groups, protocol) - t_ar = t_rs + t_ag - args.kernel_launch_latency # Remove duplicate kernel launch latency + t_ar = ( + t_rs + t_ag - args.kernel_launch_latency + ) # Remove duplicate kernel launch latency return t_ar @@ -695,7 +729,9 @@ def allreduce(args, msg_size, gpus, groups=["dp"]): hypercubeallreduce = oneshotHCallreduce(args, msg_size, gpus, protocol=p) ss_allreduce = single_shot_allreduce(args, msg_size, gpus, protocol=p) ringallreduce = RingAllreduce(args, msg_size, gpus, protocol=p) - min_ar_alg_time = min(ringallreduce, bruck_time, hypercubeallreduce, ss_allreduce) + min_ar_alg_time = min( + ringallreduce, bruck_time, hypercubeallreduce, ss_allreduce + ) if min_ar_alg_time < min_ar_time: min_ar_time = min_ar_alg_time return min_ar_time @@ -706,13 +742,15 @@ def alltoall(args, msg_size, gpus, groups=["ep"]): Select best alltoall algorithm among several options. Tries multiple protocols and algorithms, returns fastest. Applies per-peer latency overhead and minimum latency floor. - + If DeepEP is enabled (moe_enable_deepep=True), uses PXN All-to-All which pipelines scale-up and scale-out communication. """ # Check if DeepEP is enabled - use_deepep = getattr(args, "moe_enable_deepep", False) or getattr(args, "use_turbo_deepep", False) - + use_deepep = getattr(args, "moe_enable_deepep", False) or getattr( + args, "use_turbo_deepep", False + ) + min_a2a_time = float("inf") for p in ["simple", "ll", "ll64", "ll128"]: if use_deepep: @@ -722,10 +760,14 @@ def alltoall(args, msg_size, gpus, groups=["ep"]): else: # Use regular All-to-All algorithms direct_a2a_time = run_alltoall(args, msg_size, gpus, protocol=p) - single_shot_a2a_time = single_shot_alltoall(args, msg_size, gpus, protocol=p) - hierarchical_a2a_time = hierarchical_alltoall(args, msg_size, gpus, protocol=p) + single_shot_a2a_time = single_shot_alltoall( + args, msg_size, gpus, protocol=p + ) + hierarchical_a2a_time = hierarchical_alltoall( + args, msg_size, gpus, protocol=p + ) a2a_time = min(direct_a2a_time, single_shot_a2a_time, hierarchical_a2a_time) - + if a2a_time < min_a2a_time: min_a2a_time = a2a_time @@ -740,15 +782,21 @@ def alltoall(args, msg_size, gpus, groups=["ep"]): gpus_per_node = args.node_size intra_node_peers = min(gpus - 1, gpus_per_node - 1) # Peers within same node inter_node_peers = max(0, gpus - gpus_per_node) # Peers on other nodes - + # Intra-node overhead is much higher due to synchronization and scheduling # Based on preflight measurements: EP=8 intra-node A2A needs ~19-28 us per peer # Inter-node overhead is lower (~0.45 us per peer) due to RDMA efficiency - intra_node_overhead_per_peer = getattr(args, "a2a_intra_node_peer_lat", 28.0) # Default 28 us - inter_node_overhead_per_peer = getattr(args, "a2a_peer_lat", 0.45) # Default 0.45 us - - peer_overhead = (intra_node_overhead_per_peer * intra_node_peers + - inter_node_overhead_per_peer * inter_node_peers) + intra_node_overhead_per_peer = getattr( + args, "a2a_intra_node_peer_lat", 28.0 + ) # Default 28 us + inter_node_overhead_per_peer = getattr( + args, "a2a_peer_lat", 0.45 + ) # Default 0.45 us + + peer_overhead = ( + intra_node_overhead_per_peer * intra_node_peers + + inter_node_overhead_per_peer * inter_node_peers + ) min_a2a_time += peer_overhead return min_a2a_time diff --git a/primus/core/projection/module_profilers/embedding.py b/primus/core/projection/module_profilers/embedding.py index cfd2c3a90..6ca7ea409 100644 --- a/primus/core/projection/module_profilers/embedding.py +++ b/primus/core/projection/module_profilers/embedding.py @@ -17,7 +17,9 @@ class EmbeddingProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.module = None # Will be set during benchmarking - self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) + self._cached_results = ( + None # Cache for (forward_time, backward_time, activation_memory) + ) self._cache_key = None # Cache key (batch_size, seq_len) self._simulation_mode = False # Set to True when simulation backends are active @@ -35,7 +37,10 @@ def set_simulation_mode(self, enabled: bool = True): self._cache_key = None def estimated_num_params(self, rank: Optional[int] = None) -> int: - return self.config.model_config.padded_vocab_size * self.config.model_config.hidden_size + return ( + self.config.model_config.padded_vocab_size + * self.config.model_config.hidden_size + ) def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return ( @@ -47,7 +52,9 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: * 2 ) # bf16 - def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_simulated_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Estimate embedding time analytically (lookup is memory-bound, very fast).""" tp_size = self.config.model_parallel_config.tensor_model_parallel_size cp_size = self.config.model_parallel_config.context_model_parallel_size @@ -63,7 +70,9 @@ def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_benchmark_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) diff --git a/primus/core/projection/module_profilers/language_model.py b/primus/core/projection/module_profilers/language_model.py index 815366522..0145f7078 100644 --- a/primus/core/projection/module_profilers/language_model.py +++ b/primus/core/projection/module_profilers/language_model.py @@ -26,7 +26,9 @@ def build_profiler(spec: ModuleProfilerSpec, depth=0) -> BaseModuleProfiler: Recursively build a profiler instance from a ModuleProfilerSpec. """ if not issubclass(spec.profiler, BaseModuleProfiler): - raise TypeError(f"spec.profiler must be subclass of BaseModuleProfiler, got {spec.profiler}") + raise TypeError( + f"spec.profiler must be subclass of BaseModuleProfiler, got {spec.profiler}" + ) if depth == 0: print(f"Begin build profiler: {spec.profiler.__name__}") @@ -47,7 +49,9 @@ def build_profiler(spec: ModuleProfilerSpec, depth=0) -> BaseModuleProfiler: print(f"{'--'*(depth+1)}[{sub_spec.__name__}]({name})") sub_profilers[name] = sub_spec(spec.config, sub_profilers=None) else: - raise TypeError(f"Invalid type for sub_profiler_specs['{name}']: {type(sub_spec)}") + raise TypeError( + f"Invalid type for sub_profiler_specs['{name}']: {type(sub_spec)}" + ) return spec.profiler(config=spec.config, sub_profilers=sub_profilers) @@ -58,7 +62,9 @@ def get_language_model_profiler_spec(config: TrainingConfig) -> ModuleProfilerSp config=config, sub_profiler_specs={ "embedding": EmbeddingProfiler, - "dense_transformer_layer": get_dense_transformer_layer_profiler_spec(config), + "dense_transformer_layer": get_dense_transformer_layer_profiler_spec( + config + ), "moe_transformer_layer": get_moe_transformer_layer_profiler_spec(config), "final_layernorm": LayerNormProfiler, "output_layer": OutputLayerProfiler, @@ -112,7 +118,11 @@ def _get_explicit_layer_distribution( middle_stages = ( total_stages - 2 if (decoder_first is not None and decoder_last is not None) - else (total_stages - 1 if (decoder_first is not None or decoder_last is not None) else total_stages) + else ( + total_stages - 1 + if (decoder_first is not None or decoder_last is not None) + else total_stages + ) ) if middle_stages > 0 and remaining_layers > 0: @@ -174,13 +184,21 @@ def set_simulation_backends(self, gemm_backend=None, sdpa_backend=None): layer_profiler.set_simulation_backends(gemm_backend, sdpa_backend) # Propagate to embedding (uses simple analytical estimate in sim mode). - if "embedding" in self.sub_profilers and self.sub_profilers["embedding"] is not None: + if ( + "embedding" in self.sub_profilers + and self.sub_profilers["embedding"] is not None + ): emb = self.sub_profilers["embedding"] if hasattr(emb, "set_simulation_mode"): - emb.set_simulation_mode(gemm_backend is not None or sdpa_backend is not None) + emb.set_simulation_mode( + gemm_backend is not None or sdpa_backend is not None + ) # Propagate GEMM backend to output layer (vocab projection GEMM). - if "output_layer" in self.sub_profilers and self.sub_profilers["output_layer"] is not None: + if ( + "output_layer" in self.sub_profilers + and self.sub_profilers["output_layer"] is not None + ): out = self.sub_profilers["output_layer"] if gemm_backend is not None and hasattr(out, "set_gemm_backend"): out.set_gemm_backend(gemm_backend) @@ -204,7 +222,11 @@ def get_layers_for_rank( to the first virtual stages (or use decoder_first/last_pipeline_num_layers if set). """ total_stages = pp_size - vpp_size = num_virtual_pipeline_stages if num_virtual_pipeline_stages is not None else 1 + vpp_size = ( + num_virtual_pipeline_stages + if num_virtual_pipeline_stages is not None + else 1 + ) total_stages = pp_size * vpp_size model_parallel_size = pp_size * tp_size * cp_size * ep_size @@ -218,9 +240,13 @@ def get_layers_for_rank( if self is not None and hasattr(self, "config") and self.config is not None: mp_config = self.config.model_parallel_config if decoder_first is None: - decoder_first = getattr(mp_config, "decoder_first_pipeline_num_layers", None) + decoder_first = getattr( + mp_config, "decoder_first_pipeline_num_layers", None + ) if decoder_last is None: - decoder_last = getattr(mp_config, "decoder_last_pipeline_num_layers", None) + decoder_last = getattr( + mp_config, "decoder_last_pipeline_num_layers", None + ) # Build layer counts per virtual stage if decoder_first is not None or decoder_last is not None: @@ -359,14 +385,22 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: for layer in layers: is_moe = self.config.model_config.moe_pattern[layer] if is_moe: - total_params += self.sub_profilers["moe_transformer_layer"].estimated_num_params(rank) + total_params += self.sub_profilers[ + "moe_transformer_layer" + ].estimated_num_params(rank) else: - total_params += self.sub_profilers["dense_transformer_layer"].estimated_num_params(rank) + total_params += self.sub_profilers[ + "dense_transformer_layer" + ].estimated_num_params(rank) if 0 in self.layers: total_params += self.sub_profilers["embedding"].estimated_num_params(rank) if self.config.model_config.num_layers - 1 in self.layers: - total_params += self.sub_profilers["final_layernorm"].estimated_num_params(rank) - total_params += self.sub_profilers["output_layer"].estimated_num_params(rank) + total_params += self.sub_profilers["final_layernorm"].estimated_num_params( + rank + ) + total_params += self.sub_profilers["output_layer"].estimated_num_params( + rank + ) total_params += self.sub_profilers["calc_loss"].estimated_num_params(rank) return total_params @@ -378,13 +412,17 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: (hidden_size * batch_size * seq_len * dtype_bytes), not full intermediate activations. """ pp_size = self.config.model_parallel_config.pipeline_model_parallel_size - vpp_size = self.config.model_parallel_config.virtual_pipeline_model_parallel_size + vpp_size = ( + self.config.model_parallel_config.virtual_pipeline_model_parallel_size + ) recompute_granularity = self.config.model_parallel_config.recompute_granularity recompute_num_layers = self.config.model_parallel_config.recompute_num_layers # Calculate number of layers per virtual pipeline stage on this rank layers_per_rank = len(self.layers) - layers_per_vpp_stage = layers_per_rank // vpp_size if vpp_size > 0 else layers_per_rank + layers_per_vpp_stage = ( + layers_per_rank // vpp_size if vpp_size > 0 else layers_per_rank + ) # Input activation size per layer (only thing stored for recomputed layers) # hidden_size * batch_size * seq_len * dtype_bytes (bf16 = 2 bytes) @@ -401,7 +439,9 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: layer_act = 0 for i, layer in enumerate(self.layers): # Determine if this layer is recomputed - local_layer_idx = i % layers_per_vpp_stage if layers_per_vpp_stage > 0 else i + local_layer_idx = ( + i % layers_per_vpp_stage if layers_per_vpp_stage > 0 else i + ) is_recomputed = ( recompute_granularity == "full" and recompute_num_layers is not None @@ -415,25 +455,31 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: # Non-recomputed layer: store full activations is_moe = self.config.model_config.moe_pattern[layer] if is_moe: - layer_act += self.sub_profilers["moe_transformer_layer"].estimated_activation_memory( - batch_size, seq_len - ) + layer_act += self.sub_profilers[ + "moe_transformer_layer" + ].estimated_activation_memory(batch_size, seq_len) else: - layer_act += self.sub_profilers["dense_transformer_layer"].estimated_activation_memory( - batch_size, seq_len - ) + layer_act += self.sub_profilers[ + "dense_transformer_layer" + ].estimated_activation_memory(batch_size, seq_len) total_act = layer_act # Add embedding/output activations if 0 in self.layers: - total_act += self.sub_profilers["embedding"].estimated_activation_memory(batch_size, seq_len) + total_act += self.sub_profilers["embedding"].estimated_activation_memory( + batch_size, seq_len + ) if self.config.model_config.num_layers - 1 in self.layers: - total_act += self.sub_profilers["final_layernorm"].estimated_activation_memory( + total_act += self.sub_profilers[ + "final_layernorm" + ].estimated_activation_memory(batch_size, seq_len) + total_act += self.sub_profilers["output_layer"].estimated_activation_memory( + batch_size, seq_len + ) + total_act += self.sub_profilers["calc_loss"].estimated_activation_memory( batch_size, seq_len ) - total_act += self.sub_profilers["output_layer"].estimated_activation_memory(batch_size, seq_len) - total_act += self.sub_profilers["calc_loss"].estimated_activation_memory(batch_size, seq_len) # 1F1B pipeline schedule: need to store activations for pp_size microbatches total_act *= pp_size @@ -460,7 +506,9 @@ def run_layer_benchmark(self, model, batch_size: int, seq_len: int) -> dict: The mode is automatically selected based on whether simulation backends have been set via :meth:`set_simulation_backends`. """ - is_simulation_mode = self._gemm_backend is not None or self._sdpa_backend is not None + is_simulation_mode = ( + self._gemm_backend is not None or self._sdpa_backend is not None + ) # ----------------------------------------------------------------- # Unwrap model (only when an actual model is provided) @@ -470,9 +518,14 @@ def run_layer_benchmark(self, model, batch_size: int, seq_len: int) -> dict: all_layers = [] if model is not None: + def unwrap_module(module): """Recursively unwrap DistributedDataParallel / pipeline wrappers.""" - return unwrap_module(module.module) if hasattr(module, "module") else module + return ( + unwrap_module(module.module) + if hasattr(module, "module") + else module + ) model_chunks = model if isinstance(model, list) else [model] @@ -486,15 +539,21 @@ def unwrap_module(module): if hasattr(language_model, "output_layer"): output_module = language_model.output_layer - if hasattr(language_model, "encoder") and hasattr(language_model.encoder, "layers"): + if hasattr(language_model, "encoder") and hasattr( + language_model.encoder, "layers" + ): all_layers.extend(language_model.encoder.layers) - elif hasattr(language_model, "decoder") and hasattr(language_model.decoder, "layers"): + elif hasattr(language_model, "decoder") and hasattr( + language_model.decoder, "layers" + ): all_layers.extend(language_model.decoder.layers) elif hasattr(language_model, "layers"): all_layers.extend(language_model.layers) continue - if hasattr(unwrapped, "decoder") and hasattr(unwrapped.decoder, "layers"): + if hasattr(unwrapped, "decoder") and hasattr( + unwrapped.decoder, "layers" + ): all_layers.extend(unwrapped.decoder.layers) elif hasattr(unwrapped, "layers"): all_layers.extend(unwrapped.layers) @@ -516,17 +575,25 @@ def unwrap_module(module): mode_label = "Simulating" if is_simulation_mode else "Benchmarking" if is_rank_0: if model is not None: - print(f"\n[Primus:Performance Projection] Found {len(all_layers)} transformer layers") + print( + f"\n[Primus:Performance Projection] Found {len(all_layers)} transformer layers" + ) else: - print(f"\n[Primus:Performance Projection] Pure simulation mode (no model)") - print(f"[Primus:Performance Projection] This rank is responsible for layers: {self.layers}") + print( + f"\n[Primus:Performance Projection] Pure simulation mode (no model)" + ) + print( + f"[Primus:Performance Projection] This rank is responsible for layers: {self.layers}" + ) if is_simulation_mode: backends = [] if self._gemm_backend is not None: backends.append(f"GEMM={self._gemm_backend.name()}") if self._sdpa_backend is not None: backends.append(f"SDPA={self._sdpa_backend.name()}") - print(f"[Primus:Performance Projection] Mode: SIMULATION ({', '.join(backends)})") + print( + f"[Primus:Performance Projection] Mode: SIMULATION ({', '.join(backends)})" + ) embedding_stats = None output_stats = None @@ -535,14 +602,20 @@ def unwrap_module(module): # Benchmark / simulate embedding layer (if this rank hosts it) # ---------------------------------------------------------------------- if 0 in self.layers: - if model is not None and embedding_module is None and not is_simulation_mode: + if ( + model is not None + and embedding_module is None + and not is_simulation_mode + ): if is_rank_0: print( "[Primus:Performance Projection] WARNING: Embedding module not found on this rank." ) else: if is_rank_0: - print(f"[Primus:Performance Projection] {mode_label} embedding layer...") + print( + f"[Primus:Performance Projection] {mode_label} embedding layer..." + ) profiler = self.sub_profilers["embedding"] if embedding_module is not None: module = ( @@ -582,7 +655,9 @@ def unwrap_module(module): ) else: if is_rank_0: - print(f"[Primus:Performance Projection] {mode_label} output layer...") + print( + f"[Primus:Performance Projection] {mode_label} output layer..." + ) profiler = self.sub_profilers["output_layer"] if output_module is not None: profiler.set_module(output_module) @@ -627,7 +702,9 @@ def unwrap_module(module): continue if is_rank_0: - print(f"\n[Primus:Performance Projection] {mode_label} Layer {layer_idx} ({layer_type})...") + print( + f"\n[Primus:Performance Projection] {mode_label} Layer {layer_idx} ({layer_type})..." + ) # Get the appropriate profiler if is_moe: @@ -643,7 +720,9 @@ def unwrap_module(module): # Benchmark/simulate full layer forward_time = layer_profiler.measured_forward_time(batch_size, seq_len) backward_time = layer_profiler.measured_backward_time(batch_size, seq_len) - activation_memory = layer_profiler.measured_activation_memory(batch_size, seq_len) + activation_memory = layer_profiler.measured_activation_memory( + batch_size, seq_len + ) # Benchmark/simulate Attention attn_profiler = layer_profiler.get_sub_profiler("self_attention") @@ -683,7 +762,9 @@ def unwrap_module(module): print(f" Backward time: {backward_time:.2f} ms {src}") print(f" Total: {forward_time + backward_time:.2f} ms {src}") print(f" Activation memory: {activation_memory / (1024**2):.2f} MB") - print(f" Attention: fwd={attn_forward:.2f} ms, bwd={attn_backward:.2f} ms") + print( + f" Attention: fwd={attn_forward:.2f} ms, bwd={attn_backward:.2f} ms" + ) print(f" MLP: fwd={mlp_forward:.2f} ms, bwd={mlp_backward:.2f} ms") # Expand results to all layers diff --git a/primus/core/projection/module_profilers/moe_mlp.py b/primus/core/projection/module_profilers/moe_mlp.py index 0f4dc585a..a7ca7a424 100644 --- a/primus/core/projection/module_profilers/moe_mlp.py +++ b/primus/core/projection/module_profilers/moe_mlp.py @@ -185,16 +185,17 @@ def _get_simulated_results( # execution → model as Origami batched GEMM (batch=num_local_experts). # Legacy grouped_gemm executes experts more sequentially → model as # individual GEMM (batch=1) × num_local_experts. - use_turbo = ( - getattr(self.config.model_config, "enable_primus_turbo", False) - and getattr(self.config.model_config, "use_turbo_grouped_mlp", False) - ) + use_turbo = getattr( + self.config.model_config, "enable_primus_turbo", False + ) and getattr(self.config.model_config, "use_turbo_grouped_mlp", False) is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0 and num_local_experts > 1: mode = "Turbo (batched)" if use_turbo else "Legacy (sequential)" - print(f" [MoE MLP] Grouped-GEMM model: {mode}" - f" ({num_local_experts} local experts, M={M}, H={H}, F={F})") + print( + f" [MoE MLP] Grouped-GEMM model: {mode}" + f" ({num_local_experts} local experts, M={M}, H={H}, F={F})" + ) expert_fwd_ms = 0.0 expert_bwd_ms = 0.0 @@ -203,62 +204,96 @@ def _get_simulated_results( # ── Turbo model: batched GEMM (all experts in parallel) ── B = num_local_experts if self.config.model_config.swiglu: - gate_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) - up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) - down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) - expert_fwd_ms = (gate_fwd.forward_time_ms - + up_fwd.forward_time_ms - + down_fwd.forward_time_ms) + gate_fwd = self._gemm_backend.simulate_gemm( + M, F, H, gemm_dtype, batch=B + ) + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + down_fwd = self._gemm_backend.simulate_gemm( + M, H, F, gemm_dtype, batch=B + ) + expert_fwd_ms = ( + gate_fwd.forward_time_ms + + up_fwd.forward_time_ms + + down_fwd.forward_time_ms + ) gate_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) gate_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) - up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) - up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=B) - expert_bwd_ms = (gate_dg.forward_time_ms + gate_wg.forward_time_ms - + up_dg.forward_time_ms + up_wg.forward_time_ms - + down_dg.forward_time_ms + down_wg.forward_time_ms) + expert_bwd_ms = ( + gate_dg.forward_time_ms + + gate_wg.forward_time_ms + + up_dg.forward_time_ms + + up_wg.forward_time_ms + + down_dg.forward_time_ms + + down_wg.forward_time_ms + ) else: - up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) - down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) + down_fwd = self._gemm_backend.simulate_gemm( + M, H, F, gemm_dtype, batch=B + ) expert_fwd_ms = up_fwd.forward_time_ms + down_fwd.forward_time_ms - up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) - up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=B) - expert_bwd_ms = (up_dg.forward_time_ms + up_wg.forward_time_ms - + down_dg.forward_time_ms + down_wg.forward_time_ms) + expert_bwd_ms = ( + up_dg.forward_time_ms + + up_wg.forward_time_ms + + down_dg.forward_time_ms + + down_wg.forward_time_ms + ) expert_fwd = expert_fwd_ms expert_bwd = expert_bwd_ms else: # ── Legacy model: individual GEMM × num_local_experts ── if self.config.model_config.swiglu: - gate_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) - up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) - down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) - expert_fwd_ms = (gate_fwd.forward_time_ms - + up_fwd.forward_time_ms - + down_fwd.forward_time_ms) + gate_fwd = self._gemm_backend.simulate_gemm( + M, F, H, gemm_dtype, batch=1 + ) + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + down_fwd = self._gemm_backend.simulate_gemm( + M, H, F, gemm_dtype, batch=1 + ) + expert_fwd_ms = ( + gate_fwd.forward_time_ms + + up_fwd.forward_time_ms + + down_fwd.forward_time_ms + ) gate_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) gate_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) - up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) - up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=1) - expert_bwd_ms = (gate_dg.forward_time_ms + gate_wg.forward_time_ms - + up_dg.forward_time_ms + up_wg.forward_time_ms - + down_dg.forward_time_ms + down_wg.forward_time_ms) + expert_bwd_ms = ( + gate_dg.forward_time_ms + + gate_wg.forward_time_ms + + up_dg.forward_time_ms + + up_wg.forward_time_ms + + down_dg.forward_time_ms + + down_wg.forward_time_ms + ) else: - up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) - down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) + down_fwd = self._gemm_backend.simulate_gemm( + M, H, F, gemm_dtype, batch=1 + ) expert_fwd_ms = up_fwd.forward_time_ms + down_fwd.forward_time_ms - up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) - up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) + up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) down_dg = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) down_wg = self._gemm_backend.simulate_gemm(F, H, M, gemm_dtype, batch=1) - expert_bwd_ms = (up_dg.forward_time_ms + up_wg.forward_time_ms - + down_dg.forward_time_ms + down_wg.forward_time_ms) + expert_bwd_ms = ( + up_dg.forward_time_ms + + up_wg.forward_time_ms + + down_dg.forward_time_ms + + down_wg.forward_time_ms + ) expert_fwd = expert_fwd_ms * num_local_experts expert_bwd = expert_bwd_ms * num_local_experts @@ -306,7 +341,9 @@ def _get_simulated_results( # ── 4. Activation function overhead (SwiGLU / GELU) ── if self.config.model_config.swiglu: - act_bytes = 3 * topk_tokens * moe_ffn * bytes_per_el # gate+up read, result write + act_bytes = ( + 3 * topk_tokens * moe_ffn * bytes_per_el + ) # gate+up read, result write else: act_bytes = 2 * topk_tokens * moe_ffn * bytes_per_el # read + write activation_ms = act_bytes / (activation_bw_gbps * 1e6) diff --git a/primus/core/projection/module_profilers/output_layer.py b/primus/core/projection/module_profilers/output_layer.py index c2e27555b..8c979a16d 100644 --- a/primus/core/projection/module_profilers/output_layer.py +++ b/primus/core/projection/module_profilers/output_layer.py @@ -15,7 +15,9 @@ class OutputLayerProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.module = None # Will be set during benchmarking - self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) + self._cached_results = ( + None # Cache for (forward_time, backward_time, activation_memory) + ) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend @@ -33,7 +35,10 @@ def set_gemm_backend(self, backend): self._cache_key = None def estimated_num_params(self, rank: Optional[int] = None) -> int: - return self.config.model_config.padded_vocab_size * self.config.model_config.hidden_size + return ( + self.config.model_config.padded_vocab_size + * self.config.model_config.hidden_size + ) def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return ( @@ -45,7 +50,9 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: * 2 ) # bf16 - def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_simulated_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Simulate output layer using GEMM backend (vocab projection GEMM).""" tp_size = self.config.model_parallel_config.tensor_model_parallel_size cp_size = self.config.model_parallel_config.context_model_parallel_size @@ -82,7 +89,9 @@ def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_benchmark_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) diff --git a/primus/core/projection/module_profilers/router.py b/primus/core/projection/module_profilers/router.py index efdd8338b..994d556ef 100644 --- a/primus/core/projection/module_profilers/router.py +++ b/primus/core/projection/module_profilers/router.py @@ -11,7 +11,9 @@ class RouterProfiler(BaseModuleProfiler): def estimated_num_params(self, rank: Optional[int] = None) -> int: - return self.config.model_config.hidden_size * self.config.model_config.num_experts + return ( + self.config.model_config.hidden_size * self.config.model_config.num_experts + ) def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return ( diff --git a/primus/core/projection/module_profilers/transformer_layer.py b/primus/core/projection/module_profilers/transformer_layer.py index 329634819..a3a115149 100644 --- a/primus/core/projection/module_profilers/transformer_layer.py +++ b/primus/core/projection/module_profilers/transformer_layer.py @@ -21,6 +21,7 @@ from .router import RouterProfiler from .utils import benchmark_layer + def _estimate_tp_allreduce_time_ms(config, batch_size: int, seq_len: int) -> float: """ Estimate TP AllReduce time for a single AllReduce operation (in ms). @@ -53,7 +54,10 @@ def _estimate_tp_allreduce_time_ms(config, batch_size: int, seq_len: int) -> flo coll_args = get_default_args( num_nodes=num_nodes, gpus_per_node=gpus_per_node, - tp=tp, pp=pp, ep=ep, cp=cp, + tp=tp, + pp=pp, + ep=ep, + cp=cp, ) # TP AllReduce is across tp ranks (typically intra-node) @@ -97,11 +101,16 @@ def _estimate_moe_a2a_time_ms(config, batch_size: int, seq_len: int) -> float: coll_args = get_default_args( num_nodes=num_nodes, gpus_per_node=gpus_per_node, - tp=tp, pp=pp, ep=ep, cp=cp, + tp=tp, + pp=pp, + ep=ep, + cp=cp, ) # Propagate DeepEP setting if present (affects A2A algorithm selection) - moe_enable_deepep = getattr(config.model_parallel_config, "moe_enable_deepep", False) + moe_enable_deepep = getattr( + config.model_parallel_config, "moe_enable_deepep", False + ) use_turbo_deepep = getattr(config.model_parallel_config, "use_turbo_deepep", False) coll_args.moe_enable_deepep = moe_enable_deepep coll_args.use_turbo_deepep = use_turbo_deepep @@ -162,7 +171,9 @@ class DenseTransformerLayerProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.layer_module = None # Will be set during benchmarking - self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) + self._cached_results = ( + None # Cache for (forward_time, backward_time, activation_memory) + ) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend self._sdpa_backend = None # Optional: SDPA simulation backend @@ -209,16 +220,30 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return ( - self.sub_profilers["layer_norm"].estimated_activation_memory(batch_size, seq_len) * 3 - + self.sub_profilers["self_attention"].estimated_activation_memory(batch_size, seq_len) + self.sub_profilers["layer_norm"].estimated_activation_memory( + batch_size, seq_len + ) + * 3 + + self.sub_profilers["self_attention"].estimated_activation_memory( + batch_size, seq_len + ) + self.sub_profilers["mlp"].estimated_activation_memory(batch_size, seq_len) - + self.sub_profilers["residual_add"].estimated_activation_memory(batch_size, seq_len) * 2 + + self.sub_profilers["residual_add"].estimated_activation_memory( + batch_size, seq_len + ) + * 2 ) - def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_simulated_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Aggregate simulated results from sub-profilers, including TP AllReduce.""" - attn_fwd = self.sub_profilers["self_attention"].measured_forward_time(batch_size, seq_len) - attn_bwd = self.sub_profilers["self_attention"].measured_backward_time(batch_size, seq_len) + attn_fwd = self.sub_profilers["self_attention"].measured_forward_time( + batch_size, seq_len + ) + attn_bwd = self.sub_profilers["self_attention"].measured_backward_time( + batch_size, seq_len + ) mlp_fwd = self.sub_profilers["mlp"].measured_forward_time(batch_size, seq_len) mlp_bwd = self.sub_profilers["mlp"].measured_backward_time(batch_size, seq_len) @@ -234,7 +259,9 @@ def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_benchmark_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: @@ -269,7 +296,9 @@ class MoETransformerLayerProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.layer_module = None # Will be set during benchmarking - self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) + self._cached_results = ( + None # Cache for (forward_time, backward_time, activation_memory) + ) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend self._sdpa_backend = None # Optional: SDPA simulation backend @@ -317,22 +346,38 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return ( - self.sub_profilers["layer_norm"].estimated_activation_memory(batch_size, seq_len) * 3 - + self.sub_profilers["self_attention"].estimated_activation_memory(batch_size, seq_len) + self.sub_profilers["layer_norm"].estimated_activation_memory( + batch_size, seq_len + ) + * 3 + + self.sub_profilers["self_attention"].estimated_activation_memory( + batch_size, seq_len + ) + self.sub_profilers["mlp"].estimated_activation_memory(batch_size, seq_len) - + self.sub_profilers["router"].estimated_activation_memory(batch_size, seq_len) - + self.sub_profilers["residual_add"].estimated_activation_memory(batch_size, seq_len) * 2 + + self.sub_profilers["router"].estimated_activation_memory( + batch_size, seq_len + ) + + self.sub_profilers["residual_add"].estimated_activation_memory( + batch_size, seq_len + ) + * 2 ) - def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_simulated_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Aggregate simulated results from sub-profilers. Includes TP AllReduce and MoE All-to-All communication overhead that would be captured in the measured layer time during benchmark mode but must be added explicitly in simulation mode. """ - attn_fwd = self.sub_profilers["self_attention"].measured_forward_time(batch_size, seq_len) - attn_bwd = self.sub_profilers["self_attention"].measured_backward_time(batch_size, seq_len) + attn_fwd = self.sub_profilers["self_attention"].measured_forward_time( + batch_size, seq_len + ) + attn_bwd = self.sub_profilers["self_attention"].measured_backward_time( + batch_size, seq_len + ) mlp_fwd = self.sub_profilers["mlp"].measured_forward_time(batch_size, seq_len) mlp_bwd = self.sub_profilers["mlp"].measured_backward_time(batch_size, seq_len) @@ -355,7 +400,9 @@ def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: + def _get_benchmark_results( + self, batch_size: int, seq_len: int + ) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: diff --git a/primus/core/projection/module_profilers/utils.py b/primus/core/projection/module_profilers/utils.py index c2cad601e..b64db4528 100644 --- a/primus/core/projection/module_profilers/utils.py +++ b/primus/core/projection/module_profilers/utils.py @@ -16,7 +16,9 @@ class _FP8ContextFactory: def __init__(self, transformer_config): self.transformer_config = transformer_config - self.fp8_enabled = getattr(transformer_config, "fp8", None) if transformer_config else None + self.fp8_enabled = ( + getattr(transformer_config, "fp8", None) if transformer_config else None + ) self._printed = False def __enter__(self): @@ -28,7 +30,9 @@ def __enter__(self): self._ctx = get_fp8_context(self.transformer_config, layer_no=-1) if not self._printed: - print(f" [FP8] Using FP8 autocast context for benchmarking (fp8={self.fp8_enabled})") + print( + f" [FP8] Using FP8 autocast context for benchmarking (fp8={self.fp8_enabled})" + ) self._printed = True except Exception as e: try: @@ -88,7 +92,11 @@ def benchmark_layer( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def create_input(spec): - if isinstance(spec, tuple) and len(spec) == 2 and isinstance(spec[1], torch.dtype): + if ( + isinstance(spec, tuple) + and len(spec) == 2 + and isinstance(spec[1], torch.dtype) + ): shape, dtype = spec else: shape = spec diff --git a/primus/core/projection/performance_projection/projection.py b/primus/core/projection/performance_projection/projection.py index f16220ee3..a76c85f91 100644 --- a/primus/core/projection/performance_projection/projection.py +++ b/primus/core/projection/performance_projection/projection.py @@ -2170,9 +2170,9 @@ def launch_projection_from_cli(args, overrides): # Also propagate num_experts adjustment so that the profiler sees # the correct experts_per_rank (e.g. 128/4=32, not 256/4=64). if reduction_info.get("benchmark_num_experts") is not None: - primus_config.get_module_config("pre_trainer").num_experts = ( - reduction_info["benchmark_num_experts"] - ) + primus_config.get_module_config("pre_trainer").num_experts = reduction_info[ + "benchmark_num_experts" + ] # Determine profiling mode profiling_mode = getattr(args, "profiling_mode", "benchmark") diff --git a/primus/core/projection/performance_projection/simulator.py b/primus/core/projection/performance_projection/simulator.py index 00e3d6f6d..f39bfb88e 100644 --- a/primus/core/projection/performance_projection/simulator.py +++ b/primus/core/projection/performance_projection/simulator.py @@ -32,11 +32,15 @@ def __init__(self, config: dict): self.debug_simulator = int(os.getenv("DEBUG_SIMULATOR", "0") == "1") - def _summarize_simulation_result(self, simulation_result: list[dict], scheduler_config: dict) -> dict: + def _summarize_simulation_result( + self, simulation_result: list[dict], scheduler_config: dict + ) -> dict: rank_totals = [rank.get("total", 0.0) for rank in simulation_result] step_time_ms = max(rank_totals) if rank_totals else 0.0 critical_rank = rank_totals.index(step_time_ms) if rank_totals else None - max_memory = max((rank.get("memory", 0.0) for rank in simulation_result), default=0.0) + max_memory = max( + (rank.get("memory", 0.0) for rank in simulation_result), default=0.0 + ) return { "step_time_ms": step_time_ms, "rank_totals": rank_totals, @@ -85,7 +89,9 @@ def _chunk_duration( else: raise ValueError("Duration is not found.") - def _chunk_activation(self, rank: int, chunk: int | None, vpp_size: int | None) -> float: + def _chunk_activation( + self, rank: int, chunk: int | None, vpp_size: int | None + ) -> float: if self.chunk_time_ms is None: if vpp_size is None: vpp_size = 1 @@ -163,7 +169,9 @@ def run(self): module = importlib.import_module(module_path) scheduler_class = getattr(module, class_name) - scheduler_params = {k: v for k, v in scheduler_config.items() if k not in ["name", "class"]} + scheduler_params = { + k: v for k, v in scheduler_config.items() if k not in ["name", "class"] + } scheduler_instance = scheduler_class(**scheduler_params) schedule_table = scheduler_instance.generate_schedule_table() @@ -174,9 +182,13 @@ def run(self): print(f"{'='*20 * scheduler_config['pp_size']}") if self.debug_simulator: scheduler_instance.print_schedule_table(schedule_table) - simulation_result = self.simulate_scheduler_table(schedule_table, scheduler_config) + simulation_result = self.simulate_scheduler_table( + schedule_table, scheduler_config + ) self.dump_simulation_result(simulation_result, scheduler_config) - summary = self._summarize_simulation_result(simulation_result, scheduler_config) + summary = self._summarize_simulation_result( + simulation_result, scheduler_config + ) run_summaries.append( { "name": scheduler_config["name"], @@ -191,7 +203,9 @@ def run(self): return run_summaries - def simulate_scheduler_table(self, schedule_table: list[list[SchedulerNode]], scheduler_config: dict): + def simulate_scheduler_table( + self, schedule_table: list[list[SchedulerNode]], scheduler_config: dict + ): current_rank = 0 rank_clock = [0.0 for _ in range(len(schedule_table))] @@ -251,7 +265,9 @@ def simulate_scheduler_table(self, schedule_table: list[list[SchedulerNode]], sc print(f"rank {current_rank} send_key: {send_key}") communication_map[send_key] = rank_clock[current_rank] if node.func_type in [FuncType.RF, FuncType.RB]: - send_func_type = FuncType.SF if node.func_type == FuncType.RF else FuncType.SB + send_func_type = ( + FuncType.SF if node.func_type == FuncType.RF else FuncType.SB + ) send_key = f"{node.args['from_pp_rank']}_{node.args['to_pp_rank']}_{send_func_type}_{node.mini_batch}_{node.chunk}" if send_key not in communication_map: merge_comm_index = rank_idx[current_rank] + 1 @@ -259,7 +275,9 @@ def simulate_scheduler_table(self, schedule_table: list[list[SchedulerNode]], sc print(f"rank {current_rank} wait send_key {send_key}") # merge the send op behind - for i in range(merge_comm_index, len(schedule_table[current_rank])): + for i in range( + merge_comm_index, len(schedule_table[current_rank]) + ): if schedule_table[current_rank][i].func_type in [ FuncType.SF, FuncType.SB, @@ -276,7 +294,9 @@ def simulate_scheduler_table(self, schedule_table: list[list[SchedulerNode]], sc send_time = communication_map.pop(send_key) - recv_time_map[f"{node.mini_batch}_{node.chunk}_{node.func_type}"] = send_time + recv_time_map[ + f"{node.mini_batch}_{node.chunk}_{node.func_type}" + ] = send_time if node.func_type in [ FuncType.F, @@ -286,17 +306,21 @@ def simulate_scheduler_table(self, schedule_table: list[list[SchedulerNode]], sc FuncType.FB, ]: if node.func_type in [FuncType.F, FuncType.B, FuncType.BW]: - recv_node_type = FuncType.RF if node.func_type == FuncType.F else FuncType.RB + recv_node_type = ( + FuncType.RF if node.func_type == FuncType.F else FuncType.RB + ) recv_key = f"{node.mini_batch}_{node.chunk}_{recv_node_type}" if recv_key in recv_time_map: rank_clock[current_rank] = max( rank_clock[current_rank], - recv_time_map[f"{node.mini_batch}_{node.chunk}_{recv_node_type}"], + recv_time_map[ + f"{node.mini_batch}_{node.chunk}_{recv_node_type}" + ], ) - simulation_result[current_rank][f"{self._result_key_dict[node.func_type]}_start"].append( - rank_clock[current_rank] - ) + simulation_result[current_rank][ + f"{self._result_key_dict[node.func_type]}_start" + ].append(rank_clock[current_rank]) duration = self._chunk_duration( current_rank, getattr(node, "chunk", 0), @@ -304,15 +328,15 @@ def simulate_scheduler_table(self, schedule_table: list[list[SchedulerNode]], sc scheduler_config, ) rank_clock[current_rank] += duration - simulation_result[current_rank][f"{self._result_key_dict[node.func_type]}_end"].append( - rank_clock[current_rank] - ) + simulation_result[current_rank][ + f"{self._result_key_dict[node.func_type]}_end" + ].append(rank_clock[current_rank]) simulation_result[current_rank][ f"{self._result_key_dict[node.func_type]}_minibatch" ].append(node.mini_batch) - simulation_result[current_rank][f"{self._result_key_dict[node.func_type]}_chunk"].append( - node.chunk - ) + simulation_result[current_rank][ + f"{self._result_key_dict[node.func_type]}_chunk" + ].append(node.chunk) if node.func_type == FuncType.F: act_gb = self._chunk_activation( current_rank, @@ -324,9 +348,9 @@ def simulate_scheduler_table(self, schedule_table: list[list[SchedulerNode]], sc simulation_result[current_rank]["memory"], rank_memory[current_rank], ) - simulation_result[current_rank]["activation_memory_usage"].append( - rank_memory[current_rank] - ) + simulation_result[current_rank][ + "activation_memory_usage" + ].append(rank_memory[current_rank]) elif node.func_type in [FuncType.BW, FuncType.W]: act_gb = self._chunk_activation( current_rank, @@ -334,9 +358,9 @@ def simulate_scheduler_table(self, schedule_table: list[list[SchedulerNode]], sc scheduler_config["vpp_size"], ) rank_memory[current_rank] = rank_memory[current_rank] - act_gb - simulation_result[current_rank]["activation_memory_usage"].append( - rank_memory[current_rank] - ) + simulation_result[current_rank][ + "activation_memory_usage" + ].append(rank_memory[current_rank]) rank_idx[current_rank] += 1 current_rank = (current_rank + 1) % len(schedule_table) @@ -346,7 +370,9 @@ def simulate_scheduler_table(self, schedule_table: list[list[SchedulerNode]], sc return simulation_result - def dump_simulation_result(self, simulation_result: list[dict], scheduler_config: dict): + def dump_simulation_result( + self, simulation_result: list[dict], scheduler_config: dict + ): result_dir = f"{self.config['output_dir']}/{scheduler_config['name']}" os.makedirs(result_dir, exist_ok=True) with open(f"{result_dir}/config.json", "w") as f: diff --git a/primus/core/projection/profiler_spec.py b/primus/core/projection/profiler_spec.py index 02c8d9e25..e053f02fa 100644 --- a/primus/core/projection/profiler_spec.py +++ b/primus/core/projection/profiler_spec.py @@ -15,6 +15,6 @@ class ModuleProfilerSpec: profiler: Type[BaseModuleProfiler] config: Type[TrainingConfig] - sub_profiler_specs: Optional[Dict[str, Union[Type[BaseModuleProfiler], "ModuleProfilerSpec", None]]] = ( - field(default_factory=lambda: {}) - ) + sub_profiler_specs: Optional[ + Dict[str, Union[Type[BaseModuleProfiler], "ModuleProfilerSpec", None]] + ] = field(default_factory=lambda: {}) diff --git a/primus/core/projection/simulation_backends/base.py b/primus/core/projection/simulation_backends/base.py index b9daea699..ae68b7b80 100644 --- a/primus/core/projection/simulation_backends/base.py +++ b/primus/core/projection/simulation_backends/base.py @@ -141,53 +141,92 @@ def simulate_mlp_gemms( if swiglu: # Gate projection fwd: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] - gate_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + gate_fwd = self.simulate_gemm( + batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b + ) # Up projection fwd: same shape as gate - up_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + up_fwd = self.simulate_gemm( + batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b + ) # Down projection fwd: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] - down_fwd = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + down_fwd = self.simulate_gemm( + batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b + ) - fwd_time = gate_fwd.forward_time_ms + up_fwd.forward_time_ms + down_fwd.forward_time_ms + fwd_time = ( + gate_fwd.forward_time_ms + + up_fwd.forward_time_ms + + down_fwd.forward_time_ms + ) # Backward: simulate actual dgrad + wgrad GEMMs per projection # Gate dgrad: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] - gate_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + gate_dgrad = self.simulate_gemm( + batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b + ) # Gate wgrad: [hidden, tokens] x [tokens, ffn] -> [hidden, ffn] - gate_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) + gate_wgrad = self.simulate_gemm( + hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b + ) # Up dgrad + wgrad: same shapes as gate - up_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) - up_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) + up_dgrad = self.simulate_gemm( + batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b + ) + up_wgrad = self.simulate_gemm( + hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b + ) # Down dgrad: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] - down_dgrad = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + down_dgrad = self.simulate_gemm( + batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b + ) # Down wgrad: [ffn, tokens] x [tokens, hidden] -> [ffn, hidden] - down_wgrad = self.simulate_gemm(ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b) + down_wgrad = self.simulate_gemm( + ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b + ) bwd_time = ( - gate_dgrad.forward_time_ms + gate_wgrad.forward_time_ms - + up_dgrad.forward_time_ms + up_wgrad.forward_time_ms - + down_dgrad.forward_time_ms + down_wgrad.forward_time_ms + gate_dgrad.forward_time_ms + + gate_wgrad.forward_time_ms + + up_dgrad.forward_time_ms + + up_wgrad.forward_time_ms + + down_dgrad.forward_time_ms + + down_wgrad.forward_time_ms ) else: # Up projection fwd: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] - up_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + up_fwd = self.simulate_gemm( + batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b + ) # Down projection fwd: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] - down_fwd = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + down_fwd = self.simulate_gemm( + batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b + ) fwd_time = up_fwd.forward_time_ms + down_fwd.forward_time_ms # Backward: simulate actual dgrad + wgrad GEMMs per projection # Up dgrad: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] - up_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + up_dgrad = self.simulate_gemm( + batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b + ) # Up wgrad: [hidden, tokens] x [tokens, ffn] -> [hidden, ffn] - up_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) + up_wgrad = self.simulate_gemm( + hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b + ) # Down dgrad: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] - down_dgrad = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) + down_dgrad = self.simulate_gemm( + batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b + ) # Down wgrad: [ffn, tokens] x [tokens, hidden] -> [ffn, hidden] - down_wgrad = self.simulate_gemm(ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b) + down_wgrad = self.simulate_gemm( + ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b + ) bwd_time = ( - up_dgrad.forward_time_ms + up_wgrad.forward_time_ms - + down_dgrad.forward_time_ms + down_wgrad.forward_time_ms + up_dgrad.forward_time_ms + + up_wgrad.forward_time_ms + + down_dgrad.forward_time_ms + + down_wgrad.forward_time_ms ) return SimulationResult(forward_time_ms=fwd_time, backward_time_ms=bwd_time) diff --git a/primus/core/projection/simulation_backends/origami_backend.py b/primus/core/projection/simulation_backends/origami_backend.py index 94e29d3e1..72febb90b 100644 --- a/primus/core/projection/simulation_backends/origami_backend.py +++ b/primus/core/projection/simulation_backends/origami_backend.py @@ -407,7 +407,9 @@ def _get_hardware(self): clock_khz = profile.compute_clock_khz if self._clock_override_mhz is not None: clock_khz = self._clock_override_mhz * 1000 - n_cu = self._n_cu_override if self._n_cu_override is not None else profile.n_cu + n_cu = ( + self._n_cu_override if self._n_cu_override is not None else profile.n_cu + ) arch_enum = getattr(_origami.architecture_t, profile.arch_enum_name) hw = _origami.get_hardware_for_arch( arch_enum, @@ -420,7 +422,11 @@ def _get_hardware(self): override_tag = "" if self._clock_override_mhz is not None: override_tag = " (overridden via --gpu-clock-mhz)" - cu_tag = f" (n_cu_override={n_cu})" if self._n_cu_override is not None else "" + cu_tag = ( + f" (n_cu_override={n_cu})" + if self._n_cu_override is not None + else "" + ) print( f"[Primus:Origami] Using hardware profile for " f"'{self._gpu_arch}': N_CU={n_cu}, " @@ -475,7 +481,9 @@ def _get_hardware(self): override_tag = "" if self._clock_override_mhz is not None: override_tag = " (overridden via --gpu-clock-mhz)" - cu_tag = f" (n_cu_override={n_cu})" if self._n_cu_override is not None else "" + cu_tag = ( + f" (n_cu_override={n_cu})" if self._n_cu_override is not None else "" + ) print( f"[Primus:Origami] Using known hardware profile for " f"'{self._gpu_arch}': N_CU={n_cu}, " diff --git a/primus/core/projection/simulation_backends/sdpa_simulator.py b/primus/core/projection/simulation_backends/sdpa_simulator.py index 8b957321d..ad1e2a907 100644 --- a/primus/core/projection/simulation_backends/sdpa_simulator.py +++ b/primus/core/projection/simulation_backends/sdpa_simulator.py @@ -306,7 +306,16 @@ def simulate_sdpa( bpe = self._bytes_per_element(dtype) return self._simulate_tile_level( - B, H_Q, S_Q, S_K, H_K, D_qk, D_v, causal, dtype, bpe, + B, + H_Q, + S_Q, + S_K, + H_K, + D_qk, + D_v, + causal, + dtype, + bpe, ) # ------------------------------------------------------------------ @@ -403,16 +412,20 @@ def _simulate_tile_level( # Per-workgroup GEMMs on 1 CU (tile sweeps all S_K positions): # QKᵀ: [q_tile_m, D_qk, S_K] r_fwd_qk = self._tile_gemm.simulate_gemm( - m=_FAV3_FWD.q_tile_m, n=S_K, k=D_qk, dtype=dtype, + m=_FAV3_FWD.q_tile_m, + n=S_K, + k=D_qk, + dtype=dtype, ) # PV: [q_tile_m, S_K, D_v] r_fwd_pv = self._tile_gemm.simulate_gemm( - m=_FAV3_FWD.q_tile_m, n=D_v, k=S_K, dtype=dtype, + m=_FAV3_FWD.q_tile_m, + n=D_v, + k=S_K, + dtype=dtype, ) - fwd_time_ms = ( - r_fwd_qk.forward_time_ms + r_fwd_pv.forward_time_ms - ) * fwd_waves + fwd_time_ms = (r_fwd_qk.forward_time_ms + r_fwd_pv.forward_time_ms) * fwd_waves # ============================================================== # BACKWARD @@ -424,23 +437,38 @@ def _simulate_tile_level( # Per-workgroup GEMMs (5 operations, full Q-sweep on 1 CU): # 1. QKᵀ recompute: [S_Q, D_qk, kv_tile] r_bwd_qk = self._tile_gemm.simulate_gemm( - m=S_Q, n=kv_tile, k=D_qk, dtype=dtype, + m=S_Q, + n=kv_tile, + k=D_qk, + dtype=dtype, ) # 2. dP = dO × Vᵀ: [S_Q, D_v, kv_tile] r_bwd_dp = self._tile_gemm.simulate_gemm( - m=S_Q, n=kv_tile, k=D_v, dtype=dtype, + m=S_Q, + n=kv_tile, + k=D_v, + dtype=dtype, ) # 3. dV = Pᵀ × dO: [kv_tile, S_Q, D_v] r_bwd_dv = self._tile_gemm.simulate_gemm( - m=kv_tile, n=D_v, k=S_Q, dtype=dtype, + m=kv_tile, + n=D_v, + k=S_Q, + dtype=dtype, ) # 4. dQ = dS × K: [S_Q, kv_tile, D_qk] r_bwd_dq = self._tile_gemm.simulate_gemm( - m=S_Q, n=D_qk, k=kv_tile, dtype=dtype, + m=S_Q, + n=D_qk, + k=kv_tile, + dtype=dtype, ) # 5. dK = dSᵀ × Q: [kv_tile, S_Q, D_qk] r_bwd_dk = self._tile_gemm.simulate_gemm( - m=kv_tile, n=D_qk, k=S_Q, dtype=dtype, + m=kv_tile, + n=D_qk, + k=S_Q, + dtype=dtype, ) bwd_compute_ms = ( @@ -456,14 +484,10 @@ def _simulate_tile_level( # The latency model counts warp-level reduction updates (global and # local) and multiplies by the per-op latency. num_k_tiles = math.ceil(kv_tile / kv_tile) # = 1 - warp_updates_global = math.ceil( - num_k_tiles * math.ceil(D_qk / _WARP_SIZE) - ) + warp_updates_global = math.ceil(num_k_tiles * math.ceil(D_qk / _WARP_SIZE)) total_updates_global = warp_updates_global * bwd_waves - warp_updates_local = math.ceil( - kv_tile * math.ceil(D_qk / _WARP_SIZE) - ) + warp_updates_local = math.ceil(kv_tile * math.ceil(D_qk / _WARP_SIZE)) total_updates_local = warp_updates_local * bwd_waves bwd_atomic_ms = ( @@ -566,4 +590,3 @@ def _simulate_tile_level( def _bytes_per_element(self, dtype: str) -> int: return {"bf16": 2, "fp16": 2, "fp32": 4, "fp8": 1}.get(dtype, 2) - From e21d40f04f7607a01d3fec213b664a7f48a4caa0 Mon Sep 17 00:00:00 2001 From: Primus Date: Wed, 25 Feb 2026 06:45:35 +0000 Subject: [PATCH 09/12] Add inference (prefill/decode) modes to performance projection tool Extend the projection tool to support inference workloads in addition to training: - Add --mode {training,prefill,decode} CLI argument - Prefill mode: forward-only benchmark, reports prefill latency and tokens/s with inference-specific communication modeling (no backward, no optimizer, no gradient AllReduce) - Decode mode: analytical model based on HBM bandwidth roofline, with optional benchmark-enhanced layer timing. Models KV cache read time, autoregressive token generation, and per-token latency - Update docs/projection.md with usage examples, output samples, and detailed explanations of both inference modes --- docs/projection.md | 288 ++++- primus/cli/subcommands/projection.py | 65 +- .../performance_projection/projection.py | 1133 ++++++++++++++++- 3 files changed, 1455 insertions(+), 31 deletions(-) diff --git a/docs/projection.md b/docs/projection.md index 5c5212026..5a2d93f58 100644 --- a/docs/projection.md +++ b/docs/projection.md @@ -1,6 +1,6 @@ # Performance Projection -Primus includes a performance projection tool that benchmarks transformer layers on a single node and projects training iteration times to multi-node configurations. +Primus includes a performance projection tool that benchmarks transformer layers on a single node and projects performance to multi-node configurations for **training**, **prefill**, and **decode** workloads. - **User-facing entry**: `primus-cli … -- projection performance [options]` - **Implementation entrypoint**: `primus/cli/subcommands/projection.py` @@ -14,15 +14,32 @@ The performance projection tool: 2. **Simulates** pipeline parallelism scheduling (including zero-bubble optimization) 3. **Projects** performance to multi-node configurations by modeling: - Data Parallelism (DP) scaling - - Gradient AllReduce communication overhead + - Gradient AllReduce communication overhead (training only) - Expert Parallelism (EP) All-to-All communication overhead - Inter-node vs intra-node communication differences -This allows you to estimate training performance on larger clusters without actually running on them. +This allows you to estimate training or inference performance on larger clusters without actually running on them. + +### Training vs Prefill vs Decode + +| Aspect | Training (`--mode training`) | Prefill (`--mode prefill`) | Decode (`--mode decode`) | +|--------|------------------------------|----------------------------|--------------------------| +| Compute | Forward + Backward + Wgrad | Forward only | Analytical (tiny GEMMs) | +| Bottleneck | Compute-bound | Compute-bound | **Memory-bandwidth-bound** (typically) | +| Optimizer | Adam step (HBM-bound) | None | None | +| Gradient sync | AllReduce across DP ranks | None | None | +| MoE A2A | Forward + Backward | Forward only | Analytical latency estimate | +| Pipeline sim | 1F1B / Zero-bubble schedule | Skipped | Skipped | +| FSDP | AllGather + ReduceScatter | None | None | +| DP scaling | Reduces microbatches per GPU | Adds independent replicas | Adds independent replicas | +| GPU required | Yes (benchmark) | Yes (benchmark) | No (simulate) / Yes (benchmark) | +| Metrics | Iteration time, tokens/s/GPU | Prefill latency, tokens/s/replica | Time/token, tokens/s, generation time | ## Quick Start -Run a basic performance projection for the minimum required nodes: +### Training Projection (default) + +Run a basic training performance projection for the minimum required nodes: ```bash export NNODES=1 @@ -33,7 +50,7 @@ bash runner/primus-cli direct --script primus/cli/main.py -- \ --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml ``` -Project performance to a specific number of nodes: +Project training performance to a specific number of nodes: ```bash export NNODES=1 @@ -45,6 +62,74 @@ bash runner/primus-cli direct --script primus/cli/main.py -- \ --target-nodes 4 ``` +### Prefill Projection + +Project inference prefill performance using the same config (`--mode inference` is an alias for `--mode prefill`): + +```bash +export NNODES=1 +export HSA_NO_SCRATCH_RECLAIM=1 + +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode prefill +``` + +Project prefill performance to multiple nodes (scales DP replicas): + +```bash +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode prefill --target-nodes 4 +``` + +### Decode Projection + +Project autoregressive decode (token generation) latency using the **analytical model** (no GPU required): + +```bash +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode decode --profiling-mode simulate +``` + +For **benchmark-enhanced** decode (real GPU measurement + analytical KV cache overlay): + +```bash +export NNODES=1 +export HSA_NO_SCRATCH_RECLAIM=1 + +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode decode +``` + +Customize decode parameters: + +```bash +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode decode \ + --decode-batch-size 32 \ + --decode-context-length 4096 \ + --num-generated-tokens 256 \ + --target-nodes 4 +``` + +Compare analytical and benchmark-enhanced side by side: + +```bash +bash runner/primus-cli direct --script primus/cli/main.py -- \ + projection performance \ + --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml \ + --mode decode --profiling-mode both +``` + ## Command Syntax ```bash @@ -58,6 +143,10 @@ primus-cli [global-options] [mode-args] -- projection performance [option | `--config` | string | Path to the Primus YAML configuration file (required) | | `--target-nodes` | int | Target number of nodes for projection. Defaults to minimum required by parallelism config | | `--hardware-config` | string | Path to YAML file with custom hardware parameters for communication modeling | +| `--mode` | string | `training` (default), `prefill`, `decode`, or `inference` (alias for prefill). Training projects full iteration time; prefill projects forward-only latency; decode projects per-token generation latency (supports both analytical and benchmark-enhanced profiling) | +| `--decode-batch-size` | int | Number of concurrent sequences being decoded (decode mode only). Defaults to micro_batch_size | +| `--decode-context-length` | int | Number of tokens already in the KV cache (decode mode only). Defaults to sequence_length | +| `--num-generated-tokens` | int | Number of tokens to generate (decode mode only). Used for total generation time estimate. Default: 128 | ### Parallelism Overrides @@ -163,6 +252,22 @@ Gradient AllReduce Size = num_params × 4 (FP32 gradients) **What it does**: Splits sequence length across GPUs for long-context training. +**How it interacts with other dimensions**: CP behaves differently for dense and MoE models: + +- **Dense models**: CP is an independent parallelism axis. It directly multiplies the GPU count: + ``` + Total GPUs = TP × PP × CP × DP + ``` + Each CP rank holds a different chunk of the sequence and they communicate via AllGather/AllToAll during attention. + +- **MoE models (with Parallel Folding)**: CP is **folded into** EP — the CP ranks are a subset of the EP ranks. The constraints are `CP ≤ EP` and `EP % CP == 0`. CP does **not** add extra GPUs beyond EP: + ``` + Total GPUs = TP × PP × EP × DP + ``` + Within the EP group, CP determines how many of the EP ranks share context-parallel work on attention. The remaining `EP / CP` factor provides inner data-parallel streams for the attention layers. For the MoE FFN layers, all EP ranks participate in expert parallelism as usual. + + For example, with EP=8 and CP=4: the 8 EP ranks form 2 groups of 4 for context-parallel attention, giving 2 inner-DP streams. Traditional (unfolded) parallelism would require `EP × CP = 32` GPUs; with folding it requires only `EP = 8`. + **How it's modeled**: CP affects the GPU topology for communication routing. Currently included in minimum GPU requirements calculation. ## Communication Modeling @@ -183,16 +288,25 @@ Communication times differ significantly based on: #### Minimum Nodes Required -The minimum nodes required is determined by: +The minimum nodes required depends on the model type: + +**Dense models** (no expert parallelism): ``` -Min Nodes = ceil(TP × PP × EP × CP / GPUs_per_node) +Min GPUs = TP × PP × CP +Min Nodes = ceil(Min GPUs / GPUs_per_node) +``` + +**MoE models** (with MoE Parallel Folding, where CP is folded into EP): +``` +Min GPUs = TP × PP × EP (CP ≤ EP, folded in) +Min Nodes = ceil(Min GPUs / GPUs_per_node) ``` #### Scaling Behavior - **DP scaling**: Linear speedup. Doubling DP halves iteration time (minus communication overhead). - **PP scaling**: Happens in multiples of pipeline replicas. With PP=3, you need 3, 6, 9... nodes to increase scaling. -- **EP scaling**: Divides the experts on EP nodes. +- **EP scaling**: Divides the experts across EP ranks. For MoE models, EP also subsumes the CP dimension. ## Example Output @@ -226,12 +340,170 @@ Multinode Scaling Projection Results ==================================================================================================== ``` +## Prefill Mode Details + +When `--mode prefill` (or `--mode inference`) is specified, the projection tool: + +1. **Benchmarks layers the same way** as training (same `--profiling-mode` options work) +2. **Uses only forward pass times** — backward/wgrad times are discarded +3. **Skips pipeline simulation** — inference doesn't use 1F1B scheduling +4. **Skips optimizer and gradient communication** — no backward pass means no gradients +5. **Reports prefill latency** — time for one forward pass through the full model +6. **Scales DP as replicas** — more nodes = more independent serving replicas, each with the same latency + +### Prefill Output Example + +``` +==================================================================================================== +Inference Projection Results +==================================================================================================== +📊 Parallelism: TP=1, PP=1, EP=8, CP=1 + +🎯 Target Configuration (4 nodes): + Nodes: 4, GPUs: 32 + TP=1, PP=1, EP=8, CP=1, DP(replicas)=4 + Prefill Latency: 125.456 ms (seq_len=4096, micro_batch=1) + Tokens/s per replica: 32,642 + Tokens/s total (4 replicas): 130,568 + Tokens/s/GPU: 4,080 +==================================================================================================== +``` + +### Key Prefill Metrics + +| Metric | Description | +|--------|-------------| +| Prefill Latency | Time for one forward pass (one batch of sequences) | +| Tokens/s per replica | Throughput of a single model instance | +| Tokens/s total | Aggregate throughput across all DP replicas | +| Tokens/s/GPU | Per-GPU throughput efficiency | + +## Decode Mode Details + +Decode mode supports two profiling approaches: + +| Approach | `--profiling-mode` | GPU Required | What it measures | +|----------|-------------------|--------------|-----------------| +| **Analytical** | `simulate` | No | Roofline model: HBM BW × (weights + KV cache) | +| **Benchmark-enhanced** | `benchmark` (default) | Yes | Real GPU GEMM timing (seq_len=1) + analytical KV cache overlay | +| **Both** | `both` | Yes | Runs both side-by-side for comparison | + +### How Decode Differs + +Autoregressive decode generates tokens **one at a time**. Each decode step: +- Loads all model weights from HBM (weights dominate memory access) +- Reads the KV cache (grows with context length) +- Performs tiny GEMMs of shape `[batch, 1, …]` — far below the compute roofline + +This makes decode fundamentally **memory-bandwidth-bound** for typical batch sizes. + +### Analytical Model (`--profiling-mode simulate`) + +Pure roofline model — no GPU needed. The decode time per token is estimated as: + +``` +time_per_token ≈ max(memory_time, compute_time) + comm_overhead +``` + +Where: +- **memory_time** = (total_weight_bytes + kv_cache_bytes) / HBM_bandwidth +- **compute_time** = total_FLOPs / peak_TFLOPS (usually much smaller) +- **comm_overhead** = TP AllReduce latency + PP P2P latency + MoE A2A latency + +### Benchmark-Enhanced Model (`--profiling-mode benchmark`) + +When a GPU is available, the tool can benchmark the actual transformer layers with `seq_len=1` to capture real kernel timings: + +1. **Benchmarks** each layer with `seq_len=1` and `batch_size=decode_batch_size` +2. **Extracts forward-only time** — captures real GEMM costs, TP AllReduce, kernel launch overhead +3. **Overlays analytical KV cache model** — the benchmark's 1-token self-attention doesn't include KV cache reads, so we add them analytically + +``` +decode_time = benchmarked_layer_forward(seq_len=1) + analytical_kv_cache_read + comm_overhead +``` + +This gives the best of both worlds: real hardware measurements for GEMMs plus a physics-based model for KV cache. + +### Decode Output Example (Analytical) + +``` +==================================================================================================== +Decode Projection Results (Analytical Model) +==================================================================================================== +📊 Parallelism: TP=1, PP=1, EP=8, CP=1 + Decode batch size: 32 + Context length (KV cache): 4096 + GPU arch: mi300x + HBM bandwidth: 5300 GB/s + Peak BF16 compute: 1307 TFLOPS + +⏱️ Per-Token Decode Time Breakdown: + Weight loading: 2.8453 ms (14450.0 MB) + KV cache read: 0.3124 ms (1587.5 MB) + Compute: 0.0012 ms (1.20 GFLOPS) + Bottleneck: MEMORY-BOUND (arith intensity=0.07 FLOPs/B, balance=247 FLOPs/B) + ───────────────────────────────────── + Total per token: 3.1577 ms + +🎯 Target Configuration (4 nodes): + ... +==================================================================================================== +``` + +### Decode Output Example (Benchmark-Enhanced) + +``` +==================================================================================================== +Decode Projection Results (Benchmark-Enhanced) +==================================================================================================== +📊 Parallelism: TP=1, PP=1, EP=8, CP=1 + Decode batch size: 32 + Context length (KV cache): 4096 + +⏱️ Per-Token Decode Time Breakdown: + Layer fwd (benchmarked, seq_len=1): 3.0215 ms (includes GEMMs + TP AllReduce) + KV cache read (analytical): 0.3124 ms (1587.5 MB) + ─── Analytical comparison ─── + Weight loading (analytical): 2.8453 ms (14450.0 MB) + Total (analytical): 3.1577 ms + ──────────────────────────── + ───────────────────────────────────── + Total per token: 3.3339 ms + +🎯 Target Configuration (4 nodes): + ... +==================================================================================================== +``` + +### Key Decode Metrics + +| Metric | Description | +|--------|-------------| +| Per-token latency | Time to generate one output token (all sequences in the batch) | +| Tokens/s per replica | Decode throughput of a single model instance | +| Tokens/s total | Aggregate decode throughput across all DP replicas | +| Tokens/s/GPU | Per-GPU decode throughput efficiency | +| Generation time | Estimated wall-clock time to generate N tokens | +| Bottleneck | Whether the decode step is memory-bound or compute-bound (analytical only) | +| KV cache memory | Memory consumed by the KV cache (grows with batch × context length) | + +### Decode-Specific CLI Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--decode-batch-size` | micro_batch_size | Number of sequences decoded concurrently. Larger batches amortize weight loading but increase KV cache memory | +| `--decode-context-length` | sequence_length | Tokens in the KV cache. Longer contexts increase KV cache read time per step | +| `--num-generated-tokens` | 128 | How many tokens to generate. Used to estimate total generation wall-clock time | +| `--profiling-mode` | benchmark | `simulate` for pure analytical (no GPU), `benchmark` for real GPU measurement + KV analytical overlay, `both` for side-by-side comparison | + ## Tips - **Start with 1 node**: Always benchmark on 1 node first to establish baseline performance. - **Understand scaling limits**: DP scaling is limited by global_batch_size / micro_batch_size. If you run out of microbatches, adding more nodes won't help. - **Check minimum nodes**: If your config requires multiple nodes (e.g., PP=4 with 8 GPUs/node), projection will automatically reduce PP for benchmarking. - **Pipeline scaling**: With PP > 1, you can only scale in multiples of the pipeline replica size. +- **Prefill mode**: Use `--mode prefill` (or `--mode inference`) to project prefill latency. The same config file works for both training and prefill projection. +- **Decode mode**: Use `--mode decode` to estimate per-token generation latency. With `--profiling-mode simulate` it runs instantly without a GPU; with `--profiling-mode benchmark` (default) it runs real GEMMs with seq_len=1 and overlays analytical KV cache timing. Use `--profiling-mode both` to compare the two approaches. Tune `--decode-batch-size` and `--decode-context-length` to explore the latency/memory trade-off. ## Related Documentation diff --git a/primus/cli/subcommands/projection.py b/primus/cli/subcommands/projection.py index 8ef62985a..bb0ecfcfb 100644 --- a/primus/cli/subcommands/projection.py +++ b/primus/cli/subcommands/projection.py @@ -14,10 +14,21 @@ def run(args, overrides): launch_projection_from_cli(args, overrides) elif args.suite == "performance": + # Normalise mode: "inference" is an alias for "prefill" + mode = getattr(args, "mode", "training") + if mode == "inference": + args.mode = "prefill" + profiling_mode = getattr(args, "profiling_mode", "benchmark") - if profiling_mode != "simulate": - # Benchmark or "both" modes need the Megatron backend + # Decode + simulate is fully analytical — no backend needed. + # Decode + benchmark/both needs the Megatron backend (runs real layers + # with seq_len=1 to measure decode-step GEMMs on the GPU). + needs_backend = profiling_mode != "simulate" + if mode == "decode" and profiling_mode == "simulate": + needs_backend = False + + if needs_backend: from primus.pretrain import setup_backend_path setup_backend_path(framework="megatron", verbose=True) @@ -146,6 +157,56 @@ def register_subcommand(subparsers): "Example: --gpu-clock-mhz 1500\n" ), ) + performance.add_argument( + "--mode", + type=str, + required=False, + default="training", + choices=["training", "inference", "prefill", "decode"], + help=( + "Projection mode:\n" + " training - Project training iteration time (forward + backward +\n" + " optimizer step + gradient AllReduce). Default.\n" + " inference - Alias for 'prefill'.\n" + " prefill - Project inference prefill latency (forward-only, no\n" + " backward pass, optimizer, or gradient communication).\n" + " decode - Project autoregressive decode latency per token.\n" + " With --profiling-mode simulate: fully analytical (no GPU).\n" + " With --profiling-mode benchmark: benchmarks GEMMs with\n" + " seq_len=1 on GPU, overlays analytical KV cache model.\n" + ), + ) + performance.add_argument( + "--decode-batch-size", + type=int, + required=False, + default=None, + help=( + "Number of sequences being decoded concurrently (decode mode only).\n" + "Defaults to micro_batch_size from the config.\n" + ), + ) + performance.add_argument( + "--decode-context-length", + type=int, + required=False, + default=None, + help=( + "Current context length during decode, i.e. number of previous tokens\n" + "in the KV cache (decode mode only). Affects KV cache read time.\n" + "Defaults to sequence_length from the config.\n" + ), + ) + performance.add_argument( + "--num-generated-tokens", + type=int, + required=False, + default=None, + help=( + "Number of tokens to generate (decode mode only). Used to estimate\n" + "total generation time. Defaults to 128.\n" + ), + ) parser.set_defaults(func=run) diff --git a/primus/core/projection/performance_projection/projection.py b/primus/core/projection/performance_projection/projection.py index a76c85f91..52d29b169 100644 --- a/primus/core/projection/performance_projection/projection.py +++ b/primus/core/projection/performance_projection/projection.py @@ -38,6 +38,33 @@ _MAX_EXPERT_PARALLEL_SIZE = 8 _BYTES_PER_GB = 1024**3 +# Projection mode constants +MODE_TRAINING = "training" +MODE_INFERENCE = "inference" +MODE_PREFILL = "prefill" +MODE_DECODE = "decode" + + +def _calculate_min_gpus(tp, pp, ep, cp): + """Calculate minimum GPUs required by parallelism config. + + For MoE models (EP > 1), CP is folded into EP via MoE Parallel Folding: + the CP ranks are a subset of the EP ranks, so the minimum GPU count is + TP × PP × EP. Constraints: CP ≤ EP and EP % CP == 0. + + For dense models (EP ≤ 1), CP is an independent parallelism axis, so the + minimum GPU count is TP × PP × CP. + + Note: DP is *not* affected by this folding — EP borrows from the DP + dimension, so DP = world_size / (TP × PP × CP) in both cases. + """ + if ep > 1: + # MoE: CP is folded into EP (MoE Parallel Folding) + return tp * pp * ep + else: + # Dense: CP is an independent axis + return tp * pp * cp + # HBM bandwidth (GB/s) by GPU architecture — used for optimizer step estimation _HBM_BANDWIDTH_GBPS: Dict[str, float] = { "mi300x": 5300.0, @@ -47,6 +74,235 @@ "gfx950": 8000.0, } +# Peak BF16 TFLOPS per GPU — used for compute roofline in decode estimation +_PEAK_TFLOPS_BF16: Dict[str, float] = { + "mi300x": 1307.0, + "gfx942": 1307.0, + "mi325x": 1307.0, + "mi355x": 2611.0, + "gfx950": 2611.0, +} + + +def _get_hw_params(gpu_arch: Optional[str] = None): + """Return (hbm_bandwidth_gb_s, peak_tflops_bf16) for the given GPU arch.""" + arch = (gpu_arch or os.getenv("PRIMUS_GPU_ARCH", "mi300x")).lower().strip() + hbm_bw = _HBM_BANDWIDTH_GBPS.get(arch, 5300.0) + peak_tf = _PEAK_TFLOPS_BF16.get(arch, 1307.0) + return hbm_bw, peak_tf + + +# ========================================================================= +# Analytical Decode Model +# ========================================================================= + + +def _estimate_decode_time_per_token( + training_config, + decode_batch_size: int, + context_length: int, + tp: int = 1, + pp: int = 1, + ep: int = 1, + gpu_arch: Optional[str] = None, + hardware_config: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + Analytically estimate the time for **one decode step** (generating one token + per sequence in the batch). + + Decode is memory-bandwidth-bound for small batch sizes: the GPU must load + all model weights once per step, and read the KV cache. The computation + (tiny GEMMs of shape ``[batch, 1, …]``) is far below the compute roofline. + + Model: + time_per_token ≈ (weight_bytes + kv_cache_bytes) / HBM_bandwidth + + TP_allreduce_latency + + Args: + training_config: TrainingConfig with model / parallel config. + decode_batch_size: Number of sequences decoded concurrently. + context_length: Number of tokens already in the KV cache. + tp, pp, ep: Parallelism dimensions. + gpu_arch: GPU architecture string (for HBM BW / peak TFLOPS). + hardware_config: Optional custom HW config dict. + + Returns: + Dict with detailed breakdown (all times in milliseconds). + """ + mc = training_config.model_config + hidden = mc.hidden_size + num_layers = mc.num_layers + num_heads = mc.num_attention_heads + head_dim = mc.kv_channels or (hidden // num_heads) + gqa = mc.group_query_attention + num_kv_heads = mc.num_query_groups if gqa else num_heads + ffn_hidden = mc.ffn_hidden_size or (hidden * 4) + vocab_size = mc.padded_vocab_size or 100352 + moe_pattern = mc.moe_pattern or [0] * num_layers + num_experts = mc.num_experts or 0 + moe_ffn = mc.moe_ffn_hidden_size or ffn_hidden + moe_topk = mc.moe_router_topk or 1 + shared_expert_size = mc.moe_shared_expert_intermediate_size or 0 + + bytes_per_param = 2 # BF16 + + hbm_bw_gbps, peak_tflops = _get_hw_params(gpu_arch) + hbm_bw_bytes_per_ms = hbm_bw_gbps * 1e9 / 1e3 # bytes / ms + peak_flops_per_ms = peak_tflops * 1e12 / 1e3 # FLOPs / ms + + # Per-layer weight bytes (sharded by TP) + # Attention: Q(h→h) + K(h→kv_dim) + V(h→kv_dim) + O(h→h) + kv_dim = num_kv_heads * head_dim + attn_weight_bytes = ( + hidden * hidden # Q + + hidden * kv_dim # K + + hidden * kv_dim # V + + hidden * hidden # O + ) * bytes_per_param // tp + + # Dense MLP: gate(h→ffn) + up(h→ffn) + down(ffn→h) (SwiGLU has gate+up) + dense_mlp_weight_bytes = 3 * hidden * ffn_hidden * bytes_per_param // tp + + # MoE MLP weights per GPU (num_experts/EP experts, each gate+up+down, divided by expert_TP) + expert_tp = 1 # expert TP typically 1 + if num_experts > 0: + experts_per_gpu = max(num_experts // max(ep, 1), 1) + moe_mlp_weight_bytes = ( + 3 * hidden * moe_ffn * experts_per_gpu * bytes_per_param // expert_tp + ) + # Router weight: (hidden → num_experts) + router_weight_bytes = hidden * num_experts * bytes_per_param + # Shared expert (if any) + shared_expert_weight_bytes = ( + 3 * hidden * shared_expert_size * bytes_per_param // tp + if shared_expert_size > 0 + else 0 + ) + else: + moe_mlp_weight_bytes = 0 + router_weight_bytes = 0 + shared_expert_weight_bytes = 0 + + # KV cache bytes per layer per decode step + # Read K: batch × num_kv_heads/TP × context_len × head_dim × bytes + # Read V: same + kv_heads_per_gpu = max(num_kv_heads // tp, 1) + kv_cache_per_layer_bytes = ( + 2 * decode_batch_size * kv_heads_per_gpu * context_length * head_dim + * bytes_per_param + ) + + # Compute FLOPs per layer per decode step (batch × 1 token) + # Linear projections: 2 * batch * M * N (M=1, N=weight cols) + attn_proj_flops = ( + 2 * decode_batch_size * ( + hidden * hidden # Q + + hidden * kv_dim # K + + hidden * kv_dim # V + + hidden * hidden # O + ) + ) // tp + dense_mlp_flops = 2 * decode_batch_size * 3 * hidden * ffn_hidden // tp + + # Attention score + V multiply: batch × num_heads/TP × context_len × head_dim × 2 + heads_per_gpu = max(num_heads // tp, 1) + attn_score_flops = 2 * decode_batch_size * heads_per_gpu * context_length * head_dim # Q·K^T + attn_v_flops = 2 * decode_batch_size * heads_per_gpu * context_length * head_dim # score·V + attn_kv_flops = attn_score_flops + attn_v_flops + + # MoE compute: each token routed to topk experts + if num_experts > 0: + moe_mlp_flops = ( + 2 * decode_batch_size * moe_topk * 3 * hidden * moe_ffn // max(ep, 1) + ) + else: + moe_mlp_flops = 0 + + # ---- Aggregate across all layers ---- + num_dense_layers = sum(1 for p in moe_pattern if p == 0) + num_moe_layers = sum(1 for p in moe_pattern if p == 1) + layers_per_pp = num_layers // max(pp, 1) + + total_weight_bytes = ( + attn_weight_bytes * num_layers # attention in every layer + + dense_mlp_weight_bytes * num_dense_layers + + (moe_mlp_weight_bytes + router_weight_bytes + shared_expert_weight_bytes) + * num_moe_layers + ) // max(pp, 1) + + total_kv_bytes = kv_cache_per_layer_bytes * layers_per_pp + + # Embedding + output layer weights + embedding_weight_bytes = vocab_size * hidden * bytes_per_param // tp # only on first PP stage + output_weight_bytes = vocab_size * hidden * bytes_per_param // tp # only on last PP stage + # Amortise across PP stages + total_weight_bytes += (embedding_weight_bytes + output_weight_bytes) // max(pp, 1) + + total_compute_flops = ( + (attn_proj_flops + attn_kv_flops) * layers_per_pp + + dense_mlp_flops * (num_dense_layers // max(pp, 1)) + + moe_mlp_flops * (num_moe_layers // max(pp, 1)) + ) + + # ---- Roofline: memory time vs compute time ---- + memory_time_ms = (total_weight_bytes + total_kv_bytes) / hbm_bw_bytes_per_ms + compute_time_ms = total_compute_flops / peak_flops_per_ms + + # TP AllReduce latency (2× per layer: after attention + after MLP) + # For decode, messages are tiny (batch × hidden × 2 bytes) — latency-dominated + tp_latency_per_ar_us = 5.0 # ~5 µs typical intra-node AllReduce latency + if hardware_config: + tp_latency_per_ar_us = hardware_config.get("intra_node_latency_us", 1.0) * 3 + tp_allreduce_count = 2 * layers_per_pp # attn + MLP per layer + tp_overhead_ms = (tp_allreduce_count * tp_latency_per_ar_us / 1000) if tp > 1 else 0 + + # PP overhead (serial forward through PP stages, activation P2P) + pp_p2p_latency_us = 5.0 + if hardware_config: + pp_p2p_latency_us = hardware_config.get("intra_node_latency_us", 1.0) * 3 + pp_overhead_ms = ((pp - 1) * pp_p2p_latency_us / 1000) if pp > 1 else 0 + + # EP All-to-All for MoE decode (tiny messages, latency-dominated) + ep_a2a_latency_us = 10.0 + if hardware_config: + ep_a2a_latency_us = hardware_config.get("intra_node_latency_us", 1.0) * 5 + moe_a2a_per_layer_ms = (2 * ep_a2a_latency_us / 1000) if ep > 1 else 0 # dispatch + combine + moe_a2a_total_ms = moe_a2a_per_layer_ms * (num_moe_layers // max(pp, 1)) + + # Total decode time per token + decode_time_ms = max(memory_time_ms, compute_time_ms) + tp_overhead_ms + pp_overhead_ms + moe_a2a_total_ms + + # Arithmetic intensity (FLOPs / byte) — shows how memory-bound we are + total_bytes_accessed = total_weight_bytes + total_kv_bytes + arith_intensity = total_compute_flops / total_bytes_accessed if total_bytes_accessed > 0 else 0 + # Machine balance point: peak_flops / hbm_bw + balance_point = (peak_tflops * 1e12) / (hbm_bw_gbps * 1e9) + + return { + "decode_time_ms": decode_time_ms, + "memory_time_ms": memory_time_ms, + "compute_time_ms": compute_time_ms, + "tp_overhead_ms": tp_overhead_ms, + "pp_overhead_ms": pp_overhead_ms, + "moe_a2a_total_ms": moe_a2a_total_ms, + "total_weight_bytes": total_weight_bytes, + "total_kv_bytes": total_kv_bytes, + "total_weight_mb": total_weight_bytes / (1024 * 1024), + "total_kv_mb": total_kv_bytes / (1024 * 1024), + "total_compute_tflops": total_compute_flops / 1e12, + "arithmetic_intensity": arith_intensity, + "balance_point": balance_point, + "is_memory_bound": arith_intensity < balance_point, + "hbm_bw_gbps": hbm_bw_gbps, + "peak_tflops": peak_tflops, + "decode_batch_size": decode_batch_size, + "context_length": context_length, + "layers_per_pp": layers_per_pp, + "num_dense_layers": num_dense_layers, + "num_moe_layers": num_moe_layers, + } + def _estimate_optimizer_step_ms( training_config, @@ -650,6 +906,121 @@ def extract_single_node_time_from_profiling( return total_time_ms +def extract_single_node_time_inference( + profiling_results: dict, training_config +) -> float: + """ + Extract total single-node **forward-only** time from profiling results. + + This is the inference counterpart of :func:`extract_single_node_time_from_profiling`. + It uses only the forward pass timings (no backward, no recomputation overhead). + + Args: + profiling_results: Dict with integer keys for layers and "embedding", "output" + training_config: Training configuration containing model config + + Returns: + Total forward-only time in milliseconds for the full model (one pass) + """ + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + + if is_rank_0: + print( + "[Primus:Inference Projection] Extracting forward-only timing from benchmark results..." + ) + print("-" * 100) + + model_config = training_config.model_config + moe_pattern = model_config.moe_pattern + + num_total_layers = len(moe_pattern) + + profiled_layer_indices = sorted( + [k for k in profiling_results.keys() if isinstance(k, int)] + ) + if is_rank_0: + print(f" Profiled layers: {profiled_layer_indices}") + print(f" Full model has {num_total_layers} transformer layers") + + total_time_ms = 0.0 + + # Embedding layer (forward only) + if "embedding" in profiling_results: + emb = profiling_results["embedding"] + emb_time = emb.get("forward_time_ms", 0) + total_time_ms += emb_time + if is_rank_0: + print(f" Embedding (fwd): {emb_time:.2f} ms") + + # Analyze profiled transformer layers — forward only + profiled_dense_fwd = [] + profiled_moe_fwd = [] + + for layer_idx in profiled_layer_indices: + if layer_idx < len(moe_pattern): + layer_data = profiling_results[layer_idx] + fwd_time = layer_data.get("forward_time_ms", 0) + + if moe_pattern[layer_idx] == 0: + profiled_dense_fwd.append(fwd_time) + else: + profiled_moe_fwd.append(fwd_time) + + avg_dense_fwd = ( + sum(profiled_dense_fwd) / len(profiled_dense_fwd) + if profiled_dense_fwd + else 0 + ) + avg_moe_fwd = ( + sum(profiled_moe_fwd) / len(profiled_moe_fwd) + if profiled_moe_fwd + else 0 + ) + + num_dense_layers = sum(1 for x in moe_pattern if x == 0) + num_moe_layers = sum(1 for x in moe_pattern if x == 1) + + total_dense_time = avg_dense_fwd * num_dense_layers + total_moe_time = avg_moe_fwd * num_moe_layers + total_transformer_time = total_dense_time + total_moe_time + total_time_ms += total_transformer_time + + if is_rank_0: + if profiled_dense_fwd: + print( + f" Dense Layers: {len(profiled_dense_fwd)} profiled → {num_dense_layers} total" + ) + print(f" Avg fwd per layer: {avg_dense_fwd:.2f} ms") + print(f" Total fwd time: {total_dense_time:.2f} ms") + + if profiled_moe_fwd: + print( + f" MoE Layers: {len(profiled_moe_fwd)} profiled → {num_moe_layers} total" + ) + print(f" Avg fwd per layer: {avg_moe_fwd:.2f} ms") + print(f" Total fwd time: {total_moe_time:.2f} ms") + + # Output layer (forward only) + if "output" in profiling_results: + out = profiling_results["output"] + out_time = out.get("forward_time_ms", 0) + total_time_ms += out_time + if is_rank_0: + print(f" Output Layer (fwd): {out_time:.2f} ms") + + if is_rank_0: + print("-" * 100) + print( + f"[Primus:Inference Projection] Extrapolated Forward-Only Time: {total_time_ms:.2f} ms" + ) + print( + f" (Based on {len(profiled_layer_indices)} profiled layers → {num_total_layers} total layers)" + ) + print("=" * 100) + + return total_time_ms + + # ============================================================================= # Layer Configuration Functions # ============================================================================= @@ -712,7 +1083,10 @@ def _limit_layers_for_projection(module_config): def _rescale_expert_parallelism(module_config): """ - Cap expert_model_parallel_size so that EP * TP * CP <= 8 and adjust num_experts. + Cap expert_model_parallel_size so that EP * TP <= GPUs_per_node and adjust num_experts. + + With MoE Parallel Folding, CP is folded into EP (CP ranks are a subset of + EP ranks), so the minimum GPUs for a MoE config is EP * TP, not EP * TP * CP. """ expert_mp_size = getattr(module_config, "expert_model_parallel_size", None) if expert_mp_size is None or expert_mp_size <= _MAX_EXPERT_PARALLEL_SIZE: @@ -720,13 +1094,16 @@ def _rescale_expert_parallelism(module_config): current_cp = getattr(module_config, "context_parallel_size", 1) or 1 if expert_mp_size is None: expert_mp_size = 1 - if expert_mp_size * current_tp * current_cp <= _MAX_EXPERT_PARALLEL_SIZE: + # MoE Parallel Folding: CP is folded into EP, so min GPUs = EP * TP + if expert_mp_size * current_tp <= _MAX_EXPERT_PARALLEL_SIZE: return None num_experts = getattr(module_config, "num_experts", None) current_tp = getattr(module_config, "tensor_model_parallel_size", 1) or 1 current_cp = getattr(module_config, "context_parallel_size", 1) or 1 - total_parallel_product = max(1, current_tp * current_cp) + # MoE Parallel Folding: CP is folded into EP, so only TP contributes to + # the per-EP-rank GPU cost. + total_parallel_product = max(1, current_tp) max_ep_allowed = max(1, _MAX_EXPERT_PARALLEL_SIZE // total_parallel_product) new_expert_mp = min(expert_mp_size, _MAX_EXPERT_PARALLEL_SIZE, max_ep_allowed) @@ -792,7 +1169,7 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): cp = getattr(original_config, "context_parallel_size", 1) or 1 num_experts = getattr(original_config, "num_experts", None) - gpus_required = tp * pp * ep * cp + gpus_required = _calculate_min_gpus(tp, pp, ep, cp) nodes_required = (gpus_required + gpus_per_node - 1) // gpus_per_node # If already fits on 1 node, no adjustment needed @@ -812,7 +1189,7 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): # Step 1: Reduce PP to 1 benchmark_pp = 1 - benchmark_gpus_required = tp * benchmark_pp * ep * cp + benchmark_gpus_required = _calculate_min_gpus(tp, benchmark_pp, ep, cp) # Step 2: If still doesn't fit, rescale EP benchmark_ep = ep @@ -831,7 +1208,7 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): if rescale_info: benchmark_ep = rescale_info["ep_after"] benchmark_num_experts = rescale_info.get("num_experts_after", num_experts) - benchmark_gpus_required = tp * benchmark_pp * benchmark_ep * cp + benchmark_gpus_required = _calculate_min_gpus(tp, benchmark_pp, benchmark_ep, cp) if benchmark_gpus_required > gpus_per_node: raise ValueError( @@ -903,7 +1280,7 @@ def _estimate_pp_communication_overhead( # Get hardware setup gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) - gpus_required = tp * pp_size * ep * cp + gpus_required = _calculate_min_gpus(tp, pp_size, ep, cp) num_nodes = (gpus_required + gpus_per_node - 1) // gpus_per_node # Get collective model args @@ -926,8 +1303,9 @@ def _estimate_pp_communication_overhead( p2p_size = batch_size * seq_len * hidden_size * 2 # BF16 # Number of microbatches + # DP = world_size / (TP × PP × CP) — EP excluded (borrows from DP via folding) global_batch_size = runtime_config.global_batch_size - data_parallel_size = (num_nodes * gpus_per_node) // (tp * pp_size * ep * cp) + data_parallel_size = (num_nodes * gpus_per_node) // (tp * pp_size * cp) num_microbatches = global_batch_size // (batch_size * data_parallel_size) # P2P time: 2 * (PP-1) sends per microbatch (forward + backward) @@ -1040,11 +1418,11 @@ def _estimate_ep_communication_overhead( gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) # Calculate nodes required for original EP - gpus_required_original = tp * pp * original_ep * cp + gpus_required_original = _calculate_min_gpus(tp, pp, original_ep, cp) num_nodes_original = (gpus_required_original + gpus_per_node - 1) // gpus_per_node # Calculate nodes for benchmark EP (should be 1) - gpus_required_benchmark = tp * pp * benchmark_ep * cp + gpus_required_benchmark = _calculate_min_gpus(tp, pp, benchmark_ep, cp) num_nodes_benchmark = (gpus_required_benchmark + gpus_per_node - 1) // gpus_per_node # Get collective model args for original EP configuration @@ -1491,6 +1869,7 @@ def _run_layer_benchmark(primus_config, unknown_overrides): print("[Primus:Performance Projection] Initializing Megatron...") trainer.init() + print("[Primus:Performance Projection] Setting up model and optimizer...") trainer.setup() @@ -1813,9 +2192,10 @@ def _run_multinode_projection( cp = getattr(mp_config, "context_model_parallel_size", 1) gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) - # Calculate minimum nodes required by parallelism config - # EP is included in the minimum GPUs calculation (need GPUs to hold experts) - gpus_required = tp * pp * ep * cp + # Calculate minimum nodes required by parallelism config. + # For MoE (EP > 1): CP is folded into EP via MoE Parallel Folding, + # so min GPUs = TP × PP × EP. For dense: min GPUs = TP × PP × CP. + gpus_required = _calculate_min_gpus(tp, pp, ep, cp) min_nodes_required = (gpus_required + gpus_per_node - 1) // gpus_per_node # Validate target >= minimum required @@ -1826,9 +2206,12 @@ def _run_multinode_projection( f"--target-nodes must be >= {min_nodes_required}." ) - # Calculate DP for scaling - EXCLUDES EP (DP scaling is independent of EP) - # EP distributes experts but doesn't affect how many data batches can be processed in parallel - gpus_for_dp = tp * pp * cp # EP excluded for DP calculation + # Calculate DP for scaling. EP is excluded from this divisor because, with + # MoE Parallel Folding, EP borrows from the DP dimension (not from extra + # GPUs). Data-loading DP = world_size / (TP × PP × CP) for both dense and + # MoE models. Within each EP group the CP ranks share context-parallel + # attention work while EP/CP ranks provide inner data-parallel streams. + gpus_for_dp = tp * pp * cp # EP excluded — it borrows from DP total_gpus_target = target_nodes * gpus_per_node dp_target = total_gpus_target // gpus_for_dp @@ -2085,6 +2468,583 @@ def _run_multinode_projection( } +# ============================================================================= +# Inference Projection Functions +# ============================================================================= + + +def calculate_inference_communication_time( + training_config, + num_nodes: int, + gpus_per_node: int, + tp: int, + pp: int, + ep: int, + cp: int, + hardware_config: Dict[str, Any] = None, +) -> Tuple[float, Dict[str, float], Dict[str, Any]]: + """ + Calculate collective communication time for **inference** (forward-only). + + Compared to the training version, this skips: + - Gradient AllReduce (no backward pass) + - Backward MoE All-to-All + - FSDP reduce-scatter (no gradient sharding) + + Returns: + (total_comm_time_ms, breakdown_dict, message_info_dict) + """ + model_config = training_config.model_config + runtime_config = training_config.runtime_config + + coll_args = get_default_args( + num_nodes=num_nodes, + gpus_per_node=gpus_per_node, + tp=tp, + pp=pp, + ep=ep, + cp=cp, + hardware_config=hardware_config, + ) + + hidden_size = model_config.hidden_size + num_layers = model_config.num_layers + moe_router_topk = model_config.moe_router_topk + moe_pattern = model_config.moe_pattern + batch_size = runtime_config.micro_batch_size + seq_len = runtime_config.sequence_length + num_moe_layers = sum(1 for p in moe_pattern if p == 1) + + breakdown = {} + message_info = {} + + # No gradient AllReduce for inference + breakdown["gradient_allreduce"] = 0.0 + message_info["gradient_allreduce_size"] = 0 + message_info["gradient_allreduce_size_mb"] = 0.0 + + # MoE All-to-All — forward only (dispatch + combine, no backward) + if ep > 1 and num_moe_layers > 0: + tokens_per_gpu = seq_len * batch_size // max(tp, 1) + dispatch_size = tokens_per_gpu * hidden_size * moe_router_topk * 2 # BF16 + + a2a_dispatch = cm.alltoall(coll_args, dispatch_size, ep, groups=["ep"]) + a2a_combine = cm.alltoall(coll_args, dispatch_size, ep, groups=["ep"]) + + total_a2a_fwd = (a2a_dispatch + a2a_combine) * num_moe_layers / 1000 # ms + + breakdown["moe_a2a_fwd"] = total_a2a_fwd + message_info["moe_a2a_size"] = dispatch_size + message_info["moe_a2a_size_mb"] = dispatch_size / (1024 * 1024) + message_info["moe_a2a_per_layer_fwd"] = (a2a_dispatch + a2a_combine) / 1000 + message_info["num_moe_layers"] = num_moe_layers + else: + breakdown["moe_a2a_fwd"] = 0.0 + message_info["moe_a2a_size"] = 0 + message_info["moe_a2a_size_mb"] = 0.0 + message_info["moe_a2a_per_layer_fwd"] = 0.0 + message_info["num_moe_layers"] = 0 + + # No backward A2A + # No FSDP communication (no gradient sharding in inference) + breakdown["fsdp_allgather_fwd"] = 0.0 + breakdown["fsdp_reducescatter_bwd"] = 0.0 + message_info["fsdp_enabled"] = False + + message_info["num_layers"] = num_layers + + total_comm_time = sum(breakdown.values()) + return total_comm_time, breakdown, message_info + + +def _run_inference_projection( + training_config, + forward_time_ms: float, + profiling_results, + args, + target_nodes: int, +): + """ + Run inference-mode multinode projection. + + Unlike the training projection, this: + - Uses only forward pass time (no backward, no wgrad) + - Skips optimizer step estimation + - Skips gradient AllReduce + - Reports prefill latency and throughput + + Args: + training_config: Configuration object + forward_time_ms: Measured forward-only time in ms for one pass + profiling_results: Layer profiling results + args: CLI arguments + target_nodes: Target number of nodes for projection + """ + import torch.distributed as dist + + is_rank_0 = not dist.is_initialized() or dist.get_rank() == 0 + + mp_config = training_config.model_parallel_config + + tp = mp_config.tensor_model_parallel_size + pp = mp_config.pipeline_model_parallel_size + ep = getattr(mp_config, "expert_model_parallel_size", 1) + cp = getattr(mp_config, "context_model_parallel_size", 1) + gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) + + gpus_required = tp * pp * ep * cp + min_nodes_required = (gpus_required + gpus_per_node - 1) // gpus_per_node + + if target_nodes < min_nodes_required: + raise ValueError( + f"[Primus:Inference Projection] ERROR: Cannot project to {target_nodes} nodes. " + f"Minimum required by parallelism config is {min_nodes_required} nodes." + ) + + total_gpus_target = target_nodes * gpus_per_node + # For inference, DP means we can serve independent requests in parallel + gpus_for_dp = tp * pp * cp + dp_target = total_gpus_target // gpus_for_dp + + if is_rank_0: + print("" + "=" * 100) + print("Inference Parallelism Configuration") + print("=" * 100) + print(f" TP: {tp}, PP: {pp}, EP: {ep}, CP: {cp}") + print(f" GPUs per Node: {gpus_per_node}") + print(f" Minimum GPUs Required: {gpus_required}") + print(f" Minimum Nodes Required: {min_nodes_required}") + print(f" Target Nodes: {target_nodes}") + + # Load hardware config if provided + hardware_config_dict = None + if hasattr(args, "hardware_config") and args.hardware_config: + hardware_config_dict = load_hardware_config(args.hardware_config) + if is_rank_0: + print(f" Using custom hardware config from: {args.hardware_config}") + + # Calculate inference communication times (forward-only) + total_comm_time_ms, breakdown, message_info = ( + calculate_inference_communication_time( + training_config, + target_nodes, + gpus_per_node, + tp, + pp, + ep, + cp, + hardware_config_dict, + ) + ) + + # Inference projected time: forward compute + forward communication + projected_time_ms = forward_time_ms + + # For PP > 1, add estimated PP P2P overhead (forward only — one direction) + if pp > 1: + if hardware_config_dict: + pp_bw = hardware_config_dict.get("intra_node_bandwidth_gbps", 896.0) + else: + pp_bw = 896.0 # Default xGMI bandwidth + hidden = training_config.model_config.hidden_size + seq_len = training_config.runtime_config.sequence_length + micro_batch = training_config.runtime_config.micro_batch_size + activation_size = seq_len * micro_batch * hidden * 2 # BF16 + pp_latency_us = 1.0 + pp_time_per_stage = (activation_size / (pp_bw * 1e9) * 1e6 + pp_latency_us) / 1000 # ms + pp_overhead_ms = pp_time_per_stage * (pp - 1) + projected_time_ms += pp_overhead_ms + if is_rank_0: + print(f" PP forward overhead ({pp - 1} hops): {pp_overhead_ms:.3f} ms") + + # Get runtime config for throughput calculation + runtime_config = training_config.runtime_config + seq_len = getattr(runtime_config, "sequence_length", 4096) + micro_batch = getattr(runtime_config, "micro_batch_size", 1) + + # Prefill latency = forward time for one batch on one model replica + prefill_latency_ms = projected_time_ms + + # Tokens per pass (one forward pass processes all tokens in the sequence) + tokens_per_pass = seq_len * micro_batch + + # Throughput per replica + tokens_per_sec_per_replica = tokens_per_pass * 1000 / prefill_latency_ms if prefill_latency_ms > 0 else 0 + + # Number of independent replicas with DP + num_replicas = dp_target + + # Aggregate throughput across all replicas + total_tokens_per_sec = tokens_per_sec_per_replica * num_replicas + tokens_per_sec_per_gpu = total_tokens_per_sec / total_gpus_target if total_gpus_target > 0 else 0 + + if is_rank_0: + print("" + "=" * 100) + print("Inference Projection Results") + print("=" * 100) + print(f"📊 Parallelism: TP={tp}, PP={pp}, EP={ep}, CP={cp}") + + # Communication Breakdown + if total_comm_time_ms > 0: + print("📡 Communication Breakdown (forward-only):") + for op_name, op_time in breakdown.items(): + if op_time > 0: + print(f" {op_name}: {op_time:.3f} ms", end="") + if op_name == "moe_a2a_fwd" and "moe_a2a_size_mb" in message_info: + print( + f" (message: {message_info['moe_a2a_size_mb']:.2f} MB, " + f"{message_info['num_moe_layers']} layers × " + f"{message_info['moe_a2a_per_layer_fwd']:.3f} ms/layer)" + ) + else: + print("") + print(f" Total Communication: {total_comm_time_ms:.3f} ms") + + print(f"🎯 Target Configuration ({target_nodes} nodes):") + print(f" Nodes: {target_nodes}, GPUs: {total_gpus_target}") + print(f" TP={tp}, PP={pp}, EP={ep}, CP={cp}, DP(replicas)={num_replicas}") + print(f" Prefill Latency: {prefill_latency_ms:.3f} ms " + f"(seq_len={seq_len}, micro_batch={micro_batch})") + print(f" Tokens/s per replica: {tokens_per_sec_per_replica:,.0f}") + print(f" Tokens/s total ({num_replicas} replicas): {total_tokens_per_sec:,.0f}") + print(f" Tokens/s/GPU: {tokens_per_sec_per_gpu:,.0f}") + print("=" * 100) + + return { + "target_nodes": target_nodes, + "target_gpus": total_gpus_target, + "tp": tp, + "pp": pp, + "ep": ep, + "cp": cp, + "dp_replicas": num_replicas, + "prefill_latency_ms": prefill_latency_ms, + "tokens_per_sec_per_replica": tokens_per_sec_per_replica, + "tokens_per_sec_total": total_tokens_per_sec, + "tokens_per_sec_per_gpu": tokens_per_sec_per_gpu, + } + + +def _run_decode_layer_benchmark(primus_config, unknown_overrides, decode_batch_size): + """ + Benchmark transformer layers with seq_len=1 to measure decode-step GEMM + times on the GPU. + + The Megatron trainer is initialized with the **original** config (so that + all Megatron assertions pass normally). Only the ``run_layer_benchmark`` + call uses ``seq_len=1`` and ``batch_size=decode_batch_size`` — these + override the input tensor shapes without touching validated Megatron args. + + The resulting forward_time_ms per layer captures real kernel timings for + the tiny GEMMs characteristic of autoregressive decode. Attention timing + will be unrealistically small (1-token self-attention without KV cache), + so callers should overlay an analytical KV cache model. + + Args: + primus_config: Primus configuration (will be mutated for layer + limiting / EP rescaling only — seq_length is NOT changed). + unknown_overrides: Extra CLI overrides for the trainer. + decode_batch_size: Number of concurrent sequences. + + Returns: + dict: Profiling results in the same format as _run_layer_benchmark. + """ + from primus.modules.trainer.megatron.pre_trainer import MegatronPretrainTrainer + + module_config = primus_config.get_module_config("pre_trainer") + + # NOTE: Do NOT change seq_length or micro_batch_size in the module config — + # Megatron's argument validation has many assertions on seq_length (e.g. + # divisibility by CP, TP, position-embedding limits). We keep the original + # config for trainer init and only pass seq_len=1 to run_layer_benchmark(). + _limit_layers_for_projection(module_config) + rescale_info = _rescale_expert_parallelism(module_config) + training_config = convert_primus_config_to_projection_config(primus_config) + + master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") + master_port = int(os.getenv("MASTER_PORT", "29500")) + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + + is_rank_0 = rank == 0 + if is_rank_0: + print("[Primus:Decode Benchmark] Initializing MegatronPretrainTrainer...") + print("[Primus:Decode Benchmark] (trainer uses original seq_length for init; " + "benchmark will use seq_len=1)") + + # Disable overlap/FSDP features for profiling + primus_config.get_module_config("pre_trainer").overlap_grad_reduce = False + primus_config.get_module_config("pre_trainer").overlap_param_gather = False + primus_config.get_module_config("pre_trainer").use_torch_fsdp2 = False + + trainer = MegatronPretrainTrainer( + module_name="pre_trainer", + primus_config=primus_config, + module_rank=rank, + module_world_size=world_size, + module_master_addr=master_addr, + module_master_port=master_port, + extra_args=unknown_overrides, + ) + + if is_rank_0: + print("[Primus:Decode Benchmark] Initializing Megatron...") + trainer.init() + if is_rank_0: + print("[Primus:Decode Benchmark] Setting up model and optimizer...") + trainer.setup() + + if is_rank_0: + print("[Primus:Decode Benchmark] Building model profiler...") + model_profiler_spec = get_language_model_profiler_spec(training_config) + model_profiler = build_profiler(model_profiler_spec) + + # Override seq_len and batch_size ONLY for the benchmark call + seq_len = 1 + batch_size = decode_batch_size + + if is_rank_0: + print("[Primus:Decode Benchmark] Benchmarking with:") + print(f" Rank: {rank}") + print(f" World Size: {world_size}") + print(f" Batch Size: {batch_size} (decode concurrent sequences)") + print(f" Sequence Length: {seq_len} (1 token per decode step)") + if rescale_info: + note = ( + f" NOTE: MoE rescaled -> EP {rescale_info['ep_before']} -> {rescale_info['ep_after']}" + f" (TP={rescale_info['tp']}, CP={rescale_info['cp']})" + ) + print(note) + + if is_rank_0: + print("" + "=" * 100) + print("[Primus:Decode Benchmark] Starting layer benchmarking (seq_len=1)...") + print("=" * 100) + + profiling_results = model_profiler.run_layer_benchmark( + model=trainer.model, + batch_size=batch_size, + seq_len=seq_len, + ) + return profiling_results + + +def _run_decode_projection( + training_config, + args, + target_nodes: int, + profiling_results: Optional[Dict] = None, +): + """ + Run decode-mode projection. + + Supports two modes: + - **Analytical** (profiling_results=None): Fully analytical model based + on HBM bandwidth and model architecture. No GPU needed. + - **Benchmark-enhanced** (profiling_results provided): Uses real GPU + benchmark timings for GEMMs (seq_len=1) and overlays an analytical + KV cache attention model for the portion that benchmarks can't capture. + + Args: + training_config: Configuration object. + args: CLI arguments (decode_batch_size, decode_context_length, etc.). + target_nodes: Target number of nodes for projection. + profiling_results: Optional benchmark results from _run_decode_layer_benchmark. + """ + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + use_benchmark = profiling_results is not None + + mp_config = training_config.model_parallel_config + tp = mp_config.tensor_model_parallel_size + pp = mp_config.pipeline_model_parallel_size + ep = getattr(mp_config, "expert_model_parallel_size", 1) + cp = getattr(mp_config, "context_model_parallel_size", 1) + gpus_per_node = int(os.getenv("GPUS_PER_NODE", "8")) + total_gpus = target_nodes * gpus_per_node + + # DP replicas for decode + gpus_for_dp = tp * pp * cp + dp_replicas = total_gpus // gpus_for_dp + + # Decode-specific parameters from CLI or config defaults + decode_batch = getattr(args, "decode_batch_size", None) + if decode_batch is None: + decode_batch = training_config.runtime_config.micro_batch_size + context_len = getattr(args, "decode_context_length", None) + if context_len is None: + context_len = training_config.runtime_config.sequence_length + num_gen_tokens = getattr(args, "num_generated_tokens", None) or 128 + + gpu_arch = getattr(args, "gpu_arch", None) + hardware_config_dict = None + if hasattr(args, "hardware_config") and args.hardware_config: + hardware_config_dict = load_hardware_config(args.hardware_config) + + # Always run analytical model (used for KV cache estimate and as baseline) + analytical = _estimate_decode_time_per_token( + training_config, + decode_batch_size=decode_batch, + context_length=context_len, + tp=tp, + pp=pp, + ep=ep, + gpu_arch=gpu_arch, + hardware_config=hardware_config_dict, + ) + + # ── Compute decode time ────────────────────────────────────────────── + if use_benchmark: + # Benchmark-enhanced: real GPU GEMM timings + analytical KV cache + # + # The benchmark ran layers with seq_len=1. The forward_time_ms per + # layer captures the real GEMM cost (QKV proj, output proj, MLP) + # plus TP AllReduce — but attention was only 1-token self-attention + # (no KV cache). We add the analytical KV cache read time on top. + mc = training_config.model_config + num_layers = mc.num_layers + layers_per_pp = num_layers // max(pp, 1) + + # Sum forward times from benchmark (only transformer layers) + benchmark_fwd_ms = 0.0 + layer_count = 0 + for layer_idx, layer_data in profiling_results.items(): + if isinstance(layer_data, dict) and "forward_time_ms" in layer_data: + benchmark_fwd_ms += layer_data["forward_time_ms"] + layer_count += 1 + + # Scale to full model depth (benchmark may run fewer representative layers) + if layer_count > 0 and layer_count < layers_per_pp: + benchmark_fwd_ms = benchmark_fwd_ms * layers_per_pp / layer_count + + # KV cache read time (analytical — benchmark doesn't have a KV cache) + hbm_bw_bytes_per_ms = analytical["hbm_bw_gbps"] * 1e9 / 1e3 + kv_cache_ms = analytical["total_kv_bytes"] / hbm_bw_bytes_per_ms + + # PP overhead (serial stages) + pp_overhead_ms = analytical["pp_overhead_ms"] + + # MoE EP All-to-All (latency-dominated for 1-token messages) + moe_a2a_ms = analytical["moe_a2a_total_ms"] + + decode_time_ms = benchmark_fwd_ms + kv_cache_ms + pp_overhead_ms + moe_a2a_ms + method_label = "Benchmark-Enhanced" + else: + # Pure analytical + decode_time_ms = analytical["decode_time_ms"] + benchmark_fwd_ms = None + kv_cache_ms = analytical["total_kv_bytes"] / (analytical["hbm_bw_gbps"] * 1e9 / 1e3) + pp_overhead_ms = analytical["pp_overhead_ms"] + moe_a2a_ms = analytical["moe_a2a_total_ms"] + method_label = "Analytical Model" + + tokens_per_sec_per_replica = ( + decode_batch * 1000 / decode_time_ms if decode_time_ms > 0 else 0 + ) + total_tokens_per_sec = tokens_per_sec_per_replica * dp_replicas + tokens_per_sec_per_gpu = total_tokens_per_sec / total_gpus if total_gpus > 0 else 0 + + total_generation_time_ms = decode_time_ms * num_gen_tokens + + # KV cache memory per GPU + kv_mb_per_gpu = analytical["total_kv_mb"] + weight_mb_per_gpu = analytical["total_weight_mb"] + + if is_rank_0: + print("\n" + "=" * 100) + print(f"Decode Projection Results ({method_label})") + print("=" * 100) + + print(f"📊 Parallelism: TP={tp}, PP={pp}, EP={ep}, CP={cp}") + print(f" Decode batch size: {decode_batch}") + print(f" Context length (KV cache): {context_len}") + print(f" GPU arch: {(gpu_arch or os.getenv('PRIMUS_GPU_ARCH', 'mi300x')).lower()}") + print(f" HBM bandwidth: {analytical['hbm_bw_gbps']:.0f} GB/s") + print(f" Peak BF16 compute: {analytical['peak_tflops']:.0f} TFLOPS") + + print() + print("⏱️ Per-Token Decode Time Breakdown:") + + if use_benchmark: + # Show benchmarked GEMM time + analytical KV cache overlay + print(f" Layer fwd (benchmarked, seq_len=1): {benchmark_fwd_ms:.4f} ms " + f"(includes GEMMs + TP AllReduce)") + print(f" KV cache read (analytical): {kv_cache_ms:.4f} ms " + f"({analytical['total_kv_mb']:.1f} MB)") + # Also show analytical-only for comparison + analytical_weight_ms = analytical["total_weight_bytes"] / (analytical["hbm_bw_gbps"] * 1e9 / 1e3) + analytical_total = analytical["decode_time_ms"] + print(f" ─── Analytical comparison ───") + print(f" Weight loading (analytical): {analytical_weight_ms:.4f} ms " + f"({analytical['total_weight_mb']:.1f} MB)") + print(f" Total (analytical): {analytical_total:.4f} ms") + print(f" ────────────────────────────") + else: + weight_only_ms = analytical["total_weight_bytes"] / (analytical["hbm_bw_gbps"] * 1e9 / 1e3) + print(f" Weight loading: {weight_only_ms:.4f} ms ({analytical['total_weight_mb']:.1f} MB)") + print(f" KV cache read: {kv_cache_ms:.4f} ms ({analytical['total_kv_mb']:.1f} MB)") + print(f" Compute: {analytical['compute_time_ms']:.4f} ms " + f"({analytical['total_compute_tflops'] * 1000:.2f} GFLOPS)") + if analytical["tp_overhead_ms"] > 0: + print(f" TP AllReduce: {analytical['tp_overhead_ms']:.4f} ms") + + if pp_overhead_ms > 0: + print(f" PP P2P: {pp_overhead_ms:.4f} ms") + if moe_a2a_ms > 0: + print(f" MoE All-to-All: {moe_a2a_ms:.4f} ms") + + if not use_benchmark: + bound = "MEMORY-BOUND" if analytical["is_memory_bound"] else "COMPUTE-BOUND" + print(f" Bottleneck: {bound} " + f"(arith intensity={analytical['arithmetic_intensity']:.2f} FLOPs/B, " + f"balance={analytical['balance_point']:.0f} FLOPs/B)") + + print(f" ─────────────────────────────────────") + print(f" Total per token: {decode_time_ms:.4f} ms") + + print() + print(f"🎯 Target Configuration ({target_nodes} nodes):") + print(f" Nodes: {target_nodes}, GPUs: {total_gpus}") + print(f" TP={tp}, PP={pp}, EP={ep}, CP={cp}, DP(replicas)={dp_replicas}") + print(f" Per-token latency: {decode_time_ms:.4f} ms") + print(f" Tokens/s per replica: {tokens_per_sec_per_replica:,.0f}") + print(f" Tokens/s total ({dp_replicas} replicas): {total_tokens_per_sec:,.0f}") + print(f" Tokens/s/GPU: {tokens_per_sec_per_gpu:,.0f}") + print() + print(f" Generation of {num_gen_tokens} tokens: {total_generation_time_ms:.1f} ms " + f"({total_generation_time_ms / 1000:.2f} s)") + + print() + print("💾 Memory Estimate (per GPU):") + print(f" Model weights: {weight_mb_per_gpu:.1f} MB ({weight_mb_per_gpu / 1024:.2f} GB)") + print(f" KV cache: {kv_mb_per_gpu:.1f} MB ({kv_mb_per_gpu / 1024:.2f} GB) " + f"(batch={decode_batch}, ctx={context_len})") + total_mem_gb = (weight_mb_per_gpu + kv_mb_per_gpu) / 1024 + print(f" Total: {total_mem_gb:.2f} GB") + + print("=" * 100) + + return { + "target_nodes": target_nodes, + "target_gpus": total_gpus, + "tp": tp, + "pp": pp, + "ep": ep, + "cp": cp, + "dp_replicas": dp_replicas, + "decode_batch_size": decode_batch, + "context_length": context_len, + "decode_time_per_token_ms": decode_time_ms, + "tokens_per_sec_per_replica": tokens_per_sec_per_replica, + "tokens_per_sec_total": total_tokens_per_sec, + "tokens_per_sec_per_gpu": tokens_per_sec_per_gpu, + "total_generation_time_ms": total_generation_time_ms, + "weight_mb_per_gpu": weight_mb_per_gpu, + "kv_cache_mb_per_gpu": kv_mb_per_gpu, + "is_memory_bound": analytical["is_memory_bound"], + "method": method_label, + } + + def launch_projection_from_cli(args, overrides): """ Entry point for the 'performance_projection' subcommand. @@ -2174,6 +3134,56 @@ def launch_projection_from_cli(args, overrides): "benchmark_num_experts" ] + # ========================================================================= + # DECODE MODE — analytical or benchmark-enhanced + # ========================================================================= + _projection_mode = getattr(args, "mode", MODE_TRAINING) + if _projection_mode == MODE_DECODE: + training_config = convert_primus_config_to_projection_config(primus_config_original) + + profiling_mode = getattr(args, "profiling_mode", "benchmark") + decode_profiling_results = None + + if profiling_mode in ("benchmark", "both"): + # Benchmark layers with seq_len=1 to get real GPU GEMM timings + decode_batch = getattr(args, "decode_batch_size", None) + if decode_batch is None: + decode_batch = training_config.runtime_config.micro_batch_size + + decode_profiling_results = _run_decode_layer_benchmark( + copy.deepcopy(primus_config), + unknown_overrides, + decode_batch_size=decode_batch, + ) + + if profiling_mode == "both": + # Run both: show analytical and benchmark-enhanced side by side + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + + # Analytical + if is_rank_0: + print("\n" + "=" * 100) + print("[Primus:Decode] Running ANALYTICAL projection...") + print("=" * 100) + _run_decode_projection(training_config, args, target_nodes, profiling_results=None) + + # Benchmark-enhanced + if is_rank_0: + print("\n" + "=" * 100) + print("[Primus:Decode] Running BENCHMARK-ENHANCED projection...") + print("=" * 100) + _run_decode_projection( + training_config, args, target_nodes, + profiling_results=decode_profiling_results, + ) + else: + # Single mode: benchmark or simulate + _run_decode_projection( + training_config, args, target_nodes, + profiling_results=decode_profiling_results, + ) + return + # Determine profiling mode profiling_mode = getattr(args, "profiling_mode", "benchmark") @@ -2224,10 +3234,17 @@ def launch_projection_from_cli(args, overrides): # Use original config for projection calculations training_config = convert_primus_config_to_projection_config(primus_config_original) + # Determine projection mode (training / prefill / decode) + projection_mode = getattr(args, "mode", MODE_TRAINING) + # Normalise: "inference" is an alias for "prefill" + if projection_mode == MODE_INFERENCE: + projection_mode = MODE_PREFILL + # Update data_parallel_size based on target_nodes - # This ensures the pipeline simulation calculates the correct number of microbatches - # NOTE: For MoE models, EP does NOT reduce DP (experts are distributed but tokens are replicated) - # DP = world_size / (TP × PP × CP) [EP is excluded] + # This ensures the pipeline simulation calculates the correct number of microbatches. + # DP = world_size / (TP × PP × CP) for both dense and MoE. With MoE Parallel + # Folding, EP borrows from the DP dimension (CP is folded into EP), so EP + # does not appear in the DP divisor. mp_config = training_config.model_parallel_config tp = mp_config.tensor_model_parallel_size pp = mp_config.pipeline_model_parallel_size @@ -2238,7 +3255,7 @@ def launch_projection_from_cli(args, overrides): # The pipeline simulator simulates the target config, so it needs target DP for microbatch calculation target_world_size = target_nodes * gpus_per_node - # For MoE models: DP calculation excludes EP since experts are distributed but data is replicated + # DP = world_size / (TP × PP × CP) — EP excluded (borrows from DP via folding) target_dp = target_world_size // (tp * pp * cp) # Also show benchmark config for reference @@ -2250,14 +3267,88 @@ def launch_projection_from_cli(args, overrides): # Only print from rank 0 is_rank_0 = int(os.getenv("RANK", "0")) == 0 + mode_label = "Prefill" if projection_mode == MODE_PREFILL else "Training" if is_rank_0: - print("[Primus:Performance Projection] Configuration Summary:") + print(f"[Primus:{mode_label} Projection] Configuration Summary:") print( f" Benchmark Config: PP={benchmark_pp}, EP={benchmark_ep}, TP={tp}, CP={cp}, DP={benchmark_dp} (1 node)" ) print( f" Target Config: PP={pp}, EP={ep}, TP={tp}, CP={cp}, DP={target_dp} ({target_nodes} nodes)" ) + print(f" Mode: {projection_mode}") + + # ========================================================================= + # PREFILL MODE — forward-only projection, no backward/optimizer/gradient + # ========================================================================= + if projection_mode == MODE_PREFILL: + # For inference, EP overhead adjustment is forward-only + if ( + reduction_info["adjusted"] + and reduction_info["original_ep"] != reduction_info["benchmark_ep"] + ): + original_ep = reduction_info["original_ep"] + benchmark_ep = reduction_info["benchmark_ep"] + original_num_experts = reduction_info.get("original_num_experts") + benchmark_num_experts = reduction_info.get("benchmark_num_experts") + + hardware_config_dict = None + if hasattr(args, "hardware_config") and args.hardware_config: + hardware_config_dict = load_hardware_config(args.hardware_config) + + fwd_overhead_per_layer, _ = _estimate_ep_communication_overhead( + training_config, original_ep, benchmark_ep, hardware_config_dict, + ) + ep_mlp_scale = _compute_ep_mlp_scale( + training_config.model_config, + benchmark_ep, original_ep, + original_num_experts=original_num_experts, + benchmark_num_experts=benchmark_num_experts, + ) + + if is_rank_0: + print("[Primus:Inference Projection] Adjusting profiling for EP (forward-only):") + print(f" EP rescaled: {benchmark_ep} → {original_ep}") + print(f" MLP fwd scale factor: {ep_mlp_scale:.3f}") + if fwd_overhead_per_layer > 0: + print(f" A2A fwd delta: +{fwd_overhead_per_layer:.3f} ms/layer") + + for layer_idx, layer_data in profiling_results.items(): + if isinstance(layer_data, dict) and layer_data.get("type") == "moe": + old_fwd = layer_data.get("forward_time_ms", 0) + mlp_info = layer_data.get("mlp", {}) + mlp_fwd = mlp_info.get("forward_time_ms", 0) + new_mlp_fwd = mlp_fwd * ep_mlp_scale + mlp_delta_fwd = new_mlp_fwd - mlp_fwd + new_fwd = old_fwd + mlp_delta_fwd + fwd_overhead_per_layer + layer_data["forward_time_ms"] = new_fwd + if mlp_info: + mlp_info["forward_time_ms"] = new_mlp_fwd + + # Extract forward-only time + forward_time_ms = extract_single_node_time_inference( + profiling_results, training_config + ) + + # Run inference projection + if target_nodes >= min_nodes_required: + if is_rank_0: + print("" + "=" * 100) + print("[Primus:Inference] Running inference projection") + print("=" * 100) + + _run_inference_projection( + training_config, + forward_time_ms, + profiling_results, + args, + target_nodes, + ) + return + + # ========================================================================= + # TRAINING MODE — full forward + backward + optimizer + gradient AllReduce + # ========================================================================= # Use BENCHMARK DP for pipeline simulation to get consistent baseline # The multinode projection will then scale from this baseline to target From 6fdf7deb7776102e03016562573e555982359300 Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Tue, 24 Feb 2026 23:26:41 -0800 Subject: [PATCH 10/12] =?UTF-8?q?fix:=20resolve=20CodeQL=20warnings=20?= =?UTF-8?q?=E2=80=94=20remove=20dead=20stores,=20unused=20imports,=20and?= =?UTF-8?q?=20empty=20except.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../module_profilers/collective_model.py | 1 - .../core/projection/module_profilers/moe_mlp.py | 4 ---- .../performance_projection/projection.py | 17 ----------------- .../core/projection/simulation_backends/base.py | 2 -- .../simulation_backends/sdpa_simulator.py | 11 +++++++++-- 5 files changed, 9 insertions(+), 26 deletions(-) diff --git a/primus/core/projection/module_profilers/collective_model.py b/primus/core/projection/module_profilers/collective_model.py index f22548758..43af2db81 100644 --- a/primus/core/projection/module_profilers/collective_model.py +++ b/primus/core/projection/module_profilers/collective_model.py @@ -581,7 +581,6 @@ def pxn_alltoall(args, msg_size, gpus, groups=None, protocol=None): scale_out_msg_size = int(original_msg_size * (num_nodes - 1) / num_nodes) # Scale-up delay: time to accumulate 4MB before scale-out starts - scaleup_delay = 0.0 if scale_out_msg_size < chunk_size: # If total scale-out msg size is less than 4MB, # total time = scaleup_delay + scaleout_time diff --git a/primus/core/projection/module_profilers/moe_mlp.py b/primus/core/projection/module_profilers/moe_mlp.py index a7ca7a424..84acd7362 100644 --- a/primus/core/projection/module_profilers/moe_mlp.py +++ b/primus/core/projection/module_profilers/moe_mlp.py @@ -4,7 +4,6 @@ # See LICENSE for license information. ############################################################################### -import math import os from typing import Optional @@ -197,9 +196,6 @@ def _get_simulated_results( f" ({num_local_experts} local experts, M={M}, H={H}, F={F})" ) - expert_fwd_ms = 0.0 - expert_bwd_ms = 0.0 - if use_turbo: # ── Turbo model: batched GEMM (all experts in parallel) ── B = num_local_experts diff --git a/primus/core/projection/performance_projection/projection.py b/primus/core/projection/performance_projection/projection.py index 52d29b169..18b2ae15b 100644 --- a/primus/core/projection/performance_projection/projection.py +++ b/primus/core/projection/performance_projection/projection.py @@ -696,26 +696,9 @@ def calculate_collective_communication_time( if use_fsdp and dp > 1: overlap_fsdp = getattr(mp_config, "use_torch_fsdp2", False) if overlap_fsdp: - recompute_gran = getattr(mp_config, "recompute_granularity", None) - recomp_n = getattr(mp_config, "recompute_num_layers", 0) or 0 - has_recompute = recompute_gran == "full" and recomp_n > 0 - total_fsdp_ag = breakdown.get("fsdp_allgather_fwd", 0) total_fsdp_rs = breakdown.get("fsdp_reducescatter_bwd", 0) - if has_recompute: - # With full recompute the AG total already includes the 2× - # multiplier. Split into forward AG and backward (recomp) AG. - recomp_ratio = min(recomp_n, num_layers) / num_layers - ag_multiplier_val = message_info.get( - "fsdp_ag_multiplier", 1 + recomp_ratio - ) - fwd_ag_total = total_fsdp_ag / ag_multiplier_val - bwd_ag_total = total_fsdp_ag - fwd_ag_total - else: - fwd_ag_total = total_fsdp_ag - bwd_ag_total = 0.0 - # Overlap factor applied uniformly to all FSDP # communication - (AllGather fwd, AllGather recompute, ReduceScatter). FSDP_OVERLAP = 0.93 diff --git a/primus/core/projection/simulation_backends/base.py b/primus/core/projection/simulation_backends/base.py index ae68b7b80..941be3d53 100644 --- a/primus/core/projection/simulation_backends/base.py +++ b/primus/core/projection/simulation_backends/base.py @@ -132,8 +132,6 @@ def simulate_mlp_gemms( Returns: SimulationResult with forward_time_ms and backward_time_ms. """ - fwd_time = 0.0 - bwd_time = 0.0 # Use batched GEMM (batch=num_experts) as approximation of grouped GEMM. # Valid under uniform token distribution (all experts get the same M). # TODO: switch to native grouped-GEMM simulation if/when Origami supports it. diff --git a/primus/core/projection/simulation_backends/sdpa_simulator.py b/primus/core/projection/simulation_backends/sdpa_simulator.py index ad1e2a907..c4723e89c 100644 --- a/primus/core/projection/simulation_backends/sdpa_simulator.py +++ b/primus/core/projection/simulation_backends/sdpa_simulator.py @@ -349,8 +349,15 @@ def _create_tile_gemm_backend( "for Flash Attention" ) return backend - except Exception: - pass + except Exception as exc: + # If Origami is not available or fails to initialize, fall back to + # the analytic SDPA model by returning None here. + is_rank_0 = int(os.getenv("RANK", "0")) == 0 + if is_rank_0: + print( + "[Primus:SDPA] Origami 1-CU tile-level simulation disabled " + f"due to error: {exc}" + ) return None def _simulate_tile_level( From a2a968afb07b621693263a7ba297ae803583a59b Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Wed, 25 Feb 2026 10:00:11 -0800 Subject: [PATCH 11/12] style: run black formatter (line-length=110) on projection module --- primus/cli/subcommands/projection.py | 4 +- .../memory_projection/projection.py | 20 +- .../projection/module_profilers/attention.py | 34 +- .../module_profilers/collective_args.py | 4 +- .../module_profilers/collective_model.py | 91 +-- .../projection/module_profilers/dense_mlp.py | 24 +- .../projection/module_profilers/embedding.py | 17 +- .../module_profilers/language_model.py | 176 ++--- .../projection/module_profilers/moe_mlp.py | 87 +-- .../module_profilers/output_layer.py | 17 +- .../projection/module_profilers/router.py | 4 +- .../module_profilers/transformer_layer.py | 76 +- .../core/projection/module_profilers/utils.py | 14 +- .../performance_projection/projection.py | 657 ++++++------------ .../performance_projection/simulator.py | 82 +-- primus/core/projection/profiler_spec.py | 6 +- .../projection/simulation_backends/base.py | 66 +- .../projection/simulation_backends/factory.py | 12 +- .../simulation_backends/origami_backend.py | 29 +- .../simulation_backends/sdpa_simulator.py | 25 +- primus/core/projection/training_config.py | 28 +- 21 files changed, 431 insertions(+), 1042 deletions(-) diff --git a/primus/cli/subcommands/projection.py b/primus/cli/subcommands/projection.py index bb0ecfcfb..b2917a890 100644 --- a/primus/cli/subcommands/projection.py +++ b/primus/cli/subcommands/projection.py @@ -75,9 +75,7 @@ def register_subcommand(subparsers): suite_parsers = parser.add_subparsers(dest="suite", required=True) # ---------- memory ---------- - memory = suite_parsers.add_parser( - "memory", help="Memory projection only (per-GPU memory analysis)." - ) + memory = suite_parsers.add_parser("memory", help="Memory projection only (per-GPU memory analysis).") from primus.core.launcher.parser import add_pretrain_parser add_pretrain_parser(memory) diff --git a/primus/core/projection/memory_projection/projection.py b/primus/core/projection/memory_projection/projection.py index e83b5e957..c1f988c35 100644 --- a/primus/core/projection/memory_projection/projection.py +++ b/primus/core/projection/memory_projection/projection.py @@ -17,9 +17,7 @@ ) -def print_profiler_hierarchy( - profiler, batch_size, seq_len, rank=None, name="root", depth=0, visited=None -): +def print_profiler_hierarchy(profiler, batch_size, seq_len, rank=None, name="root", depth=0, visited=None): """ Recursively print the profiler hierarchy with num_params and activation_memory for each component. @@ -48,17 +46,13 @@ def print_profiler_hierarchy( if depth == 0: # Only output the total number of parameters for the entire model for depth 0. num_params = profiler.estimated_num_params(rank=None) - print( - f"{indent} Total Number of Parameters: {num_params / 1e9:.6f} Billion ({num_params:,})" - ) + print(f"{indent} Total Number of Parameters: {num_params / 1e9:.6f} Billion ({num_params:,})") else: num_params = profiler.estimated_num_params(rank=rank) activation_mem = profiler.estimated_activation_memory(batch_size, seq_len) print(f"{indent}[{name}]") print(f"{indent} Params: {num_params / 1e9:.6f} Billion ({num_params:,})") - print( - f"{indent} Activation Memory: {activation_mem / 1024 / 1024 / 1024:.4f} GB" - ) + print(f"{indent} Activation Memory: {activation_mem / 1024 / 1024 / 1024:.4f} GB") # Recursively process sub_profilers if they exist if hasattr(profiler, "sub_profilers") and profiler.sub_profilers: @@ -85,9 +79,7 @@ def launch_projection_from_cli(args, overrides): """ cfg_path = Path(args.config) if not cfg_path.exists(): - raise FileNotFoundError( - f"[Primus:Projection] Config file '{cfg_path}' not found." - ) + raise FileNotFoundError(f"[Primus:Projection] Config file '{cfg_path}' not found.") config_parser = PrimusParser() primus_config = config_parser.parse(args) @@ -125,9 +117,7 @@ def launch_projection_from_cli(args, overrides): print("=" * 100) print(f"[Primus:Projection] Memory Projection Summary on Rank {rank}:") print(f" Params: {num_params / 1e9:.6f} Billion ({num_params:,})") - print( - f" Param+Optimizer Memory: {num_params * num_bytes_per_param / 1024 / 1024 / 1024:.4f} GB" - ) + print(f" Param+Optimizer Memory: {num_params * num_bytes_per_param / 1024 / 1024 / 1024:.4f} GB") print( f" Activation Memory (per batch size {batch_size}, seq len {seq_len}): " f"{activation_memory / 1024 / 1024 / 1024:.4f} GB" diff --git a/primus/core/projection/module_profilers/attention.py b/primus/core/projection/module_profilers/attention.py index 5dc6eb88e..f227a00fa 100644 --- a/primus/core/projection/module_profilers/attention.py +++ b/primus/core/projection/module_profilers/attention.py @@ -17,9 +17,7 @@ class AttentionProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.module = None # Will be set during benchmarking - self._cached_results = ( - None # Cache for (forward_time, backward_time, activation_memory) - ) + self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend self._sdpa_backend = None # Optional: SDPA simulation backend @@ -54,9 +52,7 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: ) # Projection ratio: (kv_channels * n_heads) / hidden_size - query_proj_to_hidden = ( - args.kv_channels * args.num_attention_heads - ) / args.hidden_size + query_proj_to_hidden = (args.kv_channels * args.num_attention_heads) / args.hidden_size if args.multi_latent_attention: # q_term: either dense or LoRA factored Q with RoPE/Q-norm @@ -69,19 +65,14 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: else: q_term = args.q_lora_rank * ( args.hidden_size - + args.num_attention_heads - * (args.qk_head_dim + args.qk_pos_emb_head_dim) + + args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim) + 1 ) attn = ( q_term # kv lora + rope + kv norm + args.kv_lora_rank - * ( - args.hidden_size - + args.num_attention_heads * (args.qk_head_dim + args.v_head_dim) - + 1 - ) + * (args.hidden_size + args.num_attention_heads * (args.qk_head_dim + args.v_head_dim) + 1) # pos emb + args.hidden_size * args.qk_pos_emb_head_dim # out proj @@ -94,10 +85,7 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: 2 * args.hidden_size * args.hidden_size - * ( - (1 + (num_query_groups / args.num_attention_heads)) - * query_proj_to_hidden - ) + * ((1 + (num_query_groups / args.num_attention_heads)) * query_proj_to_hidden) ) def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: @@ -142,9 +130,7 @@ def _num_query_groups() -> int: kv_projection_size = args.kv_channels * _num_query_groups() # Need to retain Q, K, V as well as the projected context/output. - activation_width = ( - query_projection_size + 2 * kv_projection_size + args.hidden_size - ) + activation_width = query_projection_size + 2 * kv_projection_size + args.hidden_size if args.qk_layernorm: ln_width += kv_projection_size * 2 @@ -262,9 +248,7 @@ def _simulate_mla_gemms(self, batch_tokens: int, dtype: str) -> tuple[float, flo return fwd_time, bwd_time - def _get_simulated_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get simulated results from GEMM + SDPA simulation backends.""" args = self.config.model_config mp = self.config.model_parallel_config @@ -334,9 +318,7 @@ def _get_simulated_results( activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) diff --git a/primus/core/projection/module_profilers/collective_args.py b/primus/core/projection/module_profilers/collective_args.py index cf6246d60..b377e153e 100644 --- a/primus/core/projection/module_profilers/collective_args.py +++ b/primus/core/projection/module_profilers/collective_args.py @@ -52,9 +52,7 @@ class CollectiveArgs: # All-to-all specific a2a_peer_lat: float = 0.45 # Per-peer latency overhead for inter-node a2a - a2a_intra_node_peer_lat: float = ( - 28.0 # Per-peer latency overhead for intra-node a2a - ) + a2a_intra_node_peer_lat: float = 28.0 # Per-peer latency overhead for intra-node a2a # Intra-node overhead is higher (~19-28 us) due to: # - P2P scatter/gather scheduling overhead # - RCCL internal synchronization barriers diff --git a/primus/core/projection/module_profilers/collective_model.py b/primus/core/projection/module_profilers/collective_model.py index 43af2db81..1b7358cc4 100644 --- a/primus/core/projection/module_profilers/collective_model.py +++ b/primus/core/projection/module_profilers/collective_model.py @@ -110,9 +110,7 @@ def sendrecv(args, msg_size): return t -def direct_alltoall( - args, msg_size, gpus, groups=["ep"], protocol=None, original_msg_size=None -): +def direct_alltoall(args, msg_size, gpus, groups=["ep"], protocol=None, original_msg_size=None): """ Direct alltoall for HP=1, hierarchical with parallel NIC utilization. @@ -141,9 +139,7 @@ def direct_alltoall( intra_node_volume = msg_size * intra_fraction inter_node_volume_per_gpu = msg_size * inter_fraction - node_lat, intra_vol_adj = node_latency_and_volume_protocol( - args, intra_node_volume, protocol - ) + node_lat, intra_vol_adj = node_latency_and_volume_protocol(args, intra_node_volume, protocol) pod_lat = args.pod_lat # Intra-node time @@ -158,10 +154,7 @@ def direct_alltoall( t_inter = total_inter_volume / aggregate_inter_bw * 1.0e-3 + pod_lat else: remote_nodes = num_nodes - 1 - t_inter = ( - inter_node_volume_per_gpu / (args.bw_eff * args.pod_bw) * 1.0e-3 - + pod_lat * remote_nodes - ) + t_inter = inter_node_volume_per_gpu / (args.bw_eff * args.pod_bw) * 1.0e-3 + pod_lat * remote_nodes # Overlap intra and inter t_a2a = max(t_intra, t_inter) @@ -202,9 +195,7 @@ def run_alltoall(args, msg_size, gpus, groups=["ep"], protocol=None): elif (args.hp * gpus > args.node_size) and (args.hp * gpus) <= args.pod_size: # Alltoall fits within pod if args.hp == 1: - return direct_alltoall( - args, msg_size, gpus, groups, protocol, original_msg_size - ) + return direct_alltoall(args, msg_size, gpus, groups, protocol, original_msg_size) bw = args.bw_eff * args.pod_bw lat = args.pod_lat else: @@ -240,11 +231,7 @@ def cp_allgather(args, msg_size, gpus, protocol=None): bw = args.cluster_bw * args.bw_eff lat = args.cluster_lat # Logarithmic steps for tree allgather - t = ( - msg_size / bw * 1.0e-3 - + lat * np.ceil(np.log2(gpus)) - + args.kernel_launch_latency - ) + t = msg_size / bw * 1.0e-3 + lat * np.ceil(np.log2(gpus)) + args.kernel_launch_latency return t @@ -467,9 +454,7 @@ def single_shot_alltoall(args, msg_size, gpus, groups=None, protocol=None): t_intra_node = 0 t_inter_node = 0 if intra_node_gpus > 0: - node_lat, msg_size_per_peer_adj = node_latency_and_volume_protocol( - args, msg_size_per_peer, protocol - ) + node_lat, msg_size_per_peer_adj = node_latency_and_volume_protocol(args, msg_size_per_peer, protocol) node_bw = args.bw_eff * args.node_bw intra_node_rounds = ceil(intra_node_gpus / intra_node_fanout) t_intra_node = intra_node_rounds * node_lat @@ -520,9 +505,7 @@ def hierarchical_alltoall(args, msg_size, gpus, groups=None, protocol=None): inter_node_volume_per_gpu = msg_size * (gpus - gpus_per_node) / gpus # Intra-node time - node_lat, intra_vol_adj = node_latency_and_volume_protocol( - args, intra_node_volume, protocol - ) + node_lat, intra_vol_adj = node_latency_and_volume_protocol(args, intra_node_volume, protocol) node_bw = args.bw_eff * args.node_bw t_intra = node_lat + intra_vol_adj / node_bw * 1.0e-3 @@ -533,10 +516,7 @@ def hierarchical_alltoall(args, msg_size, gpus, groups=None, protocol=None): t_inter = args.pod_lat + total_inter_volume / aggregate_inter_bw * 1.0e-3 else: effective_pod_bw = args.bw_eff * args.pod_bw - t_inter = ( - args.pod_lat * num_nodes - + inter_node_volume_per_gpu / effective_pod_bw * 1.0e-3 - ) + t_inter = args.pod_lat * num_nodes + inter_node_volume_per_gpu / effective_pod_bw * 1.0e-3 t_total = max(t_intra, t_inter) t_total += args.kernel_launch_latency @@ -584,9 +564,7 @@ def pxn_alltoall(args, msg_size, gpus, groups=None, protocol=None): if scale_out_msg_size < chunk_size: # If total scale-out msg size is less than 4MB, # total time = scaleup_delay + scaleout_time - node_lat, _ = node_latency_and_volume_protocol( - args, scale_out_msg_size, protocol - ) + node_lat, _ = node_latency_and_volume_protocol(args, scale_out_msg_size, protocol) scaleup_delay = node_lat + scale_out_msg_size / args.node_bw * 1.0e-3 else: # Scale-out comm doesn't start until 4MB is accumulated @@ -594,18 +572,12 @@ def pxn_alltoall(args, msg_size, gpus, groups=None, protocol=None): scaleup_delay = node_lat + chunk_size / args.node_bw * 1.0e-3 # Assume PXN style alltoall with overlapped scale-up and scale-out - node_msg_size = int( - original_msg_size * (effective_gpus_per_node - 1) / effective_gpus_per_node - ) + node_msg_size = int(original_msg_size * (effective_gpus_per_node - 1) / effective_gpus_per_node) scale_out_msg_size = int(original_msg_size * (num_nodes - 1) / num_nodes) # Calculate latencies with protocol inflation - node_lat, node_msg_size_adj = node_latency_and_volume_protocol( - args, node_msg_size, protocol - ) - pod_lat, scale_out_msg_size_adj = pod_latency_and_volume_protocol( - args, scale_out_msg_size, protocol - ) + node_lat, node_msg_size_adj = node_latency_and_volume_protocol(args, node_msg_size, protocol) + pod_lat, scale_out_msg_size_adj = pod_latency_and_volume_protocol(args, scale_out_msg_size, protocol) # Scale-up (intra-node) time node_bw = args.bw_eff * args.node_bw @@ -637,9 +609,7 @@ def single_shot_allgather(args, msg_size, gpus, groups=None, protocol=None): t_intra_node = 0 t_inter_node = 0 if intra_node_gpus > 0: - node_lat, msg_size_per_peer_node = node_latency_and_volume_protocol( - args, msg_size_per_peer, protocol - ) + node_lat, msg_size_per_peer_node = node_latency_and_volume_protocol(args, msg_size_per_peer, protocol) node_bw = args.bw_eff * args.node_bw intra_node_rounds = ceil(intra_node_gpus / intra_node_fanout) t_intra_node = intra_node_rounds * node_lat @@ -672,9 +642,7 @@ def single_shot_reduce_scatter(args, msg_size, gpus, groups=["hp"], protocol=Non t_intra_node = 0 t_inter_node = 0 if intra_node_gpus > 0: - node_lat, msg_size_per_peer_node = node_latency_and_volume_protocol( - args, msg_size_per_peer, protocol - ) + node_lat, msg_size_per_peer_node = node_latency_and_volume_protocol(args, msg_size_per_peer, protocol) node_bw = args.bw_eff * args.node_bw intra_node_rounds = ceil(intra_node_gpus / intra_node_fanout) t_intra_node = intra_node_rounds * node_lat @@ -704,9 +672,7 @@ def single_shot_allreduce(args, msg_size, gpus, groups=["hp"], protocol=None): return 0 t_rs = single_shot_reduce_scatter(args, msg_size, gpus, groups, protocol) t_ag = single_shot_allgather(args, msg_size, gpus, groups, protocol) - t_ar = ( - t_rs + t_ag - args.kernel_launch_latency - ) # Remove duplicate kernel launch latency + t_ar = t_rs + t_ag - args.kernel_launch_latency # Remove duplicate kernel launch latency return t_ar @@ -728,9 +694,7 @@ def allreduce(args, msg_size, gpus, groups=["dp"]): hypercubeallreduce = oneshotHCallreduce(args, msg_size, gpus, protocol=p) ss_allreduce = single_shot_allreduce(args, msg_size, gpus, protocol=p) ringallreduce = RingAllreduce(args, msg_size, gpus, protocol=p) - min_ar_alg_time = min( - ringallreduce, bruck_time, hypercubeallreduce, ss_allreduce - ) + min_ar_alg_time = min(ringallreduce, bruck_time, hypercubeallreduce, ss_allreduce) if min_ar_alg_time < min_ar_time: min_ar_time = min_ar_alg_time return min_ar_time @@ -746,9 +710,7 @@ def alltoall(args, msg_size, gpus, groups=["ep"]): which pipelines scale-up and scale-out communication. """ # Check if DeepEP is enabled - use_deepep = getattr(args, "moe_enable_deepep", False) or getattr( - args, "use_turbo_deepep", False - ) + use_deepep = getattr(args, "moe_enable_deepep", False) or getattr(args, "use_turbo_deepep", False) min_a2a_time = float("inf") for p in ["simple", "ll", "ll64", "ll128"]: @@ -759,12 +721,8 @@ def alltoall(args, msg_size, gpus, groups=["ep"]): else: # Use regular All-to-All algorithms direct_a2a_time = run_alltoall(args, msg_size, gpus, protocol=p) - single_shot_a2a_time = single_shot_alltoall( - args, msg_size, gpus, protocol=p - ) - hierarchical_a2a_time = hierarchical_alltoall( - args, msg_size, gpus, protocol=p - ) + single_shot_a2a_time = single_shot_alltoall(args, msg_size, gpus, protocol=p) + hierarchical_a2a_time = hierarchical_alltoall(args, msg_size, gpus, protocol=p) a2a_time = min(direct_a2a_time, single_shot_a2a_time, hierarchical_a2a_time) if a2a_time < min_a2a_time: @@ -785,16 +743,11 @@ def alltoall(args, msg_size, gpus, groups=["ep"]): # Intra-node overhead is much higher due to synchronization and scheduling # Based on preflight measurements: EP=8 intra-node A2A needs ~19-28 us per peer # Inter-node overhead is lower (~0.45 us per peer) due to RDMA efficiency - intra_node_overhead_per_peer = getattr( - args, "a2a_intra_node_peer_lat", 28.0 - ) # Default 28 us - inter_node_overhead_per_peer = getattr( - args, "a2a_peer_lat", 0.45 - ) # Default 0.45 us + intra_node_overhead_per_peer = getattr(args, "a2a_intra_node_peer_lat", 28.0) # Default 28 us + inter_node_overhead_per_peer = getattr(args, "a2a_peer_lat", 0.45) # Default 0.45 us peer_overhead = ( - intra_node_overhead_per_peer * intra_node_peers - + inter_node_overhead_per_peer * inter_node_peers + intra_node_overhead_per_peer * intra_node_peers + inter_node_overhead_per_peer * inter_node_peers ) min_a2a_time += peer_overhead diff --git a/primus/core/projection/module_profilers/dense_mlp.py b/primus/core/projection/module_profilers/dense_mlp.py index 83d1db655..91c4ce0f2 100644 --- a/primus/core/projection/module_profilers/dense_mlp.py +++ b/primus/core/projection/module_profilers/dense_mlp.py @@ -17,9 +17,7 @@ class DenseMLPProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.module = None # Will be set during benchmarking - self._cached_results = ( - None # Cache for (forward_time, backward_time, activation_memory) - ) + self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend @@ -57,26 +55,18 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: # Memory after first projection(s) if self.config.model_config.swiglu: # Need to store both gate and up projections for backward - intermediate_memory = ( - 2 * num_tokens * self.config.model_config.ffn_hidden_size * 2 - ) # bf16 + intermediate_memory = 2 * num_tokens * self.config.model_config.ffn_hidden_size * 2 # bf16 else: - intermediate_memory = ( - num_tokens * self.config.model_config.ffn_hidden_size * 2 - ) # bf16 + intermediate_memory = num_tokens * self.config.model_config.ffn_hidden_size * 2 # bf16 # After activation - activation_memory = ( - num_tokens * self.config.model_config.ffn_hidden_size * 2 - ) # bf16 + activation_memory = num_tokens * self.config.model_config.ffn_hidden_size * 2 # bf16 output_memory = num_tokens * self.config.model_config.hidden_size * 2 # bf16 # Peak memory is input + intermediate (both needed for backward) return intermediate_memory + activation_memory + output_memory - def _get_simulated_results( - self, batch_size: int, seq_len: int - ) -> Tuple[float, float, int]: + def _get_simulated_results(self, batch_size: int, seq_len: int) -> Tuple[float, float, int]: """Get simulated results from the GEMM simulation backend.""" tp_size = self.config.model_parallel_config.tensor_model_parallel_size cp_size = self.config.model_parallel_config.context_model_parallel_size @@ -98,9 +88,7 @@ def _get_simulated_results( activation_memory, ) - def _get_benchmark_results( - self, batch_size: int, seq_len: int - ) -> Tuple[float, float, int]: + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> Tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: diff --git a/primus/core/projection/module_profilers/embedding.py b/primus/core/projection/module_profilers/embedding.py index 6ca7ea409..cfd2c3a90 100644 --- a/primus/core/projection/module_profilers/embedding.py +++ b/primus/core/projection/module_profilers/embedding.py @@ -17,9 +17,7 @@ class EmbeddingProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.module = None # Will be set during benchmarking - self._cached_results = ( - None # Cache for (forward_time, backward_time, activation_memory) - ) + self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) self._simulation_mode = False # Set to True when simulation backends are active @@ -37,10 +35,7 @@ def set_simulation_mode(self, enabled: bool = True): self._cache_key = None def estimated_num_params(self, rank: Optional[int] = None) -> int: - return ( - self.config.model_config.padded_vocab_size - * self.config.model_config.hidden_size - ) + return self.config.model_config.padded_vocab_size * self.config.model_config.hidden_size def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return ( @@ -52,9 +47,7 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: * 2 ) # bf16 - def _get_simulated_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Estimate embedding time analytically (lookup is memory-bound, very fast).""" tp_size = self.config.model_parallel_config.tensor_model_parallel_size cp_size = self.config.model_parallel_config.context_model_parallel_size @@ -70,9 +63,7 @@ def _get_simulated_results( activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) diff --git a/primus/core/projection/module_profilers/language_model.py b/primus/core/projection/module_profilers/language_model.py index 0145f7078..a975ab73b 100644 --- a/primus/core/projection/module_profilers/language_model.py +++ b/primus/core/projection/module_profilers/language_model.py @@ -26,9 +26,7 @@ def build_profiler(spec: ModuleProfilerSpec, depth=0) -> BaseModuleProfiler: Recursively build a profiler instance from a ModuleProfilerSpec. """ if not issubclass(spec.profiler, BaseModuleProfiler): - raise TypeError( - f"spec.profiler must be subclass of BaseModuleProfiler, got {spec.profiler}" - ) + raise TypeError(f"spec.profiler must be subclass of BaseModuleProfiler, got {spec.profiler}") if depth == 0: print(f"Begin build profiler: {spec.profiler.__name__}") @@ -49,9 +47,7 @@ def build_profiler(spec: ModuleProfilerSpec, depth=0) -> BaseModuleProfiler: print(f"{'--'*(depth+1)}[{sub_spec.__name__}]({name})") sub_profilers[name] = sub_spec(spec.config, sub_profilers=None) else: - raise TypeError( - f"Invalid type for sub_profiler_specs['{name}']: {type(sub_spec)}" - ) + raise TypeError(f"Invalid type for sub_profiler_specs['{name}']: {type(sub_spec)}") return spec.profiler(config=spec.config, sub_profilers=sub_profilers) @@ -62,9 +58,7 @@ def get_language_model_profiler_spec(config: TrainingConfig) -> ModuleProfilerSp config=config, sub_profiler_specs={ "embedding": EmbeddingProfiler, - "dense_transformer_layer": get_dense_transformer_layer_profiler_spec( - config - ), + "dense_transformer_layer": get_dense_transformer_layer_profiler_spec(config), "moe_transformer_layer": get_moe_transformer_layer_profiler_spec(config), "final_layernorm": LayerNormProfiler, "output_layer": OutputLayerProfiler, @@ -118,11 +112,7 @@ def _get_explicit_layer_distribution( middle_stages = ( total_stages - 2 if (decoder_first is not None and decoder_last is not None) - else ( - total_stages - 1 - if (decoder_first is not None or decoder_last is not None) - else total_stages - ) + else (total_stages - 1 if (decoder_first is not None or decoder_last is not None) else total_stages) ) if middle_stages > 0 and remaining_layers > 0: @@ -184,21 +174,13 @@ def set_simulation_backends(self, gemm_backend=None, sdpa_backend=None): layer_profiler.set_simulation_backends(gemm_backend, sdpa_backend) # Propagate to embedding (uses simple analytical estimate in sim mode). - if ( - "embedding" in self.sub_profilers - and self.sub_profilers["embedding"] is not None - ): + if "embedding" in self.sub_profilers and self.sub_profilers["embedding"] is not None: emb = self.sub_profilers["embedding"] if hasattr(emb, "set_simulation_mode"): - emb.set_simulation_mode( - gemm_backend is not None or sdpa_backend is not None - ) + emb.set_simulation_mode(gemm_backend is not None or sdpa_backend is not None) # Propagate GEMM backend to output layer (vocab projection GEMM). - if ( - "output_layer" in self.sub_profilers - and self.sub_profilers["output_layer"] is not None - ): + if "output_layer" in self.sub_profilers and self.sub_profilers["output_layer"] is not None: out = self.sub_profilers["output_layer"] if gemm_backend is not None and hasattr(out, "set_gemm_backend"): out.set_gemm_backend(gemm_backend) @@ -222,11 +204,7 @@ def get_layers_for_rank( to the first virtual stages (or use decoder_first/last_pipeline_num_layers if set). """ total_stages = pp_size - vpp_size = ( - num_virtual_pipeline_stages - if num_virtual_pipeline_stages is not None - else 1 - ) + vpp_size = num_virtual_pipeline_stages if num_virtual_pipeline_stages is not None else 1 total_stages = pp_size * vpp_size model_parallel_size = pp_size * tp_size * cp_size * ep_size @@ -240,13 +218,9 @@ def get_layers_for_rank( if self is not None and hasattr(self, "config") and self.config is not None: mp_config = self.config.model_parallel_config if decoder_first is None: - decoder_first = getattr( - mp_config, "decoder_first_pipeline_num_layers", None - ) + decoder_first = getattr(mp_config, "decoder_first_pipeline_num_layers", None) if decoder_last is None: - decoder_last = getattr( - mp_config, "decoder_last_pipeline_num_layers", None - ) + decoder_last = getattr(mp_config, "decoder_last_pipeline_num_layers", None) # Build layer counts per virtual stage if decoder_first is not None or decoder_last is not None: @@ -385,22 +359,14 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: for layer in layers: is_moe = self.config.model_config.moe_pattern[layer] if is_moe: - total_params += self.sub_profilers[ - "moe_transformer_layer" - ].estimated_num_params(rank) + total_params += self.sub_profilers["moe_transformer_layer"].estimated_num_params(rank) else: - total_params += self.sub_profilers[ - "dense_transformer_layer" - ].estimated_num_params(rank) + total_params += self.sub_profilers["dense_transformer_layer"].estimated_num_params(rank) if 0 in self.layers: total_params += self.sub_profilers["embedding"].estimated_num_params(rank) if self.config.model_config.num_layers - 1 in self.layers: - total_params += self.sub_profilers["final_layernorm"].estimated_num_params( - rank - ) - total_params += self.sub_profilers["output_layer"].estimated_num_params( - rank - ) + total_params += self.sub_profilers["final_layernorm"].estimated_num_params(rank) + total_params += self.sub_profilers["output_layer"].estimated_num_params(rank) total_params += self.sub_profilers["calc_loss"].estimated_num_params(rank) return total_params @@ -412,17 +378,13 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: (hidden_size * batch_size * seq_len * dtype_bytes), not full intermediate activations. """ pp_size = self.config.model_parallel_config.pipeline_model_parallel_size - vpp_size = ( - self.config.model_parallel_config.virtual_pipeline_model_parallel_size - ) + vpp_size = self.config.model_parallel_config.virtual_pipeline_model_parallel_size recompute_granularity = self.config.model_parallel_config.recompute_granularity recompute_num_layers = self.config.model_parallel_config.recompute_num_layers # Calculate number of layers per virtual pipeline stage on this rank layers_per_rank = len(self.layers) - layers_per_vpp_stage = ( - layers_per_rank // vpp_size if vpp_size > 0 else layers_per_rank - ) + layers_per_vpp_stage = layers_per_rank // vpp_size if vpp_size > 0 else layers_per_rank # Input activation size per layer (only thing stored for recomputed layers) # hidden_size * batch_size * seq_len * dtype_bytes (bf16 = 2 bytes) @@ -439,9 +401,7 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: layer_act = 0 for i, layer in enumerate(self.layers): # Determine if this layer is recomputed - local_layer_idx = ( - i % layers_per_vpp_stage if layers_per_vpp_stage > 0 else i - ) + local_layer_idx = i % layers_per_vpp_stage if layers_per_vpp_stage > 0 else i is_recomputed = ( recompute_granularity == "full" and recompute_num_layers is not None @@ -455,31 +415,25 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: # Non-recomputed layer: store full activations is_moe = self.config.model_config.moe_pattern[layer] if is_moe: - layer_act += self.sub_profilers[ - "moe_transformer_layer" - ].estimated_activation_memory(batch_size, seq_len) + layer_act += self.sub_profilers["moe_transformer_layer"].estimated_activation_memory( + batch_size, seq_len + ) else: - layer_act += self.sub_profilers[ - "dense_transformer_layer" - ].estimated_activation_memory(batch_size, seq_len) + layer_act += self.sub_profilers["dense_transformer_layer"].estimated_activation_memory( + batch_size, seq_len + ) total_act = layer_act # Add embedding/output activations if 0 in self.layers: - total_act += self.sub_profilers["embedding"].estimated_activation_memory( - batch_size, seq_len - ) + total_act += self.sub_profilers["embedding"].estimated_activation_memory(batch_size, seq_len) if self.config.model_config.num_layers - 1 in self.layers: - total_act += self.sub_profilers[ - "final_layernorm" - ].estimated_activation_memory(batch_size, seq_len) - total_act += self.sub_profilers["output_layer"].estimated_activation_memory( - batch_size, seq_len - ) - total_act += self.sub_profilers["calc_loss"].estimated_activation_memory( + total_act += self.sub_profilers["final_layernorm"].estimated_activation_memory( batch_size, seq_len ) + total_act += self.sub_profilers["output_layer"].estimated_activation_memory(batch_size, seq_len) + total_act += self.sub_profilers["calc_loss"].estimated_activation_memory(batch_size, seq_len) # 1F1B pipeline schedule: need to store activations for pp_size microbatches total_act *= pp_size @@ -506,9 +460,7 @@ def run_layer_benchmark(self, model, batch_size: int, seq_len: int) -> dict: The mode is automatically selected based on whether simulation backends have been set via :meth:`set_simulation_backends`. """ - is_simulation_mode = ( - self._gemm_backend is not None or self._sdpa_backend is not None - ) + is_simulation_mode = self._gemm_backend is not None or self._sdpa_backend is not None # ----------------------------------------------------------------- # Unwrap model (only when an actual model is provided) @@ -521,11 +473,7 @@ def run_layer_benchmark(self, model, batch_size: int, seq_len: int) -> dict: def unwrap_module(module): """Recursively unwrap DistributedDataParallel / pipeline wrappers.""" - return ( - unwrap_module(module.module) - if hasattr(module, "module") - else module - ) + return unwrap_module(module.module) if hasattr(module, "module") else module model_chunks = model if isinstance(model, list) else [model] @@ -539,28 +487,20 @@ def unwrap_module(module): if hasattr(language_model, "output_layer"): output_module = language_model.output_layer - if hasattr(language_model, "encoder") and hasattr( - language_model.encoder, "layers" - ): + if hasattr(language_model, "encoder") and hasattr(language_model.encoder, "layers"): all_layers.extend(language_model.encoder.layers) - elif hasattr(language_model, "decoder") and hasattr( - language_model.decoder, "layers" - ): + elif hasattr(language_model, "decoder") and hasattr(language_model.decoder, "layers"): all_layers.extend(language_model.decoder.layers) elif hasattr(language_model, "layers"): all_layers.extend(language_model.layers) continue - if hasattr(unwrapped, "decoder") and hasattr( - unwrapped.decoder, "layers" - ): + if hasattr(unwrapped, "decoder") and hasattr(unwrapped.decoder, "layers"): all_layers.extend(unwrapped.decoder.layers) elif hasattr(unwrapped, "layers"): all_layers.extend(unwrapped.layers) else: - raise ValueError( - f"Cannot find transformer layers in model chunk: {type(unwrapped)}" - ) + raise ValueError(f"Cannot find transformer layers in model chunk: {type(unwrapped)}") if hasattr(unwrapped, "embedding"): embedding_module = unwrapped.embedding if hasattr(unwrapped, "output_layer"): @@ -575,25 +515,17 @@ def unwrap_module(module): mode_label = "Simulating" if is_simulation_mode else "Benchmarking" if is_rank_0: if model is not None: - print( - f"\n[Primus:Performance Projection] Found {len(all_layers)} transformer layers" - ) + print(f"\n[Primus:Performance Projection] Found {len(all_layers)} transformer layers") else: - print( - f"\n[Primus:Performance Projection] Pure simulation mode (no model)" - ) - print( - f"[Primus:Performance Projection] This rank is responsible for layers: {self.layers}" - ) + print(f"\n[Primus:Performance Projection] Pure simulation mode (no model)") + print(f"[Primus:Performance Projection] This rank is responsible for layers: {self.layers}") if is_simulation_mode: backends = [] if self._gemm_backend is not None: backends.append(f"GEMM={self._gemm_backend.name()}") if self._sdpa_backend is not None: backends.append(f"SDPA={self._sdpa_backend.name()}") - print( - f"[Primus:Performance Projection] Mode: SIMULATION ({', '.join(backends)})" - ) + print(f"[Primus:Performance Projection] Mode: SIMULATION ({', '.join(backends)})") embedding_stats = None output_stats = None @@ -602,20 +534,12 @@ def unwrap_module(module): # Benchmark / simulate embedding layer (if this rank hosts it) # ---------------------------------------------------------------------- if 0 in self.layers: - if ( - model is not None - and embedding_module is None - and not is_simulation_mode - ): + if model is not None and embedding_module is None and not is_simulation_mode: if is_rank_0: - print( - "[Primus:Performance Projection] WARNING: Embedding module not found on this rank." - ) + print("[Primus:Performance Projection] WARNING: Embedding module not found on this rank.") else: if is_rank_0: - print( - f"[Primus:Performance Projection] {mode_label} embedding layer..." - ) + print(f"[Primus:Performance Projection] {mode_label} embedding layer...") profiler = self.sub_profilers["embedding"] if embedding_module is not None: module = ( @@ -655,9 +579,7 @@ def unwrap_module(module): ) else: if is_rank_0: - print( - f"[Primus:Performance Projection] {mode_label} output layer..." - ) + print(f"[Primus:Performance Projection] {mode_label} output layer...") profiler = self.sub_profilers["output_layer"] if output_module is not None: profiler.set_module(output_module) @@ -690,9 +612,7 @@ def unwrap_module(module): # In benchmark mode, guard against out-of-range layer indices. if model is not None and layer_idx >= len(all_layers): if is_rank_0: - print( - f"[WARNING] Layer index {layer_idx} exceeds available layers ({len(all_layers)})" - ) + print(f"[WARNING] Layer index {layer_idx} exceeds available layers ({len(all_layers)})") continue is_moe = self.config.model_config.moe_pattern[layer_idx] @@ -702,9 +622,7 @@ def unwrap_module(module): continue if is_rank_0: - print( - f"\n[Primus:Performance Projection] {mode_label} Layer {layer_idx} ({layer_type})..." - ) + print(f"\n[Primus:Performance Projection] {mode_label} Layer {layer_idx} ({layer_type})...") # Get the appropriate profiler if is_moe: @@ -720,9 +638,7 @@ def unwrap_module(module): # Benchmark/simulate full layer forward_time = layer_profiler.measured_forward_time(batch_size, seq_len) backward_time = layer_profiler.measured_backward_time(batch_size, seq_len) - activation_memory = layer_profiler.measured_activation_memory( - batch_size, seq_len - ) + activation_memory = layer_profiler.measured_activation_memory(batch_size, seq_len) # Benchmark/simulate Attention attn_profiler = layer_profiler.get_sub_profiler("self_attention") @@ -762,9 +678,7 @@ def unwrap_module(module): print(f" Backward time: {backward_time:.2f} ms {src}") print(f" Total: {forward_time + backward_time:.2f} ms {src}") print(f" Activation memory: {activation_memory / (1024**2):.2f} MB") - print( - f" Attention: fwd={attn_forward:.2f} ms, bwd={attn_backward:.2f} ms" - ) + print(f" Attention: fwd={attn_forward:.2f} ms, bwd={attn_backward:.2f} ms") print(f" MLP: fwd={mlp_forward:.2f} ms, bwd={mlp_backward:.2f} ms") # Expand results to all layers diff --git a/primus/core/projection/module_profilers/moe_mlp.py b/primus/core/projection/module_profilers/moe_mlp.py index 84acd7362..8a9a06c5f 100644 --- a/primus/core/projection/module_profilers/moe_mlp.py +++ b/primus/core/projection/module_profilers/moe_mlp.py @@ -35,9 +35,7 @@ class MoEMLPProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.module = None # Will be set during benchmarking - self._cached_results = ( - None # Cache for (forward_time, backward_time, activation_memory) - ) + self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend @@ -64,26 +62,16 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: # For SwiGLU: 3 projections per expert (gate, up, down) # For standard FFN: 2 projections per expert (up, down) num_ffn_projections = 3 if self.config.model_config.swiglu else 2 - per_expert_params = ( - num_ffn_projections * self.config.model_config.hidden_size * moe_ffn - ) - ep = ( - 1 - if rank is None - else self.config.model_parallel_config.expert_model_parallel_size - ) + per_expert_params = num_ffn_projections * self.config.model_config.hidden_size * moe_ffn + ep = 1 if rank is None else self.config.model_parallel_config.expert_model_parallel_size - all_experts_params = ( - self.config.model_config.num_experts * per_expert_params // ep - ) + all_experts_params = self.config.model_config.num_experts * per_expert_params // ep # Shared experts (if any) shared_sz = 0 if self.config.model_config.moe_shared_expert_intermediate_size is not None: shared_sz = self.config.model_config.moe_shared_expert_intermediate_size - shared_params = ( - num_ffn_projections * self.config.model_config.hidden_size * shared_sz - ) + shared_params = num_ffn_projections * self.config.model_config.hidden_size * shared_sz return all_experts_params + shared_params @@ -120,16 +108,12 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: # After activation activation_memory = num_tokens * moe_ffn * 2 # bf16 - output_memory = ( - num_tokens * self.config.model_config.hidden_size * 2 - ) # bf16 + output_memory = num_tokens * self.config.model_config.hidden_size * 2 # bf16 total += intermediate_memory + activation_memory + output_memory return total - def _get_simulated_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get simulated results from the GEMM simulation backend for MoE MLP. In addition to expert GEMM time, this method estimates several @@ -184,9 +168,9 @@ def _get_simulated_results( # execution → model as Origami batched GEMM (batch=num_local_experts). # Legacy grouped_gemm executes experts more sequentially → model as # individual GEMM (batch=1) × num_local_experts. - use_turbo = getattr( - self.config.model_config, "enable_primus_turbo", False - ) and getattr(self.config.model_config, "use_turbo_grouped_mlp", False) + use_turbo = getattr(self.config.model_config, "enable_primus_turbo", False) and getattr( + self.config.model_config, "use_turbo_grouped_mlp", False + ) is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0 and num_local_experts > 1: @@ -200,18 +184,10 @@ def _get_simulated_results( # ── Turbo model: batched GEMM (all experts in parallel) ── B = num_local_experts if self.config.model_config.swiglu: - gate_fwd = self._gemm_backend.simulate_gemm( - M, F, H, gemm_dtype, batch=B - ) + gate_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) - down_fwd = self._gemm_backend.simulate_gemm( - M, H, F, gemm_dtype, batch=B - ) - expert_fwd_ms = ( - gate_fwd.forward_time_ms - + up_fwd.forward_time_ms - + down_fwd.forward_time_ms - ) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) + expert_fwd_ms = gate_fwd.forward_time_ms + up_fwd.forward_time_ms + down_fwd.forward_time_ms gate_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) gate_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) @@ -228,9 +204,7 @@ def _get_simulated_results( ) else: up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=B) - down_fwd = self._gemm_backend.simulate_gemm( - M, H, F, gemm_dtype, batch=B - ) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) expert_fwd_ms = up_fwd.forward_time_ms + down_fwd.forward_time_ms up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=B) up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=B) @@ -248,18 +222,10 @@ def _get_simulated_results( else: # ── Legacy model: individual GEMM × num_local_experts ── if self.config.model_config.swiglu: - gate_fwd = self._gemm_backend.simulate_gemm( - M, F, H, gemm_dtype, batch=1 - ) + gate_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) - down_fwd = self._gemm_backend.simulate_gemm( - M, H, F, gemm_dtype, batch=1 - ) - expert_fwd_ms = ( - gate_fwd.forward_time_ms - + up_fwd.forward_time_ms - + down_fwd.forward_time_ms - ) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) + expert_fwd_ms = gate_fwd.forward_time_ms + up_fwd.forward_time_ms + down_fwd.forward_time_ms gate_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) gate_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) @@ -276,9 +242,7 @@ def _get_simulated_results( ) else: up_fwd = self._gemm_backend.simulate_gemm(M, F, H, gemm_dtype, batch=1) - down_fwd = self._gemm_backend.simulate_gemm( - M, H, F, gemm_dtype, batch=1 - ) + down_fwd = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) expert_fwd_ms = up_fwd.forward_time_ms + down_fwd.forward_time_ms up_dg = self._gemm_backend.simulate_gemm(M, H, F, gemm_dtype, batch=1) up_wg = self._gemm_backend.simulate_gemm(H, F, M, gemm_dtype, batch=1) @@ -299,9 +263,7 @@ def _get_simulated_results( # ── 2. Router overhead ── # Gate linear: [batch_tokens, num_experts, hidden_size] - router_gemm = self._gemm_backend.simulate_gemm( - batch_tokens, num_experts, hidden_size, gemm_dtype - ) + router_gemm = self._gemm_backend.simulate_gemm(batch_tokens, num_experts, hidden_size, gemm_dtype) router_fwd_ms = router_gemm.forward_time_ms # Softmax + top-K selection + auxiliary loss overhead (empirical) topk_overhead_ms = 0.1 + 0.002 * num_experts @@ -320,8 +282,7 @@ def _get_simulated_results( # model adapts automatically to different architectures. peak_hbm = ( self._gemm_backend.hbm_bandwidth_gbps - if self._gemm_backend is not None - and self._gemm_backend.hbm_bandwidth_gbps is not None + if self._gemm_backend is not None and self._gemm_backend.hbm_bandwidth_gbps is not None else _FALLBACK_HBM_BW_GBPS ) permute_eff_bw_gbps = peak_hbm * _PERMUTE_BW_FRACTION @@ -337,9 +298,7 @@ def _get_simulated_results( # ── 4. Activation function overhead (SwiGLU / GELU) ── if self.config.model_config.swiglu: - act_bytes = ( - 3 * topk_tokens * moe_ffn * bytes_per_el - ) # gate+up read, result write + act_bytes = 3 * topk_tokens * moe_ffn * bytes_per_el # gate+up read, result write else: act_bytes = 2 * topk_tokens * moe_ffn * bytes_per_el # read + write activation_ms = act_bytes / (activation_bw_gbps * 1e6) @@ -363,9 +322,7 @@ def _get_simulated_results( activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: diff --git a/primus/core/projection/module_profilers/output_layer.py b/primus/core/projection/module_profilers/output_layer.py index 8c979a16d..c2e27555b 100644 --- a/primus/core/projection/module_profilers/output_layer.py +++ b/primus/core/projection/module_profilers/output_layer.py @@ -15,9 +15,7 @@ class OutputLayerProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.module = None # Will be set during benchmarking - self._cached_results = ( - None # Cache for (forward_time, backward_time, activation_memory) - ) + self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend @@ -35,10 +33,7 @@ def set_gemm_backend(self, backend): self._cache_key = None def estimated_num_params(self, rank: Optional[int] = None) -> int: - return ( - self.config.model_config.padded_vocab_size - * self.config.model_config.hidden_size - ) + return self.config.model_config.padded_vocab_size * self.config.model_config.hidden_size def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return ( @@ -50,9 +45,7 @@ def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: * 2 ) # bf16 - def _get_simulated_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Simulate output layer using GEMM backend (vocab projection GEMM).""" tp_size = self.config.model_parallel_config.tensor_model_parallel_size cp_size = self.config.model_parallel_config.context_model_parallel_size @@ -89,9 +82,7 @@ def _get_simulated_results( activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) diff --git a/primus/core/projection/module_profilers/router.py b/primus/core/projection/module_profilers/router.py index 994d556ef..efdd8338b 100644 --- a/primus/core/projection/module_profilers/router.py +++ b/primus/core/projection/module_profilers/router.py @@ -11,9 +11,7 @@ class RouterProfiler(BaseModuleProfiler): def estimated_num_params(self, rank: Optional[int] = None) -> int: - return ( - self.config.model_config.hidden_size * self.config.model_config.num_experts - ) + return self.config.model_config.hidden_size * self.config.model_config.num_experts def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return ( diff --git a/primus/core/projection/module_profilers/transformer_layer.py b/primus/core/projection/module_profilers/transformer_layer.py index a3a115149..af4fe0526 100644 --- a/primus/core/projection/module_profilers/transformer_layer.py +++ b/primus/core/projection/module_profilers/transformer_layer.py @@ -108,9 +108,7 @@ def _estimate_moe_a2a_time_ms(config, batch_size: int, seq_len: int) -> float: ) # Propagate DeepEP setting if present (affects A2A algorithm selection) - moe_enable_deepep = getattr( - config.model_parallel_config, "moe_enable_deepep", False - ) + moe_enable_deepep = getattr(config.model_parallel_config, "moe_enable_deepep", False) use_turbo_deepep = getattr(config.model_parallel_config, "use_turbo_deepep", False) coll_args.moe_enable_deepep = moe_enable_deepep coll_args.use_turbo_deepep = use_turbo_deepep @@ -171,9 +169,7 @@ class DenseTransformerLayerProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.layer_module = None # Will be set during benchmarking - self._cached_results = ( - None # Cache for (forward_time, backward_time, activation_memory) - ) + self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend self._sdpa_backend = None # Optional: SDPA simulation backend @@ -220,30 +216,16 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return ( - self.sub_profilers["layer_norm"].estimated_activation_memory( - batch_size, seq_len - ) - * 3 - + self.sub_profilers["self_attention"].estimated_activation_memory( - batch_size, seq_len - ) + self.sub_profilers["layer_norm"].estimated_activation_memory(batch_size, seq_len) * 3 + + self.sub_profilers["self_attention"].estimated_activation_memory(batch_size, seq_len) + self.sub_profilers["mlp"].estimated_activation_memory(batch_size, seq_len) - + self.sub_profilers["residual_add"].estimated_activation_memory( - batch_size, seq_len - ) - * 2 + + self.sub_profilers["residual_add"].estimated_activation_memory(batch_size, seq_len) * 2 ) - def _get_simulated_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Aggregate simulated results from sub-profilers, including TP AllReduce.""" - attn_fwd = self.sub_profilers["self_attention"].measured_forward_time( - batch_size, seq_len - ) - attn_bwd = self.sub_profilers["self_attention"].measured_backward_time( - batch_size, seq_len - ) + attn_fwd = self.sub_profilers["self_attention"].measured_forward_time(batch_size, seq_len) + attn_bwd = self.sub_profilers["self_attention"].measured_backward_time(batch_size, seq_len) mlp_fwd = self.sub_profilers["mlp"].measured_forward_time(batch_size, seq_len) mlp_bwd = self.sub_profilers["mlp"].measured_backward_time(batch_size, seq_len) @@ -259,9 +241,7 @@ def _get_simulated_results( activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: @@ -296,9 +276,7 @@ class MoETransformerLayerProfiler(BaseModuleProfiler): def __init__(self, config, sub_profilers=None): super().__init__(config, sub_profilers) self.layer_module = None # Will be set during benchmarking - self._cached_results = ( - None # Cache for (forward_time, backward_time, activation_memory) - ) + self._cached_results = None # Cache for (forward_time, backward_time, activation_memory) self._cache_key = None # Cache key (batch_size, seq_len) self._gemm_backend = None # Optional: GEMM simulation backend self._sdpa_backend = None # Optional: SDPA simulation backend @@ -346,38 +324,22 @@ def estimated_num_params(self, rank: Optional[int] = None) -> int: def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int: return ( - self.sub_profilers["layer_norm"].estimated_activation_memory( - batch_size, seq_len - ) - * 3 - + self.sub_profilers["self_attention"].estimated_activation_memory( - batch_size, seq_len - ) + self.sub_profilers["layer_norm"].estimated_activation_memory(batch_size, seq_len) * 3 + + self.sub_profilers["self_attention"].estimated_activation_memory(batch_size, seq_len) + self.sub_profilers["mlp"].estimated_activation_memory(batch_size, seq_len) - + self.sub_profilers["router"].estimated_activation_memory( - batch_size, seq_len - ) - + self.sub_profilers["residual_add"].estimated_activation_memory( - batch_size, seq_len - ) - * 2 + + self.sub_profilers["router"].estimated_activation_memory(batch_size, seq_len) + + self.sub_profilers["residual_add"].estimated_activation_memory(batch_size, seq_len) * 2 ) - def _get_simulated_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_simulated_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Aggregate simulated results from sub-profilers. Includes TP AllReduce and MoE All-to-All communication overhead that would be captured in the measured layer time during benchmark mode but must be added explicitly in simulation mode. """ - attn_fwd = self.sub_profilers["self_attention"].measured_forward_time( - batch_size, seq_len - ) - attn_bwd = self.sub_profilers["self_attention"].measured_backward_time( - batch_size, seq_len - ) + attn_fwd = self.sub_profilers["self_attention"].measured_forward_time(batch_size, seq_len) + attn_bwd = self.sub_profilers["self_attention"].measured_backward_time(batch_size, seq_len) mlp_fwd = self.sub_profilers["mlp"].measured_forward_time(batch_size, seq_len) mlp_bwd = self.sub_profilers["mlp"].measured_backward_time(batch_size, seq_len) @@ -400,9 +362,7 @@ def _get_simulated_results( activation_memory = self.estimated_activation_memory(batch_size, seq_len) return (fwd_time, bwd_time, activation_memory) - def _get_benchmark_results( - self, batch_size: int, seq_len: int - ) -> tuple[float, float, int]: + def _get_benchmark_results(self, batch_size: int, seq_len: int) -> tuple[float, float, int]: """Get or compute benchmark results (cached).""" cache_key = (batch_size, seq_len) if self._cached_results is None or self._cache_key != cache_key: diff --git a/primus/core/projection/module_profilers/utils.py b/primus/core/projection/module_profilers/utils.py index b64db4528..c2cad601e 100644 --- a/primus/core/projection/module_profilers/utils.py +++ b/primus/core/projection/module_profilers/utils.py @@ -16,9 +16,7 @@ class _FP8ContextFactory: def __init__(self, transformer_config): self.transformer_config = transformer_config - self.fp8_enabled = ( - getattr(transformer_config, "fp8", None) if transformer_config else None - ) + self.fp8_enabled = getattr(transformer_config, "fp8", None) if transformer_config else None self._printed = False def __enter__(self): @@ -30,9 +28,7 @@ def __enter__(self): self._ctx = get_fp8_context(self.transformer_config, layer_no=-1) if not self._printed: - print( - f" [FP8] Using FP8 autocast context for benchmarking (fp8={self.fp8_enabled})" - ) + print(f" [FP8] Using FP8 autocast context for benchmarking (fp8={self.fp8_enabled})") self._printed = True except Exception as e: try: @@ -92,11 +88,7 @@ def benchmark_layer( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def create_input(spec): - if ( - isinstance(spec, tuple) - and len(spec) == 2 - and isinstance(spec[1], torch.dtype) - ): + if isinstance(spec, tuple) and len(spec) == 2 and isinstance(spec[1], torch.dtype): shape, dtype = spec else: shape = spec diff --git a/primus/core/projection/performance_projection/projection.py b/primus/core/projection/performance_projection/projection.py index 18b2ae15b..7ba62180d 100644 --- a/primus/core/projection/performance_projection/projection.py +++ b/primus/core/projection/performance_projection/projection.py @@ -65,6 +65,7 @@ def _calculate_min_gpus(tp, pp, ep, cp): # Dense: CP is an independent axis return tp * pp * cp + # HBM bandwidth (GB/s) by GPU architecture — used for optimizer step estimation _HBM_BANDWIDTH_GBPS: Dict[str, float] = { "mi300x": 5300.0, @@ -155,11 +156,10 @@ def _estimate_decode_time_per_token( # Attention: Q(h→h) + K(h→kv_dim) + V(h→kv_dim) + O(h→h) kv_dim = num_kv_heads * head_dim attn_weight_bytes = ( - hidden * hidden # Q - + hidden * kv_dim # K - + hidden * kv_dim # V - + hidden * hidden # O - ) * bytes_per_param // tp + (hidden * hidden + hidden * kv_dim + hidden * kv_dim + hidden * hidden) # Q # K # V # O + * bytes_per_param + // tp + ) # Dense MLP: gate(h→ffn) + up(h→ffn) + down(ffn→h) (SwiGLU has gate+up) dense_mlp_weight_bytes = 3 * hidden * ffn_hidden * bytes_per_param // tp @@ -168,16 +168,12 @@ def _estimate_decode_time_per_token( expert_tp = 1 # expert TP typically 1 if num_experts > 0: experts_per_gpu = max(num_experts // max(ep, 1), 1) - moe_mlp_weight_bytes = ( - 3 * hidden * moe_ffn * experts_per_gpu * bytes_per_param // expert_tp - ) + moe_mlp_weight_bytes = 3 * hidden * moe_ffn * experts_per_gpu * bytes_per_param // expert_tp # Router weight: (hidden → num_experts) router_weight_bytes = hidden * num_experts * bytes_per_param # Shared expert (if any) shared_expert_weight_bytes = ( - 3 * hidden * shared_expert_size * bytes_per_param // tp - if shared_expert_size > 0 - else 0 + 3 * hidden * shared_expert_size * bytes_per_param // tp if shared_expert_size > 0 else 0 ) else: moe_mlp_weight_bytes = 0 @@ -189,19 +185,15 @@ def _estimate_decode_time_per_token( # Read V: same kv_heads_per_gpu = max(num_kv_heads // tp, 1) kv_cache_per_layer_bytes = ( - 2 * decode_batch_size * kv_heads_per_gpu * context_length * head_dim - * bytes_per_param + 2 * decode_batch_size * kv_heads_per_gpu * context_length * head_dim * bytes_per_param ) # Compute FLOPs per layer per decode step (batch × 1 token) # Linear projections: 2 * batch * M * N (M=1, N=weight cols) attn_proj_flops = ( - 2 * decode_batch_size * ( - hidden * hidden # Q - + hidden * kv_dim # K - + hidden * kv_dim # V - + hidden * hidden # O - ) + 2 + * decode_batch_size + * (hidden * hidden + hidden * kv_dim + hidden * kv_dim + hidden * hidden) # Q # K # V # O ) // tp dense_mlp_flops = 2 * decode_batch_size * 3 * hidden * ffn_hidden // tp @@ -213,9 +205,7 @@ def _estimate_decode_time_per_token( # MoE compute: each token routed to topk experts if num_experts > 0: - moe_mlp_flops = ( - 2 * decode_batch_size * moe_topk * 3 * hidden * moe_ffn // max(ep, 1) - ) + moe_mlp_flops = 2 * decode_batch_size * moe_topk * 3 * hidden * moe_ffn // max(ep, 1) else: moe_mlp_flops = 0 @@ -227,8 +217,7 @@ def _estimate_decode_time_per_token( total_weight_bytes = ( attn_weight_bytes * num_layers # attention in every layer + dense_mlp_weight_bytes * num_dense_layers - + (moe_mlp_weight_bytes + router_weight_bytes + shared_expert_weight_bytes) - * num_moe_layers + + (moe_mlp_weight_bytes + router_weight_bytes + shared_expert_weight_bytes) * num_moe_layers ) // max(pp, 1) total_kv_bytes = kv_cache_per_layer_bytes * layers_per_pp @@ -352,14 +341,9 @@ def _estimate_optimizer_step_ms( expert_mlp_params_per_expert = 3 * hidden * moe_ffn // expert_tp # Non-expert params across all layers (sharded by TP, PP) - non_expert_params = ( - num_layers * attn_params_per_layer - + num_dense_layers * dense_mlp_params_per_layer - ) + non_expert_params = num_layers * attn_params_per_layer + num_dense_layers * dense_mlp_params_per_layer # Expert params (sharded by EP, expert_TP, PP) - expert_params = ( - num_moe_layers * num_experts * expert_mlp_params_per_expert // max(ep, 1) - ) + expert_params = num_moe_layers * num_experts * expert_mlp_params_per_expert // max(ep, 1) # Shared experts (if any) shared_sz = getattr(model_config, "moe_shared_expert_intermediate_size", 0) or 0 @@ -367,9 +351,7 @@ def _estimate_optimizer_step_ms( if shared_sz and num_moe_layers > 0: shared_expert_params = num_moe_layers * 3 * hidden * shared_sz // tp - total_params_per_gpu = ( - non_expert_params + expert_params + shared_expert_params - ) // pp + total_params_per_gpu = (non_expert_params + expert_params + shared_expert_params) // pp # Embedding + output layer params (only on first / last PP rank, amortise) vocab_size = getattr(model_config, "vocab_size", 0) or 0 @@ -513,16 +495,12 @@ def calculate_collective_communication_time( bw_eff = getattr(coll_args, "bw_eff", 0.91) inter_bw = pod_bw * bw_eff # GB/s per link msg_scale = (dp_replicas - 1) / dp_replicas - expert_ar_time_ms = ( - 2 * expert_grad_size * msg_scale / (inter_bw * 1e9) * 1e3 - ) + expert_ar_time_ms = 2 * expert_grad_size * msg_scale / (inter_bw * 1e9) * 1e3 # Non-expert gradient allreduce: across full DP group non_expert_per_rank = non_expert_params // (tp * pp) non_expert_grad_size = non_expert_per_rank * 4 # FP32 - non_expert_ar_time = cm.allreduce( - coll_args, non_expert_grad_size, dp, groups=["dp"] - ) + non_expert_ar_time = cm.allreduce(coll_args, non_expert_grad_size, dp, groups=["dp"]) non_expert_ar_ms = non_expert_ar_time / 1000 total_ar_ms = expert_ar_time_ms + non_expert_ar_ms @@ -594,29 +572,19 @@ def calculate_collective_communication_time( # Dense layer: ~12 * hidden^2 params (qkv_proj, o_proj, mlp up/down/gate) # MoE layer: similar attention + num_experts * expert_params ffn_hidden = model_config.ffn_hidden_size or hidden_size * 4 - params_per_dense_layer = ( - hidden_size * hidden_size * 4 + hidden_size * ffn_hidden * 3 - ) # attn + MLP - params_per_dense_layer = ( - params_per_dense_layer // tp - ) # Divide by TP (params are TP-sharded) + params_per_dense_layer = hidden_size * hidden_size * 4 + hidden_size * ffn_hidden * 3 # attn + MLP + params_per_dense_layer = params_per_dense_layer // tp # Divide by TP (params are TP-sharded) # Weight size in bytes (BF16 = 2 bytes) weight_size_per_layer = params_per_dense_layer * 2 # All-gather: each rank sends its shard (1/DP), receives full weights # Total data moved = weight_size * (DP-1)/DP per rank - ag_time_per_layer_us = cm.allgather( - coll_args, weight_size_per_layer, dp, groups=["dp"] - ) + ag_time_per_layer_us = cm.allgather(coll_args, weight_size_per_layer, dp, groups=["dp"]) # Reduce-scatter: each rank sends full gradients, receives its shard - grad_size_per_layer = ( - params_per_dense_layer * 2 - ) # BF16 gradients for communication - rs_time_per_layer_us = cm.reduce_scatter( - coll_args, grad_size_per_layer, dp, groups=["dp"] - ) + grad_size_per_layer = params_per_dense_layer * 2 # BF16 gradients for communication + rs_time_per_layer_us = cm.reduce_scatter(coll_args, grad_size_per_layer, dp, groups=["dp"]) # --- Recompute correction --- # With recompute_granularity="full", during the backward pass each layer @@ -633,16 +601,12 @@ def calculate_collective_communication_time( ag_multiplier = 1 + recomp_ratio # e.g. 2.0 when all layers recomputed # Calculate total FSDP time for all layers - total_fsdp_ag_fwd = ( - ag_time_per_layer_us * num_layers * ag_multiplier - ) / 1000 # ms + total_fsdp_ag_fwd = (ag_time_per_layer_us * num_layers * ag_multiplier) / 1000 # ms total_fsdp_rs_bwd = (rs_time_per_layer_us * num_layers) / 1000 # ms breakdown["fsdp_allgather_fwd"] = total_fsdp_ag_fwd breakdown["fsdp_reducescatter_bwd"] = total_fsdp_rs_bwd - message_info["fsdp_weight_size_per_layer_mb"] = weight_size_per_layer / ( - 1024 * 1024 - ) + message_info["fsdp_weight_size_per_layer_mb"] = weight_size_per_layer / (1024 * 1024) message_info["fsdp_ag_per_layer_ms"] = ag_time_per_layer_us / 1000 message_info["fsdp_rs_per_layer_ms"] = rs_time_per_layer_us / 1000 message_info["fsdp_ag_multiplier"] = ag_multiplier @@ -681,9 +645,7 @@ def calculate_collective_communication_time( total_comm_time = sum(breakdown.values()) # Check if gradient all-reduce should be overlapped - overlap_grad_reduce = getattr( - mp_config, "overlap_grad_reduce", True - ) # Default to True + overlap_grad_reduce = getattr(mp_config, "overlap_grad_reduce", True) # Default to True # If overlapped and NOT MoE-no-overlap, don't add to critical path moe_no_overlap = message_info.get("moe_ar_no_overlap", False) @@ -716,9 +678,7 @@ def calculate_collective_communication_time( return total_comm_time, breakdown, message_info, per_layer_info -def extract_single_node_time_from_profiling( - profiling_results: dict, training_config -) -> float: +def extract_single_node_time_from_profiling(profiling_results: dict, training_config) -> float: """ Extract total single-node time from profiling results. @@ -735,9 +695,7 @@ def extract_single_node_time_from_profiling( is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0: - print( - "[Primus:Performance Projection] Extracting timing from benchmark results..." - ) + print("[Primus:Performance Projection] Extracting timing from benchmark results...") print("-" * 100) model_config = training_config.model_config @@ -750,16 +708,12 @@ def extract_single_node_time_from_profiling( num_total_layers = len(moe_pattern) # Get profiled layer indices - profiled_layer_indices = sorted( - [k for k in profiling_results.keys() if isinstance(k, int)] - ) + profiled_layer_indices = sorted([k for k in profiling_results.keys() if isinstance(k, int)]) if is_rank_0: print(f" Profiled layers: {profiled_layer_indices}") print(f" Full model has {num_total_layers} transformer layers") if recompute_granularity == "full" and recompute_num_layers > 0: - print( - f" Recomputation: {recompute_num_layers} layers (granularity={recompute_granularity})" - ) + print(f" Recomputation: {recompute_num_layers} layers (granularity={recompute_granularity})") total_time_ms = 0.0 @@ -792,24 +746,12 @@ def extract_single_node_time_from_profiling( profiled_moe_fwd_times.append(fwd_time) # Calculate averages from profiled layers - avg_dense_time = ( - sum(profiled_dense_times) / len(profiled_dense_times) - if profiled_dense_times - else 0 - ) + avg_dense_time = sum(profiled_dense_times) / len(profiled_dense_times) if profiled_dense_times else 0 avg_dense_fwd = ( - sum(profiled_dense_fwd_times) / len(profiled_dense_fwd_times) - if profiled_dense_fwd_times - else 0 - ) - avg_moe_time = ( - sum(profiled_moe_times) / len(profiled_moe_times) if profiled_moe_times else 0 - ) - avg_moe_fwd = ( - sum(profiled_moe_fwd_times) / len(profiled_moe_fwd_times) - if profiled_moe_fwd_times - else 0 + sum(profiled_dense_fwd_times) / len(profiled_dense_fwd_times) if profiled_dense_fwd_times else 0 ) + avg_moe_time = sum(profiled_moe_times) / len(profiled_moe_times) if profiled_moe_times else 0 + avg_moe_fwd = sum(profiled_moe_fwd_times) / len(profiled_moe_fwd_times) if profiled_moe_fwd_times else 0 # Count total dense and MoE layers in full model num_dense_layers = sum(1 for x in moe_pattern if x == 0) @@ -825,21 +767,13 @@ def extract_single_node_time_from_profiling( # Print detailed breakdown if is_rank_0: if profiled_dense_times: - print( - f" Dense Layers: {len(profiled_dense_times)} profiled → {num_dense_layers} total" - ) - print( - f" Avg per layer: {avg_dense_time:.2f} ms (fwd={avg_dense_fwd:.2f} ms)" - ) + print(f" Dense Layers: {len(profiled_dense_times)} profiled → {num_dense_layers} total") + print(f" Avg per layer: {avg_dense_time:.2f} ms (fwd={avg_dense_fwd:.2f} ms)") print(f" Total time: {total_dense_time:.2f} ms") if profiled_moe_times: - print( - f" MoE Layers: {len(profiled_moe_times)} profiled → {num_moe_layers} total" - ) - print( - f" Avg per layer: {avg_moe_time:.2f} ms (fwd={avg_moe_fwd:.2f} ms)" - ) + print(f" MoE Layers: {len(profiled_moe_times)} profiled → {num_moe_layers} total") + print(f" Avg per layer: {avg_moe_time:.2f} ms (fwd={avg_moe_fwd:.2f} ms)") print(f" Total time: {total_moe_time:.2f} ms") # Output layer @@ -870,28 +804,20 @@ def extract_single_node_time_from_profiling( if is_rank_0: print(f" Recomputation Overhead: {recompute_overhead_ms:.2f} ms") - print( - f" ({recompute_dense_layers} dense + {recompute_moe_layers} MoE layers recomputed)" - ) + print(f" ({recompute_dense_layers} dense + {recompute_moe_layers} MoE layers recomputed)") if is_rank_0: print("-" * 100) - print( - f"[Primus:Performance Projection] Extrapolated Baseline Time: {total_time_ms:.2f} ms/iteration" - ) + print(f"[Primus:Performance Projection] Extrapolated Baseline Time: {total_time_ms:.2f} ms/iteration") if recompute_overhead_ms > 0: print(f" (Includes {recompute_overhead_ms:.2f} ms recomputation overhead)") - print( - f" (Based on {len(profiled_layer_indices)} profiled layers → {num_total_layers} total layers)" - ) + print(f" (Based on {len(profiled_layer_indices)} profiled layers → {num_total_layers} total layers)") print("=" * 100) return total_time_ms -def extract_single_node_time_inference( - profiling_results: dict, training_config -) -> float: +def extract_single_node_time_inference(profiling_results: dict, training_config) -> float: """ Extract total single-node **forward-only** time from profiling results. @@ -908,9 +834,7 @@ def extract_single_node_time_inference( is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0: - print( - "[Primus:Inference Projection] Extracting forward-only timing from benchmark results..." - ) + print("[Primus:Inference Projection] Extracting forward-only timing from benchmark results...") print("-" * 100) model_config = training_config.model_config @@ -918,9 +842,7 @@ def extract_single_node_time_inference( num_total_layers = len(moe_pattern) - profiled_layer_indices = sorted( - [k for k in profiling_results.keys() if isinstance(k, int)] - ) + profiled_layer_indices = sorted([k for k in profiling_results.keys() if isinstance(k, int)]) if is_rank_0: print(f" Profiled layers: {profiled_layer_indices}") print(f" Full model has {num_total_layers} transformer layers") @@ -949,16 +871,8 @@ def extract_single_node_time_inference( else: profiled_moe_fwd.append(fwd_time) - avg_dense_fwd = ( - sum(profiled_dense_fwd) / len(profiled_dense_fwd) - if profiled_dense_fwd - else 0 - ) - avg_moe_fwd = ( - sum(profiled_moe_fwd) / len(profiled_moe_fwd) - if profiled_moe_fwd - else 0 - ) + avg_dense_fwd = sum(profiled_dense_fwd) / len(profiled_dense_fwd) if profiled_dense_fwd else 0 + avg_moe_fwd = sum(profiled_moe_fwd) / len(profiled_moe_fwd) if profiled_moe_fwd else 0 num_dense_layers = sum(1 for x in moe_pattern if x == 0) num_moe_layers = sum(1 for x in moe_pattern if x == 1) @@ -970,16 +884,12 @@ def extract_single_node_time_inference( if is_rank_0: if profiled_dense_fwd: - print( - f" Dense Layers: {len(profiled_dense_fwd)} profiled → {num_dense_layers} total" - ) + print(f" Dense Layers: {len(profiled_dense_fwd)} profiled → {num_dense_layers} total") print(f" Avg fwd per layer: {avg_dense_fwd:.2f} ms") print(f" Total fwd time: {total_dense_time:.2f} ms") if profiled_moe_fwd: - print( - f" MoE Layers: {len(profiled_moe_fwd)} profiled → {num_moe_layers} total" - ) + print(f" MoE Layers: {len(profiled_moe_fwd)} profiled → {num_moe_layers} total") print(f" Avg fwd per layer: {avg_moe_fwd:.2f} ms") print(f" Total fwd time: {total_moe_time:.2f} ms") @@ -993,12 +903,8 @@ def extract_single_node_time_inference( if is_rank_0: print("-" * 100) - print( - f"[Primus:Inference Projection] Extrapolated Forward-Only Time: {total_time_ms:.2f} ms" - ) - print( - f" (Based on {len(profiled_layer_indices)} profiled layers → {num_total_layers} total layers)" - ) + print(f"[Primus:Inference Projection] Extrapolated Forward-Only Time: {total_time_ms:.2f} ms") + print(f" (Based on {len(profiled_layer_indices)} profiled layers → {num_total_layers} total layers)") print("=" * 100) return total_time_ms @@ -1182,9 +1088,7 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): f"[Primus:Performance Projection] After reducing PP to 1, " f"config still requires {benchmark_gpus_required} GPUs (TP={tp}, EP={ep}, CP={cp})." ) - print( - f"[Primus:Performance Projection] Rescaling EP to fit on {gpus_per_node} GPUs..." - ) + print(f"[Primus:Performance Projection] Rescaling EP to fit on {gpus_per_node} GPUs...") # Rescale EP to fit rescale_info = _rescale_expert_parallelism(original_config) @@ -1236,9 +1140,7 @@ def _calculate_single_node_config(original_config, gpus_per_node=8): } -def _estimate_pp_communication_overhead( - training_config, pp_size, hardware_config_dict=None -): +def _estimate_pp_communication_overhead(training_config, pp_size, hardware_config_dict=None): """ Estimate the PP P2P communication overhead for a given PP size. @@ -1298,9 +1200,7 @@ def _estimate_pp_communication_overhead( # Total P2P time per iteration # Forward: (PP-1) sends, Backward: (PP-1) sends # Times number of microbatches - total_p2p_time_ms = ( - 2 * (pp_size - 1) * num_microbatches * p2p_time_per_transfer / 1000 - ) + total_p2p_time_ms = 2 * (pp_size - 1) * num_microbatches * p2p_time_per_transfer / 1000 return total_p2p_time_ms @@ -1444,24 +1344,14 @@ def _estimate_ep_communication_overhead( dispatch_size = tokens_per_gpu * hidden_size * moe_router_topk * 2 # BF16 # Calculate All-to-All time for original EP (dispatch + combine) - a2a_dispatch_original = cm.alltoall( - coll_args_original, dispatch_size, original_ep, groups=["ep"] - ) - a2a_combine_original = cm.alltoall( - coll_args_original, dispatch_size, original_ep, groups=["ep"] - ) + a2a_dispatch_original = cm.alltoall(coll_args_original, dispatch_size, original_ep, groups=["ep"]) + a2a_combine_original = cm.alltoall(coll_args_original, dispatch_size, original_ep, groups=["ep"]) a2a_time_original_fwd = (a2a_dispatch_original + a2a_combine_original) / 1000 # ms # Calculate All-to-All time for benchmark EP (dispatch + combine) - a2a_dispatch_benchmark = cm.alltoall( - coll_args_benchmark, dispatch_size, benchmark_ep, groups=["ep"] - ) - a2a_combine_benchmark = cm.alltoall( - coll_args_benchmark, dispatch_size, benchmark_ep, groups=["ep"] - ) - a2a_time_benchmark_fwd = ( - a2a_dispatch_benchmark + a2a_combine_benchmark - ) / 1000 # ms + a2a_dispatch_benchmark = cm.alltoall(coll_args_benchmark, dispatch_size, benchmark_ep, groups=["ep"]) + a2a_combine_benchmark = cm.alltoall(coll_args_benchmark, dispatch_size, benchmark_ep, groups=["ep"]) + a2a_time_benchmark_fwd = (a2a_dispatch_benchmark + a2a_combine_benchmark) / 1000 # ms # The overhead is the difference (original is larger due to inter-node communication) fwd_overhead_per_layer = a2a_time_original_fwd - a2a_time_benchmark_fwd @@ -1482,9 +1372,7 @@ def _extract_layer_type_timings(layer_results: dict) -> Dict[str, dict[str, floa continue forward = float(result.get("forward_time_ms", 0.0) or 0.0) backward = float(result.get("backward_time_ms", 0.0) or 0.0) - activation = ( - float(result.get("activation_memory_bytes", 0.0) or 0.0) / _BYTES_PER_GB - ) + activation = float(result.get("activation_memory_bytes", 0.0) or 0.0) / _BYTES_PER_GB type_timings[layer_type] = { "forward": forward, "backward": backward, @@ -1507,9 +1395,7 @@ def _add_io_layer_timings(chunk_timings: List[list[dict]], profiling_results: di emb_bwd = embedding.get("backward_time_ms", 0.0) or 0.0 first_chunk["bwd"] += emb_bwd # wgrad already included in backward, don't add again - first_chunk["activation"] += ( - embedding.get("activation_memory_bytes", 0.0) or 0.0 - ) / _BYTES_PER_GB + first_chunk["activation"] += (embedding.get("activation_memory_bytes", 0.0) or 0.0) / _BYTES_PER_GB output = profiling_results.get("output") if output and chunk_timings[-1]: @@ -1518,14 +1404,10 @@ def _add_io_layer_timings(chunk_timings: List[list[dict]], profiling_results: di out_bwd = output.get("backward_time_ms", 0.0) or 0.0 last_chunk["bwd"] += out_bwd # wgrad already included in backward, don't add again - last_chunk["activation"] += ( - output.get("activation_memory_bytes", 0.0) or 0.0 - ) / _BYTES_PER_GB + last_chunk["activation"] += (output.get("activation_memory_bytes", 0.0) or 0.0) / _BYTES_PER_GB -def _build_chunk_time_matrix( - training_config, layer_results: dict -) -> Optional[List[List[dict]]]: +def _build_chunk_time_matrix(training_config, layer_results: dict) -> Optional[List[List[dict]]]: model_cfg = getattr(training_config, "model_config", None) mp_cfg = getattr(training_config, "model_parallel_config", None) if model_cfg is None or mp_cfg is None: @@ -1536,10 +1418,7 @@ def _build_chunk_time_matrix( return None layer_type_pattern = getattr(model_cfg, "moe_pattern", None) - if ( - not isinstance(layer_type_pattern, (list, tuple)) - or len(layer_type_pattern) != total_layers - ): + if not isinstance(layer_type_pattern, (list, tuple)) or len(layer_type_pattern) != total_layers: layer_type_pattern = [0] * total_layers type_timings = _extract_layer_type_timings(layer_results) if not type_timings: @@ -1571,20 +1450,14 @@ def _build_chunk_time_matrix( ) if not layers: chunk_timings.append( - [ - {"fwd": 0.0, "bwd": 0.0, "wgrad": 0.0, "activation": 0.0} - for _ in range(vpp_size) - ] + [{"fwd": 0.0, "bwd": 0.0, "wgrad": 0.0, "activation": 0.0} for _ in range(vpp_size)] ) continue layers_per_chunk = len(layers) // vpp_size if vpp_size else len(layers) if layers_per_chunk == 0: chunk_timings.append( - [ - {"fwd": 0.0, "bwd": 0.0, "wgrad": 0.0, "activation": 0.0} - for _ in range(vpp_size) - ] + [{"fwd": 0.0, "bwd": 0.0, "wgrad": 0.0, "activation": 0.0} for _ in range(vpp_size)] ) continue @@ -1620,9 +1493,7 @@ def _compute_micro_batches(runtime_cfg, model_parallel_config) -> int: return max(1, math.ceil(global_batch / denominator)) -def _build_scheduler_sim_config( - training_config, profiling_results, enable_zero_bubble=False -): +def _build_scheduler_sim_config(training_config, profiling_results, enable_zero_bubble=False): chunk_time_matrix = _build_chunk_time_matrix(training_config, profiling_results) assert chunk_time_matrix is not None @@ -1630,9 +1501,7 @@ def _build_scheduler_sim_config( # The zero-bubble scheduler schedules these separately to minimize pipeline bubbles. # Typically B and W are roughly equal in duration (each ~50% of total backward). if enable_zero_bubble: - print( - "[Primus:Performance Projection] Splitting backward time for zero-bubble scheduling:" - ) + print("[Primus:Performance Projection] Splitting backward time for zero-bubble scheduling:") print(" B (input grad) = 50% of backward, W (weight grad) = 50% of backward") for rank_chunks in chunk_time_matrix: for chunk in rank_chunks: @@ -1677,9 +1546,7 @@ def _build_scheduler_sim_config( "vpp_size": 1, "micro_batches": micro_batches, } - print( - "[Primus:Performance Projection] Using zero-bubble scheduler (enable_zero_bubble=True)" - ) + print("[Primus:Performance Projection] Using zero-bubble scheduler (enable_zero_bubble=True)") elif vpp_size > 1: scheduler = { "name": "interleaved_1f1b", @@ -1758,18 +1625,14 @@ def _report_simulation_results(sim_results, training_config): activation_trace = scheduled_layers.get("activation_memory_usage") or [] peak_activation = ( - max(activation_trace) - if activation_trace - else scheduled_layers.get("memory", 0.0) + max(activation_trace) if activation_trace else scheduled_layers.get("memory", 0.0) ) # Map rank_idx to pipeline rank (rank_idx // vpp_size) vpp_size = mp_cfg.virtual_pipeline_model_parallel_size or 1 pp_rank = rank_idx // vpp_size if pp_rank not in param_mem_cache: - param_mem_cache[pp_rank] = _get_parameter_memory( - training_config, pp_rank - ) + param_mem_cache[pp_rank] = _get_parameter_memory(training_config, pp_rank) param_mem_gb = param_mem_cache[pp_rank] total_peak_gb = peak_activation + param_mem_gb rank_stats.append( @@ -1831,15 +1694,9 @@ def _run_layer_benchmark(primus_config, unknown_overrides): primus_config.get_module_config("pre_trainer").overlap_param_gather = False primus_config.get_module_config("pre_trainer").use_torch_fsdp2 = False print("[Primus:Performance Projection] Config (with profiling overrides):") - print( - f" overlap_grad_reduce: {primus_config.get_module_config('pre_trainer').overlap_grad_reduce}" - ) - print( - f" overlap_param_gather: {primus_config.get_module_config('pre_trainer').overlap_param_gather}" - ) - print( - f" use_torch_fsdp2: {primus_config.get_module_config('pre_trainer').use_torch_fsdp2}" - ) + print(f" overlap_grad_reduce: {primus_config.get_module_config('pre_trainer').overlap_grad_reduce}") + print(f" overlap_param_gather: {primus_config.get_module_config('pre_trainer').overlap_param_gather}") + print(f" use_torch_fsdp2: {primus_config.get_module_config('pre_trainer').use_torch_fsdp2}") trainer = MegatronPretrainTrainer( module_name="pre_trainer", primus_config=primus_config, @@ -1926,9 +1783,7 @@ def _run_layer_simulation(primus_config, args): gpu_arch=gpu_arch, gpu_clock_mhz=gpu_clock_mhz, ) - sdpa_backend = get_sdpa_simulation_backend( - gpu_arch=gpu_arch, gpu_clock_mhz=gpu_clock_mhz - ) + sdpa_backend = get_sdpa_simulation_backend(gpu_arch=gpu_arch, gpu_clock_mhz=gpu_clock_mhz) # ---- Build profiler tree (no model needed) ---- if is_rank_0: @@ -1998,9 +1853,7 @@ def _run_pipeline_simulation_megatron_zb(training_config, profiling_results): mem_b = [] mem_w = [] - print( - "[Primus:Performance Projection] Using Megatron zero-bubble scheduler (ILP-based)" - ) + print("[Primus:Performance Projection] Using Megatron zero-bubble scheduler (ILP-based)") print(f" PP size: {pp_size}, Microbatches: {micro_batches}") for rank_idx, rank_chunks in enumerate(chunk_time_matrix): @@ -2024,9 +1877,7 @@ def _run_pipeline_simulation_megatron_zb(training_config, profiling_results): mem_b.append(float(-act_gb * 0.5)) # B releases half mem_w.append(float(-act_gb * 0.5)) # W releases remaining half - print( - f" Stage {rank_idx}: F={fwd:.2f}ms, B={b_time:.2f}ms, W={w_time:.2f}ms, act={act_gb:.2f}GB" - ) + print(f" Stage {rank_idx}: F={fwd:.2f}ms, B={b_time:.2f}ms, W={w_time:.2f}ms, act={act_gb:.2f}GB") # Estimate communication cost (P2P latency) # Use a small default value; actual value depends on hardware @@ -2055,9 +1906,7 @@ def _run_pipeline_simulation_megatron_zb(training_config, profiling_results): step_time_ms = best_time # Calculate bubble time - total_compute_per_mb = ( - sum(cost_f) / pp_size + sum(cost_b) / pp_size + sum(cost_w) / pp_size - ) + total_compute_per_mb = sum(cost_f) / pp_size + sum(cost_b) / pp_size + sum(cost_w) / pp_size ideal_time = total_compute_per_mb * micro_batches bubble_time = step_time_ms - ideal_time bubble_ratio = bubble_time / step_time_ms if step_time_ms > 0 else 0 @@ -2070,9 +1919,7 @@ def _run_pipeline_simulation_megatron_zb(training_config, profiling_results): return step_time_ms -def _run_pipeline_simulation( - training_config, profiling_results, enable_zero_bubble=False -): +def _run_pipeline_simulation(training_config, profiling_results, enable_zero_bubble=False): """ Run pipeline simulation and return the step time. @@ -2087,16 +1934,12 @@ def _run_pipeline_simulation( # Use Megatron's actual ZB scheduler for more accurate simulation if enable_zero_bubble: try: - return _run_pipeline_simulation_megatron_zb( - training_config, profiling_results - ) + return _run_pipeline_simulation_megatron_zb(training_config, profiling_results) except Exception as e: print(f"[Primus:Performance Projection] Megatron ZB scheduler failed: {e}") print("[Primus:Performance Projection] Falling back to simple simulator...") - sim_config = _build_scheduler_sim_config( - training_config, profiling_results, enable_zero_bubble - ) + sim_config = _build_scheduler_sim_config(training_config, profiling_results, enable_zero_bubble) if sim_config is None: return None print("[Primus:Performance Projection] Running pipeline schedule simulator...") @@ -2216,23 +2059,19 @@ def _run_multinode_projection( print(f" Using custom hardware config from: {args.hardware_config}") else: if is_rank_0: - print( - " Using default hardware parameters from custom_hardware_example.yaml" - ) + print(" Using default hardware parameters from custom_hardware_example.yaml") # Calculate communication times - total_comm_time_ms, breakdown, message_info, per_layer_info = ( - calculate_collective_communication_time( - training_config, - target_nodes, - gpus_per_node, - tp, - pp, - ep, - cp, - dp_target, - hardware_config_dict, - ) + total_comm_time_ms, breakdown, message_info, per_layer_info = calculate_collective_communication_time( + training_config, + target_nodes, + gpus_per_node, + tp, + pp, + ep, + cp, + dp_target, + hardware_config_dict, ) # Benchmarked time is for the minimum node configuration @@ -2262,18 +2101,16 @@ def _run_multinode_projection( grad_ar_per_iteration_ms = 0.0 # Non-overlapped allreduce time (added once) if dp_target > 1: # Calculate gradient all-reduce for target - _, target_breakdown, target_message_info, _ = ( - calculate_collective_communication_time( - training_config, - target_nodes, - gpus_per_node, - tp, - pp, - ep, - cp, - dp_target, - hardware_config_dict, - ) + _, target_breakdown, target_message_info, _ = calculate_collective_communication_time( + training_config, + target_nodes, + gpus_per_node, + tp, + pp, + ep, + cp, + dp_target, + hardware_config_dict, ) target_grad_ar = target_breakdown.get("gradient_allreduce", 0) moe_ar_no_overlap = target_message_info.get("moe_ar_no_overlap", False) @@ -2297,18 +2134,16 @@ def _run_multinode_projection( projected_time_ms = projected_compute_time_ms # For reporting, get full breakdown for target - total_comm_time_ms, breakdown, message_info, per_layer_info = ( - calculate_collective_communication_time( - training_config, - target_nodes, - gpus_per_node, - tp, - pp, - ep, - cp, - dp_target, - hardware_config_dict, - ) + total_comm_time_ms, breakdown, message_info, per_layer_info = calculate_collective_communication_time( + training_config, + target_nodes, + gpus_per_node, + tp, + pp, + ep, + cp, + dp_target, + hardware_config_dict, ) # Add exposed FSDP communication time to projected time @@ -2329,9 +2164,7 @@ def _run_multinode_projection( # Calculate number of microbatches per GPU for the target configuration target_microbatches_per_gpu = ( - global_batch // (micro_batch * target_dp_for_microbatch) - if target_dp_for_microbatch > 0 - else 1 + global_batch // (micro_batch * target_dp_for_microbatch) if target_dp_for_microbatch > 0 else 1 ) # Handle edge case where global_batch is smaller than micro_batch * target_dp @@ -2351,27 +2184,19 @@ def _run_multinode_projection( # Estimate optimizer step time (once per iteration, after all microbatches) gpu_arch = getattr(args, "gpu_arch", None) - optimizer_step_ms = _estimate_optimizer_step_ms( - training_config, dp_target, gpu_arch - ) + optimizer_step_ms = _estimate_optimizer_step_ms(training_config, dp_target, gpu_arch) # Build full iteration time: # compute (per-microbatch) × num_microbatches + gradient allreduce + optimizer step if time_includes_all_microbatches: - full_iteration_time_ms = ( - projected_time_ms + grad_ar_per_iteration_ms + optimizer_step_ms - ) - time_breakdown_str = ( - f"{full_iteration_time_ms:.3f} ms (from pipeline simulation" - ) + full_iteration_time_ms = projected_time_ms + grad_ar_per_iteration_ms + optimizer_step_ms + time_breakdown_str = f"{full_iteration_time_ms:.3f} ms (from pipeline simulation" if grad_ar_per_iteration_ms > 0: time_breakdown_str += f" + {grad_ar_per_iteration_ms:.1f} ms grad AR" time_breakdown_str += f" + {optimizer_step_ms:.1f} ms optimizer)" else: compute_total = projected_time_ms * target_microbatches_per_gpu - full_iteration_time_ms = ( - compute_total + grad_ar_per_iteration_ms + optimizer_step_ms - ) + full_iteration_time_ms = compute_total + grad_ar_per_iteration_ms + optimizer_step_ms time_breakdown_str = f"{full_iteration_time_ms:.3f} ms ({target_microbatches_per_gpu} microbatches × {projected_time_ms:.3f} ms" if grad_ar_per_iteration_ms > 0: time_breakdown_str += f" + {grad_ar_per_iteration_ms:.1f} ms grad AR" @@ -2397,10 +2222,7 @@ def _run_multinode_projection( for op_name, op_time in breakdown.items(): if op_time > 0: print(f" {op_name}: {op_time:.3f} ms", end="") - if ( - op_name == "gradient_allreduce" - and "gradient_allreduce_size_mb" in message_info - ): + if op_name == "gradient_allreduce" and "gradient_allreduce_size_mb" in message_info: moe_no_overlap = message_info.get("moe_ar_no_overlap", False) if moe_no_overlap: detail = " [MoE: NOT overlapped]" @@ -2410,13 +2232,9 @@ def _run_multinode_projection( detail += f"\n Expert AR: {expert_ms:.1f} ms (across {dp_reps} nodes)" detail += f" | Non-expert AR: {non_expert_ms:.1f} ms" else: - overlapped_flag = message_info.get( - "gradient_allreduce_overlapped", False - ) + overlapped_flag = message_info.get("gradient_allreduce_overlapped", False) detail = " [OVERLAPPED]" if overlapped_flag else "" - print( - f" (message: {message_info['gradient_allreduce_size_mb']:.2f} MB){detail}" - ) + print(f" (message: {message_info['gradient_allreduce_size_mb']:.2f} MB){detail}") elif op_name == "moe_a2a_fwd" and "moe_a2a_size_mb" in message_info: print( f" (message: {message_info['moe_a2a_size_mb']:.2f} MB, {message_info['num_moe_layers']} layers × {message_info['moe_a2a_per_layer_fwd']:.3f} ms/layer)" @@ -2607,17 +2425,15 @@ def _run_inference_projection( print(f" Using custom hardware config from: {args.hardware_config}") # Calculate inference communication times (forward-only) - total_comm_time_ms, breakdown, message_info = ( - calculate_inference_communication_time( - training_config, - target_nodes, - gpus_per_node, - tp, - pp, - ep, - cp, - hardware_config_dict, - ) + total_comm_time_ms, breakdown, message_info = calculate_inference_communication_time( + training_config, + target_nodes, + gpus_per_node, + tp, + pp, + ep, + cp, + hardware_config_dict, ) # Inference projected time: forward compute + forward communication @@ -2686,8 +2502,10 @@ def _run_inference_projection( print(f"🎯 Target Configuration ({target_nodes} nodes):") print(f" Nodes: {target_nodes}, GPUs: {total_gpus_target}") print(f" TP={tp}, PP={pp}, EP={ep}, CP={cp}, DP(replicas)={num_replicas}") - print(f" Prefill Latency: {prefill_latency_ms:.3f} ms " - f"(seq_len={seq_len}, micro_batch={micro_batch})") + print( + f" Prefill Latency: {prefill_latency_ms:.3f} ms " + f"(seq_len={seq_len}, micro_batch={micro_batch})" + ) print(f" Tokens/s per replica: {tokens_per_sec_per_replica:,.0f}") print(f" Tokens/s total ({num_replicas} replicas): {total_tokens_per_sec:,.0f}") print(f" Tokens/s/GPU: {tokens_per_sec_per_gpu:,.0f}") @@ -2752,8 +2570,10 @@ def _run_decode_layer_benchmark(primus_config, unknown_overrides, decode_batch_s is_rank_0 = rank == 0 if is_rank_0: print("[Primus:Decode Benchmark] Initializing MegatronPretrainTrainer...") - print("[Primus:Decode Benchmark] (trainer uses original seq_length for init; " - "benchmark will use seq_len=1)") + print( + "[Primus:Decode Benchmark] (trainer uses original seq_length for init; " + "benchmark will use seq_len=1)" + ) # Disable overlap/FSDP features for profiling primus_config.get_module_config("pre_trainer").overlap_grad_reduce = False @@ -2920,9 +2740,7 @@ def _run_decode_projection( moe_a2a_ms = analytical["moe_a2a_total_ms"] method_label = "Analytical Model" - tokens_per_sec_per_replica = ( - decode_batch * 1000 / decode_time_ms if decode_time_ms > 0 else 0 - ) + tokens_per_sec_per_replica = decode_batch * 1000 / decode_time_ms if decode_time_ms > 0 else 0 total_tokens_per_sec = tokens_per_sec_per_replica * dp_replicas tokens_per_sec_per_gpu = total_tokens_per_sec / total_gpus if total_gpus > 0 else 0 @@ -2949,24 +2767,32 @@ def _run_decode_projection( if use_benchmark: # Show benchmarked GEMM time + analytical KV cache overlay - print(f" Layer fwd (benchmarked, seq_len=1): {benchmark_fwd_ms:.4f} ms " - f"(includes GEMMs + TP AllReduce)") - print(f" KV cache read (analytical): {kv_cache_ms:.4f} ms " - f"({analytical['total_kv_mb']:.1f} MB)") + print( + f" Layer fwd (benchmarked, seq_len=1): {benchmark_fwd_ms:.4f} ms " + f"(includes GEMMs + TP AllReduce)" + ) + print( + f" KV cache read (analytical): {kv_cache_ms:.4f} ms " + f"({analytical['total_kv_mb']:.1f} MB)" + ) # Also show analytical-only for comparison analytical_weight_ms = analytical["total_weight_bytes"] / (analytical["hbm_bw_gbps"] * 1e9 / 1e3) analytical_total = analytical["decode_time_ms"] print(f" ─── Analytical comparison ───") - print(f" Weight loading (analytical): {analytical_weight_ms:.4f} ms " - f"({analytical['total_weight_mb']:.1f} MB)") + print( + f" Weight loading (analytical): {analytical_weight_ms:.4f} ms " + f"({analytical['total_weight_mb']:.1f} MB)" + ) print(f" Total (analytical): {analytical_total:.4f} ms") print(f" ────────────────────────────") else: weight_only_ms = analytical["total_weight_bytes"] / (analytical["hbm_bw_gbps"] * 1e9 / 1e3) print(f" Weight loading: {weight_only_ms:.4f} ms ({analytical['total_weight_mb']:.1f} MB)") print(f" KV cache read: {kv_cache_ms:.4f} ms ({analytical['total_kv_mb']:.1f} MB)") - print(f" Compute: {analytical['compute_time_ms']:.4f} ms " - f"({analytical['total_compute_tflops'] * 1000:.2f} GFLOPS)") + print( + f" Compute: {analytical['compute_time_ms']:.4f} ms " + f"({analytical['total_compute_tflops'] * 1000:.2f} GFLOPS)" + ) if analytical["tp_overhead_ms"] > 0: print(f" TP AllReduce: {analytical['tp_overhead_ms']:.4f} ms") @@ -2977,9 +2803,11 @@ def _run_decode_projection( if not use_benchmark: bound = "MEMORY-BOUND" if analytical["is_memory_bound"] else "COMPUTE-BOUND" - print(f" Bottleneck: {bound} " - f"(arith intensity={analytical['arithmetic_intensity']:.2f} FLOPs/B, " - f"balance={analytical['balance_point']:.0f} FLOPs/B)") + print( + f" Bottleneck: {bound} " + f"(arith intensity={analytical['arithmetic_intensity']:.2f} FLOPs/B, " + f"balance={analytical['balance_point']:.0f} FLOPs/B)" + ) print(f" ─────────────────────────────────────") print(f" Total per token: {decode_time_ms:.4f} ms") @@ -2993,14 +2821,18 @@ def _run_decode_projection( print(f" Tokens/s total ({dp_replicas} replicas): {total_tokens_per_sec:,.0f}") print(f" Tokens/s/GPU: {tokens_per_sec_per_gpu:,.0f}") print() - print(f" Generation of {num_gen_tokens} tokens: {total_generation_time_ms:.1f} ms " - f"({total_generation_time_ms / 1000:.2f} s)") + print( + f" Generation of {num_gen_tokens} tokens: {total_generation_time_ms:.1f} ms " + f"({total_generation_time_ms / 1000:.2f} s)" + ) print() print("💾 Memory Estimate (per GPU):") print(f" Model weights: {weight_mb_per_gpu:.1f} MB ({weight_mb_per_gpu / 1024:.2f} GB)") - print(f" KV cache: {kv_mb_per_gpu:.1f} MB ({kv_mb_per_gpu / 1024:.2f} GB) " - f"(batch={decode_batch}, ctx={context_len})") + print( + f" KV cache: {kv_mb_per_gpu:.1f} MB ({kv_mb_per_gpu / 1024:.2f} GB) " + f"(batch={decode_batch}, ctx={context_len})" + ) total_mem_gb = (weight_mb_per_gpu + kv_mb_per_gpu) / 1024 print(f" Total: {total_mem_gb:.2f} GB") @@ -3044,9 +2876,7 @@ def launch_projection_from_cli(args, overrides): """ cfg_path = Path(args.config) if not cfg_path.exists(): - raise FileNotFoundError( - f"[Primus:Performance Projection] Config file '{cfg_path}' not found." - ) + raise FileNotFoundError(f"[Primus:Performance Projection] Config file '{cfg_path}' not found.") # Load Primus configuration primus_config, unknown_overrides = load_primus_config(args, overrides) @@ -3060,9 +2890,7 @@ def launch_projection_from_cli(args, overrides): # Store original parallelism before any modifications module_config = primus_config.get_module_config("pre_trainer") - reduction_info = _calculate_single_node_config( - copy.deepcopy(module_config), gpus_per_node - ) + reduction_info = _calculate_single_node_config(copy.deepcopy(module_config), gpus_per_node) # Calculate minimum nodes required min_nodes_required = reduction_info["original_nodes_required"] @@ -3089,13 +2917,9 @@ def launch_projection_from_cli(args, overrides): # Show what was changed changes = [] if reduction_info["original_pp"] != reduction_info["benchmark_pp"]: - changes.append( - f"PP {reduction_info['original_pp']} → {reduction_info['benchmark_pp']}" - ) + changes.append(f"PP {reduction_info['original_pp']} → {reduction_info['benchmark_pp']}") if reduction_info["original_ep"] != reduction_info["benchmark_ep"]: - changes.append( - f"EP {reduction_info['original_ep']} → {reduction_info['benchmark_ep']}" - ) + changes.append(f"EP {reduction_info['original_ep']} → {reduction_info['benchmark_ep']}") if changes: print(f" ({', '.join(changes)})") @@ -3104,12 +2928,12 @@ def launch_projection_from_cli(args, overrides): print("=" * 100) # Apply the reduction to the config used for benchmarking - primus_config.get_module_config("pre_trainer").pipeline_model_parallel_size = ( - reduction_info["benchmark_pp"] - ) - primus_config.get_module_config("pre_trainer").expert_model_parallel_size = ( - reduction_info["benchmark_ep"] - ) + primus_config.get_module_config("pre_trainer").pipeline_model_parallel_size = reduction_info[ + "benchmark_pp" + ] + primus_config.get_module_config("pre_trainer").expert_model_parallel_size = reduction_info[ + "benchmark_ep" + ] # Also propagate num_experts adjustment so that the profiler sees # the correct experts_per_rank (e.g. 128/4=32, not 256/4=64). if reduction_info.get("benchmark_num_experts") is not None: @@ -3156,13 +2980,17 @@ def launch_projection_from_cli(args, overrides): print("[Primus:Decode] Running BENCHMARK-ENHANCED projection...") print("=" * 100) _run_decode_projection( - training_config, args, target_nodes, + training_config, + args, + target_nodes, profiling_results=decode_profiling_results, ) else: # Single mode: benchmark or simulate _run_decode_projection( - training_config, args, target_nodes, + training_config, + args, + target_nodes, profiling_results=decode_profiling_results, ) return @@ -3200,12 +3028,8 @@ def launch_projection_from_cli(args, overrides): fwd_err = ((s_fwd - b_fwd) / b_fwd * 100) if b_fwd else 0 bwd_err = ((s_bwd - b_bwd) / b_bwd * 100) if b_bwd else 0 print(f" Layer type: {lt}") - print( - f" Forward: bench={b_fwd:.2f} ms sim={s_fwd:.2f} ms (err={fwd_err:+.1f}%)" - ) - print( - f" Backward: bench={b_bwd:.2f} ms sim={s_bwd:.2f} ms (err={bwd_err:+.1f}%)" - ) + print(f" Forward: bench={b_fwd:.2f} ms sim={s_fwd:.2f} ms (err={fwd_err:+.1f}%)") + print(f" Backward: bench={b_bwd:.2f} ms sim={s_bwd:.2f} ms (err={bwd_err:+.1f}%)") print("=" * 100) # Use benchmark results for the rest of the pipeline @@ -3256,9 +3080,7 @@ def launch_projection_from_cli(args, overrides): print( f" Benchmark Config: PP={benchmark_pp}, EP={benchmark_ep}, TP={tp}, CP={cp}, DP={benchmark_dp} (1 node)" ) - print( - f" Target Config: PP={pp}, EP={ep}, TP={tp}, CP={cp}, DP={target_dp} ({target_nodes} nodes)" - ) + print(f" Target Config: PP={pp}, EP={ep}, TP={tp}, CP={cp}, DP={target_dp} ({target_nodes} nodes)") print(f" Mode: {projection_mode}") # ========================================================================= @@ -3266,10 +3088,7 @@ def launch_projection_from_cli(args, overrides): # ========================================================================= if projection_mode == MODE_PREFILL: # For inference, EP overhead adjustment is forward-only - if ( - reduction_info["adjusted"] - and reduction_info["original_ep"] != reduction_info["benchmark_ep"] - ): + if reduction_info["adjusted"] and reduction_info["original_ep"] != reduction_info["benchmark_ep"]: original_ep = reduction_info["original_ep"] benchmark_ep = reduction_info["benchmark_ep"] original_num_experts = reduction_info.get("original_num_experts") @@ -3280,11 +3099,15 @@ def launch_projection_from_cli(args, overrides): hardware_config_dict = load_hardware_config(args.hardware_config) fwd_overhead_per_layer, _ = _estimate_ep_communication_overhead( - training_config, original_ep, benchmark_ep, hardware_config_dict, + training_config, + original_ep, + benchmark_ep, + hardware_config_dict, ) ep_mlp_scale = _compute_ep_mlp_scale( training_config.model_config, - benchmark_ep, original_ep, + benchmark_ep, + original_ep, original_num_experts=original_num_experts, benchmark_num_experts=benchmark_num_experts, ) @@ -3309,9 +3132,7 @@ def launch_projection_from_cli(args, overrides): mlp_info["forward_time_ms"] = new_mlp_fwd # Extract forward-only time - forward_time_ms = extract_single_node_time_inference( - profiling_results, training_config - ) + forward_time_ms = extract_single_node_time_inference(profiling_results, training_config) # Run inference projection if target_nodes >= min_nodes_required: @@ -3343,9 +3164,7 @@ def launch_projection_from_cli(args, overrides): # common case for configs that already require all target GPUs for their # parallelism dims). Using benchmark_dp here would give 2× too many # microbatches when benchmark_dp < target_dp. - target_microbatches = ( - global_batch // (micro_batch * target_dp) if target_dp > 0 else 1 - ) + target_microbatches = global_batch // (micro_batch * target_dp) if target_dp > 0 else 1 target_microbatches = max(1, target_microbatches) benchmark_microbatches = global_batch // (micro_batch * benchmark_dp) if is_rank_0: @@ -3362,10 +3181,7 @@ def launch_projection_from_cli(args, overrides): # If EP was rescaled, adjust profiling_results to add EP overhead BEFORE pipeline simulation ep_overhead_applied = False - if ( - reduction_info["adjusted"] - and reduction_info["original_ep"] != reduction_info["benchmark_ep"] - ): + if reduction_info["adjusted"] and reduction_info["original_ep"] != reduction_info["benchmark_ep"]: original_ep = reduction_info["original_ep"] benchmark_ep = reduction_info["benchmark_ep"] original_num_experts = reduction_info.get("original_num_experts") @@ -3377,13 +3193,11 @@ def launch_projection_from_cli(args, overrides): hardware_config_dict = load_hardware_config(args.hardware_config) # Calculate EP communication overhead per layer (A2A delta) - fwd_overhead_per_layer, bwd_overhead_per_layer = ( - _estimate_ep_communication_overhead( - training_config, - original_ep, - benchmark_ep, - hardware_config_dict, - ) + fwd_overhead_per_layer, bwd_overhead_per_layer = _estimate_ep_communication_overhead( + training_config, + original_ep, + benchmark_ep, + hardware_config_dict, ) # EP compute scaling. Per-GPU routed compute is EP-invariant @@ -3401,9 +3215,7 @@ def launch_projection_from_cli(args, overrides): ) if is_rank_0: - print( - "[Primus:Performance Projection] Adjusting profiling results for EP scaling:" - ) + print("[Primus:Performance Projection] Adjusting profiling results for EP scaling:") print(f" EP rescaled: {benchmark_ep} → {original_ep}") if original_num_experts is not None and benchmark_num_experts is not None: orig_epr = original_num_experts // original_ep @@ -3448,12 +3260,8 @@ def launch_projection_from_cli(args, overrides): if is_rank_0 and moe_layers_adjusted == 0: print(f" MoE layer adjustment (per layer):") - print( - f" MLP fwd: {mlp_fwd:.2f} → {new_mlp_fwd:.2f} ms (×{ep_mlp_scale:.3f})" - ) - print( - f" MLP bwd: {mlp_bwd:.2f} → {new_mlp_bwd:.2f} ms (×{ep_mlp_scale:.3f})" - ) + print(f" MLP fwd: {mlp_fwd:.2f} → {new_mlp_fwd:.2f} ms (×{ep_mlp_scale:.3f})") + print(f" MLP bwd: {mlp_bwd:.2f} → {new_mlp_bwd:.2f} ms (×{ep_mlp_scale:.3f})") print(f" A2A fwd delta: +{fwd_overhead_per_layer:.3f} ms") print(f" A2A bwd delta: +{bwd_overhead_per_layer:.3f} ms") print(f" Layer fwd: {old_fwd:.2f} → {new_fwd:.2f} ms") @@ -3503,17 +3311,13 @@ def launch_projection_from_cli(args, overrides): # No need to add additional PP overhead benchmarked_time_ms = pipeline_simulation_time_ms if is_rank_0: - print( - f" (Pipeline simulation already includes PP={reduction_info['original_pp']} effects)" - ) + print(f" (Pipeline simulation already includes PP={reduction_info['original_pp']} effects)") else: if is_rank_0: print( "[Primus:Performance Projection] Pipeline simulation not available, using extrapolated time from profiling" ) - measured_time_ms = extract_single_node_time_from_profiling( - profiling_results, training_config - ) + measured_time_ms = extract_single_node_time_from_profiling(profiling_results, training_config) # If we reduced PP for benchmarking, estimate the time with PP overhead if reduction_info["adjusted"]: @@ -3531,9 +3335,7 @@ def launch_projection_from_cli(args, overrides): if is_rank_0: print("[Primus:Performance Projection] Time Adjustment:") - print( - f" Measured time (PP={reduction_info['benchmark_pp']}): {measured_time_ms:.2f} ms" - ) + print(f" Measured time (PP={reduction_info['benchmark_pp']}): {measured_time_ms:.2f} ms") print( f" Estimated PP overhead (PP={reduction_info['original_pp']}): {pp_overhead_ms:.2f} ms" ) @@ -3553,32 +3355,24 @@ def launch_projection_from_cli(args, overrides): benchmark_ep_val = reduction_info["benchmark_ep"] # Get the number of MoE layers - moe_pattern = getattr( - training_config.model_config, "moe_layer_pattern", [] - ) + moe_pattern = getattr(training_config.model_config, "moe_layer_pattern", []) if not moe_pattern: # If no pattern, check if model has MoE layers - num_moe_layers = getattr( - training_config.model_config, "num_moe_layers", 0 - ) + num_moe_layers = getattr(training_config.model_config, "num_moe_layers", 0) else: num_moe_layers = sum(1 for x in moe_pattern if x == 1) if num_moe_layers > 0: # Calculate EP communication overhead per layer - fwd_overhead_per_layer, bwd_overhead_per_layer = ( - _estimate_ep_communication_overhead( - training_config, - original_ep, - benchmark_ep_val, - hardware_config_dict, - ) + fwd_overhead_per_layer, bwd_overhead_per_layer = _estimate_ep_communication_overhead( + training_config, + original_ep, + benchmark_ep_val, + hardware_config_dict, ) # Total EP overhead = per-layer overhead * number of MoE layers - total_ep_overhead_ms = ( - fwd_overhead_per_layer + bwd_overhead_per_layer - ) * num_moe_layers + total_ep_overhead_ms = (fwd_overhead_per_layer + bwd_overhead_per_layer) * num_moe_layers # EP compute scaling — per-GPU MoE routed compute is # EP-invariant (see _compute_ep_mlp_scale docstring). @@ -3594,32 +3388,23 @@ def launch_projection_from_cli(args, overrides): # Estimate MLP portion of MoE layer time from profiling results mlp_time_reduction = 0.0 for layer_idx, layer_data in profiling_results.items(): - if ( - isinstance(layer_data, dict) - and layer_data.get("type") == "moe" - ): + if isinstance(layer_data, dict) and layer_data.get("type") == "moe": mlp_info = layer_data.get("mlp", {}) - mlp_total = mlp_info.get( - "forward_time_ms", 0 - ) + mlp_info.get("backward_time_ms", 0) + mlp_total = mlp_info.get("forward_time_ms", 0) + mlp_info.get( + "backward_time_ms", 0 + ) mlp_time_reduction = mlp_total * (1 - ep_mlp_scale) break # All MoE layers have same profiled time total_mlp_reduction_ms = mlp_time_reduction * num_moe_layers if is_rank_0: - print( - "[Primus:Performance Projection] EP Compute + Communication Adjustment:" - ) + print("[Primus:Performance Projection] EP Compute + Communication Adjustment:") print(f" EP rescaled: {benchmark_ep_val} → {original_ep}") print(f" Number of MoE layers: {num_moe_layers}") print(f" MLP time scale factor: {ep_mlp_scale:.3f}") - print( - f" Total MLP compute reduction: -{total_mlp_reduction_ms:.3f} ms" - ) - print( - f" Total A2A comm overhead: +{total_ep_overhead_ms:.3f} ms" - ) + print(f" Total MLP compute reduction: -{total_mlp_reduction_ms:.3f} ms") + print(f" Total A2A comm overhead: +{total_ep_overhead_ms:.3f} ms") net_change = total_ep_overhead_ms - total_mlp_reduction_ms print(f" Net adjustment: {net_change:+.3f} ms") diff --git a/primus/core/projection/performance_projection/simulator.py b/primus/core/projection/performance_projection/simulator.py index f39bfb88e..00e3d6f6d 100644 --- a/primus/core/projection/performance_projection/simulator.py +++ b/primus/core/projection/performance_projection/simulator.py @@ -32,15 +32,11 @@ def __init__(self, config: dict): self.debug_simulator = int(os.getenv("DEBUG_SIMULATOR", "0") == "1") - def _summarize_simulation_result( - self, simulation_result: list[dict], scheduler_config: dict - ) -> dict: + def _summarize_simulation_result(self, simulation_result: list[dict], scheduler_config: dict) -> dict: rank_totals = [rank.get("total", 0.0) for rank in simulation_result] step_time_ms = max(rank_totals) if rank_totals else 0.0 critical_rank = rank_totals.index(step_time_ms) if rank_totals else None - max_memory = max( - (rank.get("memory", 0.0) for rank in simulation_result), default=0.0 - ) + max_memory = max((rank.get("memory", 0.0) for rank in simulation_result), default=0.0) return { "step_time_ms": step_time_ms, "rank_totals": rank_totals, @@ -89,9 +85,7 @@ def _chunk_duration( else: raise ValueError("Duration is not found.") - def _chunk_activation( - self, rank: int, chunk: int | None, vpp_size: int | None - ) -> float: + def _chunk_activation(self, rank: int, chunk: int | None, vpp_size: int | None) -> float: if self.chunk_time_ms is None: if vpp_size is None: vpp_size = 1 @@ -169,9 +163,7 @@ def run(self): module = importlib.import_module(module_path) scheduler_class = getattr(module, class_name) - scheduler_params = { - k: v for k, v in scheduler_config.items() if k not in ["name", "class"] - } + scheduler_params = {k: v for k, v in scheduler_config.items() if k not in ["name", "class"]} scheduler_instance = scheduler_class(**scheduler_params) schedule_table = scheduler_instance.generate_schedule_table() @@ -182,13 +174,9 @@ def run(self): print(f"{'='*20 * scheduler_config['pp_size']}") if self.debug_simulator: scheduler_instance.print_schedule_table(schedule_table) - simulation_result = self.simulate_scheduler_table( - schedule_table, scheduler_config - ) + simulation_result = self.simulate_scheduler_table(schedule_table, scheduler_config) self.dump_simulation_result(simulation_result, scheduler_config) - summary = self._summarize_simulation_result( - simulation_result, scheduler_config - ) + summary = self._summarize_simulation_result(simulation_result, scheduler_config) run_summaries.append( { "name": scheduler_config["name"], @@ -203,9 +191,7 @@ def run(self): return run_summaries - def simulate_scheduler_table( - self, schedule_table: list[list[SchedulerNode]], scheduler_config: dict - ): + def simulate_scheduler_table(self, schedule_table: list[list[SchedulerNode]], scheduler_config: dict): current_rank = 0 rank_clock = [0.0 for _ in range(len(schedule_table))] @@ -265,9 +251,7 @@ def simulate_scheduler_table( print(f"rank {current_rank} send_key: {send_key}") communication_map[send_key] = rank_clock[current_rank] if node.func_type in [FuncType.RF, FuncType.RB]: - send_func_type = ( - FuncType.SF if node.func_type == FuncType.RF else FuncType.SB - ) + send_func_type = FuncType.SF if node.func_type == FuncType.RF else FuncType.SB send_key = f"{node.args['from_pp_rank']}_{node.args['to_pp_rank']}_{send_func_type}_{node.mini_batch}_{node.chunk}" if send_key not in communication_map: merge_comm_index = rank_idx[current_rank] + 1 @@ -275,9 +259,7 @@ def simulate_scheduler_table( print(f"rank {current_rank} wait send_key {send_key}") # merge the send op behind - for i in range( - merge_comm_index, len(schedule_table[current_rank]) - ): + for i in range(merge_comm_index, len(schedule_table[current_rank])): if schedule_table[current_rank][i].func_type in [ FuncType.SF, FuncType.SB, @@ -294,9 +276,7 @@ def simulate_scheduler_table( send_time = communication_map.pop(send_key) - recv_time_map[ - f"{node.mini_batch}_{node.chunk}_{node.func_type}" - ] = send_time + recv_time_map[f"{node.mini_batch}_{node.chunk}_{node.func_type}"] = send_time if node.func_type in [ FuncType.F, @@ -306,21 +286,17 @@ def simulate_scheduler_table( FuncType.FB, ]: if node.func_type in [FuncType.F, FuncType.B, FuncType.BW]: - recv_node_type = ( - FuncType.RF if node.func_type == FuncType.F else FuncType.RB - ) + recv_node_type = FuncType.RF if node.func_type == FuncType.F else FuncType.RB recv_key = f"{node.mini_batch}_{node.chunk}_{recv_node_type}" if recv_key in recv_time_map: rank_clock[current_rank] = max( rank_clock[current_rank], - recv_time_map[ - f"{node.mini_batch}_{node.chunk}_{recv_node_type}" - ], + recv_time_map[f"{node.mini_batch}_{node.chunk}_{recv_node_type}"], ) - simulation_result[current_rank][ - f"{self._result_key_dict[node.func_type]}_start" - ].append(rank_clock[current_rank]) + simulation_result[current_rank][f"{self._result_key_dict[node.func_type]}_start"].append( + rank_clock[current_rank] + ) duration = self._chunk_duration( current_rank, getattr(node, "chunk", 0), @@ -328,15 +304,15 @@ def simulate_scheduler_table( scheduler_config, ) rank_clock[current_rank] += duration - simulation_result[current_rank][ - f"{self._result_key_dict[node.func_type]}_end" - ].append(rank_clock[current_rank]) + simulation_result[current_rank][f"{self._result_key_dict[node.func_type]}_end"].append( + rank_clock[current_rank] + ) simulation_result[current_rank][ f"{self._result_key_dict[node.func_type]}_minibatch" ].append(node.mini_batch) - simulation_result[current_rank][ - f"{self._result_key_dict[node.func_type]}_chunk" - ].append(node.chunk) + simulation_result[current_rank][f"{self._result_key_dict[node.func_type]}_chunk"].append( + node.chunk + ) if node.func_type == FuncType.F: act_gb = self._chunk_activation( current_rank, @@ -348,9 +324,9 @@ def simulate_scheduler_table( simulation_result[current_rank]["memory"], rank_memory[current_rank], ) - simulation_result[current_rank][ - "activation_memory_usage" - ].append(rank_memory[current_rank]) + simulation_result[current_rank]["activation_memory_usage"].append( + rank_memory[current_rank] + ) elif node.func_type in [FuncType.BW, FuncType.W]: act_gb = self._chunk_activation( current_rank, @@ -358,9 +334,9 @@ def simulate_scheduler_table( scheduler_config["vpp_size"], ) rank_memory[current_rank] = rank_memory[current_rank] - act_gb - simulation_result[current_rank][ - "activation_memory_usage" - ].append(rank_memory[current_rank]) + simulation_result[current_rank]["activation_memory_usage"].append( + rank_memory[current_rank] + ) rank_idx[current_rank] += 1 current_rank = (current_rank + 1) % len(schedule_table) @@ -370,9 +346,7 @@ def simulate_scheduler_table( return simulation_result - def dump_simulation_result( - self, simulation_result: list[dict], scheduler_config: dict - ): + def dump_simulation_result(self, simulation_result: list[dict], scheduler_config: dict): result_dir = f"{self.config['output_dir']}/{scheduler_config['name']}" os.makedirs(result_dir, exist_ok=True) with open(f"{result_dir}/config.json", "w") as f: diff --git a/primus/core/projection/profiler_spec.py b/primus/core/projection/profiler_spec.py index e053f02fa..02c8d9e25 100644 --- a/primus/core/projection/profiler_spec.py +++ b/primus/core/projection/profiler_spec.py @@ -15,6 +15,6 @@ class ModuleProfilerSpec: profiler: Type[BaseModuleProfiler] config: Type[TrainingConfig] - sub_profiler_specs: Optional[ - Dict[str, Union[Type[BaseModuleProfiler], "ModuleProfilerSpec", None]] - ] = field(default_factory=lambda: {}) + sub_profiler_specs: Optional[Dict[str, Union[Type[BaseModuleProfiler], "ModuleProfilerSpec", None]]] = ( + field(default_factory=lambda: {}) + ) diff --git a/primus/core/projection/simulation_backends/base.py b/primus/core/projection/simulation_backends/base.py index 941be3d53..0b4a8fa21 100644 --- a/primus/core/projection/simulation_backends/base.py +++ b/primus/core/projection/simulation_backends/base.py @@ -139,48 +139,26 @@ def simulate_mlp_gemms( if swiglu: # Gate projection fwd: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] - gate_fwd = self.simulate_gemm( - batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b - ) + gate_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) # Up projection fwd: same shape as gate - up_fwd = self.simulate_gemm( - batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b - ) + up_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) # Down projection fwd: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] - down_fwd = self.simulate_gemm( - batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b - ) + down_fwd = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) - fwd_time = ( - gate_fwd.forward_time_ms - + up_fwd.forward_time_ms - + down_fwd.forward_time_ms - ) + fwd_time = gate_fwd.forward_time_ms + up_fwd.forward_time_ms + down_fwd.forward_time_ms # Backward: simulate actual dgrad + wgrad GEMMs per projection # Gate dgrad: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] - gate_dgrad = self.simulate_gemm( - batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b - ) + gate_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) # Gate wgrad: [hidden, tokens] x [tokens, ffn] -> [hidden, ffn] - gate_wgrad = self.simulate_gemm( - hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b - ) + gate_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) # Up dgrad + wgrad: same shapes as gate - up_dgrad = self.simulate_gemm( - batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b - ) - up_wgrad = self.simulate_gemm( - hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b - ) + up_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) + up_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) # Down dgrad: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] - down_dgrad = self.simulate_gemm( - batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b - ) + down_dgrad = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) # Down wgrad: [ffn, tokens] x [tokens, hidden] -> [ffn, hidden] - down_wgrad = self.simulate_gemm( - ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b - ) + down_wgrad = self.simulate_gemm(ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b) bwd_time = ( gate_dgrad.forward_time_ms @@ -192,33 +170,21 @@ def simulate_mlp_gemms( ) else: # Up projection fwd: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] - up_fwd = self.simulate_gemm( - batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b - ) + up_fwd = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) # Down projection fwd: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] - down_fwd = self.simulate_gemm( - batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b - ) + down_fwd = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) fwd_time = up_fwd.forward_time_ms + down_fwd.forward_time_ms # Backward: simulate actual dgrad + wgrad GEMMs per projection # Up dgrad: [tokens, ffn] x [ffn, hidden] -> [tokens, hidden] - up_dgrad = self.simulate_gemm( - batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b - ) + up_dgrad = self.simulate_gemm(batch_tokens, hidden_size, ffn_hidden_size, dtype, batch=b) # Up wgrad: [hidden, tokens] x [tokens, ffn] -> [hidden, ffn] - up_wgrad = self.simulate_gemm( - hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b - ) + up_wgrad = self.simulate_gemm(hidden_size, ffn_hidden_size, batch_tokens, dtype, batch=b) # Down dgrad: [tokens, hidden] x [hidden, ffn] -> [tokens, ffn] - down_dgrad = self.simulate_gemm( - batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b - ) + down_dgrad = self.simulate_gemm(batch_tokens, ffn_hidden_size, hidden_size, dtype, batch=b) # Down wgrad: [ffn, tokens] x [tokens, hidden] -> [ffn, hidden] - down_wgrad = self.simulate_gemm( - ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b - ) + down_wgrad = self.simulate_gemm(ffn_hidden_size, hidden_size, batch_tokens, dtype, batch=b) bwd_time = ( up_dgrad.forward_time_ms diff --git a/primus/core/projection/simulation_backends/factory.py b/primus/core/projection/simulation_backends/factory.py index 114b85cb4..76db3df9b 100644 --- a/primus/core/projection/simulation_backends/factory.py +++ b/primus/core/projection/simulation_backends/factory.py @@ -51,10 +51,7 @@ def get_gemm_simulation_backend( is_rank_0 = int(os.getenv("RANK", "0")) == 0 if name is not None and name != "origami": - raise ValueError( - f"Unknown GEMM simulation backend: '{name}'. " - f"Supported backend: 'origami'" - ) + raise ValueError(f"Unknown GEMM simulation backend: '{name}'. " f"Supported backend: 'origami'") from primus.core.projection.simulation_backends.origami_backend import ( OrigamiGEMMBackend, @@ -63,8 +60,7 @@ def get_gemm_simulation_backend( backend = OrigamiGEMMBackend(gpu_arch=gpu_arch, gpu_clock_mhz=gpu_clock_mhz) if not backend.is_available(): raise RuntimeError( - "Origami GEMM simulation backend is not available.\n" - "Install it with: pip install origami" + "Origami GEMM simulation backend is not available.\n" "Install it with: pip install origami" ) if is_rank_0: @@ -98,9 +94,7 @@ def get_sdpa_simulation_backend( is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0: - print( - "[Primus:Simulation] Using SDPA backend: sdpa_simulator (FAv3 Origami 1-CU)" - ) + print("[Primus:Simulation] Using SDPA backend: sdpa_simulator (FAv3 Origami 1-CU)") return SDPASimulator( gpu_arch=gpu_arch, diff --git a/primus/core/projection/simulation_backends/origami_backend.py b/primus/core/projection/simulation_backends/origami_backend.py index 72febb90b..be2bad979 100644 --- a/primus/core/projection/simulation_backends/origami_backend.py +++ b/primus/core/projection/simulation_backends/origami_backend.py @@ -164,9 +164,7 @@ def __init__( # Clock override: CLI > env var > profile default _env_clock = os.getenv("PRIMUS_GPU_CLOCK_MHZ", None) - self._clock_override_mhz: Optional[int] = gpu_clock_mhz or ( - int(_env_clock) if _env_clock else None - ) + self._clock_override_mhz: Optional[int] = gpu_clock_mhz or (int(_env_clock) if _env_clock else None) self._n_cu_override = n_cu_override @@ -258,12 +256,8 @@ def simulate_gemm( problem = _origami.problem_t() problem.size = _origami.dim3_t(m, n, k) problem.batch = batch - problem.a_transpose = ( - _origami.transpose_t.T if trans_a else _origami.transpose_t.N - ) - problem.b_transpose = ( - _origami.transpose_t.T if trans_b else _origami.transpose_t.N - ) + problem.a_transpose = _origami.transpose_t.T if trans_a else _origami.transpose_t.N + problem.b_transpose = _origami.transpose_t.T if trans_b else _origami.transpose_t.N origami_dtype = _origami.string_to_datatype(_DTYPE_MAP.get(sim_dtype, "bf16")) problem.a_dtype = origami_dtype @@ -279,8 +273,7 @@ def simulate_gemm( result = _origami.select_config(problem, self._hardware, self._configs) except Exception as e: raise RuntimeError( - f"Origami select_config failed for " - f"(M={m}, N={n}, K={k}, dtype={dtype}): {e}" + f"Origami select_config failed for " f"(M={m}, N={n}, K={k}, dtype={dtype}): {e}" ) from e latency_cycles = result.latency @@ -407,9 +400,7 @@ def _get_hardware(self): clock_khz = profile.compute_clock_khz if self._clock_override_mhz is not None: clock_khz = self._clock_override_mhz * 1000 - n_cu = ( - self._n_cu_override if self._n_cu_override is not None else profile.n_cu - ) + n_cu = self._n_cu_override if self._n_cu_override is not None else profile.n_cu arch_enum = getattr(_origami.architecture_t, profile.arch_enum_name) hw = _origami.get_hardware_for_arch( arch_enum, @@ -422,11 +413,7 @@ def _get_hardware(self): override_tag = "" if self._clock_override_mhz is not None: override_tag = " (overridden via --gpu-clock-mhz)" - cu_tag = ( - f" (n_cu_override={n_cu})" - if self._n_cu_override is not None - else "" - ) + cu_tag = f" (n_cu_override={n_cu})" if self._n_cu_override is not None else "" print( f"[Primus:Origami] Using hardware profile for " f"'{self._gpu_arch}': N_CU={n_cu}, " @@ -481,9 +468,7 @@ def _get_hardware(self): override_tag = "" if self._clock_override_mhz is not None: override_tag = " (overridden via --gpu-clock-mhz)" - cu_tag = ( - f" (n_cu_override={n_cu})" if self._n_cu_override is not None else "" - ) + cu_tag = f" (n_cu_override={n_cu})" if self._n_cu_override is not None else "" print( f"[Primus:Origami] Using known hardware profile for " f"'{self._gpu_arch}': N_CU={n_cu}, " diff --git a/primus/core/projection/simulation_backends/sdpa_simulator.py b/primus/core/projection/simulation_backends/sdpa_simulator.py index c4723e89c..cd1629aaa 100644 --- a/primus/core/projection/simulation_backends/sdpa_simulator.py +++ b/primus/core/projection/simulation_backends/sdpa_simulator.py @@ -173,9 +173,7 @@ def _get_hardware_spec( spec = _HW_PROFILES.get(arch, _HW_PROFILES["mi300x"]) # Apply clock override — scale TFLOPS linearly - clock_override = gpu_clock_mhz or ( - int(v) if (v := os.getenv("PRIMUS_GPU_CLOCK_MHZ")) else None - ) + clock_override = gpu_clock_mhz or (int(v) if (v := os.getenv("PRIMUS_GPU_CLOCK_MHZ")) else None) if clock_override is not None: # Derive the profile's implicit clock from a known reference. _PROFILE_CLOCK_MHZ = { @@ -344,20 +342,14 @@ def _create_tile_gemm_backend( if backend.is_available(): is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0: - print( - "[Primus:SDPA] Using Origami 1-CU tile-level simulation " - "for Flash Attention" - ) + print("[Primus:SDPA] Using Origami 1-CU tile-level simulation " "for Flash Attention") return backend except Exception as exc: # If Origami is not available or fails to initialize, fall back to # the analytic SDPA model by returning None here. is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0: - print( - "[Primus:SDPA] Origami 1-CU tile-level simulation disabled " - f"due to error: {exc}" - ) + print("[Primus:SDPA] Origami 1-CU tile-level simulation disabled " f"due to error: {exc}") return None def _simulate_tile_level( @@ -498,8 +490,7 @@ def _simulate_tile_level( total_updates_local = warp_updates_local * bwd_waves bwd_atomic_ms = ( - _ATOMIC_LATENCY_GLOBAL_NS * total_updates_global - + _ATOMIC_LATENCY_LOCAL_NS * total_updates_local + _ATOMIC_LATENCY_GLOBAL_NS * total_updates_global + _ATOMIC_LATENCY_LOCAL_NS * total_updates_local ) / 1e6 # ns → ms bwd_time_ms = bwd_compute_ms + bwd_atomic_ms @@ -540,17 +531,13 @@ def _simulate_tile_level( + B * H_K * S_K * D_v * bpe # dV ) - fwd_achieved_tflops = ( - (fwd_flops / (fwd_time_ms * 1e-3)) / 1e12 if fwd_time_ms > 0 else 0 - ) + fwd_achieved_tflops = (fwd_flops / (fwd_time_ms * 1e-3)) / 1e12 if fwd_time_ms > 0 else 0 return SimulationResult( forward_time_ms=fwd_time_ms, backward_time_ms=bwd_time_ms, tflops=fwd_achieved_tflops, - bandwidth_gbps=( - (fwd_bytes / (fwd_time_ms * 1e-3)) / 1e9 if fwd_time_ms > 0 else 0 - ), + bandwidth_gbps=((fwd_bytes / (fwd_time_ms * 1e-3)) / 1e9 if fwd_time_ms > 0 else 0), metadata={ "backend": "sdpa_simulator (FAv3 tile-level, Origami 1-CU)", # Standard metadata keys (for compatibility) diff --git a/primus/core/projection/training_config.py b/primus/core/projection/training_config.py index 475a3bb50..7e84fbe6b 100644 --- a/primus/core/projection/training_config.py +++ b/primus/core/projection/training_config.py @@ -100,9 +100,7 @@ def megatron_derive_default_args(args): if not hasattr(args, "data_parallel_size") or args.data_parallel_size is None: args.data_parallel_size = world_size // ( - args.tensor_model_parallel_size - * args.pipeline_model_parallel_size - * args.context_parallel_size + args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.context_parallel_size ) if not hasattr(args, "virtual_pipeline_model_parallel_size"): args.virtual_pipeline_model_parallel_size = None @@ -113,31 +111,23 @@ def megatron_derive_default_args(args): args.virtual_pipeline_model_parallel_size = 1 elif args.num_layers_per_virtual_pipeline_stage is not None: args.virtual_pipeline_model_parallel_size = args.num_layers // ( - args.num_layers_per_virtual_pipeline_stage - * args.pipeline_model_parallel_size + args.num_layers_per_virtual_pipeline_stage * args.pipeline_model_parallel_size ) - args.share_embeddings_and_output_weights = ( - not args.untie_embeddings_and_output_weights - ) + args.share_embeddings_and_output_weights = not args.untie_embeddings_and_output_weights if args.num_experts is None: args.moe_pattern = [0] * args.num_layers else: if isinstance(args.moe_layer_freq, int): - args.moe_pattern = [ - 1 if (i % args.moe_layer_freq == 0) else 0 - for i in range(args.num_layers) - ] + args.moe_pattern = [1 if (i % args.moe_layer_freq == 0) else 0 for i in range(args.num_layers)] elif isinstance(args.moe_layer_freq, list): args.moe_pattern = args.moe_layer_freq elif isinstance(args.moe_layer_freq, str): try: parsed = eval(args.moe_layer_freq) except Exception: - raise ValueError( - f"Invalid moe_layer_freq format: {args.moe_layer_freq}" - ) + raise ValueError(f"Invalid moe_layer_freq format: {args.moe_layer_freq}") # Handle case where eval returns an int (e.g., "1" -> 1 means all layers are MoE) if isinstance(parsed, int): @@ -146,18 +136,14 @@ def megatron_derive_default_args(args): args.moe_pattern = [1] * args.num_layers else: # Every Nth layer is MoE - args.moe_pattern = [ - 1 if (i % parsed == 0) else 0 for i in range(args.num_layers) - ] + args.moe_pattern = [1 if (i % parsed == 0) else 0 for i in range(args.num_layers)] elif isinstance(parsed, list): args.moe_pattern = parsed assert ( len(args.moe_pattern) == args.num_layers ), f"Invalid moe_layer_freq length: {len(args.moe_pattern)} (expected {args.num_layers})" else: - raise ValueError( - f"Invalid moe_layer_freq format after eval: {type(parsed)}" - ) + raise ValueError(f"Invalid moe_layer_freq format after eval: {type(parsed)}") # naming conversion args.sequence_length = args.seq_length From d92406a743a52fa9c04c612bf835580082cf3e49 Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Wed, 25 Feb 2026 14:26:11 -0800 Subject: [PATCH 12/12] style: fix isort import ordering in transformer_layer.py and factory.py --- primus/core/projection/module_profilers/transformer_layer.py | 2 +- primus/core/projection/simulation_backends/factory.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/primus/core/projection/module_profilers/transformer_layer.py b/primus/core/projection/module_profilers/transformer_layer.py index af4fe0526..59f593615 100644 --- a/primus/core/projection/module_profilers/transformer_layer.py +++ b/primus/core/projection/module_profilers/transformer_layer.py @@ -12,8 +12,8 @@ from primus.core.projection.training_config import TrainingConfig from . import collective_model as cm -from .collective_args import get_default_args from .attention import AttentionProfiler +from .collective_args import get_default_args from .dense_mlp import DenseMLPProfiler from .layer_norm import LayerNormProfiler from .moe_mlp import MoEMLPProfiler diff --git a/primus/core/projection/simulation_backends/factory.py b/primus/core/projection/simulation_backends/factory.py index 76db3df9b..e050f1f8c 100644 --- a/primus/core/projection/simulation_backends/factory.py +++ b/primus/core/projection/simulation_backends/factory.py @@ -88,9 +88,7 @@ def get_sdpa_simulation_backend( Raises: RuntimeError: If the Origami backend is not available. """ - from primus.core.projection.simulation_backends.sdpa_simulator import ( - SDPASimulator, - ) + from primus.core.projection.simulation_backends.sdpa_simulator import SDPASimulator is_rank_0 = int(os.getenv("RANK", "0")) == 0 if is_rank_0: