-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathblackwell_primer.py
More file actions
3603 lines (3176 loc) · 155 KB
/
blackwell_primer.py
File metadata and controls
3603 lines (3176 loc) · 155 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
"""
Blackwell Primer — Programmatic Memory and Verification for SM120/SM121
Hardware-first, bottom-up verification and analysis tool for the Blackwell
LLM deployment lab. Covers all GPU architectures (SM120/SM121/SM89) and all
deployment targets (RTX Pro 6000, DGX Spark, RTX 4090).
This script IS the algorithm. It reads hardware, parses source code, checks
configuration, and reports exactly what works, what's blocked, and what to
do next. Every check traces to a file, line number, or hardware query.
Primary function: verify TRT-LLM fork readiness for NVFP4 KV cache on
SM120/SM121. Secondary: model-specific analysis (Qwen3-Next-80B kernel
roadmap, weight inventory, performance ceilings).
State is persisted in data/ directory. Re-run after any code change to
detect regressions. This is a climbing anchor — set it, verify it holds,
build on it.
Commands:
primer verify <fork_path> Programmatic code audit of TRT-LLM fork
primer unknowns Show open questions and design decisions
primer analyze Model analysis (architecture + roadmap)
primer next Single instruction: what to do RIGHT NOW
primer status Phase overview with prerequisite checks
primer record <phase> <key> <value> Save a measurement
primer history Show all recorded measurements
primer complete <phase> Manually mark phase done
primer note <phase> <text> Add a note to a phase
Usage:
python3 blackwell_primer.py verify SM12x.../TensorRT-LLM --arch sm121
python3 blackwell_primer.py verify SM12x.../TensorRT-LLM --arch sm120
python3 blackwell_primer.py unknowns
python3 blackwell_primer.py analyze [config.json] [--gpu sm120]
python3 blackwell_primer.py status
"""
import json
import sys
import re
import subprocess
import time
import concurrent.futures
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Any
from pathlib import Path
from collections import defaultdict
# =============================================================================
# Hardware Specifications (measured, not marketed)
# =============================================================================
@dataclass
class HardwareSpec:
name: str
sm_arch: str
sm_count: int
memory_bandwidth_gb_s: float # DRAM bandwidth (GB/s)
memory_total_gb: float # Total VRAM (GB)
registers_per_sm: int # 32-bit registers
smem_per_sm_bytes: int # Shared memory per SM (bytes)
max_warps_per_sm: int # Max concurrent warps
warp_size: int = 32
kernel_launch_overhead_us: float = 3.5 # Measured CPU→GPU dispatch (μs)
cuda_graph_launch_overhead_us: float = 0.5 # GPU-side replay (μs)
@property
def max_threads_per_sm(self) -> int:
return self.max_warps_per_sm * self.warp_size
SM120_RTX_PRO_6000 = HardwareSpec(
name="RTX Pro 6000 (SM120)",
sm_arch="sm_120a",
sm_count=188,
memory_bandwidth_gb_s=1792.0,
memory_total_gb=96.0,
registers_per_sm=65536,
smem_per_sm_bytes=101376, # 99KB per block (opt-in); 100KB (102400) per SM total
max_warps_per_sm=48,
kernel_launch_overhead_us=3.5,
cuda_graph_launch_overhead_us=0.5,
# Verified 2026-02-11 via CUDA driver API (cuDeviceGetAttribute).
# NOT 128KB — LLM research pipelines hallucinate this value.
)
SM121_DGX_SPARK = HardwareSpec(
name="DGX Spark (SM121)",
sm_arch="sm_121a",
sm_count=84,
memory_bandwidth_gb_s=273.0, # LPDDR5x
memory_total_gb=128.0,
registers_per_sm=65536,
smem_per_sm_bytes=101376,
max_warps_per_sm=48,
kernel_launch_overhead_us=3.5,
cuda_graph_launch_overhead_us=0.5,
)
GPU_SPECS = {
"sm120": SM120_RTX_PRO_6000,
"sm121": SM121_DGX_SPARK,
}
# =============================================================================
# Quantization Format Definitions
# =============================================================================
@dataclass
class QuantFormat:
name: str
bits_per_element: float # Effective bits per weight element
scale_bits: int # Bits per scale factor
scale_group_size: int # Elements per scale group
bytes_per_element: float = 0.0 # Computed
def __post_init__(self):
# Total bytes = data bytes + scale bytes
data_bytes_per_elem = self.bits_per_element / 8.0
scale_bytes_per_elem = (self.scale_bits / 8.0) / self.scale_group_size
self.bytes_per_element = data_bytes_per_elem + scale_bytes_per_elem
NVFP4 = QuantFormat(
name="NVFP4",
bits_per_element=4,
scale_bits=8, # UE4M3
scale_group_size=16, # 1 scale per 16 elements
)
FP8_E4M3 = QuantFormat(
name="FP8",
bits_per_element=8,
scale_bits=0,
scale_group_size=1,
)
BF16 = QuantFormat(
name="BF16",
bits_per_element=16,
scale_bits=0,
scale_group_size=1,
)
FP16 = QuantFormat(
name="FP16",
bits_per_element=16,
scale_bits=0,
scale_group_size=1,
)
# =============================================================================
# Weight Matrix Descriptor
# =============================================================================
@dataclass
class WeightMatrix:
"""A single weight matrix in the model."""
name: str # e.g. "layer.0.linear_attn.in_proj_qkvz"
layer_index: int
layer_type: str # "linear_attention" or "full_attention"
operation: str # "attention_proj", "expert_ffn", "router", etc.
shape_N: int # Output dimension (rows)
shape_K: int # Input dimension (cols)
quant: QuantFormat # NVFP4, BF16, etc.
count_per_forward: int # How many times executed per forward pass
is_routed_expert: bool = False # True for per-expert weights (grouped)
expert_index: Optional[int] = None
@property
def weight_bytes(self) -> float:
"""Total bytes to read this weight matrix once."""
return self.shape_N * self.shape_K * self.quant.bytes_per_element
@property
def total_bytes_per_forward(self) -> float:
"""Total bytes read from this matrix across the full forward pass."""
return self.weight_bytes * self.count_per_forward
@property
def flops_per_call_m1(self) -> int:
"""FLOPs for M=1 (GEMV): 2*N*K."""
return 2 * self.shape_N * self.shape_K
@property
def arithmetic_intensity_m1(self) -> float:
"""FLOPs per byte at M=1."""
if self.weight_bytes == 0:
return 0.0
return self.flops_per_call_m1 / self.weight_bytes
# =============================================================================
# GEMV Shape Descriptor (aggregated)
# =============================================================================
@dataclass
class GemvShape:
"""Unique (N, K, quant) shape with aggregated count."""
shape_N: int
shape_K: int
quant: QuantFormat
total_calls_per_forward: int
operation_names: List[str] = field(default_factory=list)
is_routed_expert: bool = False
@property
def weight_bytes(self) -> float:
return self.shape_N * self.shape_K * self.quant.bytes_per_element
@property
def input_bytes_m1(self) -> float:
"""Input vector x is K elements in FP16."""
return self.shape_K * 2
@property
def total_read_bytes_per_call(self) -> float:
return self.weight_bytes + self.input_bytes_m1
@property
def total_read_bytes_per_forward(self) -> float:
return self.total_read_bytes_per_call * self.total_calls_per_forward
@property
def flops_per_call_m1(self) -> int:
return 2 * self.shape_N * self.shape_K
# =============================================================================
# Model Config Parser
# =============================================================================
def parse_model_config(config_path: str) -> dict:
"""Load and return the raw HuggingFace config.json."""
with open(config_path, 'r') as f:
return json.load(f)
def extract_architecture_parameters(config: dict) -> dict:
"""
Extract every architecture parameter from config.json into a flat dict.
Uses exact field names from the config. No guessing.
"""
params = {}
# Core dimensions
params["hidden_size"] = config["hidden_size"]
params["intermediate_size"] = config.get("intermediate_size", 0)
params["vocab_size"] = config["vocab_size"]
params["num_hidden_layers"] = config["num_hidden_layers"]
# Attention (full)
params["num_attention_heads"] = config["num_attention_heads"]
params["num_key_value_heads"] = config["num_key_value_heads"]
params["head_dim"] = config["head_dim"]
# Attention (linear / DeltaNet / Mamba)
params["linear_key_head_dim"] = config.get("linear_key_head_dim", 0)
params["linear_num_key_heads"] = config.get("linear_num_key_heads", 0)
params["linear_num_value_heads"] = config.get("linear_num_value_heads", 0)
params["linear_value_head_dim"] = config.get("linear_value_head_dim", 0)
params["linear_conv_kernel_dim"] = config.get("linear_conv_kernel_dim", 0)
# MoE
params["num_experts"] = config.get("num_experts", 0)
params["num_experts_per_tok"] = config.get("num_experts_per_tok", 0)
params["moe_intermediate_size"] = config.get("moe_intermediate_size", 0)
params["shared_expert_intermediate_size"] = config.get("shared_expert_intermediate_size", 0)
params["decoder_sparse_step"] = config.get("decoder_sparse_step", 1)
# Layer types
params["layer_types"] = config.get("layer_types", [])
params["full_attention_interval"] = config.get("full_attention_interval", 0)
# Context
params["max_position_embeddings"] = config.get("max_position_embeddings", 0)
# Quantization
qconfig = config.get("quantization_config", {})
params["quant_method"] = qconfig.get("quant_method", "none")
params["quant_algo"] = qconfig.get("quant_algo", "none")
params["quant_ignore_list"] = qconfig.get("ignore", [])
# KV cache quantization
kv_scheme = qconfig.get("kv_cache_scheme", {})
params["kv_cache_bits"] = kv_scheme.get("num_bits", 16)
params["kv_cache_type"] = kv_scheme.get("type", "float")
return params
def classify_layer_types(params: dict) -> Tuple[List[int], List[int]]:
"""
Return (linear_attention_indices, full_attention_indices) from layer_types.
"""
linear_indices = []
full_indices = []
for i, ltype in enumerate(params["layer_types"]):
if ltype == "linear_attention":
linear_indices.append(i)
elif ltype == "full_attention":
full_indices.append(i)
else:
print(f" WARNING: Unknown layer type '{ltype}' at index {i}")
return linear_indices, full_indices
def build_quantization_ignore_set(params: dict) -> set:
"""
Build a set of layer name patterns that are NOT quantized (stay in BF16).
The ignore list in config.json contains full layer paths.
"""
return set(params["quant_ignore_list"])
def determine_weight_quant(layer_idx: int, proj_suffix: str,
ignore_set: set, default_quant: QuantFormat) -> QuantFormat:
"""
Determine if a specific weight matrix is NVFP4 or BF16 by checking
the quantization ignore list.
"""
# Build possible ignore patterns
patterns = [
f"model.layers.{layer_idx}.{proj_suffix}",
]
for pattern in patterns:
if pattern in ignore_set:
return BF16
return default_quant
# =============================================================================
# Weight Matrix Enumeration — The Complete Forward Pass
# =============================================================================
def enumerate_all_weight_matrices(params: dict) -> List[WeightMatrix]:
"""
Enumerate EVERY weight matrix read during a single forward pass.
This is exhaustive. If it reads memory, it's listed here.
"""
weights = []
ignore_set = build_quantization_ignore_set(params)
linear_layers, full_layers = classify_layer_types(params)
H = params["hidden_size"] # 2048
num_layers = params["num_hidden_layers"] # 48
num_experts = params["num_experts"] # 512
experts_per_tok = params["num_experts_per_tok"] # 10
moe_intermediate = params["moe_intermediate_size"] # 512
shared_intermediate = params["shared_expert_intermediate_size"] # 512
vocab = params["vocab_size"] # 151936
# Full attention dimensions
full_q_dim = params["num_attention_heads"] * params["head_dim"] # 16*256=4096
full_k_dim = params["num_key_value_heads"] * params["head_dim"] # 2*256=512
full_v_dim = params["num_key_value_heads"] * params["head_dim"] # 2*256=512
full_o_dim = full_q_dim # 4096
# Linear attention dimensions
lin_q_dim = params["linear_num_key_heads"] * params["linear_key_head_dim"] # 16*128=2048
lin_k_dim = params["linear_num_key_heads"] * params["linear_key_head_dim"] # 16*128=2048
lin_v_dim = params["linear_num_value_heads"] * params["linear_value_head_dim"] # 32*128=4096
# Z dimension: typically same as Q for DeltaNet gating
lin_z_dim = lin_q_dim # 2048
lin_qkvz_dim = lin_q_dim + lin_k_dim + lin_v_dim + lin_z_dim # 10240
# BA projection (B and A state matrices for DeltaNet)
# Typically 2 * num_heads * head_dim or similar — check model code
# For now estimate same as Q+K
lin_ba_dim = lin_q_dim + lin_k_dim # 4096
# Linear attention output projection
lin_o_dim = H # 2048
# -------------------------------------------------------------------------
# Process each layer
# -------------------------------------------------------------------------
for layer_idx in range(num_layers):
layer_type = params["layer_types"][layer_idx]
# --- Attention Projections ---
if layer_type == "linear_attention":
# Fused QKVZ projection: [H] -> [Q+K+V+Z]
q = determine_weight_quant(layer_idx, "linear_attn.in_proj_qkvz", ignore_set, NVFP4)
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.linear_attn.in_proj_qkvz",
layer_index=layer_idx,
layer_type=layer_type,
operation="linear_attention_qkvz_projection",
shape_N=lin_qkvz_dim,
shape_K=H,
quant=q,
count_per_forward=1,
))
# BA projection: [H] -> [B+A]
q_ba = determine_weight_quant(layer_idx, "linear_attn.in_proj_ba", ignore_set, NVFP4)
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.linear_attn.in_proj_ba",
layer_index=layer_idx,
layer_type=layer_type,
operation="linear_attention_ba_projection",
shape_N=lin_ba_dim,
shape_K=H,
quant=q_ba,
count_per_forward=1,
))
# Conv1D: small, [kernel_dim * channels]
# This is a 1D causal conv, not a GEMV. Include for byte accounting.
conv_kernel = params["linear_conv_kernel_dim"]
q_conv = determine_weight_quant(layer_idx, "linear_attn.conv1d", ignore_set, NVFP4)
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.linear_attn.conv1d",
layer_index=layer_idx,
layer_type=layer_type,
operation="linear_attention_conv1d",
shape_N=H,
shape_K=conv_kernel, # Conv kernel is tiny
quant=q_conv,
count_per_forward=1,
))
# Output projection: [V_out] -> [H]
# This is NOT in the ignore list for linear attention
q_o = determine_weight_quant(layer_idx, "linear_attn.o_proj", ignore_set, NVFP4)
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.linear_attn.o_proj",
layer_index=layer_idx,
layer_type=layer_type,
operation="linear_attention_output_projection",
shape_N=H,
shape_K=lin_v_dim, # Output proj takes V output
quant=q_o,
count_per_forward=1,
))
elif layer_type == "full_attention":
# Q projection: [H] -> [Q_dim]
q_q = determine_weight_quant(layer_idx, "self_attn.q_proj", ignore_set, NVFP4)
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.self_attn.q_proj",
layer_index=layer_idx,
layer_type=layer_type,
operation="full_attention_q_projection",
shape_N=full_q_dim,
shape_K=H,
quant=q_q,
count_per_forward=1,
))
# K projection: [H] -> [K_dim]
q_k = determine_weight_quant(layer_idx, "self_attn.k_proj", ignore_set, NVFP4)
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.self_attn.k_proj",
layer_index=layer_idx,
layer_type=layer_type,
operation="full_attention_k_projection",
shape_N=full_k_dim,
shape_K=H,
quant=q_k,
count_per_forward=1,
))
# V projection: [H] -> [V_dim]
q_v = determine_weight_quant(layer_idx, "self_attn.v_proj", ignore_set, NVFP4)
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.self_attn.v_proj",
layer_index=layer_idx,
layer_type=layer_type,
operation="full_attention_v_projection",
shape_N=full_v_dim,
shape_K=H,
quant=q_v,
count_per_forward=1,
))
# O projection: [Q_dim] -> [H]
q_o = determine_weight_quant(layer_idx, "self_attn.o_proj", ignore_set, NVFP4)
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.self_attn.o_proj",
layer_index=layer_idx,
layer_type=layer_type,
operation="full_attention_output_projection",
shape_N=H,
shape_K=full_q_dim,
quant=q_o,
count_per_forward=1,
))
# --- MoE: Router ---
q_gate = determine_weight_quant(layer_idx, "mlp.gate", ignore_set, NVFP4)
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.mlp.gate",
layer_index=layer_idx,
layer_type=layer_type,
operation="moe_router",
shape_N=num_experts,
shape_K=H,
quant=q_gate,
count_per_forward=1,
))
# --- MoE: Routed Experts (only top-K active per token) ---
for expert_op, (N, K) in [
("expert_up_proj", (moe_intermediate, H)),
("expert_gate_proj", (moe_intermediate, H)),
("expert_down_proj", (H, moe_intermediate)),
]:
q_exp = determine_weight_quant(
layer_idx, f"mlp.experts.0.{expert_op.replace('expert_', '')}", ignore_set, NVFP4)
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.mlp.experts.*.{expert_op}",
layer_index=layer_idx,
layer_type=layer_type,
operation=f"moe_routed_{expert_op}",
shape_N=N,
shape_K=K,
quant=q_exp,
count_per_forward=experts_per_tok, # 10 experts active
is_routed_expert=True,
))
# --- MoE: Shared Expert ---
q_shared_gate = determine_weight_quant(
layer_idx, "mlp.shared_expert_gate", ignore_set, NVFP4)
# The shared_expert_gate is a scalar gate, not a full FFN
# But we still account for its weight read
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.mlp.shared_expert_gate",
layer_index=layer_idx,
layer_type=layer_type,
operation="moe_shared_expert_gate_scalar",
shape_N=1,
shape_K=H,
quant=q_shared_gate,
count_per_forward=1,
))
for shared_op, (N, K) in [
("shared_expert_up_proj", (shared_intermediate, H)),
("shared_expert_gate_proj", (shared_intermediate, H)),
("shared_expert_down_proj", (H, shared_intermediate)),
]:
q_shared = determine_weight_quant(
layer_idx, f"mlp.shared_expert.{shared_op.replace('shared_expert_', '')}", ignore_set, NVFP4)
weights.append(WeightMatrix(
name=f"layer.{layer_idx}.mlp.shared_expert.{shared_op}",
layer_index=layer_idx,
layer_type=layer_type,
operation=f"moe_{shared_op}",
shape_N=N,
shape_K=K,
quant=q_shared,
count_per_forward=1,
))
# --- LM Head ---
# lm_head is in the ignore list → BF16
weights.append(WeightMatrix(
name="lm_head",
layer_index=-1,
layer_type="output",
operation="lm_head",
shape_N=vocab,
shape_K=H,
quant=BF16, # Always ignored in quantization
count_per_forward=1,
))
# --- Embedding ---
# Embedding table is read once (lookup, not GEMV)
# Include for total memory accounting but not GEMV time
weights.append(WeightMatrix(
name="model.embed_tokens",
layer_index=-1,
layer_type="input",
operation="embedding_lookup",
shape_N=vocab,
shape_K=H,
quant=BF16,
count_per_forward=1, # lookup, not matmul
))
return weights
# =============================================================================
# Shape Aggregation
# =============================================================================
def aggregate_gemv_shapes(weights: List[WeightMatrix]) -> List[GemvShape]:
"""
Group weight matrices by unique (N, K, quant, is_routed) tuples.
Returns aggregated shapes sorted by total bytes descending.
"""
shape_map: Dict[Tuple, GemvShape] = {}
for w in weights:
# Skip embedding (it's a lookup, not a GEMV)
if w.operation == "embedding_lookup":
continue
key = (w.shape_N, w.shape_K, w.quant.name, w.is_routed_expert)
if key not in shape_map:
shape_map[key] = GemvShape(
shape_N=w.shape_N,
shape_K=w.shape_K,
quant=w.quant,
total_calls_per_forward=0,
operation_names=[],
is_routed_expert=w.is_routed_expert,
)
shape_map[key].total_calls_per_forward += w.count_per_forward
if w.operation not in shape_map[key].operation_names:
shape_map[key].operation_names.append(w.operation)
shapes = list(shape_map.values())
shapes.sort(key=lambda s: s.total_read_bytes_per_forward, reverse=True)
return shapes
# =============================================================================
# Performance Model
# =============================================================================
@dataclass
class ShapePerformanceModel:
"""Performance prediction for a single GEMV shape."""
shape: GemvShape
hw: HardwareSpec
# Computed fields
blocks_per_call: int = 0 # Threadblocks launched
sm_utilization: float = 0.0 # blocks / sm_count
waves: int = 0 # ceil(blocks / sm_count)
bandwidth_limited_time_us: float = 0.0 # bytes / bandwidth
launch_overhead_time_us: float = 0.0 # total launch overhead
total_time_per_forward_us: float = 0.0 # estimated wall time
# Classification
bottleneck: str = "" # "launch_overhead", "bandwidth", "sm_starvation"
def compute(self, rows_per_block: int = 8):
"""Compute all performance predictions."""
s = self.shape
hw = self.hw
# Threadblock count: ceil(N / rows_per_block)
self.blocks_per_call = (s.shape_N + rows_per_block - 1) // rows_per_block
self.sm_utilization = self.blocks_per_call / hw.sm_count
self.waves = max(1, (self.blocks_per_call + hw.sm_count - 1) // hw.sm_count)
# Bandwidth-limited time (pure data movement, no overhead)
bytes_per_call = s.total_read_bytes_per_call
self.bandwidth_limited_time_us = bytes_per_call / (hw.memory_bandwidth_gb_s * 1e3)
# Launch overhead (CPU dispatch per kernel)
if s.is_routed_expert:
# Routed experts: if batched into 1 launch, 1 overhead
# If individual launches, experts_per_tok overheads
# We model BOTH and show the gap
self.launch_overhead_time_us = (
hw.kernel_launch_overhead_us * s.total_calls_per_forward
)
else:
self.launch_overhead_time_us = (
hw.kernel_launch_overhead_us * s.total_calls_per_forward
)
# Total estimate (bandwidth + launch overhead, whichever dominates)
bw_time_total = self.bandwidth_limited_time_us * s.total_calls_per_forward
self.total_time_per_forward_us = bw_time_total + self.launch_overhead_time_us
# Classify bottleneck
launch_fraction = self.launch_overhead_time_us / max(self.total_time_per_forward_us, 1e-9)
if launch_fraction > 0.5:
self.bottleneck = "LAUNCH_OVERHEAD"
elif self.sm_utilization < 0.5:
self.bottleneck = "SM_STARVATION"
else:
self.bottleneck = "BANDWIDTH"
def compute_full_forward_pass_model(
shapes: List[GemvShape], hw: HardwareSpec
) -> List[ShapePerformanceModel]:
"""Compute performance model for every shape in the forward pass."""
models = []
for s in shapes:
m = ShapePerformanceModel(shape=s, hw=hw)
m.compute()
models.append(m)
return models
# =============================================================================
# Optimization Phase Generator
# =============================================================================
@dataclass
class OptimizationPhase:
"""A single step in the optimization roadmap."""
phase_number: int
name: str
prerequisite: str
what_to_build: str
expected_time_saved_us: float
expected_tps_after: float
measurement_command: str
done_when: str
def generate_optimization_roadmap(
models: List[ShapePerformanceModel],
hw: HardwareSpec,
params: dict,
) -> List[OptimizationPhase]:
"""
Generate the deterministic optimization roadmap.
Order is fixed by the performance model: fix the biggest bottleneck first.
"""
total_time_us = sum(m.total_time_per_forward_us for m in models)
total_launch_overhead_us = sum(m.launch_overhead_time_us for m in models)
total_bw_time_us = sum(
m.bandwidth_limited_time_us * m.shape.total_calls_per_forward
for m in models
)
total_calls = sum(m.shape.total_calls_per_forward for m in models)
# Count routed expert calls
routed_calls = sum(
m.shape.total_calls_per_forward
for m in models if m.shape.is_routed_expert
)
routed_overhead_us = sum(
m.launch_overhead_time_us
for m in models if m.shape.is_routed_expert
)
# Count non-routed calls
nonrouted_calls = total_calls - routed_calls
nonrouted_overhead_us = total_launch_overhead_us - routed_overhead_us
phases = []
running_time_us = total_time_us
# Phase 0: Baseline GEMV kernel (already done)
phases.append(OptimizationPhase(
phase_number=0,
name="BASELINE: Individual GEMV Kernels",
prerequisite="Working NVFP4 + BF16 GEMV kernels with LUT dequant",
what_to_build="nvfp4_gemv.cu (NVFP4 path) + bf16_gemv.cu (BF16 path)",
expected_time_saved_us=0,
expected_tps_after=1e6 / running_time_us if running_time_us > 0 else 0,
measurement_command="bench_nvfp4_gemv (all shapes, correctness + timing)",
done_when="All shapes pass correctness, bandwidth efficiency measured per shape",
))
# Phase 1: Batched GEMV for routed experts
# Fuse N_experts same-shape kernels into one launch
if routed_calls > 0:
# With batching: 1 launch per shape per layer instead of experts_per_tok launches
batched_routed_launches = routed_calls / params["num_experts_per_tok"]
saved_launches = routed_calls - batched_routed_launches
time_saved = saved_launches * hw.kernel_launch_overhead_us
running_time_us -= time_saved
phases.append(OptimizationPhase(
phase_number=1,
name="BATCHED GEMV: Fuse Routed Expert Launches",
prerequisite="Phase 0 baseline measured",
what_to_build=(
"nvfp4_batched_gemv.cu — single kernel launch processes "
f"{params['num_experts_per_tok']} expert weight matrices. "
"Grid: (ceil(N/rows_per_block) * num_active_experts) blocks. "
"Each block reads expert_id from a dispatch table to select "
"the correct weight pointer. Same inner loop as Phase 0. "
"Library alternatives exist (CUTLASS MoEProblemShape + GroupedGemmMoE, "
"cuBLAS cublasGemmGroupedBatchedEx) but are tile-based (128x128) — "
f"our custom GEMV handles M=1 with {params['moe_intermediate_size']}-wide N directly."
),
expected_time_saved_us=time_saved,
expected_tps_after=1e6 / running_time_us if running_time_us > 0 else 0,
measurement_command=(
"bench_batched_gemv: compare individual vs batched for "
f"(N={params['moe_intermediate_size']}, K={params['hidden_size']}) × "
f"{params['num_experts_per_tok']} experts"
),
done_when=(
f"Batched launch time < {params['num_experts_per_tok']}× individual launch time. "
"Correctness matches individual path."
),
))
# Phase 2: CUDA Graphs for the full forward pass
remaining_launches = nonrouted_calls + (routed_calls / params["num_experts_per_tok"] if routed_calls > 0 else 0)
graph_time_saved = remaining_launches * (hw.kernel_launch_overhead_us - hw.cuda_graph_launch_overhead_us)
running_time_us -= graph_time_saved
phases.append(OptimizationPhase(
phase_number=2,
name="CUDA GRAPHS: Eliminate CPU Launch Overhead",
prerequisite="Phase 1 batched GEMV working",
what_to_build=(
"forward_pass_graph.cu — Record all GEMV launches "
f"(~{int(remaining_launches)} kernels after batching) into a single "
"CUDA graph. Replay is GPU-side, ~0.5μs per launch instead of ~3.5μs. "
"Requires: all kernel shapes are static (they are at M=1 decode). "
"The graph captures kernel pointers, grid dims, and shared memory sizes. "
"Weight pointers are fixed (model weights don't move). "
"Only the input vector x changes between tokens — update via "
"cudaGraphExecKernelNodeSetParams or use a persistent x buffer."
),
expected_time_saved_us=graph_time_saved,
expected_tps_after=1e6 / running_time_us if running_time_us > 0 else 0,
measurement_command=(
"bench_cuda_graph: record + replay full forward pass, "
"compare to unbatched individual launches"
),
done_when=(
"Graph replay time < sum of individual launches. "
f"Target: eliminate ~{graph_time_saved:.0f}μs of launch overhead."
),
))
# Phase 3: Kernel occupancy tuning
# Find shapes where SM starvation is the bottleneck
starved_shapes = [m for m in models if m.bottleneck == "SM_STARVATION"]
starved_time = sum(m.total_time_per_forward_us for m in starved_shapes)
# Estimate: with better occupancy, recover ~30% of starvation penalty
occ_time_saved = starved_time * 0.3
running_time_us -= occ_time_saved
phases.append(OptimizationPhase(
phase_number=3,
name="OCCUPANCY TUNING: Reduce Register Pressure",
prerequisite="Phase 2 CUDA Graphs working",
what_to_build=(
"Tune launch_bounds and inner loop to reduce registers from 60 to ~48, "
"enabling 6 blocks/SM (100% occupancy at 256 threads/block). "
"Currently at 4 blocks/SM (67% occupancy). "
"Methods: reduce accumulator variables, use __half2 instead of float "
"for intermediate x values, limit unroll depth. "
"Profile with: ncu --metrics launch__registers_per_thread"
),
expected_time_saved_us=occ_time_saved,
expected_tps_after=1e6 / running_time_us if running_time_us > 0 else 0,
measurement_command=(
"ncu --metrics sm__warps_active.avg.pct_of_peak_sustained_active "
"./bench_nvfp4_gemv (before and after)"
),
done_when=(
"Warp occupancy > 90%. Registers per thread ≤ 48. "
"No performance regression on any shape."
),
))
# Phase 4: LM Head bandwidth optimization
lm_head_models = [m for m in models if "lm_head" in m.shape.operation_names]
if lm_head_models:
lm_time = sum(m.total_time_per_forward_us for m in lm_head_models)
lm_bw_time = sum(
m.bandwidth_limited_time_us * m.shape.total_calls_per_forward
for m in lm_head_models
)
# LM head is BF16, biggest single matrix. Target: 70% bandwidth eff
# Currently ~53% for NVFP4, BF16 LM head will be different
lm_improvement = lm_time * 0.25 # Estimate 25% improvement
running_time_us -= lm_improvement
phases.append(OptimizationPhase(
phase_number=4,
name="LM HEAD: BF16 GEMV with Maximum Bandwidth",
prerequisite="Phase 3 occupancy tuned",
what_to_build=(
f"bf16_gemv.cu — Specialized kernel for LM head "
f"(N={params['vocab_size']}, K={params['hidden_size']}). "
"BF16 weights (not NVFP4) — no dequant needed. "
"Use 128-bit vectorized loads (4× BF16 per load). "
"This is the single largest GEMV in the forward pass "
f"({params['vocab_size'] * params['hidden_size'] * 2 / 1e6:.0f} MB). "
"Target: >70% bandwidth efficiency."
),
expected_time_saved_us=lm_improvement,
expected_tps_after=1e6 / running_time_us if running_time_us > 0 else 0,
measurement_command=(
f"bench_bf16_gemv: N={params['vocab_size']}, K={params['hidden_size']}, "
"measure BW efficiency"
),
done_when=(
"BF16 GEMV achieves >70% of peak memory bandwidth on LM head shape."
),
))
# Phase 5: Non-GEMV operations
# Research: top-K routing and fused elementwise queries ACCEPTED.
# RMSNorm and SiLU queries failed (docs insufficient / validation failed).
phases.append(OptimizationPhase(
phase_number=5,
name="NON-GEMV OPS: RMSNorm + SiLU + Routing + Softmax",
prerequisite="Phase 4 complete",
what_to_build=(
"Fused element-wise kernels for the non-GEMV operations: "
"RMSNorm (per layer), SiLU activation (per expert FFN), "
"top-K routing with softmax (per MoE layer). "
"These are tiny in bytes but add launch overhead. "
"Fuse adjacent operations: e.g., RMSNorm + input broadcast, "
"expert_up * SiLU(expert_gate) in a single kernel. "
"Note: top-K routing can reuse CUTLASS MoE epilogue (top_K <= 8 compile-time) "
f"but our model uses top-{params['num_experts_per_tok']} — needs 2-pass or custom."
),
expected_time_saved_us=0, # Hard to estimate without profiling
expected_tps_after=1e6 / running_time_us if running_time_us > 0 else 0,
measurement_command="Profile full forward pass with nsys to identify remaining gaps",
done_when="Non-GEMV operations add < 10% overhead to total forward time",
))
# Phase 6: Attention compute (full attention layers only)
#
# RESOLVED (Q1, Q2, Q3, Q5 via nvidia-docs-search 2026-02-11):
# - mxf8f6f4.block_scale: A=mx_float8_t(FP8) × B=mx_float4_t(FP4), UE8M0 scales, group=32
# - mxf4nvf4.block_scale: A=nv_float4_t × B=nv_float4_t, UE4M3 scales, group=16, TN only
# - Accumulator: FP32 for all block_scale variants
# - SM120 shapes: m16n8k64, m16n8k128
#
# THREE VIABLE PATHS (see D1_attention_mma_path in OPEN_QUESTIONS):
# Path A: mxf8f6f4 — Q quantized FP16→FP8, K stays FP4. KV cache needs MX format.
# Path B: mxf4nvf4 — Q quantized FP16→FP4, K stays NVFP4. Native format. 4x throughput.
# Path C: Software dequant — Q stays FP16, K dequanted via LUT. Max precision.
#
# UNRESOLVED: Q8 (Q precision loss), Q9 (missing .cuh headers), Q10 (head_dim=256)
#
# STRATEGY: Implement Path C first (bandwidth-bound decode, simplest).
# Then Path B (highest throughput). Then Path A (best Q precision).
# Benchmark all three. Decode at M=1 is bandwidth-bound, so Path C may win.
#
n_full_layers = len([t for t in params["layer_types"] if t == "full_attention"])
nvfp4_kv_bytes_per_elem = NVFP4.bytes_per_element # 0.5625 B (4-bit + scale)
kv_bytes_8k = (n_full_layers * 2 * params['num_key_value_heads']
* params['head_dim'] * 8192 * nvfp4_kv_bytes_per_elem)
phases.append(OptimizationPhase(
phase_number=6,
name=f"ATTENTION: NVFP4 KV Cache Decode for {n_full_layers} Full Attention Layers",
prerequisite="Phase 5 complete",
what_to_build=(
f"nvfp4_attention_decode.cu — Only {n_full_layers} of {params['num_hidden_layers']} layers "
"use full attention with KV cache. "
f"GQA ratio: {params['num_attention_heads']}:{params['num_key_value_heads']} "
f"({params['num_attention_heads'] // params['num_key_value_heads']}:1). "
f"Head dim: {params['head_dim']} (NOTE: existing kernel hardcodes 128 — needs rewrite). "
"THREE PATHS (benchmark all): "
"(A) mxf8f6f4.block_scale: Q=FP8×K=MX_FP4, 2x Hopper throughput, "
"needs KV cache repacking NVFP4→MX (UE4M3→UE8M0, 16→32 group). "
"(B) mxf4nvf4.block_scale: Q=FP4×K=NVFP4, 4x Hopper throughput (FASTEST), "
"NVFP4 KV cache used directly, but Q loses precision (FP16→FP4). TN only. "
"(C) Software LUT dequant: Q stays FP16, K/V dequanted via LUT to float. "
"No tensor core for attention — pure bandwidth play. May win at M=1. "
f"At 8K context: NVFP4 KV cache reads = {kv_bytes_8k / 1e6:.1f} MB "
f"(vs {n_full_layers * 2 * params['num_key_value_heads'] * params['head_dim'] * 8192 / 1e6:.1f} MB for FP8). "
"BLOCKERS: Must write missing .cuh headers (Q9), support head_dim=256 (Q10), "
"add GQA config for 16Q/2KV/256dim. Use FlashDecoding split-K across GQA groups."
),
expected_time_saved_us=0,
expected_tps_after=0, # Can't predict without attention timing
measurement_command="bench_attention_decode: measure all 3 paths per-layer at 8K context",
done_when="Attention adds < 20% overhead to GEMV-only forward time at 8K context",
))
# Phase 7: Linear attention state update
# CONFIRMED CUSTOM-ONLY: Research pipeline (864 LLM calls, 9-consensus) verified
# NO cuDNN primitive for Mamba selective scan. 6/9 extractions hallucinated APIs
# (CUDNN_BACKEND_OPERATION_SEQSCAN_DESCRIPTOR, CUDNN_OP_SCAN — none exist).
# Linear attention O(1) per-token: all shipped kernels are O(N²). No library path.
n_linear_layers = len([t for t in params["layer_types"] if t == "linear_attention"])
phases.append(OptimizationPhase(
phase_number=7,
name=f"LINEAR ATTENTION: DeltaNet State Update for {n_linear_layers} Layers",
prerequisite="Phase 6 complete",
what_to_build=(
f"deltanet_state_update.cu — {n_linear_layers} layers use linear attention "
"(DeltaNet/Mamba hybrid). No KV cache — instead, maintain a recurrent state. "
"State update is O(1) per token (no sequence length dependency). "
"This is an element-wise operation on the state matrix, not a GEMV. "
"CONFIRMED: No cuDNN/cuBLAS primitive exists for selective scan or "
"linear attention state update. Fully custom CUDA required. "
"Conv1D causal convolution also needed (kernel_dim="
f"{params['linear_conv_kernel_dim']}). Conv is short — fuse with state update."
),
expected_time_saved_us=0,
expected_tps_after=0,
measurement_command="bench_deltanet_state: measure state update time per layer",
done_when="State update + conv1d < 5μs per layer",
))
return phases