From 3dc5c962f2c727f055eb51c21126ed26314ca215 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Wed, 8 Oct 2025 05:53:09 -0700 Subject: [PATCH 01/29] add simulated nvfp4 gemv example. --- problems/nvidia/gemv/reference.py | 118 ++++++++++++++ problems/nvidia/gemv/submission.py | 246 +++++++++++++++++++++++++++++ problems/nvidia/gemv/task.py | 10 ++ problems/nvidia/gemv/task.yml | 53 +++++++ problems/nvidia/gemv/template.py | 23 +++ 5 files changed, 450 insertions(+) create mode 100644 problems/nvidia/gemv/reference.py create mode 100644 problems/nvidia/gemv/submission.py create mode 100644 problems/nvidia/gemv/task.py create mode 100644 problems/nvidia/gemv/task.yml create mode 100644 problems/nvidia/gemv/template.py diff --git a/problems/nvidia/gemv/reference.py b/problems/nvidia/gemv/reference.py new file mode 100644 index 0000000..3a3be55 --- /dev/null +++ b/problems/nvidia/gemv/reference.py @@ -0,0 +1,118 @@ +import torch +from task import input_t, output_t +from utils import make_match_reference + +def ref_kernel( + data: input_t, +)->output_t: + """ + Highly inefficient torch reference implementation of a FFAM simulated NVFP4 block-scaled GEMV. + + a: [l, m, k] matrix + b: [l, 1, k] vector + scale_a: [l, m, k//16] blockwise scales for a + scale_b: [l, 1, k//16] blockwise scales for b + c: [l, m, 1] output + + Block size is 16 along the k dimension. + """ + a, b, scale_a, scale_b, c = data + + # Make contiguous for efficiency + a = a.contiguous() + b = b.contiguous() + + # Get dimensions + l, m, k = a.shape + block_size = 16 + scale_k = k // block_size + + #reshape makes memory contiguous + scale_a = ( + scale_a.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, m, scale_k, block_size) + .reshape(l, m, scale_k * block_size) + .permute(1, 2, 0) + ) + scale_a = scale_a[:, :k, :] + scale_b = ( + scale_b.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, 1, scale_k, block_size) + .reshape(l, 1, scale_k * block_size) + .permute(1, 2, 0) + ) + scale_b = scale_b[:, :k, :] + # scale_a = scale_a.contiguous() + # scale_b = scale_b.contiguous() + + + # # Apply blockwise scaling to input 'a' + # # scale_a shape: [l, m, scale_k] -> expand to [l, m, k] + # a_scale_expanded = scale_a.unsqueeze(-1).repeat(1, 1, 1, block_size) # Shape: [l, m, scale_k, block_size] + # a_scale_expanded = a_scale_expanded.reshape(l, m, scale_k * block_size) + # a_scale_expanded = a_scale_expanded[:, :, :k] # Handle case where k is not exactly divisible + + # Dequantize 'a' by applying scales, convert to float32 for computation + a_scaled = a.to(torch.float32) * scale_a + b_scaled = b.to(torch.float32) * scale_b + + # # Apply blockwise scaling to input 'b' + # # scale_b shape: [l, 1, scale_k] -> expand to [l, 1, k] + # b_scale_expanded = scale_b.unsqueeze(-1).repeat(1, 1, 1, block_size) # Shape: [l, 1, scale_k, block_size] + # b_scale_expanded = b_scale_expanded.reshape(l, 1, scale_k * block_size) + # b_scale_expanded = b_scale_expanded[:, :, :k] # Handle case where k is not exactly divisible + + # # Dequantize 'b' by applying scales, convert to float32 for computation + # b_scaled = b.to(torch.float32) * b_scale_expanded.to(torch.float32) + + # Compute GEMV using batched matmul: a_scaled [l, m, k] @ b_scaled [l, 1, k] -> [l, m, 1] + # For each batch l: a[i, :, :] @ b[i, 0, :].T + result = torch.zeros((l, m, 1), dtype=torch.float32, device=a.device) + for i in range(l): + result[i, :, 0] = (a_scaled[i, :, :] @ b_scaled[i, 0, :]).to(c.dtype) + c[...] = result.to(c.dtype) + + return c + +def generate_input( + m: int, + k: int, + l: int, + seed: int, +): + torch.manual_seed(seed) + block_size = 16 + scale_k = k // block_size + + # Create fp4 input a, b tensors with LxMxK layout + # torch.float4e2m1fn is not a standard torch dtype; use torch.uint8 as a placeholder for fp4 + a = torch.arange(l * m * k, dtype=torch.float32, device="cuda").reshape(m, k, l).to(torch.uint8) + b = torch.arange(l * 1 * k, dtype=torch.float32, device="cuda").reshape(1, k, l).to(torch.uint8) + # Create fp16 output tensor with LxMx1 layout + c = torch.arange(l * m * 1, dtype=torch.float32, device="cuda").reshape(m, 1, l).to(torch.float16) + + # Create scales factor with f32 data type + def ceil_div(a, b): + return (a + b - 1) // b + + # every 16 k elements share the same scale factor + # Set the block size for blockwise scaling + block_size = 16 + # Compute the number of scale factors needed along k (ceil division) + scale_k = ceil_div(k, block_size) + # Define the shape for scale_a: [l, m, scale_k] + scale_a_shape = (l, m, scale_k) + # Define the shape for scale_b: [l, 1, scale_k] + scale_b_shape = (l, 1, scale_k) + # Permute order to match expected layout: (m, scale_k, l) + scale_permute_order = (1, 2, 0) + # Generate random scale factors for a, then permute to (m, scale_k, l) + scale_a_f32 = torch.randint(1, 3, scale_a_shape, dtype=torch.float32, device="cuda").permute(scale_permute_order) + # Generate random scale factors for b, then permute to (1, scale_k, l) + scale_b_f32 = torch.randint(1, 3, scale_b_shape, dtype=torch.float32, device="cuda").permute(scale_permute_order) + + return (a, b, scale_a_f32, scale_b_f32, c) + +check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) \ No newline at end of file diff --git a/problems/nvidia/gemv/submission.py b/problems/nvidia/gemv/submission.py new file mode 100644 index 0000000..5269517 --- /dev/null +++ b/problems/nvidia/gemv/submission.py @@ -0,0 +1,246 @@ +import cuda.bindings.driver as cuda + +import torch +from task import input_t, output_t +from typing import Tuple + +import cutlass +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack +import cutlass.utils.blockscaled_layout as blockscaled_utils + +mma_tiler_mnk = (128, 1, 64) +ab_dtype = cutlass.Float4E2M1FN +sf_dtype = cutlass.Float8E4M3FN +c_dtype = cutlass.Float16 +block_size = 16 +threads_per_cta = 128 + +# Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_tensor: cute.Tensor, + sf_mma_tensor: cute.Tensor, +): + # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) + # group to ((32, 4, rest_m), (4, rest_k), l) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] + +# GPU device kernel +# Using FFMA to simulate NVFP4 block-scaled GEMV computation +@cute.kernel +def kernel( + mA_mkl: cute.Tensor, + mB_nkl: cute.Tensor, + mSFA_mkl: cute.Tensor, + mSFB_nkl: cute.Tensor, + mC_mnl: cute.Tensor, +): + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + # mma_coord_mnk = (bidx, bidy, bidz) + + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bM, bK, RestM, RestK, RestL) + # bM = (32, 4) + # bK = (16, 4) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) + ) + + tCgC = gC_mnl[tidx, None, bidx, bidy, bidz] + tCgC = cute.make_tensor(tCgC.iterator, 1) + res = cute.zeros_like(tCgC, cutlass.Float32) + + k_tile_cnt = gA_mkl.layout[3].shape + for k_tile in range(k_tile_cnt): + tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz] + tBgB = gB_nkl[None, None, bidy, k_tile, bidz] + tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz] + tBgSFB = gSFB_nkl[None, None, bidy, k_tile, bidz] + + # Load A/B/SFA/SFB tile from global memory + a_val_nvfp4 = tAgA.load() + b_val_nvfp4 = tBgB.load() + sfa_val_fp8 = tAgSFA.load() + sfb_val_fp8 = tBgSFB.load() + + # Convert to f32 for FFMA computation + a_val = a_val_nvfp4.to(cutlass.Float32) + b_val = b_val_nvfp4.to(cutlass.Float32) + sfa_val = sfa_val_fp8.to(cutlass.Float32) + sfb_val = sfb_val_fp8.to(cutlass.Float32) + + for i in cutlass.range_constexpr(mma_tiler_mnk[2] // block_size): + for j in cutlass.range_constexpr(block_size): + res += ( + a_val[i * block_size + j] + * sfa_val[i] + * b_val[i * block_size + j] + * sfb_val[i] + ) + tCgC.store(res.to(cutlass.Float16)) + return + +@cute.jit +def my_kernel( + a_tensor: cute.Tensor, + b_tensor: cute.Tensor, + sfa_tensor: cute.Tensor, + sfb_tensor: cute.Tensor, + c_tensor: cute.Tensor): + # (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, block_size + ) + sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) + # (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, block_size + ) + sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) + # Compute grid size + grid = ( + cute.ceil_div(c_tensor.shape[0], 128), + 1, + c_tensor.shape[2], + ) + # Launch the kernel synchronously + kernel(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor).launch( + grid=grid, + block=[threads_per_cta, 1, 1], + cluster=(1, 1, 1) + ) + +def ceil_div(a, b): + return (a + b - 1) // b + +def create_scale_factor_tensor(l, mn, sf_k, ref, dtype): + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + mma_permute_order = (3, 4, 1, 5, 2, 0) + + # Create f32 cute torch tensor (cpu) + cute_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + mma_shape, + torch.float32, + permute_order=mma_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0, + max_val=1, + ), + ) + + # convert ref f32 tensor to cute f32 tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(ref), + from_dlpack(cute_f32_torch_tensor_cpu), + ) + cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() + + # Create dtype cute torch tensor (cpu) + cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( + cute_f32_torch_tensor_cpu, + dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + + # Convert f32 cute tensor to dtype cute tensor + cute_tensor = cutlass_torch.convert_cute_tensor( + cute_f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=True, + ) + return cute_tensor, cute_torch_tensor + + +def custom_kernel(data: input_t): + """ + Execute the kernel. If not already compiled, compile it first. + + Args: + data: Tuple of (a, b, scale_a, scale_b, c) tensors + + Returns: + Output tensor c + """ + a, b, scale_a, scale_b, c = data + # Get dimensions from MxKxL layout + m = a.shape[0] + k = a.shape[1] + l = a.shape[2] + n = b.shape[0] # should be 1 for GEMV + + # Convert torch tensors to CuTe tensors + a_tensor, a_torch = cutlass_torch.cute_tensor_like( + a, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, b_torch = cutlass_torch.cute_tensor_like( + b, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch = cutlass_torch.cute_tensor_like( + c, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + # Mark tensor with element divisibility for 16B alignment + a_tensor.mark_compact_shape_dynamic( + mode=1, + stride_order=(2, 0, 1), + divisibility=32, + ) + b_tensor.mark_compact_shape_dynamic( + mode=1, + stride_order=(2, 0, 1), + divisibility=32, + ) + c_tensor.mark_compact_shape_dynamic( + 0, + (2, 1, 0), + divisibility=16, + ) + + # Create scale tensors + scale_k = ceil_div(k, block_size) + scale_a_tensor, scale_a_torch = create_scale_factor_tensor( + l, m, scale_k, scale_a, sf_dtype + ) + scale_b_tensor, scale_b_torch = create_scale_factor_tensor( + l, 1, scale_k, scale_b, sf_dtype + ) + # Run the compiled kernel + my_kernel(a_tensor, b_tensor, scale_a_tensor, scale_b_tensor, c_tensor) + return c + diff --git a/problems/nvidia/gemv/task.py b/problems/nvidia/gemv/task.py new file mode 100644 index 0000000..6005f59 --- /dev/null +++ b/problems/nvidia/gemv/task.py @@ -0,0 +1,10 @@ +import torch +from typing import TypedDict, TypeVar + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) +class TestSpec(TypedDict): + m: int + k: int + l: int + seed: int \ No newline at end of file diff --git a/problems/nvidia/gemv/task.yml b/problems/nvidia/gemv/task.yml new file mode 100644 index 0000000..9583c0a --- /dev/null +++ b/problems/nvidia/gemv/task.yml @@ -0,0 +1,53 @@ +# name: nvfp4-ffma-gemv + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + + You will implement a batched matrix-vector multiplication kernel optimized for NVIDIA B200. + To be explicit, you will be given a tuple of tensors: + ``` + (a, scale_a, b, scale_b, c) + ``` + where: + * `a` is L x M x K in row-major order in nvfp4(e2m1) + * `b` is L x 1 x K in nvfp4(e2m1) + * `scale_a` is L x M x K // 16 in row-major order in fp8(e4m3fnuz) + * `scale_b` is L x 1 x K // 16 in fp8(e4m3fnuz) + * `c` is L x M x 1 in fp16 + + Matrix sizes `M` is divisible by mma_tiler_mn[0] defined in the kernel, `K` is divisible by 64. + The computation is using FFMA instructions to simulate NVFP4 block-scaled GEMV computation and block_size is 16. + The ranking criteria is the geometric mean of the benchmark results. + For the grand price, your kernel will be evaluated against the speed of light analysis + and the solution closest to the speed of light will be awarded the grand price. + ``` + The speed of light analysis is (using 1.5Ghz clock): + M K L time[us] + 7168 16384 1 8.71 + 4096 7168 1 2.18 + 7168 2048 1 1.09 + ``` +config: + main: "eval.py" + +templates: + Python: "template.py" + +tests: + - {"m": 128, "k": 256, "l": 1, "seed": 1111} + - {"m": 2384, "k": 4608, "l": 2, "seed": 1111} + +benchmarks: + - {"m": 7168, "k": 16384, "l":1, "seed": 1111} + - {"m": 4096, "k": 7168, "l":1, "seed": 1111} + - {"m": 7168, "k": 2048, "l":1, "seed": 1111} + +ranking_by: "geom" \ No newline at end of file diff --git a/problems/nvidia/gemv/template.py b/problems/nvidia/gemv/template.py new file mode 100644 index 0000000..bdc1a22 --- /dev/null +++ b/problems/nvidia/gemv/template.py @@ -0,0 +1,23 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference implementation of block-scale fp8 gemm + Args: + data: Tuple that expands to: + a: torch.Tensor[float4e2m1fn] of shape [l, m, k], + b: torch.Tensor[float4e2m1fn] of shape [l, 1, k], + scale_a: torch.Tensor[float8_e4m3fnuz] of shape [l, m, k // 16], + scale_b: torch.Tensor[float8_e4m3fnuz] of shape [l, 1, k // 16], + c: torch.Tensor[float16] of shape [l, m, 1] + Returns: + Tensor containing output in float16 + c: torch.Tensor[float16] of shape [l, m, 1] + """ + # c: [l, m, 1] is pre-allocated memory to avoid timing allocation overhead. + a, b, scale_a, scale_b, c = data + + # Your implementation here + + return c \ No newline at end of file From 9b5aae3cf1d9d127ff32cb4708953b885553d84b Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Thu, 9 Oct 2025 01:32:47 -0700 Subject: [PATCH 02/29] modify the code --- problems/nvidia/gemv/reference.py | 209 ++++++++++++++----------- problems/nvidia/gemv/submission.py | 243 ++++++++++++++++++++--------- problems/nvidia/gemv/task.yml | 1 - 3 files changed, 293 insertions(+), 160 deletions(-) diff --git a/problems/nvidia/gemv/reference.py b/problems/nvidia/gemv/reference.py index 3a3be55..c430eec 100644 --- a/problems/nvidia/gemv/reference.py +++ b/problems/nvidia/gemv/reference.py @@ -2,117 +2,148 @@ from task import input_t, output_t from utils import make_match_reference +def ceil_div(a, b): + """Helper function for ceiling division""" + return (a + b - 1) // b + def ref_kernel( data: input_t, -)->output_t: +) -> output_t: """ - Highly inefficient torch reference implementation of a FFAM simulated NVFP4 block-scaled GEMV. + PyTorch reference implementation of NVFP4 block-scaled GEMV. + + This simulates the GEMV operation: C = A @ b + where A and b are block-scaled with FP4 values and FP8 scale factors. - a: [l, m, k] matrix - b: [l, 1, k] vector - scale_a: [l, m, k//16] blockwise scales for a - scale_b: [l, 1, k//16] blockwise scales for b - c: [l, m, 1] output + Tensor shapes (MxKxL layout): + a: [m, k, l] - Input matrix in FP4 + b: [1, k, l] - Input vector in FP4 + scale_a: [m, k, l] - Expanded blockwise scales for a in FP32 + scale_b: [1, k, l] - Expanded blockwise scales for b in FP32 + c: [m, 1, l] - Output vector in FP16 - Block size is 16 along the k dimension. + where: + m: number of rows in A + k: number of columns in A (must be multiple of block_size) + l: batch size + + The reference implementation follows the pattern: + res_a = einsum("mkl,mkl->mkl", a_ref, sfa_ref) + res_b = einsum("nkl,nkl->nkl", b_ref, sfb_ref) # n=1 for GEMV + ref = einsum("mkl,nkl->mnl", res_a, res_b) """ a, b, scale_a, scale_b, c = data - # Make contiguous for efficiency - a = a.contiguous() - b = b.contiguous() - - # Get dimensions - l, m, k = a.shape - block_size = 16 - scale_k = k // block_size + # Get dimensions from MxKxL layout + m, k, l = a.shape + n = 1 # GEMV: N dimension is always 1 - #reshape makes memory contiguous - scale_a = ( - scale_a.permute(2, 0, 1) - .unsqueeze(-1) - .expand(l, m, scale_k, block_size) - .reshape(l, m, scale_k * block_size) - .permute(1, 2, 0) - ) - scale_a = scale_a[:, :k, :] - scale_b = ( - scale_b.permute(2, 0, 1) - .unsqueeze(-1) - .expand(l, 1, scale_k, block_size) - .reshape(l, 1, scale_k * block_size) - .permute(1, 2, 0) - ) - scale_b = scale_b[:, :k, :] - # scale_a = scale_a.contiguous() - # scale_b = scale_b.contiguous() + # Convert to f32 for reference computation + a_ref = a.to(torch.float32) + b_ref = b.to(torch.float32) + sfa_ref = scale_a.to(torch.float32) + sfb_ref = scale_b.to(torch.float32) - - # # Apply blockwise scaling to input 'a' - # # scale_a shape: [l, m, scale_k] -> expand to [l, m, k] - # a_scale_expanded = scale_a.unsqueeze(-1).repeat(1, 1, 1, block_size) # Shape: [l, m, scale_k, block_size] - # a_scale_expanded = a_scale_expanded.reshape(l, m, scale_k * block_size) - # a_scale_expanded = a_scale_expanded[:, :, :k] # Handle case where k is not exactly divisible - - # Dequantize 'a' by applying scales, convert to float32 for computation - a_scaled = a.to(torch.float32) * scale_a - b_scaled = b.to(torch.float32) * scale_b - - # # Apply blockwise scaling to input 'b' - # # scale_b shape: [l, 1, scale_k] -> expand to [l, 1, k] - # b_scale_expanded = scale_b.unsqueeze(-1).repeat(1, 1, 1, block_size) # Shape: [l, 1, scale_k, block_size] - # b_scale_expanded = b_scale_expanded.reshape(l, 1, scale_k * block_size) - # b_scale_expanded = b_scale_expanded[:, :, :k] # Handle case where k is not exactly divisible - - # # Dequantize 'b' by applying scales, convert to float32 for computation - # b_scaled = b.to(torch.float32) * b_scale_expanded.to(torch.float32) - - # Compute GEMV using batched matmul: a_scaled [l, m, k] @ b_scaled [l, 1, k] -> [l, m, 1] - # For each batch l: a[i, :, :] @ b[i, 0, :].T - result = torch.zeros((l, m, 1), dtype=torch.float32, device=a.device) + # Apply blockwise scaling: elementwise multiplication + # This simulates NVFP4 GEMV via 2 FFMA based elementwise multiplication + # and 1 FFMA based matmul computations + res_a = a_ref * sfa_ref # [m, k, l] + res_b = b_ref * sfb_ref # [1, k, l] + + # Compute batched GEMV: C[m, n, l] = A[m, k, l] @ B[n, k, l] + # For each batch: c[:, :, i] = a[:, :, i] @ b[0, :, i].T + result = torch.zeros((m, n, l), dtype=torch.float32, device=a.device) for i in range(l): - result[i, :, 0] = (a_scaled[i, :, :] @ b_scaled[i, 0, :]).to(c.dtype) - c[...] = result.to(c.dtype) + # res_a[:, :, i] is [m, k], res_b[0, :, i] is [k] + # matmul gives [m], reshape to [m, 1] + result[:, 0, i] = res_a[:, :, i] @ res_b[0, :, i] + # Store result in output tensor + c[...] = result.to(c.dtype) return c + def generate_input( m: int, k: int, l: int, seed: int, ): + """ + Generate input tensors for NVFP4 block-scaled GEMV. + + This follows the pattern from nvfp4_gemv_cute_layout.py for tensor preparation. + + Args: + m: Number of rows in matrix A + k: Number of columns in A (and length of vector b) + l: Batch size + seed: Random seed for reproducibility + + Returns: + Tuple of (a, b, scale_a, scale_b, c) where: + a: [m, k, l] - Input matrix in FP4 (simulated with uint8) + b: [1, k, l] - Input vector in FP4 (simulated with uint8) + scale_a: [m, k, l] - Expanded scale factors for a in FP32 + scale_b: [1, k, l] - Expanded scale factors for b in FP32 + c: [m, 1, l] - Output vector in FP16 + """ torch.manual_seed(seed) block_size = 16 - scale_k = k // block_size - - # Create fp4 input a, b tensors with LxMxK layout - # torch.float4e2m1fn is not a standard torch dtype; use torch.uint8 as a placeholder for fp4 - a = torch.arange(l * m * k, dtype=torch.float32, device="cuda").reshape(m, k, l).to(torch.uint8) - b = torch.arange(l * 1 * k, dtype=torch.float32, device="cuda").reshape(1, k, l).to(torch.uint8) - # Create fp16 output tensor with LxMx1 layout - c = torch.arange(l * m * 1, dtype=torch.float32, device="cuda").reshape(m, 1, l).to(torch.float16) + n = 1 # GEMV: N dimension is always 1 + scale_k = ceil_div(k, block_size) - # Create scales factor with f32 data type - def ceil_div(a, b): - return (a + b - 1) // b + # Create input tensors A, b following the MxKxL memory layout + # This matches: cutlass_torch.matrix(l, m, k, False, cutlass.Float32) + # which creates tensors with contiguous k dimension (stride-1) - # every 16 k elements share the same scale factor - # Set the block size for blockwise scaling - block_size = 16 - # Compute the number of scale factors needed along k (ceil division) - scale_k = ceil_div(k, block_size) - # Define the shape for scale_a: [l, m, scale_k] - scale_a_shape = (l, m, scale_k) - # Define the shape for scale_b: [l, 1, scale_k] - scale_b_shape = (l, 1, scale_k) - # Permute order to match expected layout: (m, scale_k, l) - scale_permute_order = (1, 2, 0) - # Generate random scale factors for a, then permute to (m, scale_k, l) - scale_a_f32 = torch.randint(1, 3, scale_a_shape, dtype=torch.float32, device="cuda").permute(scale_permute_order) - # Generate random scale factors for b, then permute to (1, scale_k, l) - scale_b_f32 = torch.randint(1, 3, scale_b_shape, dtype=torch.float32, device="cuda").permute(scale_permute_order) - - return (a, b, scale_a_f32, scale_b_f32, c) + # Generate random FP32 values, then convert to uint8 (FP4 placeholder) + # Shape transformations: (l, m, k) -> permute to (m, k, l) for MxKxL layout + a = torch.randn(l, m, k, dtype=torch.float32, device="cuda") + a = a.permute(1, 2, 0).contiguous().to(torch.uint8) # [m, k, l] + + b = torch.randn(l, n, k, dtype=torch.float32, device="cuda") + b = b.permute(1, 2, 0).contiguous().to(torch.uint8) # [1, k, l] + + # Create output tensor C in FP16 with MxNxL layout (N=1 for GEMV) + c = torch.zeros(l, m, n, dtype=torch.float32, device="cuda") + c = c.permute(1, 2, 0).contiguous().to(torch.float16) # [m, 1, l] + + # Create scale factors with FP32 data type + # Original ref_shape is (l, mn, sf_k), then permuted to (mn, sf_k, l) + ref_shape = (l, m, scale_k) + ref_permute_order = (1, 2, 0) # Permute from LxMxScaleK to MxScaleKxL + + # Generate random scale factors in range [1, 3) for better numerical stability + scale_a_sf = torch.randint(1, 3, ref_shape, dtype=torch.float32, device="cuda") # [1, 3) + scale_a_sf = scale_a_sf.permute(ref_permute_order).contiguous() # [m, scale_k, l] + + ref_shape_b = (l, n, scale_k) + scale_b_sf = torch.randint(1, 3, ref_shape_b, dtype=torch.float32, device="cuda") # [1, 3) + scale_b_sf = scale_b_sf.permute(ref_permute_order).contiguous() # [n, scale_k, l] + + # Expand scale factors from [m, scale_k, l] to [m, k, l] + # This matches the expansion done in nvfp4_gemv_cute_layout.py lines 320-328 + # The pattern: permute -> unsqueeze -> expand -> reshape -> permute -> prune + scale_a_expanded = ( + scale_a_sf.permute(2, 0, 1) # [l, m, scale_k] + .unsqueeze(-1) # [l, m, scale_k, 1] + .expand(l, m, scale_k, block_size) # [l, m, scale_k, block_size] + .reshape(l, m, scale_k * block_size) # [l, m, k] + .permute(*ref_permute_order) # [m, k, l] + ) + scale_a_expanded = scale_a_expanded[:, :k, :] # Prune to exact k + + scale_b_expanded = ( + scale_b_sf.permute(2, 0, 1) # [l, n, scale_k] + .unsqueeze(-1) # [l, n, scale_k, 1] + .expand(l, n, scale_k, block_size) # [l, n, scale_k, block_size] + .reshape(l, n, scale_k * block_size) # [l, n, k] + .permute(*ref_permute_order) # [n, k, l] + ) + scale_b_expanded = scale_b_expanded[:, :k, :] # Prune to exact k + + return (a, b, scale_a_expanded, scale_b_expanded, c) + -check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) \ No newline at end of file +check_implementation = make_match_reference(ref_kernel, rtol=1e-02, atol=1e-01) diff --git a/problems/nvidia/gemv/submission.py b/problems/nvidia/gemv/submission.py index 5269517..ba43953 100644 --- a/problems/nvidia/gemv/submission.py +++ b/problems/nvidia/gemv/submission.py @@ -2,7 +2,6 @@ import torch from task import input_t, output_t -from typing import Tuple import cutlass import cutlass.cute as cute @@ -10,29 +9,42 @@ from cutlass.cute.runtime import from_dlpack import cutlass.utils.blockscaled_layout as blockscaled_utils -mma_tiler_mnk = (128, 1, 64) -ab_dtype = cutlass.Float4E2M1FN -sf_dtype = cutlass.Float8E4M3FN -c_dtype = cutlass.Float16 -block_size = 16 -threads_per_cta = 128 +# Kernel configuration parameters +mma_tiler_mnk = (128, 1, 64) # Tile sizes for M, N, K dimensions +ab_dtype = cutlass.Float4E2M1FN # FP4 data type for A and B +sf_dtype = cutlass.Float8E8M0FNU # FP8 data type for scale factors +c_dtype = cutlass.Float16 # FP16 output type +block_size = 16 # Scale factor block size (16 elements share one scale) +threads_per_cta = 128 # Number of threads per CUDA thread block + +def ceil_div(a, b): + """Helper function for ceiling division""" + return (a + b - 1) // b + -# Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout @cute.jit def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( sf_ref_tensor: cute.Tensor, sf_mma_tensor: cute.Tensor, ): - # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) - # group to ((32, 4, rest_m), (4, rest_k), l) + """ + Convert scale factor tensor from reference MxKxL layout to MMA layout. + + This follows the cuBLAS block-scaling factors layout specification: + https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout + + """ + # sf_mma_tensor has flattened shape (32, 4, rest_m, 4, rest_k, l) + # Group modes to ((32, 4, rest_m), (4, rest_k), l) for hierarchical indexing sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + + # Copy data from reference layout to MMA layout for i in cutlass.range(cute.size(sf_ref_tensor)): mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] -# GPU device kernel -# Using FFMA to simulate NVFP4 block-scaled GEMV computation + @cute.kernel def kernel( mA_mkl: cute.Tensor, @@ -41,39 +53,62 @@ def kernel( mSFB_nkl: cute.Tensor, mC_mnl: cute.Tensor, ): + """ + GPU device kernel for NVFP4 block-scaled GEMV computation. + + This kernel simulates NVFP4 computation using FFMA (Fused Multiply-Add). + Each thread computes one element of the output vector. + + Args: + mA_mkl: Input matrix A in MxKxL layout (FP4) + mB_nkl: Input vector b in NxKxL layout where N=1 (FP4) + mSFA_mkl: Scale factors for A (FP8) + mSFB_nkl: Scale factors for b (FP8) + mC_mnl: Output vector c in MxNxL layout where N=1 (FP16) + """ + # Get thread and block indices bidx, bidy, bidz = cute.arch.block_idx() tidx, _, _ = cute.arch.thread_idx() - # mma_coord_mnk = (bidx, bidy, bidz) - # (bM, bK, RestM, RestK, RestL) + # Tile input tensors according to MMA configuration + # Each tile processes mma_tiler_mnk elements + + # Tile A: shape becomes (bM, bK, RestM, RestK, RestL) + # bM = (32, 4) for Tensor Core, bK = (16, 4) for scale blocks gA_mkl = cute.local_tile( mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) ) - # (bM, bK, RestM, RestK, RestL) - # bM = (32, 4) - # bK = (16, 4) + + # Tile scale factors for A with same pattern gSFA_mkl = cute.local_tile( mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) ) - # (bN, bK, RestN, RestK, RestL) + + # Tile B: shape becomes (bN, bK, RestN, RestK, RestL) where N=1 gB_nkl = cute.local_tile( mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) ) - # (bN, bK, RestN, RestK, RestL) + + # Tile scale factors for B with same pattern gSFB_nkl = cute.local_tile( mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) ) - # (bM, bN, RestM, RestN, RestL) + + # Tile output C: shape becomes (bM, bN, RestM, RestN, RestL) gC_mnl = cute.local_tile( mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) ) + # Each thread computes one output element + # Index into output tile using thread ID and block coordinates tCgC = gC_mnl[tidx, None, bidx, bidy, bidz] tCgC = cute.make_tensor(tCgC.iterator, 1) res = cute.zeros_like(tCgC, cutlass.Float32) + # Create Tensors for tile views with proper shapes k_tile_cnt = gA_mkl.layout[3].shape for k_tile in range(k_tile_cnt): + # Extract this thread's slice of A, B, and scale factors for current K tile tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz] tBgB = gB_nkl[None, None, bidy, k_tile, bidz] tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz] @@ -85,57 +120,106 @@ def kernel( sfa_val_fp8 = tAgSFA.load() sfb_val_fp8 = tBgSFB.load() - # Convert to f32 for FFMA computation + # Convert to FP32 for FFMA computation a_val = a_val_nvfp4.to(cutlass.Float32) b_val = b_val_nvfp4.to(cutlass.Float32) sfa_val = sfa_val_fp8.to(cutlass.Float32) sfb_val = sfb_val_fp8.to(cutlass.Float32) + # Compute block-scaled dot product + # Each scale factor applies to block_size consecutive elements for i in cutlass.range_constexpr(mma_tiler_mnk[2] // block_size): for j in cutlass.range_constexpr(block_size): + # Accumulate: res += (a * scale_a) * (b * scale_b) res += ( a_val[i * block_size + j] * sfa_val[i] * b_val[i * block_size + j] * sfb_val[i] ) + + # Store result back to global memory in FP16 tCgC.store(res.to(cutlass.Float16)) return + @cute.jit def my_kernel( a_tensor: cute.Tensor, - b_tensor: cute.Tensor, + b_tensor: cute.Tensor, sfa_tensor: cute.Tensor, sfb_tensor: cute.Tensor, - c_tensor: cute.Tensor): - # (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) + c_tensor: cute.Tensor, +): + """ + Host-side JIT function to prepare tensors and launch GPU kernel. + + This function: + 1. Converts scale factor tensors to the correct MMA layout + 2. Computes grid dimensions based on tensor shapes + 3. Launches the CUDA kernel + + Args: + a_tensor: Input matrix A (CuTe tensor) + b_tensor: Input vector b (CuTe tensor) + sfa_tensor: Scale factors for A (CuTe tensor) + sfb_tensor: Scale factors for B (CuTe tensor) + c_tensor: Output vector c (CuTe tensor) + """ + # Convert scale factor tensors to MMA layout + # The layout matches Tensor Core requirements: (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( a_tensor.shape, block_size ) sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) - # (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( b_tensor.shape, block_size ) sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) - # Compute grid size + + # Compute grid dimensions + # Grid is (M_blocks, 1, L) where: + # - M_blocks = ceil(M / 128) to cover all output rows + # - N=1 for GEMV (middle dimension) + # - L = batch size grid = ( cute.ceil_div(c_tensor.shape[0], 128), 1, c_tensor.shape[2], ) - # Launch the kernel synchronously + + # Launch the CUDA kernel kernel(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor).launch( grid=grid, block=[threads_per_cta, 1, 1], - cluster=(1, 1, 1) + cluster=(1, 1, 1), ) + return -def ceil_div(a, b): - return (a + b - 1) // b -def create_scale_factor_tensor(l, mn, sf_k, ref, dtype): +def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype, ref_sf): + """ + Create scale factor tensor in MMA layout for CuTe kernel. + + This function converts a reference scale tensor from MxScaleKxL layout + to the MMA-compatible layout required by Tensor Cores. + + Args: + l: Batch size + mn: M or N dimension (number of rows for A, 1 for b) + k: K dimension (full, not scale_k) + sf_vec_size: Scale factor block size (16) + dtype: Target CuTe data type (e.g., cutlass.Float8E8M0FNU) + ref_sf: Reference scale tensor in [mn, scale_k, l] layout (PyTorch tensor) + + Returns: + Tuple of (ref_expanded, cute_tensor, cute_torch_tensor): + - ref_expanded: Expanded reference tensor in [mn, k, l] layout for CPU validation + - cute_tensor: CuTe tensor with MMA layout + - cute_torch_tensor: Underlying PyTorch tensor + """ + sf_k = ceil_div(k, sf_vec_size) atom_m = (32, 4) atom_k = 4 @@ -147,29 +231,27 @@ def create_scale_factor_tensor(l, mn, sf_k, ref, dtype): atom_m[1], atom_k, ) - mma_permute_order = (3, 4, 1, 5, 2, 0) - # Create f32 cute torch tensor (cpu) - cute_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( - mma_shape, - torch.float32, - permute_order=mma_permute_order, - init_type=cutlass_torch.TensorInitType.RANDOM, - init_config=cutlass_torch.RandomInitConfig( - min_val=0, - max_val=1, - ), - ) + # Move reference scale factors to CPU if needed + ref_sf_cpu = ref_sf.cpu() if ref_sf.is_cuda else ref_sf + # Reshape to ref_shape format: [mn, scale_k, l] -> [l, mn, scale_k] + ref_f32_torch_tensor_cpu = ref_sf_cpu.permute(2, 0, 1).contiguous() + + # Create f32 MMA tensor on CPU using PyTorch + cute_f32_torch_tensor_cpu = torch.randint(0, 1, mma_shape, dtype=torch.float32) + # Permute to MMA layout + cute_f32_torch_tensor_cpu = cute_f32_torch_tensor_cpu.permute(mma_permute_order).contiguous() - # convert ref f32 tensor to cute f32 tensor + # Convert reference f32 tensor to CuTe f32 tensor using layout conversion cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - from_dlpack(ref), + from_dlpack(ref_f32_torch_tensor_cpu), from_dlpack(cute_f32_torch_tensor_cpu), ) + # Move to GPU cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() - # Create dtype cute torch tensor (cpu) + # Create CuTe tensor with target dtype (FP8) cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( cute_f32_torch_tensor_cpu, dtype, @@ -177,7 +259,7 @@ def create_scale_factor_tensor(l, mn, sf_k, ref, dtype): assumed_align=16, ) - # Convert f32 cute tensor to dtype cute tensor + # Convert f32 CuTe tensor to target dtype CuTe tensor cute_tensor = cutlass_torch.convert_cute_tensor( cute_f32_torch_tensor, cute_tensor, @@ -187,35 +269,52 @@ def create_scale_factor_tensor(l, mn, sf_k, ref, dtype): return cute_tensor, cute_torch_tensor -def custom_kernel(data: input_t): +def custom_kernel(data: input_t) -> output_t: """ - Execute the kernel. If not already compiled, compile it first. + Execute the block-scaled GEMV kernel. + + This is the main entry point called by the evaluation framework. + It converts PyTorch tensors to CuTe tensors, launches the kernel, + and returns the result. Args: - data: Tuple of (a, b, scale_a, scale_b, c) tensors + data: Tuple of (a, b, scale_a, scale_b, c) PyTorch tensors + a: [m, k, l] - Input matrix in FP4 (simulated with uint8) + b: [1, k, l] - Input vector in FP4 (simulated with uint8) + scale_a: [m, k, l] - Expanded scale factors for a in FP32 + scale_b: [1, k, l] - Expanded scale factors for b in FP32 + c: [m, 1, l] - Output vector in FP16 Returns: - Output tensor c + Output tensor c with computed GEMV results """ a, b, scale_a, scale_b, c = data - # Get dimensions from MxKxL layout - m = a.shape[0] - k = a.shape[1] - l = a.shape[2] - n = b.shape[0] # should be 1 for GEMV - # Convert torch tensors to CuTe tensors + # Get dimensions from MxKxL layout + m, k, l = a.shape + n = 1 # GEMV: N dimension is always 1 + scale_k = ceil_div(k, block_size) + + # GEMV, N must be 1 + assert n == 1, "GEMV requires N=1" + + # Create reference tensors in LxMxK layout (for CuTe compatibility) + a_ref = a.to(torch.float32).permute(2, 0, 1).contiguous() # [l, m, k] + b_ref = b.to(torch.float32).permute(2, 0, 1).contiguous() # [l, 1, k] + c_ref = torch.zeros(l, m, n, dtype=torch.float32, device=a.device) # [l, m, 1] + + # Create CuTe tensors for A, B, C a_tensor, a_torch = cutlass_torch.cute_tensor_like( - a, ab_dtype, is_dynamic_layout=True, assumed_align=16 + a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 ) b_tensor, b_torch = cutlass_torch.cute_tensor_like( - b, ab_dtype, is_dynamic_layout=True, assumed_align=16 + b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 ) c_tensor, c_torch = cutlass_torch.cute_tensor_like( - c, c_dtype, is_dynamic_layout=True, assumed_align=16 + c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 ) - # Mark tensor with element divisibility for 16B alignment + # Mark tensors with element divisibility for 16B alignment a_tensor.mark_compact_shape_dynamic( mode=1, stride_order=(2, 0, 1), @@ -231,16 +330,20 @@ def custom_kernel(data: input_t): (2, 1, 0), divisibility=16, ) - - # Create scale tensors - scale_k = ceil_div(k, block_size) - scale_a_tensor, scale_a_torch = create_scale_factor_tensor( - l, m, scale_k, scale_a, sf_dtype + + # Extract compact scale factors from expanded scales + # scale_a and scale_b are [m/n, k, l], we need [m/n, scale_k, l] + # Take every block_size-th element along k dimension + scale_a_compact = scale_a[:, ::block_size, :].contiguous() # [m, scale_k, l] + scale_b_compact = scale_b[:, ::block_size, :].contiguous() # [1, scale_k, l] + + # Create scale factor tensors in MMA layout + sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( + l, m, k, block_size, sf_dtype, scale_a_compact ) - scale_b_tensor, scale_b_torch = create_scale_factor_tensor( - l, 1, scale_k, scale_b, sf_dtype + sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( + l, n, k, block_size, sf_dtype, scale_b_compact ) # Run the compiled kernel - my_kernel(a_tensor, b_tensor, scale_a_tensor, scale_b_tensor, c_tensor) + my_kernel(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor) return c - diff --git a/problems/nvidia/gemv/task.yml b/problems/nvidia/gemv/task.yml index 9583c0a..7f59454 100644 --- a/problems/nvidia/gemv/task.yml +++ b/problems/nvidia/gemv/task.yml @@ -43,7 +43,6 @@ templates: tests: - {"m": 128, "k": 256, "l": 1, "seed": 1111} - - {"m": 2384, "k": 4608, "l": 2, "seed": 1111} benchmarks: - {"m": 7168, "k": 16384, "l":1, "seed": 1111} From 5726cbc968dc08f8a2c8fecedd220837caf754e2 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Thu, 9 Oct 2025 22:36:03 -0700 Subject: [PATCH 03/29] fix function failure for fp4 simulated gemv --- problems/nvidia/gemv/reference.py | 137 ++++++++++---------------- problems/nvidia/gemv/submission.py | 153 +++++++++-------------------- problems/nvidia/gemv/task.yml | 10 ++ 3 files changed, 109 insertions(+), 191 deletions(-) diff --git a/problems/nvidia/gemv/reference.py b/problems/nvidia/gemv/reference.py index c430eec..2773531 100644 --- a/problems/nvidia/gemv/reference.py +++ b/problems/nvidia/gemv/reference.py @@ -2,6 +2,8 @@ from task import input_t, output_t from utils import make_match_reference +block_size = 16 + def ceil_div(a, b): """Helper function for ceiling division""" return (a + b - 1) // b @@ -14,23 +16,6 @@ def ref_kernel( This simulates the GEMV operation: C = A @ b where A and b are block-scaled with FP4 values and FP8 scale factors. - - Tensor shapes (MxKxL layout): - a: [m, k, l] - Input matrix in FP4 - b: [1, k, l] - Input vector in FP4 - scale_a: [m, k, l] - Expanded blockwise scales for a in FP32 - scale_b: [1, k, l] - Expanded blockwise scales for b in FP32 - c: [m, 1, l] - Output vector in FP16 - - where: - m: number of rows in A - k: number of columns in A (must be multiple of block_size) - l: batch size - - The reference implementation follows the pattern: - res_a = einsum("mkl,mkl->mkl", a_ref, sfa_ref) - res_b = einsum("nkl,nkl->nkl", b_ref, sfb_ref) # n=1 for GEMV - ref = einsum("mkl,nkl->mnl", res_a, res_b) """ a, b, scale_a, scale_b, c = data @@ -38,31 +23,60 @@ def ref_kernel( m, k, l = a.shape n = 1 # GEMV: N dimension is always 1 + scale_k = ceil_div(k, block_size) + + # Extend scale factor tensor from [m, scale_k, l] to [m, k, l] + ref_permute_order = (1, 2, 0) + scale_a = ( + scale_a.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, m, scale_k, block_size) + .reshape(l, m, scale_k * block_size) + .permute(*ref_permute_order) + ) + # prune to mkl for reference check. + scale_a = scale_a[:, :k, :] + + scale_b = ( + scale_b.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, n, scale_k, block_size) + .reshape(l, n, scale_k * block_size) + .permute(*ref_permute_order) + ) + # prune to mkl for reference check. + scale_b = scale_b[:, :k, :] + # Convert to f32 for reference computation - a_ref = a.to(torch.float32) - b_ref = b.to(torch.float32) - sfa_ref = scale_a.to(torch.float32) - sfb_ref = scale_b.to(torch.float32) - # Apply blockwise scaling: elementwise multiplication # This simulates NVFP4 GEMV via 2 FFMA based elementwise multiplication # and 1 FFMA based matmul computations - res_a = a_ref * sfa_ref # [m, k, l] - res_b = b_ref * sfb_ref # [1, k, l] + res_a = a.to(torch.float32) * scale_a.cuda() # [m, k, l] + res_b = b.to(torch.float32) * scale_b.cuda() # [1, k, l] # Compute batched GEMV: C[m, n, l] = A[m, k, l] @ B[n, k, l] - # For each batch: c[:, :, i] = a[:, :, i] @ b[0, :, i].T - result = torch.zeros((m, n, l), dtype=torch.float32, device=a.device) - for i in range(l): - # res_a[:, :, i] is [m, k], res_b[0, :, i] is [k] - # matmul gives [m], reshape to [m, 1] - result[:, 0, i] = res_a[:, :, i] @ res_b[0, :, i] - - # Store result in output tensor - c[...] = result.to(c.dtype) + for i in range(c.shape[2]): + # matmul gives [m], convert to c.dtype then assign to [m, 1] + acc = res_a[:, :, i] @ res_b[0, :, i] + c[:, 0, i] = acc.to(torch.float16) return c +# Helper function to create reference scale factor tensor SFA/SFB +# for 1x16 block scaled wise use case and follow the layout requirement +# defined in https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout +def create_scale_factor_tensor(l, mn, k, block_size): + scale_k = ceil_div(k, block_size) + ref_shape = (l, mn, scale_k) + ref_permute_order = (1, 2, 0) + + # Create f32 ref torch tensor (cpu) + # After this line, ref_f32_torch_tensor_cpu has shape (mn, scale_k, l) + ref_f32_torch_tensor_cpu = torch.randint( + 1, 3, ref_shape, dtype=torch.float32 + ).permute(*ref_permute_order) + return ref_f32_torch_tensor_cpu + def generate_input( m: int, k: int, @@ -89,61 +103,18 @@ def generate_input( c: [m, 1, l] - Output vector in FP16 """ torch.manual_seed(seed) - block_size = 16 n = 1 # GEMV: N dimension is always 1 - scale_k = ceil_div(k, block_size) - - # Create input tensors A, b following the MxKxL memory layout - # This matches: cutlass_torch.matrix(l, m, k, False, cutlass.Float32) - # which creates tensors with contiguous k dimension (stride-1) # Generate random FP32 values, then convert to uint8 (FP4 placeholder) - # Shape transformations: (l, m, k) -> permute to (m, k, l) for MxKxL layout - a = torch.randn(l, m, k, dtype=torch.float32, device="cuda") - a = a.permute(1, 2, 0).contiguous().to(torch.uint8) # [m, k, l] - - b = torch.randn(l, n, k, dtype=torch.float32, device="cuda") - b = b.permute(1, 2, 0).contiguous().to(torch.uint8) # [1, k, l] - - # Create output tensor C in FP16 with MxNxL layout (N=1 for GEMV) - c = torch.zeros(l, m, n, dtype=torch.float32, device="cuda") - c = c.permute(1, 2, 0).contiguous().to(torch.float16) # [m, 1, l] + a = torch.randint(0, 2, (l, m, k), dtype=torch.uint8, device="cuda").permute(1, 2, 0) + b = torch.randint(1, 3, (l, n, k), dtype=torch.uint8, device="cuda").permute(1, 2, 0) + c = torch.randn((l, n, m), dtype=torch.float16, device="cuda").permute(2, 1, 0) # Create scale factors with FP32 data type - # Original ref_shape is (l, mn, sf_k), then permuted to (mn, sf_k, l) - ref_shape = (l, m, scale_k) - ref_permute_order = (1, 2, 0) # Permute from LxMxScaleK to MxScaleKxL - - # Generate random scale factors in range [1, 3) for better numerical stability - scale_a_sf = torch.randint(1, 3, ref_shape, dtype=torch.float32, device="cuda") # [1, 3) - scale_a_sf = scale_a_sf.permute(ref_permute_order).contiguous() # [m, scale_k, l] - - ref_shape_b = (l, n, scale_k) - scale_b_sf = torch.randint(1, 3, ref_shape_b, dtype=torch.float32, device="cuda") # [1, 3) - scale_b_sf = scale_b_sf.permute(ref_permute_order).contiguous() # [n, scale_k, l] - - # Expand scale factors from [m, scale_k, l] to [m, k, l] - # This matches the expansion done in nvfp4_gemv_cute_layout.py lines 320-328 - # The pattern: permute -> unsqueeze -> expand -> reshape -> permute -> prune - scale_a_expanded = ( - scale_a_sf.permute(2, 0, 1) # [l, m, scale_k] - .unsqueeze(-1) # [l, m, scale_k, 1] - .expand(l, m, scale_k, block_size) # [l, m, scale_k, block_size] - .reshape(l, m, scale_k * block_size) # [l, m, k] - .permute(*ref_permute_order) # [m, k, l] - ) - scale_a_expanded = scale_a_expanded[:, :k, :] # Prune to exact k - - scale_b_expanded = ( - scale_b_sf.permute(2, 0, 1) # [l, n, scale_k] - .unsqueeze(-1) # [l, n, scale_k, 1] - .expand(l, n, scale_k, block_size) # [l, n, scale_k, block_size] - .reshape(l, n, scale_k * block_size) # [l, n, k] - .permute(*ref_permute_order) # [n, k, l] - ) - scale_b_expanded = scale_b_expanded[:, :k, :] # Prune to exact k + scale_a = create_scale_factor_tensor(l, m, k, block_size) + scale_b = create_scale_factor_tensor(l, 1, k, block_size) - return (a, b, scale_a_expanded, scale_b_expanded, c) + return (a, b, scale_a, scale_b, c) -check_implementation = make_match_reference(ref_kernel, rtol=1e-02, atol=1e-01) +check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) diff --git a/problems/nvidia/gemv/submission.py b/problems/nvidia/gemv/submission.py index ba43953..563ef34 100644 --- a/problems/nvidia/gemv/submission.py +++ b/problems/nvidia/gemv/submission.py @@ -1,3 +1,4 @@ +from torch._higher_order_ops.torchbind import call_torchbind_fake import cuda.bindings.driver as cuda import torch @@ -38,7 +39,6 @@ def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( # Group modes to ((32, 4, rest_m), (4, rest_k), l) for hierarchical indexing sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) - # Copy data from reference layout to MMA layout for i in cutlass.range(cute.size(sf_ref_tensor)): mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) @@ -53,62 +53,38 @@ def kernel( mSFB_nkl: cute.Tensor, mC_mnl: cute.Tensor, ): - """ - GPU device kernel for NVFP4 block-scaled GEMV computation. - - This kernel simulates NVFP4 computation using FFMA (Fused Multiply-Add). - Each thread computes one element of the output vector. - - Args: - mA_mkl: Input matrix A in MxKxL layout (FP4) - mB_nkl: Input vector b in NxKxL layout where N=1 (FP4) - mSFA_mkl: Scale factors for A (FP8) - mSFB_nkl: Scale factors for b (FP8) - mC_mnl: Output vector c in MxNxL layout where N=1 (FP16) - """ - # Get thread and block indices bidx, bidy, bidz = cute.arch.block_idx() tidx, _, _ = cute.arch.thread_idx() - # Tile input tensors according to MMA configuration - # Each tile processes mma_tiler_mnk elements - - # Tile A: shape becomes (bM, bK, RestM, RestK, RestL) - # bM = (32, 4) for Tensor Core, bK = (16, 4) for scale blocks + # (bM, bK, RestM, RestK, RestL) gA_mkl = cute.local_tile( mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) ) - - # Tile scale factors for A with same pattern + # (bM, bK, RestM, RestK, RestL) + # bM = (32, 4) + # bK = (16, 4) gSFA_mkl = cute.local_tile( mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) ) - - # Tile B: shape becomes (bN, bK, RestN, RestK, RestL) where N=1 + # (bN, bK, RestN, RestK, RestL) gB_nkl = cute.local_tile( mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) ) - - # Tile scale factors for B with same pattern + # (bN, bK, RestN, RestK, RestL) gSFB_nkl = cute.local_tile( mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) ) - - # Tile output C: shape becomes (bM, bN, RestM, RestN, RestL) + # (bM, bN, RestM, RestN, RestL) gC_mnl = cute.local_tile( mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) ) - # Each thread computes one output element - # Index into output tile using thread ID and block coordinates tCgC = gC_mnl[tidx, None, bidx, bidy, bidz] tCgC = cute.make_tensor(tCgC.iterator, 1) res = cute.zeros_like(tCgC, cutlass.Float32) - # Create Tensors for tile views with proper shapes k_tile_cnt = gA_mkl.layout[3].shape for k_tile in range(k_tile_cnt): - # Extract this thread's slice of A, B, and scale factors for current K tile tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz] tBgB = gB_nkl[None, None, bidy, k_tile, bidz] tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz] @@ -120,25 +96,20 @@ def kernel( sfa_val_fp8 = tAgSFA.load() sfb_val_fp8 = tBgSFB.load() - # Convert to FP32 for FFMA computation + # Convert to f32 for FFMA computation a_val = a_val_nvfp4.to(cutlass.Float32) b_val = b_val_nvfp4.to(cutlass.Float32) sfa_val = sfa_val_fp8.to(cutlass.Float32) sfb_val = sfb_val_fp8.to(cutlass.Float32) - # Compute block-scaled dot product - # Each scale factor applies to block_size consecutive elements for i in cutlass.range_constexpr(mma_tiler_mnk[2] // block_size): for j in cutlass.range_constexpr(block_size): - # Accumulate: res += (a * scale_a) * (b * scale_b) res += ( a_val[i * block_size + j] * sfa_val[i] * b_val[i * block_size + j] * sfb_val[i] ) - - # Store result back to global memory in FP16 tCgC.store(res.to(cutlass.Float16)) return @@ -181,7 +152,6 @@ def my_kernel( # Compute grid dimensions # Grid is (M_blocks, 1, L) where: # - M_blocks = ceil(M / 128) to cover all output rows - # - N=1 for GEMV (middle dimension) # - L = batch size grid = ( cute.ceil_div(c_tensor.shape[0], 128), @@ -195,63 +165,44 @@ def my_kernel( block=[threads_per_cta, 1, 1], cluster=(1, 1, 1), ) - return -def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype, ref_sf): - """ - Create scale factor tensor in MMA layout for CuTe kernel. - - This function converts a reference scale tensor from MxScaleKxL layout - to the MMA-compatible layout required by Tensor Cores. +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + + +# Helper function to convert reference tensor to cute tensor +def create_scale_factor_cute_tensor(ref_tensor, l, mn, k, block_size, dtype): - Args: - l: Batch size - mn: M or N dimension (number of rows for A, 1 for b) - k: K dimension (full, not scale_k) - sf_vec_size: Scale factor block size (16) - dtype: Target CuTe data type (e.g., cutlass.Float8E8M0FNU) - ref_sf: Reference scale tensor in [mn, scale_k, l] layout (PyTorch tensor) + scale_k = ceil_div(k, block_size) - Returns: - Tuple of (ref_expanded, cute_tensor, cute_torch_tensor): - - ref_expanded: Expanded reference tensor in [mn, k, l] layout for CPU validation - - cute_tensor: CuTe tensor with MMA layout - - cute_torch_tensor: Underlying PyTorch tensor - """ - sf_k = ceil_div(k, sf_vec_size) - atom_m = (32, 4) atom_k = 4 mma_shape = ( l, # batch size ceil_div(mn, atom_m[0] * atom_m[1]), - ceil_div(sf_k, atom_k), + ceil_div(scale_k, atom_k), atom_m[0], atom_m[1], atom_k, ) - mma_permute_order = (3, 4, 1, 5, 2, 0) - # Move reference scale factors to CPU if needed - ref_sf_cpu = ref_sf.cpu() if ref_sf.is_cuda else ref_sf - # Reshape to ref_shape format: [mn, scale_k, l] -> [l, mn, scale_k] - ref_f32_torch_tensor_cpu = ref_sf_cpu.permute(2, 0, 1).contiguous() + mma_permute_order = (3, 4, 1, 5, 2, 0) - # Create f32 MMA tensor on CPU using PyTorch - cute_f32_torch_tensor_cpu = torch.randint(0, 1, mma_shape, dtype=torch.float32) - # Permute to MMA layout - cute_f32_torch_tensor_cpu = cute_f32_torch_tensor_cpu.permute(mma_permute_order).contiguous() + # Create f32 cute torch tensor (cpu) + cute_f32_torch_tensor_cpu = torch.randint( + 1, 3, mma_shape, dtype=torch.float32 + ).permute(*mma_permute_order) - # Convert reference f32 tensor to CuTe f32 tensor using layout conversion + # Copy reference tensor to cute tensor in the customized data layout cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - from_dlpack(ref_f32_torch_tensor_cpu), + from_dlpack(ref_tensor), from_dlpack(cute_f32_torch_tensor_cpu), ) - # Move to GPU cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() - # Create CuTe tensor with target dtype (FP8) + # Create the desired data type cute torch tensor (cpu) cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( cute_f32_torch_tensor_cpu, dtype, @@ -259,7 +210,7 @@ def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype, ref_sf): assumed_align=16, ) - # Convert f32 CuTe tensor to target dtype CuTe tensor + # Convert f32 cute tensor to the desired data type cute tensor cute_tensor = cutlass_torch.convert_cute_tensor( cute_f32_torch_tensor, cute_tensor, @@ -279,11 +230,11 @@ def custom_kernel(data: input_t) -> output_t: Args: data: Tuple of (a, b, scale_a, scale_b, c) PyTorch tensors - a: [m, k, l] - Input matrix in FP4 (simulated with uint8) - b: [1, k, l] - Input vector in FP4 (simulated with uint8) - scale_a: [m, k, l] - Expanded scale factors for a in FP32 - scale_b: [1, k, l] - Expanded scale factors for b in FP32 - c: [m, 1, l] - Output vector in FP16 + a: [m, k, l] - Input matrix in float4e2m1fn (simulated with uint8) + b: [1, k, l] - Input vector in float4e2m1fn (simulated with uint8) + scale_a: [m, k, l] - Scale factors in float8_e4m3fnuz (simulated with FP32) + scale_b: [1, k, l] - Scale factors in float8_e4m3fnuz (simulated with FP32) + c: [m, 1, l] - Output vector in float32 Returns: Output tensor c with computed GEMV results @@ -292,29 +243,18 @@ def custom_kernel(data: input_t) -> output_t: # Get dimensions from MxKxL layout m, k, l = a.shape - n = 1 # GEMV: N dimension is always 1 - scale_k = ceil_div(k, block_size) - - # GEMV, N must be 1 - assert n == 1, "GEMV requires N=1" - - # Create reference tensors in LxMxK layout (for CuTe compatibility) - a_ref = a.to(torch.float32).permute(2, 0, 1).contiguous() # [l, m, k] - b_ref = b.to(torch.float32).permute(2, 0, 1).contiguous() # [l, 1, k] - c_ref = torch.zeros(l, m, n, dtype=torch.float32, device=a.device) # [l, m, 1] # Create CuTe tensors for A, B, C a_tensor, a_torch = cutlass_torch.cute_tensor_like( - a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + a, ab_dtype, is_dynamic_layout=True, assumed_align=16 ) b_tensor, b_torch = cutlass_torch.cute_tensor_like( - b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + b, ab_dtype, is_dynamic_layout=True, assumed_align=16 ) c_tensor, c_torch = cutlass_torch.cute_tensor_like( - c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + c, c_dtype, is_dynamic_layout=True, assumed_align=16 ) - - # Mark tensors with element divisibility for 16B alignment + # Mark tensor with element divisibility for 16B alignment a_tensor.mark_compact_shape_dynamic( mode=1, stride_order=(2, 0, 1), @@ -331,19 +271,16 @@ def custom_kernel(data: input_t) -> output_t: divisibility=16, ) - # Extract compact scale factors from expanded scales - # scale_a and scale_b are [m/n, k, l], we need [m/n, scale_k, l] - # Take every block_size-th element along k dimension - scale_a_compact = scale_a[:, ::block_size, :].contiguous() # [m, scale_k, l] - scale_b_compact = scale_b[:, ::block_size, :].contiguous() # [1, scale_k, l] - - # Create scale factor tensors in MMA layout - sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( - l, m, k, block_size, sf_dtype, scale_a_compact + # Create cute tensors from reference tensors + sfa_tensor, sfa_torch = create_scale_factor_cute_tensor( + scale_a, l, m, k, block_size, sf_dtype ) - sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( - l, n, k, block_size, sf_dtype, scale_b_compact + sfb_tensor, sfb_torch = create_scale_factor_cute_tensor( + scale_b, l, 1, k, block_size, sf_dtype ) + # Run the compiled kernel + # INSERT_YOUR_CODE my_kernel(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor) - return c + + return c_torch diff --git a/problems/nvidia/gemv/task.yml b/problems/nvidia/gemv/task.yml index 7f59454..84b0665 100644 --- a/problems/nvidia/gemv/task.yml +++ b/problems/nvidia/gemv/task.yml @@ -43,6 +43,16 @@ templates: tests: - {"m": 128, "k": 256, "l": 1, "seed": 1111} + - {"m": 128, "k": 1536, "l": 1, "seed": 1111} + - {"m": 128, "k": 3072, "l": 1, "seed": 1111} + - {"m": 256, "k": 7168, "l": 1, "seed": 1111} + - {"m": 256, "k": 7168, "l": 1, "seed": 1111} + - {"m": 2432, "k": 4608, "l": 2, "seed": 1111} + - {"m": 384, "k": 7168, "l": 2, "seed": 1111} + - {"m": 512, "k": 512, "l": 2, "seed": 1111} + - {"m": 512, "k": 4096, "l": 2, "seed": 1111} + - {"m": 512, "k": 1536, "l": 2, "seed": 1111} + benchmarks: - {"m": 7168, "k": 16384, "l":1, "seed": 1111} From 232ab548051ab0c8600e01f551522fad2a0a540a Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Thu, 9 Oct 2025 22:36:46 -0700 Subject: [PATCH 04/29] rename the folder --- problems/nvidia/nvfp4_gemv/eval.py | 426 ++++++++++++++++++ problems/nvidia/nvfp4_gemv/kernel.mlir | 409 +++++++++++++++++ problems/nvidia/nvfp4_gemv/log | 15 + .../nvfp4_gemv/nvfp4_gemv_cute_layout.py | 426 ++++++++++++++++++ .../nvidia/{gemv => nvfp4_gemv}/reference.py | 0 .../nvidia/{gemv => nvfp4_gemv}/submission.py | 0 problems/nvidia/{gemv => nvfp4_gemv}/task.py | 0 problems/nvidia/{gemv => nvfp4_gemv}/task.yml | 0 .../nvidia/{gemv => nvfp4_gemv}/template.py | 0 problems/nvidia/nvfp4_gemv/test_python_1.sh | 86 ++++ problems/nvidia/nvfp4_gemv/utils.py | 176 ++++++++ 11 files changed, 1538 insertions(+) create mode 100644 problems/nvidia/nvfp4_gemv/eval.py create mode 100644 problems/nvidia/nvfp4_gemv/kernel.mlir create mode 100644 problems/nvidia/nvfp4_gemv/log create mode 100644 problems/nvidia/nvfp4_gemv/nvfp4_gemv_cute_layout.py rename problems/nvidia/{gemv => nvfp4_gemv}/reference.py (100%) rename problems/nvidia/{gemv => nvfp4_gemv}/submission.py (100%) rename problems/nvidia/{gemv => nvfp4_gemv}/task.py (100%) rename problems/nvidia/{gemv => nvfp4_gemv}/task.yml (100%) rename problems/nvidia/{gemv => nvfp4_gemv}/template.py (100%) create mode 100644 problems/nvidia/nvfp4_gemv/test_python_1.sh create mode 100644 problems/nvidia/nvfp4_gemv/utils.py diff --git a/problems/nvidia/nvfp4_gemv/eval.py b/problems/nvidia/nvfp4_gemv/eval.py new file mode 100644 index 0000000..890668a --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/eval.py @@ -0,0 +1,426 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional +import tempfile + +import torch.cuda +from cutlass.cute.nvgpu.common import OpError + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, "w") + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats( + runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) + ) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + + data = generate_input(**test.args) + torch.cuda.synchronize() + try: + submission_output = custom_kernel(_clone_data(data)) + + except OpError as E: + print(f"Encountered {E}", file=sys.stderr) + return False, str(E) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(test: TestCase): + """ + Runs a single test directly (no multiprocessing). + """ + return _run_single_test(test) + + +def run_testing(logger: PopcornOutput, tests: list[TestCase]): + """ + Executes the actual test case code and checks for correctness. + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _run_single_benchmark( + test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float +) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + # first, one obligatory correctness check + try: + output = custom_kernel(_clone_data(data)) + except OpError as E: + return f"Encountered {E}" + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 100 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if ( + stats.err / stats.mean < 0.001 + or stats.mean * stats.runs > max_time_ns + or total_bm_duration > 120e9 + ): + break + + return calculate_stats(durations) + + +def run_single_benchmark( + test: TestCase, + recheck: bool, + max_repeats: int, + max_time_ns: float, +): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return _run_single_benchmark(test, recheck, max_repeats, max_time_ns) + + +def run_benchmarking(logger: PopcornOutput, tests: list[TestCase]): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + @param logger: A PopcornOutput object used for logging benchmark results. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # warm up + run_single_benchmark(tests[0], False, 100, 10e7) + + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(test, False, 100, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log( + f"benchmark.{idx}.report", + base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), + ) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + + filename = None + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + + def build_test_string(tests: list[dict]): + as_str = "" + for test in tests: + kvs = [] + for k, v in test.items(): + kvs.append(f"{k}: {v}") + as_str += "; ".join(kvs) + "\n" + return as_str + + import yaml + + yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + if mode == "test": + tests_str = build_test_string(yaml_content.get("tests", [])) + elif mode in ("benchmark", "leaderboard", "profile"): + tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + tmp.write(tests_str.encode("utf-8")) + tmp.flush() + filename = tmp.name + + tests = get_test_cases(filename, seed) + + os.unlink(filename) + + with PopcornOutput(int(fd)) as logger: + if mode == "test": + return run_testing(logger, tests) + if mode == "benchmark": + return run_benchmarking(logger, tests) + + if mode == "leaderboard": + # warmup + run_single_benchmark(tests[0], False, 100, 1e7) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(tests[i], True, 100, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log( + f"benchmark.{i}.{field.name}", + getattr(result, field.name), + ) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log( + f"benchmark.{i}.error", str(result) + ) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemv/kernel.mlir b/problems/nvidia/nvfp4_gemv/kernel.mlir new file mode 100644 index 0000000..7a829fc --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/kernel.mlir @@ -0,0 +1,409 @@ +!memref_gmem_f32 = !cute.memref +!memref_gmem_f32_1 = !cute.memref +!memref_gmem_f32_2 = !cute.memref +!memref_gmem_f32_3 = !cute.memref +!memref_gmem_f32_4 = !cute.memref +!memref_gmem_f4E2M1FN = !cute.memref, "(?,?,?):(?{i64},1,?{i64})"> +!memref_gmem_i8 = !cute.memref, "((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))"> +!memref_gmem_i8_1 = !cute.memref +!memref_gmem_i8_2 = !cute.memref +!memref_gmem_i8_3 = !cute.memref +!memref_gmem_i8_4 = !cute.memref, "(?,?,?):(?{i64},1,?{i64})"> +!memref_rmem_f32 = !cute.memref, "8:1"> +!memref_rmem_i8 = !cute.memref, "4:1"> +module attributes {gpu.container_module} { + gpu.module @kernels { + func.func public @kernel_cutlass__convert_kernel_tensorptrf32gmemo11024100div10241_tensorptri8gmemalign16o15121010512_tensor000o1102410110101024112____Float32_Float4E2M1FN_0(%arg0: !memref_gmem_f32, %arg1: !memref_gmem_i8, %arg2: !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">, %arg3: !cute.layout<"(128,8):(8,1)">, %arg4: !cute.layout<"(128,4):(4,1)">, %arg5: i32, %arg6: i32, %arg7: i32) attributes {cute.kernel, gpu.kernel, nvvm.reqntid = array} { + %iter = cute.get_iter(%arg0) : !memref_gmem_f32 + %iter_0 = cute.get_iter(%arg1) : !memref_gmem_i8 + %iter_1 = cute.get_iter(%arg2) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> + %tup = cute.deref_arith_tuple_iter(%iter_1) : !cute.arith_tuple_iter<"(0,0,0)"> + %e0, %e1, %e2 = cute.get_leaves(%tup) : !cute.int_tuple<"(0,0,0)"> + %iter_2 = cute.get_iter(%arg0) : !memref_gmem_f32 + %iter_3 = cute.get_iter(%arg1) : !memref_gmem_i8 + %iter_4 = cute.get_iter(%arg2) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> + %tup_5 = cute.deref_arith_tuple_iter(%iter_4) : !cute.arith_tuple_iter<"(0,0,0)"> + %e0_6, %e1_7, %e2_8 = cute.get_leaves(%tup_5) : !cute.int_tuple<"(0,0,0)"> + %lay = cute.get_layout(%arg0) : !memref_gmem_f32 + %0 = cute.get_shape(%lay) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> + %e0_9, %e1_10, %e2_11, %e3, %e4, %e5 = cute.get_leaves(%0) : !cute.shape<"((1,1024,1),(?,?,?))"> + %itup = cute.to_int_tuple(%e3) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %1 = cute.get_scalars(%itup) : !cute.int_tuple<"?"> + %itup_12 = cute.to_int_tuple(%e4) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %2 = cute.get_scalars(%itup_12) : !cute.int_tuple<"?"> + %itup_13 = cute.to_int_tuple(%e5) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %3 = cute.get_scalars(%itup_13) : !cute.int_tuple<"?"> + %4 = cute.get_stride(%lay) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> + %e0_14, %e1_15, %e2_16, %e3_17, %e4_18, %e5_19 = cute.get_leaves(%4) : !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> + %itup_20 = cute.to_int_tuple(%e1_15) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %5 = cute.get_scalars(%itup_20) : !cute.int_tuple<"?{i64}"> + %itup_21 = cute.to_int_tuple(%e3_17) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %6 = cute.get_scalars(%itup_21) : !cute.int_tuple<"?{i64}"> + %itup_22 = cute.to_int_tuple(%e4_18) : !cute.stride<"?{i64 div=1024}"> to !cute.int_tuple<"?{i64 div=1024}"> + %7 = cute.get_scalars(%itup_22) : !cute.int_tuple<"?{i64 div=1024}"> + %lay_23 = cute.get_layout(%arg1) : !memref_gmem_i8 + %8 = cute.get_shape(%lay_23) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.shape<"((1,512,1),(?,?,?))"> + %e0_24, %e1_25, %e2_26, %e3_27, %e4_28, %e5_29 = cute.get_leaves(%8) : !cute.shape<"((1,512,1),(?,?,?))"> + %itup_30 = cute.to_int_tuple(%e3_27) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %9 = cute.get_scalars(%itup_30) : !cute.int_tuple<"?"> + %itup_31 = cute.to_int_tuple(%e4_28) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %10 = cute.get_scalars(%itup_31) : !cute.int_tuple<"?"> + %itup_32 = cute.to_int_tuple(%e5_29) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %11 = cute.get_scalars(%itup_32) : !cute.int_tuple<"?"> + %12 = cute.get_stride(%lay_23) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> + %e0_33, %e1_34, %e2_35, %e3_36, %e4_37, %e5_38 = cute.get_leaves(%12) : !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> + %itup_39 = cute.to_int_tuple(%e3_36) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %13 = cute.get_scalars(%itup_39) : !cute.int_tuple<"?{i64}"> + %itup_40 = cute.to_int_tuple(%e5_38) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %14 = cute.get_scalars(%itup_40) : !cute.int_tuple<"?{i64}"> + %lay_41 = cute.get_layout(%arg2) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> + %15 = cute.get_shape(%lay_41) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> + %e0_42, %e1_43, %e2_44, %e3_45, %e4_46, %e5_47 = cute.get_leaves(%15) : !cute.shape<"((1,1024,1),(?,?,?))"> + %itup_48 = cute.to_int_tuple(%e3_45) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %16 = cute.get_scalars(%itup_48) : !cute.int_tuple<"?"> + %itup_49 = cute.to_int_tuple(%e4_46) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %17 = cute.get_scalars(%itup_49) : !cute.int_tuple<"?"> + %itup_50 = cute.to_int_tuple(%e5_47) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %18 = cute.get_scalars(%itup_50) : !cute.int_tuple<"?"> + %19 = cute.get_stride(%lay_41) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> + %e0_51, %e1_52, %e2_53, %e3_54, %e4_55, %e5_56 = cute.get_leaves(%19) : !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> + %20 = cute.get_shape(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.shape<"(128,8)"> + %e0_57, %e1_58 = cute.get_leaves(%20) : !cute.shape<"(128,8)"> + %21 = cute.get_stride(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.stride<"(8,1)"> + %e0_59, %e1_60 = cute.get_leaves(%21) : !cute.stride<"(8,1)"> + %22 = cute.get_shape(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.shape<"(128,4)"> + %e0_61, %e1_62 = cute.get_leaves(%22) : !cute.shape<"(128,4)"> + %23 = cute.get_stride(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.stride<"(4,1)"> + %e0_63, %e1_64 = cute.get_leaves(%23) : !cute.stride<"(4,1)"> + %24 = cute.get_shape(%lay) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> + %e0_65, %e1_66, %e2_67, %e3_68, %e4_69, %e5_70 = cute.get_leaves(%24) : !cute.shape<"((1,1024,1),(?,?,?))"> + %itup_71 = cute.to_int_tuple(%e3_68) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %25 = cute.get_scalars(%itup_71) : !cute.int_tuple<"?"> + %itup_72 = cute.to_int_tuple(%e4_69) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %26 = cute.get_scalars(%itup_72) : !cute.int_tuple<"?"> + %itup_73 = cute.to_int_tuple(%e5_70) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %27 = cute.get_scalars(%itup_73) : !cute.int_tuple<"?"> + %28 = cute.get_stride(%lay) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> + %e0_74, %e1_75, %e2_76, %e3_77, %e4_78, %e5_79 = cute.get_leaves(%28) : !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> + %itup_80 = cute.to_int_tuple(%e1_75) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %29 = cute.get_scalars(%itup_80) : !cute.int_tuple<"?{i64}"> + %itup_81 = cute.to_int_tuple(%e3_77) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %30 = cute.get_scalars(%itup_81) : !cute.int_tuple<"?{i64}"> + %itup_82 = cute.to_int_tuple(%e4_78) : !cute.stride<"?{i64 div=1024}"> to !cute.int_tuple<"?{i64 div=1024}"> + %31 = cute.get_scalars(%itup_82) : !cute.int_tuple<"?{i64 div=1024}"> + %32 = cute.get_shape(%lay_23) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.shape<"((1,512,1),(?,?,?))"> + %e0_83, %e1_84, %e2_85, %e3_86, %e4_87, %e5_88 = cute.get_leaves(%32) : !cute.shape<"((1,512,1),(?,?,?))"> + %itup_89 = cute.to_int_tuple(%e3_86) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %33 = cute.get_scalars(%itup_89) : !cute.int_tuple<"?"> + %itup_90 = cute.to_int_tuple(%e4_87) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %34 = cute.get_scalars(%itup_90) : !cute.int_tuple<"?"> + %itup_91 = cute.to_int_tuple(%e5_88) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %35 = cute.get_scalars(%itup_91) : !cute.int_tuple<"?"> + %36 = cute.get_stride(%lay_23) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> + %e0_92, %e1_93, %e2_94, %e3_95, %e4_96, %e5_97 = cute.get_leaves(%36) : !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> + %itup_98 = cute.to_int_tuple(%e3_95) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %37 = cute.get_scalars(%itup_98) : !cute.int_tuple<"?{i64}"> + %itup_99 = cute.to_int_tuple(%e5_97) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %38 = cute.get_scalars(%itup_99) : !cute.int_tuple<"?{i64}"> + %39 = cute.get_shape(%lay_41) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> + %e0_100, %e1_101, %e2_102, %e3_103, %e4_104, %e5_105 = cute.get_leaves(%39) : !cute.shape<"((1,1024,1),(?,?,?))"> + %itup_106 = cute.to_int_tuple(%e3_103) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %40 = cute.get_scalars(%itup_106) : !cute.int_tuple<"?"> + %itup_107 = cute.to_int_tuple(%e4_104) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %41 = cute.get_scalars(%itup_107) : !cute.int_tuple<"?"> + %itup_108 = cute.to_int_tuple(%e5_105) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %42 = cute.get_scalars(%itup_108) : !cute.int_tuple<"?"> + %43 = cute.get_stride(%lay_41) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> + %e0_109, %e1_110, %e2_111, %e3_112, %e4_113, %e5_114 = cute.get_leaves(%43) : !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> + %44 = cute.get_shape(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.shape<"(128,8)"> + %e0_115, %e1_116 = cute.get_leaves(%44) : !cute.shape<"(128,8)"> + %45 = cute.get_stride(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.stride<"(8,1)"> + %e0_117, %e1_118 = cute.get_leaves(%45) : !cute.stride<"(8,1)"> + %46 = cute.get_shape(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.shape<"(128,4)"> + %e0_119, %e1_120 = cute.get_leaves(%46) : !cute.shape<"(128,4)"> + %47 = cute.get_stride(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.stride<"(4,1)"> + %e0_121, %e1_122 = cute.get_leaves(%47) : !cute.stride<"(4,1)"> + %48 = nvvm.read.ptx.sreg.tid.x : i32 + %49 = nvvm.read.ptx.sreg.ctaid.x : i32 + %coord = cute.make_coord(%49) : (i32) -> !cute.coord<"(_,?)"> + %slice = cute.slice(%arg0, %coord) : !memref_gmem_f32, !cute.coord<"(_,?)"> + %iter_123 = cute.get_iter(%slice) : !memref_gmem_f32_1 + %iter_124 = cute.get_iter(%slice) : !memref_gmem_f32_1 + %coord_125 = cute.make_coord(%49) : (i32) -> !cute.coord<"(_,?)"> + %slice_126 = cute.slice(%arg1, %coord_125) : !memref_gmem_i8, !cute.coord<"(_,?)"> + %iter_127 = cute.get_iter(%slice_126) : !memref_gmem_i8_1 + %iter_128 = cute.get_iter(%slice_126) : !memref_gmem_i8_1 + %coord_129 = cute.make_coord(%49) : (i32) -> !cute.coord<"(_,?)"> + %slice_130 = cute.slice(%arg2, %coord_129) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">, !cute.coord<"(_,?)"> + %iter_131 = cute.get_iter(%slice_130) : !cute.coord_tensor<"(?,?{div=1024},?)", "((1,1024,1)):((0,1@1,0))"> + %tup_132 = cute.deref_arith_tuple_iter(%iter_131) : !cute.arith_tuple_iter<"(?,?{div=1024},?)"> + %e0_133, %e1_134, %e2_135 = cute.get_leaves(%tup_132) : !cute.int_tuple<"(?,?{div=1024},?)"> + %50 = cute.get_scalars(%e0_133) : !cute.int_tuple<"?"> + %51 = cute.get_scalars(%e1_134) : !cute.int_tuple<"?{div=1024}"> + %52 = cute.get_scalars(%e2_135) : !cute.int_tuple<"?"> + %iter_136 = cute.get_iter(%slice_130) : !cute.coord_tensor<"(?,?{div=1024},?)", "((1,1024,1)):((0,1@1,0))"> + %tup_137 = cute.deref_arith_tuple_iter(%iter_136) : !cute.arith_tuple_iter<"(?,?{div=1024},?)"> + %e0_138, %e1_139, %e2_140 = cute.get_leaves(%tup_137) : !cute.int_tuple<"(?,?{div=1024},?)"> + %53 = cute.get_scalars(%e0_138) : !cute.int_tuple<"?"> + %54 = cute.get_scalars(%e1_139) : !cute.int_tuple<"?{div=1024}"> + %55 = cute.get_scalars(%e2_140) : !cute.int_tuple<"?"> + %56 = cute.composition(%slice, %arg3) : (!memref_gmem_f32_1, !cute.layout<"(128,8):(8,1)">) -> !memref_gmem_f32_2 + %iter_141 = cute.get_iter(%56) : !memref_gmem_f32_2 + %57 = cute.composition(%slice_126, %arg4) : (!memref_gmem_i8_1, !cute.layout<"(128,4):(4,1)">) -> !memref_gmem_i8_2 + %iter_142 = cute.get_iter(%57) : !memref_gmem_i8_2 + %58 = cute.composition(%slice_130, %arg3) : (!cute.coord_tensor<"(?,?{div=1024},?)", "((1,1024,1)):((0,1@1,0))">, !cute.layout<"(128,8):(8,1)">) -> !cute.coord_tensor<"(?,?{div=1024},?)", "(128,8):(8@1,1@1)"> + %iter_143 = cute.get_iter(%58) : !cute.coord_tensor<"(?,?{div=1024},?)", "(128,8):(8@1,1@1)"> + %tup_144 = cute.deref_arith_tuple_iter(%iter_143) : !cute.arith_tuple_iter<"(?,?{div=1024},?)"> + %e0_145, %e1_146, %e2_147 = cute.get_leaves(%tup_144) : !cute.int_tuple<"(?,?{div=1024},?)"> + %59 = cute.get_scalars(%e0_145) : !cute.int_tuple<"?"> + %60 = cute.get_scalars(%e1_146) : !cute.int_tuple<"?{div=1024}"> + %61 = cute.get_scalars(%e2_147) : !cute.int_tuple<"?"> + %coord_148 = cute.make_coord(%48) : (i32) -> !cute.coord<"(?,_)"> + %slice_149 = cute.slice(%56, %coord_148) : !memref_gmem_f32_2, !cute.coord<"(?,_)"> + %iter_150 = cute.get_iter(%slice_149) : !memref_gmem_f32_3 + %iter_151 = cute.get_iter(%slice_149) : !memref_gmem_f32_3 + %coord_152 = cute.make_coord(%48) : (i32) -> !cute.coord<"(?,_)"> + %slice_153 = cute.slice(%57, %coord_152) : !memref_gmem_i8_2, !cute.coord<"(?,_)"> + %iter_154 = cute.get_iter(%slice_153) : !memref_gmem_i8_3 + %iter_155 = cute.get_iter(%slice_153) : !memref_gmem_i8_3 + %coord_156 = cute.make_coord(%48) : (i32) -> !cute.coord<"(?,_)"> + %slice_157 = cute.slice(%58, %coord_156) : !cute.coord_tensor<"(?,?{div=1024},?)", "(128,8):(8@1,1@1)">, !cute.coord<"(?,_)"> + %iter_158 = cute.get_iter(%slice_157) : !cute.coord_tensor<"(?,?{div=8},?)", "(8):(1@1)"> + %tup_159 = cute.deref_arith_tuple_iter(%iter_158) : !cute.arith_tuple_iter<"(?,?{div=8},?)"> + %e0_160, %e1_161, %e2_162 = cute.get_leaves(%tup_159) : !cute.int_tuple<"(?,?{div=8},?)"> + %62 = cute.get_scalars(%e0_160) : !cute.int_tuple<"?"> + %63 = cute.get_scalars(%e1_161) : !cute.int_tuple<"?{div=8}"> + %64 = cute.get_scalars(%e2_162) : !cute.int_tuple<"?"> + %iter_163 = cute.get_iter(%slice_157) : !cute.coord_tensor<"(?,?{div=8},?)", "(8):(1@1)"> + %tup_164 = cute.deref_arith_tuple_iter(%iter_163) : !cute.arith_tuple_iter<"(?,?{div=8},?)"> + %e0_165, %e1_166, %e2_167 = cute.get_leaves(%tup_164) : !cute.int_tuple<"(?,?{div=8},?)"> + %65 = cute.get_scalars(%e0_165) : !cute.int_tuple<"?"> + %66 = cute.get_scalars(%e1_166) : !cute.int_tuple<"?{div=8}"> + %67 = cute.get_scalars(%e2_167) : !cute.int_tuple<"?"> + %coord_168 = cute.make_coord() : () -> !cute.coord<"0"> + %slice_169 = cute.slice(%slice_157, %coord_168) : !cute.coord_tensor<"(?,?{div=8},?)", "(8):(1@1)">, !cute.coord<"0"> + %iter_170 = cute.get_iter(%slice_169) : !cute.coord_tensor<"(?,?{div=8},?)", "():()"> + %tup_171 = cute.deref_arith_tuple_iter(%iter_170) : !cute.arith_tuple_iter<"(?,?{div=8},?)"> + %e0_172, %e1_173, %e2_174 = cute.get_leaves(%tup_171) : !cute.int_tuple<"(?,?{div=8},?)"> + %68 = cute.get_scalars(%e0_172) : !cute.int_tuple<"?"> + %69 = cute.get_scalars(%e1_173) : !cute.int_tuple<"?{div=8}"> + %70 = cute.get_scalars(%e2_174) : !cute.int_tuple<"?"> + %iter_175 = cute.get_iter(%slice_169) : !cute.coord_tensor<"(?,?{div=8},?)", "():()"> + %tup_176 = cute.deref_arith_tuple_iter(%iter_175) : !cute.arith_tuple_iter<"(?,?{div=8},?)"> + %e0_177, %e1_178, %e2_179 = cute.get_leaves(%tup_176) : !cute.int_tuple<"(?,?{div=8},?)"> + %71 = cute.get_scalars(%e0_177) : !cute.int_tuple<"?"> + %72 = cute.get_scalars(%e1_178) : !cute.int_tuple<"?{div=8}"> + %73 = cute.get_scalars(%e2_179) : !cute.int_tuple<"?"> + %iter_180 = cute.get_iter(%slice_169) : !cute.coord_tensor<"(?,?{div=8},?)", "():()"> + %tup_181 = cute.deref_arith_tuple_iter(%iter_180) : !cute.arith_tuple_iter<"(?,?{div=8},?)"> + %e0_182, %e1_183, %e2_184 = cute.get_leaves(%tup_181) : !cute.int_tuple<"(?,?{div=8},?)"> + %74 = cute.get_scalars(%e0_182) : !cute.int_tuple<"?"> + %75 = cute.get_scalars(%e1_183) : !cute.int_tuple<"?{div=8}"> + %76 = cute.get_scalars(%e2_184) : !cute.int_tuple<"?"> + %coord_185 = cute.make_coord(%e0_182, %e1_183, %e2_184) : (!cute.int_tuple<"?">, !cute.int_tuple<"?{div=8}">, !cute.int_tuple<"?">) -> !cute.coord<"(?,?{div=8},?)"> + %coord_186 = cute.make_coord(%arg5, %arg6, %arg7) : (i32, i32, i32) -> !cute.coord<"(?,?,?)"> + %77 = cute.elem_less(%coord_185, %coord_186) : !cute.coord<"(?,?{div=8},?)">, !cute.coord<"(?,?,?)"> + scf.if %77 { + %78 = cute.get_shape(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.shape<"(128,8)"> + %e0_187, %e1_188 = cute.get_leaves(%78) : !cute.shape<"(128,8)"> + %79 = cute.get_shape(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.shape<"(128,8)"> + %e0_189, %e1_190 = cute.get_leaves(%79) : !cute.shape<"(128,8)"> + %80 = cute.get(%arg3) <{mode = [1]}> : !cute.layout<"(128,8):(8,1)"> -> !cute.layout<"8:1"> + %rmem = cute.memref.alloca(%80) : !memref_rmem_f32 + %iter_191 = cute.get_iter(%rmem) : !memref_rmem_f32 + %iter_192 = cute.get_iter(%rmem) : !memref_rmem_f32 + %81 = cute.get_shape(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.shape<"(128,4)"> + %e0_193, %e1_194 = cute.get_leaves(%81) : !cute.shape<"(128,4)"> + %82 = cute.get_shape(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.shape<"(128,4)"> + %e0_195, %e1_196 = cute.get_leaves(%82) : !cute.shape<"(128,4)"> + %83 = cute.get(%arg4) <{mode = [1]}> : !cute.layout<"(128,4):(4,1)"> -> !cute.layout<"4:1"> + %rmem_197 = cute.memref.alloca(%83) : !memref_rmem_i8 + %iter_198 = cute.get_iter(%rmem_197) : !memref_rmem_i8 + %iter_199 = cute.get_iter(%rmem_197) : !memref_rmem_i8 + %atom = cute.make_atom() : () -> !cute_nvgpu.atom.universal_copy + cute.copy(%atom, %slice_149, %rmem) : (!cute_nvgpu.atom.universal_copy, !memref_gmem_f32_3, !memref_rmem_f32) + %lay_200 = cute.get_layout(%rmem) : !memref_rmem_f32 + %84 = cute.get_shape(%lay_200) : (!cute.layout<"8:1">) -> !cute.shape<"8"> + %e0_201 = cute.get_leaves(%84) : !cute.shape<"8"> + %85 = cute.memref.load_vec %rmem, row_major : !memref_rmem_f32 + %86 = nvgpu.cvt_fptrunc %85 : vector<8xf32> to vector<8xf4E2M1FN> + %shape = cute.make_shape() : () -> !cute.shape<"8"> + %lay_202 = cute.make_layout(%shape) : !cute.layout<"8:1"> + %87 = cute.recast_layout<8, 4> (%lay_202) : !cute.layout<"8:1"> to !cute.layout<"4:1"> + %88 = cute.get_shape(%87) : (!cute.layout<"4:1">) -> !cute.shape<"4"> + %e0_203 = cute.get_leaves(%88) : !cute.shape<"4"> + %89 = builtin.unrealized_conversion_cast %86 : vector<8xf4E2M1FN> to vector<8xi4> + %90 = vector.bitcast %89 : vector<8xi4> to vector<4xi8> + %lay_204 = cute.get_layout(%rmem_197) : !memref_rmem_i8 + %91 = cute.get_shape(%lay_204) : (!cute.layout<"4:1">) -> !cute.shape<"4"> + %e0_205 = cute.get_leaves(%91) : !cute.shape<"4"> + %int_tuple = cute.make_int_tuple() : () -> !cute.int_tuple<"4"> + %sz = cute.size(%int_tuple) : (!cute.int_tuple<"4">) -> !cute.int_tuple<"4"> + %e0_206 = cute.get_leaves(%sz) : !cute.int_tuple<"4"> + %int_tuple_207 = cute.make_int_tuple() : () -> !cute.int_tuple<"4"> + %sz_208 = cute.size(%int_tuple_207) : (!cute.int_tuple<"4">) -> !cute.int_tuple<"4"> + %e0_209 = cute.get_leaves(%sz_208) : !cute.int_tuple<"4"> + cute.memref.store_vec %90, %rmem_197, row_major : !memref_rmem_i8 + %atom_210 = cute.make_atom() : () -> !cute_nvgpu.atom.universal_copy + cute.copy(%atom_210, %rmem_197, %slice_153) : (!cute_nvgpu.atom.universal_copy, !memref_rmem_i8, !memref_gmem_i8_3) + } + return + } + } + func.func @cutlass__convert_Tensorgmemoi64i641_Tensorgmemoi641i64_1_8(%arg0: !memref_gmem_f32_4, %arg1: !memref_gmem_f4E2M1FN) attributes {llvm.emit_c_interface} { + %iter = cute.get_iter(%arg0) : !memref_gmem_f32_4 + %iter_0 = cute.get_iter(%arg1) : !memref_gmem_f4E2M1FN + %iter_1 = cute.get_iter(%arg0) : !memref_gmem_f32_4 + %iter_2 = cute.get_iter(%arg1) : !memref_gmem_f4E2M1FN + %lay = cute.get_layout(%arg0) : !memref_gmem_f32_4 + %0 = cute.get_shape(%lay) : (!cute.layout<"(?,?,?):(?{i64},?{i64},1)">) -> !cute.shape<"(?,?,?)"> + %e0, %e1, %e2 = cute.get_leaves(%0) : !cute.shape<"(?,?,?)"> + %itup = cute.to_int_tuple(%e0) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %1 = cute.get_scalars(%itup) : !cute.int_tuple<"?"> + %itup_3 = cute.to_int_tuple(%e1) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %2 = cute.get_scalars(%itup_3) : !cute.int_tuple<"?"> + %itup_4 = cute.to_int_tuple(%e2) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %3 = cute.get_scalars(%itup_4) : !cute.int_tuple<"?"> + %4 = cute.get_stride(%lay) : (!cute.layout<"(?,?,?):(?{i64},?{i64},1)">) -> !cute.stride<"(?{i64},?{i64},1)"> + %e0_5, %e1_6, %e2_7 = cute.get_leaves(%4) : !cute.stride<"(?{i64},?{i64},1)"> + %itup_8 = cute.to_int_tuple(%e0_5) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %5 = cute.get_scalars(%itup_8) : !cute.int_tuple<"?{i64}"> + %itup_9 = cute.to_int_tuple(%e1_6) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %6 = cute.get_scalars(%itup_9) : !cute.int_tuple<"?{i64}"> + %lay_10 = cute.get_layout(%arg1) : !memref_gmem_f4E2M1FN + %7 = cute.get_shape(%lay_10) : (!cute.layout<"(?,?,?):(?{i64},1,?{i64})">) -> !cute.shape<"(?,?,?)"> + %e0_11, %e1_12, %e2_13 = cute.get_leaves(%7) : !cute.shape<"(?,?,?)"> + %itup_14 = cute.to_int_tuple(%e0_11) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %8 = cute.get_scalars(%itup_14) : !cute.int_tuple<"?"> + %itup_15 = cute.to_int_tuple(%e1_12) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %9 = cute.get_scalars(%itup_15) : !cute.int_tuple<"?"> + %itup_16 = cute.to_int_tuple(%e2_13) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %10 = cute.get_scalars(%itup_16) : !cute.int_tuple<"?"> + %11 = cute.get_stride(%lay_10) : (!cute.layout<"(?,?,?):(?{i64},1,?{i64})">) -> !cute.stride<"(?{i64},1,?{i64})"> + %e0_17, %e1_18, %e2_19 = cute.get_leaves(%11) : !cute.stride<"(?{i64},1,?{i64})"> + %itup_20 = cute.to_int_tuple(%e0_17) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %12 = cute.get_scalars(%itup_20) : !cute.int_tuple<"?{i64}"> + %itup_21 = cute.to_int_tuple(%e2_19) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %13 = cute.get_scalars(%itup_21) : !cute.int_tuple<"?{i64}"> + %shape = cute.make_shape() : () -> !cute.shape<"(128,8)"> + %stride = cute.make_stride() : () -> !cute.stride<"(8,1)"> + %lay_22 = cute.make_layout(%shape, %stride) : !cute.layout<"(128,8):(8,1)"> + %14 = cute.recast_layout<8, 4> (%lay_22) : !cute.layout<"(128,8):(8,1)"> to !cute.layout<"(128,4):(4,1)"> + %iter_23 = cute.recast_iter(%iter_2) : !cute.ptr> to !cute.ptr> + %15 = cute.recast_layout<8, 4> (%lay_10) : !cute.layout<"(?,?,?):(?{i64},1,?{i64})"> to !cute.layout<"(?,?,?):(?{i64},1,?{i64})"> + %view = cute.make_view(%iter_23, %15) : !memref_gmem_i8_4 + %iter_24 = cute.get_iter(%view) : !memref_gmem_i8_4 + %16 = cute.get_shape(%lay) : (!cute.layout<"(?,?,?):(?{i64},?{i64},1)">) -> !cute.shape<"(?,?,?)"> + %e0_25, %e1_26, %e2_27 = cute.get_leaves(%16) : !cute.shape<"(?,?,?)"> + %itup_28 = cute.to_int_tuple(%e0_25) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %17 = cute.get_scalars(%itup_28) : !cute.int_tuple<"?"> + %itup_29 = cute.to_int_tuple(%e1_26) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %18 = cute.get_scalars(%itup_29) : !cute.int_tuple<"?"> + %itup_30 = cute.to_int_tuple(%e2_27) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %19 = cute.get_scalars(%itup_30) : !cute.int_tuple<"?"> + %shape_31 = cute.make_shape(%itup_28, %itup_29, %itup_30) : (!cute.int_tuple<"?">, !cute.int_tuple<"?">, !cute.int_tuple<"?">) -> !cute.shape<"(?,?,?)"> + %20 = cute.make_identity_tensor(%shape_31) : !cute.coord_tensor<"(0,0,0)", "(?,?,?):(1@0,1@1,1@2)"> + %iter_32 = cute.get_iter(%20) : !cute.coord_tensor<"(0,0,0)", "(?,?,?):(1@0,1@1,1@2)"> + %tup = cute.deref_arith_tuple_iter(%iter_32) : !cute.arith_tuple_iter<"(0,0,0)"> + %e0_33, %e1_34, %e2_35 = cute.get_leaves(%tup) : !cute.int_tuple<"(0,0,0)"> + %21 = cute.get_shape(%lay) : (!cute.layout<"(?,?,?):(?{i64},?{i64},1)">) -> !cute.shape<"(?,?,?)"> + %e0_36, %e1_37, %e2_38 = cute.get_leaves(%21) : !cute.shape<"(?,?,?)"> + %itup_39 = cute.to_int_tuple(%e0_36) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %22 = cute.get_scalars(%itup_39) : !cute.int_tuple<"?"> + %itup_40 = cute.to_int_tuple(%e1_37) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %23 = cute.get_scalars(%itup_40) : !cute.int_tuple<"?"> + %itup_41 = cute.to_int_tuple(%e2_38) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %24 = cute.get_scalars(%itup_41) : !cute.int_tuple<"?"> + %sz = cute.size(%lay_22) : (!cute.layout<"(128,8):(8,1)">) -> !cute.int_tuple<"1024"> + %e0_42 = cute.get_leaves(%sz) : !cute.int_tuple<"1024"> + %lay_43 = cute.get_layout(%view) : !memref_gmem_i8_4 + %25 = cute.get_shape(%lay_43) : (!cute.layout<"(?,?,?):(?{i64},1,?{i64})">) -> !cute.shape<"(?,?,?)"> + %e0_44, %e1_45, %e2_46 = cute.get_leaves(%25) : !cute.shape<"(?,?,?)"> + %itup_47 = cute.to_int_tuple(%e0_44) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %26 = cute.get_scalars(%itup_47) : !cute.int_tuple<"?"> + %itup_48 = cute.to_int_tuple(%e1_45) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %27 = cute.get_scalars(%itup_48) : !cute.int_tuple<"?"> + %itup_49 = cute.to_int_tuple(%e2_46) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %28 = cute.get_scalars(%itup_49) : !cute.int_tuple<"?"> + %sz_50 = cute.size(%14) : (!cute.layout<"(128,4):(4,1)">) -> !cute.int_tuple<"512"> + %e0_51 = cute.get_leaves(%sz_50) : !cute.int_tuple<"512"> + %tile = cute.make_tile() : () -> !cute.tile<"[1:0;1024:1;1:0]"> + %div = cute.zipped_divide(%arg0, %tile) : !memref_gmem_f32_4, !cute.tile<"[1:0;1024:1;1:0]"> + %iter_52 = cute.get_iter(%div) : !memref_gmem_f32 + %iter_53 = cute.get_iter(%div) : !memref_gmem_f32 + %tile_54 = cute.make_tile() : () -> !cute.tile<"[1:0;1024:1;1:0]"> + %div_55 = cute.zipped_divide(%20, %tile_54) : !cute.coord_tensor<"(0,0,0)", "(?,?,?):(1@0,1@1,1@2)">, !cute.tile<"[1:0;1024:1;1:0]"> + %iter_56 = cute.get_iter(%div_55) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> + %tup_57 = cute.deref_arith_tuple_iter(%iter_56) : !cute.arith_tuple_iter<"(0,0,0)"> + %e0_58, %e1_59, %e2_60 = cute.get_leaves(%tup_57) : !cute.int_tuple<"(0,0,0)"> + %iter_61 = cute.get_iter(%div_55) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> + %tup_62 = cute.deref_arith_tuple_iter(%iter_61) : !cute.arith_tuple_iter<"(0,0,0)"> + %e0_63, %e1_64, %e2_65 = cute.get_leaves(%tup_62) : !cute.int_tuple<"(0,0,0)"> + %tile_66 = cute.make_tile() : () -> !cute.tile<"[1:0;512:1;1:0]"> + %div_67 = cute.zipped_divide(%view, %tile_66) : !memref_gmem_i8_4, !cute.tile<"[1:0;512:1;1:0]"> + %iter_68 = cute.get_iter(%div_67) : !memref_gmem_i8 + %iter_69 = cute.get_iter(%div_67) : !memref_gmem_i8 + %sz_70 = cute.size(%div) <{mode = [1]}> : (!memref_gmem_f32) -> !cute.int_tuple<"?"> + %e0_71 = cute.get_leaves(%sz_70) : !cute.int_tuple<"?"> + %29 = cute.get_scalars(%e0_71) : !cute.int_tuple<"?"> + %sz_72 = cute.size(%lay_22) <{mode = [0]}> : (!cute.layout<"(128,8):(8,1)">) -> !cute.int_tuple<"128"> + %e0_73 = cute.get_leaves(%sz_72) : !cute.int_tuple<"128"> + %lay_74 = cute.get_layout(%div) : !memref_gmem_f32 + %30 = cute.get_shape(%lay_74) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> + %e0_75, %e1_76, %e2_77, %e3, %e4, %e5 = cute.get_leaves(%30) : !cute.shape<"((1,1024,1),(?,?,?))"> + %itup_78 = cute.to_int_tuple(%e3) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %31 = cute.get_scalars(%itup_78) : !cute.int_tuple<"?"> + %itup_79 = cute.to_int_tuple(%e4) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %32 = cute.get_scalars(%itup_79) : !cute.int_tuple<"?"> + %itup_80 = cute.to_int_tuple(%e5) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %33 = cute.get_scalars(%itup_80) : !cute.int_tuple<"?"> + %34 = cute.get_stride(%lay_74) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> + %e0_81, %e1_82, %e2_83, %e3_84, %e4_85, %e5_86 = cute.get_leaves(%34) : !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> + %itup_87 = cute.to_int_tuple(%e1_82) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %35 = cute.get_scalars(%itup_87) : !cute.int_tuple<"?{i64}"> + %itup_88 = cute.to_int_tuple(%e3_84) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %36 = cute.get_scalars(%itup_88) : !cute.int_tuple<"?{i64}"> + %itup_89 = cute.to_int_tuple(%e4_85) : !cute.stride<"?{i64 div=1024}"> to !cute.int_tuple<"?{i64 div=1024}"> + %37 = cute.get_scalars(%itup_89) : !cute.int_tuple<"?{i64 div=1024}"> + %lay_90 = cute.get_layout(%div_67) : !memref_gmem_i8 + %38 = cute.get_shape(%lay_90) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.shape<"((1,512,1),(?,?,?))"> + %e0_91, %e1_92, %e2_93, %e3_94, %e4_95, %e5_96 = cute.get_leaves(%38) : !cute.shape<"((1,512,1),(?,?,?))"> + %itup_97 = cute.to_int_tuple(%e3_94) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %39 = cute.get_scalars(%itup_97) : !cute.int_tuple<"?"> + %itup_98 = cute.to_int_tuple(%e4_95) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %40 = cute.get_scalars(%itup_98) : !cute.int_tuple<"?"> + %itup_99 = cute.to_int_tuple(%e5_96) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %41 = cute.get_scalars(%itup_99) : !cute.int_tuple<"?"> + %42 = cute.get_stride(%lay_90) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> + %e0_100, %e1_101, %e2_102, %e3_103, %e4_104, %e5_105 = cute.get_leaves(%42) : !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> + %itup_106 = cute.to_int_tuple(%e3_103) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %43 = cute.get_scalars(%itup_106) : !cute.int_tuple<"?{i64}"> + %itup_107 = cute.to_int_tuple(%e5_105) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> + %44 = cute.get_scalars(%itup_107) : !cute.int_tuple<"?{i64}"> + %lay_108 = cute.get_layout(%div_55) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> + %45 = cute.get_shape(%lay_108) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> + %e0_109, %e1_110, %e2_111, %e3_112, %e4_113, %e5_114 = cute.get_leaves(%45) : !cute.shape<"((1,1024,1),(?,?,?))"> + %itup_115 = cute.to_int_tuple(%e3_112) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %46 = cute.get_scalars(%itup_115) : !cute.int_tuple<"?"> + %itup_116 = cute.to_int_tuple(%e4_113) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %47 = cute.get_scalars(%itup_116) : !cute.int_tuple<"?"> + %itup_117 = cute.to_int_tuple(%e5_114) : !cute.shape<"?"> to !cute.int_tuple<"?"> + %48 = cute.get_scalars(%itup_117) : !cute.int_tuple<"?"> + %49 = cute.get_stride(%lay_108) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> + %e0_118, %e1_119, %e2_120, %e3_121, %e4_122, %e5_123 = cute.get_leaves(%49) : !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> + %50 = cute.get_shape(%lay_22) : (!cute.layout<"(128,8):(8,1)">) -> !cute.shape<"(128,8)"> + %e0_124, %e1_125 = cute.get_leaves(%50) : !cute.shape<"(128,8)"> + %51 = cute.get_stride(%lay_22) : (!cute.layout<"(128,8):(8,1)">) -> !cute.stride<"(8,1)"> + %e0_126, %e1_127 = cute.get_leaves(%51) : !cute.stride<"(8,1)"> + %52 = cute.get_shape(%14) : (!cute.layout<"(128,4):(4,1)">) -> !cute.shape<"(128,4)"> + %e0_128, %e1_129 = cute.get_leaves(%52) : !cute.shape<"(128,4)"> + %53 = cute.get_stride(%14) : (!cute.layout<"(128,4):(4,1)">) -> !cute.stride<"(4,1)"> + %e0_130, %e1_131 = cute.get_leaves(%53) : !cute.stride<"(4,1)"> + %c0_i32 = arith.constant 0 : i32 + %54 = arith.index_cast %29 : i32 to index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + gpu.launch_func @kernels::@kernel_cutlass__convert_kernel_tensorptrf32gmemo11024100div10241_tensorptri8gmemalign16o15121010512_tensor000o1102410110101024112____Float32_Float4E2M1FN_0 blocks in (%54, %c1, %c1) threads in (%c128, %c1, %c1) dynamic_shared_memory_size %c0_i32 args(%div : !memref_gmem_f32, %div_67 : !memref_gmem_i8, %div_55 : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">, %lay_22 : !cute.layout<"(128,8):(8,1)">, %14 : !cute.layout<"(128,4):(4,1)">, %17 : i32, %18 : i32, %19 : i32) {use_pdl = false} + return + } +} diff --git a/problems/nvidia/nvfp4_gemv/log b/problems/nvidia/nvfp4_gemv/log new file mode 100644 index 0000000..bc55151 --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/log @@ -0,0 +1,15 @@ +============================================================ +Launching Blackwell NVFP4 GEMV Test +------------------------------------------------------------ +Input dimensions: + A: (1, 512, 256) [l: batch size, m: rows, k: cols] + b: (1, 256) [l: batch size, k: length] + c: (1, 512) [l: batch size, m: length] +Data types: + A/b dtype: Float4E2M1FN + Scaling factor dtype: Float8E8M0FNU (vector size: 16) + Output C dtype: Float16 +Validation tolerance: 0.1 +============================================================ +c_ref and ref are close within tolerance. +PASS diff --git a/problems/nvidia/nvfp4_gemv/nvfp4_gemv_cute_layout.py b/problems/nvidia/nvfp4_gemv/nvfp4_gemv_cute_layout.py new file mode 100644 index 0000000..662aa92 --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/nvfp4_gemv_cute_layout.py @@ -0,0 +1,426 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import cuda.bindings.driver as cuda + +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack +import cutlass.utils.blockscaled_layout as blockscaled_utils + +mma_tiler_mnk = (128, 1, 64) +ab_dtype = cutlass.Float4E2M1FN +sf_dtype = cutlass.Float8E8M0FNU +c_dtype = cutlass.Float16 +sf_vec_size = 16 + +""" +Below code gives a reference for NVFP4 block-scaled GEMV (General Matrix-Vector Multiplication): + +Given: + - A: a matrix of shape (l, m, k), where l is the batch size, m is the number of rows, k is the number of columns. The data type is Float4E2M1FN + - SFA: a matrix of shape (l, m, k//scaling_factor_vector), where l is the batch size, m is the number of rows, k is the number of columns, and scaling factor vector size means these elements will share the same scaling factor. The data type is Float8E8M0FNU. The layout matches definition here https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout. + - b: a batched vector of shape (l, k) and the data type is Float4E2M1FN. + - SFB: a matrix of shape (l, k//scaling_factor_vector, 128) and the data type is Float8E8M0FNU. The layout matches definition here https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout. + - c: the output batched vector of shape (l, m) and the data type is Float16. + +Operation: + c = A * b + +Assumptions: + - The matrix A is stored in memory such that the k (column) dimension is contiguous + - The m dimension is a multiple of 128 + - The k dimension is a multiple of 64 + +""" + + +class Sm100BlockScaledDenseGemvKernel: + def __init__(self): + self.threads_per_cta = 128 + + @cute.jit + def __call__( + self, + a_tensor: cute.Tensor, + b_tensor: cute.Tensor, + sfa_tensor: cute.Tensor, + sfb_tensor: cute.Tensor, + c_tensor: cute.Tensor, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + # (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) + # (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, sf_vec_size + ) + sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) + # Compute grid size + grid = ( + cute.ceil_div(c_tensor.shape[0], 128), + 1, + c_tensor.shape[2], + ) + # Launch the kernel synchronously + self.kernel(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + mA_mkl: cute.Tensor, + mB_nkl: cute.Tensor, + mSFA_mkl: cute.Tensor, + mSFB_nkl: cute.Tensor, + mC_mnl: cute.Tensor, + ): + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + # mma_coord_mnk = (bidx, bidy, bidz) + + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bM, bK, RestM, RestK, RestL) + # bM = (32, 4) + # bK = (16, 4) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) + ) + + tCgC = gC_mnl[tidx, None, bidx, bidy, bidz] + tCgC = cute.make_tensor(tCgC.iterator, 1) + res = cute.zeros_like(tCgC, cutlass.Float32) + + k_tile_cnt = gA_mkl.layout[3].shape + for k_tile in range(k_tile_cnt): + tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz] + tBgB = gB_nkl[None, None, bidy, k_tile, bidz] + tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz] + tBgSFB = gSFB_nkl[None, None, bidy, k_tile, bidz] + + # Create Tensor for A/B/SFA/SFB tile + tAgA = cute.make_tensor(tAgA.iterator, mma_tiler_mnk[2]) + tBgB = cute.make_tensor(tBgB.iterator, mma_tiler_mnk[2]) + tAgSFA = cute.make_tensor(tAgSFA.iterator, 4) + tBgSFB = cute.make_tensor(tBgSFB.iterator, 4) + + # Load A/B/SFA/SFB tile from global memory + a_val_nvfp4 = tAgA.load() + b_val_nvfp4 = tBgB.load() + sfa_val_fp8 = tAgSFA.load() + sfb_val_fp8 = tBgSFB.load() + + # Convert to f32 for FFMA computation + a_val = a_val_nvfp4.to(cutlass.Float32) + b_val = b_val_nvfp4.to(cutlass.Float32) + sfa_val = sfa_val_fp8.to(cutlass.Float32) + sfb_val = sfb_val_fp8.to(cutlass.Float32) + + for i in cutlass.range_constexpr(mma_tiler_mnk[2] // sf_vec_size): + for j in cutlass.range_constexpr(sf_vec_size): + res += ( + a_val[i * sf_vec_size + j] + * sfa_val[i] + * b_val[i * sf_vec_size + j] + * sfb_val[i] + ) + tCgC.store(res.to(cutlass.Float16)) + return + + +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_tensor: cute.Tensor, + sf_mma_tensor: cute.Tensor, +): + """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout""" + # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) + # group to ((32, 4, rest_m), (4, rest_k), l) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] + + +def run_gemv( + m: int, + k: int, + l: int, + tolerance: float, +): + """ + Prepare A/B/SFA/SFB/C tensors, launch GPU kernel, and reference checking. + """ + print("=" * 60) + print("Launching Blackwell NVFP4 GEMV Test") + print("-" * 60) + print("Input dimensions:") + print(f" A: ({l}, {m}, {k}) [l: batch size, m: rows, k: cols]") + print(f" b: ({l}, {k}) [l: batch size, k: length]") + print(f" c: ({l}, {m}) [l: batch size, m: length]") + print("Data types:") + print(f" A/b dtype: {ab_dtype}") + print(f" Scaling factor dtype: {sf_dtype} (vector size: {sf_vec_size})") + print(f" Output C dtype: {c_dtype}") + print(f"Validation tolerance: {tolerance}") + print("=" * 60) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # GEMV, N must be 1 + n = 1 + + # Create tensor A/B/C + a_ref = cutlass_torch.matrix(l, m, k, False, cutlass.Float32) + b_ref = cutlass_torch.matrix(l, n, k, False, cutlass.Float32) + c_ref = cutlass_torch.matrix(l, m, n, True, cutlass.Float32) + a_tensor, a_torch = cutlass_torch.cute_tensor_like( + a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, b_torch = cutlass_torch.cute_tensor_like( + b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch = cutlass_torch.cute_tensor_like( + c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + # Mark tensor with element divisibility for 16B alignment + a_tensor.mark_compact_shape_dynamic( + mode=1, + stride_order=(2, 0, 1), + divisibility=32, + ) + b_tensor.mark_compact_shape_dynamic( + mode=1, + stride_order=(2, 0, 1), + divisibility=32, + ) + c_tensor.mark_compact_shape_dynamic( + 0, + (2, 1, 0), + divisibility=16, + ) + + # + # Helper function to create scale factor tensor SFA/SFB + # for 1x16 block scaled wise use case and follow the layout requirement + # defined in https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout + # + def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype): + def ceil_div(a, b): + return (a + b - 1) // b + + sf_k = ceil_div(k, sf_vec_size) + ref_shape = (l, mn, sf_k) + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + ref_permute_order = (1, 2, 0) + mma_permute_order = (3, 4, 1, 5, 2, 0) + + # Create f32 ref torch tensor (cpu) + ref_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + ref_shape, + torch.float32, + permute_order=ref_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=1, + max_val=3, + ), + ) + + # Create f32 cute torch tensor (cpu) + cute_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + mma_shape, + torch.float32, + permute_order=mma_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0, + max_val=1, + ), + ) + + # convert ref f32 tensor to cute f32 tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(ref_f32_torch_tensor_cpu), + from_dlpack(cute_f32_torch_tensor_cpu), + ) + cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() + + # reshape makes memory contiguous + ref_f32_torch_tensor_cpu = ( + ref_f32_torch_tensor_cpu.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, mn, sf_k, sf_vec_size) + .reshape(l, mn, sf_k * sf_vec_size) + .permute(*ref_permute_order) + ) + # prune to mkl for reference check. + ref_f32_torch_tensor_cpu = ref_f32_torch_tensor_cpu[:, :k, :] + + # Create dtype cute torch tensor (cpu) + cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( + cute_f32_torch_tensor_cpu, + dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + + # Convert f32 cute tensor to dtype cute tensor + cute_tensor = cutlass_torch.convert_cute_tensor( + cute_f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=True, + ) + return ref_f32_torch_tensor_cpu, cute_tensor, cute_torch_tensor + + sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( + l, m, k, sf_vec_size, sf_dtype + ) + sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( + l, 1, k, sf_vec_size, sf_dtype + ) + + # Configure gemv kernel + gemv = Sm100BlockScaledDenseGemvKernel() + # Initialize Stream + current_stream = cutlass_torch.default_stream() + # Compile gemv kernel + compiled_gemv = cute.compile( + gemv, + a_tensor, + b_tensor, + sfa_tensor, + sfb_tensor, + c_tensor, + current_stream, + ) + + # Launch GPU kernel + compiled_gemv(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, current_stream) + + # Compute reference result, simulate NVFP4 GEMV via 2 FFMA based elementwise multiplication and 1 FFMA based matmul computations + res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref) + res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref) + ref = torch.einsum("mkl,nkl->mnl", res_a, res_b) + + # Convert c back to f32 for comparison. + c_ref_device = c_ref.cuda() + cute.testing.convert( + c_tensor, + from_dlpack(c_ref_device, assumed_align=16).mark_layout_dynamic(leading_dim=0), + ) + c_ref = c_ref_device.cpu() + torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Example of Sm100 Dense BlockScaled GEMV." + ) + parser.add_argument( + "--m", + type=int, + default=512, + help="m dimensions", + ) + parser.add_argument( + "--k", + type=int, + default=256, + help="m dimensions", + ) + parser.add_argument( + "--l", + type=int, + default=1, + help="l dimension", + ) + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + args = parser.parse_args() + + if args.k % mma_tiler_mnk[2] != 0: + raise ValueError("K must be a multiple of 64 for this GEMV kernel.") + if args.m % mma_tiler_mnk[0] != 0: + raise ValueError("M must be a multiple of 128 for this GEMV kernel.") + + run_gemv( + args.m, + args.k, + args.l, + args.tolerance, + ) + print("PASS") diff --git a/problems/nvidia/gemv/reference.py b/problems/nvidia/nvfp4_gemv/reference.py similarity index 100% rename from problems/nvidia/gemv/reference.py rename to problems/nvidia/nvfp4_gemv/reference.py diff --git a/problems/nvidia/gemv/submission.py b/problems/nvidia/nvfp4_gemv/submission.py similarity index 100% rename from problems/nvidia/gemv/submission.py rename to problems/nvidia/nvfp4_gemv/submission.py diff --git a/problems/nvidia/gemv/task.py b/problems/nvidia/nvfp4_gemv/task.py similarity index 100% rename from problems/nvidia/gemv/task.py rename to problems/nvidia/nvfp4_gemv/task.py diff --git a/problems/nvidia/gemv/task.yml b/problems/nvidia/nvfp4_gemv/task.yml similarity index 100% rename from problems/nvidia/gemv/task.yml rename to problems/nvidia/nvfp4_gemv/task.yml diff --git a/problems/nvidia/gemv/template.py b/problems/nvidia/nvfp4_gemv/template.py similarity index 100% rename from problems/nvidia/gemv/template.py rename to problems/nvidia/nvfp4_gemv/template.py diff --git a/problems/nvidia/nvfp4_gemv/test_python_1.sh b/problems/nvidia/nvfp4_gemv/test_python_1.sh new file mode 100644 index 0000000..c7087c1 --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/test_python_1.sh @@ -0,0 +1,86 @@ +# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build +BUILD_DIR=/home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/build_non_docker +LLVM_DIR=$BUILD_DIR/llvm-prebuilt +# # BUILD_DIR=/home/scratch.ftse_gpu/workspace/dkg/build +# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build +# #BUILD_DIR=/home/yanchengz/scratch_1/dynamic-kernel-generator/build_debug2 +# # sudo /home/scratch.computelab/utils/driver/install_driver.py --installer=/home/builds/daily/display/x86_64/rel/gpu_drv/r580/r580_00/20250527_36037303/NVIDIA-Linux-x86_64-rel_gpu_drv_r580_r580_00-20250527_36037303-internal.run --reason="Change to tot driver" + + +# # BUILD_DIR=/home/scratch.nbommi_gpu/warp-phase-trace/dynamic-kernel-generator/build_main + +export PYTHONPATH=$BUILD_DIR/cutlass_ir/python_packages +#export PYTHONPATH=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/scripts +export CUDA_TOOLKIT_PATH=$BUILD_DIR/compiler_next +MLIR_CUDA_RUNTIME="$LLVM_DIR/lib/libmlir_cuda_runtime.so" +MLIR_C_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_c_runner_utils.so" +MLIR_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_runner_utils.so" +CUDA_DIALECT_RUNTIME="$BUILD_DIR/lib/libcuda_dialect_runtime.so" +export CUTE_DSL_LIBS="$MLIR_CUDA_RUNTIME:$MLIR_C_RUNNER_UTILS:$MLIR_RUNNER_UTILS:$CUDA_DIALECT_RUNTIME" + + +#export CUTE_DSL_PREPROCESSOR=True + +# export CUTE_DSL_PRINT_IR=1 +# just compile the IR but not execute it +# export CUTE_DSL_DRYRUN=1 +# export CUTE_DSL_JIT_TIME_PROFILING=ON +# export CUTE_DSL_KEEP_IR=True +# export CUTE_DSL_PRINT_IR=1 +# export CUTE_DSL_KEEP_CUBIN=1 +# export CUTE_DSL_LINEINFO=True +# export CUTE_DSL_LOG_TO_CONSOLE=1 +# export PYTHONUNBUFFERED=1 +# export CUTE_DSL_KEEP_SASS=1 +# whether to show detailed log in preprocessing +# export CUTE_DSL_FILTER_STACKTRACE=10 +export CUTE_DSL_ARCH=sm_100a + +# +/home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py +/home/scratch.vickiw_gpu/env/bin/python3 eval.py test task.yml +/home/scratch.vickiw_gpu/env/bin/python3 eval.py benchmark task.yml +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/cuda-gdb --args + +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py +# # /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gated_dual_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gecccccbkvnjtrvtfreufijlfglnudnvuggvdfucidbnhk +# mm.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemm/nvfp4_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemv/nvfp4_gemv.py +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool memcheck \ +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,16384 #135us +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 4096,128,7168 #62 + +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,2048 #26 + + +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gated_dual_gemm/nvfp4_gated_dual_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_naive.py + + + +# print out ncu time +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# python3 vicki/tutorial_fp16_gemm_0__.py --mnk 7168,8,512 + +# use sanitizer to check race contention and memref error +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck|memcheck +# cutlass_ir/compiler/test/python/examples/sm_100a/test_nvfp4_gemv.py + +# capture ncu report +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --check-exit-code 0 -f --set full --import-source yes --target-processes all --clock-control base --cache-control none -o gemv_4.1 \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv.py --m 128 --k 128 --l 2 + +# regular run python example +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/min_latency_hmma.py --mnkl 7168,8,512,1 + +# run pytest +# pytest cutlass_ir/compiler/test/python/examples/sm_80/test_sgemm.py diff --git a/problems/nvidia/nvfp4_gemv/utils.py b/problems/nvidia/nvfp4_gemv/utils.py new file mode 100644 index 0000000..e8a9082 --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/utils.py @@ -0,0 +1,176 @@ +import os +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + + Raises: + AssertionError: If the tensors are not all close within the given tolerance. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received - expected) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +@torch.no_grad() +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) + + if len(reasons) > 0: + return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) + + return True, '' + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + return wrapped + + +class DeterministicContext: + def __init__(self): + self.allow_tf32 = None + self.deterministic = None + self.cublas = None + + def __enter__(self): + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + torch.use_deterministic_algorithms(False) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy \ No newline at end of file From eb7cf9ec30e21d5e3e76a720721b0847d7b4e618 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Thu, 9 Oct 2025 22:52:43 -0700 Subject: [PATCH 05/29] remove useless files --- problems/nvidia/nvfp4_gemv/eval.py | 426 -------------------- problems/nvidia/nvfp4_gemv/kernel.mlir | 409 ------------------- problems/nvidia/nvfp4_gemv/log | 15 - problems/nvidia/nvfp4_gemv/test_python_1.sh | 86 ---- problems/nvidia/nvfp4_gemv/utils.py | 176 -------- 5 files changed, 1112 deletions(-) delete mode 100644 problems/nvidia/nvfp4_gemv/eval.py delete mode 100644 problems/nvidia/nvfp4_gemv/kernel.mlir delete mode 100644 problems/nvidia/nvfp4_gemv/log delete mode 100644 problems/nvidia/nvfp4_gemv/test_python_1.sh delete mode 100644 problems/nvidia/nvfp4_gemv/utils.py diff --git a/problems/nvidia/nvfp4_gemv/eval.py b/problems/nvidia/nvfp4_gemv/eval.py deleted file mode 100644 index 890668a..0000000 --- a/problems/nvidia/nvfp4_gemv/eval.py +++ /dev/null @@ -1,426 +0,0 @@ -import base64 -import dataclasses -import multiprocessing -import re -import time -import os -import sys -import math -from pathlib import Path -from typing import Any, Optional -import tempfile - -import torch.cuda -from cutlass.cute.nvgpu.common import OpError - -from utils import set_seed, clear_l2_cache - -try: - from task import TestSpec -except ImportError: - TestSpec = dict - -from reference import check_implementation, generate_input - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, "w") - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -def _combine(a: int, b: int) -> int: - # combine two integers into one: - # we need this to generate a secret seed based on the test-level seed and - # the global secret seed. - # the test-level seeds are public knowledge, and typically relatively small numbers, - # so we need to make sure they don't provide any useful info for the full seed. - # This Cantor construction ensures that if the secret seed is a large number, - # then so is the overall seed. - return int(a + (a + b) * (a + b + 1) // 2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as E: - print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) - exit(113) - - tests = [] - lines = content.splitlines() - match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" - for line in lines: - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - pass - - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - - return tests - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def calculate_stats(durations: list[int]): - """ - Calculate statistical data from a list of durations. - @param durations: A list of durations in nanoseconds. - @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. - """ - runs = len(durations) - total = sum(durations) - best = min(durations) - worst = max(durations) - - avg = total / runs - variance = sum(map(lambda x: (x - avg) ** 2, durations)) - std = math.sqrt(variance / (runs - 1)) - err = std / math.sqrt(runs) - - return Stats( - runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) - ) - - -def _clone_data(data): - """ - Recursively goes through data and clones all tensors. - """ - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - elif isinstance(data, list): - return [_clone_data(x) for x in data] - elif isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - elif isinstance(data, torch.Tensor): - return data.clone() - else: - return data - - -def _run_single_test(test: TestCase): - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - - data = generate_input(**test.args) - torch.cuda.synchronize() - try: - submission_output = custom_kernel(_clone_data(data)) - - except OpError as E: - print(f"Encountered {E}", file=sys.stderr) - return False, str(E) - torch.cuda.synchronize() - return check_implementation(data, submission_output) - - -def run_single_test(test: TestCase): - """ - Runs a single test directly (no multiprocessing). - """ - return _run_single_test(test) - - -def run_testing(logger: PopcornOutput, tests: list[TestCase]): - """ - Executes the actual test case code and checks for correctness. - @param logger: A PopcornOutput object used for logging test results. - @param tests: A list of TestCase objects representing the test cases to be executed. - @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. - """ - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(test) - if not good: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - else: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _run_single_benchmark( - test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float -) -> Stats | Any: - """ - Runs one benchmark. Do not call directly. - """ - from submission import custom_kernel - - durations = [] - # generate input data once - data = generate_input(**test.args) - check_copy = _clone_data(data) - # first, one obligatory correctness check - try: - output = custom_kernel(_clone_data(data)) - except OpError as E: - return f"Encountered {E}" - good, message = check_implementation(check_copy, output) - if not good: - return message - - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 100 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. - - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - if recheck: - # ensure we use a different seed for every benchmark - if "seed" in test.args: - test.args["seed"] += 13 - - data = generate_input(**test.args) - check_copy = _clone_data(data) - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - clear_l2_cache() - - start_event.record() - output = custom_kernel(data) - end_event.record() - torch.cuda.synchronize() - duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns - - if recheck: - good, message = check_implementation(check_copy, output) - if not good: - return message - - del output - durations.append(duration) - - if i > 1: - total_bm_duration = time.perf_counter_ns() - bm_start_time - stats = calculate_stats(durations) - # stop if either - # a) relative error dips below 0.1% - # b) we exceed the total time limit for benchmarking the kernel - # c) we exceed 2 minutes of total wallclock time. - if ( - stats.err / stats.mean < 0.001 - or stats.mean * stats.runs > max_time_ns - or total_bm_duration > 120e9 - ): - break - - return calculate_stats(durations) - - -def run_single_benchmark( - test: TestCase, - recheck: bool, - max_repeats: int, - max_time_ns: float, -): - """ - For a particular test case, check correctness (if applicable) and grab runtime results. - @param test: TestCase object. - @param recheck: Flag for whether to explicitly check functional correctness. - @param max_repeats: Number of trials to repeat. - @param max_time_ns: Timeout time in nanoseconds. - @return: A Stats object for this particular benchmark case or an error if the test fails. - """ - return _run_single_benchmark(test, recheck, max_repeats, max_time_ns) - - -def run_benchmarking(logger: PopcornOutput, tests: list[TestCase]): - """ - Executes benchmarking code for a CUDA Kernel and logs runtimes. - @param logger: A PopcornOutput object used for logging benchmark results. - @param tests: A list of TestCase objects representing the test cases to be benchmarked. - @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. - """ - # warm up - run_single_benchmark(tests[0], False, 100, 10e7) - - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(test, False, 100, 10e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def run_single_profile(test: TestCase) -> str: - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - from torch.profiler import profile, record_function, ProfilerActivity - - data = generate_input(**test.args) - torch.cuda.synchronize() - - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - submission_output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) - - -def run_profiling(logger: PopcornOutput, tests: list[TestCase]): - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - report = run_single_profile(test) - logger.log( - f"benchmark.{idx}.report", - base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), - ) - logger.log("check", "pass") - return 0 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - - filename = None - - with tempfile.NamedTemporaryFile(delete=False) as tmp: - - def build_test_string(tests: list[dict]): - as_str = "" - for test in tests: - kvs = [] - for k, v in test.items(): - kvs.append(f"{k}: {v}") - as_str += "; ".join(kvs) + "\n" - return as_str - - import yaml - - yaml_content = yaml.safe_load(open(sys.argv[2], "r")) - if mode == "test": - tests_str = build_test_string(yaml_content.get("tests", [])) - elif mode in ("benchmark", "leaderboard", "profile"): - tests_str = build_test_string(yaml_content.get("benchmarks", [])) - - tmp.write(tests_str.encode("utf-8")) - tmp.flush() - filename = tmp.name - - tests = get_test_cases(filename, seed) - - os.unlink(filename) - - with PopcornOutput(int(fd)) as logger: - if mode == "test": - return run_testing(logger, tests) - if mode == "benchmark": - return run_benchmarking(logger, tests) - - if mode == "leaderboard": - # warmup - run_single_benchmark(tests[0], False, 100, 1e7) - logger.log("benchmark-count", len(tests)) - passed = True - for i in range(len(tests)): - result = run_single_benchmark(tests[i], True, 100, 30e9) - logger.log(f"benchmark.{i}.spec", tests[i].spec) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log( - f"benchmark.{i}.{field.name}", - getattr(result, field.name), - ) - else: - passed = False - logger.log(f"benchmark.{i}.status", "fail") - logger.log( - f"benchmark.{i}.error", str(result) - ) # TODO: Make sure result implements __str__? - break - - logger.log("check", "pass" if passed else "fail") - elif mode == "profile": - run_profiling(logger, tests) - else: - # TODO: Implement script mode - return 2 - - -if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemv/kernel.mlir b/problems/nvidia/nvfp4_gemv/kernel.mlir deleted file mode 100644 index 7a829fc..0000000 --- a/problems/nvidia/nvfp4_gemv/kernel.mlir +++ /dev/null @@ -1,409 +0,0 @@ -!memref_gmem_f32 = !cute.memref -!memref_gmem_f32_1 = !cute.memref -!memref_gmem_f32_2 = !cute.memref -!memref_gmem_f32_3 = !cute.memref -!memref_gmem_f32_4 = !cute.memref -!memref_gmem_f4E2M1FN = !cute.memref, "(?,?,?):(?{i64},1,?{i64})"> -!memref_gmem_i8 = !cute.memref, "((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))"> -!memref_gmem_i8_1 = !cute.memref -!memref_gmem_i8_2 = !cute.memref -!memref_gmem_i8_3 = !cute.memref -!memref_gmem_i8_4 = !cute.memref, "(?,?,?):(?{i64},1,?{i64})"> -!memref_rmem_f32 = !cute.memref, "8:1"> -!memref_rmem_i8 = !cute.memref, "4:1"> -module attributes {gpu.container_module} { - gpu.module @kernels { - func.func public @kernel_cutlass__convert_kernel_tensorptrf32gmemo11024100div10241_tensorptri8gmemalign16o15121010512_tensor000o1102410110101024112____Float32_Float4E2M1FN_0(%arg0: !memref_gmem_f32, %arg1: !memref_gmem_i8, %arg2: !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">, %arg3: !cute.layout<"(128,8):(8,1)">, %arg4: !cute.layout<"(128,4):(4,1)">, %arg5: i32, %arg6: i32, %arg7: i32) attributes {cute.kernel, gpu.kernel, nvvm.reqntid = array} { - %iter = cute.get_iter(%arg0) : !memref_gmem_f32 - %iter_0 = cute.get_iter(%arg1) : !memref_gmem_i8 - %iter_1 = cute.get_iter(%arg2) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> - %tup = cute.deref_arith_tuple_iter(%iter_1) : !cute.arith_tuple_iter<"(0,0,0)"> - %e0, %e1, %e2 = cute.get_leaves(%tup) : !cute.int_tuple<"(0,0,0)"> - %iter_2 = cute.get_iter(%arg0) : !memref_gmem_f32 - %iter_3 = cute.get_iter(%arg1) : !memref_gmem_i8 - %iter_4 = cute.get_iter(%arg2) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> - %tup_5 = cute.deref_arith_tuple_iter(%iter_4) : !cute.arith_tuple_iter<"(0,0,0)"> - %e0_6, %e1_7, %e2_8 = cute.get_leaves(%tup_5) : !cute.int_tuple<"(0,0,0)"> - %lay = cute.get_layout(%arg0) : !memref_gmem_f32 - %0 = cute.get_shape(%lay) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> - %e0_9, %e1_10, %e2_11, %e3, %e4, %e5 = cute.get_leaves(%0) : !cute.shape<"((1,1024,1),(?,?,?))"> - %itup = cute.to_int_tuple(%e3) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %1 = cute.get_scalars(%itup) : !cute.int_tuple<"?"> - %itup_12 = cute.to_int_tuple(%e4) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %2 = cute.get_scalars(%itup_12) : !cute.int_tuple<"?"> - %itup_13 = cute.to_int_tuple(%e5) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %3 = cute.get_scalars(%itup_13) : !cute.int_tuple<"?"> - %4 = cute.get_stride(%lay) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> - %e0_14, %e1_15, %e2_16, %e3_17, %e4_18, %e5_19 = cute.get_leaves(%4) : !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> - %itup_20 = cute.to_int_tuple(%e1_15) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %5 = cute.get_scalars(%itup_20) : !cute.int_tuple<"?{i64}"> - %itup_21 = cute.to_int_tuple(%e3_17) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %6 = cute.get_scalars(%itup_21) : !cute.int_tuple<"?{i64}"> - %itup_22 = cute.to_int_tuple(%e4_18) : !cute.stride<"?{i64 div=1024}"> to !cute.int_tuple<"?{i64 div=1024}"> - %7 = cute.get_scalars(%itup_22) : !cute.int_tuple<"?{i64 div=1024}"> - %lay_23 = cute.get_layout(%arg1) : !memref_gmem_i8 - %8 = cute.get_shape(%lay_23) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.shape<"((1,512,1),(?,?,?))"> - %e0_24, %e1_25, %e2_26, %e3_27, %e4_28, %e5_29 = cute.get_leaves(%8) : !cute.shape<"((1,512,1),(?,?,?))"> - %itup_30 = cute.to_int_tuple(%e3_27) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %9 = cute.get_scalars(%itup_30) : !cute.int_tuple<"?"> - %itup_31 = cute.to_int_tuple(%e4_28) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %10 = cute.get_scalars(%itup_31) : !cute.int_tuple<"?"> - %itup_32 = cute.to_int_tuple(%e5_29) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %11 = cute.get_scalars(%itup_32) : !cute.int_tuple<"?"> - %12 = cute.get_stride(%lay_23) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> - %e0_33, %e1_34, %e2_35, %e3_36, %e4_37, %e5_38 = cute.get_leaves(%12) : !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> - %itup_39 = cute.to_int_tuple(%e3_36) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %13 = cute.get_scalars(%itup_39) : !cute.int_tuple<"?{i64}"> - %itup_40 = cute.to_int_tuple(%e5_38) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %14 = cute.get_scalars(%itup_40) : !cute.int_tuple<"?{i64}"> - %lay_41 = cute.get_layout(%arg2) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> - %15 = cute.get_shape(%lay_41) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> - %e0_42, %e1_43, %e2_44, %e3_45, %e4_46, %e5_47 = cute.get_leaves(%15) : !cute.shape<"((1,1024,1),(?,?,?))"> - %itup_48 = cute.to_int_tuple(%e3_45) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %16 = cute.get_scalars(%itup_48) : !cute.int_tuple<"?"> - %itup_49 = cute.to_int_tuple(%e4_46) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %17 = cute.get_scalars(%itup_49) : !cute.int_tuple<"?"> - %itup_50 = cute.to_int_tuple(%e5_47) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %18 = cute.get_scalars(%itup_50) : !cute.int_tuple<"?"> - %19 = cute.get_stride(%lay_41) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> - %e0_51, %e1_52, %e2_53, %e3_54, %e4_55, %e5_56 = cute.get_leaves(%19) : !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> - %20 = cute.get_shape(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.shape<"(128,8)"> - %e0_57, %e1_58 = cute.get_leaves(%20) : !cute.shape<"(128,8)"> - %21 = cute.get_stride(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.stride<"(8,1)"> - %e0_59, %e1_60 = cute.get_leaves(%21) : !cute.stride<"(8,1)"> - %22 = cute.get_shape(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.shape<"(128,4)"> - %e0_61, %e1_62 = cute.get_leaves(%22) : !cute.shape<"(128,4)"> - %23 = cute.get_stride(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.stride<"(4,1)"> - %e0_63, %e1_64 = cute.get_leaves(%23) : !cute.stride<"(4,1)"> - %24 = cute.get_shape(%lay) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> - %e0_65, %e1_66, %e2_67, %e3_68, %e4_69, %e5_70 = cute.get_leaves(%24) : !cute.shape<"((1,1024,1),(?,?,?))"> - %itup_71 = cute.to_int_tuple(%e3_68) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %25 = cute.get_scalars(%itup_71) : !cute.int_tuple<"?"> - %itup_72 = cute.to_int_tuple(%e4_69) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %26 = cute.get_scalars(%itup_72) : !cute.int_tuple<"?"> - %itup_73 = cute.to_int_tuple(%e5_70) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %27 = cute.get_scalars(%itup_73) : !cute.int_tuple<"?"> - %28 = cute.get_stride(%lay) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> - %e0_74, %e1_75, %e2_76, %e3_77, %e4_78, %e5_79 = cute.get_leaves(%28) : !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> - %itup_80 = cute.to_int_tuple(%e1_75) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %29 = cute.get_scalars(%itup_80) : !cute.int_tuple<"?{i64}"> - %itup_81 = cute.to_int_tuple(%e3_77) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %30 = cute.get_scalars(%itup_81) : !cute.int_tuple<"?{i64}"> - %itup_82 = cute.to_int_tuple(%e4_78) : !cute.stride<"?{i64 div=1024}"> to !cute.int_tuple<"?{i64 div=1024}"> - %31 = cute.get_scalars(%itup_82) : !cute.int_tuple<"?{i64 div=1024}"> - %32 = cute.get_shape(%lay_23) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.shape<"((1,512,1),(?,?,?))"> - %e0_83, %e1_84, %e2_85, %e3_86, %e4_87, %e5_88 = cute.get_leaves(%32) : !cute.shape<"((1,512,1),(?,?,?))"> - %itup_89 = cute.to_int_tuple(%e3_86) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %33 = cute.get_scalars(%itup_89) : !cute.int_tuple<"?"> - %itup_90 = cute.to_int_tuple(%e4_87) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %34 = cute.get_scalars(%itup_90) : !cute.int_tuple<"?"> - %itup_91 = cute.to_int_tuple(%e5_88) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %35 = cute.get_scalars(%itup_91) : !cute.int_tuple<"?"> - %36 = cute.get_stride(%lay_23) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> - %e0_92, %e1_93, %e2_94, %e3_95, %e4_96, %e5_97 = cute.get_leaves(%36) : !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> - %itup_98 = cute.to_int_tuple(%e3_95) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %37 = cute.get_scalars(%itup_98) : !cute.int_tuple<"?{i64}"> - %itup_99 = cute.to_int_tuple(%e5_97) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %38 = cute.get_scalars(%itup_99) : !cute.int_tuple<"?{i64}"> - %39 = cute.get_shape(%lay_41) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> - %e0_100, %e1_101, %e2_102, %e3_103, %e4_104, %e5_105 = cute.get_leaves(%39) : !cute.shape<"((1,1024,1),(?,?,?))"> - %itup_106 = cute.to_int_tuple(%e3_103) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %40 = cute.get_scalars(%itup_106) : !cute.int_tuple<"?"> - %itup_107 = cute.to_int_tuple(%e4_104) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %41 = cute.get_scalars(%itup_107) : !cute.int_tuple<"?"> - %itup_108 = cute.to_int_tuple(%e5_105) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %42 = cute.get_scalars(%itup_108) : !cute.int_tuple<"?"> - %43 = cute.get_stride(%lay_41) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> - %e0_109, %e1_110, %e2_111, %e3_112, %e4_113, %e5_114 = cute.get_leaves(%43) : !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> - %44 = cute.get_shape(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.shape<"(128,8)"> - %e0_115, %e1_116 = cute.get_leaves(%44) : !cute.shape<"(128,8)"> - %45 = cute.get_stride(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.stride<"(8,1)"> - %e0_117, %e1_118 = cute.get_leaves(%45) : !cute.stride<"(8,1)"> - %46 = cute.get_shape(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.shape<"(128,4)"> - %e0_119, %e1_120 = cute.get_leaves(%46) : !cute.shape<"(128,4)"> - %47 = cute.get_stride(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.stride<"(4,1)"> - %e0_121, %e1_122 = cute.get_leaves(%47) : !cute.stride<"(4,1)"> - %48 = nvvm.read.ptx.sreg.tid.x : i32 - %49 = nvvm.read.ptx.sreg.ctaid.x : i32 - %coord = cute.make_coord(%49) : (i32) -> !cute.coord<"(_,?)"> - %slice = cute.slice(%arg0, %coord) : !memref_gmem_f32, !cute.coord<"(_,?)"> - %iter_123 = cute.get_iter(%slice) : !memref_gmem_f32_1 - %iter_124 = cute.get_iter(%slice) : !memref_gmem_f32_1 - %coord_125 = cute.make_coord(%49) : (i32) -> !cute.coord<"(_,?)"> - %slice_126 = cute.slice(%arg1, %coord_125) : !memref_gmem_i8, !cute.coord<"(_,?)"> - %iter_127 = cute.get_iter(%slice_126) : !memref_gmem_i8_1 - %iter_128 = cute.get_iter(%slice_126) : !memref_gmem_i8_1 - %coord_129 = cute.make_coord(%49) : (i32) -> !cute.coord<"(_,?)"> - %slice_130 = cute.slice(%arg2, %coord_129) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">, !cute.coord<"(_,?)"> - %iter_131 = cute.get_iter(%slice_130) : !cute.coord_tensor<"(?,?{div=1024},?)", "((1,1024,1)):((0,1@1,0))"> - %tup_132 = cute.deref_arith_tuple_iter(%iter_131) : !cute.arith_tuple_iter<"(?,?{div=1024},?)"> - %e0_133, %e1_134, %e2_135 = cute.get_leaves(%tup_132) : !cute.int_tuple<"(?,?{div=1024},?)"> - %50 = cute.get_scalars(%e0_133) : !cute.int_tuple<"?"> - %51 = cute.get_scalars(%e1_134) : !cute.int_tuple<"?{div=1024}"> - %52 = cute.get_scalars(%e2_135) : !cute.int_tuple<"?"> - %iter_136 = cute.get_iter(%slice_130) : !cute.coord_tensor<"(?,?{div=1024},?)", "((1,1024,1)):((0,1@1,0))"> - %tup_137 = cute.deref_arith_tuple_iter(%iter_136) : !cute.arith_tuple_iter<"(?,?{div=1024},?)"> - %e0_138, %e1_139, %e2_140 = cute.get_leaves(%tup_137) : !cute.int_tuple<"(?,?{div=1024},?)"> - %53 = cute.get_scalars(%e0_138) : !cute.int_tuple<"?"> - %54 = cute.get_scalars(%e1_139) : !cute.int_tuple<"?{div=1024}"> - %55 = cute.get_scalars(%e2_140) : !cute.int_tuple<"?"> - %56 = cute.composition(%slice, %arg3) : (!memref_gmem_f32_1, !cute.layout<"(128,8):(8,1)">) -> !memref_gmem_f32_2 - %iter_141 = cute.get_iter(%56) : !memref_gmem_f32_2 - %57 = cute.composition(%slice_126, %arg4) : (!memref_gmem_i8_1, !cute.layout<"(128,4):(4,1)">) -> !memref_gmem_i8_2 - %iter_142 = cute.get_iter(%57) : !memref_gmem_i8_2 - %58 = cute.composition(%slice_130, %arg3) : (!cute.coord_tensor<"(?,?{div=1024},?)", "((1,1024,1)):((0,1@1,0))">, !cute.layout<"(128,8):(8,1)">) -> !cute.coord_tensor<"(?,?{div=1024},?)", "(128,8):(8@1,1@1)"> - %iter_143 = cute.get_iter(%58) : !cute.coord_tensor<"(?,?{div=1024},?)", "(128,8):(8@1,1@1)"> - %tup_144 = cute.deref_arith_tuple_iter(%iter_143) : !cute.arith_tuple_iter<"(?,?{div=1024},?)"> - %e0_145, %e1_146, %e2_147 = cute.get_leaves(%tup_144) : !cute.int_tuple<"(?,?{div=1024},?)"> - %59 = cute.get_scalars(%e0_145) : !cute.int_tuple<"?"> - %60 = cute.get_scalars(%e1_146) : !cute.int_tuple<"?{div=1024}"> - %61 = cute.get_scalars(%e2_147) : !cute.int_tuple<"?"> - %coord_148 = cute.make_coord(%48) : (i32) -> !cute.coord<"(?,_)"> - %slice_149 = cute.slice(%56, %coord_148) : !memref_gmem_f32_2, !cute.coord<"(?,_)"> - %iter_150 = cute.get_iter(%slice_149) : !memref_gmem_f32_3 - %iter_151 = cute.get_iter(%slice_149) : !memref_gmem_f32_3 - %coord_152 = cute.make_coord(%48) : (i32) -> !cute.coord<"(?,_)"> - %slice_153 = cute.slice(%57, %coord_152) : !memref_gmem_i8_2, !cute.coord<"(?,_)"> - %iter_154 = cute.get_iter(%slice_153) : !memref_gmem_i8_3 - %iter_155 = cute.get_iter(%slice_153) : !memref_gmem_i8_3 - %coord_156 = cute.make_coord(%48) : (i32) -> !cute.coord<"(?,_)"> - %slice_157 = cute.slice(%58, %coord_156) : !cute.coord_tensor<"(?,?{div=1024},?)", "(128,8):(8@1,1@1)">, !cute.coord<"(?,_)"> - %iter_158 = cute.get_iter(%slice_157) : !cute.coord_tensor<"(?,?{div=8},?)", "(8):(1@1)"> - %tup_159 = cute.deref_arith_tuple_iter(%iter_158) : !cute.arith_tuple_iter<"(?,?{div=8},?)"> - %e0_160, %e1_161, %e2_162 = cute.get_leaves(%tup_159) : !cute.int_tuple<"(?,?{div=8},?)"> - %62 = cute.get_scalars(%e0_160) : !cute.int_tuple<"?"> - %63 = cute.get_scalars(%e1_161) : !cute.int_tuple<"?{div=8}"> - %64 = cute.get_scalars(%e2_162) : !cute.int_tuple<"?"> - %iter_163 = cute.get_iter(%slice_157) : !cute.coord_tensor<"(?,?{div=8},?)", "(8):(1@1)"> - %tup_164 = cute.deref_arith_tuple_iter(%iter_163) : !cute.arith_tuple_iter<"(?,?{div=8},?)"> - %e0_165, %e1_166, %e2_167 = cute.get_leaves(%tup_164) : !cute.int_tuple<"(?,?{div=8},?)"> - %65 = cute.get_scalars(%e0_165) : !cute.int_tuple<"?"> - %66 = cute.get_scalars(%e1_166) : !cute.int_tuple<"?{div=8}"> - %67 = cute.get_scalars(%e2_167) : !cute.int_tuple<"?"> - %coord_168 = cute.make_coord() : () -> !cute.coord<"0"> - %slice_169 = cute.slice(%slice_157, %coord_168) : !cute.coord_tensor<"(?,?{div=8},?)", "(8):(1@1)">, !cute.coord<"0"> - %iter_170 = cute.get_iter(%slice_169) : !cute.coord_tensor<"(?,?{div=8},?)", "():()"> - %tup_171 = cute.deref_arith_tuple_iter(%iter_170) : !cute.arith_tuple_iter<"(?,?{div=8},?)"> - %e0_172, %e1_173, %e2_174 = cute.get_leaves(%tup_171) : !cute.int_tuple<"(?,?{div=8},?)"> - %68 = cute.get_scalars(%e0_172) : !cute.int_tuple<"?"> - %69 = cute.get_scalars(%e1_173) : !cute.int_tuple<"?{div=8}"> - %70 = cute.get_scalars(%e2_174) : !cute.int_tuple<"?"> - %iter_175 = cute.get_iter(%slice_169) : !cute.coord_tensor<"(?,?{div=8},?)", "():()"> - %tup_176 = cute.deref_arith_tuple_iter(%iter_175) : !cute.arith_tuple_iter<"(?,?{div=8},?)"> - %e0_177, %e1_178, %e2_179 = cute.get_leaves(%tup_176) : !cute.int_tuple<"(?,?{div=8},?)"> - %71 = cute.get_scalars(%e0_177) : !cute.int_tuple<"?"> - %72 = cute.get_scalars(%e1_178) : !cute.int_tuple<"?{div=8}"> - %73 = cute.get_scalars(%e2_179) : !cute.int_tuple<"?"> - %iter_180 = cute.get_iter(%slice_169) : !cute.coord_tensor<"(?,?{div=8},?)", "():()"> - %tup_181 = cute.deref_arith_tuple_iter(%iter_180) : !cute.arith_tuple_iter<"(?,?{div=8},?)"> - %e0_182, %e1_183, %e2_184 = cute.get_leaves(%tup_181) : !cute.int_tuple<"(?,?{div=8},?)"> - %74 = cute.get_scalars(%e0_182) : !cute.int_tuple<"?"> - %75 = cute.get_scalars(%e1_183) : !cute.int_tuple<"?{div=8}"> - %76 = cute.get_scalars(%e2_184) : !cute.int_tuple<"?"> - %coord_185 = cute.make_coord(%e0_182, %e1_183, %e2_184) : (!cute.int_tuple<"?">, !cute.int_tuple<"?{div=8}">, !cute.int_tuple<"?">) -> !cute.coord<"(?,?{div=8},?)"> - %coord_186 = cute.make_coord(%arg5, %arg6, %arg7) : (i32, i32, i32) -> !cute.coord<"(?,?,?)"> - %77 = cute.elem_less(%coord_185, %coord_186) : !cute.coord<"(?,?{div=8},?)">, !cute.coord<"(?,?,?)"> - scf.if %77 { - %78 = cute.get_shape(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.shape<"(128,8)"> - %e0_187, %e1_188 = cute.get_leaves(%78) : !cute.shape<"(128,8)"> - %79 = cute.get_shape(%arg3) : (!cute.layout<"(128,8):(8,1)">) -> !cute.shape<"(128,8)"> - %e0_189, %e1_190 = cute.get_leaves(%79) : !cute.shape<"(128,8)"> - %80 = cute.get(%arg3) <{mode = [1]}> : !cute.layout<"(128,8):(8,1)"> -> !cute.layout<"8:1"> - %rmem = cute.memref.alloca(%80) : !memref_rmem_f32 - %iter_191 = cute.get_iter(%rmem) : !memref_rmem_f32 - %iter_192 = cute.get_iter(%rmem) : !memref_rmem_f32 - %81 = cute.get_shape(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.shape<"(128,4)"> - %e0_193, %e1_194 = cute.get_leaves(%81) : !cute.shape<"(128,4)"> - %82 = cute.get_shape(%arg4) : (!cute.layout<"(128,4):(4,1)">) -> !cute.shape<"(128,4)"> - %e0_195, %e1_196 = cute.get_leaves(%82) : !cute.shape<"(128,4)"> - %83 = cute.get(%arg4) <{mode = [1]}> : !cute.layout<"(128,4):(4,1)"> -> !cute.layout<"4:1"> - %rmem_197 = cute.memref.alloca(%83) : !memref_rmem_i8 - %iter_198 = cute.get_iter(%rmem_197) : !memref_rmem_i8 - %iter_199 = cute.get_iter(%rmem_197) : !memref_rmem_i8 - %atom = cute.make_atom() : () -> !cute_nvgpu.atom.universal_copy - cute.copy(%atom, %slice_149, %rmem) : (!cute_nvgpu.atom.universal_copy, !memref_gmem_f32_3, !memref_rmem_f32) - %lay_200 = cute.get_layout(%rmem) : !memref_rmem_f32 - %84 = cute.get_shape(%lay_200) : (!cute.layout<"8:1">) -> !cute.shape<"8"> - %e0_201 = cute.get_leaves(%84) : !cute.shape<"8"> - %85 = cute.memref.load_vec %rmem, row_major : !memref_rmem_f32 - %86 = nvgpu.cvt_fptrunc %85 : vector<8xf32> to vector<8xf4E2M1FN> - %shape = cute.make_shape() : () -> !cute.shape<"8"> - %lay_202 = cute.make_layout(%shape) : !cute.layout<"8:1"> - %87 = cute.recast_layout<8, 4> (%lay_202) : !cute.layout<"8:1"> to !cute.layout<"4:1"> - %88 = cute.get_shape(%87) : (!cute.layout<"4:1">) -> !cute.shape<"4"> - %e0_203 = cute.get_leaves(%88) : !cute.shape<"4"> - %89 = builtin.unrealized_conversion_cast %86 : vector<8xf4E2M1FN> to vector<8xi4> - %90 = vector.bitcast %89 : vector<8xi4> to vector<4xi8> - %lay_204 = cute.get_layout(%rmem_197) : !memref_rmem_i8 - %91 = cute.get_shape(%lay_204) : (!cute.layout<"4:1">) -> !cute.shape<"4"> - %e0_205 = cute.get_leaves(%91) : !cute.shape<"4"> - %int_tuple = cute.make_int_tuple() : () -> !cute.int_tuple<"4"> - %sz = cute.size(%int_tuple) : (!cute.int_tuple<"4">) -> !cute.int_tuple<"4"> - %e0_206 = cute.get_leaves(%sz) : !cute.int_tuple<"4"> - %int_tuple_207 = cute.make_int_tuple() : () -> !cute.int_tuple<"4"> - %sz_208 = cute.size(%int_tuple_207) : (!cute.int_tuple<"4">) -> !cute.int_tuple<"4"> - %e0_209 = cute.get_leaves(%sz_208) : !cute.int_tuple<"4"> - cute.memref.store_vec %90, %rmem_197, row_major : !memref_rmem_i8 - %atom_210 = cute.make_atom() : () -> !cute_nvgpu.atom.universal_copy - cute.copy(%atom_210, %rmem_197, %slice_153) : (!cute_nvgpu.atom.universal_copy, !memref_rmem_i8, !memref_gmem_i8_3) - } - return - } - } - func.func @cutlass__convert_Tensorgmemoi64i641_Tensorgmemoi641i64_1_8(%arg0: !memref_gmem_f32_4, %arg1: !memref_gmem_f4E2M1FN) attributes {llvm.emit_c_interface} { - %iter = cute.get_iter(%arg0) : !memref_gmem_f32_4 - %iter_0 = cute.get_iter(%arg1) : !memref_gmem_f4E2M1FN - %iter_1 = cute.get_iter(%arg0) : !memref_gmem_f32_4 - %iter_2 = cute.get_iter(%arg1) : !memref_gmem_f4E2M1FN - %lay = cute.get_layout(%arg0) : !memref_gmem_f32_4 - %0 = cute.get_shape(%lay) : (!cute.layout<"(?,?,?):(?{i64},?{i64},1)">) -> !cute.shape<"(?,?,?)"> - %e0, %e1, %e2 = cute.get_leaves(%0) : !cute.shape<"(?,?,?)"> - %itup = cute.to_int_tuple(%e0) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %1 = cute.get_scalars(%itup) : !cute.int_tuple<"?"> - %itup_3 = cute.to_int_tuple(%e1) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %2 = cute.get_scalars(%itup_3) : !cute.int_tuple<"?"> - %itup_4 = cute.to_int_tuple(%e2) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %3 = cute.get_scalars(%itup_4) : !cute.int_tuple<"?"> - %4 = cute.get_stride(%lay) : (!cute.layout<"(?,?,?):(?{i64},?{i64},1)">) -> !cute.stride<"(?{i64},?{i64},1)"> - %e0_5, %e1_6, %e2_7 = cute.get_leaves(%4) : !cute.stride<"(?{i64},?{i64},1)"> - %itup_8 = cute.to_int_tuple(%e0_5) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %5 = cute.get_scalars(%itup_8) : !cute.int_tuple<"?{i64}"> - %itup_9 = cute.to_int_tuple(%e1_6) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %6 = cute.get_scalars(%itup_9) : !cute.int_tuple<"?{i64}"> - %lay_10 = cute.get_layout(%arg1) : !memref_gmem_f4E2M1FN - %7 = cute.get_shape(%lay_10) : (!cute.layout<"(?,?,?):(?{i64},1,?{i64})">) -> !cute.shape<"(?,?,?)"> - %e0_11, %e1_12, %e2_13 = cute.get_leaves(%7) : !cute.shape<"(?,?,?)"> - %itup_14 = cute.to_int_tuple(%e0_11) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %8 = cute.get_scalars(%itup_14) : !cute.int_tuple<"?"> - %itup_15 = cute.to_int_tuple(%e1_12) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %9 = cute.get_scalars(%itup_15) : !cute.int_tuple<"?"> - %itup_16 = cute.to_int_tuple(%e2_13) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %10 = cute.get_scalars(%itup_16) : !cute.int_tuple<"?"> - %11 = cute.get_stride(%lay_10) : (!cute.layout<"(?,?,?):(?{i64},1,?{i64})">) -> !cute.stride<"(?{i64},1,?{i64})"> - %e0_17, %e1_18, %e2_19 = cute.get_leaves(%11) : !cute.stride<"(?{i64},1,?{i64})"> - %itup_20 = cute.to_int_tuple(%e0_17) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %12 = cute.get_scalars(%itup_20) : !cute.int_tuple<"?{i64}"> - %itup_21 = cute.to_int_tuple(%e2_19) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %13 = cute.get_scalars(%itup_21) : !cute.int_tuple<"?{i64}"> - %shape = cute.make_shape() : () -> !cute.shape<"(128,8)"> - %stride = cute.make_stride() : () -> !cute.stride<"(8,1)"> - %lay_22 = cute.make_layout(%shape, %stride) : !cute.layout<"(128,8):(8,1)"> - %14 = cute.recast_layout<8, 4> (%lay_22) : !cute.layout<"(128,8):(8,1)"> to !cute.layout<"(128,4):(4,1)"> - %iter_23 = cute.recast_iter(%iter_2) : !cute.ptr> to !cute.ptr> - %15 = cute.recast_layout<8, 4> (%lay_10) : !cute.layout<"(?,?,?):(?{i64},1,?{i64})"> to !cute.layout<"(?,?,?):(?{i64},1,?{i64})"> - %view = cute.make_view(%iter_23, %15) : !memref_gmem_i8_4 - %iter_24 = cute.get_iter(%view) : !memref_gmem_i8_4 - %16 = cute.get_shape(%lay) : (!cute.layout<"(?,?,?):(?{i64},?{i64},1)">) -> !cute.shape<"(?,?,?)"> - %e0_25, %e1_26, %e2_27 = cute.get_leaves(%16) : !cute.shape<"(?,?,?)"> - %itup_28 = cute.to_int_tuple(%e0_25) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %17 = cute.get_scalars(%itup_28) : !cute.int_tuple<"?"> - %itup_29 = cute.to_int_tuple(%e1_26) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %18 = cute.get_scalars(%itup_29) : !cute.int_tuple<"?"> - %itup_30 = cute.to_int_tuple(%e2_27) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %19 = cute.get_scalars(%itup_30) : !cute.int_tuple<"?"> - %shape_31 = cute.make_shape(%itup_28, %itup_29, %itup_30) : (!cute.int_tuple<"?">, !cute.int_tuple<"?">, !cute.int_tuple<"?">) -> !cute.shape<"(?,?,?)"> - %20 = cute.make_identity_tensor(%shape_31) : !cute.coord_tensor<"(0,0,0)", "(?,?,?):(1@0,1@1,1@2)"> - %iter_32 = cute.get_iter(%20) : !cute.coord_tensor<"(0,0,0)", "(?,?,?):(1@0,1@1,1@2)"> - %tup = cute.deref_arith_tuple_iter(%iter_32) : !cute.arith_tuple_iter<"(0,0,0)"> - %e0_33, %e1_34, %e2_35 = cute.get_leaves(%tup) : !cute.int_tuple<"(0,0,0)"> - %21 = cute.get_shape(%lay) : (!cute.layout<"(?,?,?):(?{i64},?{i64},1)">) -> !cute.shape<"(?,?,?)"> - %e0_36, %e1_37, %e2_38 = cute.get_leaves(%21) : !cute.shape<"(?,?,?)"> - %itup_39 = cute.to_int_tuple(%e0_36) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %22 = cute.get_scalars(%itup_39) : !cute.int_tuple<"?"> - %itup_40 = cute.to_int_tuple(%e1_37) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %23 = cute.get_scalars(%itup_40) : !cute.int_tuple<"?"> - %itup_41 = cute.to_int_tuple(%e2_38) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %24 = cute.get_scalars(%itup_41) : !cute.int_tuple<"?"> - %sz = cute.size(%lay_22) : (!cute.layout<"(128,8):(8,1)">) -> !cute.int_tuple<"1024"> - %e0_42 = cute.get_leaves(%sz) : !cute.int_tuple<"1024"> - %lay_43 = cute.get_layout(%view) : !memref_gmem_i8_4 - %25 = cute.get_shape(%lay_43) : (!cute.layout<"(?,?,?):(?{i64},1,?{i64})">) -> !cute.shape<"(?,?,?)"> - %e0_44, %e1_45, %e2_46 = cute.get_leaves(%25) : !cute.shape<"(?,?,?)"> - %itup_47 = cute.to_int_tuple(%e0_44) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %26 = cute.get_scalars(%itup_47) : !cute.int_tuple<"?"> - %itup_48 = cute.to_int_tuple(%e1_45) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %27 = cute.get_scalars(%itup_48) : !cute.int_tuple<"?"> - %itup_49 = cute.to_int_tuple(%e2_46) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %28 = cute.get_scalars(%itup_49) : !cute.int_tuple<"?"> - %sz_50 = cute.size(%14) : (!cute.layout<"(128,4):(4,1)">) -> !cute.int_tuple<"512"> - %e0_51 = cute.get_leaves(%sz_50) : !cute.int_tuple<"512"> - %tile = cute.make_tile() : () -> !cute.tile<"[1:0;1024:1;1:0]"> - %div = cute.zipped_divide(%arg0, %tile) : !memref_gmem_f32_4, !cute.tile<"[1:0;1024:1;1:0]"> - %iter_52 = cute.get_iter(%div) : !memref_gmem_f32 - %iter_53 = cute.get_iter(%div) : !memref_gmem_f32 - %tile_54 = cute.make_tile() : () -> !cute.tile<"[1:0;1024:1;1:0]"> - %div_55 = cute.zipped_divide(%20, %tile_54) : !cute.coord_tensor<"(0,0,0)", "(?,?,?):(1@0,1@1,1@2)">, !cute.tile<"[1:0;1024:1;1:0]"> - %iter_56 = cute.get_iter(%div_55) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> - %tup_57 = cute.deref_arith_tuple_iter(%iter_56) : !cute.arith_tuple_iter<"(0,0,0)"> - %e0_58, %e1_59, %e2_60 = cute.get_leaves(%tup_57) : !cute.int_tuple<"(0,0,0)"> - %iter_61 = cute.get_iter(%div_55) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> - %tup_62 = cute.deref_arith_tuple_iter(%iter_61) : !cute.arith_tuple_iter<"(0,0,0)"> - %e0_63, %e1_64, %e2_65 = cute.get_leaves(%tup_62) : !cute.int_tuple<"(0,0,0)"> - %tile_66 = cute.make_tile() : () -> !cute.tile<"[1:0;512:1;1:0]"> - %div_67 = cute.zipped_divide(%view, %tile_66) : !memref_gmem_i8_4, !cute.tile<"[1:0;512:1;1:0]"> - %iter_68 = cute.get_iter(%div_67) : !memref_gmem_i8 - %iter_69 = cute.get_iter(%div_67) : !memref_gmem_i8 - %sz_70 = cute.size(%div) <{mode = [1]}> : (!memref_gmem_f32) -> !cute.int_tuple<"?"> - %e0_71 = cute.get_leaves(%sz_70) : !cute.int_tuple<"?"> - %29 = cute.get_scalars(%e0_71) : !cute.int_tuple<"?"> - %sz_72 = cute.size(%lay_22) <{mode = [0]}> : (!cute.layout<"(128,8):(8,1)">) -> !cute.int_tuple<"128"> - %e0_73 = cute.get_leaves(%sz_72) : !cute.int_tuple<"128"> - %lay_74 = cute.get_layout(%div) : !memref_gmem_f32 - %30 = cute.get_shape(%lay_74) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> - %e0_75, %e1_76, %e2_77, %e3, %e4, %e5 = cute.get_leaves(%30) : !cute.shape<"((1,1024,1),(?,?,?))"> - %itup_78 = cute.to_int_tuple(%e3) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %31 = cute.get_scalars(%itup_78) : !cute.int_tuple<"?"> - %itup_79 = cute.to_int_tuple(%e4) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %32 = cute.get_scalars(%itup_79) : !cute.int_tuple<"?"> - %itup_80 = cute.to_int_tuple(%e5) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %33 = cute.get_scalars(%itup_80) : !cute.int_tuple<"?"> - %34 = cute.get_stride(%lay_74) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,?{i64},0),(?{i64},?{i64 div=1024},1))">) -> !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> - %e0_81, %e1_82, %e2_83, %e3_84, %e4_85, %e5_86 = cute.get_leaves(%34) : !cute.stride<"((0,?{i64},0),(?{i64},?{i64 div=1024},1))"> - %itup_87 = cute.to_int_tuple(%e1_82) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %35 = cute.get_scalars(%itup_87) : !cute.int_tuple<"?{i64}"> - %itup_88 = cute.to_int_tuple(%e3_84) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %36 = cute.get_scalars(%itup_88) : !cute.int_tuple<"?{i64}"> - %itup_89 = cute.to_int_tuple(%e4_85) : !cute.stride<"?{i64 div=1024}"> to !cute.int_tuple<"?{i64 div=1024}"> - %37 = cute.get_scalars(%itup_89) : !cute.int_tuple<"?{i64 div=1024}"> - %lay_90 = cute.get_layout(%div_67) : !memref_gmem_i8 - %38 = cute.get_shape(%lay_90) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.shape<"((1,512,1),(?,?,?))"> - %e0_91, %e1_92, %e2_93, %e3_94, %e4_95, %e5_96 = cute.get_leaves(%38) : !cute.shape<"((1,512,1),(?,?,?))"> - %itup_97 = cute.to_int_tuple(%e3_94) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %39 = cute.get_scalars(%itup_97) : !cute.int_tuple<"?"> - %itup_98 = cute.to_int_tuple(%e4_95) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %40 = cute.get_scalars(%itup_98) : !cute.int_tuple<"?"> - %itup_99 = cute.to_int_tuple(%e5_96) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %41 = cute.get_scalars(%itup_99) : !cute.int_tuple<"?"> - %42 = cute.get_stride(%lay_90) : (!cute.layout<"((1,512,1),(?,?,?)):((0,1,0),(?{i64},512,?{i64}))">) -> !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> - %e0_100, %e1_101, %e2_102, %e3_103, %e4_104, %e5_105 = cute.get_leaves(%42) : !cute.stride<"((0,1,0),(?{i64},512,?{i64}))"> - %itup_106 = cute.to_int_tuple(%e3_103) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %43 = cute.get_scalars(%itup_106) : !cute.int_tuple<"?{i64}"> - %itup_107 = cute.to_int_tuple(%e5_105) : !cute.stride<"?{i64}"> to !cute.int_tuple<"?{i64}"> - %44 = cute.get_scalars(%itup_107) : !cute.int_tuple<"?{i64}"> - %lay_108 = cute.get_layout(%div_55) : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))"> - %45 = cute.get_shape(%lay_108) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.shape<"((1,1024,1),(?,?,?))"> - %e0_109, %e1_110, %e2_111, %e3_112, %e4_113, %e5_114 = cute.get_leaves(%45) : !cute.shape<"((1,1024,1),(?,?,?))"> - %itup_115 = cute.to_int_tuple(%e3_112) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %46 = cute.get_scalars(%itup_115) : !cute.int_tuple<"?"> - %itup_116 = cute.to_int_tuple(%e4_113) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %47 = cute.get_scalars(%itup_116) : !cute.int_tuple<"?"> - %itup_117 = cute.to_int_tuple(%e5_114) : !cute.shape<"?"> to !cute.int_tuple<"?"> - %48 = cute.get_scalars(%itup_117) : !cute.int_tuple<"?"> - %49 = cute.get_stride(%lay_108) : (!cute.layout<"((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">) -> !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> - %e0_118, %e1_119, %e2_120, %e3_121, %e4_122, %e5_123 = cute.get_leaves(%49) : !cute.stride<"((0,1@1,0),(1@0,1024@1,1@2))"> - %50 = cute.get_shape(%lay_22) : (!cute.layout<"(128,8):(8,1)">) -> !cute.shape<"(128,8)"> - %e0_124, %e1_125 = cute.get_leaves(%50) : !cute.shape<"(128,8)"> - %51 = cute.get_stride(%lay_22) : (!cute.layout<"(128,8):(8,1)">) -> !cute.stride<"(8,1)"> - %e0_126, %e1_127 = cute.get_leaves(%51) : !cute.stride<"(8,1)"> - %52 = cute.get_shape(%14) : (!cute.layout<"(128,4):(4,1)">) -> !cute.shape<"(128,4)"> - %e0_128, %e1_129 = cute.get_leaves(%52) : !cute.shape<"(128,4)"> - %53 = cute.get_stride(%14) : (!cute.layout<"(128,4):(4,1)">) -> !cute.stride<"(4,1)"> - %e0_130, %e1_131 = cute.get_leaves(%53) : !cute.stride<"(4,1)"> - %c0_i32 = arith.constant 0 : i32 - %54 = arith.index_cast %29 : i32 to index - %c1 = arith.constant 1 : index - %c128 = arith.constant 128 : index - gpu.launch_func @kernels::@kernel_cutlass__convert_kernel_tensorptrf32gmemo11024100div10241_tensorptri8gmemalign16o15121010512_tensor000o1102410110101024112____Float32_Float4E2M1FN_0 blocks in (%54, %c1, %c1) threads in (%c128, %c1, %c1) dynamic_shared_memory_size %c0_i32 args(%div : !memref_gmem_f32, %div_67 : !memref_gmem_i8, %div_55 : !cute.coord_tensor<"(0,0,0)", "((1,1024,1),(?,?,?)):((0,1@1,0),(1@0,1024@1,1@2))">, %lay_22 : !cute.layout<"(128,8):(8,1)">, %14 : !cute.layout<"(128,4):(4,1)">, %17 : i32, %18 : i32, %19 : i32) {use_pdl = false} - return - } -} diff --git a/problems/nvidia/nvfp4_gemv/log b/problems/nvidia/nvfp4_gemv/log deleted file mode 100644 index bc55151..0000000 --- a/problems/nvidia/nvfp4_gemv/log +++ /dev/null @@ -1,15 +0,0 @@ -============================================================ -Launching Blackwell NVFP4 GEMV Test ------------------------------------------------------------- -Input dimensions: - A: (1, 512, 256) [l: batch size, m: rows, k: cols] - b: (1, 256) [l: batch size, k: length] - c: (1, 512) [l: batch size, m: length] -Data types: - A/b dtype: Float4E2M1FN - Scaling factor dtype: Float8E8M0FNU (vector size: 16) - Output C dtype: Float16 -Validation tolerance: 0.1 -============================================================ -c_ref and ref are close within tolerance. -PASS diff --git a/problems/nvidia/nvfp4_gemv/test_python_1.sh b/problems/nvidia/nvfp4_gemv/test_python_1.sh deleted file mode 100644 index c7087c1..0000000 --- a/problems/nvidia/nvfp4_gemv/test_python_1.sh +++ /dev/null @@ -1,86 +0,0 @@ -# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build -BUILD_DIR=/home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/build_non_docker -LLVM_DIR=$BUILD_DIR/llvm-prebuilt -# # BUILD_DIR=/home/scratch.ftse_gpu/workspace/dkg/build -# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build -# #BUILD_DIR=/home/yanchengz/scratch_1/dynamic-kernel-generator/build_debug2 -# # sudo /home/scratch.computelab/utils/driver/install_driver.py --installer=/home/builds/daily/display/x86_64/rel/gpu_drv/r580/r580_00/20250527_36037303/NVIDIA-Linux-x86_64-rel_gpu_drv_r580_r580_00-20250527_36037303-internal.run --reason="Change to tot driver" - - -# # BUILD_DIR=/home/scratch.nbommi_gpu/warp-phase-trace/dynamic-kernel-generator/build_main - -export PYTHONPATH=$BUILD_DIR/cutlass_ir/python_packages -#export PYTHONPATH=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/scripts -export CUDA_TOOLKIT_PATH=$BUILD_DIR/compiler_next -MLIR_CUDA_RUNTIME="$LLVM_DIR/lib/libmlir_cuda_runtime.so" -MLIR_C_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_c_runner_utils.so" -MLIR_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_runner_utils.so" -CUDA_DIALECT_RUNTIME="$BUILD_DIR/lib/libcuda_dialect_runtime.so" -export CUTE_DSL_LIBS="$MLIR_CUDA_RUNTIME:$MLIR_C_RUNNER_UTILS:$MLIR_RUNNER_UTILS:$CUDA_DIALECT_RUNTIME" - - -#export CUTE_DSL_PREPROCESSOR=True - -# export CUTE_DSL_PRINT_IR=1 -# just compile the IR but not execute it -# export CUTE_DSL_DRYRUN=1 -# export CUTE_DSL_JIT_TIME_PROFILING=ON -# export CUTE_DSL_KEEP_IR=True -# export CUTE_DSL_PRINT_IR=1 -# export CUTE_DSL_KEEP_CUBIN=1 -# export CUTE_DSL_LINEINFO=True -# export CUTE_DSL_LOG_TO_CONSOLE=1 -# export PYTHONUNBUFFERED=1 -# export CUTE_DSL_KEEP_SASS=1 -# whether to show detailed log in preprocessing -# export CUTE_DSL_FILTER_STACKTRACE=10 -export CUTE_DSL_ARCH=sm_100a - -# -/home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py -/home/scratch.vickiw_gpu/env/bin/python3 eval.py test task.yml -/home/scratch.vickiw_gpu/env/bin/python3 eval.py benchmark task.yml -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/cuda-gdb --args - -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py -# # /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gated_dual_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gecccccbkvnjtrvtfreufijlfglnudnvuggvdfucidbnhk -# mm.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemm/nvfp4_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemv/nvfp4_gemv.py -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool memcheck \ -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,16384 #135us -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 4096,128,7168 #62 - -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,2048 #26 - - -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gated_dual_gemm/nvfp4_gated_dual_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_naive.py - - - -# print out ncu time -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# python3 vicki/tutorial_fp16_gemm_0__.py --mnk 7168,8,512 - -# use sanitizer to check race contention and memref error -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck|memcheck -# cutlass_ir/compiler/test/python/examples/sm_100a/test_nvfp4_gemv.py - -# capture ncu report -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --check-exit-code 0 -f --set full --import-source yes --target-processes all --clock-control base --cache-control none -o gemv_4.1 \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv.py --m 128 --k 128 --l 2 - -# regular run python example -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/min_latency_hmma.py --mnkl 7168,8,512,1 - -# run pytest -# pytest cutlass_ir/compiler/test/python/examples/sm_80/test_sgemm.py diff --git a/problems/nvidia/nvfp4_gemv/utils.py b/problems/nvidia/nvfp4_gemv/utils.py deleted file mode 100644 index e8a9082..0000000 --- a/problems/nvidia/nvfp4_gemv/utils.py +++ /dev/null @@ -1,176 +0,0 @@ -import os -import random -import numpy as np -import torch - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_device(use_cuda: bool = True) -> torch.device: - """Get the appropriate device (GPU or CPU).""" - if use_cuda: - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - print("No compatible GPU found. Falling back to CPU.") - return torch.device("cpu") - - -# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py -@torch.no_grad() -def verbose_allclose( - received: torch.Tensor, - expected: torch.Tensor, - rtol=1e-05, - atol=1e-08, - max_print=5 -) -> list[str]: - """ - Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - rtol (float): Relative tolerance; relative to expected - atol (float): Absolute tolerance. - max_print (int): Maximum number of mismatched elements to print. - - Raises: - AssertionError: If the tensors are not all close within the given tolerance. - """ - # Check if the shapes of the tensors match - if received.shape != expected.shape: - return ["SIZE MISMATCH"] - - # Calculate the difference between the tensors - diff = torch.abs(received - expected) - - # Determine the tolerance - tolerance = atol + rtol * torch.abs(expected) - - # Find tolerance mismatched elements - tol_mismatched = diff > tolerance - - # Find nan mismatched elements - nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) - - # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) - # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) - - # Find all mismatched elements - mismatched = torch.logical_or( - torch.logical_or(tol_mismatched, nan_mismatched), - torch.logical_or(posinf_mismatched, neginf_mismatched), - ) - - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -@torch.no_grad() -def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): - """ - Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - max_print (int): Maximum number of mismatched elements to print. - - Returns: - Empty string if tensors are equal, otherwise detailed error information - """ - mismatched = torch.not_equal(received, expected) - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: - """ - Convenient "default" implementation for tasks' `check_implementation` function. - """ - expected = reference(data) - reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) - - if len(reasons) > 0: - return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) - - return True, '' - - -def make_match_reference(reference: callable, **kwargs): - def wrapped(data, output): - return match_reference(data, output, reference=reference, **kwargs) - return wrapped - - -class DeterministicContext: - def __init__(self): - self.allow_tf32 = None - self.deterministic = None - self.cublas = None - - def __enter__(self): - self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') - self.allow_tf32 = torch.backends.cudnn.allow_tf32 - self.deterministic = torch.backends.cudnn.deterministic - torch.backends.cudnn.allow_tf32 = False - torch.backends.cudnn.deterministic = True - torch.use_deterministic_algorithms(True) - return self - - def __exit__(self, exc_type, exc_value, traceback): - torch.backends.cudnn.allow_tf32 = self.allow_tf32 - torch.backends.cudnn.deterministic = self.deterministic - torch.use_deterministic_algorithms(False) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas - -def clear_l2_cache(): - # import cupy as cp - # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) - # create a large dummy tensor - dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") - # write stuff to - dummy.fill_(42) - del dummy \ No newline at end of file From 9a1d6c9b839b89605ed72324665c1eb850f4a6fa Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Fri, 10 Oct 2025 07:42:12 -0700 Subject: [PATCH 06/29] improve testing time. --- problems/nvidia/nvfp4_gemv/reference.py | 5 +- problems/nvidia/nvfp4_gemv/submission.py | 80 ++++++++++++++++++++++-- 2 files changed, 77 insertions(+), 8 deletions(-) diff --git a/problems/nvidia/nvfp4_gemv/reference.py b/problems/nvidia/nvfp4_gemv/reference.py index 2773531..b96a724 100644 --- a/problems/nvidia/nvfp4_gemv/reference.py +++ b/problems/nvidia/nvfp4_gemv/reference.py @@ -13,6 +13,8 @@ def ref_kernel( ) -> output_t: """ PyTorch reference implementation of NVFP4 block-scaled GEMV. + This is a very slow reference implementation to show the computation details + of a block-scaled GEMV. This simulates the GEMV operation: C = A @ b where A and b are block-scaled with FP4 values and FP8 scale factors. @@ -47,7 +49,7 @@ def ref_kernel( # prune to mkl for reference check. scale_b = scale_b[:, :k, :] - # Convert to f32 for reference computation + # Convert to f32 for computation # Apply blockwise scaling: elementwise multiplication # This simulates NVFP4 GEMV via 2 FFMA based elementwise multiplication # and 1 FFMA based matmul computations @@ -71,7 +73,6 @@ def create_scale_factor_tensor(l, mn, k, block_size): ref_permute_order = (1, 2, 0) # Create f32 ref torch tensor (cpu) - # After this line, ref_f32_torch_tensor_cpu has shape (mn, scale_k, l) ref_f32_torch_tensor_cpu = torch.randint( 1, 3, ref_shape, dtype=torch.float32 ).permute(*ref_permute_order) diff --git a/problems/nvidia/nvfp4_gemv/submission.py b/problems/nvidia/nvfp4_gemv/submission.py index 563ef34..8f5a87c 100644 --- a/problems/nvidia/nvfp4_gemv/submission.py +++ b/problems/nvidia/nvfp4_gemv/submission.py @@ -220,6 +220,72 @@ def create_scale_factor_cute_tensor(ref_tensor, l, mn, k, block_size, dtype): return cute_tensor, cute_torch_tensor +# Global cache for compiled kernel +_compiled_kernel_cache = None + + +def compile_kernel(data: input_t): + """ + Compile the kernel once and cache it. + This should be called before any timing measurements. + + Args: + a, b, scale_a, scale_b, c: Sample tensors with the expected shapes and types + + Returns: + The compiled kernel function + """ + global _compiled_kernel_cache + + a, b, scale_a, scale_b, c = data + if _compiled_kernel_cache is not None: + return _compiled_kernel_cache + + # Get dimensions from MxKxL layout + m, k, l = a.shape + + # Create CuTe tensors for A, B, C + a_tensor, a_torch = cutlass_torch.cute_tensor_like( + a, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, b_torch = cutlass_torch.cute_tensor_like( + b, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch = cutlass_torch.cute_tensor_like( + c, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + # Mark tensor with element divisibility for 16B alignment + a_tensor.mark_compact_shape_dynamic( + mode=1, + stride_order=(2, 0, 1), + divisibility=32, + ) + b_tensor.mark_compact_shape_dynamic( + mode=1, + stride_order=(2, 0, 1), + divisibility=32, + ) + c_tensor.mark_compact_shape_dynamic( + 0, + (2, 1, 0), + divisibility=16, + ) + + # Create cute tensors from reference tensors + sfa_tensor, sfa_torch = create_scale_factor_cute_tensor( + scale_a, l, m, k, block_size, sf_dtype + ) + sfb_tensor, sfb_torch = create_scale_factor_cute_tensor( + scale_b, l, 1, k, block_size, sf_dtype + ) + + # Compile the kernel + _compiled_kernel_cache = cute.compile(my_kernel, a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor) + + return _compiled_kernel_cache + + def custom_kernel(data: input_t) -> output_t: """ Execute the block-scaled GEMV kernel. @@ -232,15 +298,18 @@ def custom_kernel(data: input_t) -> output_t: data: Tuple of (a, b, scale_a, scale_b, c) PyTorch tensors a: [m, k, l] - Input matrix in float4e2m1fn (simulated with uint8) b: [1, k, l] - Input vector in float4e2m1fn (simulated with uint8) - scale_a: [m, k, l] - Scale factors in float8_e4m3fnuz (simulated with FP32) - scale_b: [1, k, l] - Scale factors in float8_e4m3fnuz (simulated with FP32) - c: [m, 1, l] - Output vector in float32 + scale_a: [m, k, l] - Scale factors in float8_e8m0fnu (simulated with FP32) + scale_b: [1, k, l] - Scale factors in float8_e8m0fnu (simulated with FP32) + c: [m, 1, l] - Output vector in float16 Returns: Output tensor c with computed GEMV results """ a, b, scale_a, scale_b, c = data + # Ensure kernel is compiled (will use cached version if available) + compiled_func = compile_kernel(data) + # Get dimensions from MxKxL layout m, k, l = a.shape @@ -279,8 +348,7 @@ def custom_kernel(data: input_t) -> output_t: scale_b, l, 1, k, block_size, sf_dtype ) - # Run the compiled kernel - # INSERT_YOUR_CODE - my_kernel(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor) + # Execute the compiled kernel + compiled_func(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor) return c_torch From 2eaef115e86e6e66ef0f68d24edf39021b79a084 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Fri, 10 Oct 2025 07:49:35 -0700 Subject: [PATCH 07/29] remove useless file. --- .../nvfp4_gemv/nvfp4_gemv_cute_layout.py | 426 ------------------ 1 file changed, 426 deletions(-) delete mode 100644 problems/nvidia/nvfp4_gemv/nvfp4_gemv_cute_layout.py diff --git a/problems/nvidia/nvfp4_gemv/nvfp4_gemv_cute_layout.py b/problems/nvidia/nvfp4_gemv/nvfp4_gemv_cute_layout.py deleted file mode 100644 index 662aa92..0000000 --- a/problems/nvidia/nvfp4_gemv/nvfp4_gemv_cute_layout.py +++ /dev/null @@ -1,426 +0,0 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: - -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. - -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. - -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import argparse -import cuda.bindings.driver as cuda - -import torch - -import cutlass -import cutlass.cute as cute -import cutlass.torch as cutlass_torch -from cutlass.cute.runtime import from_dlpack -import cutlass.utils.blockscaled_layout as blockscaled_utils - -mma_tiler_mnk = (128, 1, 64) -ab_dtype = cutlass.Float4E2M1FN -sf_dtype = cutlass.Float8E8M0FNU -c_dtype = cutlass.Float16 -sf_vec_size = 16 - -""" -Below code gives a reference for NVFP4 block-scaled GEMV (General Matrix-Vector Multiplication): - -Given: - - A: a matrix of shape (l, m, k), where l is the batch size, m is the number of rows, k is the number of columns. The data type is Float4E2M1FN - - SFA: a matrix of shape (l, m, k//scaling_factor_vector), where l is the batch size, m is the number of rows, k is the number of columns, and scaling factor vector size means these elements will share the same scaling factor. The data type is Float8E8M0FNU. The layout matches definition here https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout. - - b: a batched vector of shape (l, k) and the data type is Float4E2M1FN. - - SFB: a matrix of shape (l, k//scaling_factor_vector, 128) and the data type is Float8E8M0FNU. The layout matches definition here https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout. - - c: the output batched vector of shape (l, m) and the data type is Float16. - -Operation: - c = A * b - -Assumptions: - - The matrix A is stored in memory such that the k (column) dimension is contiguous - - The m dimension is a multiple of 128 - - The k dimension is a multiple of 64 - -""" - - -class Sm100BlockScaledDenseGemvKernel: - def __init__(self): - self.threads_per_cta = 128 - - @cute.jit - def __call__( - self, - a_tensor: cute.Tensor, - b_tensor: cute.Tensor, - sfa_tensor: cute.Tensor, - sfb_tensor: cute.Tensor, - c_tensor: cute.Tensor, - stream: cuda.CUstream, - epilogue_op: cutlass.Constexpr = lambda x: x, - ): - # (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) - sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( - a_tensor.shape, sf_vec_size - ) - sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) - # (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( - b_tensor.shape, sf_vec_size - ) - sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) - # Compute grid size - grid = ( - cute.ceil_div(c_tensor.shape[0], 128), - 1, - c_tensor.shape[2], - ) - # Launch the kernel synchronously - self.kernel(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor).launch( - grid=grid, - block=[self.threads_per_cta, 1, 1], - cluster=(1, 1, 1), - stream=stream, - ) - return - - # GPU device kernel - @cute.kernel - def kernel( - self, - mA_mkl: cute.Tensor, - mB_nkl: cute.Tensor, - mSFA_mkl: cute.Tensor, - mSFB_nkl: cute.Tensor, - mC_mnl: cute.Tensor, - ): - bidx, bidy, bidz = cute.arch.block_idx() - tidx, _, _ = cute.arch.thread_idx() - # mma_coord_mnk = (bidx, bidy, bidz) - - # (bM, bK, RestM, RestK, RestL) - gA_mkl = cute.local_tile( - mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - # (bM, bK, RestM, RestK, RestL) - # bM = (32, 4) - # bK = (16, 4) - gSFA_mkl = cute.local_tile( - mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - # (bN, bK, RestN, RestK, RestL) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # (bN, bK, RestN, RestK, RestL) - gSFB_nkl = cute.local_tile( - mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # (bM, bN, RestM, RestN, RestL) - gC_mnl = cute.local_tile( - mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) - ) - - tCgC = gC_mnl[tidx, None, bidx, bidy, bidz] - tCgC = cute.make_tensor(tCgC.iterator, 1) - res = cute.zeros_like(tCgC, cutlass.Float32) - - k_tile_cnt = gA_mkl.layout[3].shape - for k_tile in range(k_tile_cnt): - tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz] - tBgB = gB_nkl[None, None, bidy, k_tile, bidz] - tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz] - tBgSFB = gSFB_nkl[None, None, bidy, k_tile, bidz] - - # Create Tensor for A/B/SFA/SFB tile - tAgA = cute.make_tensor(tAgA.iterator, mma_tiler_mnk[2]) - tBgB = cute.make_tensor(tBgB.iterator, mma_tiler_mnk[2]) - tAgSFA = cute.make_tensor(tAgSFA.iterator, 4) - tBgSFB = cute.make_tensor(tBgSFB.iterator, 4) - - # Load A/B/SFA/SFB tile from global memory - a_val_nvfp4 = tAgA.load() - b_val_nvfp4 = tBgB.load() - sfa_val_fp8 = tAgSFA.load() - sfb_val_fp8 = tBgSFB.load() - - # Convert to f32 for FFMA computation - a_val = a_val_nvfp4.to(cutlass.Float32) - b_val = b_val_nvfp4.to(cutlass.Float32) - sfa_val = sfa_val_fp8.to(cutlass.Float32) - sfb_val = sfb_val_fp8.to(cutlass.Float32) - - for i in cutlass.range_constexpr(mma_tiler_mnk[2] // sf_vec_size): - for j in cutlass.range_constexpr(sf_vec_size): - res += ( - a_val[i * sf_vec_size + j] - * sfa_val[i] - * b_val[i * sf_vec_size + j] - * sfb_val[i] - ) - tCgC.store(res.to(cutlass.Float16)) - return - - -@cute.jit -def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - sf_ref_tensor: cute.Tensor, - sf_mma_tensor: cute.Tensor, -): - """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout""" - # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) - # group to ((32, 4, rest_m), (4, rest_k), l) - sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) - sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) - for i in cutlass.range(cute.size(sf_ref_tensor)): - mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) - sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] - - -def run_gemv( - m: int, - k: int, - l: int, - tolerance: float, -): - """ - Prepare A/B/SFA/SFB/C tensors, launch GPU kernel, and reference checking. - """ - print("=" * 60) - print("Launching Blackwell NVFP4 GEMV Test") - print("-" * 60) - print("Input dimensions:") - print(f" A: ({l}, {m}, {k}) [l: batch size, m: rows, k: cols]") - print(f" b: ({l}, {k}) [l: batch size, k: length]") - print(f" c: ({l}, {m}) [l: batch size, m: length]") - print("Data types:") - print(f" A/b dtype: {ab_dtype}") - print(f" Scaling factor dtype: {sf_dtype} (vector size: {sf_vec_size})") - print(f" Output C dtype: {c_dtype}") - print(f"Validation tolerance: {tolerance}") - print("=" * 60) - - if not torch.cuda.is_available(): - raise RuntimeError("GPU is required to run this example!") - - torch.manual_seed(1111) - - # GEMV, N must be 1 - n = 1 - - # Create tensor A/B/C - a_ref = cutlass_torch.matrix(l, m, k, False, cutlass.Float32) - b_ref = cutlass_torch.matrix(l, n, k, False, cutlass.Float32) - c_ref = cutlass_torch.matrix(l, m, n, True, cutlass.Float32) - a_tensor, a_torch = cutlass_torch.cute_tensor_like( - a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 - ) - b_tensor, b_torch = cutlass_torch.cute_tensor_like( - b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 - ) - c_tensor, c_torch = cutlass_torch.cute_tensor_like( - c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 - ) - - # Mark tensor with element divisibility for 16B alignment - a_tensor.mark_compact_shape_dynamic( - mode=1, - stride_order=(2, 0, 1), - divisibility=32, - ) - b_tensor.mark_compact_shape_dynamic( - mode=1, - stride_order=(2, 0, 1), - divisibility=32, - ) - c_tensor.mark_compact_shape_dynamic( - 0, - (2, 1, 0), - divisibility=16, - ) - - # - # Helper function to create scale factor tensor SFA/SFB - # for 1x16 block scaled wise use case and follow the layout requirement - # defined in https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout - # - def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype): - def ceil_div(a, b): - return (a + b - 1) // b - - sf_k = ceil_div(k, sf_vec_size) - ref_shape = (l, mn, sf_k) - - atom_m = (32, 4) - atom_k = 4 - mma_shape = ( - l, # batch size - ceil_div(mn, atom_m[0] * atom_m[1]), - ceil_div(sf_k, atom_k), - atom_m[0], - atom_m[1], - atom_k, - ) - - ref_permute_order = (1, 2, 0) - mma_permute_order = (3, 4, 1, 5, 2, 0) - - # Create f32 ref torch tensor (cpu) - ref_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( - ref_shape, - torch.float32, - permute_order=ref_permute_order, - init_type=cutlass_torch.TensorInitType.RANDOM, - init_config=cutlass_torch.RandomInitConfig( - min_val=1, - max_val=3, - ), - ) - - # Create f32 cute torch tensor (cpu) - cute_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( - mma_shape, - torch.float32, - permute_order=mma_permute_order, - init_type=cutlass_torch.TensorInitType.RANDOM, - init_config=cutlass_torch.RandomInitConfig( - min_val=0, - max_val=1, - ), - ) - - # convert ref f32 tensor to cute f32 tensor - cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - from_dlpack(ref_f32_torch_tensor_cpu), - from_dlpack(cute_f32_torch_tensor_cpu), - ) - cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() - - # reshape makes memory contiguous - ref_f32_torch_tensor_cpu = ( - ref_f32_torch_tensor_cpu.permute(2, 0, 1) - .unsqueeze(-1) - .expand(l, mn, sf_k, sf_vec_size) - .reshape(l, mn, sf_k * sf_vec_size) - .permute(*ref_permute_order) - ) - # prune to mkl for reference check. - ref_f32_torch_tensor_cpu = ref_f32_torch_tensor_cpu[:, :k, :] - - # Create dtype cute torch tensor (cpu) - cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( - cute_f32_torch_tensor_cpu, - dtype, - is_dynamic_layout=True, - assumed_align=16, - ) - - # Convert f32 cute tensor to dtype cute tensor - cute_tensor = cutlass_torch.convert_cute_tensor( - cute_f32_torch_tensor, - cute_tensor, - dtype, - is_dynamic_layout=True, - ) - return ref_f32_torch_tensor_cpu, cute_tensor, cute_torch_tensor - - sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( - l, m, k, sf_vec_size, sf_dtype - ) - sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( - l, 1, k, sf_vec_size, sf_dtype - ) - - # Configure gemv kernel - gemv = Sm100BlockScaledDenseGemvKernel() - # Initialize Stream - current_stream = cutlass_torch.default_stream() - # Compile gemv kernel - compiled_gemv = cute.compile( - gemv, - a_tensor, - b_tensor, - sfa_tensor, - sfb_tensor, - c_tensor, - current_stream, - ) - - # Launch GPU kernel - compiled_gemv(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, current_stream) - - # Compute reference result, simulate NVFP4 GEMV via 2 FFMA based elementwise multiplication and 1 FFMA based matmul computations - res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref) - res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref) - ref = torch.einsum("mkl,nkl->mnl", res_a, res_b) - - # Convert c back to f32 for comparison. - c_ref_device = c_ref.cuda() - cute.testing.convert( - c_tensor, - from_dlpack(c_ref_device, assumed_align=16).mark_layout_dynamic(leading_dim=0), - ) - c_ref = c_ref_device.cpu() - torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser( - description="Example of Sm100 Dense BlockScaled GEMV." - ) - parser.add_argument( - "--m", - type=int, - default=512, - help="m dimensions", - ) - parser.add_argument( - "--k", - type=int, - default=256, - help="m dimensions", - ) - parser.add_argument( - "--l", - type=int, - default=1, - help="l dimension", - ) - parser.add_argument( - "--tolerance", type=float, default=1e-01, help="Tolerance for validation" - ) - args = parser.parse_args() - - if args.k % mma_tiler_mnk[2] != 0: - raise ValueError("K must be a multiple of 64 for this GEMV kernel.") - if args.m % mma_tiler_mnk[0] != 0: - raise ValueError("M must be a multiple of 128 for this GEMV kernel.") - - run_gemv( - args.m, - args.k, - args.l, - args.tolerance, - ) - print("PASS") From 174ffdfd8c4e4ee61581d02b0459be7786f1cb5b Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Wed, 15 Oct 2025 08:44:14 -0700 Subject: [PATCH 08/29] simplify nvfp4 gemv code --- problems/nvidia/nvfp4_gemv/reference.py | 155 ++++++------ problems/nvidia/nvfp4_gemv/submission.py | 310 +++++++++++------------ problems/nvidia/nvfp4_gemv/task.yml | 13 +- problems/nvidia/nvfp4_gemv/template.py | 16 +- 4 files changed, 238 insertions(+), 256 deletions(-) diff --git a/problems/nvidia/nvfp4_gemv/reference.py b/problems/nvidia/nvfp4_gemv/reference.py index b96a724..a6101a6 100644 --- a/problems/nvidia/nvfp4_gemv/reference.py +++ b/problems/nvidia/nvfp4_gemv/reference.py @@ -2,81 +2,55 @@ from task import input_t, output_t from utils import make_match_reference -block_size = 16 +# Scaling factor vector size +sf_vec_size = 16 +# Helper function for ceiling division def ceil_div(a, b): - """Helper function for ceiling division""" return (a + b - 1) // b +# Helper function to convert scale factor tensor to blocked format +def to_blocked(input_matrix): + rows, cols = input_matrix.shape + + # Please ensure rows and cols are multiples of 128 and 4 respectively + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + padded = input_matrix + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() + def ref_kernel( data: input_t, ) -> output_t: """ PyTorch reference implementation of NVFP4 block-scaled GEMV. - This is a very slow reference implementation to show the computation details - of a block-scaled GEMV. - - This simulates the GEMV operation: C = A @ b - where A and b are block-scaled with FP4 values and FP8 scale factors. """ - a, b, scale_a, scale_b, c = data + a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, c_ref = data # Get dimensions from MxKxL layout - m, k, l = a.shape - n = 1 # GEMV: N dimension is always 1 + _, _, l = a_ref.shape - scale_k = ceil_div(k, block_size) + # Call torch._scaled_mm to compute the GEMV result + for l_idx in range(l): + # Convert the scale factor tensor to blocked format + scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) + scale_b = to_blocked(sfb_ref_cpu[:, :, l_idx]) + # (m, k) @ (n, k).T -> (m, n) + res = torch._scaled_mm( + a_ref[:, :, l_idx], + b_ref[:, :, l_idx].transpose(0, 1), + scale_a.cuda(), + scale_b.cuda(), + bias=None, + out_dtype=torch.float16, + ) + c_ref[:, 0, l_idx] = res[:, 0] + return c_ref - # Extend scale factor tensor from [m, scale_k, l] to [m, k, l] - ref_permute_order = (1, 2, 0) - scale_a = ( - scale_a.permute(2, 0, 1) - .unsqueeze(-1) - .expand(l, m, scale_k, block_size) - .reshape(l, m, scale_k * block_size) - .permute(*ref_permute_order) - ) - # prune to mkl for reference check. - scale_a = scale_a[:, :k, :] - - scale_b = ( - scale_b.permute(2, 0, 1) - .unsqueeze(-1) - .expand(l, n, scale_k, block_size) - .reshape(l, n, scale_k * block_size) - .permute(*ref_permute_order) - ) - # prune to mkl for reference check. - scale_b = scale_b[:, :k, :] - - # Convert to f32 for computation - # Apply blockwise scaling: elementwise multiplication - # This simulates NVFP4 GEMV via 2 FFMA based elementwise multiplication - # and 1 FFMA based matmul computations - res_a = a.to(torch.float32) * scale_a.cuda() # [m, k, l] - res_b = b.to(torch.float32) * scale_b.cuda() # [1, k, l] - - # Compute batched GEMV: C[m, n, l] = A[m, k, l] @ B[n, k, l] - for i in range(c.shape[2]): - # matmul gives [m], convert to c.dtype then assign to [m, 1] - acc = res_a[:, :, i] @ res_b[0, :, i] - c[:, 0, i] = acc.to(torch.float16) - return c - - -# Helper function to create reference scale factor tensor SFA/SFB -# for 1x16 block scaled wise use case and follow the layout requirement -# defined in https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout -def create_scale_factor_tensor(l, mn, k, block_size): - scale_k = ceil_div(k, block_size) - ref_shape = (l, mn, scale_k) - ref_permute_order = (1, 2, 0) - - # Create f32 ref torch tensor (cpu) - ref_f32_torch_tensor_cpu = torch.randint( - 1, 3, ref_shape, dtype=torch.float32 - ).permute(*ref_permute_order) - return ref_f32_torch_tensor_cpu def generate_input( m: int, @@ -97,25 +71,54 @@ def generate_input( Returns: Tuple of (a, b, scale_a, scale_b, c) where: - a: [m, k, l] - Input matrix in FP4 (simulated with uint8) - b: [1, k, l] - Input vector in FP4 (simulated with uint8) - scale_a: [m, k, l] - Expanded scale factors for a in FP32 - scale_b: [1, k, l] - Expanded scale factors for b in FP32 - c: [m, 1, l] - Output vector in FP16 + a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + b: [1, k, l] - Input vector in torch.float4e2m1fn_x2 data type + scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b: [1, k, l] - Input scale factors in torch.float8e4m3fn data type + c: [m, 1, l] - Output vector in torch.float16 data type """ torch.manual_seed(seed) - n = 1 # GEMV: N dimension is always 1 - - # Generate random FP32 values, then convert to uint8 (FP4 placeholder) - a = torch.randint(0, 2, (l, m, k), dtype=torch.uint8, device="cuda").permute(1, 2, 0) - b = torch.randint(1, 3, (l, n, k), dtype=torch.uint8, device="cuda").permute(1, 2, 0) - c = torch.randn((l, n, m), dtype=torch.float16, device="cuda").permute(2, 1, 0) + + # GEMV N dimension is always 1 + n = 1 + # Scaling factor needs to pad the N size to 128 + n_padded_128 = 128 - # Create scale factors with FP32 data type - scale_a = create_scale_factor_tensor(l, m, k, block_size) - scale_b = create_scale_factor_tensor(l, 1, k, block_size) + # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type + a_ref = torch.randint( + 0, 2, (l, m, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + # Pad b tensor's N dimension to 128 to call torch._scaled_mm for nvfp4 dot product computation + b_ref = torch.randint( + 0, 2, (l, n_padded_128, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + a_ref = a_ref.view(torch.float4_e2m1fn_x2) + b_ref = b_ref.view(torch.float4_e2m1fn_x2) + + # Create float16 output tensor + c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( + 1, 2, 0 + ) - return (a, b, scale_a, scale_b, c) + # Helper function to prepare the scale factor tensors + def create_scale_factor_tensors(l, mn, sf_k): + # Create the reference scale factor tensor (mn, l, sf_k) on CPU. + ref_shape = (l, mn, sf_k) + ref_permute_order = (1, 2, 0) + # Init with uint8 tensor, then convert to float8_e4m3fn + ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8) + ref_f8_torch_tensor_cpu = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) + # permute to match ref_permute_order + ref_f8_torch_tensor_cpu_permuted = ref_f8_torch_tensor_cpu.permute( + *ref_permute_order + ) + return ref_f8_torch_tensor_cpu_permuted + sf_k = ceil_div(k, sf_vec_size) + sfa_ref_cpu = create_scale_factor_tensors(l, m, sf_k) + sfb_ref_cpu = create_scale_factor_tensors(l, n_padded_128, sf_k) + + return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, c_ref) + check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) diff --git a/problems/nvidia/nvfp4_gemv/submission.py b/problems/nvidia/nvfp4_gemv/submission.py index 8f5a87c..03d498f 100644 --- a/problems/nvidia/nvfp4_gemv/submission.py +++ b/problems/nvidia/nvfp4_gemv/submission.py @@ -6,45 +6,49 @@ import cutlass import cutlass.cute as cute -import cutlass.torch as cutlass_torch -from cutlass.cute.runtime import from_dlpack +from cutlass.cute.runtime import make_ptr import cutlass.utils.blockscaled_layout as blockscaled_utils # Kernel configuration parameters mma_tiler_mnk = (128, 1, 64) # Tile sizes for M, N, K dimensions ab_dtype = cutlass.Float4E2M1FN # FP4 data type for A and B -sf_dtype = cutlass.Float8E8M0FNU # FP8 data type for scale factors +sf_dtype = cutlass.Float8E4M3FN # FP8 data type for scale factors c_dtype = cutlass.Float16 # FP16 output type -block_size = 16 # Scale factor block size (16 elements share one scale) +sf_vec_size = 16 # Scale factor block size (16 elements share one scale) threads_per_cta = 128 # Number of threads per CUDA thread block +# Helper function for ceiling division def ceil_div(a, b): - """Helper function for ceiling division""" return (a + b - 1) // b +# Helper function to reorder the scale factor tensor to match the layout defined in +# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout @cute.jit def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - sf_ref_tensor: cute.Tensor, - sf_mma_tensor: cute.Tensor, + sf_ref_ptr: cute.Pointer, + sf_mma_ptr: cute.Pointer, + mn: int, + sf_k: int, + l: int, + mma_shape: tuple, ): - """ - Convert scale factor tensor from reference MxKxL layout to MMA layout. - - This follows the cuBLAS block-scaling factors layout specification: - https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout + mma_permute_order = (3, 4, 1, 5, 2, 0) + permuted_shape = tuple(mma_shape[i] for i in mma_permute_order) + cute_layout = cute.make_ordered_layout(permuted_shape, order=(2, 1, 4, 0, 3, 5)) - """ - # sf_mma_tensor has flattened shape (32, 4, rest_m, 4, rest_k, l) - # Group modes to ((32, 4, rest_m), (4, rest_k), l) for hierarchical indexing + sf_ref_tensor = cute.make_tensor( + sf_ref_ptr, cute.make_layout((mn, sf_k, l), stride=(sf_k, 1, mn * sf_k)) + ) + sf_mma_tensor = cute.make_tensor(sf_mma_ptr, cute_layout) sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) - # Copy data from reference layout to MMA layout for i in cutlass.range(cute.size(sf_ref_tensor)): mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] +# The CuTe reference implementation for NVFP4 block-scaled GEMV @cute.kernel def kernel( mA_mkl: cute.Tensor, @@ -53,36 +57,38 @@ def kernel( mSFB_nkl: cute.Tensor, mC_mnl: cute.Tensor, ): + # Get CUDA block and thread indices bidx, bidy, bidz = cute.arch.block_idx() tidx, _, _ = cute.arch.thread_idx() - # (bM, bK, RestM, RestK, RestL) + # Extract the local tile for input matrix A (shape: [block_M, block_K, rest_M, rest_K, rest_L]) gA_mkl = cute.local_tile( mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) ) - # (bM, bK, RestM, RestK, RestL) - # bM = (32, 4) - # bK = (16, 4) + # Extract the local tile for scale factor tensor for A (same shape as gA_mkl) + # Here, block_M = (32, 4); block_K = (16, 4) gSFA_mkl = cute.local_tile( mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) ) - # (bN, bK, RestN, RestK, RestL) + # Extract the local tile for input matrix B (shape: [block_N, block_K, rest_N, rest_K, rest_L]) gB_nkl = cute.local_tile( mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) ) - # (bN, bK, RestN, RestK, RestL) + # Extract the local tile for scale factor tensor for B (same shape as gB_nkl) gSFB_nkl = cute.local_tile( mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) ) - # (bM, bN, RestM, RestN, RestL) + # Extract the local tile for output matrix C (shape: [block_M, block_N, rest_M, rest_N, rest_L]) gC_mnl = cute.local_tile( mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) ) + # Select output element corresponding to this thread and block indices tCgC = gC_mnl[tidx, None, bidx, bidy, bidz] tCgC = cute.make_tensor(tCgC.iterator, 1) res = cute.zeros_like(tCgC, cutlass.Float32) + # Get the number of k tiles (depth dimension) for the reduction loop k_tile_cnt = gA_mkl.layout[3].shape for k_tile in range(k_tile_cnt): tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz] @@ -90,64 +96,79 @@ def kernel( tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz] tBgSFB = gSFB_nkl[None, None, bidy, k_tile, bidz] - # Load A/B/SFA/SFB tile from global memory + # Load NVFP4 or FP8 values from global memory a_val_nvfp4 = tAgA.load() b_val_nvfp4 = tBgB.load() sfa_val_fp8 = tAgSFA.load() sfb_val_fp8 = tBgSFB.load() - # Convert to f32 for FFMA computation + # Convert loaded values to float32 for computation (FFMA) a_val = a_val_nvfp4.to(cutlass.Float32) b_val = b_val_nvfp4.to(cutlass.Float32) sfa_val = sfa_val_fp8.to(cutlass.Float32) sfb_val = sfb_val_fp8.to(cutlass.Float32) - for i in cutlass.range_constexpr(mma_tiler_mnk[2] // block_size): - for j in cutlass.range_constexpr(block_size): + # Iterate over SF vector tiles and compute the scale&matmul accumulation + for i in cutlass.range_constexpr(mma_tiler_mnk[2] // sf_vec_size): + for j in cutlass.range_constexpr(sf_vec_size): + # Accumulate: (A * scaleA * B * scaleB), where scaling is per-vector res += ( - a_val[i * block_size + j] + a_val[i * sf_vec_size + j] * sfa_val[i] - * b_val[i * block_size + j] + * b_val[i * sf_vec_size + j] * sfb_val[i] ) + # Store the final float16 result back to global memory tCgC.store(res.to(cutlass.Float16)) return @cute.jit def my_kernel( - a_tensor: cute.Tensor, - b_tensor: cute.Tensor, - sfa_tensor: cute.Tensor, - sfb_tensor: cute.Tensor, - c_tensor: cute.Tensor, + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + sfa_ptr: cute.Pointer, + sfb_ptr: cute.Pointer, + c_ptr: cute.Pointer, + problem_size: tuple, ): """ Host-side JIT function to prepare tensors and launch GPU kernel. - - This function: - 1. Converts scale factor tensors to the correct MMA layout - 2. Computes grid dimensions based on tensor shapes - 3. Launches the CUDA kernel - - Args: - a_tensor: Input matrix A (CuTe tensor) - b_tensor: Input vector b (CuTe tensor) - sfa_tensor: Scale factors for A (CuTe tensor) - sfb_tensor: Scale factors for B (CuTe tensor) - c_tensor: Output vector c (CuTe tensor) """ + m, _, k, l = problem_size + # Create CuTe Tensor via pointer and problem size. + a_tensor = cute.make_tensor( + a_ptr, + cute.make_layout( + (m, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), + ), + ) + # We use n=128 to create the torch tensor to do fp4 computation via torch._scaled_mm + # then copy torch tensor to cute tensor for cute customize kernel computation + # therefore we need to ensure b_tensor has the right stride with this 128 padded size on n. + n_padded_128 = 128 + b_tensor = cute.make_tensor( + b_ptr, + cute.make_layout( + (n_padded_128, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(n_padded_128 * k, 32)), + ), + ) + c_tensor = cute.make_tensor( + c_ptr, cute.make_layout((cute.assume(m, 32), 1, l), stride=(1, 1, m)) + ) # Convert scale factor tensors to MMA layout # The layout matches Tensor Core requirements: (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( - a_tensor.shape, block_size + a_tensor.shape, sf_vec_size ) - sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) + sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( - b_tensor.shape, block_size + b_tensor.shape, sf_vec_size ) - sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) + sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout) # Compute grid dimensions # Grid is (M_blocks, 1, L) where: @@ -165,66 +186,56 @@ def my_kernel( block=[threads_per_cta, 1, 1], cluster=(1, 1, 1), ) + return -# Helper function for ceiling division -def ceil_div(a, b): - return (a + b - 1) // b - - -# Helper function to convert reference tensor to cute tensor -def create_scale_factor_cute_tensor(ref_tensor, l, mn, k, block_size, dtype): - - scale_k = ceil_div(k, block_size) - +# Reorder scale factor from (mn, l, sf_k) to (32, 4, rest_m, 4, rest_k, l) layout +def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): + sf_k = ceil_div(k, sf_vec_size) atom_m = (32, 4) atom_k = 4 mma_shape = ( l, # batch size ceil_div(mn, atom_m[0] * atom_m[1]), - ceil_div(scale_k, atom_k), + ceil_div(sf_k, atom_k), atom_m[0], atom_m[1], atom_k, ) - + # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on CPU. mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) + reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) - # Create f32 cute torch tensor (cpu) - cute_f32_torch_tensor_cpu = torch.randint( - 1, 3, mma_shape, dtype=torch.float32 - ).permute(*mma_permute_order) - - # Copy reference tensor to cute tensor in the customized data layout + # Helper function to convert scale factor tensor to CUTE-format scale factor tensor cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - from_dlpack(ref_tensor), - from_dlpack(cute_f32_torch_tensor_cpu), - ) - cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() - - # Create the desired data type cute torch tensor (cpu) - cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( - cute_f32_torch_tensor_cpu, - dtype, - is_dynamic_layout=True, - assumed_align=16, + make_ptr( + cutlass.Float8E4M3FN, + ref_f8_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ), + make_ptr( + cutlass.Float8E4M3FN, + reordered_f8_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ), + mn, + sf_k, + l, + mma_shape, ) - - # Convert f32 cute tensor to the desired data type cute tensor - cute_tensor = cutlass_torch.convert_cute_tensor( - cute_f32_torch_tensor, - cute_tensor, - dtype, - is_dynamic_layout=True, - ) - return cute_tensor, cute_torch_tensor + return reordered_f8_tensor.cuda() # Global cache for compiled kernel _compiled_kernel_cache = None - -def compile_kernel(data: input_t): +def compile_kernel(): """ Compile the kernel once and cache it. This should be called before any timing measurements. @@ -237,51 +248,29 @@ def compile_kernel(data: input_t): """ global _compiled_kernel_cache - a, b, scale_a, scale_b, c = data if _compiled_kernel_cache is not None: return _compiled_kernel_cache - # Get dimensions from MxKxL layout - m, k, l = a.shape - # Create CuTe tensors for A, B, C - a_tensor, a_torch = cutlass_torch.cute_tensor_like( - a, ab_dtype, is_dynamic_layout=True, assumed_align=16 + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 ) - b_tensor, b_torch = cutlass_torch.cute_tensor_like( - b, ab_dtype, is_dynamic_layout=True, assumed_align=16 + b_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 ) - c_tensor, c_torch = cutlass_torch.cute_tensor_like( - c, c_dtype, is_dynamic_layout=True, assumed_align=16 + c_ptr = make_ptr( + c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 ) - - # Mark tensor with element divisibility for 16B alignment - a_tensor.mark_compact_shape_dynamic( - mode=1, - stride_order=(2, 0, 1), - divisibility=32, + sfa_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 ) - b_tensor.mark_compact_shape_dynamic( - mode=1, - stride_order=(2, 0, 1), - divisibility=32, - ) - c_tensor.mark_compact_shape_dynamic( - 0, - (2, 1, 0), - divisibility=16, - ) - - # Create cute tensors from reference tensors - sfa_tensor, sfa_torch = create_scale_factor_cute_tensor( - scale_a, l, m, k, block_size, sf_dtype - ) - sfb_tensor, sfb_torch = create_scale_factor_cute_tensor( - scale_b, l, 1, k, block_size, sf_dtype + sfb_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 ) # Compile the kernel - _compiled_kernel_cache = cute.compile(my_kernel, a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor) + _compiled_kernel_cache = cute.compile(my_kernel, a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (0, 0, 0, 0)) return _compiled_kernel_cache @@ -295,60 +284,51 @@ def custom_kernel(data: input_t) -> output_t: and returns the result. Args: - data: Tuple of (a, b, scale_a, scale_b, c) PyTorch tensors - a: [m, k, l] - Input matrix in float4e2m1fn (simulated with uint8) - b: [1, k, l] - Input vector in float4e2m1fn (simulated with uint8) - scale_a: [m, k, l] - Scale factors in float8_e8m0fnu (simulated with FP32) - scale_b: [1, k, l] - Scale factors in float8_e8m0fnu (simulated with FP32) + data: Tuple of (a, b, sfa_cpu, sfb_cpu, c) PyTorch tensors + a: [m, k, l] - Input matrix in float4e2m1fn + b: [1, k, l] - Input vector in float4e2m1fn + sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn + sfb_cpu: [1, k, l] - Scale factors in float8_e4m3fn c: [m, 1, l] - Output vector in float16 Returns: Output tensor c with computed GEMV results """ - a, b, scale_a, scale_b, c = data + a, b, sfa_cpu, sfb_cpu, c = data # Ensure kernel is compiled (will use cached version if available) - compiled_func = compile_kernel(data) - + compiled_func = compile_kernel() # Get dimensions from MxKxL layout m, k, l = a.shape - - # Create CuTe tensors for A, B, C - a_tensor, a_torch = cutlass_torch.cute_tensor_like( - a, ab_dtype, is_dynamic_layout=True, assumed_align=16 - ) - b_tensor, b_torch = cutlass_torch.cute_tensor_like( - b, ab_dtype, is_dynamic_layout=True, assumed_align=16 + # Torch use e2m1_x2 data type, thus k is halved + k = k * 2 + # GEMV N dimension is always 1 + n = 1 + # Scaling factor needs to pad the N size to 128 + n_padded_128 = 128 + + # Create the reordered scale factor tensors from the reference scale factor tensors via CuTe function. + sfa_reordered = create_reordered_scale_factor_tensor(l, m, k, sfa_cpu) + sfb_reordered = create_reordered_scale_factor_tensor(l, n_padded_128, k, sfb_cpu) + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 ) - c_tensor, c_torch = cutlass_torch.cute_tensor_like( - c, c_dtype, is_dynamic_layout=True, assumed_align=16 + b_ptr = make_ptr( + ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 ) - # Mark tensor with element divisibility for 16B alignment - a_tensor.mark_compact_shape_dynamic( - mode=1, - stride_order=(2, 0, 1), - divisibility=32, + c_ptr = make_ptr( + c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 ) - b_tensor.mark_compact_shape_dynamic( - mode=1, - stride_order=(2, 0, 1), - divisibility=32, + sfa_ptr = make_ptr( + sf_dtype, sfa_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) - c_tensor.mark_compact_shape_dynamic( - 0, - (2, 1, 0), - divisibility=16, - ) - - # Create cute tensors from reference tensors - sfa_tensor, sfa_torch = create_scale_factor_cute_tensor( - scale_a, l, m, k, block_size, sf_dtype - ) - sfb_tensor, sfb_torch = create_scale_factor_cute_tensor( - scale_b, l, 1, k, block_size, sf_dtype + sfb_ptr = make_ptr( + sf_dtype, sfb_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) # Execute the compiled kernel - compiled_func(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor) - - return c_torch + compiled_func(a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l)) + + return c diff --git a/problems/nvidia/nvfp4_gemv/task.yml b/problems/nvidia/nvfp4_gemv/task.yml index 84b0665..dc0e4e5 100644 --- a/problems/nvidia/nvfp4_gemv/task.yml +++ b/problems/nvidia/nvfp4_gemv/task.yml @@ -14,17 +14,16 @@ description: | You will implement a batched matrix-vector multiplication kernel optimized for NVIDIA B200. To be explicit, you will be given a tuple of tensors: ``` - (a, scale_a, b, scale_b, c) + (a, b, sfa, sfb, c) ``` where: - * `a` is L x M x K in row-major order in nvfp4(e2m1) - * `b` is L x 1 x K in nvfp4(e2m1) - * `scale_a` is L x M x K // 16 in row-major order in fp8(e4m3fnuz) - * `scale_b` is L x 1 x K // 16 in fp8(e4m3fnuz) - * `c` is L x M x 1 in fp16 + * `a` is M x K x L in K-major order in nvfp4(e2m1) + * `b` is 1 x K x L in K-major order in nvfp4(e2m1) + * `sfa` is M x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `sfb` is 1 x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `c` is M x 1 x L in fp16 Matrix sizes `M` is divisible by mma_tiler_mn[0] defined in the kernel, `K` is divisible by 64. - The computation is using FFMA instructions to simulate NVFP4 block-scaled GEMV computation and block_size is 16. The ranking criteria is the geometric mean of the benchmark results. For the grand price, your kernel will be evaluated against the speed of light analysis and the solution closest to the speed of light will be awarded the grand price. diff --git a/problems/nvidia/nvfp4_gemv/template.py b/problems/nvidia/nvfp4_gemv/template.py index bdc1a22..2cf273e 100644 --- a/problems/nvidia/nvfp4_gemv/template.py +++ b/problems/nvidia/nvfp4_gemv/template.py @@ -3,20 +3,20 @@ def custom_kernel(data: input_t) -> output_t: """ - Reference implementation of block-scale fp8 gemm + Reference implementation of block-scale fp8 gemv Args: data: Tuple that expands to: - a: torch.Tensor[float4e2m1fn] of shape [l, m, k], - b: torch.Tensor[float4e2m1fn] of shape [l, 1, k], - scale_a: torch.Tensor[float8_e4m3fnuz] of shape [l, m, k // 16], - scale_b: torch.Tensor[float8_e4m3fnuz] of shape [l, 1, k // 16], - c: torch.Tensor[float16] of shape [l, m, 1] + a: torch.Tensor[float4e2m1fn] of shape [m, k, l], + b: torch.Tensor[float4e2m1fn] of shape [1, k, l], + sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], + sfb: torch.Tensor[float8_e4m3fnuz] of shape [1, k // 16, l], + c: torch.Tensor[float16] of shape [m, 1, l] Returns: Tensor containing output in float16 - c: torch.Tensor[float16] of shape [l, m, 1] + c: torch.Tensor[float16] of shape [m, 1, l] """ # c: [l, m, 1] is pre-allocated memory to avoid timing allocation overhead. - a, b, scale_a, scale_b, c = data + a, b, sfa, sfb, c = data # Your implementation here From 4d3bd27fb8171e77a3f6e7417eaadad166f96c72 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Wed, 15 Oct 2025 08:44:54 -0700 Subject: [PATCH 09/29] add nvfp4 gemm code. --- problems/nvidia/nvfp4_gemm/reference.py | 120 ++++ problems/nvidia/nvfp4_gemm/submission.py | 820 +++++++++++++++++++++++ problems/nvidia/nvfp4_gemm/task.py | 11 + problems/nvidia/nvfp4_gemm/task.yml | 60 ++ problems/nvidia/nvfp4_gemm/template.py | 23 + 5 files changed, 1034 insertions(+) create mode 100644 problems/nvidia/nvfp4_gemm/reference.py create mode 100644 problems/nvidia/nvfp4_gemm/submission.py create mode 100644 problems/nvidia/nvfp4_gemm/task.py create mode 100644 problems/nvidia/nvfp4_gemm/task.yml create mode 100644 problems/nvidia/nvfp4_gemm/template.py diff --git a/problems/nvidia/nvfp4_gemm/reference.py b/problems/nvidia/nvfp4_gemm/reference.py new file mode 100644 index 0000000..18ffb7c --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/reference.py @@ -0,0 +1,120 @@ +import torch +from task import input_t, output_t +from utils import make_match_reference + +# Scaling factor vector size +sf_vec_size = 16 + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + +# Helper function to convert scale factor tensor to blocked format +def to_blocked(input_matrix): + rows, cols = input_matrix.shape + + # Please ensure rows and cols are multiples of 128 and 4 respectively + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + padded = input_matrix + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() + +def ref_kernel( + data: input_t, +) -> output_t: + """ + PyTorch reference implementation of NVFP4 block-scaled GEMM. + """ + a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, c_ref = data + + # Get dimensions from MxKxL layout + _, _, l = a_ref.shape + + # Call torch._scaled_mm to compute the GEMV result + for l_idx in range(l): + # Convert the scale factor tensor to blocked format + scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) + scale_b = to_blocked(sfb_ref_cpu[:, :, l_idx]) + # (m, k) @ (n, k).T -> (m, n) + res = torch._scaled_mm( + a_ref[:, :, l_idx], + b_ref[:, :, l_idx].transpose(0, 1), + scale_a.cuda(), + scale_b.cuda(), + bias=None, + out_dtype=torch.float16, + ) + c_ref[:, :, l_idx] = res + return c_ref + + +def generate_input( + m: int, + n: int, + k: int, + l: int, + seed: int, +): + """ + Generate input tensors for NVFP4 block-scaled GEMV. + + This follows the pattern from nvfp4_gemv_cute_layout.py for tensor preparation. + + Args: + m: Number of rows in matrix A + k: Number of columns in A (and length of vector b) + l: Batch size + seed: Random seed for reproducibility + + Returns: + Tuple of (a, b, scale_a, scale_b, c) where: + a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + b: [1, k, l] - Input vector in torch.float4e2m1fn_x2 data type + scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b: [1, k, l] - Input scale factors in torch.float8e4m3fn data type + c: [m, 1, l] - Output vector in torch.float16 data type + """ + torch.manual_seed(seed) + + # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type + a_ref = torch.randint( + 0, 2, (l, m, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + # Pad b tensor's N dimension to 128 to call torch._scaled_mm for nvfp4 dot product computation + b_ref = torch.randint( + 0, 2, (l, n, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + a_ref = a_ref.view(torch.float4_e2m1fn_x2) + b_ref = b_ref.view(torch.float4_e2m1fn_x2) + + # Create float16 output tensor + c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( + 1, 2, 0 + ) + + # Helper function to prepare the scale factor tensors + def create_scale_factor_tensors(l, mn, sf_k): + # Create the reference scale factor tensor (mn, l, sf_k) on CPU. + ref_shape = (l, mn, sf_k) + ref_permute_order = (1, 2, 0) + # Init with uint8 tensor, then convert to float8_e4m3fn + ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8) + ref_f8_torch_tensor_cpu = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) + # permute to match ref_permute_order + ref_f8_torch_tensor_cpu_permuted = ref_f8_torch_tensor_cpu.permute( + *ref_permute_order + ) + return ref_f8_torch_tensor_cpu_permuted + + + sf_k = ceil_div(k, sf_vec_size) + sfa_ref_cpu = create_scale_factor_tensors(l, m, sf_k) + sfb_ref_cpu = create_scale_factor_tensors(l, n, sf_k) + + return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, c_ref) + +check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) diff --git a/problems/nvidia/nvfp4_gemm/submission.py b/problems/nvidia/nvfp4_gemm/submission.py new file mode 100644 index 0000000..23c61e3 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/submission.py @@ -0,0 +1,820 @@ +from torch._higher_order_ops.torchbind import call_torchbind_fake +import cuda.bindings.driver as cuda + +import torch +from task import input_t, output_t + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import make_ptr + +# Kernel configuration parameters +mma_tiler_mnk = (128, 128, 256) # Tile sizes for M, N, K dimensions +mma_inst_shape_k = 64 +ab_dtype = cutlass.Float4E2M1FN # FP4 data type for A and B +sf_dtype = cutlass.Float8E4M3FN # FP8 data type for scale factors +c_dtype = cutlass.Float16 # FP16 output type +sf_vec_size = 16 # Scale factor block size (16 elements share one scale) +threads_per_cta = 128 # Number of threads per CUDA thread block +# stage numbers of shared memory and tmem +num_acc_stage = 1 +num_ab_stage = 1 +num_tmem_alloc_cols = 512 + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + + +# Helper function to reorder the scale factor tensor to match the layout defined in +# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_ptr: cute.Pointer, + sf_mma_ptr: cute.Pointer, + mn: int, + sf_k: int, + l: int, + mma_shape: tuple, +): + mma_permute_order = (3, 4, 1, 5, 2, 0) + permuted_shape = tuple(mma_shape[i] for i in mma_permute_order) + cute_layout = cute.make_ordered_layout(permuted_shape, order=(2, 1, 4, 0, 3, 5)) + + sf_ref_tensor = cute.make_tensor( + sf_ref_ptr, cute.make_layout((mn, sf_k, l), stride=(sf_k, 1, mn * sf_k)) + ) + sf_mma_tensor = cute.make_tensor(sf_mma_ptr, cute_layout) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] + + +# The CuTe reference implementation for NVFP4 block-scaled GEMV +@cute.kernel +def kernel( + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + mC_mnl: cute.Tensor, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + num_tma_load_bytes: cutlass.Constexpr[int], +): + """ + GPU device kernel performing the batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + tidx = cute.arch.thread_idx() + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + + # Coords outside cluster + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Define shared storage for kernel + # + @cute.struct + class SharedStorage: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + # (MMA, MMA_M, MMA_K, STAGE) + sA = smem.allocate_tensor( + element_type=ab_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = smem.allocate_tensor( + element_type=ab_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfa_smem_layout_staged, + byte_alignment=128, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) + + # + # Initialize mainloop ab_pipeline, acc_pipeline and their states + # + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=num_tma_load_bytes, + ).make_participants() + acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=num_acc_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + threads_per_cta, + ), + ).make_participants() + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/SFA/SFB/C + # + # (MMA, MMA_M, MMA_K, RestK) + thr_mma = tiled_mma.get_slice(0) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B/SFA/SFB + # + # TMA load A partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + 0, + cute.make_layout(1), + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + 0, + cute.make_layout(1), + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMALDG_SFA partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + 0, + cute.make_layout(1), + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMALDG_SFB partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cpasync.tma_partition( + tma_atom_sfb, + 0, + cute.make_layout(1), + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + + # + # Alloc tensor memory buffer + # + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=threads_per_cta, + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + ) + tmem.allocate(num_tmem_alloc_cols) + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) + tCtAcc = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Make SFA/SFB tmem tensor + # + # Get SFA tmem ptr + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc), + dtype=sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + # Get SFB tmem ptr + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + # + # Partition for S2T copy of SFA/SFB + # + # Make S2T CopyAtom + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), + sf_dtype, + ) + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFA_compact = cute.filter_zeros(sSFA) + # (MMA, MMA_MN, MMA_K) + tCtSFA_compact = cute.filter_zeros(tCtSFA) + tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) + thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) + + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFB_compact = cute.filter_zeros(sSFB) + # (MMA, MMA_MN, MMA_K) + tCtSFB_compact = cute.filter_zeros(tCtSFB) + tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB_compact) + thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfb, tCsSFB_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFB_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB_compact) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgSFB = tBgSFB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # + # Execute Data copy and Math computation in the k_tile loop + # + if warp_idx == 0: + # Wait for accumulator buffer empty + acc_empty = acc_producer.acquire_and_advance() + # Set ACCUMULATE field to False for the first k_tile iteration + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + # Execute k_tile loop + for k_tile in range(k_tile_cnt): + # Wait for AB buffer empty + ab_empty = ab_producer.acquire_and_advance() + + # TMALDG A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA[(None, k_tile)], + tAsA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_b, + tBgB[(None, k_tile)], + tBsB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfa, + tAgSFA[(None, k_tile)], + tAsSFA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfb, + tBgSFB[(None, k_tile)], + tBsSFB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + + # Wait for AB buffer full + ab_full = ab_consumer.wait_and_advance() + + # Copy SFA/SFB to tmem + s2t_stage_coord = (None, None, None, None, ab_full.index) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_full.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_full.release() + acc_empty.commit() + + # + # Epilogue + # Partition for epilogue + # + op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) + copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(tCgC) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 + ) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rC = cute.make_fragment( + tTR_gC[None, None, None, None, 0, 0, 0].shape, c_dtype + ) + # STG Atom + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) + tTR_gC = tTR_gC[(None, None, None, None, *mma_tile_coord_mnl)] + + # Release TMEM allocation lock + tmem.relinquish_alloc_permit() + + # Wait for accumulator buffer full + acc_full = acc_consumer.wait_and_advance() + + # Copy accumulator to register + cute.copy(tiled_copy_t2r, tTR_tAcc, tTR_rAcc) + acc_vec = tTR_rAcc.load().to(c_dtype) + tTR_rC.store(acc_vec) + # Store C to global memory + cute.copy(simt_atom, tTR_rC, tTR_gC) + + acc_full.release() + + # Deallocate TMEM + cute.arch.barrier() + tmem.free(acc_tmem_ptr) + + return + + +@cute.jit +def my_kernel( + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + sfa_ptr: cute.Pointer, + sfb_ptr: cute.Pointer, + c_ptr: cute.Pointer, + problem_size: tuple, +): + """ + Host-side JIT function to prepare tensors and launch GPU kernel. + """ + m, n, k, l = problem_size + + # Setup attributes that depend on gemm inputs + cta_tile_shape_mnk = ( + mma_tiler_mnk[0], + mma_tiler_mnk[1], + mma_tiler_mnk[2], + ) + a_tensor = cute.make_tensor( + a_ptr, + cute.make_layout( + (m, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), + ), + ) + b_tensor = cute.make_tensor( + b_ptr, + cute.make_layout( + (n, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), + ), + ) + c_tensor = cute.make_tensor( + c_ptr, cute.make_layout((cute.assume(m, 32), n, l), stride=(n, 1, m * n)) + ) + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, sf_vec_size + ) + sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout) + + mma_op = tcgen05.MmaMXF4NVF4Op( + sf_dtype, + (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((1, 1, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute A/B/SFA/SFB/C shared memory layout + a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # TMA load for A + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + a_tensor, + a_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + ) + # TMA load for B + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + b_tensor, + b_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + ) + + # TMA load for SFA + sfa_smem_layout = cute.slice_( + sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfa_tensor, + sfa_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # TMA load for SFB + sfb_smem_layout = cute.slice_( + sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfb_tensor, + sfb_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Compute TMA load bytes + a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) + num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Compute grid size + grid = ( + cute.ceil_div(c_tensor.shape[0], cta_tile_shape_mnk[0]), + cute.ceil_div(c_tensor.shape[1], cta_tile_shape_mnk[1]), + c_tensor.shape[2], + ) + + # Launch the kernel synchronously + kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + c_tensor, + a_smem_layout_staged, + b_smem_layout_staged, + sfa_smem_layout_staged, + sfb_smem_layout_staged, + num_tma_load_bytes, + ).launch( + grid=grid, + block=[threads_per_cta, 1, 1], + cluster=(1, 1, 1), + ) + return + + +# Reorder scale factor from (mn, l, sf_k) to (32, 4, rest_m, 4, rest_k, l) layout +def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): + sf_k = ceil_div(k, sf_vec_size) + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on CPU. + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) + reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) + + # Helper function to convert scale factor tensor to CUTE-format scale factor tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + make_ptr( + cutlass.Float8E4M3FN, + ref_f8_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ), + make_ptr( + cutlass.Float8E4M3FN, + reordered_f8_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ), + mn, + sf_k, + l, + mma_shape, + ) + return reordered_f8_tensor.cuda() + + +# Global cache for compiled kernel +_compiled_kernel_cache = None + +def compile_kernel(): + """ + Compile the kernel once and cache it. + This should be called before any timing measurements. + + Args: + a, b, scale_a, scale_b, c: Sample tensors with the expected shapes and types + + Returns: + The compiled kernel function + """ + global _compiled_kernel_cache + + if _compiled_kernel_cache is not None: + return _compiled_kernel_cache + + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + b_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + + # Compile the kernel + _compiled_kernel_cache = cute.compile(my_kernel, a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (0, 0, 0, 0)) + + return _compiled_kernel_cache + + +def custom_kernel(data: input_t) -> output_t: + """ + Execute the block-scaled GEMM kernel. + + This is the main entry point called by the evaluation framework. + It converts PyTorch tensors to CuTe tensors, launches the kernel, + and returns the result. + + Args: + data: Tuple of (a, b, sfa_cpu, sfb_cpu, c) PyTorch tensors + a: [m, k, l] - Input matrix in float4e2m1fn + b: [n, k, l] - Input vector in float4e2m1fn + sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn + sfb_cpu: [n, k, l] - Scale factors in float8_e4m3fn + c: [m, n, l] - Output vector in float16 + + Returns: + Output tensor c with computed GEMV results + """ + a, b, sfa_cpu, sfb_cpu, c = data + + # Ensure kernel is compiled (will use cached version if available) + compiled_func = compile_kernel() + # Get dimensions from MxKxL layout + m, k, l = a.shape + n, _, _ = b.shape + # Torch use e2m1_x2 data type, thus k is halved + k = k * 2 + + # Create the reordered scale factor tensors from the reference scale factor tensors via CuTe function. + sfa_reordered = create_reordered_scale_factor_tensor(l, m, k, sfa_cpu) + sfb_reordered = create_reordered_scale_factor_tensor(l, n, k, sfb_cpu) + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + b_ptr = make_ptr( + ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, sfa_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, sfb_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + + # Execute the compiled kernel + compiled_func(a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l)) + + return c diff --git a/problems/nvidia/nvfp4_gemm/task.py b/problems/nvidia/nvfp4_gemm/task.py new file mode 100644 index 0000000..4ebbe88 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/task.py @@ -0,0 +1,11 @@ +import torch +from typing import TypedDict, TypeVar + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) +class TestSpec(TypedDict): + m: int + n: int + k: int + l: int + seed: int \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/task.yml b/problems/nvidia/nvfp4_gemm/task.yml new file mode 100644 index 0000000..5aa355e --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/task.yml @@ -0,0 +1,60 @@ +# name: nvfp4-ffma-gemm + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + + You will implement a batched matrix-matrix multiplication kernel optimized for NVIDIA B200. + To be explicit, you will be given a tuple of tensors: + ``` + (a, b, sfa, sfb, c) + ``` + where: + * `a` is M x K x L in K-major order in nvfp4(e2m1) + * `b` is N x K x L in K-major order in nvfp4(e2m1) + * `sfa` is M x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `sfb` is N x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `c` is M x N x L in fp16 + + Matrix sizes `M` is divisible by mma_tiler_mn[0], `N` is divisible by mma_tiler_mn[1], `K` is divisible by 256. + The ranking criteria is the geometric mean of the benchmark results. + For the grand price, your kernel will be evaluated against the speed of light analysis + and the solution closest to the speed of light will be awarded the grand price. + ``` + The speed of light analysis is (using 1.5Ghz clock): + M N K L time[us] + 7168 128 16384 1 8.71 + 4096 128 7168 1 2.18 + 7168 128 2048 1 1.09 + ``` +config: + main: "eval.py" + +templates: + Python: "template.py" + +tests: + - {"m": 128, "n": 256, "k": 256, "l": 1, "seed": 1111} + - {"m": 128, "n": 1536, "k": 7168, "l": 1, "seed": 1111} + - {"m": 128, "n": 3072, "k": 1536, "l": 1, "seed": 1111} + - {"m": 256, "n": 7168, "k": 256, "l": 1, "seed": 1111} + - {"m": 256, "n": 7168, "k": 2048, "l": 1, "seed": 1111} + - {"m": 2304, "n": 4608, "k": 7168, "l": 1, "seed": 1111} + - {"m": 384, "n": 7168, "k": 2304, "l": 1, "seed": 1111} + - {"m": 512, "n": 512, "k": 7168, "l": 1, "seed": 1111} + - {"m": 512, "n": 4096, "k": 512, "l": 1, "seed": 1111} + - {"m": 512, "n": 1536, "k": 7168, "l": 1, "seed": 1111} + +benchmarks: + - {"m": 7168, "n": 128, "k": 16384, "l": 1, "seed": 1111} + - {"m": 4096, "n": 128, "k": 7168, "l": 1, "seed": 1111} + - {"m": 7168, "n": 128, "k": 2048, "l": 1, "seed": 1111} + +ranking_by: "geom" \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/template.py b/problems/nvidia/nvfp4_gemm/template.py new file mode 100644 index 0000000..aa337e9 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/template.py @@ -0,0 +1,23 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference implementation of block-scale fp8 gemm + Args: + data: Tuple that expands to: + a: torch.Tensor[float4e2m1fn] of shape [m, k, l], + b: torch.Tensor[float4e2m1fn] of shape [n, k, l], + sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], + sfb: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], + c: torch.Tensor[float16] of shape [m, n, l] + Returns: + Tensor containing output in float16 + c: torch.Tensor[float16] of shape [m, n, l] + """ + # c: [m, n, l] is pre-allocated memory to avoid timing allocation overhead. + a, b, sfa, sfb, c = data + + # Your implementation here + + return c \ No newline at end of file From 8410579c202decb14b9ce2607df1d020c88275cb Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Wed, 15 Oct 2025 17:09:57 -0700 Subject: [PATCH 10/29] fix typo in comments. --- problems/nvidia/nvfp4_gemm/reference.py | 20 ++++++++---------- problems/nvidia/nvfp4_gemm/submission.py | 27 ++++++++++++++++-------- problems/nvidia/nvfp4_gemm/task.yml | 4 ++-- problems/nvidia/nvfp4_gemm/template.py | 2 +- problems/nvidia/nvfp4_gemv/reference.py | 6 ++---- 5 files changed, 32 insertions(+), 27 deletions(-) diff --git a/problems/nvidia/nvfp4_gemm/reference.py b/problems/nvidia/nvfp4_gemm/reference.py index 18ffb7c..12db1a6 100644 --- a/problems/nvidia/nvfp4_gemm/reference.py +++ b/problems/nvidia/nvfp4_gemm/reference.py @@ -31,10 +31,10 @@ def ref_kernel( """ a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, c_ref = data - # Get dimensions from MxKxL layout - _, _, l = a_ref.shape + # Get dimensions from MxNxL layout + _, _, l = c_ref.shape - # Call torch._scaled_mm to compute the GEMV result + # Call torch._scaled_mm to compute the GEMM result for l_idx in range(l): # Convert the scale factor tensor to blocked format scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) @@ -60,23 +60,22 @@ def generate_input( seed: int, ): """ - Generate input tensors for NVFP4 block-scaled GEMV. - - This follows the pattern from nvfp4_gemv_cute_layout.py for tensor preparation. + Generate input tensors for NVFP4 block-scaled GEMM. Args: m: Number of rows in matrix A - k: Number of columns in A (and length of vector b) + n: Number of columns in matrix B + k: Number of columns in A and rows of B l: Batch size seed: Random seed for reproducibility Returns: Tuple of (a, b, scale_a, scale_b, c) where: a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - b: [1, k, l] - Input vector in torch.float4e2m1fn_x2 data type + b: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b: [1, k, l] - Input scale factors in torch.float8e4m3fn data type - c: [m, 1, l] - Output vector in torch.float16 data type + scale_b: [n, k, l] - Input scale factors in torch.float8e4m3fn data type + c: [m, n, l] - Output matrix in torch.float16 data type """ torch.manual_seed(seed) @@ -84,7 +83,6 @@ def generate_input( a_ref = torch.randint( 0, 2, (l, m, k // 2), dtype=torch.uint8, device="cuda" ).permute(1, 2, 0) - # Pad b tensor's N dimension to 128 to call torch._scaled_mm for nvfp4 dot product computation b_ref = torch.randint( 0, 2, (l, n, k // 2), dtype=torch.uint8, device="cuda" ).permute(1, 2, 0) diff --git a/problems/nvidia/nvfp4_gemm/submission.py b/problems/nvidia/nvfp4_gemm/submission.py index 23c61e3..b40e206 100644 --- a/problems/nvidia/nvfp4_gemm/submission.py +++ b/problems/nvidia/nvfp4_gemm/submission.py @@ -15,18 +15,27 @@ from cutlass.cute.runtime import make_ptr # Kernel configuration parameters -mma_tiler_mnk = (128, 128, 256) # Tile sizes for M, N, K dimensions +# Tile sizes for M, N, K dimensions +mma_tiler_mnk = (128, 128, 256) +# Shape of the K dimension for the MMA instruction mma_inst_shape_k = 64 -ab_dtype = cutlass.Float4E2M1FN # FP4 data type for A and B -sf_dtype = cutlass.Float8E4M3FN # FP8 data type for scale factors -c_dtype = cutlass.Float16 # FP16 output type -sf_vec_size = 16 # Scale factor block size (16 elements share one scale) -threads_per_cta = 128 # Number of threads per CUDA thread block -# stage numbers of shared memory and tmem +# FP4 data type for A and B +ab_dtype = cutlass.Float4E2M1FN +# FP8 data type for scale factors +sf_dtype = cutlass.Float8E4M3FN +# FP16 output type +c_dtype = cutlass.Float16 +# Scale factor block size (16 elements share one scale) +sf_vec_size = 16 +# Number of threads per CUDA thread block +threads_per_cta = 128 +# Stage numbers of shared memory and tmem num_acc_stage = 1 num_ab_stage = 1 +# Total number of columns in tmem num_tmem_alloc_cols = 512 + # Helper function for ceiling division def ceil_div(a, b): return (a + b - 1) // b @@ -58,7 +67,7 @@ def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] -# The CuTe reference implementation for NVFP4 block-scaled GEMV +# The CuTe reference implementation for NVFP4 block-scaled GEMM @cute.kernel def kernel( tiled_mma: cute.TiledMma, @@ -781,7 +790,7 @@ def custom_kernel(data: input_t) -> output_t: c: [m, n, l] - Output vector in float16 Returns: - Output tensor c with computed GEMV results + Output tensor c with computed results """ a, b, sfa_cpu, sfb_cpu, c = data diff --git a/problems/nvidia/nvfp4_gemm/task.yml b/problems/nvidia/nvfp4_gemm/task.yml index 5aa355e..a35b7d8 100644 --- a/problems/nvidia/nvfp4_gemm/task.yml +++ b/problems/nvidia/nvfp4_gemm/task.yml @@ -1,4 +1,4 @@ -# name: nvfp4-ffma-gemm +# name: nvfp4-block-scaled-gemm files: - {"name": "submission.py", "source": "@SUBMISSION@"} @@ -11,7 +11,7 @@ lang: "py" description: | - You will implement a batched matrix-matrix multiplication kernel optimized for NVIDIA B200. + You will implement a block scaled matrix-matrix multiplication kernel optimized for NVIDIA B200. To be explicit, you will be given a tuple of tensors: ``` (a, b, sfa, sfb, c) diff --git a/problems/nvidia/nvfp4_gemm/template.py b/problems/nvidia/nvfp4_gemm/template.py index aa337e9..17d6347 100644 --- a/problems/nvidia/nvfp4_gemm/template.py +++ b/problems/nvidia/nvfp4_gemm/template.py @@ -3,7 +3,7 @@ def custom_kernel(data: input_t) -> output_t: """ - Reference implementation of block-scale fp8 gemm + Reference implementation of block-scale fp4 gemm Args: data: Tuple that expands to: a: torch.Tensor[float4e2m1fn] of shape [m, k, l], diff --git a/problems/nvidia/nvfp4_gemv/reference.py b/problems/nvidia/nvfp4_gemv/reference.py index a6101a6..4b236af 100644 --- a/problems/nvidia/nvfp4_gemv/reference.py +++ b/problems/nvidia/nvfp4_gemv/reference.py @@ -31,8 +31,8 @@ def ref_kernel( """ a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, c_ref = data - # Get dimensions from MxKxL layout - _, _, l = a_ref.shape + # Get dimensions from MxNxL layout + _, _, l = c_ref.shape # Call torch._scaled_mm to compute the GEMV result for l_idx in range(l): @@ -61,8 +61,6 @@ def generate_input( """ Generate input tensors for NVFP4 block-scaled GEMV. - This follows the pattern from nvfp4_gemv_cute_layout.py for tensor preparation. - Args: m: Number of rows in matrix A k: Number of columns in A (and length of vector b) From 82e09121702a51169131e10b730a661abe870def Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Wed, 15 Oct 2025 17:10:35 -0700 Subject: [PATCH 11/29] add dual gemm example --- problems/nvidia/nvfp4_dual_gemm/reference.py | 148 +++ problems/nvidia/nvfp4_dual_gemm/submission.py | 1017 +++++++++++++++++ problems/nvidia/nvfp4_dual_gemm/task.py | 11 + problems/nvidia/nvfp4_dual_gemm/task.yml | 62 + problems/nvidia/nvfp4_dual_gemm/template.py | 25 + 5 files changed, 1263 insertions(+) create mode 100644 problems/nvidia/nvfp4_dual_gemm/reference.py create mode 100644 problems/nvidia/nvfp4_dual_gemm/submission.py create mode 100644 problems/nvidia/nvfp4_dual_gemm/task.py create mode 100644 problems/nvidia/nvfp4_dual_gemm/task.yml create mode 100644 problems/nvidia/nvfp4_dual_gemm/template.py diff --git a/problems/nvidia/nvfp4_dual_gemm/reference.py b/problems/nvidia/nvfp4_dual_gemm/reference.py new file mode 100644 index 0000000..ffa56ef --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/reference.py @@ -0,0 +1,148 @@ +import torch +from task import input_t, output_t +from utils import make_match_reference + +# Scaling factor vector size +sf_vec_size = 16 + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + +# Helper function to convert scale factor tensor to blocked format +def to_blocked(input_matrix): + rows, cols = input_matrix.shape + + # Please ensure rows and cols are multiples of 128 and 4 respectively + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + padded = input_matrix + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() + +def ref_kernel( + data: input_t, +) -> output_t: + """ + PyTorch reference implementation of NVFP4 block-scaled GEMM. + """ + a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, c_ref = data + + # Get dimensions from MxNxL layout + m, n, l = c_ref.shape + + # Call torch._scaled_mm to compute the GEMV result + ref1 = torch.empty( + (l, m, n), + dtype=torch.float32, + device="cuda", + ).permute(1, 2, 0) + ref2 = torch.empty( + (l, m, n), + dtype=torch.float32, + device="cuda", + ).permute(1, 2, 0) + for l_idx in range(l): + # Convert the scale factor tensor to blocked format + scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) + scale_b1 = to_blocked(sfb1_ref_cpu[:, :, l_idx]) + scale_b2 = to_blocked(sfb2_ref_cpu[:, :, l_idx]) + # (m, k) @ (n, k).T -> (m, n) + res1 = torch._scaled_mm( + a_ref[:, :, l_idx], + b1_ref[:, :, l_idx].transpose(0, 1), + scale_a.cuda(), + scale_b1.cuda(), + bias=None, + out_dtype=torch.float32, + ) + ref1[:, :, l_idx] = res1 + + res2 = torch._scaled_mm( + a_ref[:, :, l_idx], + b2_ref[:, :, l_idx].transpose(0, 1), + scale_a.cuda(), + scale_b2.cuda(), + bias=None, + out_dtype=torch.float32, + ) + ref2[:, :, l_idx] = res2 + # Do silu on the first GEMM result and multiply with the second GEMM result + c_ref = (torch.nn.functional.silu(ref1) * ref2).to(torch.float16) + return c_ref + + +def generate_input( + m: int, + n: int, + k: int, + l: int, + seed: int, +): + """ + Generate input tensors for NVFP4 block-scaled dual GEMM with silu activation, + C = silu(A @ B1) * (A @ B2). + + Args: + m: Number of rows in matrix A + n: Number of columns in matrix B1 and B2 + k: Number of columns in A and rows of B1 and B2 + l: Batch size + seed: Random seed for reproducibility + + Returns: + Tuple of (a, b, scale_a, scale_b, c) where: + a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + b1: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + b2: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b1: [n, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b2: [n, k, l] - Input scale factors in torch.float8e4m3fn data type + c: [m, n, l] - Output matrix in torch.float16 data type + """ + torch.manual_seed(seed) + + # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type + a_ref = torch.randint( + 0, 2, (l, m, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + b1_ref = torch.randint( + 0, 2, (l, n, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + b2_ref = torch.randint( + 0, 2, (l, n, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + a_ref = a_ref.view(torch.float4_e2m1fn_x2) + b1_ref = b1_ref.view(torch.float4_e2m1fn_x2) + b2_ref = b2_ref.view(torch.float4_e2m1fn_x2) + + # Create float16 output tensor + c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( + 1, 2, 0 + ) + + # Helper function to prepare the scale factor tensors + def create_scale_factor_tensors(l, mn, sf_k): + # Create the reference scale factor tensor (mn, l, sf_k) on CPU. + ref_shape = (l, mn, sf_k) + ref_permute_order = (1, 2, 0) + # Init with uint8 tensor, then convert to float8_e4m3fn + ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8) + ref_f8_torch_tensor_cpu = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) + # permute to match ref_permute_order + ref_f8_torch_tensor_cpu_permuted = ref_f8_torch_tensor_cpu.permute( + *ref_permute_order + ) + return ref_f8_torch_tensor_cpu_permuted + + sf_k = ceil_div(k, sf_vec_size) + sfa_ref_cpu = create_scale_factor_tensors(l, m, sf_k) + sfb1_ref_cpu = create_scale_factor_tensors(l, n, sf_k) + sfb2_ref_cpu = create_scale_factor_tensors(l, n, sf_k) + + return (a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, c_ref) + +check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) diff --git a/problems/nvidia/nvfp4_dual_gemm/submission.py b/problems/nvidia/nvfp4_dual_gemm/submission.py new file mode 100644 index 0000000..6ecddbc --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/submission.py @@ -0,0 +1,1017 @@ +from torch._higher_order_ops.torchbind import call_torchbind_fake +import cuda.bindings.driver as cuda + +import torch +from task import input_t, output_t + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import make_ptr + +# Kernel configuration parameters +# Tile sizes for M, N, K dimensions +mma_tiler_mnk= (128, 128, 256) +# Shape of the K dimension for the MMA instruction +mma_inst_shape_k = 64 +# FP4 data type for A and B +ab_dtype = cutlass.Float4E2M1FN +# FP8 data type for scale factors +sf_dtype = cutlass.Float8E4M3FN +# FP16 output type +c_dtype = cutlass.Float16 +# Scale factor block size (16 elements share one scale) +sf_vec_size = 16 +# Number of threads per CUDA thread block +threads_per_cta = 128 +# Stage numbers of shared memory and tmem +num_acc_stage = 1 +num_ab_stage = 1 +# Total number of columns in tmem +num_tmem_alloc_cols = 512 + + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + + +# Helper function to reorder the scale factor tensor to match the layout defined in +# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_ptr: cute.Pointer, + sf_mma_ptr: cute.Pointer, + mn: int, + sf_k: int, + l: int, + mma_shape: tuple, +): + mma_permute_order = (3, 4, 1, 5, 2, 0) + permuted_shape = tuple(mma_shape[i] for i in mma_permute_order) + cute_layout = cute.make_ordered_layout(permuted_shape, order=(2, 1, 4, 0, 3, 5)) + + sf_ref_tensor = cute.make_tensor( + sf_ref_ptr, cute.make_layout((mn, sf_k, l), stride=(sf_k, 1, mn * sf_k)) + ) + sf_mma_tensor = cute.make_tensor(sf_mma_ptr, cute_layout) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] + + +# GPU device kernel +@cute.kernel +def kernel( + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b1: cute.CopyAtom, + mB_nkl1: cute.Tensor, + tma_atom_b2: cute.CopyAtom, + mB_nkl2: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb1: cute.CopyAtom, + mSFB_nkl1: cute.Tensor, + tma_atom_sfb2: cute.CopyAtom, + mSFB_nkl2: cute.Tensor, + mC_mnl: cute.Tensor, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + num_tma_load_bytes: cutlass.Constexpr[int], + epilogue_op: cutlass.Constexpr = lambda x: x + * (1.0 / (1.0 + cute.math.exp(-x, fastmath=True))), +): + """ + GPU device kernel performing the batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + tidx = cute.arch.thread_idx() + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + + # Coords outside cluster + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Define shared storage for kernel + # + @cute.struct + class SharedStorage: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + # (MMA, MMA_M, MMA_K, STAGE) + sA = smem.allocate_tensor( + element_type=ab_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB1 = smem.allocate_tensor( + element_type=ab_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB2 = smem.allocate_tensor( + element_type=ab_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfa_smem_layout_staged, + byte_alignment=128, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB1 = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB2 = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) + + # + # Initialize mainloop ab_pipeline, acc_pipeline and their states + # + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=num_tma_load_bytes, + ).make_participants() + acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=num_acc_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + threads_per_cta, + ), + ).make_participants() + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl1 = cute.local_tile( + mB_nkl1, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl2 = cute.local_tile( + mB_nkl2, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + gSFB_nkl1 = cute.local_tile( + mSFB_nkl1, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl2 = cute.local_tile( + mSFB_nkl2, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/SFA/SFB/C + # + # (MMA, MMA_M, MMA_K, RestK) + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB1 = thr_mma.partition_B(gB_nkl1) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB2 = thr_mma.partition_B(gB_nkl2) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB1 = thr_mma.partition_B(gSFB_nkl1) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB2 = thr_mma.partition_B(gSFB_nkl2) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B/SFA/SFB + # + # TMA load A partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + 0, + cute.make_layout(1), + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B1 partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB1, tBgB1 = cpasync.tma_partition( + tma_atom_b1, + 0, + cute.make_layout(1), + cute.group_modes(sB1, 0, 3), + cute.group_modes(tCgB1, 0, 3), + ) + # TMA load B2 partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB2, tBgB2 = cpasync.tma_partition( + tma_atom_b2, + 0, + cute.make_layout(1), + cute.group_modes(sB2, 0, 3), + cute.group_modes(tCgB2, 0, 3), + ) + + # TMALDG_SFA partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + 0, + cute.make_layout(1), + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMALDG SFB1 partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB1, tBgSFB1 = cpasync.tma_partition( + tma_atom_sfb1, + 0, + cute.make_layout(1), + cute.group_modes(sSFB1, 0, 3), + cute.group_modes(tCgSFB1, 0, 3), + ) + tBsSFB1 = cute.filter_zeros(tBsSFB1) + tBgSFB1 = cute.filter_zeros(tBgSFB1) + # TMALDG SFB2 partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB2, tBgSFB2 = cpasync.tma_partition( + tma_atom_sfb2, + 0, + cute.make_layout(1), + cute.group_modes(sSFB2, 0, 3), + cute.group_modes(tCgSFB2, 0, 3), + ) + tBsSFB2 = cute.filter_zeros(tBsSFB2) + tBgSFB2 = cute.filter_zeros(tBgSFB2) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB1 = tiled_mma.make_fragment_B(sB1) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB2 = tiled_mma.make_fragment_B(sB2) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + + # + # Alloc tensor memory buffer + # Make ACC1 and ACC2 tmem tensor + # ACC1 += A @ B1 + # ACC2 += A @ B2 + # + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=threads_per_cta, + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + ) + tmem.allocate(num_tmem_alloc_cols) + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) + tCtAcc1 = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + acc_tmem_ptr1 = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc1), + dtype=cutlass.Float32, + ) + tCtAcc2 = cute.make_tensor(acc_tmem_ptr1, tCtAcc_fake.layout) + + # + # Make SFA/SFB1/SFB2 tmem tensor + # + # SFA tmem layout: (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + # Get SFA tmem ptr + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) + + tcgen05.find_tmem_tensor_col_offset(tCtAcc2), + dtype=sf_dtype, + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # SFB1, SFB2 tmem layout: (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + # Get SFB1 tmem ptr + sfb_tmem_ptr1 = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) + + tcgen05.find_tmem_tensor_col_offset(tCtAcc2) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=sf_dtype, + ) + tCtSFB1 = cute.make_tensor(sfb_tmem_ptr1, tCtSFB_layout) + # Get SFB2 tmem ptr + sfb_tmem_ptr2 = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) + + tcgen05.find_tmem_tensor_col_offset(tCtAcc2) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + + tcgen05.find_tmem_tensor_col_offset(tCtSFB1), + dtype=sf_dtype, + ) + tCtSFB2 = cute.make_tensor(sfb_tmem_ptr2, tCtSFB_layout) + + # + # Partition for S2T copy of SFA/SFB1/SFB2 + # + # Make S2T CopyAtom + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), + sf_dtype, + ) + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFA_compact = cute.filter_zeros(sSFA) + # (MMA, MMA_MN, MMA_K) + tCtSFA_compact = cute.filter_zeros(tCtSFA) + tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) + thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) + + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFB1_compact = cute.filter_zeros(sSFB1) + # (MMA, MMA_MN, MMA_K) + tCtSFB1_compact = cute.filter_zeros(tCtSFB1) + tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB1_compact) + thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB1_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB1_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB1_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfb, tCsSFB1_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFB1_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB1_compact) + + # SFB2 S2T copy and partition + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFB2_compact = cute.filter_zeros(sSFB2) + # (MMA, MMA_MN, MMA_K) + tCtSFB2_compact = cute.filter_zeros(tCtSFB2) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB2_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB2_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB2_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfb, tCsSFB2_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFB2_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB2_compact) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB1 = tBgB1[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB2 = tBgB2[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgSFB1 = tBgSFB1[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgSFB2 = tBgSFB2[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # + # Execute Data copy and Math computation in the k_tile loop + # + if warp_idx == 0: + # Wait for accumulator buffer empty + acc_empty = acc_producer.acquire_and_advance() + # Set ACCUMULATE field to False for the first k_tile iteration + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + # Execute k_tile loop + for k_tile in range(k_tile_cnt): + # Wait for AB buffer empty + ab_empty = ab_producer.acquire_and_advance() + + # TMALDG A/B1/B2/SFA/SFB1/SFB2 + cute.copy( + tma_atom_a, + tAgA[(None, ab_empty.count)], + tAsA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_b1, + tBgB1[(None, ab_empty.count)], + tBsB1[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_b2, + tBgB2[(None, ab_empty.count)], + tBsB2[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfa, + tAgSFA[(None, ab_empty.count)], + tAsSFA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfb1, + tBgSFB1[(None, ab_empty.count)], + tBsSFB1[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfb2, + tBgSFB2[(None, ab_empty.count)], + tBsSFB2[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + + # Wait for AB buffer full + ab_full = ab_consumer.wait_and_advance() + + # Copy SFA/SFB1/SFB2 to tmem + s2t_stage_coord = (None, None, None, None, ab_full.index) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB1_compact_s2t_staged = tCsSFB1_compact_s2t[s2t_stage_coord] + tCsSFB2_compact_s2t_staged = tCsSFB2_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB1_compact_s2t_staged, + tCtSFB1_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB2_compact_s2t_staged, + tCtSFB2_compact_s2t, + ) + + # tCtAcc1 += tCrA * tCrSFA * tCrB1 * tCrSFB1 + # tCtAcc2 += tCrA * tCrSFA * tCrB2 * tCrSFB2 + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_full.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB1[sf_kblock_coord].iterator, + ) + cute.gemm( + tiled_mma, + tCtAcc1, + tCrA[kblock_coord], + tCrB1[kblock_coord], + tCtAcc1, + ) + + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB2[sf_kblock_coord].iterator, + ) + cute.gemm( + tiled_mma, + tCtAcc2, + tCrA[kblock_coord], + tCrB2[kblock_coord], + tCtAcc2, + ) + + # Enable accumulate on tCtAcc1/tCtAcc2 after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_full.release() + acc_empty.commit() + + # + # Epilogue + # Partition for epilogue + # + op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) + copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc1) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc1 = thr_copy_t2r.partition_S(tCtAcc1) + # (T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc2 = thr_copy_t2r.partition_S(tCtAcc2) + # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(tCgC) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc1 = cute.make_fragment( + tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 + ) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc2 = cute.make_fragment( + tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 + ) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rC = cute.make_fragment( + tTR_gC[None, None, None, None, 0, 0, 0].shape, c_dtype + ) + # STG Atom + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) + tTR_gC = tTR_gC[(None, None, None, None, *mma_tile_coord_mnl)] + + # Release tensor memory allocation lock + if warp_idx == 0: + cute.arch.relinquish_tmem_alloc_permit() + # Wait for accumulator buffer full + acc_full = acc_consumer.wait_and_advance() + + # Copy accumulator to register + cute.copy(tiled_copy_t2r, tTR_tAcc1, tTR_rAcc1) + cute.copy(tiled_copy_t2r, tTR_tAcc2, tTR_rAcc2) + + # Silu activation on acc1 and multiply with acc2 + acc_vec1 = epilogue_op(tTR_rAcc1.load()) + acc_vec2 = tTR_rAcc2.load() + acc_vec = acc_vec1 * acc_vec2 + + tTR_rC.store(acc_vec.to(c_dtype)) + # Store C to global memory + cute.copy(simt_atom, tTR_rC, tTR_gC) + + acc_full.release() + # Deallocate TMEM + cute.arch.barrier() + tmem.free(acc_tmem_ptr) + return + + +@cute.jit +def my_kernel( + a_ptr: cute.Pointer, + b1_ptr: cute.Pointer, + b2_ptr: cute.Pointer, + sfa_ptr: cute.Pointer, + sfb1_ptr: cute.Pointer, + sfb2_ptr: cute.Pointer, + c_ptr: cute.Pointer, + problem_size: tuple, + epilogue_op: cutlass.Constexpr = lambda x: x + * (1.0 / (1.0 + cute.math.exp(-x, fastmath=True))), +): + """ + Host-side JIT function to prepare tensors and launch GPU kernel. + """ + m, n, k, l = problem_size + + # Setup attributes that depend on gemm inputs + cta_tile_shape_mnk = ( + mma_tiler_mnk[0], + mma_tiler_mnk[1], + mma_tiler_mnk[2], + ) + a_tensor = cute.make_tensor( + a_ptr, + cute.make_layout( + (m, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), + ), + ) + b_tensor1 = cute.make_tensor( + b1_ptr, + cute.make_layout( + (n, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), + ), + ) + b_tensor2 = cute.make_tensor( + b2_ptr, + cute.make_layout( + (n, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), + ), + ) + c_tensor = cute.make_tensor( + c_ptr, cute.make_layout((cute.assume(m, 32), n, l), stride=(n, 1, m * n)) + ) + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor1.shape, sf_vec_size + ) + sfb_tensor1 = cute.make_tensor(sfb1_ptr, sfb_layout) + sfb_tensor2 = cute.make_tensor(sfb2_ptr, sfb_layout) + + mma_op = tcgen05.MmaMXF4NVF4Op( + sf_dtype, + (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((1, 1, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute A/B/SFA/SFB/C shared memory layout + a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + # B1 and B2 have the same size thus share the same smem layout + b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + # SFB1 and SFB2 have the same size thus share the same smem layout + sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # TMA load for A + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + a_tensor, + a_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + ) + # TMA load for B1 + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b1, tma_tensor_b1 = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + b_tensor1, + b_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + ) + # TMA load for B2 + tma_atom_b2, tma_tensor_b2 = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + b_tensor2, + b_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + ) + # TMA load for SFA + sfa_smem_layout = cute.slice_( + sfa_smem_layout_staged , (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfa_tensor, + sfa_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + internal_type=cutlass.Int16, + ) + # TMA load for SFB1 + sfb_smem_layout = cute.slice_( + sfb_smem_layout_staged , (None, None, None, 0) + ) + tma_atom_sfb1, tma_tensor_sfb1 = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfb_tensor1, + sfb_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + internal_type=cutlass.Int16, + ) + # TMA load for SFB2 + tma_atom_sfb2, tma_tensor_sfb2 = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfb_tensor2, + sfb_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + internal_type=cutlass.Int16, + ) + + # Compute TMA load bytes + a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) + num_tma_load_bytes = ( + a_copy_size + b_copy_size * 2 + sfa_copy_size + sfb_copy_size * 2 + ) * atom_thr_size + + # Compute grid size + grid = ( + cute.ceil_div(c_tensor.shape[0], cta_tile_shape_mnk[0]), + cute.ceil_div(c_tensor.shape[1], cta_tile_shape_mnk[1]), + c_tensor.shape[2], + ) + + # Launch the kernel synchronously + kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b1, + tma_tensor_b1, + tma_atom_b2, + tma_tensor_b2, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb1, + tma_tensor_sfb1, + tma_atom_sfb2, + tma_tensor_sfb2, + c_tensor, + a_smem_layout_staged, + b_smem_layout_staged, + sfa_smem_layout_staged, + sfb_smem_layout_staged, + num_tma_load_bytes, + epilogue_op, + ).launch( + grid=grid, + block=[threads_per_cta, 1, 1], + cluster=(1, 1, 1), + ) + return + + +# Reorder scale factor from (mn, l, sf_k) to (32, 4, rest_m, 4, rest_k, l) layout +def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): + sf_k = ceil_div(k, sf_vec_size) + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on CPU. + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) + reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) + + # Helper function to convert scale factor tensor to CUTE-format scale factor tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + make_ptr( + cutlass.Float8E4M3FN, + ref_f8_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ), + make_ptr( + cutlass.Float8E4M3FN, + reordered_f8_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ), + mn, + sf_k, + l, + mma_shape, + ) + return reordered_f8_tensor.cuda() + + +# Global cache for compiled kernel +_compiled_kernel_cache = None + +def compile_kernel(): + """ + Compile the kernel once and cache it. + This should be called before any timing measurements. + + Args: + a, b1, b2, sfa, sfb1, sfb2, c: Sample tensors with the expected shapes and types + + Returns: + The compiled kernel function + """ + global _compiled_kernel_cache + + if _compiled_kernel_cache is not None: + return _compiled_kernel_cache + + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + b1_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + b2_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + sfb1_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + sfb2_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + + # Compile the kernel + _compiled_kernel_cache = cute.compile(my_kernel, a_ptr, b1_ptr, b2_ptr, sfa_ptr, sfb1_ptr, sfb2_ptr, c_ptr, (0, 0, 0, 0)) + + return _compiled_kernel_cache + + +def custom_kernel(data: input_t) -> output_t: + """ + Execute the block-scaled dual GEMM kernel with silu activation, + C = silu(A @ B1) * (A @ B2). + + This is the main entry point called by the evaluation framework. + It converts PyTorch tensors to CuTe tensors, launches the kernel, + and returns the result. + + Args: + data: Tuple of (a, b1, b2, sfa_cpu, sfb1_cpu, sfb2_cpu, c) PyTorch tensors + a: [m, k, l] - Input matrix in float4e2m1fn + b1: [n, k, l] - Input matrix in float4e2m1fn + b2: [n, k, l] - Input matrix in float4e2m1fn + sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn + sfb1_cpu: [n, k, l] - Scale factors in float8_e4m3fn + sfb2_cpu: [n, k, l] - Scale factors in float8_e4m3fn + c: [m, n, l] - Output vector in float16 + + Returns: + Output tensor c with computed results + """ + a, b1, b2, sfa_cpu, sfb1_cpu, sfb2_cpu, c = data + + # Ensure kernel is compiled (will use cached version if available) + compiled_func = compile_kernel() + # Get dimensions from MxKxL layout + _, k, _ = a.shape + m, n, l = c.shape + # Torch use e2m1_x2 data type, thus k is halved + k = k * 2 + + # Create the reordered scale factor tensors from the reference scale factor tensors via CuTe function. + sfa_reordered = create_reordered_scale_factor_tensor(l, m, k, sfa_cpu) + sfb1_reordered = create_reordered_scale_factor_tensor(l, n, k, sfb1_cpu) + sfb2_reordered = create_reordered_scale_factor_tensor(l, n, k, sfb2_cpu) + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + b1_ptr = make_ptr( + ab_dtype, b1.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + b2_ptr = make_ptr( + ab_dtype, b2.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, sfa_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + sfb1_ptr = make_ptr( + sf_dtype, sfb1_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + sfb2_ptr = make_ptr( + sf_dtype, sfb2_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + + # Execute the compiled kernel + compiled_func(a_ptr, b1_ptr, b2_ptr, sfa_ptr, sfb1_ptr, sfb2_ptr, c_ptr, (m, n, k, l)) + + return c diff --git a/problems/nvidia/nvfp4_dual_gemm/task.py b/problems/nvidia/nvfp4_dual_gemm/task.py new file mode 100644 index 0000000..66db735 --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/task.py @@ -0,0 +1,11 @@ +import torch +from typing import TypedDict, TypeVar + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) +class TestSpec(TypedDict): + m: int + n: int + k: int + l: int + seed: int \ No newline at end of file diff --git a/problems/nvidia/nvfp4_dual_gemm/task.yml b/problems/nvidia/nvfp4_dual_gemm/task.yml new file mode 100644 index 0000000..6d59274 --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/task.yml @@ -0,0 +1,62 @@ +# name: nvfp4-dual-gemm + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + + You will implement a block scaled dual matrix-matrix multiplication kernel with silu activation optimized for NVIDIA B200. + To be explicit, you will be given a tuple of tensors: + ``` + (a, b1, b2, sfa, sfb1, sfb2, c) + ``` + where: + * `a` is M x K x L in K-major order in nvfp4(e2m1) + * `b1` is N x K x L in K-major order in nvfp4(e2m1) + * `b2` is N x K x L in K-major order in nvfp4(e2m1) + * `sfa` is M x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `sfb1` is N x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `sfb2` is N x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `c` is M x N x L in fp16 + + Matrix sizes `M` is divisible by mma_tiler_mn[0], `N` is divisible by mma_tiler_mn[1], `K` is divisible by 256. + The ranking criteria is the geometric mean of the benchmark results. + For the grand price, your kernel will be evaluated against the speed of light analysis + and the solution closest to the speed of light will be awarded the grand price. + ``` + The speed of light analysis is (using 1.5Ghz clock): + M N K L time[us] + 7168 128 16384 1 8.71 + 4096 128 7168 1 2.18 + 7168 128 2048 1 1.09 + ``` +config: + main: "eval.py" + +templates: + Python: "template.py" + +tests: + - {"m": 128, "n": 256, "k": 256, "l": 1, "seed": 1111} + - {"m": 128, "n": 1536, "k": 7168, "l": 1, "seed": 1111} + - {"m": 128, "n": 3072, "k": 1536, "l": 1, "seed": 1111} + - {"m": 256, "n": 7168, "k": 256, "l": 1, "seed": 1111} + - {"m": 256, "n": 7168, "k": 2048, "l": 1, "seed": 1111} + - {"m": 2304, "n": 4608, "k": 7168, "l": 1, "seed": 1111} + - {"m": 384, "n": 7168, "k": 2304, "l": 1, "seed": 1111} + - {"m": 512, "n": 512, "k": 7168, "l": 1, "seed": 1111} + - {"m": 512, "n": 4096, "k": 512, "l": 1, "seed": 1111} + - {"m": 512, "n": 1536, "k": 7168, "l": 1, "seed": 1111} + +benchmarks: + - {"m": 7168, "n": 128, "k": 16384, "l": 1, "seed": 1111} + - {"m": 4096, "n": 128, "k": 7168, "l": 1, "seed": 1111} + - {"m": 7168, "n": 128, "k": 2048, "l": 1, "seed": 1111} + +ranking_by: "geom" \ No newline at end of file diff --git a/problems/nvidia/nvfp4_dual_gemm/template.py b/problems/nvidia/nvfp4_dual_gemm/template.py new file mode 100644 index 0000000..2509200 --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/template.py @@ -0,0 +1,25 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference implementation of block-scale fp4 dual gemm with silu activation + Args: + data: Tuple that expands to: + a: torch.Tensor[float4e2m1fn] of shape [m, k, l], + b1: torch.Tensor[float4e2m1fn] of shape [n, k, l], + b2: torch.Tensor[float4e2m1fn] of shape [n, k, l], + sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], + sfb1: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], + sfb2: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], + c: torch.Tensor[float16] of shape [m, n, l] + Returns: + Tensor containing output in float16 + c: torch.Tensor[float16] of shape [m, n, l] + """ + # c: [m, n, l] is pre-allocated memory to avoid timing allocation overhead. + a, b, sfa, sfb, c = data + + # Your implementation here + + return c \ No newline at end of file From 349297242a3f5da3c962bc9ad6ee121cdc6471c2 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Thu, 16 Oct 2025 01:57:01 -0700 Subject: [PATCH 12/29] add group nvfp4 example --- problems/nvidia/nvfp4_dual_gemm/task.yml | 9 +- problems/nvidia/nvfp4_gemm/task.yml | 8 +- problems/nvidia/nvfp4_gemv/task.yml | 6 +- problems/nvidia/nvfp4_group_gemm/reference.py | 126 ++ .../nvidia/nvfp4_group_gemm/submission.py | 1104 +++++++++++++++++ problems/nvidia/nvfp4_group_gemm/task.py | 8 + problems/nvidia/nvfp4_group_gemm/task.yml | 65 + problems/nvidia/nvfp4_group_gemm/template.py | 28 + problems/nvidia/nvfp4_group_gemm/utils.py | 176 +++ 9 files changed, 1519 insertions(+), 11 deletions(-) create mode 100644 problems/nvidia/nvfp4_group_gemm/reference.py create mode 100644 problems/nvidia/nvfp4_group_gemm/submission.py create mode 100644 problems/nvidia/nvfp4_group_gemm/task.py create mode 100644 problems/nvidia/nvfp4_group_gemm/task.yml create mode 100644 problems/nvidia/nvfp4_group_gemm/template.py create mode 100644 problems/nvidia/nvfp4_group_gemm/utils.py diff --git a/problems/nvidia/nvfp4_dual_gemm/task.yml b/problems/nvidia/nvfp4_dual_gemm/task.yml index 6d59274..33ceff4 100644 --- a/problems/nvidia/nvfp4_dual_gemm/task.yml +++ b/problems/nvidia/nvfp4_dual_gemm/task.yml @@ -31,10 +31,11 @@ description: | and the solution closest to the speed of light will be awarded the grand price. ``` The speed of light analysis is (using 1.5Ghz clock): - M N K L time[us] - 7168 128 16384 1 8.71 - 4096 128 7168 1 2.18 - 7168 128 2048 1 1.09 + M N K L time[us] + 128 4096 7168 1 1.09 + 512 4096 7168 1 4.36 + 128 3072 4096 1 0.47 + 512 3072 7168 1 3.27 ``` config: main: "eval.py" diff --git a/problems/nvidia/nvfp4_gemm/task.yml b/problems/nvidia/nvfp4_gemm/task.yml index a35b7d8..94334f5 100644 --- a/problems/nvidia/nvfp4_gemm/task.yml +++ b/problems/nvidia/nvfp4_gemm/task.yml @@ -29,10 +29,10 @@ description: | and the solution closest to the speed of light will be awarded the grand price. ``` The speed of light analysis is (using 1.5Ghz clock): - M N K L time[us] - 7168 128 16384 1 8.71 - 4096 128 7168 1 2.18 - 7168 128 2048 1 1.09 + M N K L time[us] + 128 7168 16384 1 4.36 + 128 4096 7168 1 1.09 + 128 7168 2048 1 0.55 ``` config: main: "eval.py" diff --git a/problems/nvidia/nvfp4_gemv/task.yml b/problems/nvidia/nvfp4_gemv/task.yml index dc0e4e5..01f425b 100644 --- a/problems/nvidia/nvfp4_gemv/task.yml +++ b/problems/nvidia/nvfp4_gemv/task.yml @@ -30,9 +30,9 @@ description: | ``` The speed of light analysis is (using 1.5Ghz clock): M K L time[us] - 7168 16384 1 8.71 - 4096 7168 1 2.18 - 7168 2048 1 1.09 + 7168 16384 1 7.65 + 4096 7168 1 1.91 + 7168 2048 1 0.96 ``` config: main: "eval.py" diff --git a/problems/nvidia/nvfp4_group_gemm/reference.py b/problems/nvidia/nvfp4_group_gemm/reference.py new file mode 100644 index 0000000..4b9d8dd --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/reference.py @@ -0,0 +1,126 @@ +import torch +from task import input_t, output_t +from utils import make_match_reference + +# Scaling factor vector size +sf_vec_size = 16 + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + +# Helper function to convert scale factor tensor to blocked format +def to_blocked(input_matrix): + rows, cols = input_matrix.shape + + # Please ensure rows and cols are multiples of 128 and 4 respectively + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + padded = input_matrix + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() + +def ref_kernel( + data: input_t, +) -> output_t: + """ + PyTorch reference implementation of NVFP4 block-scaled group GEMM. + """ + abc_tensors, sfasfb_tensors, problem_sizes = data + + result_tensors = [] + for i, ( + (a_ref, b_ref, c_ref), + (sfa_ref, sfb_ref), + (m, n, k, l), + ) in enumerate( + zip( + abc_tensors, + sfasfb_tensors, + problem_sizes, + ) + ): + for l_idx in range(l): + # Convert the scale factor tensor to blocked format + scale_a = to_blocked(sfa_ref[:, :, l_idx]) + scale_b = to_blocked(sfb_ref[:, :, l_idx]) + # (m, k) @ (n, k).T -> (m, n) + res = torch._scaled_mm( + a_ref[:, :, l_idx].view(torch.float4_e2m1fn_x2), + b_ref[:, :, l_idx].transpose(0, 1).view(torch.float4_e2m1fn_x2), + scale_a.cuda(), + scale_b.cuda(), + bias=None, + out_dtype=torch.float16, + ) + c_ref[:, :, l_idx] = res + result_tensors.append((c_ref)) + return result_tensors + +def generate_input( + m: int, + n: int, + k: int, + g: int, + seed: int, +): + """ + Generate input tensors for NVFP4 block-scaled group GEMM. + Each group can have different m, n, k, l. + + Args: + problem_sizes: List of tuples (m, n, k, l) for each problem + m: Number of rows in matrix A + n: Number of columns in matrix B + k: Number of columns in A and rows of B + l: Batch size, always is 1 + groups: Number of groups + seed: Random seed for reproducibility + + Returns: + Tuple of (list(tuple(a, b, c)), list(tuple(sfa, sfb)), list(tuple(m, n, k, l))) where each group has its own a, b, c, sfa, sfb. + a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + b: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b: [n, k, l] - Input scale factors in torch.float8e4m3fn data type + c: [m, n, l] - Output matrix in torch.float16 data type + """ + torch.manual_seed(seed) + + abc_tensors = [] + sfasfb_tensors = [] + problem_sizes = [] + l = 1 + # Generate a, b, c, sfa, sfb tensors for all groups + for group_idx in range(g): + a_ref = torch.randint( + 0, 2, (l, m, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + b_ref = torch.randint( + 0, 2, (l, n, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + a_ref = a_ref.view(torch.float4_e2m1fn_x2) + b_ref = b_ref.view(torch.float4_e2m1fn_x2) + + c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( + 1, 2, 0 + ) + + sf_k = ceil_div(k, sf_vec_size) + sfa_ref_cpu = torch.randint( + 1, 3, (l, m, sf_k), dtype=torch.int8 + ).to(dtype=torch.float8_e4m3fn).permute(1, 2, 0) + sfb_ref_cpu = torch.randint( + 1, 3, (l, n, sf_k), dtype=torch.int8 + ).to(dtype=torch.float8_e4m3fn).permute(1, 2, 0) + + abc_tensors.append((a_ref, b_ref, c_ref)) + sfasfb_tensors.append((sfa_ref_cpu, sfb_ref_cpu)) + problem_sizes.append((m, n, k, l)) + + return (abc_tensors, sfasfb_tensors, problem_sizes) + +check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) diff --git a/problems/nvidia/nvfp4_group_gemm/submission.py b/problems/nvidia/nvfp4_group_gemm/submission.py new file mode 100644 index 0000000..c7affe3 --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/submission.py @@ -0,0 +1,1104 @@ +from torch._higher_order_ops.torchbind import call_torchbind_fake +import cuda.bindings.driver as cuda +import functools +from typing import Tuple, List + +import torch +from task import input_t, output_t + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import make_ptr + +# Kernel configuration parameters + +# Size of tma descriptor in bytes +bytes_per_tensormap = 128 +# Number of tensormaps: a, b, sfa, sfb +num_tensormaps = 4 +# Tile sizes for M, N, K dimensions +mma_tiler_mnk = (128, 128, 256) +# Shape of the K dimension for the MMA instruction +mma_inst_shape_k = 64 +# FP4 data type for A and B +ab_dtype = cutlass.Float4E2M1FN +# FP8 data type for scale factors +sf_dtype = cutlass.Float8E4M3FN +# FP16 output type +c_dtype = cutlass.Float16 +# Scale factor block size (16 elements share one scale) +sf_vec_size = 16 +# Number of threads per CUDA thread block +threads_per_cta = 128 +# Stage numbers of shared memory and tmem +num_acc_stage = 1 +num_ab_stage = 1 +# Total number of columns in tmem +num_tmem_alloc_cols = 512 + + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + + +# Helper function to reorder the scale factor tensor to match the layout defined in +# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_ptr: cute.Pointer, + sf_mma_ptr: cute.Pointer, + mn: int, + sf_k: int, + l: int, + mma_shape: tuple, +): + mma_permute_order = (3, 4, 1, 5, 2, 0) + permuted_shape = tuple(mma_shape[i] for i in mma_permute_order) + cute_layout = cute.make_ordered_layout(permuted_shape, order=(2, 1, 4, 0, 3, 5)) + + sf_ref_tensor = cute.make_tensor( + sf_ref_ptr, cute.make_layout((mn, sf_k, l), stride=(sf_k, 1, mn * sf_k)) + ) + sf_mma_tensor = cute.make_tensor(sf_mma_ptr, cute_layout) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] + + +# The CuTe reference implementation for NVFP4 block-scaled GEMM +@cute.kernel +def kernel( + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tensor_of_abc_ptrs: cute.Tensor, + tensor_of_sfasfb_ptrs: cute.Tensor, + tensormaps: cute.Tensor, + tensor_of_problem_sizes: cute.Tensor, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + cta_mn_list: List[Tuple[int, int]], + num_tma_load_bytes: cutlass.Constexpr[int], +): + """ + GPU device kernel performing the Group GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + tidx, _, _ = cute.arch.thread_idx() + + # + # Delinearize bidz to coord_x, coord_y and group_idx for each CTA + # + bidx, bidy, bidz = cute.arch.block_idx() + group_idx = 0 + find = False + coord_x = 0 + coord_y = 0 + cta_rest = bidz + for _, (cta_m, cta_n) in enumerate(cta_mn_list): + if cta_rest >= (cta_m * cta_n): + group_idx += 1 + cta_rest -= cta_m * cta_n + else: + if not find: + coord_y = cta_rest // cta_m + coord_x = cta_rest % cta_m + cta_rest -= cta_m * cta_n + find = True + + # + # Construct C Tensor for each CTA + # + mC_mnl_iter = cute.make_ptr( + c_dtype, tensor_of_abc_ptrs[group_idx, 2], cute.AddressSpace.gmem + ).align(32) + m = tensor_of_problem_sizes[group_idx, 0] + n = tensor_of_problem_sizes[group_idx, 1] + k = tensor_of_problem_sizes[group_idx, 2] + l = tensor_of_problem_sizes[group_idx, 3] + + mC_mnl_layout = cute.make_layout( + ( + m, + n, + l, + ), + stride=( + cute.assume(n, 32), + 1, + m * n, + ), + ) + mC_mnl = cute.make_tensor(mC_mnl_iter, mC_mnl_layout) + # Local partition for global C Tensor + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, 0) + ) + + # + # Define shared storage for kernel + # + size_tensormap_in_i64 = ( + num_tensormaps * bytes_per_tensormap // 8 + ) + @cute.struct + class SharedStorage: + tensormap_buffer: cute.struct.MemRange[ + cutlass.Int64, size_tensormap_in_i64 + ] + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] + tmem_holding_buf: cutlass.Int32 + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + tensormap_smem_ptr = storage.tensormap_buffer.data_ptr() + tensormap_a_smem_ptr = tensormap_smem_ptr + tensormap_b_smem_ptr = ( + tensormap_a_smem_ptr + + bytes_per_tensormap // 8 + ) + tensormap_sfa_smem_ptr = ( + tensormap_b_smem_ptr + + bytes_per_tensormap // 8 + ) + tensormap_sfb_smem_ptr = ( + tensormap_sfa_smem_ptr + + bytes_per_tensormap // 8 + ) + # Setup smem tensor for A, B, SFA, SFB + # (MMA, MMA_M, MMA_K, STAGE) + sA = smem.allocate_tensor( + element_type=ab_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = smem.allocate_tensor( + element_type=ab_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfa_smem_layout_staged, + byte_alignment=128, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) + + # Initialize mainloop ab_pipeline, acc_pipeline and their states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=num_tma_load_bytes, + ).make_participants() + acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + threads_per_cta, + ), + ).make_participants() + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(0) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # Update tma descriptor with the correct shapes and strides + tensormap_manager = utils.TensorMapManager( + utils.TensorMapUpdateMode.SMEM, + 128, + ) + tensormap_a_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(bidz, 0, None)].iterator + ) + tensormap_b_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(bidz, 1, None)].iterator + ) + tensormap_sfa_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(bidz, 2, None)].iterator + ) + tensormap_sfb_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(bidz, 3, None)].iterator + ) + + mA_mkl_iter = cute.make_ptr( + ab_dtype, tensor_of_abc_ptrs[group_idx, 0], cute.AddressSpace.gmem + ).align(32) + mB_nkl_iter = cute.make_ptr( + ab_dtype, tensor_of_abc_ptrs[group_idx, 1], cute.AddressSpace.gmem + ).align(32) + sfa_mkl_iter = cute.make_ptr( + sf_dtype, tensor_of_sfasfb_ptrs[group_idx, 0], cute.AddressSpace.gmem + ).align(32) + sfb_nkl_iter = cute.make_ptr( + sf_dtype, tensor_of_sfasfb_ptrs[group_idx, 1], cute.AddressSpace.gmem + ).align(32) + mA_mkl_layout = cute.make_layout( + (m, k, l), + stride=( + cute.assume(k, 32), + 1, + cute.assume(m * k, 32), + ), + ) + mB_nkl_layout = cute.make_layout( + (n, k, l), + stride=( + cute.assume(k, 32), + 1, + cute.assume(n * k, 32), + ), + ) + # SFA, SFB follows specialized layout defined + # here: TODO add linke + atom_shape = ((32, 4), (sf_vec_size, 4)) + atom_stride = ((16, 4), (0, 1)) + sfa_layout = cute.tile_to_shape( + cute.make_layout(atom_shape, stride=atom_stride), + mA_mkl_layout.shape, + (2, 1, 3), + ) + sfb_layout = cute.tile_to_shape( + cute.make_layout(atom_shape, stride=atom_stride), + mB_nkl_layout.shape, + (2, 1, 3), + ) + real_tensor_a = cute.make_tensor(mA_mkl_iter, mA_mkl_layout) + real_tensor_b = cute.make_tensor(mB_nkl_iter, mB_nkl_layout) + real_tensor_sfa = cute.make_tensor(sfa_mkl_iter, sfa_layout) + real_tensor_sfb = cute.make_tensor(sfb_nkl_iter, sfb_layout) + + # Let warp 0 initialize tensormap + if warp_idx == 0: + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_smem_ptr, 0 + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_smem_ptr, 0 + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_sfa, tensormap_sfa_smem_ptr, 0 + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_sfb, tensormap_sfb_smem_ptr, 0 + ) + tensormap_manager.update_tensormap( + ( + real_tensor_a, + real_tensor_b, + real_tensor_sfa, + real_tensor_sfb, + ), + (tma_atom_a, tma_atom_b, tma_atom_sfa, tma_atom_sfb), + ( + tensormap_a_gmem_ptr, + tensormap_b_gmem_ptr, + tensormap_sfa_gmem_ptr, + tensormap_sfb_gmem_ptr, + ), + 0, # tma warp id + ( + tensormap_a_smem_ptr, + tensormap_b_smem_ptr, + tensormap_sfa_smem_ptr, + tensormap_sfb_smem_ptr, + ), + ) + + tensormap_manager.fence_tensormap_update(tensormap_a_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_b_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_sfa_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_sfb_gmem_ptr) + + cute.arch.barrier() + + # + # Partition global/shared tensor for TMA load A/B/SFA/SFB + # + # TMA load A partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + 0, + cute.make_layout(1), + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + 0, + cute.make_layout(1), + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMALDG_SFA partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + 0, + cute.make_layout(1), + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMALDG_SFB partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cpasync.tma_partition( + tma_atom_sfb, + 0, + cute.make_layout(1), + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + # + # Alloc tensor memory buffer + # + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=threads_per_cta, + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + ) + tmem.allocate(num_tmem_alloc_cols) + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) + tCtAcc = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Make SFA/SFB tmem tensor + # + # Get SFA tmem ptr + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc), + dtype=sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + # Get SFB tmem ptr + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + # + # Partition for S2T copy of SFA/SFB + # + # Make S2T CopyAtom + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), + sf_dtype, + ) + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFA_compact = cute.filter_zeros(sSFA) + tCtSFA_compact = cute.filter_zeros(tCtSFA) + tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) + thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) + + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFB_compact = cute.filter_zeros(sSFB) + # (MMA, MMA_MN, MMA_K) + tCtSFB_compact = cute.filter_zeros(tCtSFB) + tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB_compact) + thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfb, tCsSFB_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFB_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB_compact) + + # Number of K loops + k_tile_cnt = cute.ceil_div(real_tensor_a.shape[1], mma_tiler_mnk[2]) + + # + # Slice to per mma tile index + # + mma_tile_coord_mnl = (coord_x, coord_y, 0) + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgSFB = tBgSFB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # + # Main loop + # + if warp_idx == 0: + # Wait for accumulator buffer empty + acc_empty = acc_producer.acquire_and_advance() + # Set ACCUMULATE field to False for the first k_tile iteration + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + # Execute k_tile loop + for k_tile in range(k_tile_cnt): + # Wait for AB buffer empty + ab_empty = ab_producer.acquire_and_advance() + + # TMALDG A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA[(None, k_tile)], + tAsA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_a_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_b, + tBgB[(None, k_tile)], + tBsB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_b_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_sfa, + tAgSFA[(None, k_tile)], + tAsSFA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_sfa_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_sfb, + tBgSFB[(None, k_tile)], + tBsSFB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_sfb_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + + # Wait for AB buffer full + ab_full = ab_consumer.wait_and_advance() + + # Copy SFA/SFB to tmem + s2t_stage_coord = (None, None, None, None, ab_full.index) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_full.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_full.release() + acc_empty.commit() + + # + # Epilogue + # Partition for epilogue + # + op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) + copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(tCgC) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[None, None, None, None, 0, 0].shape, cutlass.Float32 + ) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rC = cute.make_fragment(tTR_gC[None, None, None, None, 0, 0].shape, c_dtype) + # STG Atom + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) + tTR_gC = tTR_gC[(None, None, None, None, coord_x, coord_y)] + + # Release TMEM allocation lock + tmem.relinquish_alloc_permit() + + # Wait for accumulator buffer full + acc_full = acc_consumer.wait_and_advance() + + # Copy accumulator to register + cute.copy(tiled_copy_t2r, tTR_tAcc, tTR_rAcc) + acc_vec = tTR_rAcc.load() + tTR_rC.store(acc_vec.to(c_dtype)) + # Store C to global memory + cute.copy(simt_atom, tTR_rC, tTR_gC) + + acc_full.release() + + # Deallocate TMEM + cute.arch.barrier() + tmem.free(acc_tmem_ptr) + pass + + +# Host-side JIT function to prepare tensors and launch GPU kernel. +@cute.jit +def my_kernel( + initial_abc_ptrs: Tuple[cute.Pointer, cute.Pointer, cute.Pointer], + initial_sfasfb_ptrs: Tuple[cute.Pointer, cute.Pointer], + initial_idx: Tuple[int, int, int], + ptr_of_tensor_of_problem_sizes: cute.Pointer, + ptr_of_tensor_of_abc_ptrs: cute.Pointer, + ptr_of_tensor_of_sfasfb_ptrs: cute.Pointer, + total_num_clusters: cutlass.Constexpr[int], + problem_sizes: List[ + Tuple[int, int, int, int] + ], # Problem sizes for each group + tensor_of_tensormap, + num_groups: cutlass.Constexpr[int], +): + + tensor_of_abc_ptrs = cute.make_tensor( + ptr_of_tensor_of_abc_ptrs, cute.make_layout((num_groups, 3), stride=(3, 1)) + ) + tensor_of_sfasfb_ptrs = cute.make_tensor( + ptr_of_tensor_of_sfasfb_ptrs, cute.make_layout((num_groups, 2), stride=(2, 1)) + ) + tensor_of_problem_sizes = cute.make_tensor( + ptr_of_tensor_of_problem_sizes, cute.make_layout((num_groups, 4), stride=(4, 1)) + ) + + a_ptr, b_ptr, _ = initial_abc_ptrs + sfa_ptr, sfb_ptr = initial_sfasfb_ptrs + min_a_idx, min_b_idx, _ = initial_idx + min_a_shape = problem_sizes[0] + min_b_shape = problem_sizes[0] + for group_idx, shape in enumerate(problem_sizes): + if group_idx == min_a_idx: + min_a_shape = shape + if group_idx == min_b_idx: + min_b_shape = shape + + initial_a = cute.make_tensor( + a_ptr, + cute.make_layout( + (min_a_shape[0], cute.assume(min_a_shape[2], 32), min_a_shape[3]), + stride=( + cute.assume(min_a_shape[2], 32), + 1, + cute.assume(min_a_shape[0] * min_a_shape[2], 32), + ), + ), + ) + min_b_shape = problem_sizes[0] + initial_b = cute.make_tensor( + b_ptr, + cute.make_layout( + (min_b_shape[1], cute.assume(min_b_shape[2], 32), min_b_shape[3]), + stride=( + cute.assume(min_b_shape[2], 32), + 1, + cute.assume(min_b_shape[1] * min_b_shape[2], 32), + ), + ), + ) + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + initial_a.shape, sf_vec_size + ) + initial_sfa = cute.make_tensor(sfa_ptr, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + initial_b.shape, sf_vec_size + ) + initial_sfb = cute.make_tensor(sfb_ptr, sfb_layout) + + # Select MMA operation + mma_op = tcgen05.MmaMXF4NVF4Op( + sf_dtype, + (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((1, 1, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute A/B/SFA/SFB/C shared memory layout + a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # TMA load for A + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + initial_a, + a_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + ) + # TMA load for B + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + initial_b, + b_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + ) + + # TMA load for SFA + sfa_smem_layout = cute.slice_( + sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + initial_sfa, + sfa_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # TMA load for SFB + sfb_smem_layout = cute.slice_( + sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + initial_sfb, + sfb_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + # Compute TMA load bytes + a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) + num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Store CTA shape information for each Group in a List + cta_mn_list = [] + for group_idx, (m, n, k, l) in enumerate(problem_sizes): + x, y = cute.ceil_div(problem_sizes[group_idx][:2], mma_tiler_mnk[0:2]) + cta_mn_list.append((x, y)) + + # Compute grid size + grid = (1, 1, total_num_clusters) + + # Launch the kernel synchronously + kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tensor_of_abc_ptrs, + tensor_of_sfasfb_ptrs, + tensor_of_tensormap, + tensor_of_problem_sizes, + a_smem_layout_staged, + b_smem_layout_staged, + sfa_smem_layout_staged, + sfb_smem_layout_staged, + cta_mn_list, + num_tma_load_bytes, + ).launch( + grid=grid, + block=[threads_per_cta, 1, 1], + cluster=(1, 1, 1), + ) + return + + +# Reorder scale factor from (mn, l, sf_k) to (32, 4, rest_m, 4, rest_k, l) layout +def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): + sf_k = ceil_div(k, sf_vec_size) + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on CPU. + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) + reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) + + # Helper function to convert scale factor tensor to CUTE-format scale factor tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + make_ptr( + cutlass.Float8E4M3FN, + ref_f8_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ), + make_ptr( + cutlass.Float8E4M3FN, + reordered_f8_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ), + mn, + sf_k, + l, + mma_shape, + ) + return reordered_f8_tensor.cuda() + +_compiled_kernel_cache = None + +def compile_kernel(): + pass + +def custom_kernel(data: input_t) -> output_t: + """ + Execute the block-scaled group GEMM kernel. + + This is the main entry point called by the evaluation framework. + It converts PyTorch tensors to CuTe tensors, launches the kernel, + and returns the result. + + Args: + data: Tuple of (abc_tensors, sfasfb_tensors, problem_sizes) where: + abc_tensors: list of tuples (a, b, c) where + a is torch.Tensor[float4e2m1fn_x2] of shape [m, k // 2, l] + b is torch.Tensor[float4e2m1fn_x2] of shape [n, k // 2, l] + c is torch.Tensor[float16] of shape [m, n, l] + sfasfb_tensors: list of tuples (sfa, sfb) where + sfa is torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l] + sfb is torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l] + problem_sizes: list of tuples (m, n, k, l) + each group has its own a, b, c, sfa, sfb with different m, n, k, l problem sizes + l should always be 1 for each group. + list size is the number of groups. + + Returns: + list of c tensors where c is torch.Tensor[float16] of shape [m, n, l] for each group + """ + abc_tensors, sfasfb_tensors, problem_sizes = data + + # Choose A, B, C, SFA, SFB with the smallest size to create initial tensormaps + key_size_a = lambda item: item[1][0] * item[1][2] + key_size_b = lambda item: item[1][1] * item[1][2] + key_size_c = lambda item: item[1][0] * item[1][1] + # Find the indices of the groups with the smallest tensor sizes + min_a_idx, _ = min(enumerate(problem_sizes), key=key_size_a) + min_b_idx, _ = min(enumerate(problem_sizes), key=key_size_b) + min_c_idx, _ = min(enumerate(problem_sizes), key=key_size_c) + + sfasfb_reordered_tensors = [] + abc_ptrs = [] + sfasfb_ptrs = [] + for i, ((a, b, c), (sfa_cpu, sfb_cpu), (m, n, k, l)) in enumerate(zip(abc_tensors, sfasfb_tensors, problem_sizes)): + sf_k = ceil_div(k, sf_vec_size) + sfa_reordered = create_reordered_scale_factor_tensor(l, m, k, sfa_cpu) + sfb_reordered = create_reordered_scale_factor_tensor(l, n, k, sfb_cpu) + sfasfb_reordered_tensors.append((sfa_reordered, sfb_reordered)) + abc_ptrs.append((a.data_ptr(), b.data_ptr(), c.data_ptr())) + sfasfb_ptrs.append((sfa_reordered.data_ptr(), sfb_reordered.data_ptr())) + + # Pick the tensor with the smallest size to create initial tensormaps + initial_cute_abc_ptrs = ( + make_ptr( + ab_dtype, + abc_tensors[min_a_idx][0].data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ), + make_ptr( + ab_dtype, + abc_tensors[min_b_idx][1].data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ), + make_ptr( + c_dtype, + abc_tensors[min_c_idx][2].data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ), + ) + initial_cute_sfasfb_ptrs = ( + make_ptr( + sf_dtype, + sfasfb_reordered_tensors[min_a_idx][0].data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ), + make_ptr( + sf_dtype, + sfasfb_reordered_tensors[min_b_idx][1].data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ), + ) + + # Create torch tensor to store problem sizes + # layout (num_groups, 4):(4, 1) + tensor_of_problem_sizes = torch.tensor( + problem_sizes, dtype=torch.int32, device="cuda" + ) + + # Create torch tensors to store abc_ptrs and sfasfb_ptrs + # layout (num_groups,3):(3, 1) + tensor_of_abc_ptrs = torch.tensor(abc_ptrs, dtype=torch.int64, device="cuda") + tensor_of_sfasfb_ptrs = torch.tensor(sfasfb_ptrs, dtype=torch.int64, device="cuda") + + # Compute cluster tile shape + cta_tile_shape_mn = [128, mma_tiler_mnk[1]] + cluster_tile_shape_mn = tuple( + x * y for x, y in zip(cta_tile_shape_mn, (1, 1)) + ) + # Compute total number of cluster tiles we need to compute for given grouped GEMM problem + total_num_clusters = 0 + num_groups = len(problem_sizes) + for m, n, _, _ in problem_sizes: + num_clusters_mn = tuple( + (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + + # Preserved buffers for each cluster to update its tma descriptor in device memory + tensormap_shape = ( + total_num_clusters, + num_tensormaps, + bytes_per_tensormap // 8, + ) + tensor_of_tensormap = torch.empty(tensormap_shape, dtype=torch.int64, device="cuda") + + cute_ptr_of_tensor_of_abc_ptrs = make_ptr( + cutlass.Int64, + tensor_of_abc_ptrs.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + cute_ptr_of_tensor_of_sfasfb_ptrs = make_ptr( + cutlass.Int64, + tensor_of_sfasfb_ptrs.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + cute_ptr_of_tensor_of_problem_sizes = make_ptr( + cutlass.Int32, + tensor_of_problem_sizes.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + + # Execute the compiled kernel + my_kernel( + initial_cute_abc_ptrs, + initial_cute_sfasfb_ptrs, + (min_a_idx, min_b_idx, min_c_idx), + cute_ptr_of_tensor_of_problem_sizes, + cute_ptr_of_tensor_of_abc_ptrs, + cute_ptr_of_tensor_of_sfasfb_ptrs, + total_num_clusters, + problem_sizes, + tensor_of_tensormap, + num_groups, + ) + + res = [] + for i in range(num_groups): + res.append(abc_tensors[i][2]) + return res \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/task.py b/problems/nvidia/nvfp4_group_gemm/task.py new file mode 100644 index 0000000..6e0961f --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/task.py @@ -0,0 +1,8 @@ +import torch +from typing import TypedDict, TypeVar + +input_t = TypeVar("input_t", bound=tuple[list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], list[tuple[torch.Tensor, torch.Tensor]], list[tuple[int, int, int, int]]]) +output_t = TypeVar("output_t", bound=list[torch.Tensor]) +class TestSpec(TypedDict): + problem_sizes: list[tuple[int, int, int, int]] + seed: int \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/task.yml b/problems/nvidia/nvfp4_group_gemm/task.yml new file mode 100644 index 0000000..82a0640 --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/task.yml @@ -0,0 +1,65 @@ +# name: nvfp4-block-scaled-gemm + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + + You will implement a block scaled group matrix-matrix multiplication kernel optimized for NVIDIA B200. + To be explicit, you will be given a tuple of tensors: + ``` + (abc_tensors, sfasfb_tensors, problem_sizes) + ``` + where: + * `abc_tensors` is list of tuples (a, b, c) where + a is torch.Tensor[float4e2m1fn_x2] of shape [M, K // 2, L] + b is torch.Tensor[float4e2m1fn_x2] of shape [N, K // 2, L] + c is torch.Tensor[float16] of shape [M, N, L] + * `sfasfb_tensors` is list of tuples (sfa, sfb) where + sfa is torch.Tensor[float8_e4m3fnuz] of shape [M, K // 16, L] + sfb is torch.Tensor[float8_e4m3fnuz] of shape [N, K // 16, L] + * `problem_sizes` is list of tuples (M, N, K, L) + + Each group's matrix sizes `M` is divisible by mma_tiler_mn[0], `N` is divisible by mma_tiler_mn[1], `K` is divisible by 256. + The ranking criteria is the geometric mean of the benchmark results. + For the grand price, your kernel will be evaluated against the speed of light analysis + and the solution closest to the speed of light will be awarded the grand price. + ``` + The speed of light analysis is (using 1.5Ghz clock): + G M N K L time[us] + 8 128 4096 7168 1 8.71 + 8 128 7168 2048 1 4.36 + 2 256 3072 4096 1 1.87 + 2 256 4096 1536 1 0.93 + ``` +config: + main: "eval.py" + +templates: + Python: "template.py" + +tests: + - {"m": 128, "n": 256, "k": 512, "g": 8, "seed": 1111} + - {"m": 128, "n": 256, "k": 512, "g": 2, "seed": 1111} + - {"m": 128, "n": 384, "k": 640, "g": 3, "seed": 1111} + - {"m": 256, "n": 384, "k": 640, "g": 4, "seed": 1111} + - {"m": 256, "n": 512, "k": 384, "g": 2, "seed": 1111} + - {"m": 384, "n": 512, "k": 384, "g": 2, "seed": 1111} + - {"m": 384, "n": 640, "k": 512, "g": 2, "seed": 1111} + - {"m": 256, "n": 640, "k": 128, "g": 8, "seed": 1111} + - {"m": 512, "n": 768, "k": 256, "g": 5, "seed": 1111} + - {"m": 512, "n": 768, "k": 768, "g": 3, "seed": 1111} + +benchmarks: + - {"m": 4096, "n": 128, "k": 7168, "g": 8, "seed": 1111} + - {"m": 7168, "n": 128, "k": 2048, "g": 8, "seed": 1111} + - {"m": 3072, "n": 256, "k": 4096, "g": 2, "seed": 1111} + - {"m": 4096, "n": 256, "k": 1536, "g": 2, "seed": 1111} + +ranking_by: "geom" diff --git a/problems/nvidia/nvfp4_group_gemm/template.py b/problems/nvidia/nvfp4_group_gemm/template.py new file mode 100644 index 0000000..ea034a9 --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/template.py @@ -0,0 +1,28 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference implementation of block-scale fp4 group gemm + Args: + data: list of tuples (abc_tensors, sfasfb_tensors, problem_sizes) where: + abc_tensors: list of tuples (a, b, c) where + a is torch.Tensor[float4e2m1fn_x2] of shape [m, k // 2, l] + b is torch.Tensor[float4e2m1fn_x2] of shape [n, k // 2, l] + c is torch.Tensor[float16] of shape [m, n, l] + sfasfb_tensors: list of tuples (sfa, sfb) where + sfa is torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l] + sfb is torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l] + problem_sizes: list of tuples (m, n, k, l) + each group has its own a, b, c, sfa, sfb with different m, n, k, l problem sizes + l should always be 1 for each group. + Returns: + list of tuples (c) where c is torch.Tensor[float16] of shape [m, n, l] + """ + abc_tensors, sfasfb_tensors, problem_sizes = data + result_tensors = [] + for i, (a, b, c) in enumerate(abc_tensors): + # add you implementation here + result_tensors.append(c) + + return result_tensors \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/utils.py b/problems/nvidia/nvfp4_group_gemm/utils.py new file mode 100644 index 0000000..486116b --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/utils.py @@ -0,0 +1,176 @@ +import os +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + + Raises: + AssertionError: If the tensors are not all close within the given tolerance. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received - expected) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +@torch.no_grad() +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + for i, (output_i, expected_i) in enumerate(zip(output, expected)): + reasons = verbose_allclose(output_i, expected_i, rtol=rtol, atol=atol) + if len(reasons) > 0: + return False, f"mismatch found! custom implementation doesn't match reference: {i} {reasons}" + + return True, '' + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + return wrapped + + +class DeterministicContext: + def __init__(self): + self.allow_tf32 = None + self.deterministic = None + self.cublas = None + + def __enter__(self): + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + torch.use_deterministic_algorithms(False) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy \ No newline at end of file From cf6425521491d3772e11cfc80ad753d56e554fe5 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Mon, 20 Oct 2025 22:17:04 -0700 Subject: [PATCH 13/29] move scale factor reorder operation to host. --- problems/nvidia/nvfp4_gemm/eval.py | 437 +++++++++++++++++ problems/nvidia/nvfp4_gemm/eval_vicki.py | 500 ++++++++++++++++++++ problems/nvidia/nvfp4_gemm/reference.py | 47 +- problems/nvidia/nvfp4_gemm/submission.py | 63 +-- problems/nvidia/nvfp4_gemm/task.py | 2 +- problems/nvidia/nvfp4_gemm/template.py | 4 +- problems/nvidia/nvfp4_gemm/test_python_1.sh | 87 ++++ problems/nvidia/nvfp4_gemm/utils.py | 176 +++++++ 8 files changed, 1254 insertions(+), 62 deletions(-) create mode 100644 problems/nvidia/nvfp4_gemm/eval.py create mode 100644 problems/nvidia/nvfp4_gemm/eval_vicki.py create mode 100644 problems/nvidia/nvfp4_gemm/test_python_1.sh create mode 100644 problems/nvidia/nvfp4_gemm/utils.py diff --git a/problems/nvidia/nvfp4_gemm/eval.py b/problems/nvidia/nvfp4_gemm/eval.py new file mode 100644 index 0000000..072b176 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/eval.py @@ -0,0 +1,437 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional +import tempfile + +import torch.cuda +from cutlass.cute.nvgpu.common import OpError + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, "w") + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats( + runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) + ) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + + data = generate_input(**test.args) + torch.cuda.synchronize() + try: + submission_output = custom_kernel(_clone_data(data)) + + except OpError as E: + print(f"Encountered {E}", file=sys.stderr) + return False, str(E) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes the actual test case code and checks for correctness. + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _run_single_benchmark( + test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float +) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + # first, one obligatory correctness check + try: + output = custom_kernel(_clone_data(data)) + except OpError as E: + return f"Encountered {E}" + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 100 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if ( + stats.err / stats.mean < 0.001 + or stats.mean * stats.runs > max_time_ns + or total_bm_duration > 120e9 + ): + break + + return calculate_stats(durations) + + +def run_single_benchmark( + pool: multiprocessing.Pool, + test: TestCase, + recheck: bool, + max_repeats: int, + max_time_ns: float, +): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # warm up + run_single_benchmark(pool, tests[0], False, 100, 10e7) + + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 100, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log( + f"benchmark.{idx}.report", + base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), + ) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + + filename = None + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + + def build_test_string(tests: list[dict]): + as_str = "" + for test in tests: + kvs = [] + for k, v in test.items(): + kvs.append(f"{k}: {v}") + as_str += "; ".join(kvs) + "\n" + return as_str + + import yaml + + yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + if mode == "test": + tests_str = build_test_string(yaml_content.get("tests", [])) + elif mode in ("benchmark", "leaderboard", "profile"): + tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + tmp.write(tests_str.encode("utf-8")) + tmp.flush() + filename = tmp.name + + tests = get_test_cases(filename, seed) + + os.unlink(filename) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + + mp_context = multiprocessing.get_context("spawn") + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # warmup + run_single_benchmark(pool, tests[0], False, 100, 1e7) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 100, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log( + f"benchmark.{i}.{field.name}", + getattr(result, field.name), + ) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log( + f"benchmark.{i}.error", str(result) + ) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/eval_vicki.py b/problems/nvidia/nvfp4_gemm/eval_vicki.py new file mode 100644 index 0000000..2441b7d --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/eval_vicki.py @@ -0,0 +1,500 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional +import tempfile + +import torch.cuda +from cutlass.cute.nvgpu.common import OpError + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, "w") + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats( + runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) + ) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + + data = generate_input(**test.args) + torch.cuda.synchronize() + try: + submission_output = custom_kernel(_clone_data(data)) + + except OpError as E: + print(f"Encountered {E}", file=sys.stderr) + return False, str(E) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + # Step 1: Compile kernel once before running tests + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Run all tests with compiled kernel + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _compile_kernel_once(): + """ + Compile the kernel once before any benchmarking. + This ensures compilation time is not included in benchmark results. + """ + from submission import compile_kernel + + try: + # Trigger compilation (will be cached) + compile_kernel() + torch.cuda.synchronize() + return True, None + except OpError as E: + return False, f"Compilation failed: {E}" + except Exception as E: + return False, f"Compilation failed: {E}" + + +def _run_single_benchmark( + test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float +) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel, compile_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + + # Ensure kernel is compiled before any timing (compilation is cached) + try: + compile_kernel() + torch.cuda.synchronize() + except OpError as E: + return f"Compilation failed: {E}" + except Exception as E: + return f"Compilation failed: {E}" + + # first, one obligatory correctness check + try: + output = custom_kernel(_clone_data(data)) + except OpError as E: + return f"Encountered {E}" + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 200 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if ( + stats.err / stats.mean < 0.001 + or stats.mean * stats.runs > max_time_ns + or total_bm_duration > 120e9 + ): + break + + return calculate_stats(durations) + + +def run_single_benchmark( + pool: multiprocessing.Pool, + test: TestCase, + recheck: bool, + max_repeats: int, + max_time_ns: float, +): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warm up with compiled kernel + run_single_benchmark(pool, tests[0], False, 2, 10e7) + + # Step 3: Run benchmarks (compilation time excluded) + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 2, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log( + f"benchmark.{idx}.report", + base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), + ) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + + filename = None + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + + def build_test_string(tests: list[dict]): + as_str = "" + for test in tests: + kvs = [] + for k, v in test.items(): + kvs.append(f"{k}: {v}") + as_str += "; ".join(kvs) + "\n" + return as_str + + import yaml + + yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + if mode == "test": + tests_str = build_test_string(yaml_content.get("tests", [])) + elif mode in ("benchmark", "leaderboard", "profile"): + tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + tmp.write(tests_str.encode("utf-8")) + tmp.flush() + filename = tmp.name + + tests = get_test_cases(filename, seed) + + os.unlink(filename) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + + mp_context = multiprocessing.get_context("spawn") + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warmup with compiled kernel + run_single_benchmark(pool, tests[0], False, 2, 1e7) + + # Step 3: Run leaderboard benchmarks (compilation time excluded) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 2, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log( + f"benchmark.{i}.{field.name}", + getattr(result, field.name), + ) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log( + f"benchmark.{i}.error", str(result) + ) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/problems/nvidia/nvfp4_gemm/reference.py b/problems/nvidia/nvfp4_gemm/reference.py index 12db1a6..ae56d1f 100644 --- a/problems/nvidia/nvfp4_gemm/reference.py +++ b/problems/nvidia/nvfp4_gemm/reference.py @@ -23,13 +23,14 @@ def to_blocked(input_matrix): return rearranged.flatten() + def ref_kernel( data: input_t, ) -> output_t: """ PyTorch reference implementation of NVFP4 block-scaled GEMM. """ - a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, c_ref = data + a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, _, _, c_ref = data # Get dimensions from MxNxL layout _, _, l = c_ref.shape @@ -94,7 +95,9 @@ def generate_input( 1, 2, 0 ) - # Helper function to prepare the scale factor tensors + # Helper function to prepare the scale factor tensors for both reference + # kernel and customize kernel. Please note this data reordering function + # is very slow. def create_scale_factor_tensors(l, mn, sf_k): # Create the reference scale factor tensor (mn, l, sf_k) on CPU. ref_shape = (l, mn, sf_k) @@ -106,13 +109,45 @@ def create_scale_factor_tensors(l, mn, sf_k): ref_f8_torch_tensor_cpu_permuted = ref_f8_torch_tensor_cpu.permute( *ref_permute_order ) - return ref_f8_torch_tensor_cpu_permuted + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + # Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout + # Which is needed by the CuTe customized kernel + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) + reordered_f8_torch_tensor_cpu = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_torch_tensor_cpu = reordered_f8_torch_tensor_cpu.permute( + *mma_permute_order + ) + + for i in range(mn): + for j in range(sf_k): + for b in range(l): + # Calculate the location in MMA shape + mm = i // (atom_m[0] * atom_m[1]) + mm32 = i % atom_m[0] + mm4 = (i % 128) // atom_m[0] + kk = j // atom_k + kk4 = j % atom_k + reordered_f8_torch_tensor_cpu[mm32, mm4, mm, kk4, kk, b] = ref_f8_torch_tensor_cpu_permuted[i, j, b] + return ref_f8_torch_tensor_cpu_permuted, reordered_f8_torch_tensor_cpu.cuda() sf_k = ceil_div(k, sf_vec_size) - sfa_ref_cpu = create_scale_factor_tensors(l, m, sf_k) - sfb_ref_cpu = create_scale_factor_tensors(l, n, sf_k) + sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) + sfb_ref_cpu, sfb_ref_permuted = create_scale_factor_tensors(l, n, sf_k) - return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, c_ref) + return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, sfa_ref_permuted, sfb_ref_permuted, c_ref) check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) diff --git a/problems/nvidia/nvfp4_gemm/submission.py b/problems/nvidia/nvfp4_gemm/submission.py index b40e206..8e25fcc 100644 --- a/problems/nvidia/nvfp4_gemm/submission.py +++ b/problems/nvidia/nvfp4_gemm/submission.py @@ -687,49 +687,6 @@ def my_kernel( return -# Reorder scale factor from (mn, l, sf_k) to (32, 4, rest_m, 4, rest_k, l) layout -def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): - sf_k = ceil_div(k, sf_vec_size) - atom_m = (32, 4) - atom_k = 4 - mma_shape = ( - l, # batch size - ceil_div(mn, atom_m[0] * atom_m[1]), - ceil_div(sf_k, atom_k), - atom_m[0], - atom_m[1], - atom_k, - ) - # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on CPU. - mma_permute_order = (3, 4, 1, 5, 2, 0) - # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) - reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) - # Permute according to mma_permute_order - reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) - - # Helper function to convert scale factor tensor to CUTE-format scale factor tensor - cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - make_ptr( - cutlass.Float8E4M3FN, - ref_f8_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32, - ), - make_ptr( - cutlass.Float8E4M3FN, - reordered_f8_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32, - ), - mn, - sf_k, - l, - mma_shape, - ) - return reordered_f8_tensor.cuda() - - # Global cache for compiled kernel _compiled_kernel_cache = None @@ -782,17 +739,19 @@ def custom_kernel(data: input_t) -> output_t: and returns the result. Args: - data: Tuple of (a, b, sfa_cpu, sfb_cpu, c) PyTorch tensors + data: Tuple of (a, b, sfa_ref, sfb_ref, sfa_permuted, sfb_permuted, c) PyTorch tensors a: [m, k, l] - Input matrix in float4e2m1fn b: [n, k, l] - Input vector in float4e2m1fn - sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn - sfb_cpu: [n, k, l] - Scale factors in float8_e4m3fn + sfa_ref: [m, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfb_ref: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn + sfb_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn c: [m, n, l] - Output vector in float16 Returns: Output tensor c with computed results """ - a, b, sfa_cpu, sfb_cpu, c = data + a, b, _, _, sfa_permuted, sfb_permuted, c = data # Ensure kernel is compiled (will use cached version if available) compiled_func = compile_kernel() @@ -802,10 +761,6 @@ def custom_kernel(data: input_t) -> output_t: # Torch use e2m1_x2 data type, thus k is halved k = k * 2 - # Create the reordered scale factor tensors from the reference scale factor tensors via CuTe function. - sfa_reordered = create_reordered_scale_factor_tensor(l, m, k, sfa_cpu) - sfb_reordered = create_reordered_scale_factor_tensor(l, n, k, sfb_cpu) - # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer a_ptr = make_ptr( ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 @@ -817,13 +772,13 @@ def custom_kernel(data: input_t) -> output_t: c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 ) sfa_ptr = make_ptr( - sf_dtype, sfa_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) sfb_ptr = make_ptr( - sf_dtype, sfb_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + sf_dtype, sfb_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) # Execute the compiled kernel compiled_func(a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l)) - return c + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/task.py b/problems/nvidia/nvfp4_gemm/task.py index 4ebbe88..66db735 100644 --- a/problems/nvidia/nvfp4_gemm/task.py +++ b/problems/nvidia/nvfp4_gemm/task.py @@ -1,7 +1,7 @@ import torch from typing import TypedDict, TypeVar -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) output_t = TypeVar("output_t", bound=torch.Tensor) class TestSpec(TypedDict): m: int diff --git a/problems/nvidia/nvfp4_gemm/template.py b/problems/nvidia/nvfp4_gemm/template.py index 17d6347..3855d69 100644 --- a/problems/nvidia/nvfp4_gemm/template.py +++ b/problems/nvidia/nvfp4_gemm/template.py @@ -10,13 +10,15 @@ def custom_kernel(data: input_t) -> output_t: b: torch.Tensor[float4e2m1fn] of shape [n, k, l], sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], sfb: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], + sfa_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l], + sfb_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], c: torch.Tensor[float16] of shape [m, n, l] Returns: Tensor containing output in float16 c: torch.Tensor[float16] of shape [m, n, l] """ # c: [m, n, l] is pre-allocated memory to avoid timing allocation overhead. - a, b, sfa, sfb, c = data + a, b, sfa, sfb, sfa_permuted, sfb_permuted, c = data # Your implementation here diff --git a/problems/nvidia/nvfp4_gemm/test_python_1.sh b/problems/nvidia/nvfp4_gemm/test_python_1.sh new file mode 100644 index 0000000..ab7a8c9 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/test_python_1.sh @@ -0,0 +1,87 @@ +# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build +BUILD_DIR=/home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/build_python +LLVM_DIR=$BUILD_DIR/llvm-prebuilt +# # BUILD_DIR=/home/scratch.ftse_gpu/workspace/dkg/build +# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build +# #BUILD_DIR=/home/yanchengz/scratch_1/dynamic-kernel-generator/build_debug2 +# # sudo /home/scratch.computelab/utils/driver/install_driver.py --installer=/home/builds/daily/display/x86_64/rel/gpu_drv/r580/r580_00/20250527_36037303/NVIDIA-Linux-x86_64-rel_gpu_drv_r580_r580_00-20250527_36037303-internal.run --reason="Change to tot driver" + + +# # BUILD_DIR=/home/scratch.nbommi_gpu/warp-phase-trace/dynamic-kernel-generator/build_main + +export PYTHONPATH=$BUILD_DIR/cutlass_ir/python_packages +#export PYTHONPATH=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/scripts +export CUDA_TOOLKIT_PATH=$BUILD_DIR/compiler_next +MLIR_CUDA_RUNTIME="$LLVM_DIR/lib/libmlir_cuda_runtime.so" +MLIR_C_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_c_runner_utils.so" +MLIR_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_runner_utils.so" +CUDA_DIALECT_RUNTIME="$BUILD_DIR/lib/libcuda_dialect_runtime.so" +export CUTE_DSL_LIBS="$MLIR_CUDA_RUNTIME:$MLIR_C_RUNNER_UTILS:$MLIR_RUNNER_UTILS:$CUDA_DIALECT_RUNTIME" + + +#export CUTE_DSL_PREPROCESSOR=True + +# export CUTE_DSL_PRINT_IR=1 +# just compile the IR but not execute it +# export CUTE_DSL_DRYRUN=1 +# export CUTE_DSL_JIT_TIME_PROFILING=ON +# export CUTE_DSL_KEEP_IR=True +# export CUTE_DSL_PRINT_IR=1 +# export CUTE_DSL_KEEP_CUBIN=1 +# export CUTE_DSL_LINEINFO=True +# export CUTE_DSL_LOG_TO_CONSOLE=1 +# export PYTHONUNBUFFERED=1 +# export CUTE_DSL_KEEP_SASS=1 +# whether to show detailed log in preprocessing +# export CUTE_DSL_FILTER_STACKTRACE=10 +export CUTE_DSL_ARCH=sm_100a + +# +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/reference-kernels/problems/nvidia/nvfp4_gemm/submission.py +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration -- +/home/scratch.vickiw_gpu/env/bin/python3 eval_vicki.py benchmark task.yml +/home/scratch.vickiw_gpu/env/bin/python3 eval_vicki.py test task.yml +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/cuda-gdb --args + +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py +# # /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gated_dual_gemm.py + +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemm/nvfp4_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemv/nvfp4_gemv.py +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool memcheck \ +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,16384 #135us +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 4096,128,7168 #62 + +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,2048 #26 + + +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gated_dual_gemm/nvfp4_gated_dual_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_naive.py + + + +# print out ncu time +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# python3 vicki/tutorial_fp16_gemm_0__.py --mnk 7168,8,512 + +# use sanitizer to check race contention and memref error +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck|memcheck +# cutlass_ir/compiler/test/python/examples/sm_100a/test_nvfp4_gemv.py + +# capture ncu report +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --check-exit-code 0 -f --set full --import-source yes --target-processes all --clock-control base --cache-control none -o gemv_4.1 \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv.py --m 128 --k 128 --l 2 + +# regular run python example +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/min_latency_hmma.py --mnkl 7168,8,512,1 + +# run pytest +# pytest cutlass_ir/compiler/test/python/examples/sm_80/test_sgemm.py diff --git a/problems/nvidia/nvfp4_gemm/utils.py b/problems/nvidia/nvfp4_gemm/utils.py new file mode 100644 index 0000000..e8a9082 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/utils.py @@ -0,0 +1,176 @@ +import os +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + + Raises: + AssertionError: If the tensors are not all close within the given tolerance. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received - expected) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +@torch.no_grad() +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) + + if len(reasons) > 0: + return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) + + return True, '' + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + return wrapped + + +class DeterministicContext: + def __init__(self): + self.allow_tf32 = None + self.deterministic = None + self.cublas = None + + def __enter__(self): + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + torch.use_deterministic_algorithms(False) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy \ No newline at end of file From d229c9091de34e34ac7dc8a85a80db677e881b9b Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Mon, 20 Oct 2025 22:39:33 -0700 Subject: [PATCH 14/29] move scale factor initialization function to reference. --- problems/nvidia/nvfp4_dual_gemm/eval.py | 500 ++++++++++++++++++ problems/nvidia/nvfp4_dual_gemm/reference.py | 53 +- problems/nvidia/nvfp4_dual_gemm/submission.py | 67 +-- problems/nvidia/nvfp4_dual_gemm/task.py | 2 +- problems/nvidia/nvfp4_dual_gemm/template.py | 11 +- .../nvidia/nvfp4_dual_gemm/test_python_1.sh | 87 +++ problems/nvidia/nvfp4_dual_gemm/utils.py | 176 ++++++ 7 files changed, 828 insertions(+), 68 deletions(-) create mode 100644 problems/nvidia/nvfp4_dual_gemm/eval.py create mode 100644 problems/nvidia/nvfp4_dual_gemm/test_python_1.sh create mode 100644 problems/nvidia/nvfp4_dual_gemm/utils.py diff --git a/problems/nvidia/nvfp4_dual_gemm/eval.py b/problems/nvidia/nvfp4_dual_gemm/eval.py new file mode 100644 index 0000000..e8bb5b2 --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/eval.py @@ -0,0 +1,500 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional +import tempfile + +import torch.cuda +from cutlass.cute.nvgpu.common import OpError + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, "w") + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats( + runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) + ) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + + data = generate_input(**test.args) + torch.cuda.synchronize() + try: + submission_output = custom_kernel(_clone_data(data)) + + except OpError as E: + print(f"Encountered {E}", file=sys.stderr) + return False, str(E) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + # Step 1: Compile kernel once before running tests + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Run all tests with compiled kernel + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _compile_kernel_once(): + """ + Compile the kernel once before any benchmarking. + This ensures compilation time is not included in benchmark results. + """ + from submission import compile_kernel + + try: + # Trigger compilation (will be cached) + compile_kernel() + torch.cuda.synchronize() + return True, None + except OpError as E: + return False, f"Compilation failed: {E}" + except Exception as E: + return False, f"Compilation failed: {E}" + + +def _run_single_benchmark( + test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float +) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel, compile_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + + # Ensure kernel is compiled before any timing (compilation is cached) + try: + compile_kernel() + torch.cuda.synchronize() + except OpError as E: + return f"Compilation failed: {E}" + except Exception as E: + return f"Compilation failed: {E}" + + # first, one obligatory correctness check + try: + output = custom_kernel(_clone_data(data)) + except OpError as E: + return f"Encountered {E}" + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 200 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if ( + stats.err / stats.mean < 0.001 + or stats.mean * stats.runs > max_time_ns + or total_bm_duration > 120e9 + ): + break + + return calculate_stats(durations) + + +def run_single_benchmark( + pool: multiprocessing.Pool, + test: TestCase, + recheck: bool, + max_repeats: int, + max_time_ns: float, +): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warm up with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 10e7) + + # Step 3: Run benchmarks (compilation time excluded) + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 200, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log( + f"benchmark.{idx}.report", + base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), + ) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + + filename = None + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + + def build_test_string(tests: list[dict]): + as_str = "" + for test in tests: + kvs = [] + for k, v in test.items(): + kvs.append(f"{k}: {v}") + as_str += "; ".join(kvs) + "\n" + return as_str + + import yaml + + yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + if mode == "test": + tests_str = build_test_string(yaml_content.get("tests", [])) + elif mode in ("benchmark", "leaderboard", "profile"): + tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + tmp.write(tests_str.encode("utf-8")) + tmp.flush() + filename = tmp.name + + tests = get_test_cases(filename, seed) + + os.unlink(filename) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + + mp_context = multiprocessing.get_context("spawn") + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warmup with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 1e7) + + # Step 3: Run leaderboard benchmarks (compilation time excluded) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 200, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log( + f"benchmark.{i}.{field.name}", + getattr(result, field.name), + ) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log( + f"benchmark.{i}.error", str(result) + ) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/problems/nvidia/nvfp4_dual_gemm/reference.py b/problems/nvidia/nvfp4_dual_gemm/reference.py index ffa56ef..c64927c 100644 --- a/problems/nvidia/nvfp4_dual_gemm/reference.py +++ b/problems/nvidia/nvfp4_dual_gemm/reference.py @@ -23,13 +23,14 @@ def to_blocked(input_matrix): return rearranged.flatten() + def ref_kernel( data: input_t, ) -> output_t: """ PyTorch reference implementation of NVFP4 block-scaled GEMM. """ - a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, c_ref = data + a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, _, _, _, c_ref = data # Get dimensions from MxNxL layout m, n, l = c_ref.shape @@ -101,6 +102,9 @@ def generate_input( scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type scale_b1: [n, k, l] - Input scale factors in torch.float8e4m3fn data type scale_b2: [n, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b1_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b2_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type c: [m, n, l] - Output matrix in torch.float16 data type """ torch.manual_seed(seed) @@ -124,7 +128,9 @@ def generate_input( 1, 2, 0 ) - # Helper function to prepare the scale factor tensors + # Helper function to prepare the scale factor tensors for both reference + # kernel and customize kernel. Please note this data reordering function + # is very slow. def create_scale_factor_tensors(l, mn, sf_k): # Create the reference scale factor tensor (mn, l, sf_k) on CPU. ref_shape = (l, mn, sf_k) @@ -136,13 +142,46 @@ def create_scale_factor_tensors(l, mn, sf_k): ref_f8_torch_tensor_cpu_permuted = ref_f8_torch_tensor_cpu.permute( *ref_permute_order ) - return ref_f8_torch_tensor_cpu_permuted + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + # Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout + # Which is needed by the CuTe customized kernel + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) + reordered_f8_torch_tensor_cpu = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_torch_tensor_cpu = reordered_f8_torch_tensor_cpu.permute( + *mma_permute_order + ) + + for i in range(mn): + for j in range(sf_k): + for b in range(l): + # Calculate the location in MMA shape + mm = i // (atom_m[0] * atom_m[1]) + mm32 = i % atom_m[0] + mm4 = (i % 128) // atom_m[0] + kk = j // atom_k + kk4 = j % atom_k + reordered_f8_torch_tensor_cpu[mm32, mm4, mm, kk4, kk, b] = ref_f8_torch_tensor_cpu_permuted[i, j, b] + return ref_f8_torch_tensor_cpu_permuted, reordered_f8_torch_tensor_cpu.cuda() sf_k = ceil_div(k, sf_vec_size) - sfa_ref_cpu = create_scale_factor_tensors(l, m, sf_k) - sfb1_ref_cpu = create_scale_factor_tensors(l, n, sf_k) - sfb2_ref_cpu = create_scale_factor_tensors(l, n, sf_k) + sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) + sfb1_ref_cpu, sfb1_ref_permuted = create_scale_factor_tensors(l, n, sf_k) + sfb2_ref_cpu, sfb2_ref_permuted = create_scale_factor_tensors(l, n, sf_k) - return (a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, c_ref) + return (a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, sfa_ref_permuted, sfb1_ref_permuted, sfb2_ref_permuted, c_ref) check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) diff --git a/problems/nvidia/nvfp4_dual_gemm/submission.py b/problems/nvidia/nvfp4_dual_gemm/submission.py index 6ecddbc..c767142 100644 --- a/problems/nvidia/nvfp4_dual_gemm/submission.py +++ b/problems/nvidia/nvfp4_dual_gemm/submission.py @@ -859,49 +859,6 @@ def my_kernel( return -# Reorder scale factor from (mn, l, sf_k) to (32, 4, rest_m, 4, rest_k, l) layout -def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): - sf_k = ceil_div(k, sf_vec_size) - atom_m = (32, 4) - atom_k = 4 - mma_shape = ( - l, # batch size - ceil_div(mn, atom_m[0] * atom_m[1]), - ceil_div(sf_k, atom_k), - atom_m[0], - atom_m[1], - atom_k, - ) - # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on CPU. - mma_permute_order = (3, 4, 1, 5, 2, 0) - # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) - reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) - # Permute according to mma_permute_order - reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) - - # Helper function to convert scale factor tensor to CUTE-format scale factor tensor - cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - make_ptr( - cutlass.Float8E4M3FN, - ref_f8_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32, - ), - make_ptr( - cutlass.Float8E4M3FN, - reordered_f8_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32, - ), - mn, - sf_k, - l, - mma_shape, - ) - return reordered_f8_tensor.cuda() - - # Global cache for compiled kernel _compiled_kernel_cache = None @@ -965,15 +922,18 @@ def custom_kernel(data: input_t) -> output_t: a: [m, k, l] - Input matrix in float4e2m1fn b1: [n, k, l] - Input matrix in float4e2m1fn b2: [n, k, l] - Input matrix in float4e2m1fn - sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn - sfb1_cpu: [n, k, l] - Scale factors in float8_e4m3fn - sfb2_cpu: [n, k, l] - Scale factors in float8_e4m3fn + sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfb1_cpu: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfb2_cpu: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn + sfb1_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn + sfb2_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn c: [m, n, l] - Output vector in float16 Returns: Output tensor c with computed results """ - a, b1, b2, sfa_cpu, sfb1_cpu, sfb2_cpu, c = data + a, b1, b2, _, _, _, sfa_permuted, sfb1_permuted, sfb2_permuted, c = data # Ensure kernel is compiled (will use cached version if available) compiled_func = compile_kernel() @@ -983,11 +943,6 @@ def custom_kernel(data: input_t) -> output_t: # Torch use e2m1_x2 data type, thus k is halved k = k * 2 - # Create the reordered scale factor tensors from the reference scale factor tensors via CuTe function. - sfa_reordered = create_reordered_scale_factor_tensor(l, m, k, sfa_cpu) - sfb1_reordered = create_reordered_scale_factor_tensor(l, n, k, sfb1_cpu) - sfb2_reordered = create_reordered_scale_factor_tensor(l, n, k, sfb2_cpu) - # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer a_ptr = make_ptr( ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 @@ -1002,16 +957,16 @@ def custom_kernel(data: input_t) -> output_t: c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 ) sfa_ptr = make_ptr( - sf_dtype, sfa_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) sfb1_ptr = make_ptr( - sf_dtype, sfb1_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + sf_dtype, sfb1_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) sfb2_ptr = make_ptr( - sf_dtype, sfb2_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + sf_dtype, sfb2_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) # Execute the compiled kernel compiled_func(a_ptr, b1_ptr, b2_ptr, sfa_ptr, sfb1_ptr, sfb2_ptr, c_ptr, (m, n, k, l)) - return c + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_dual_gemm/task.py b/problems/nvidia/nvfp4_dual_gemm/task.py index 66db735..8facfb0 100644 --- a/problems/nvidia/nvfp4_dual_gemm/task.py +++ b/problems/nvidia/nvfp4_dual_gemm/task.py @@ -1,7 +1,7 @@ import torch from typing import TypedDict, TypeVar -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) output_t = TypeVar("output_t", bound=torch.Tensor) class TestSpec(TypedDict): m: int diff --git a/problems/nvidia/nvfp4_dual_gemm/template.py b/problems/nvidia/nvfp4_dual_gemm/template.py index 2509200..d8985df 100644 --- a/problems/nvidia/nvfp4_dual_gemm/template.py +++ b/problems/nvidia/nvfp4_dual_gemm/template.py @@ -9,16 +9,19 @@ def custom_kernel(data: input_t) -> output_t: a: torch.Tensor[float4e2m1fn] of shape [m, k, l], b1: torch.Tensor[float4e2m1fn] of shape [n, k, l], b2: torch.Tensor[float4e2m1fn] of shape [n, k, l], - sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], - sfb1: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], - sfb2: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], + sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], used by reference implementation + sfb1: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], used by reference implementation + sfb2: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], used by reference implementation + sfa_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l], + sfb1_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], + sfb2_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], c: torch.Tensor[float16] of shape [m, n, l] Returns: Tensor containing output in float16 c: torch.Tensor[float16] of shape [m, n, l] """ # c: [m, n, l] is pre-allocated memory to avoid timing allocation overhead. - a, b, sfa, sfb, c = data + a, b1, b2, sfa, sfb1, sfb2, sfa_permuted, sfb1_permuted, sfb2_permuted, c = data # Your implementation here diff --git a/problems/nvidia/nvfp4_dual_gemm/test_python_1.sh b/problems/nvidia/nvfp4_dual_gemm/test_python_1.sh new file mode 100644 index 0000000..8648bdb --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/test_python_1.sh @@ -0,0 +1,87 @@ +# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build +BUILD_DIR=/home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/build_python +LLVM_DIR=$BUILD_DIR/llvm-prebuilt +# # BUILD_DIR=/home/scratch.ftse_gpu/workspace/dkg/build +# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build +# #BUILD_DIR=/home/yanchengz/scratch_1/dynamic-kernel-generator/build_debug2 +# # sudo /home/scratch.computelab/utils/driver/install_driver.py --installer=/home/builds/daily/display/x86_64/rel/gpu_drv/r580/r580_00/20250527_36037303/NVIDIA-Linux-x86_64-rel_gpu_drv_r580_r580_00-20250527_36037303-internal.run --reason="Change to tot driver" + + +# # BUILD_DIR=/home/scratch.nbommi_gpu/warp-phase-trace/dynamic-kernel-generator/build_main + +export PYTHONPATH=$BUILD_DIR/cutlass_ir/python_packages +#export PYTHONPATH=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/scripts +export CUDA_TOOLKIT_PATH=$BUILD_DIR/compiler_next +MLIR_CUDA_RUNTIME="$LLVM_DIR/lib/libmlir_cuda_runtime.so" +MLIR_C_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_c_runner_utils.so" +MLIR_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_runner_utils.so" +CUDA_DIALECT_RUNTIME="$BUILD_DIR/lib/libcuda_dialect_runtime.so" +export CUTE_DSL_LIBS="$MLIR_CUDA_RUNTIME:$MLIR_C_RUNNER_UTILS:$MLIR_RUNNER_UTILS:$CUDA_DIALECT_RUNTIME" + + +#export CUTE_DSL_PREPROCESSOR=True + +# export CUTE_DSL_PRINT_IR=1 +# just compile the IR but not execute it +# export CUTE_DSL_DRYRUN=1 +# export CUTE_DSL_JIT_TIME_PROFILING=ON +# export CUTE_DSL_KEEP_IR=True +# export CUTE_DSL_PRINT_IR=1 +# export CUTE_DSL_KEEP_CUBIN=1 +# export CUTE_DSL_LINEINFO=True +# export CUTE_DSL_LOG_TO_CONSOLE=1 +# export PYTHONUNBUFFERED=1 +# export CUTE_DSL_KEEP_SASS=1 +# whether to show detailed log in preprocessing +# export CUTE_DSL_FILTER_STACKTRACE=10 +export CUTE_DSL_ARCH=sm_100a + +# +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py +/home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/reference-kernels/problems/nvidia/nvfp4_dual_gemm/submission.py +/home/scratch.vickiw_gpu/env/bin/python3 eval.py test task.yml +/home/scratch.vickiw_gpu/env/bin/python3 eval.py benchmark task.yml +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/cuda-gdb --args + +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py +# # /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gated_dual_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gecccccbkvnjtrvtfreufijlfglnudnvuggvdfucidbnhk +# mm.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemm/nvfp4_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemv/nvfp4_gemv.py +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool memcheck \ +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,16384 #135us +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 4096,128,7168 #62 + +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,2048 #26 + + +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gated_dual_gemm/nvfp4_gated_dual_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_naive.py + + + +# print out ncu time +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# python3 vicki/tutorial_fp16_gemm_0__.py --mnk 7168,8,512 + +# use sanitizer to check race contention and memref error +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck|memcheck +# cutlass_ir/compiler/test/python/examples/sm_100a/test_nvfp4_gemv.py + +# capture ncu report +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --check-exit-code 0 -f --set full --import-source yes --target-processes all --clock-control base --cache-control none -o gemv_4.1 \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv.py --m 128 --k 128 --l 2 + +# regular run python example +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/min_latency_hmma.py --mnkl 7168,8,512,1 + +# run pytest +# pytest cutlass_ir/compiler/test/python/examples/sm_80/test_sgemm.py diff --git a/problems/nvidia/nvfp4_dual_gemm/utils.py b/problems/nvidia/nvfp4_dual_gemm/utils.py new file mode 100644 index 0000000..e8a9082 --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/utils.py @@ -0,0 +1,176 @@ +import os +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + + Raises: + AssertionError: If the tensors are not all close within the given tolerance. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received - expected) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +@torch.no_grad() +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) + + if len(reasons) > 0: + return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) + + return True, '' + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + return wrapped + + +class DeterministicContext: + def __init__(self): + self.allow_tf32 = None + self.deterministic = None + self.cublas = None + + def __enter__(self): + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + torch.use_deterministic_algorithms(False) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy \ No newline at end of file From a9e20d4a82286d14d8017c5c3a3caf829eedc302 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Tue, 21 Oct 2025 00:45:21 -0700 Subject: [PATCH 15/29] simplify code --- problems/nvidia/nvfp4_dual_gemm/submission.py | 26 - problems/nvidia/nvfp4_gemm/reference.py | 2 + problems/nvidia/nvfp4_gemm/submission.py | 26 - problems/nvidia/nvfp4_gemv/eval.py | 500 ++++ problems/nvidia/nvfp4_gemv/log | 2332 +++++++++++++++++ problems/nvidia/nvfp4_gemv/reference.py | 46 +- problems/nvidia/nvfp4_gemv/submission.py | 87 +- problems/nvidia/nvfp4_gemv/task.py | 2 +- problems/nvidia/nvfp4_gemv/template.py | 8 +- problems/nvidia/nvfp4_gemv/test_python_1.sh | 87 + problems/nvidia/nvfp4_gemv/utils.py | 176 ++ 11 files changed, 3151 insertions(+), 141 deletions(-) create mode 100644 problems/nvidia/nvfp4_gemv/eval.py create mode 100644 problems/nvidia/nvfp4_gemv/log create mode 100644 problems/nvidia/nvfp4_gemv/test_python_1.sh create mode 100644 problems/nvidia/nvfp4_gemv/utils.py diff --git a/problems/nvidia/nvfp4_dual_gemm/submission.py b/problems/nvidia/nvfp4_dual_gemm/submission.py index c767142..e9f6a6d 100644 --- a/problems/nvidia/nvfp4_dual_gemm/submission.py +++ b/problems/nvidia/nvfp4_dual_gemm/submission.py @@ -41,32 +41,6 @@ def ceil_div(a, b): return (a + b - 1) // b -# Helper function to reorder the scale factor tensor to match the layout defined in -# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout -@cute.jit -def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - sf_ref_ptr: cute.Pointer, - sf_mma_ptr: cute.Pointer, - mn: int, - sf_k: int, - l: int, - mma_shape: tuple, -): - mma_permute_order = (3, 4, 1, 5, 2, 0) - permuted_shape = tuple(mma_shape[i] for i in mma_permute_order) - cute_layout = cute.make_ordered_layout(permuted_shape, order=(2, 1, 4, 0, 3, 5)) - - sf_ref_tensor = cute.make_tensor( - sf_ref_ptr, cute.make_layout((mn, sf_k, l), stride=(sf_k, 1, mn * sf_k)) - ) - sf_mma_tensor = cute.make_tensor(sf_mma_ptr, cute_layout) - sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) - sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) - for i in cutlass.range(cute.size(sf_ref_tensor)): - mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) - sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] - - # GPU device kernel @cute.kernel def kernel( diff --git a/problems/nvidia/nvfp4_gemm/reference.py b/problems/nvidia/nvfp4_gemm/reference.py index ae56d1f..fecb7b6 100644 --- a/problems/nvidia/nvfp4_gemm/reference.py +++ b/problems/nvidia/nvfp4_gemm/reference.py @@ -76,6 +76,8 @@ def generate_input( b: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type scale_b: [n, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type c: [m, n, l] - Output matrix in torch.float16 data type """ torch.manual_seed(seed) diff --git a/problems/nvidia/nvfp4_gemm/submission.py b/problems/nvidia/nvfp4_gemm/submission.py index 8e25fcc..0089c25 100644 --- a/problems/nvidia/nvfp4_gemm/submission.py +++ b/problems/nvidia/nvfp4_gemm/submission.py @@ -41,32 +41,6 @@ def ceil_div(a, b): return (a + b - 1) // b -# Helper function to reorder the scale factor tensor to match the layout defined in -# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout -@cute.jit -def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - sf_ref_ptr: cute.Pointer, - sf_mma_ptr: cute.Pointer, - mn: int, - sf_k: int, - l: int, - mma_shape: tuple, -): - mma_permute_order = (3, 4, 1, 5, 2, 0) - permuted_shape = tuple(mma_shape[i] for i in mma_permute_order) - cute_layout = cute.make_ordered_layout(permuted_shape, order=(2, 1, 4, 0, 3, 5)) - - sf_ref_tensor = cute.make_tensor( - sf_ref_ptr, cute.make_layout((mn, sf_k, l), stride=(sf_k, 1, mn * sf_k)) - ) - sf_mma_tensor = cute.make_tensor(sf_mma_ptr, cute_layout) - sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) - sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) - for i in cutlass.range(cute.size(sf_ref_tensor)): - mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) - sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] - - # The CuTe reference implementation for NVFP4 block-scaled GEMM @cute.kernel def kernel( diff --git a/problems/nvidia/nvfp4_gemv/eval.py b/problems/nvidia/nvfp4_gemv/eval.py new file mode 100644 index 0000000..e8bb5b2 --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/eval.py @@ -0,0 +1,500 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional +import tempfile + +import torch.cuda +from cutlass.cute.nvgpu.common import OpError + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, "w") + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats( + runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) + ) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + + data = generate_input(**test.args) + torch.cuda.synchronize() + try: + submission_output = custom_kernel(_clone_data(data)) + + except OpError as E: + print(f"Encountered {E}", file=sys.stderr) + return False, str(E) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + # Step 1: Compile kernel once before running tests + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Run all tests with compiled kernel + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _compile_kernel_once(): + """ + Compile the kernel once before any benchmarking. + This ensures compilation time is not included in benchmark results. + """ + from submission import compile_kernel + + try: + # Trigger compilation (will be cached) + compile_kernel() + torch.cuda.synchronize() + return True, None + except OpError as E: + return False, f"Compilation failed: {E}" + except Exception as E: + return False, f"Compilation failed: {E}" + + +def _run_single_benchmark( + test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float +) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel, compile_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + + # Ensure kernel is compiled before any timing (compilation is cached) + try: + compile_kernel() + torch.cuda.synchronize() + except OpError as E: + return f"Compilation failed: {E}" + except Exception as E: + return f"Compilation failed: {E}" + + # first, one obligatory correctness check + try: + output = custom_kernel(_clone_data(data)) + except OpError as E: + return f"Encountered {E}" + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 200 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if ( + stats.err / stats.mean < 0.001 + or stats.mean * stats.runs > max_time_ns + or total_bm_duration > 120e9 + ): + break + + return calculate_stats(durations) + + +def run_single_benchmark( + pool: multiprocessing.Pool, + test: TestCase, + recheck: bool, + max_repeats: int, + max_time_ns: float, +): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warm up with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 10e7) + + # Step 3: Run benchmarks (compilation time excluded) + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 200, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log( + f"benchmark.{idx}.report", + base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), + ) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + + filename = None + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + + def build_test_string(tests: list[dict]): + as_str = "" + for test in tests: + kvs = [] + for k, v in test.items(): + kvs.append(f"{k}: {v}") + as_str += "; ".join(kvs) + "\n" + return as_str + + import yaml + + yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + if mode == "test": + tests_str = build_test_string(yaml_content.get("tests", [])) + elif mode in ("benchmark", "leaderboard", "profile"): + tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + tmp.write(tests_str.encode("utf-8")) + tmp.flush() + filename = tmp.name + + tests = get_test_cases(filename, seed) + + os.unlink(filename) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + + mp_context = multiprocessing.get_context("spawn") + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warmup with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 1e7) + + # Step 3: Run leaderboard benchmarks (compilation time excluded) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 200, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log( + f"benchmark.{i}.{field.name}", + getattr(result, field.name), + ) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log( + f"benchmark.{i}.error", str(result) + ) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/problems/nvidia/nvfp4_gemv/log b/problems/nvidia/nvfp4_gemv/log new file mode 100644 index 0000000..901787e --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/log @@ -0,0 +1,2332 @@ +a_ptr : raw_ptr(0x00007f417f608200: f4E2M1FN, gmem, align<16>) +b_ptr : raw_ptr(0x00007f417f60c200: f4E2M1FN, gmem, align<16>) +sfa_ptr : raw_ptr(0x00007f417f610400: f8E4M3FN, gmem, align<32>) +sfb_ptr : raw_ptr(0x00007f417f610c00: f8E4M3FN, gmem, align<32>) +c_ptr : raw_ptr(0x00007f417f610200: f16, gmem, align<16>) +problem_size : (128,1,256,1) +res[0] = 3.000000 + +res[0] = 4.250000 + +res[0] = 4.500000 + +res[0] = 7.500000 + +a.shape : torch.Size([128, 128, 1]) +b.shape : torch.Size([128, 128, 1]) +sfa_cpu.shape : torch.Size([128, 16, 1]) +sfb_cpu.shape : torch.Size([128, 16, 1]) +sfa_reordered_cpu.shape : torch.Size([32, 4, 1, 4, 4, 1]) +sfb_reordered_cpu.shape : torch.Size([32, 4, 1, 4, 4, 1]) +c.shape : torch.Size([128, 1, 1]) +a_ptr : 139919286632960 +b_ptr : 139919286649344 +sfa_ptr : 139919286666240 +sfb_ptr : 139919286668288 +c_ptr : 139919286665728 +problem_size : (128, 1, 256, 1) +c_cute[0, 0, 0] = 7.5 +c_cute[1, 0, 0] = 10.25 +c_cute[2, 0, 0] = 12.25 +c_cute[3, 0, 0] = 15.25 +c_cute[4, 0, 0] = 13.25 +c_cute[5, 0, 0] = 17.25 +c_cute[6, 0, 0] = 15.25 +c_cute[7, 0, 0] = 15.5 +c_cute[8, 0, 0] = 18.0 +c_cute[9, 0, 0] = 12.25 +c_cute[10, 0, 0] = 14.25 +c_cute[11, 0, 0] = 11.5 +c_cute[12, 0, 0] = 15.0 +c_cute[13, 0, 0] = 14.0 +c_cute[14, 0, 0] = 17.0 +c_cute[15, 0, 0] = 13.25 +c_cute[16, 0, 0] = 19.25 +c_cute[17, 0, 0] = 12.75 +c_cute[18, 0, 0] = 12.5 +c_cute[19, 0, 0] = 17.0 +c_cute[20, 0, 0] = 14.25 +c_cute[21, 0, 0] = 16.25 +c_cute[22, 0, 0] = 18.5 +c_cute[23, 0, 0] = 12.0 +c_cute[24, 0, 0] = 17.25 +c_cute[25, 0, 0] = 13.0 +c_cute[26, 0, 0] = 18.25 +c_cute[27, 0, 0] = 17.0 +c_cute[28, 0, 0] = 10.25 +c_cute[29, 0, 0] = 12.75 +c_cute[30, 0, 0] = 17.5 +c_cute[31, 0, 0] = 19.0 +c_cute[32, 0, 0] = 13.5 +c_cute[33, 0, 0] = 14.75 +c_cute[34, 0, 0] = 14.75 +c_cute[35, 0, 0] = 17.25 +c_cute[36, 0, 0] = 15.25 +c_cute[37, 0, 0] = 18.0 +c_cute[38, 0, 0] = 19.25 +c_cute[39, 0, 0] = 13.75 +c_cute[40, 0, 0] = 15.75 +c_cute[41, 0, 0] = 13.5 +c_cute[42, 0, 0] = 12.0 +c_cute[43, 0, 0] = 16.75 +c_cute[44, 0, 0] = 18.75 +c_cute[45, 0, 0] = 12.75 +c_cute[46, 0, 0] = 10.5 +c_cute[47, 0, 0] = 9.25 +c_cute[48, 0, 0] = 12.5 +c_cute[49, 0, 0] = 14.5 +c_cute[50, 0, 0] = 13.25 +c_cute[51, 0, 0] = 17.25 +c_cute[52, 0, 0] = 14.75 +c_cute[53, 0, 0] = 13.75 +c_cute[54, 0, 0] = 13.5 +c_cute[55, 0, 0] = 12.5 +c_cute[56, 0, 0] = 9.75 +c_cute[57, 0, 0] = 11.0 +c_cute[58, 0, 0] = 16.75 +c_cute[59, 0, 0] = 14.0 +c_cute[60, 0, 0] = 16.0 +c_cute[61, 0, 0] = 13.0 +c_cute[62, 0, 0] = 14.75 +c_cute[63, 0, 0] = 14.75 +c_cute[64, 0, 0] = 13.25 +c_cute[65, 0, 0] = 18.0 +c_cute[66, 0, 0] = 15.0 +c_cute[67, 0, 0] = 13.75 +c_cute[68, 0, 0] = 12.5 +c_cute[69, 0, 0] = 15.75 +c_cute[70, 0, 0] = 10.5 +c_cute[71, 0, 0] = 16.25 +c_cute[72, 0, 0] = 16.25 +c_cute[73, 0, 0] = 14.5 +c_cute[74, 0, 0] = 16.0 +c_cute[75, 0, 0] = 17.0 +c_cute[76, 0, 0] = 17.25 +c_cute[77, 0, 0] = 10.5 +c_cute[78, 0, 0] = 12.5 +c_cute[79, 0, 0] = 13.0 +c_cute[80, 0, 0] = 12.5 +c_cute[81, 0, 0] = 11.0 +c_cute[82, 0, 0] = 15.0 +c_cute[83, 0, 0] = 13.75 +c_cute[84, 0, 0] = 12.25 +c_cute[85, 0, 0] = 13.25 +c_cute[86, 0, 0] = 13.75 +c_cute[87, 0, 0] = 17.0 +c_cute[88, 0, 0] = 14.0 +c_cute[89, 0, 0] = 13.0 +c_cute[90, 0, 0] = 14.25 +c_cute[91, 0, 0] = 15.75 +c_cute[92, 0, 0] = 9.5 +c_cute[93, 0, 0] = 13.0 +c_cute[94, 0, 0] = 11.0 +c_cute[95, 0, 0] = 13.75 +c_cute[96, 0, 0] = 15.25 +c_cute[97, 0, 0] = 12.75 +c_cute[98, 0, 0] = 14.5 +c_cute[99, 0, 0] = 13.0 +c_cute[100, 0, 0] = 11.75 +c_cute[101, 0, 0] = 12.0 +c_cute[102, 0, 0] = 18.0 +c_cute[103, 0, 0] = 15.5 +c_cute[104, 0, 0] = 12.75 +c_cute[105, 0, 0] = 12.5 +c_cute[106, 0, 0] = 14.75 +c_cute[107, 0, 0] = 16.75 +c_cute[108, 0, 0] = 13.5 +c_cute[109, 0, 0] = 15.25 +c_cute[110, 0, 0] = 13.5 +c_cute[111, 0, 0] = 11.75 +c_cute[112, 0, 0] = 17.25 +c_cute[113, 0, 0] = 16.25 +c_cute[114, 0, 0] = 11.25 +c_cute[115, 0, 0] = 10.75 +c_cute[116, 0, 0] = 13.5 +c_cute[117, 0, 0] = 11.5 +c_cute[118, 0, 0] = 15.5 +c_cute[119, 0, 0] = 17.25 +c_cute[120, 0, 0] = 14.75 +c_cute[121, 0, 0] = 17.0 +c_cute[122, 0, 0] = 15.5 +c_cute[123, 0, 0] = 14.75 +c_cute[124, 0, 0] = 18.0 +c_cute[125, 0, 0] = 13.0 +c_cute[126, 0, 0] = 15.5 +c_cute[127, 0, 0] = 14.75 +-------------------------------- +sfa_ref_cpu[0] = 1.0 +sfa_ref_cpu[1] = 2.0 +sfa_ref_cpu[2] = 2.0 +sfa_ref_cpu[3] = 2.0 +sfa_ref_cpu[4] = 1.0 +sfa_ref_cpu[5] = 1.0 +sfa_ref_cpu[6] = 1.0 +sfa_ref_cpu[7] = 1.0 +sfa_ref_cpu[8] = 1.0 +sfa_ref_cpu[9] = 1.0 +sfa_ref_cpu[10] = 2.0 +sfa_ref_cpu[11] = 1.0 +sfa_ref_cpu[12] = 1.0 +sfa_ref_cpu[13] = 1.0 +sfa_ref_cpu[14] = 1.0 +sfa_ref_cpu[15] = 1.0 +sfa_ref_cpu[16] = 1.0 +sfa_ref_cpu[17] = 1.0 +sfa_ref_cpu[18] = 1.0 +sfa_ref_cpu[19] = 2.0 +sfa_ref_cpu[20] = 1.0 +sfa_ref_cpu[21] = 2.0 +sfa_ref_cpu[22] = 2.0 +sfa_ref_cpu[23] = 1.0 +sfa_ref_cpu[24] = 1.0 +sfa_ref_cpu[25] = 1.0 +sfa_ref_cpu[26] = 1.0 +sfa_ref_cpu[27] = 1.0 +sfa_ref_cpu[28] = 1.0 +sfa_ref_cpu[29] = 2.0 +sfa_ref_cpu[30] = 1.0 +sfa_ref_cpu[31] = 2.0 +sfa_ref_cpu[32] = 2.0 +sfa_ref_cpu[33] = 2.0 +sfa_ref_cpu[34] = 1.0 +sfa_ref_cpu[35] = 1.0 +sfa_ref_cpu[36] = 1.0 +sfa_ref_cpu[37] = 1.0 +sfa_ref_cpu[38] = 2.0 +sfa_ref_cpu[39] = 2.0 +sfa_ref_cpu[40] = 2.0 +sfa_ref_cpu[41] = 2.0 +sfa_ref_cpu[42] = 2.0 +sfa_ref_cpu[43] = 1.0 +sfa_ref_cpu[44] = 1.0 +sfa_ref_cpu[45] = 1.0 +sfa_ref_cpu[46] = 2.0 +sfa_ref_cpu[47] = 2.0 +sfa_ref_cpu[48] = 2.0 +sfa_ref_cpu[49] = 2.0 +sfa_ref_cpu[50] = 2.0 +sfa_ref_cpu[51] = 2.0 +sfa_ref_cpu[52] = 2.0 +sfa_ref_cpu[53] = 1.0 +sfa_ref_cpu[54] = 2.0 +sfa_ref_cpu[55] = 1.0 +sfa_ref_cpu[56] = 1.0 +sfa_ref_cpu[57] = 1.0 +sfa_ref_cpu[58] = 1.0 +sfa_ref_cpu[59] = 1.0 +sfa_ref_cpu[60] = 2.0 +sfa_ref_cpu[61] = 2.0 +sfa_ref_cpu[62] = 2.0 +sfa_ref_cpu[63] = 2.0 +sfa_ref_cpu[64] = 1.0 +sfa_ref_cpu[65] = 1.0 +sfa_ref_cpu[66] = 1.0 +sfa_ref_cpu[67] = 2.0 +sfa_ref_cpu[68] = 2.0 +sfa_ref_cpu[69] = 2.0 +sfa_ref_cpu[70] = 2.0 +sfa_ref_cpu[71] = 1.0 +sfa_ref_cpu[72] = 2.0 +sfa_ref_cpu[73] = 1.0 +sfa_ref_cpu[74] = 2.0 +sfa_ref_cpu[75] = 2.0 +sfa_ref_cpu[76] = 1.0 +sfa_ref_cpu[77] = 2.0 +sfa_ref_cpu[78] = 1.0 +sfa_ref_cpu[79] = 2.0 +sfa_ref_cpu[80] = 2.0 +sfa_ref_cpu[81] = 2.0 +sfa_ref_cpu[82] = 2.0 +sfa_ref_cpu[83] = 1.0 +sfa_ref_cpu[84] = 2.0 +sfa_ref_cpu[85] = 1.0 +sfa_ref_cpu[86] = 1.0 +sfa_ref_cpu[87] = 1.0 +sfa_ref_cpu[88] = 2.0 +sfa_ref_cpu[89] = 2.0 +sfa_ref_cpu[90] = 2.0 +sfa_ref_cpu[91] = 1.0 +sfa_ref_cpu[92] = 2.0 +sfa_ref_cpu[93] = 2.0 +sfa_ref_cpu[94] = 1.0 +sfa_ref_cpu[95] = 2.0 +sfa_ref_cpu[96] = 2.0 +sfa_ref_cpu[97] = 2.0 +sfa_ref_cpu[98] = 2.0 +sfa_ref_cpu[99] = 2.0 +sfa_ref_cpu[100] = 2.0 +sfa_ref_cpu[101] = 1.0 +sfa_ref_cpu[102] = 1.0 +sfa_ref_cpu[103] = 1.0 +sfa_ref_cpu[104] = 1.0 +sfa_ref_cpu[105] = 2.0 +sfa_ref_cpu[106] = 1.0 +sfa_ref_cpu[107] = 2.0 +sfa_ref_cpu[108] = 1.0 +sfa_ref_cpu[109] = 2.0 +sfa_ref_cpu[110] = 2.0 +sfa_ref_cpu[111] = 2.0 +sfa_ref_cpu[112] = 1.0 +sfa_ref_cpu[113] = 1.0 +sfa_ref_cpu[114] = 1.0 +sfa_ref_cpu[115] = 2.0 +sfa_ref_cpu[116] = 1.0 +sfa_ref_cpu[117] = 1.0 +sfa_ref_cpu[118] = 2.0 +sfa_ref_cpu[119] = 1.0 +sfa_ref_cpu[120] = 2.0 +sfa_ref_cpu[121] = 1.0 +sfa_ref_cpu[122] = 1.0 +sfa_ref_cpu[123] = 2.0 +sfa_ref_cpu[124] = 2.0 +sfa_ref_cpu[125] = 2.0 +sfa_ref_cpu[126] = 2.0 +sfa_ref_cpu[127] = 1.0 +sfa_ref_cpu[128] = 2.0 +sfa_ref_cpu[129] = 2.0 +sfa_ref_cpu[130] = 2.0 +sfa_ref_cpu[131] = 1.0 +sfa_ref_cpu[132] = 2.0 +sfa_ref_cpu[133] = 2.0 +sfa_ref_cpu[134] = 1.0 +sfa_ref_cpu[135] = 2.0 +sfa_ref_cpu[136] = 1.0 +sfa_ref_cpu[137] = 1.0 +sfa_ref_cpu[138] = 2.0 +sfa_ref_cpu[139] = 1.0 +sfa_ref_cpu[140] = 1.0 +sfa_ref_cpu[141] = 2.0 +sfa_ref_cpu[142] = 1.0 +sfa_ref_cpu[143] = 1.0 +sfa_ref_cpu[144] = 1.0 +sfa_ref_cpu[145] = 2.0 +sfa_ref_cpu[146] = 1.0 +sfa_ref_cpu[147] = 2.0 +sfa_ref_cpu[148] = 2.0 +sfa_ref_cpu[149] = 2.0 +sfa_ref_cpu[150] = 2.0 +sfa_ref_cpu[151] = 1.0 +sfa_ref_cpu[152] = 2.0 +sfa_ref_cpu[153] = 2.0 +sfa_ref_cpu[154] = 2.0 +sfa_ref_cpu[155] = 1.0 +sfa_ref_cpu[156] = 1.0 +sfa_ref_cpu[157] = 1.0 +sfa_ref_cpu[158] = 1.0 +sfa_ref_cpu[159] = 1.0 +sfa_ref_cpu[160] = 1.0 +sfa_ref_cpu[161] = 1.0 +sfa_ref_cpu[162] = 1.0 +sfa_ref_cpu[163] = 2.0 +sfa_ref_cpu[164] = 2.0 +sfa_ref_cpu[165] = 2.0 +sfa_ref_cpu[166] = 1.0 +sfa_ref_cpu[167] = 1.0 +sfa_ref_cpu[168] = 2.0 +sfa_ref_cpu[169] = 1.0 +sfa_ref_cpu[170] = 1.0 +sfa_ref_cpu[171] = 2.0 +sfa_ref_cpu[172] = 1.0 +sfa_ref_cpu[173] = 1.0 +sfa_ref_cpu[174] = 2.0 +sfa_ref_cpu[175] = 2.0 +sfa_ref_cpu[176] = 1.0 +sfa_ref_cpu[177] = 1.0 +sfa_ref_cpu[178] = 1.0 +sfa_ref_cpu[179] = 2.0 +sfa_ref_cpu[180] = 1.0 +sfa_ref_cpu[181] = 1.0 +sfa_ref_cpu[182] = 1.0 +sfa_ref_cpu[183] = 1.0 +sfa_ref_cpu[184] = 2.0 +sfa_ref_cpu[185] = 1.0 +sfa_ref_cpu[186] = 1.0 +sfa_ref_cpu[187] = 1.0 +sfa_ref_cpu[188] = 2.0 +sfa_ref_cpu[189] = 1.0 +sfa_ref_cpu[190] = 2.0 +sfa_ref_cpu[191] = 2.0 +sfa_ref_cpu[192] = 2.0 +sfa_ref_cpu[193] = 1.0 +sfa_ref_cpu[194] = 2.0 +sfa_ref_cpu[195] = 2.0 +sfa_ref_cpu[196] = 1.0 +sfa_ref_cpu[197] = 2.0 +sfa_ref_cpu[198] = 2.0 +sfa_ref_cpu[199] = 2.0 +sfa_ref_cpu[200] = 1.0 +sfa_ref_cpu[201] = 2.0 +sfa_ref_cpu[202] = 2.0 +sfa_ref_cpu[203] = 2.0 +sfa_ref_cpu[204] = 1.0 +sfa_ref_cpu[205] = 1.0 +sfa_ref_cpu[206] = 2.0 +sfa_ref_cpu[207] = 2.0 +sfa_ref_cpu[208] = 2.0 +sfa_ref_cpu[209] = 2.0 +sfa_ref_cpu[210] = 1.0 +sfa_ref_cpu[211] = 2.0 +sfa_ref_cpu[212] = 2.0 +sfa_ref_cpu[213] = 1.0 +sfa_ref_cpu[214] = 2.0 +sfa_ref_cpu[215] = 1.0 +sfa_ref_cpu[216] = 2.0 +sfa_ref_cpu[217] = 2.0 +sfa_ref_cpu[218] = 1.0 +sfa_ref_cpu[219] = 1.0 +sfa_ref_cpu[220] = 1.0 +sfa_ref_cpu[221] = 2.0 +sfa_ref_cpu[222] = 1.0 +sfa_ref_cpu[223] = 1.0 +sfa_ref_cpu[224] = 2.0 +sfa_ref_cpu[225] = 1.0 +sfa_ref_cpu[226] = 1.0 +sfa_ref_cpu[227] = 2.0 +sfa_ref_cpu[228] = 1.0 +sfa_ref_cpu[229] = 1.0 +sfa_ref_cpu[230] = 1.0 +sfa_ref_cpu[231] = 1.0 +sfa_ref_cpu[232] = 2.0 +sfa_ref_cpu[233] = 2.0 +sfa_ref_cpu[234] = 2.0 +sfa_ref_cpu[235] = 2.0 +sfa_ref_cpu[236] = 2.0 +sfa_ref_cpu[237] = 2.0 +sfa_ref_cpu[238] = 2.0 +sfa_ref_cpu[239] = 2.0 +sfa_ref_cpu[240] = 2.0 +sfa_ref_cpu[241] = 1.0 +sfa_ref_cpu[242] = 2.0 +sfa_ref_cpu[243] = 1.0 +sfa_ref_cpu[244] = 2.0 +sfa_ref_cpu[245] = 1.0 +sfa_ref_cpu[246] = 1.0 +sfa_ref_cpu[247] = 2.0 +sfa_ref_cpu[248] = 1.0 +sfa_ref_cpu[249] = 1.0 +sfa_ref_cpu[250] = 1.0 +sfa_ref_cpu[251] = 2.0 +sfa_ref_cpu[252] = 1.0 +sfa_ref_cpu[253] = 2.0 +sfa_ref_cpu[254] = 2.0 +sfa_ref_cpu[255] = 1.0 +sfa_ref_cpu[256] = 1.0 +sfa_ref_cpu[257] = 2.0 +sfa_ref_cpu[258] = 2.0 +sfa_ref_cpu[259] = 1.0 +sfa_ref_cpu[260] = 2.0 +sfa_ref_cpu[261] = 1.0 +sfa_ref_cpu[262] = 2.0 +sfa_ref_cpu[263] = 2.0 +sfa_ref_cpu[264] = 1.0 +sfa_ref_cpu[265] = 2.0 +sfa_ref_cpu[266] = 2.0 +sfa_ref_cpu[267] = 2.0 +sfa_ref_cpu[268] = 1.0 +sfa_ref_cpu[269] = 2.0 +sfa_ref_cpu[270] = 2.0 +sfa_ref_cpu[271] = 2.0 +sfa_ref_cpu[272] = 1.0 +sfa_ref_cpu[273] = 1.0 +sfa_ref_cpu[274] = 1.0 +sfa_ref_cpu[275] = 1.0 +sfa_ref_cpu[276] = 1.0 +sfa_ref_cpu[277] = 2.0 +sfa_ref_cpu[278] = 2.0 +sfa_ref_cpu[279] = 2.0 +sfa_ref_cpu[280] = 1.0 +sfa_ref_cpu[281] = 1.0 +sfa_ref_cpu[282] = 2.0 +sfa_ref_cpu[283] = 1.0 +sfa_ref_cpu[284] = 1.0 +sfa_ref_cpu[285] = 2.0 +sfa_ref_cpu[286] = 2.0 +sfa_ref_cpu[287] = 1.0 +sfa_ref_cpu[288] = 1.0 +sfa_ref_cpu[289] = 2.0 +sfa_ref_cpu[290] = 1.0 +sfa_ref_cpu[291] = 2.0 +sfa_ref_cpu[292] = 1.0 +sfa_ref_cpu[293] = 1.0 +sfa_ref_cpu[294] = 1.0 +sfa_ref_cpu[295] = 1.0 +sfa_ref_cpu[296] = 2.0 +sfa_ref_cpu[297] = 2.0 +sfa_ref_cpu[298] = 2.0 +sfa_ref_cpu[299] = 1.0 +sfa_ref_cpu[300] = 1.0 +sfa_ref_cpu[301] = 2.0 +sfa_ref_cpu[302] = 1.0 +sfa_ref_cpu[303] = 1.0 +sfa_ref_cpu[304] = 2.0 +sfa_ref_cpu[305] = 1.0 +sfa_ref_cpu[306] = 2.0 +sfa_ref_cpu[307] = 1.0 +sfa_ref_cpu[308] = 1.0 +sfa_ref_cpu[309] = 2.0 +sfa_ref_cpu[310] = 1.0 +sfa_ref_cpu[311] = 2.0 +sfa_ref_cpu[312] = 1.0 +sfa_ref_cpu[313] = 2.0 +sfa_ref_cpu[314] = 2.0 +sfa_ref_cpu[315] = 2.0 +sfa_ref_cpu[316] = 2.0 +sfa_ref_cpu[317] = 1.0 +sfa_ref_cpu[318] = 1.0 +sfa_ref_cpu[319] = 2.0 +sfa_ref_cpu[320] = 1.0 +sfa_ref_cpu[321] = 2.0 +sfa_ref_cpu[322] = 1.0 +sfa_ref_cpu[323] = 1.0 +sfa_ref_cpu[324] = 2.0 +sfa_ref_cpu[325] = 1.0 +sfa_ref_cpu[326] = 2.0 +sfa_ref_cpu[327] = 1.0 +sfa_ref_cpu[328] = 2.0 +sfa_ref_cpu[329] = 2.0 +sfa_ref_cpu[330] = 1.0 +sfa_ref_cpu[331] = 2.0 +sfa_ref_cpu[332] = 2.0 +sfa_ref_cpu[333] = 1.0 +sfa_ref_cpu[334] = 1.0 +sfa_ref_cpu[335] = 1.0 +sfa_ref_cpu[336] = 1.0 +sfa_ref_cpu[337] = 1.0 +sfa_ref_cpu[338] = 1.0 +sfa_ref_cpu[339] = 2.0 +sfa_ref_cpu[340] = 2.0 +sfa_ref_cpu[341] = 2.0 +sfa_ref_cpu[342] = 2.0 +sfa_ref_cpu[343] = 2.0 +sfa_ref_cpu[344] = 1.0 +sfa_ref_cpu[345] = 1.0 +sfa_ref_cpu[346] = 1.0 +sfa_ref_cpu[347] = 2.0 +sfa_ref_cpu[348] = 2.0 +sfa_ref_cpu[349] = 2.0 +sfa_ref_cpu[350] = 1.0 +sfa_ref_cpu[351] = 1.0 +sfa_ref_cpu[352] = 2.0 +sfa_ref_cpu[353] = 2.0 +sfa_ref_cpu[354] = 2.0 +sfa_ref_cpu[355] = 2.0 +sfa_ref_cpu[356] = 2.0 +sfa_ref_cpu[357] = 1.0 +sfa_ref_cpu[358] = 1.0 +sfa_ref_cpu[359] = 2.0 +sfa_ref_cpu[360] = 1.0 +sfa_ref_cpu[361] = 1.0 +sfa_ref_cpu[362] = 2.0 +sfa_ref_cpu[363] = 1.0 +sfa_ref_cpu[364] = 2.0 +sfa_ref_cpu[365] = 2.0 +sfa_ref_cpu[366] = 2.0 +sfa_ref_cpu[367] = 1.0 +sfa_ref_cpu[368] = 1.0 +sfa_ref_cpu[369] = 1.0 +sfa_ref_cpu[370] = 2.0 +sfa_ref_cpu[371] = 1.0 +sfa_ref_cpu[372] = 1.0 +sfa_ref_cpu[373] = 2.0 +sfa_ref_cpu[374] = 2.0 +sfa_ref_cpu[375] = 2.0 +sfa_ref_cpu[376] = 1.0 +sfa_ref_cpu[377] = 2.0 +sfa_ref_cpu[378] = 1.0 +sfa_ref_cpu[379] = 1.0 +sfa_ref_cpu[380] = 2.0 +sfa_ref_cpu[381] = 1.0 +sfa_ref_cpu[382] = 2.0 +sfa_ref_cpu[383] = 2.0 +sfa_ref_cpu[384] = 2.0 +sfa_ref_cpu[385] = 2.0 +sfa_ref_cpu[386] = 2.0 +sfa_ref_cpu[387] = 2.0 +sfa_ref_cpu[388] = 2.0 +sfa_ref_cpu[389] = 1.0 +sfa_ref_cpu[390] = 1.0 +sfa_ref_cpu[391] = 1.0 +sfa_ref_cpu[392] = 1.0 +sfa_ref_cpu[393] = 1.0 +sfa_ref_cpu[394] = 1.0 +sfa_ref_cpu[395] = 2.0 +sfa_ref_cpu[396] = 2.0 +sfa_ref_cpu[397] = 2.0 +sfa_ref_cpu[398] = 2.0 +sfa_ref_cpu[399] = 1.0 +sfa_ref_cpu[400] = 2.0 +sfa_ref_cpu[401] = 1.0 +sfa_ref_cpu[402] = 1.0 +sfa_ref_cpu[403] = 1.0 +sfa_ref_cpu[404] = 1.0 +sfa_ref_cpu[405] = 1.0 +sfa_ref_cpu[406] = 2.0 +sfa_ref_cpu[407] = 1.0 +sfa_ref_cpu[408] = 1.0 +sfa_ref_cpu[409] = 1.0 +sfa_ref_cpu[410] = 1.0 +sfa_ref_cpu[411] = 1.0 +sfa_ref_cpu[412] = 1.0 +sfa_ref_cpu[413] = 2.0 +sfa_ref_cpu[414] = 2.0 +sfa_ref_cpu[415] = 1.0 +sfa_ref_cpu[416] = 2.0 +sfa_ref_cpu[417] = 2.0 +sfa_ref_cpu[418] = 2.0 +sfa_ref_cpu[419] = 1.0 +sfa_ref_cpu[420] = 1.0 +sfa_ref_cpu[421] = 1.0 +sfa_ref_cpu[422] = 2.0 +sfa_ref_cpu[423] = 2.0 +sfa_ref_cpu[424] = 2.0 +sfa_ref_cpu[425] = 2.0 +sfa_ref_cpu[426] = 2.0 +sfa_ref_cpu[427] = 1.0 +sfa_ref_cpu[428] = 2.0 +sfa_ref_cpu[429] = 1.0 +sfa_ref_cpu[430] = 1.0 +sfa_ref_cpu[431] = 1.0 +sfa_ref_cpu[432] = 1.0 +sfa_ref_cpu[433] = 2.0 +sfa_ref_cpu[434] = 1.0 +sfa_ref_cpu[435] = 2.0 +sfa_ref_cpu[436] = 2.0 +sfa_ref_cpu[437] = 1.0 +sfa_ref_cpu[438] = 1.0 +sfa_ref_cpu[439] = 1.0 +sfa_ref_cpu[440] = 2.0 +sfa_ref_cpu[441] = 2.0 +sfa_ref_cpu[442] = 2.0 +sfa_ref_cpu[443] = 2.0 +sfa_ref_cpu[444] = 2.0 +sfa_ref_cpu[445] = 2.0 +sfa_ref_cpu[446] = 2.0 +sfa_ref_cpu[447] = 2.0 +sfa_ref_cpu[448] = 1.0 +sfa_ref_cpu[449] = 2.0 +sfa_ref_cpu[450] = 1.0 +sfa_ref_cpu[451] = 1.0 +sfa_ref_cpu[452] = 2.0 +sfa_ref_cpu[453] = 1.0 +sfa_ref_cpu[454] = 2.0 +sfa_ref_cpu[455] = 1.0 +sfa_ref_cpu[456] = 1.0 +sfa_ref_cpu[457] = 2.0 +sfa_ref_cpu[458] = 1.0 +sfa_ref_cpu[459] = 2.0 +sfa_ref_cpu[460] = 1.0 +sfa_ref_cpu[461] = 2.0 +sfa_ref_cpu[462] = 2.0 +sfa_ref_cpu[463] = 1.0 +sfa_ref_cpu[464] = 1.0 +sfa_ref_cpu[465] = 1.0 +sfa_ref_cpu[466] = 1.0 +sfa_ref_cpu[467] = 1.0 +sfa_ref_cpu[468] = 1.0 +sfa_ref_cpu[469] = 2.0 +sfa_ref_cpu[470] = 1.0 +sfa_ref_cpu[471] = 2.0 +sfa_ref_cpu[472] = 2.0 +sfa_ref_cpu[473] = 1.0 +sfa_ref_cpu[474] = 2.0 +sfa_ref_cpu[475] = 2.0 +sfa_ref_cpu[476] = 1.0 +sfa_ref_cpu[477] = 2.0 +sfa_ref_cpu[478] = 2.0 +sfa_ref_cpu[479] = 1.0 +sfa_ref_cpu[480] = 1.0 +sfa_ref_cpu[481] = 1.0 +sfa_ref_cpu[482] = 1.0 +sfa_ref_cpu[483] = 2.0 +sfa_ref_cpu[484] = 2.0 +sfa_ref_cpu[485] = 1.0 +sfa_ref_cpu[486] = 1.0 +sfa_ref_cpu[487] = 1.0 +sfa_ref_cpu[488] = 2.0 +sfa_ref_cpu[489] = 2.0 +sfa_ref_cpu[490] = 1.0 +sfa_ref_cpu[491] = 2.0 +sfa_ref_cpu[492] = 2.0 +sfa_ref_cpu[493] = 1.0 +sfa_ref_cpu[494] = 1.0 +sfa_ref_cpu[495] = 1.0 +sfa_ref_cpu[496] = 2.0 +sfa_ref_cpu[497] = 2.0 +sfa_ref_cpu[498] = 1.0 +sfa_ref_cpu[499] = 2.0 +sfa_ref_cpu[500] = 1.0 +sfa_ref_cpu[501] = 1.0 +sfa_ref_cpu[502] = 2.0 +sfa_ref_cpu[503] = 1.0 +sfa_ref_cpu[504] = 2.0 +sfa_ref_cpu[505] = 2.0 +sfa_ref_cpu[506] = 2.0 +sfa_ref_cpu[507] = 1.0 +sfa_ref_cpu[508] = 2.0 +sfa_ref_cpu[509] = 2.0 +sfa_ref_cpu[510] = 2.0 +sfa_ref_cpu[511] = 1.0 +sfa_ref_cpu[512] = 2.0 +sfa_ref_cpu[513] = 2.0 +sfa_ref_cpu[514] = 2.0 +sfa_ref_cpu[515] = 1.0 +sfa_ref_cpu[516] = 1.0 +sfa_ref_cpu[517] = 1.0 +sfa_ref_cpu[518] = 2.0 +sfa_ref_cpu[519] = 2.0 +sfa_ref_cpu[520] = 1.0 +sfa_ref_cpu[521] = 1.0 +sfa_ref_cpu[522] = 2.0 +sfa_ref_cpu[523] = 1.0 +sfa_ref_cpu[524] = 2.0 +sfa_ref_cpu[525] = 2.0 +sfa_ref_cpu[526] = 2.0 +sfa_ref_cpu[527] = 2.0 +sfa_ref_cpu[528] = 1.0 +sfa_ref_cpu[529] = 2.0 +sfa_ref_cpu[530] = 2.0 +sfa_ref_cpu[531] = 2.0 +sfa_ref_cpu[532] = 1.0 +sfa_ref_cpu[533] = 2.0 +sfa_ref_cpu[534] = 2.0 +sfa_ref_cpu[535] = 2.0 +sfa_ref_cpu[536] = 1.0 +sfa_ref_cpu[537] = 2.0 +sfa_ref_cpu[538] = 2.0 +sfa_ref_cpu[539] = 1.0 +sfa_ref_cpu[540] = 1.0 +sfa_ref_cpu[541] = 1.0 +sfa_ref_cpu[542] = 1.0 +sfa_ref_cpu[543] = 2.0 +sfa_ref_cpu[544] = 1.0 +sfa_ref_cpu[545] = 2.0 +sfa_ref_cpu[546] = 2.0 +sfa_ref_cpu[547] = 2.0 +sfa_ref_cpu[548] = 1.0 +sfa_ref_cpu[549] = 1.0 +sfa_ref_cpu[550] = 1.0 +sfa_ref_cpu[551] = 1.0 +sfa_ref_cpu[552] = 1.0 +sfa_ref_cpu[553] = 2.0 +sfa_ref_cpu[554] = 2.0 +sfa_ref_cpu[555] = 2.0 +sfa_ref_cpu[556] = 2.0 +sfa_ref_cpu[557] = 1.0 +sfa_ref_cpu[558] = 1.0 +sfa_ref_cpu[559] = 2.0 +sfa_ref_cpu[560] = 1.0 +sfa_ref_cpu[561] = 2.0 +sfa_ref_cpu[562] = 1.0 +sfa_ref_cpu[563] = 1.0 +sfa_ref_cpu[564] = 2.0 +sfa_ref_cpu[565] = 1.0 +sfa_ref_cpu[566] = 2.0 +sfa_ref_cpu[567] = 2.0 +sfa_ref_cpu[568] = 1.0 +sfa_ref_cpu[569] = 2.0 +sfa_ref_cpu[570] = 1.0 +sfa_ref_cpu[571] = 2.0 +sfa_ref_cpu[572] = 1.0 +sfa_ref_cpu[573] = 1.0 +sfa_ref_cpu[574] = 1.0 +sfa_ref_cpu[575] = 2.0 +sfa_ref_cpu[576] = 2.0 +sfa_ref_cpu[577] = 1.0 +sfa_ref_cpu[578] = 1.0 +sfa_ref_cpu[579] = 2.0 +sfa_ref_cpu[580] = 1.0 +sfa_ref_cpu[581] = 1.0 +sfa_ref_cpu[582] = 2.0 +sfa_ref_cpu[583] = 1.0 +sfa_ref_cpu[584] = 2.0 +sfa_ref_cpu[585] = 2.0 +sfa_ref_cpu[586] = 1.0 +sfa_ref_cpu[587] = 1.0 +sfa_ref_cpu[588] = 2.0 +sfa_ref_cpu[589] = 2.0 +sfa_ref_cpu[590] = 2.0 +sfa_ref_cpu[591] = 1.0 +sfa_ref_cpu[592] = 1.0 +sfa_ref_cpu[593] = 1.0 +sfa_ref_cpu[594] = 1.0 +sfa_ref_cpu[595] = 1.0 +sfa_ref_cpu[596] = 2.0 +sfa_ref_cpu[597] = 2.0 +sfa_ref_cpu[598] = 2.0 +sfa_ref_cpu[599] = 2.0 +sfa_ref_cpu[600] = 2.0 +sfa_ref_cpu[601] = 2.0 +sfa_ref_cpu[602] = 2.0 +sfa_ref_cpu[603] = 2.0 +sfa_ref_cpu[604] = 1.0 +sfa_ref_cpu[605] = 2.0 +sfa_ref_cpu[606] = 2.0 +sfa_ref_cpu[607] = 1.0 +sfa_ref_cpu[608] = 2.0 +sfa_ref_cpu[609] = 2.0 +sfa_ref_cpu[610] = 2.0 +sfa_ref_cpu[611] = 1.0 +sfa_ref_cpu[612] = 1.0 +sfa_ref_cpu[613] = 1.0 +sfa_ref_cpu[614] = 2.0 +sfa_ref_cpu[615] = 1.0 +sfa_ref_cpu[616] = 2.0 +sfa_ref_cpu[617] = 2.0 +sfa_ref_cpu[618] = 2.0 +sfa_ref_cpu[619] = 2.0 +sfa_ref_cpu[620] = 1.0 +sfa_ref_cpu[621] = 2.0 +sfa_ref_cpu[622] = 2.0 +sfa_ref_cpu[623] = 2.0 +sfa_ref_cpu[624] = 2.0 +sfa_ref_cpu[625] = 1.0 +sfa_ref_cpu[626] = 1.0 +sfa_ref_cpu[627] = 1.0 +sfa_ref_cpu[628] = 2.0 +sfa_ref_cpu[629] = 1.0 +sfa_ref_cpu[630] = 1.0 +sfa_ref_cpu[631] = 1.0 +sfa_ref_cpu[632] = 1.0 +sfa_ref_cpu[633] = 2.0 +sfa_ref_cpu[634] = 1.0 +sfa_ref_cpu[635] = 2.0 +sfa_ref_cpu[636] = 2.0 +sfa_ref_cpu[637] = 2.0 +sfa_ref_cpu[638] = 1.0 +sfa_ref_cpu[639] = 2.0 +sfa_ref_cpu[640] = 2.0 +sfa_ref_cpu[641] = 1.0 +sfa_ref_cpu[642] = 2.0 +sfa_ref_cpu[643] = 1.0 +sfa_ref_cpu[644] = 1.0 +sfa_ref_cpu[645] = 1.0 +sfa_ref_cpu[646] = 2.0 +sfa_ref_cpu[647] = 2.0 +sfa_ref_cpu[648] = 1.0 +sfa_ref_cpu[649] = 1.0 +sfa_ref_cpu[650] = 2.0 +sfa_ref_cpu[651] = 1.0 +sfa_ref_cpu[652] = 1.0 +sfa_ref_cpu[653] = 1.0 +sfa_ref_cpu[654] = 1.0 +sfa_ref_cpu[655] = 1.0 +sfa_ref_cpu[656] = 1.0 +sfa_ref_cpu[657] = 1.0 +sfa_ref_cpu[658] = 1.0 +sfa_ref_cpu[659] = 1.0 +sfa_ref_cpu[660] = 2.0 +sfa_ref_cpu[661] = 1.0 +sfa_ref_cpu[662] = 2.0 +sfa_ref_cpu[663] = 1.0 +sfa_ref_cpu[664] = 2.0 +sfa_ref_cpu[665] = 1.0 +sfa_ref_cpu[666] = 1.0 +sfa_ref_cpu[667] = 1.0 +sfa_ref_cpu[668] = 1.0 +sfa_ref_cpu[669] = 2.0 +sfa_ref_cpu[670] = 1.0 +sfa_ref_cpu[671] = 1.0 +sfa_ref_cpu[672] = 1.0 +sfa_ref_cpu[673] = 1.0 +sfa_ref_cpu[674] = 2.0 +sfa_ref_cpu[675] = 1.0 +sfa_ref_cpu[676] = 1.0 +sfa_ref_cpu[677] = 1.0 +sfa_ref_cpu[678] = 1.0 +sfa_ref_cpu[679] = 2.0 +sfa_ref_cpu[680] = 2.0 +sfa_ref_cpu[681] = 1.0 +sfa_ref_cpu[682] = 1.0 +sfa_ref_cpu[683] = 1.0 +sfa_ref_cpu[684] = 2.0 +sfa_ref_cpu[685] = 2.0 +sfa_ref_cpu[686] = 2.0 +sfa_ref_cpu[687] = 1.0 +sfa_ref_cpu[688] = 1.0 +sfa_ref_cpu[689] = 2.0 +sfa_ref_cpu[690] = 2.0 +sfa_ref_cpu[691] = 1.0 +sfa_ref_cpu[692] = 2.0 +sfa_ref_cpu[693] = 1.0 +sfa_ref_cpu[694] = 2.0 +sfa_ref_cpu[695] = 1.0 +sfa_ref_cpu[696] = 1.0 +sfa_ref_cpu[697] = 2.0 +sfa_ref_cpu[698] = 1.0 +sfa_ref_cpu[699] = 1.0 +sfa_ref_cpu[700] = 2.0 +sfa_ref_cpu[701] = 2.0 +sfa_ref_cpu[702] = 1.0 +sfa_ref_cpu[703] = 2.0 +sfa_ref_cpu[704] = 2.0 +sfa_ref_cpu[705] = 2.0 +sfa_ref_cpu[706] = 1.0 +sfa_ref_cpu[707] = 2.0 +sfa_ref_cpu[708] = 2.0 +sfa_ref_cpu[709] = 2.0 +sfa_ref_cpu[710] = 1.0 +sfa_ref_cpu[711] = 2.0 +sfa_ref_cpu[712] = 2.0 +sfa_ref_cpu[713] = 2.0 +sfa_ref_cpu[714] = 2.0 +sfa_ref_cpu[715] = 2.0 +sfa_ref_cpu[716] = 2.0 +sfa_ref_cpu[717] = 2.0 +sfa_ref_cpu[718] = 1.0 +sfa_ref_cpu[719] = 2.0 +sfa_ref_cpu[720] = 1.0 +sfa_ref_cpu[721] = 1.0 +sfa_ref_cpu[722] = 1.0 +sfa_ref_cpu[723] = 1.0 +sfa_ref_cpu[724] = 1.0 +sfa_ref_cpu[725] = 1.0 +sfa_ref_cpu[726] = 1.0 +sfa_ref_cpu[727] = 2.0 +sfa_ref_cpu[728] = 2.0 +sfa_ref_cpu[729] = 2.0 +sfa_ref_cpu[730] = 2.0 +sfa_ref_cpu[731] = 1.0 +sfa_ref_cpu[732] = 1.0 +sfa_ref_cpu[733] = 2.0 +sfa_ref_cpu[734] = 2.0 +sfa_ref_cpu[735] = 1.0 +sfa_ref_cpu[736] = 1.0 +sfa_ref_cpu[737] = 2.0 +sfa_ref_cpu[738] = 2.0 +sfa_ref_cpu[739] = 2.0 +sfa_ref_cpu[740] = 2.0 +sfa_ref_cpu[741] = 2.0 +sfa_ref_cpu[742] = 1.0 +sfa_ref_cpu[743] = 2.0 +sfa_ref_cpu[744] = 2.0 +sfa_ref_cpu[745] = 2.0 +sfa_ref_cpu[746] = 1.0 +sfa_ref_cpu[747] = 1.0 +sfa_ref_cpu[748] = 1.0 +sfa_ref_cpu[749] = 1.0 +sfa_ref_cpu[750] = 2.0 +sfa_ref_cpu[751] = 2.0 +sfa_ref_cpu[752] = 1.0 +sfa_ref_cpu[753] = 1.0 +sfa_ref_cpu[754] = 1.0 +sfa_ref_cpu[755] = 1.0 +sfa_ref_cpu[756] = 1.0 +sfa_ref_cpu[757] = 1.0 +sfa_ref_cpu[758] = 1.0 +sfa_ref_cpu[759] = 1.0 +sfa_ref_cpu[760] = 2.0 +sfa_ref_cpu[761] = 2.0 +sfa_ref_cpu[762] = 2.0 +sfa_ref_cpu[763] = 1.0 +sfa_ref_cpu[764] = 1.0 +sfa_ref_cpu[765] = 2.0 +sfa_ref_cpu[766] = 1.0 +sfa_ref_cpu[767] = 1.0 +sfa_ref_cpu[768] = 2.0 +sfa_ref_cpu[769] = 1.0 +sfa_ref_cpu[770] = 2.0 +sfa_ref_cpu[771] = 2.0 +sfa_ref_cpu[772] = 2.0 +sfa_ref_cpu[773] = 2.0 +sfa_ref_cpu[774] = 2.0 +sfa_ref_cpu[775] = 2.0 +sfa_ref_cpu[776] = 2.0 +sfa_ref_cpu[777] = 2.0 +sfa_ref_cpu[778] = 2.0 +sfa_ref_cpu[779] = 1.0 +sfa_ref_cpu[780] = 1.0 +sfa_ref_cpu[781] = 2.0 +sfa_ref_cpu[782] = 2.0 +sfa_ref_cpu[783] = 1.0 +sfa_ref_cpu[784] = 1.0 +sfa_ref_cpu[785] = 2.0 +sfa_ref_cpu[786] = 2.0 +sfa_ref_cpu[787] = 1.0 +sfa_ref_cpu[788] = 2.0 +sfa_ref_cpu[789] = 2.0 +sfa_ref_cpu[790] = 2.0 +sfa_ref_cpu[791] = 1.0 +sfa_ref_cpu[792] = 1.0 +sfa_ref_cpu[793] = 2.0 +sfa_ref_cpu[794] = 1.0 +sfa_ref_cpu[795] = 1.0 +sfa_ref_cpu[796] = 2.0 +sfa_ref_cpu[797] = 1.0 +sfa_ref_cpu[798] = 2.0 +sfa_ref_cpu[799] = 1.0 +sfa_ref_cpu[800] = 1.0 +sfa_ref_cpu[801] = 2.0 +sfa_ref_cpu[802] = 2.0 +sfa_ref_cpu[803] = 2.0 +sfa_ref_cpu[804] = 1.0 +sfa_ref_cpu[805] = 2.0 +sfa_ref_cpu[806] = 1.0 +sfa_ref_cpu[807] = 2.0 +sfa_ref_cpu[808] = 1.0 +sfa_ref_cpu[809] = 2.0 +sfa_ref_cpu[810] = 2.0 +sfa_ref_cpu[811] = 1.0 +sfa_ref_cpu[812] = 2.0 +sfa_ref_cpu[813] = 2.0 +sfa_ref_cpu[814] = 1.0 +sfa_ref_cpu[815] = 1.0 +sfa_ref_cpu[816] = 1.0 +sfa_ref_cpu[817] = 1.0 +sfa_ref_cpu[818] = 1.0 +sfa_ref_cpu[819] = 2.0 +sfa_ref_cpu[820] = 2.0 +sfa_ref_cpu[821] = 2.0 +sfa_ref_cpu[822] = 1.0 +sfa_ref_cpu[823] = 2.0 +sfa_ref_cpu[824] = 1.0 +sfa_ref_cpu[825] = 1.0 +sfa_ref_cpu[826] = 2.0 +sfa_ref_cpu[827] = 1.0 +sfa_ref_cpu[828] = 2.0 +sfa_ref_cpu[829] = 2.0 +sfa_ref_cpu[830] = 1.0 +sfa_ref_cpu[831] = 2.0 +sfa_ref_cpu[832] = 2.0 +sfa_ref_cpu[833] = 1.0 +sfa_ref_cpu[834] = 2.0 +sfa_ref_cpu[835] = 1.0 +sfa_ref_cpu[836] = 2.0 +sfa_ref_cpu[837] = 2.0 +sfa_ref_cpu[838] = 1.0 +sfa_ref_cpu[839] = 1.0 +sfa_ref_cpu[840] = 1.0 +sfa_ref_cpu[841] = 2.0 +sfa_ref_cpu[842] = 2.0 +sfa_ref_cpu[843] = 1.0 +sfa_ref_cpu[844] = 1.0 +sfa_ref_cpu[845] = 1.0 +sfa_ref_cpu[846] = 2.0 +sfa_ref_cpu[847] = 2.0 +sfa_ref_cpu[848] = 1.0 +sfa_ref_cpu[849] = 1.0 +sfa_ref_cpu[850] = 1.0 +sfa_ref_cpu[851] = 1.0 +sfa_ref_cpu[852] = 1.0 +sfa_ref_cpu[853] = 2.0 +sfa_ref_cpu[854] = 2.0 +sfa_ref_cpu[855] = 1.0 +sfa_ref_cpu[856] = 2.0 +sfa_ref_cpu[857] = 1.0 +sfa_ref_cpu[858] = 1.0 +sfa_ref_cpu[859] = 1.0 +sfa_ref_cpu[860] = 2.0 +sfa_ref_cpu[861] = 1.0 +sfa_ref_cpu[862] = 1.0 +sfa_ref_cpu[863] = 1.0 +sfa_ref_cpu[864] = 2.0 +sfa_ref_cpu[865] = 2.0 +sfa_ref_cpu[866] = 1.0 +sfa_ref_cpu[867] = 2.0 +sfa_ref_cpu[868] = 2.0 +sfa_ref_cpu[869] = 1.0 +sfa_ref_cpu[870] = 1.0 +sfa_ref_cpu[871] = 1.0 +sfa_ref_cpu[872] = 2.0 +sfa_ref_cpu[873] = 2.0 +sfa_ref_cpu[874] = 2.0 +sfa_ref_cpu[875] = 1.0 +sfa_ref_cpu[876] = 1.0 +sfa_ref_cpu[877] = 2.0 +sfa_ref_cpu[878] = 1.0 +sfa_ref_cpu[879] = 2.0 +sfa_ref_cpu[880] = 1.0 +sfa_ref_cpu[881] = 1.0 +sfa_ref_cpu[882] = 1.0 +sfa_ref_cpu[883] = 2.0 +sfa_ref_cpu[884] = 2.0 +sfa_ref_cpu[885] = 2.0 +sfa_ref_cpu[886] = 2.0 +sfa_ref_cpu[887] = 1.0 +sfa_ref_cpu[888] = 2.0 +sfa_ref_cpu[889] = 2.0 +sfa_ref_cpu[890] = 1.0 +sfa_ref_cpu[891] = 1.0 +sfa_ref_cpu[892] = 1.0 +sfa_ref_cpu[893] = 1.0 +sfa_ref_cpu[894] = 2.0 +sfa_ref_cpu[895] = 2.0 +sfa_ref_cpu[896] = 1.0 +sfa_ref_cpu[897] = 2.0 +sfa_ref_cpu[898] = 1.0 +sfa_ref_cpu[899] = 1.0 +sfa_ref_cpu[900] = 2.0 +sfa_ref_cpu[901] = 2.0 +sfa_ref_cpu[902] = 1.0 +sfa_ref_cpu[903] = 1.0 +sfa_ref_cpu[904] = 1.0 +sfa_ref_cpu[905] = 1.0 +sfa_ref_cpu[906] = 1.0 +sfa_ref_cpu[907] = 2.0 +sfa_ref_cpu[908] = 1.0 +sfa_ref_cpu[909] = 1.0 +sfa_ref_cpu[910] = 2.0 +sfa_ref_cpu[911] = 1.0 +sfa_ref_cpu[912] = 1.0 +sfa_ref_cpu[913] = 2.0 +sfa_ref_cpu[914] = 2.0 +sfa_ref_cpu[915] = 2.0 +sfa_ref_cpu[916] = 2.0 +sfa_ref_cpu[917] = 2.0 +sfa_ref_cpu[918] = 1.0 +sfa_ref_cpu[919] = 1.0 +sfa_ref_cpu[920] = 2.0 +sfa_ref_cpu[921] = 1.0 +sfa_ref_cpu[922] = 1.0 +sfa_ref_cpu[923] = 1.0 +sfa_ref_cpu[924] = 2.0 +sfa_ref_cpu[925] = 2.0 +sfa_ref_cpu[926] = 2.0 +sfa_ref_cpu[927] = 1.0 +sfa_ref_cpu[928] = 2.0 +sfa_ref_cpu[929] = 2.0 +sfa_ref_cpu[930] = 2.0 +sfa_ref_cpu[931] = 1.0 +sfa_ref_cpu[932] = 1.0 +sfa_ref_cpu[933] = 1.0 +sfa_ref_cpu[934] = 1.0 +sfa_ref_cpu[935] = 1.0 +sfa_ref_cpu[936] = 2.0 +sfa_ref_cpu[937] = 2.0 +sfa_ref_cpu[938] = 1.0 +sfa_ref_cpu[939] = 2.0 +sfa_ref_cpu[940] = 1.0 +sfa_ref_cpu[941] = 1.0 +sfa_ref_cpu[942] = 2.0 +sfa_ref_cpu[943] = 2.0 +sfa_ref_cpu[944] = 1.0 +sfa_ref_cpu[945] = 2.0 +sfa_ref_cpu[946] = 1.0 +sfa_ref_cpu[947] = 1.0 +sfa_ref_cpu[948] = 1.0 +sfa_ref_cpu[949] = 2.0 +sfa_ref_cpu[950] = 2.0 +sfa_ref_cpu[951] = 2.0 +sfa_ref_cpu[952] = 1.0 +sfa_ref_cpu[953] = 2.0 +sfa_ref_cpu[954] = 1.0 +sfa_ref_cpu[955] = 1.0 +sfa_ref_cpu[956] = 2.0 +sfa_ref_cpu[957] = 2.0 +sfa_ref_cpu[958] = 1.0 +sfa_ref_cpu[959] = 2.0 +sfa_ref_cpu[960] = 1.0 +sfa_ref_cpu[961] = 1.0 +sfa_ref_cpu[962] = 2.0 +sfa_ref_cpu[963] = 2.0 +sfa_ref_cpu[964] = 2.0 +sfa_ref_cpu[965] = 1.0 +sfa_ref_cpu[966] = 2.0 +sfa_ref_cpu[967] = 1.0 +sfa_ref_cpu[968] = 1.0 +sfa_ref_cpu[969] = 1.0 +sfa_ref_cpu[970] = 2.0 +sfa_ref_cpu[971] = 2.0 +sfa_ref_cpu[972] = 1.0 +sfa_ref_cpu[973] = 1.0 +sfa_ref_cpu[974] = 1.0 +sfa_ref_cpu[975] = 2.0 +sfa_ref_cpu[976] = 2.0 +sfa_ref_cpu[977] = 2.0 +sfa_ref_cpu[978] = 1.0 +sfa_ref_cpu[979] = 1.0 +sfa_ref_cpu[980] = 1.0 +sfa_ref_cpu[981] = 2.0 +sfa_ref_cpu[982] = 1.0 +sfa_ref_cpu[983] = 2.0 +sfa_ref_cpu[984] = 2.0 +sfa_ref_cpu[985] = 1.0 +sfa_ref_cpu[986] = 2.0 +sfa_ref_cpu[987] = 2.0 +sfa_ref_cpu[988] = 1.0 +sfa_ref_cpu[989] = 1.0 +sfa_ref_cpu[990] = 1.0 +sfa_ref_cpu[991] = 2.0 +sfa_ref_cpu[992] = 1.0 +sfa_ref_cpu[993] = 1.0 +sfa_ref_cpu[994] = 2.0 +sfa_ref_cpu[995] = 1.0 +sfa_ref_cpu[996] = 2.0 +sfa_ref_cpu[997] = 2.0 +sfa_ref_cpu[998] = 1.0 +sfa_ref_cpu[999] = 2.0 +sfa_ref_cpu[1000] = 1.0 +sfa_ref_cpu[1001] = 1.0 +sfa_ref_cpu[1002] = 2.0 +sfa_ref_cpu[1003] = 2.0 +sfa_ref_cpu[1004] = 1.0 +sfa_ref_cpu[1005] = 2.0 +sfa_ref_cpu[1006] = 2.0 +sfa_ref_cpu[1007] = 2.0 +sfa_ref_cpu[1008] = 2.0 +sfa_ref_cpu[1009] = 2.0 +sfa_ref_cpu[1010] = 2.0 +sfa_ref_cpu[1011] = 1.0 +sfa_ref_cpu[1012] = 2.0 +sfa_ref_cpu[1013] = 2.0 +sfa_ref_cpu[1014] = 2.0 +sfa_ref_cpu[1015] = 1.0 +sfa_ref_cpu[1016] = 2.0 +sfa_ref_cpu[1017] = 2.0 +sfa_ref_cpu[1018] = 1.0 +sfa_ref_cpu[1019] = 2.0 +sfa_ref_cpu[1020] = 2.0 +sfa_ref_cpu[1021] = 1.0 +sfa_ref_cpu[1022] = 2.0 +sfa_ref_cpu[1023] = 1.0 +sfa_ref_cpu[1024] = 1.0 +sfa_ref_cpu[1025] = 1.0 +sfa_ref_cpu[1026] = 1.0 +sfa_ref_cpu[1027] = 2.0 +sfa_ref_cpu[1028] = 2.0 +sfa_ref_cpu[1029] = 2.0 +sfa_ref_cpu[1030] = 2.0 +sfa_ref_cpu[1031] = 1.0 +sfa_ref_cpu[1032] = 1.0 +sfa_ref_cpu[1033] = 1.0 +sfa_ref_cpu[1034] = 1.0 +sfa_ref_cpu[1035] = 2.0 +sfa_ref_cpu[1036] = 1.0 +sfa_ref_cpu[1037] = 2.0 +sfa_ref_cpu[1038] = 2.0 +sfa_ref_cpu[1039] = 2.0 +sfa_ref_cpu[1040] = 2.0 +sfa_ref_cpu[1041] = 1.0 +sfa_ref_cpu[1042] = 1.0 +sfa_ref_cpu[1043] = 1.0 +sfa_ref_cpu[1044] = 1.0 +sfa_ref_cpu[1045] = 1.0 +sfa_ref_cpu[1046] = 1.0 +sfa_ref_cpu[1047] = 2.0 +sfa_ref_cpu[1048] = 1.0 +sfa_ref_cpu[1049] = 1.0 +sfa_ref_cpu[1050] = 2.0 +sfa_ref_cpu[1051] = 2.0 +sfa_ref_cpu[1052] = 2.0 +sfa_ref_cpu[1053] = 2.0 +sfa_ref_cpu[1054] = 1.0 +sfa_ref_cpu[1055] = 2.0 +sfa_ref_cpu[1056] = 2.0 +sfa_ref_cpu[1057] = 2.0 +sfa_ref_cpu[1058] = 1.0 +sfa_ref_cpu[1059] = 2.0 +sfa_ref_cpu[1060] = 2.0 +sfa_ref_cpu[1061] = 2.0 +sfa_ref_cpu[1062] = 2.0 +sfa_ref_cpu[1063] = 2.0 +sfa_ref_cpu[1064] = 2.0 +sfa_ref_cpu[1065] = 2.0 +sfa_ref_cpu[1066] = 1.0 +sfa_ref_cpu[1067] = 1.0 +sfa_ref_cpu[1068] = 1.0 +sfa_ref_cpu[1069] = 1.0 +sfa_ref_cpu[1070] = 2.0 +sfa_ref_cpu[1071] = 1.0 +sfa_ref_cpu[1072] = 2.0 +sfa_ref_cpu[1073] = 1.0 +sfa_ref_cpu[1074] = 2.0 +sfa_ref_cpu[1075] = 1.0 +sfa_ref_cpu[1076] = 1.0 +sfa_ref_cpu[1077] = 1.0 +sfa_ref_cpu[1078] = 2.0 +sfa_ref_cpu[1079] = 1.0 +sfa_ref_cpu[1080] = 1.0 +sfa_ref_cpu[1081] = 2.0 +sfa_ref_cpu[1082] = 1.0 +sfa_ref_cpu[1083] = 2.0 +sfa_ref_cpu[1084] = 2.0 +sfa_ref_cpu[1085] = 1.0 +sfa_ref_cpu[1086] = 1.0 +sfa_ref_cpu[1087] = 2.0 +sfa_ref_cpu[1088] = 1.0 +sfa_ref_cpu[1089] = 2.0 +sfa_ref_cpu[1090] = 2.0 +sfa_ref_cpu[1091] = 2.0 +sfa_ref_cpu[1092] = 2.0 +sfa_ref_cpu[1093] = 2.0 +sfa_ref_cpu[1094] = 2.0 +sfa_ref_cpu[1095] = 2.0 +sfa_ref_cpu[1096] = 1.0 +sfa_ref_cpu[1097] = 1.0 +sfa_ref_cpu[1098] = 1.0 +sfa_ref_cpu[1099] = 1.0 +sfa_ref_cpu[1100] = 1.0 +sfa_ref_cpu[1101] = 2.0 +sfa_ref_cpu[1102] = 1.0 +sfa_ref_cpu[1103] = 2.0 +sfa_ref_cpu[1104] = 1.0 +sfa_ref_cpu[1105] = 2.0 +sfa_ref_cpu[1106] = 1.0 +sfa_ref_cpu[1107] = 2.0 +sfa_ref_cpu[1108] = 2.0 +sfa_ref_cpu[1109] = 2.0 +sfa_ref_cpu[1110] = 1.0 +sfa_ref_cpu[1111] = 1.0 +sfa_ref_cpu[1112] = 2.0 +sfa_ref_cpu[1113] = 1.0 +sfa_ref_cpu[1114] = 1.0 +sfa_ref_cpu[1115] = 1.0 +sfa_ref_cpu[1116] = 1.0 +sfa_ref_cpu[1117] = 2.0 +sfa_ref_cpu[1118] = 2.0 +sfa_ref_cpu[1119] = 1.0 +sfa_ref_cpu[1120] = 1.0 +sfa_ref_cpu[1121] = 2.0 +sfa_ref_cpu[1122] = 1.0 +sfa_ref_cpu[1123] = 1.0 +sfa_ref_cpu[1124] = 2.0 +sfa_ref_cpu[1125] = 2.0 +sfa_ref_cpu[1126] = 2.0 +sfa_ref_cpu[1127] = 2.0 +sfa_ref_cpu[1128] = 1.0 +sfa_ref_cpu[1129] = 2.0 +sfa_ref_cpu[1130] = 1.0 +sfa_ref_cpu[1131] = 1.0 +sfa_ref_cpu[1132] = 1.0 +sfa_ref_cpu[1133] = 2.0 +sfa_ref_cpu[1134] = 2.0 +sfa_ref_cpu[1135] = 1.0 +sfa_ref_cpu[1136] = 1.0 +sfa_ref_cpu[1137] = 2.0 +sfa_ref_cpu[1138] = 2.0 +sfa_ref_cpu[1139] = 1.0 +sfa_ref_cpu[1140] = 2.0 +sfa_ref_cpu[1141] = 1.0 +sfa_ref_cpu[1142] = 2.0 +sfa_ref_cpu[1143] = 1.0 +sfa_ref_cpu[1144] = 2.0 +sfa_ref_cpu[1145] = 2.0 +sfa_ref_cpu[1146] = 2.0 +sfa_ref_cpu[1147] = 1.0 +sfa_ref_cpu[1148] = 2.0 +sfa_ref_cpu[1149] = 1.0 +sfa_ref_cpu[1150] = 1.0 +sfa_ref_cpu[1151] = 2.0 +sfa_ref_cpu[1152] = 2.0 +sfa_ref_cpu[1153] = 1.0 +sfa_ref_cpu[1154] = 1.0 +sfa_ref_cpu[1155] = 1.0 +sfa_ref_cpu[1156] = 1.0 +sfa_ref_cpu[1157] = 1.0 +sfa_ref_cpu[1158] = 2.0 +sfa_ref_cpu[1159] = 1.0 +sfa_ref_cpu[1160] = 2.0 +sfa_ref_cpu[1161] = 2.0 +sfa_ref_cpu[1162] = 2.0 +sfa_ref_cpu[1163] = 1.0 +sfa_ref_cpu[1164] = 2.0 +sfa_ref_cpu[1165] = 1.0 +sfa_ref_cpu[1166] = 2.0 +sfa_ref_cpu[1167] = 2.0 +sfa_ref_cpu[1168] = 2.0 +sfa_ref_cpu[1169] = 2.0 +sfa_ref_cpu[1170] = 2.0 +sfa_ref_cpu[1171] = 1.0 +sfa_ref_cpu[1172] = 2.0 +sfa_ref_cpu[1173] = 1.0 +sfa_ref_cpu[1174] = 1.0 +sfa_ref_cpu[1175] = 1.0 +sfa_ref_cpu[1176] = 2.0 +sfa_ref_cpu[1177] = 1.0 +sfa_ref_cpu[1178] = 1.0 +sfa_ref_cpu[1179] = 2.0 +sfa_ref_cpu[1180] = 2.0 +sfa_ref_cpu[1181] = 2.0 +sfa_ref_cpu[1182] = 2.0 +sfa_ref_cpu[1183] = 2.0 +sfa_ref_cpu[1184] = 1.0 +sfa_ref_cpu[1185] = 1.0 +sfa_ref_cpu[1186] = 2.0 +sfa_ref_cpu[1187] = 1.0 +sfa_ref_cpu[1188] = 2.0 +sfa_ref_cpu[1189] = 2.0 +sfa_ref_cpu[1190] = 2.0 +sfa_ref_cpu[1191] = 1.0 +sfa_ref_cpu[1192] = 1.0 +sfa_ref_cpu[1193] = 1.0 +sfa_ref_cpu[1194] = 2.0 +sfa_ref_cpu[1195] = 2.0 +sfa_ref_cpu[1196] = 2.0 +sfa_ref_cpu[1197] = 1.0 +sfa_ref_cpu[1198] = 1.0 +sfa_ref_cpu[1199] = 2.0 +sfa_ref_cpu[1200] = 1.0 +sfa_ref_cpu[1201] = 1.0 +sfa_ref_cpu[1202] = 2.0 +sfa_ref_cpu[1203] = 2.0 +sfa_ref_cpu[1204] = 2.0 +sfa_ref_cpu[1205] = 2.0 +sfa_ref_cpu[1206] = 1.0 +sfa_ref_cpu[1207] = 1.0 +sfa_ref_cpu[1208] = 1.0 +sfa_ref_cpu[1209] = 2.0 +sfa_ref_cpu[1210] = 2.0 +sfa_ref_cpu[1211] = 2.0 +sfa_ref_cpu[1212] = 2.0 +sfa_ref_cpu[1213] = 2.0 +sfa_ref_cpu[1214] = 2.0 +sfa_ref_cpu[1215] = 2.0 +sfa_ref_cpu[1216] = 2.0 +sfa_ref_cpu[1217] = 2.0 +sfa_ref_cpu[1218] = 1.0 +sfa_ref_cpu[1219] = 2.0 +sfa_ref_cpu[1220] = 1.0 +sfa_ref_cpu[1221] = 2.0 +sfa_ref_cpu[1222] = 2.0 +sfa_ref_cpu[1223] = 2.0 +sfa_ref_cpu[1224] = 1.0 +sfa_ref_cpu[1225] = 2.0 +sfa_ref_cpu[1226] = 1.0 +sfa_ref_cpu[1227] = 2.0 +sfa_ref_cpu[1228] = 1.0 +sfa_ref_cpu[1229] = 1.0 +sfa_ref_cpu[1230] = 1.0 +sfa_ref_cpu[1231] = 1.0 +sfa_ref_cpu[1232] = 2.0 +sfa_ref_cpu[1233] = 2.0 +sfa_ref_cpu[1234] = 1.0 +sfa_ref_cpu[1235] = 1.0 +sfa_ref_cpu[1236] = 2.0 +sfa_ref_cpu[1237] = 2.0 +sfa_ref_cpu[1238] = 1.0 +sfa_ref_cpu[1239] = 2.0 +sfa_ref_cpu[1240] = 2.0 +sfa_ref_cpu[1241] = 2.0 +sfa_ref_cpu[1242] = 2.0 +sfa_ref_cpu[1243] = 1.0 +sfa_ref_cpu[1244] = 1.0 +sfa_ref_cpu[1245] = 1.0 +sfa_ref_cpu[1246] = 1.0 +sfa_ref_cpu[1247] = 2.0 +sfa_ref_cpu[1248] = 2.0 +sfa_ref_cpu[1249] = 1.0 +sfa_ref_cpu[1250] = 2.0 +sfa_ref_cpu[1251] = 1.0 +sfa_ref_cpu[1252] = 2.0 +sfa_ref_cpu[1253] = 1.0 +sfa_ref_cpu[1254] = 1.0 +sfa_ref_cpu[1255] = 1.0 +sfa_ref_cpu[1256] = 2.0 +sfa_ref_cpu[1257] = 1.0 +sfa_ref_cpu[1258] = 2.0 +sfa_ref_cpu[1259] = 2.0 +sfa_ref_cpu[1260] = 1.0 +sfa_ref_cpu[1261] = 1.0 +sfa_ref_cpu[1262] = 1.0 +sfa_ref_cpu[1263] = 2.0 +sfa_ref_cpu[1264] = 1.0 +sfa_ref_cpu[1265] = 2.0 +sfa_ref_cpu[1266] = 2.0 +sfa_ref_cpu[1267] = 1.0 +sfa_ref_cpu[1268] = 1.0 +sfa_ref_cpu[1269] = 2.0 +sfa_ref_cpu[1270] = 2.0 +sfa_ref_cpu[1271] = 1.0 +sfa_ref_cpu[1272] = 2.0 +sfa_ref_cpu[1273] = 1.0 +sfa_ref_cpu[1274] = 2.0 +sfa_ref_cpu[1275] = 2.0 +sfa_ref_cpu[1276] = 1.0 +sfa_ref_cpu[1277] = 2.0 +sfa_ref_cpu[1278] = 1.0 +sfa_ref_cpu[1279] = 1.0 +sfa_ref_cpu[1280] = 1.0 +sfa_ref_cpu[1281] = 2.0 +sfa_ref_cpu[1282] = 1.0 +sfa_ref_cpu[1283] = 1.0 +sfa_ref_cpu[1284] = 2.0 +sfa_ref_cpu[1285] = 1.0 +sfa_ref_cpu[1286] = 1.0 +sfa_ref_cpu[1287] = 1.0 +sfa_ref_cpu[1288] = 2.0 +sfa_ref_cpu[1289] = 2.0 +sfa_ref_cpu[1290] = 2.0 +sfa_ref_cpu[1291] = 1.0 +sfa_ref_cpu[1292] = 2.0 +sfa_ref_cpu[1293] = 1.0 +sfa_ref_cpu[1294] = 1.0 +sfa_ref_cpu[1295] = 2.0 +sfa_ref_cpu[1296] = 1.0 +sfa_ref_cpu[1297] = 2.0 +sfa_ref_cpu[1298] = 2.0 +sfa_ref_cpu[1299] = 1.0 +sfa_ref_cpu[1300] = 1.0 +sfa_ref_cpu[1301] = 1.0 +sfa_ref_cpu[1302] = 1.0 +sfa_ref_cpu[1303] = 1.0 +sfa_ref_cpu[1304] = 1.0 +sfa_ref_cpu[1305] = 2.0 +sfa_ref_cpu[1306] = 2.0 +sfa_ref_cpu[1307] = 1.0 +sfa_ref_cpu[1308] = 2.0 +sfa_ref_cpu[1309] = 1.0 +sfa_ref_cpu[1310] = 2.0 +sfa_ref_cpu[1311] = 1.0 +sfa_ref_cpu[1312] = 2.0 +sfa_ref_cpu[1313] = 1.0 +sfa_ref_cpu[1314] = 1.0 +sfa_ref_cpu[1315] = 1.0 +sfa_ref_cpu[1316] = 1.0 +sfa_ref_cpu[1317] = 2.0 +sfa_ref_cpu[1318] = 1.0 +sfa_ref_cpu[1319] = 1.0 +sfa_ref_cpu[1320] = 2.0 +sfa_ref_cpu[1321] = 1.0 +sfa_ref_cpu[1322] = 2.0 +sfa_ref_cpu[1323] = 1.0 +sfa_ref_cpu[1324] = 1.0 +sfa_ref_cpu[1325] = 2.0 +sfa_ref_cpu[1326] = 2.0 +sfa_ref_cpu[1327] = 1.0 +sfa_ref_cpu[1328] = 2.0 +sfa_ref_cpu[1329] = 1.0 +sfa_ref_cpu[1330] = 1.0 +sfa_ref_cpu[1331] = 1.0 +sfa_ref_cpu[1332] = 2.0 +sfa_ref_cpu[1333] = 2.0 +sfa_ref_cpu[1334] = 1.0 +sfa_ref_cpu[1335] = 2.0 +sfa_ref_cpu[1336] = 2.0 +sfa_ref_cpu[1337] = 2.0 +sfa_ref_cpu[1338] = 1.0 +sfa_ref_cpu[1339] = 1.0 +sfa_ref_cpu[1340] = 1.0 +sfa_ref_cpu[1341] = 2.0 +sfa_ref_cpu[1342] = 2.0 +sfa_ref_cpu[1343] = 2.0 +sfa_ref_cpu[1344] = 1.0 +sfa_ref_cpu[1345] = 2.0 +sfa_ref_cpu[1346] = 1.0 +sfa_ref_cpu[1347] = 2.0 +sfa_ref_cpu[1348] = 2.0 +sfa_ref_cpu[1349] = 2.0 +sfa_ref_cpu[1350] = 1.0 +sfa_ref_cpu[1351] = 2.0 +sfa_ref_cpu[1352] = 2.0 +sfa_ref_cpu[1353] = 1.0 +sfa_ref_cpu[1354] = 1.0 +sfa_ref_cpu[1355] = 1.0 +sfa_ref_cpu[1356] = 2.0 +sfa_ref_cpu[1357] = 1.0 +sfa_ref_cpu[1358] = 2.0 +sfa_ref_cpu[1359] = 1.0 +sfa_ref_cpu[1360] = 1.0 +sfa_ref_cpu[1361] = 1.0 +sfa_ref_cpu[1362] = 1.0 +sfa_ref_cpu[1363] = 2.0 +sfa_ref_cpu[1364] = 1.0 +sfa_ref_cpu[1365] = 2.0 +sfa_ref_cpu[1366] = 1.0 +sfa_ref_cpu[1367] = 2.0 +sfa_ref_cpu[1368] = 2.0 +sfa_ref_cpu[1369] = 2.0 +sfa_ref_cpu[1370] = 1.0 +sfa_ref_cpu[1371] = 1.0 +sfa_ref_cpu[1372] = 2.0 +sfa_ref_cpu[1373] = 1.0 +sfa_ref_cpu[1374] = 1.0 +sfa_ref_cpu[1375] = 2.0 +sfa_ref_cpu[1376] = 1.0 +sfa_ref_cpu[1377] = 1.0 +sfa_ref_cpu[1378] = 1.0 +sfa_ref_cpu[1379] = 1.0 +sfa_ref_cpu[1380] = 1.0 +sfa_ref_cpu[1381] = 1.0 +sfa_ref_cpu[1382] = 2.0 +sfa_ref_cpu[1383] = 1.0 +sfa_ref_cpu[1384] = 2.0 +sfa_ref_cpu[1385] = 2.0 +sfa_ref_cpu[1386] = 2.0 +sfa_ref_cpu[1387] = 1.0 +sfa_ref_cpu[1388] = 1.0 +sfa_ref_cpu[1389] = 2.0 +sfa_ref_cpu[1390] = 2.0 +sfa_ref_cpu[1391] = 1.0 +sfa_ref_cpu[1392] = 2.0 +sfa_ref_cpu[1393] = 1.0 +sfa_ref_cpu[1394] = 1.0 +sfa_ref_cpu[1395] = 2.0 +sfa_ref_cpu[1396] = 2.0 +sfa_ref_cpu[1397] = 2.0 +sfa_ref_cpu[1398] = 2.0 +sfa_ref_cpu[1399] = 2.0 +sfa_ref_cpu[1400] = 1.0 +sfa_ref_cpu[1401] = 1.0 +sfa_ref_cpu[1402] = 2.0 +sfa_ref_cpu[1403] = 1.0 +sfa_ref_cpu[1404] = 1.0 +sfa_ref_cpu[1405] = 1.0 +sfa_ref_cpu[1406] = 2.0 +sfa_ref_cpu[1407] = 1.0 +sfa_ref_cpu[1408] = 2.0 +sfa_ref_cpu[1409] = 2.0 +sfa_ref_cpu[1410] = 1.0 +sfa_ref_cpu[1411] = 1.0 +sfa_ref_cpu[1412] = 1.0 +sfa_ref_cpu[1413] = 2.0 +sfa_ref_cpu[1414] = 1.0 +sfa_ref_cpu[1415] = 2.0 +sfa_ref_cpu[1416] = 2.0 +sfa_ref_cpu[1417] = 1.0 +sfa_ref_cpu[1418] = 1.0 +sfa_ref_cpu[1419] = 1.0 +sfa_ref_cpu[1420] = 2.0 +sfa_ref_cpu[1421] = 1.0 +sfa_ref_cpu[1422] = 2.0 +sfa_ref_cpu[1423] = 1.0 +sfa_ref_cpu[1424] = 1.0 +sfa_ref_cpu[1425] = 2.0 +sfa_ref_cpu[1426] = 2.0 +sfa_ref_cpu[1427] = 1.0 +sfa_ref_cpu[1428] = 1.0 +sfa_ref_cpu[1429] = 1.0 +sfa_ref_cpu[1430] = 1.0 +sfa_ref_cpu[1431] = 1.0 +sfa_ref_cpu[1432] = 1.0 +sfa_ref_cpu[1433] = 1.0 +sfa_ref_cpu[1434] = 1.0 +sfa_ref_cpu[1435] = 2.0 +sfa_ref_cpu[1436] = 2.0 +sfa_ref_cpu[1437] = 1.0 +sfa_ref_cpu[1438] = 2.0 +sfa_ref_cpu[1439] = 1.0 +sfa_ref_cpu[1440] = 1.0 +sfa_ref_cpu[1441] = 2.0 +sfa_ref_cpu[1442] = 1.0 +sfa_ref_cpu[1443] = 2.0 +sfa_ref_cpu[1444] = 1.0 +sfa_ref_cpu[1445] = 1.0 +sfa_ref_cpu[1446] = 2.0 +sfa_ref_cpu[1447] = 1.0 +sfa_ref_cpu[1448] = 1.0 +sfa_ref_cpu[1449] = 1.0 +sfa_ref_cpu[1450] = 1.0 +sfa_ref_cpu[1451] = 1.0 +sfa_ref_cpu[1452] = 2.0 +sfa_ref_cpu[1453] = 2.0 +sfa_ref_cpu[1454] = 1.0 +sfa_ref_cpu[1455] = 2.0 +sfa_ref_cpu[1456] = 2.0 +sfa_ref_cpu[1457] = 1.0 +sfa_ref_cpu[1458] = 1.0 +sfa_ref_cpu[1459] = 2.0 +sfa_ref_cpu[1460] = 2.0 +sfa_ref_cpu[1461] = 1.0 +sfa_ref_cpu[1462] = 1.0 +sfa_ref_cpu[1463] = 1.0 +sfa_ref_cpu[1464] = 2.0 +sfa_ref_cpu[1465] = 1.0 +sfa_ref_cpu[1466] = 1.0 +sfa_ref_cpu[1467] = 2.0 +sfa_ref_cpu[1468] = 1.0 +sfa_ref_cpu[1469] = 2.0 +sfa_ref_cpu[1470] = 2.0 +sfa_ref_cpu[1471] = 2.0 +sfa_ref_cpu[1472] = 1.0 +sfa_ref_cpu[1473] = 1.0 +sfa_ref_cpu[1474] = 1.0 +sfa_ref_cpu[1475] = 1.0 +sfa_ref_cpu[1476] = 1.0 +sfa_ref_cpu[1477] = 1.0 +sfa_ref_cpu[1478] = 2.0 +sfa_ref_cpu[1479] = 2.0 +sfa_ref_cpu[1480] = 1.0 +sfa_ref_cpu[1481] = 1.0 +sfa_ref_cpu[1482] = 2.0 +sfa_ref_cpu[1483] = 1.0 +sfa_ref_cpu[1484] = 1.0 +sfa_ref_cpu[1485] = 1.0 +sfa_ref_cpu[1486] = 2.0 +sfa_ref_cpu[1487] = 2.0 +sfa_ref_cpu[1488] = 2.0 +sfa_ref_cpu[1489] = 2.0 +sfa_ref_cpu[1490] = 2.0 +sfa_ref_cpu[1491] = 1.0 +sfa_ref_cpu[1492] = 1.0 +sfa_ref_cpu[1493] = 2.0 +sfa_ref_cpu[1494] = 1.0 +sfa_ref_cpu[1495] = 2.0 +sfa_ref_cpu[1496] = 1.0 +sfa_ref_cpu[1497] = 2.0 +sfa_ref_cpu[1498] = 1.0 +sfa_ref_cpu[1499] = 1.0 +sfa_ref_cpu[1500] = 2.0 +sfa_ref_cpu[1501] = 2.0 +sfa_ref_cpu[1502] = 2.0 +sfa_ref_cpu[1503] = 2.0 +sfa_ref_cpu[1504] = 2.0 +sfa_ref_cpu[1505] = 2.0 +sfa_ref_cpu[1506] = 2.0 +sfa_ref_cpu[1507] = 2.0 +sfa_ref_cpu[1508] = 2.0 +sfa_ref_cpu[1509] = 2.0 +sfa_ref_cpu[1510] = 1.0 +sfa_ref_cpu[1511] = 2.0 +sfa_ref_cpu[1512] = 1.0 +sfa_ref_cpu[1513] = 2.0 +sfa_ref_cpu[1514] = 2.0 +sfa_ref_cpu[1515] = 1.0 +sfa_ref_cpu[1516] = 1.0 +sfa_ref_cpu[1517] = 2.0 +sfa_ref_cpu[1518] = 1.0 +sfa_ref_cpu[1519] = 1.0 +sfa_ref_cpu[1520] = 2.0 +sfa_ref_cpu[1521] = 2.0 +sfa_ref_cpu[1522] = 2.0 +sfa_ref_cpu[1523] = 2.0 +sfa_ref_cpu[1524] = 1.0 +sfa_ref_cpu[1525] = 2.0 +sfa_ref_cpu[1526] = 2.0 +sfa_ref_cpu[1527] = 1.0 +sfa_ref_cpu[1528] = 1.0 +sfa_ref_cpu[1529] = 2.0 +sfa_ref_cpu[1530] = 1.0 +sfa_ref_cpu[1531] = 2.0 +sfa_ref_cpu[1532] = 1.0 +sfa_ref_cpu[1533] = 2.0 +sfa_ref_cpu[1534] = 1.0 +sfa_ref_cpu[1535] = 1.0 +sfa_ref_cpu[1536] = 1.0 +sfa_ref_cpu[1537] = 2.0 +sfa_ref_cpu[1538] = 2.0 +sfa_ref_cpu[1539] = 1.0 +sfa_ref_cpu[1540] = 2.0 +sfa_ref_cpu[1541] = 1.0 +sfa_ref_cpu[1542] = 1.0 +sfa_ref_cpu[1543] = 2.0 +sfa_ref_cpu[1544] = 1.0 +sfa_ref_cpu[1545] = 2.0 +sfa_ref_cpu[1546] = 1.0 +sfa_ref_cpu[1547] = 2.0 +sfa_ref_cpu[1548] = 1.0 +sfa_ref_cpu[1549] = 2.0 +sfa_ref_cpu[1550] = 2.0 +sfa_ref_cpu[1551] = 1.0 +sfa_ref_cpu[1552] = 1.0 +sfa_ref_cpu[1553] = 1.0 +sfa_ref_cpu[1554] = 2.0 +sfa_ref_cpu[1555] = 1.0 +sfa_ref_cpu[1556] = 2.0 +sfa_ref_cpu[1557] = 2.0 +sfa_ref_cpu[1558] = 1.0 +sfa_ref_cpu[1559] = 2.0 +sfa_ref_cpu[1560] = 2.0 +sfa_ref_cpu[1561] = 2.0 +sfa_ref_cpu[1562] = 2.0 +sfa_ref_cpu[1563] = 1.0 +sfa_ref_cpu[1564] = 2.0 +sfa_ref_cpu[1565] = 1.0 +sfa_ref_cpu[1566] = 1.0 +sfa_ref_cpu[1567] = 2.0 +sfa_ref_cpu[1568] = 1.0 +sfa_ref_cpu[1569] = 1.0 +sfa_ref_cpu[1570] = 2.0 +sfa_ref_cpu[1571] = 1.0 +sfa_ref_cpu[1572] = 2.0 +sfa_ref_cpu[1573] = 2.0 +sfa_ref_cpu[1574] = 1.0 +sfa_ref_cpu[1575] = 1.0 +sfa_ref_cpu[1576] = 2.0 +sfa_ref_cpu[1577] = 1.0 +sfa_ref_cpu[1578] = 2.0 +sfa_ref_cpu[1579] = 1.0 +sfa_ref_cpu[1580] = 2.0 +sfa_ref_cpu[1581] = 2.0 +sfa_ref_cpu[1582] = 1.0 +sfa_ref_cpu[1583] = 2.0 +sfa_ref_cpu[1584] = 2.0 +sfa_ref_cpu[1585] = 1.0 +sfa_ref_cpu[1586] = 1.0 +sfa_ref_cpu[1587] = 1.0 +sfa_ref_cpu[1588] = 2.0 +sfa_ref_cpu[1589] = 2.0 +sfa_ref_cpu[1590] = 2.0 +sfa_ref_cpu[1591] = 2.0 +sfa_ref_cpu[1592] = 1.0 +sfa_ref_cpu[1593] = 1.0 +sfa_ref_cpu[1594] = 1.0 +sfa_ref_cpu[1595] = 2.0 +sfa_ref_cpu[1596] = 2.0 +sfa_ref_cpu[1597] = 2.0 +sfa_ref_cpu[1598] = 2.0 +sfa_ref_cpu[1599] = 1.0 +sfa_ref_cpu[1600] = 1.0 +sfa_ref_cpu[1601] = 1.0 +sfa_ref_cpu[1602] = 2.0 +sfa_ref_cpu[1603] = 2.0 +sfa_ref_cpu[1604] = 1.0 +sfa_ref_cpu[1605] = 1.0 +sfa_ref_cpu[1606] = 1.0 +sfa_ref_cpu[1607] = 2.0 +sfa_ref_cpu[1608] = 2.0 +sfa_ref_cpu[1609] = 1.0 +sfa_ref_cpu[1610] = 2.0 +sfa_ref_cpu[1611] = 1.0 +sfa_ref_cpu[1612] = 1.0 +sfa_ref_cpu[1613] = 1.0 +sfa_ref_cpu[1614] = 1.0 +sfa_ref_cpu[1615] = 2.0 +sfa_ref_cpu[1616] = 1.0 +sfa_ref_cpu[1617] = 1.0 +sfa_ref_cpu[1618] = 2.0 +sfa_ref_cpu[1619] = 1.0 +sfa_ref_cpu[1620] = 2.0 +sfa_ref_cpu[1621] = 2.0 +sfa_ref_cpu[1622] = 1.0 +sfa_ref_cpu[1623] = 2.0 +sfa_ref_cpu[1624] = 1.0 +sfa_ref_cpu[1625] = 1.0 +sfa_ref_cpu[1626] = 2.0 +sfa_ref_cpu[1627] = 1.0 +sfa_ref_cpu[1628] = 2.0 +sfa_ref_cpu[1629] = 1.0 +sfa_ref_cpu[1630] = 1.0 +sfa_ref_cpu[1631] = 2.0 +sfa_ref_cpu[1632] = 2.0 +sfa_ref_cpu[1633] = 2.0 +sfa_ref_cpu[1634] = 2.0 +sfa_ref_cpu[1635] = 1.0 +sfa_ref_cpu[1636] = 2.0 +sfa_ref_cpu[1637] = 1.0 +sfa_ref_cpu[1638] = 2.0 +sfa_ref_cpu[1639] = 1.0 +sfa_ref_cpu[1640] = 2.0 +sfa_ref_cpu[1641] = 1.0 +sfa_ref_cpu[1642] = 2.0 +sfa_ref_cpu[1643] = 2.0 +sfa_ref_cpu[1644] = 2.0 +sfa_ref_cpu[1645] = 2.0 +sfa_ref_cpu[1646] = 2.0 +sfa_ref_cpu[1647] = 2.0 +sfa_ref_cpu[1648] = 2.0 +sfa_ref_cpu[1649] = 1.0 +sfa_ref_cpu[1650] = 2.0 +sfa_ref_cpu[1651] = 2.0 +sfa_ref_cpu[1652] = 1.0 +sfa_ref_cpu[1653] = 2.0 +sfa_ref_cpu[1654] = 2.0 +sfa_ref_cpu[1655] = 2.0 +sfa_ref_cpu[1656] = 2.0 +sfa_ref_cpu[1657] = 2.0 +sfa_ref_cpu[1658] = 2.0 +sfa_ref_cpu[1659] = 2.0 +sfa_ref_cpu[1660] = 1.0 +sfa_ref_cpu[1661] = 1.0 +sfa_ref_cpu[1662] = 2.0 +sfa_ref_cpu[1663] = 1.0 +sfa_ref_cpu[1664] = 2.0 +sfa_ref_cpu[1665] = 2.0 +sfa_ref_cpu[1666] = 2.0 +sfa_ref_cpu[1667] = 1.0 +sfa_ref_cpu[1668] = 1.0 +sfa_ref_cpu[1669] = 2.0 +sfa_ref_cpu[1670] = 1.0 +sfa_ref_cpu[1671] = 2.0 +sfa_ref_cpu[1672] = 1.0 +sfa_ref_cpu[1673] = 1.0 +sfa_ref_cpu[1674] = 2.0 +sfa_ref_cpu[1675] = 2.0 +sfa_ref_cpu[1676] = 2.0 +sfa_ref_cpu[1677] = 1.0 +sfa_ref_cpu[1678] = 1.0 +sfa_ref_cpu[1679] = 1.0 +sfa_ref_cpu[1680] = 2.0 +sfa_ref_cpu[1681] = 2.0 +sfa_ref_cpu[1682] = 2.0 +sfa_ref_cpu[1683] = 1.0 +sfa_ref_cpu[1684] = 2.0 +sfa_ref_cpu[1685] = 2.0 +sfa_ref_cpu[1686] = 2.0 +sfa_ref_cpu[1687] = 1.0 +sfa_ref_cpu[1688] = 1.0 +sfa_ref_cpu[1689] = 2.0 +sfa_ref_cpu[1690] = 2.0 +sfa_ref_cpu[1691] = 1.0 +sfa_ref_cpu[1692] = 2.0 +sfa_ref_cpu[1693] = 2.0 +sfa_ref_cpu[1694] = 1.0 +sfa_ref_cpu[1695] = 1.0 +sfa_ref_cpu[1696] = 2.0 +sfa_ref_cpu[1697] = 2.0 +sfa_ref_cpu[1698] = 2.0 +sfa_ref_cpu[1699] = 1.0 +sfa_ref_cpu[1700] = 2.0 +sfa_ref_cpu[1701] = 2.0 +sfa_ref_cpu[1702] = 1.0 +sfa_ref_cpu[1703] = 2.0 +sfa_ref_cpu[1704] = 1.0 +sfa_ref_cpu[1705] = 1.0 +sfa_ref_cpu[1706] = 1.0 +sfa_ref_cpu[1707] = 1.0 +sfa_ref_cpu[1708] = 2.0 +sfa_ref_cpu[1709] = 2.0 +sfa_ref_cpu[1710] = 2.0 +sfa_ref_cpu[1711] = 1.0 +sfa_ref_cpu[1712] = 1.0 +sfa_ref_cpu[1713] = 2.0 +sfa_ref_cpu[1714] = 1.0 +sfa_ref_cpu[1715] = 2.0 +sfa_ref_cpu[1716] = 1.0 +sfa_ref_cpu[1717] = 1.0 +sfa_ref_cpu[1718] = 2.0 +sfa_ref_cpu[1719] = 2.0 +sfa_ref_cpu[1720] = 2.0 +sfa_ref_cpu[1721] = 2.0 +sfa_ref_cpu[1722] = 1.0 +sfa_ref_cpu[1723] = 1.0 +sfa_ref_cpu[1724] = 2.0 +sfa_ref_cpu[1725] = 2.0 +sfa_ref_cpu[1726] = 2.0 +sfa_ref_cpu[1727] = 2.0 +sfa_ref_cpu[1728] = 2.0 +sfa_ref_cpu[1729] = 2.0 +sfa_ref_cpu[1730] = 1.0 +sfa_ref_cpu[1731] = 2.0 +sfa_ref_cpu[1732] = 1.0 +sfa_ref_cpu[1733] = 1.0 +sfa_ref_cpu[1734] = 1.0 +sfa_ref_cpu[1735] = 2.0 +sfa_ref_cpu[1736] = 1.0 +sfa_ref_cpu[1737] = 1.0 +sfa_ref_cpu[1738] = 1.0 +sfa_ref_cpu[1739] = 1.0 +sfa_ref_cpu[1740] = 2.0 +sfa_ref_cpu[1741] = 2.0 +sfa_ref_cpu[1742] = 1.0 +sfa_ref_cpu[1743] = 2.0 +sfa_ref_cpu[1744] = 1.0 +sfa_ref_cpu[1745] = 2.0 +sfa_ref_cpu[1746] = 2.0 +sfa_ref_cpu[1747] = 1.0 +sfa_ref_cpu[1748] = 1.0 +sfa_ref_cpu[1749] = 2.0 +sfa_ref_cpu[1750] = 1.0 +sfa_ref_cpu[1751] = 2.0 +sfa_ref_cpu[1752] = 1.0 +sfa_ref_cpu[1753] = 1.0 +sfa_ref_cpu[1754] = 2.0 +sfa_ref_cpu[1755] = 1.0 +sfa_ref_cpu[1756] = 2.0 +sfa_ref_cpu[1757] = 2.0 +sfa_ref_cpu[1758] = 2.0 +sfa_ref_cpu[1759] = 1.0 +sfa_ref_cpu[1760] = 1.0 +sfa_ref_cpu[1761] = 2.0 +sfa_ref_cpu[1762] = 1.0 +sfa_ref_cpu[1763] = 1.0 +sfa_ref_cpu[1764] = 1.0 +sfa_ref_cpu[1765] = 2.0 +sfa_ref_cpu[1766] = 2.0 +sfa_ref_cpu[1767] = 2.0 +sfa_ref_cpu[1768] = 1.0 +sfa_ref_cpu[1769] = 2.0 +sfa_ref_cpu[1770] = 1.0 +sfa_ref_cpu[1771] = 1.0 +sfa_ref_cpu[1772] = 1.0 +sfa_ref_cpu[1773] = 1.0 +sfa_ref_cpu[1774] = 2.0 +sfa_ref_cpu[1775] = 1.0 +sfa_ref_cpu[1776] = 1.0 +sfa_ref_cpu[1777] = 1.0 +sfa_ref_cpu[1778] = 1.0 +sfa_ref_cpu[1779] = 1.0 +sfa_ref_cpu[1780] = 2.0 +sfa_ref_cpu[1781] = 2.0 +sfa_ref_cpu[1782] = 1.0 +sfa_ref_cpu[1783] = 2.0 +sfa_ref_cpu[1784] = 2.0 +sfa_ref_cpu[1785] = 1.0 +sfa_ref_cpu[1786] = 1.0 +sfa_ref_cpu[1787] = 1.0 +sfa_ref_cpu[1788] = 1.0 +sfa_ref_cpu[1789] = 1.0 +sfa_ref_cpu[1790] = 2.0 +sfa_ref_cpu[1791] = 2.0 +sfa_ref_cpu[1792] = 2.0 +sfa_ref_cpu[1793] = 1.0 +sfa_ref_cpu[1794] = 2.0 +sfa_ref_cpu[1795] = 1.0 +sfa_ref_cpu[1796] = 1.0 +sfa_ref_cpu[1797] = 1.0 +sfa_ref_cpu[1798] = 2.0 +sfa_ref_cpu[1799] = 2.0 +sfa_ref_cpu[1800] = 2.0 +sfa_ref_cpu[1801] = 1.0 +sfa_ref_cpu[1802] = 1.0 +sfa_ref_cpu[1803] = 2.0 +sfa_ref_cpu[1804] = 2.0 +sfa_ref_cpu[1805] = 2.0 +sfa_ref_cpu[1806] = 1.0 +sfa_ref_cpu[1807] = 1.0 +sfa_ref_cpu[1808] = 2.0 +sfa_ref_cpu[1809] = 1.0 +sfa_ref_cpu[1810] = 1.0 +sfa_ref_cpu[1811] = 2.0 +sfa_ref_cpu[1812] = 1.0 +sfa_ref_cpu[1813] = 2.0 +sfa_ref_cpu[1814] = 2.0 +sfa_ref_cpu[1815] = 1.0 +sfa_ref_cpu[1816] = 2.0 +sfa_ref_cpu[1817] = 2.0 +sfa_ref_cpu[1818] = 2.0 +sfa_ref_cpu[1819] = 1.0 +sfa_ref_cpu[1820] = 2.0 +sfa_ref_cpu[1821] = 2.0 +sfa_ref_cpu[1822] = 2.0 +sfa_ref_cpu[1823] = 1.0 +sfa_ref_cpu[1824] = 1.0 +sfa_ref_cpu[1825] = 1.0 +sfa_ref_cpu[1826] = 1.0 +sfa_ref_cpu[1827] = 2.0 +sfa_ref_cpu[1828] = 2.0 +sfa_ref_cpu[1829] = 1.0 +sfa_ref_cpu[1830] = 1.0 +sfa_ref_cpu[1831] = 1.0 +sfa_ref_cpu[1832] = 1.0 +sfa_ref_cpu[1833] = 1.0 +sfa_ref_cpu[1834] = 1.0 +sfa_ref_cpu[1835] = 1.0 +sfa_ref_cpu[1836] = 2.0 +sfa_ref_cpu[1837] = 2.0 +sfa_ref_cpu[1838] = 2.0 +sfa_ref_cpu[1839] = 1.0 +sfa_ref_cpu[1840] = 2.0 +sfa_ref_cpu[1841] = 2.0 +sfa_ref_cpu[1842] = 2.0 +sfa_ref_cpu[1843] = 1.0 +sfa_ref_cpu[1844] = 2.0 +sfa_ref_cpu[1845] = 1.0 +sfa_ref_cpu[1846] = 1.0 +sfa_ref_cpu[1847] = 1.0 +sfa_ref_cpu[1848] = 1.0 +sfa_ref_cpu[1849] = 1.0 +sfa_ref_cpu[1850] = 2.0 +sfa_ref_cpu[1851] = 1.0 +sfa_ref_cpu[1852] = 1.0 +sfa_ref_cpu[1853] = 2.0 +sfa_ref_cpu[1854] = 1.0 +sfa_ref_cpu[1855] = 1.0 +sfa_ref_cpu[1856] = 1.0 +sfa_ref_cpu[1857] = 1.0 +sfa_ref_cpu[1858] = 2.0 +sfa_ref_cpu[1859] = 2.0 +sfa_ref_cpu[1860] = 2.0 +sfa_ref_cpu[1861] = 2.0 +sfa_ref_cpu[1862] = 1.0 +sfa_ref_cpu[1863] = 2.0 +sfa_ref_cpu[1864] = 2.0 +sfa_ref_cpu[1865] = 2.0 +sfa_ref_cpu[1866] = 1.0 +sfa_ref_cpu[1867] = 2.0 +sfa_ref_cpu[1868] = 2.0 +sfa_ref_cpu[1869] = 1.0 +sfa_ref_cpu[1870] = 2.0 +sfa_ref_cpu[1871] = 2.0 +sfa_ref_cpu[1872] = 2.0 +sfa_ref_cpu[1873] = 2.0 +sfa_ref_cpu[1874] = 1.0 +sfa_ref_cpu[1875] = 2.0 +sfa_ref_cpu[1876] = 1.0 +sfa_ref_cpu[1877] = 1.0 +sfa_ref_cpu[1878] = 1.0 +sfa_ref_cpu[1879] = 2.0 +sfa_ref_cpu[1880] = 1.0 +sfa_ref_cpu[1881] = 2.0 +sfa_ref_cpu[1882] = 1.0 +sfa_ref_cpu[1883] = 2.0 +sfa_ref_cpu[1884] = 1.0 +sfa_ref_cpu[1885] = 2.0 +sfa_ref_cpu[1886] = 1.0 +sfa_ref_cpu[1887] = 2.0 +sfa_ref_cpu[1888] = 2.0 +sfa_ref_cpu[1889] = 2.0 +sfa_ref_cpu[1890] = 2.0 +sfa_ref_cpu[1891] = 2.0 +sfa_ref_cpu[1892] = 2.0 +sfa_ref_cpu[1893] = 2.0 +sfa_ref_cpu[1894] = 1.0 +sfa_ref_cpu[1895] = 1.0 +sfa_ref_cpu[1896] = 2.0 +sfa_ref_cpu[1897] = 1.0 +sfa_ref_cpu[1898] = 2.0 +sfa_ref_cpu[1899] = 2.0 +sfa_ref_cpu[1900] = 2.0 +sfa_ref_cpu[1901] = 2.0 +sfa_ref_cpu[1902] = 2.0 +sfa_ref_cpu[1903] = 1.0 +sfa_ref_cpu[1904] = 1.0 +sfa_ref_cpu[1905] = 2.0 +sfa_ref_cpu[1906] = 2.0 +sfa_ref_cpu[1907] = 1.0 +sfa_ref_cpu[1908] = 1.0 +sfa_ref_cpu[1909] = 2.0 +sfa_ref_cpu[1910] = 2.0 +sfa_ref_cpu[1911] = 1.0 +sfa_ref_cpu[1912] = 2.0 +sfa_ref_cpu[1913] = 2.0 +sfa_ref_cpu[1914] = 2.0 +sfa_ref_cpu[1915] = 1.0 +sfa_ref_cpu[1916] = 2.0 +sfa_ref_cpu[1917] = 2.0 +sfa_ref_cpu[1918] = 2.0 +sfa_ref_cpu[1919] = 2.0 +sfa_ref_cpu[1920] = 1.0 +sfa_ref_cpu[1921] = 1.0 +sfa_ref_cpu[1922] = 1.0 +sfa_ref_cpu[1923] = 1.0 +sfa_ref_cpu[1924] = 2.0 +sfa_ref_cpu[1925] = 1.0 +sfa_ref_cpu[1926] = 2.0 +sfa_ref_cpu[1927] = 1.0 +sfa_ref_cpu[1928] = 1.0 +sfa_ref_cpu[1929] = 2.0 +sfa_ref_cpu[1930] = 1.0 +sfa_ref_cpu[1931] = 2.0 +sfa_ref_cpu[1932] = 2.0 +sfa_ref_cpu[1933] = 2.0 +sfa_ref_cpu[1934] = 2.0 +sfa_ref_cpu[1935] = 1.0 +sfa_ref_cpu[1936] = 1.0 +sfa_ref_cpu[1937] = 2.0 +sfa_ref_cpu[1938] = 2.0 +sfa_ref_cpu[1939] = 1.0 +sfa_ref_cpu[1940] = 1.0 +sfa_ref_cpu[1941] = 2.0 +sfa_ref_cpu[1942] = 2.0 +sfa_ref_cpu[1943] = 2.0 +sfa_ref_cpu[1944] = 2.0 +sfa_ref_cpu[1945] = 2.0 +sfa_ref_cpu[1946] = 2.0 +sfa_ref_cpu[1947] = 1.0 +sfa_ref_cpu[1948] = 1.0 +sfa_ref_cpu[1949] = 2.0 +sfa_ref_cpu[1950] = 1.0 +sfa_ref_cpu[1951] = 2.0 +sfa_ref_cpu[1952] = 1.0 +sfa_ref_cpu[1953] = 1.0 +sfa_ref_cpu[1954] = 1.0 +sfa_ref_cpu[1955] = 2.0 +sfa_ref_cpu[1956] = 2.0 +sfa_ref_cpu[1957] = 1.0 +sfa_ref_cpu[1958] = 2.0 +sfa_ref_cpu[1959] = 1.0 +sfa_ref_cpu[1960] = 1.0 +sfa_ref_cpu[1961] = 2.0 +sfa_ref_cpu[1962] = 2.0 +sfa_ref_cpu[1963] = 2.0 +sfa_ref_cpu[1964] = 1.0 +sfa_ref_cpu[1965] = 2.0 +sfa_ref_cpu[1966] = 2.0 +sfa_ref_cpu[1967] = 1.0 +sfa_ref_cpu[1968] = 1.0 +sfa_ref_cpu[1969] = 2.0 +sfa_ref_cpu[1970] = 1.0 +sfa_ref_cpu[1971] = 1.0 +sfa_ref_cpu[1972] = 2.0 +sfa_ref_cpu[1973] = 1.0 +sfa_ref_cpu[1974] = 2.0 +sfa_ref_cpu[1975] = 1.0 +sfa_ref_cpu[1976] = 1.0 +sfa_ref_cpu[1977] = 2.0 +sfa_ref_cpu[1978] = 2.0 +sfa_ref_cpu[1979] = 1.0 +sfa_ref_cpu[1980] = 1.0 +sfa_ref_cpu[1981] = 2.0 +sfa_ref_cpu[1982] = 1.0 +sfa_ref_cpu[1983] = 2.0 +sfa_ref_cpu[1984] = 1.0 +sfa_ref_cpu[1985] = 2.0 +sfa_ref_cpu[1986] = 2.0 +sfa_ref_cpu[1987] = 1.0 +sfa_ref_cpu[1988] = 2.0 +sfa_ref_cpu[1989] = 1.0 +sfa_ref_cpu[1990] = 2.0 +sfa_ref_cpu[1991] = 2.0 +sfa_ref_cpu[1992] = 1.0 +sfa_ref_cpu[1993] = 2.0 +sfa_ref_cpu[1994] = 1.0 +sfa_ref_cpu[1995] = 2.0 +sfa_ref_cpu[1996] = 2.0 +sfa_ref_cpu[1997] = 1.0 +sfa_ref_cpu[1998] = 1.0 +sfa_ref_cpu[1999] = 2.0 +sfa_ref_cpu[2000] = 2.0 +sfa_ref_cpu[2001] = 2.0 +sfa_ref_cpu[2002] = 2.0 +sfa_ref_cpu[2003] = 2.0 +sfa_ref_cpu[2004] = 2.0 +sfa_ref_cpu[2005] = 1.0 +sfa_ref_cpu[2006] = 2.0 +sfa_ref_cpu[2007] = 1.0 +sfa_ref_cpu[2008] = 1.0 +sfa_ref_cpu[2009] = 1.0 +sfa_ref_cpu[2010] = 2.0 +sfa_ref_cpu[2011] = 1.0 +sfa_ref_cpu[2012] = 2.0 +sfa_ref_cpu[2013] = 1.0 +sfa_ref_cpu[2014] = 2.0 +sfa_ref_cpu[2015] = 1.0 +sfa_ref_cpu[2016] = 2.0 +sfa_ref_cpu[2017] = 2.0 +sfa_ref_cpu[2018] = 1.0 +sfa_ref_cpu[2019] = 2.0 +sfa_ref_cpu[2020] = 2.0 +sfa_ref_cpu[2021] = 2.0 +sfa_ref_cpu[2022] = 2.0 +sfa_ref_cpu[2023] = 1.0 +sfa_ref_cpu[2024] = 2.0 +sfa_ref_cpu[2025] = 2.0 +sfa_ref_cpu[2026] = 1.0 +sfa_ref_cpu[2027] = 1.0 +sfa_ref_cpu[2028] = 2.0 +sfa_ref_cpu[2029] = 1.0 +sfa_ref_cpu[2030] = 2.0 +sfa_ref_cpu[2031] = 2.0 +sfa_ref_cpu[2032] = 1.0 +sfa_ref_cpu[2033] = 1.0 +sfa_ref_cpu[2034] = 2.0 +sfa_ref_cpu[2035] = 1.0 +sfa_ref_cpu[2036] = 2.0 +sfa_ref_cpu[2037] = 2.0 +sfa_ref_cpu[2038] = 2.0 +sfa_ref_cpu[2039] = 2.0 +sfa_ref_cpu[2040] = 1.0 +sfa_ref_cpu[2041] = 2.0 +sfa_ref_cpu[2042] = 1.0 +sfa_ref_cpu[2043] = 2.0 +sfa_ref_cpu[2044] = 2.0 +sfa_ref_cpu[2045] = 1.0 +sfa_ref_cpu[2046] = 1.0 +sfa_ref_cpu[2047] = 2.0 +c_ref[0, 0, 0] = 7.5 +c_ref[1, 0, 0] = 10.25 +c_ref[2, 0, 0] = 12.25 +c_ref[3, 0, 0] = 15.25 +c_ref[4, 0, 0] = 13.25 +c_ref[5, 0, 0] = 17.25 +c_ref[6, 0, 0] = 15.25 +c_ref[7, 0, 0] = 15.5 +c_ref[8, 0, 0] = 18.0 +c_ref[9, 0, 0] = 12.25 +c_ref[10, 0, 0] = 14.25 +c_ref[11, 0, 0] = 11.5 +c_ref[12, 0, 0] = 15.0 +c_ref[13, 0, 0] = 14.0 +c_ref[14, 0, 0] = 17.0 +c_ref[15, 0, 0] = 13.25 +c_ref[16, 0, 0] = 19.25 +c_ref[17, 0, 0] = 12.75 +c_ref[18, 0, 0] = 12.5 +c_ref[19, 0, 0] = 17.0 +c_ref[20, 0, 0] = 14.25 +c_ref[21, 0, 0] = 16.25 +c_ref[22, 0, 0] = 18.5 +c_ref[23, 0, 0] = 12.0 +c_ref[24, 0, 0] = 17.25 +c_ref[25, 0, 0] = 13.0 +c_ref[26, 0, 0] = 18.25 +c_ref[27, 0, 0] = 17.0 +c_ref[28, 0, 0] = 10.25 +c_ref[29, 0, 0] = 12.75 +c_ref[30, 0, 0] = 17.5 +c_ref[31, 0, 0] = 19.0 +c_ref[32, 0, 0] = 13.5 +c_ref[33, 0, 0] = 14.75 +c_ref[34, 0, 0] = 14.75 +c_ref[35, 0, 0] = 17.25 +c_ref[36, 0, 0] = 15.25 +c_ref[37, 0, 0] = 18.0 +c_ref[38, 0, 0] = 19.25 +c_ref[39, 0, 0] = 13.75 +c_ref[40, 0, 0] = 15.75 +c_ref[41, 0, 0] = 13.5 +c_ref[42, 0, 0] = 12.0 +c_ref[43, 0, 0] = 16.75 +c_ref[44, 0, 0] = 18.75 +c_ref[45, 0, 0] = 12.75 +c_ref[46, 0, 0] = 10.5 +c_ref[47, 0, 0] = 9.25 +c_ref[48, 0, 0] = 12.5 +c_ref[49, 0, 0] = 14.5 +c_ref[50, 0, 0] = 13.25 +c_ref[51, 0, 0] = 17.25 +c_ref[52, 0, 0] = 14.75 +c_ref[53, 0, 0] = 13.75 +c_ref[54, 0, 0] = 13.5 +c_ref[55, 0, 0] = 12.5 +c_ref[56, 0, 0] = 9.75 +c_ref[57, 0, 0] = 11.0 +c_ref[58, 0, 0] = 16.75 +c_ref[59, 0, 0] = 14.0 +c_ref[60, 0, 0] = 16.0 +c_ref[61, 0, 0] = 13.0 +c_ref[62, 0, 0] = 14.75 +c_ref[63, 0, 0] = 14.75 +c_ref[64, 0, 0] = 13.25 +c_ref[65, 0, 0] = 18.0 +c_ref[66, 0, 0] = 15.0 +c_ref[67, 0, 0] = 13.75 +c_ref[68, 0, 0] = 12.5 +c_ref[69, 0, 0] = 15.75 +c_ref[70, 0, 0] = 10.5 +c_ref[71, 0, 0] = 16.25 +c_ref[72, 0, 0] = 16.25 +c_ref[73, 0, 0] = 14.5 +c_ref[74, 0, 0] = 16.0 +c_ref[75, 0, 0] = 17.0 +c_ref[76, 0, 0] = 17.25 +c_ref[77, 0, 0] = 10.5 +c_ref[78, 0, 0] = 12.5 +c_ref[79, 0, 0] = 13.0 +c_ref[80, 0, 0] = 12.5 +c_ref[81, 0, 0] = 11.0 +c_ref[82, 0, 0] = 15.0 +c_ref[83, 0, 0] = 13.75 +c_ref[84, 0, 0] = 12.25 +c_ref[85, 0, 0] = 13.25 +c_ref[86, 0, 0] = 13.75 +c_ref[87, 0, 0] = 17.0 +c_ref[88, 0, 0] = 14.0 +c_ref[89, 0, 0] = 13.0 +c_ref[90, 0, 0] = 14.25 +c_ref[91, 0, 0] = 15.75 +c_ref[92, 0, 0] = 9.5 +c_ref[93, 0, 0] = 13.0 +c_ref[94, 0, 0] = 11.0 +c_ref[95, 0, 0] = 13.75 +c_ref[96, 0, 0] = 15.25 +c_ref[97, 0, 0] = 12.75 +c_ref[98, 0, 0] = 14.5 +c_ref[99, 0, 0] = 13.0 +c_ref[100, 0, 0] = 11.75 +c_ref[101, 0, 0] = 12.0 +c_ref[102, 0, 0] = 18.0 +c_ref[103, 0, 0] = 15.5 +c_ref[104, 0, 0] = 12.75 +c_ref[105, 0, 0] = 12.5 +c_ref[106, 0, 0] = 14.75 +c_ref[107, 0, 0] = 16.75 +c_ref[108, 0, 0] = 13.5 +c_ref[109, 0, 0] = 15.25 +c_ref[110, 0, 0] = 13.5 +c_ref[111, 0, 0] = 11.75 +c_ref[112, 0, 0] = 17.25 +c_ref[113, 0, 0] = 16.25 +c_ref[114, 0, 0] = 11.25 +c_ref[115, 0, 0] = 10.75 +c_ref[116, 0, 0] = 13.5 +c_ref[117, 0, 0] = 11.5 +c_ref[118, 0, 0] = 15.5 +c_ref[119, 0, 0] = 17.25 +c_ref[120, 0, 0] = 14.75 +c_ref[121, 0, 0] = 17.0 +c_ref[122, 0, 0] = 15.5 +c_ref[123, 0, 0] = 14.75 +c_ref[124, 0, 0] = 18.0 +c_ref[125, 0, 0] = 13.0 +c_ref[126, 0, 0] = 15.5 +c_ref[127, 0, 0] = 14.75 diff --git a/problems/nvidia/nvfp4_gemv/reference.py b/problems/nvidia/nvfp4_gemv/reference.py index 4b236af..ee06e24 100644 --- a/problems/nvidia/nvfp4_gemv/reference.py +++ b/problems/nvidia/nvfp4_gemv/reference.py @@ -9,6 +9,7 @@ def ceil_div(a, b): return (a + b - 1) // b + # Helper function to convert scale factor tensor to blocked format def to_blocked(input_matrix): rows, cols = input_matrix.shape @@ -23,13 +24,14 @@ def to_blocked(input_matrix): return rearranged.flatten() + def ref_kernel( data: input_t, ) -> output_t: """ PyTorch reference implementation of NVFP4 block-scaled GEMV. """ - a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, c_ref = data + a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, _, _, c_ref = data # Get dimensions from MxNxL layout _, _, l = c_ref.shape @@ -73,6 +75,8 @@ def generate_input( b: [1, k, l] - Input vector in torch.float4e2m1fn_x2 data type scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type scale_b: [1, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type c: [m, 1, l] - Output vector in torch.float16 data type """ torch.manual_seed(seed) @@ -110,13 +114,45 @@ def create_scale_factor_tensors(l, mn, sf_k): ref_f8_torch_tensor_cpu_permuted = ref_f8_torch_tensor_cpu.permute( *ref_permute_order ) - return ref_f8_torch_tensor_cpu_permuted + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + # Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout + # Which is needed by the CuTe customized kernel + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) + reordered_f8_torch_tensor_cpu = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_torch_tensor_cpu = reordered_f8_torch_tensor_cpu.permute( + *mma_permute_order + ) + for i in range(mn): + for j in range(sf_k): + for b in range(l): + # Calculate the location in MMA shape + mm = i // (atom_m[0] * atom_m[1]) + mm32 = i % atom_m[0] + mm4 = (i % 128) // atom_m[0] + kk = j // atom_k + kk4 = j % atom_k + reordered_f8_torch_tensor_cpu[mm32, mm4, mm, kk4, kk, b] = ref_f8_torch_tensor_cpu_permuted[i, j, b] + return ref_f8_torch_tensor_cpu_permuted, reordered_f8_torch_tensor_cpu.cuda() sf_k = ceil_div(k, sf_vec_size) - sfa_ref_cpu = create_scale_factor_tensors(l, m, sf_k) - sfb_ref_cpu = create_scale_factor_tensors(l, n_padded_128, sf_k) + sfa_ref_cpu, sfa_permuted = create_scale_factor_tensors(l, m, sf_k) + sfb_ref_cpu, sfb_permuted = create_scale_factor_tensors(l, n_padded_128, sf_k) - return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, c_ref) + return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, sfa_permuted, sfb_permuted, c_ref) check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) diff --git a/problems/nvidia/nvfp4_gemv/submission.py b/problems/nvidia/nvfp4_gemv/submission.py index 03d498f..8ee6814 100644 --- a/problems/nvidia/nvfp4_gemv/submission.py +++ b/problems/nvidia/nvfp4_gemv/submission.py @@ -22,32 +22,6 @@ def ceil_div(a, b): return (a + b - 1) // b -# Helper function to reorder the scale factor tensor to match the layout defined in -# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout -@cute.jit -def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - sf_ref_ptr: cute.Pointer, - sf_mma_ptr: cute.Pointer, - mn: int, - sf_k: int, - l: int, - mma_shape: tuple, -): - mma_permute_order = (3, 4, 1, 5, 2, 0) - permuted_shape = tuple(mma_shape[i] for i in mma_permute_order) - cute_layout = cute.make_ordered_layout(permuted_shape, order=(2, 1, 4, 0, 3, 5)) - - sf_ref_tensor = cute.make_tensor( - sf_ref_ptr, cute.make_layout((mn, sf_k, l), stride=(sf_k, 1, mn * sf_k)) - ) - sf_mma_tensor = cute.make_tensor(sf_mma_ptr, cute_layout) - sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) - sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) - for i in cutlass.range(cute.size(sf_ref_tensor)): - mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) - sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] - - # The CuTe reference implementation for NVFP4 block-scaled GEMV @cute.kernel def kernel( @@ -189,49 +163,6 @@ def my_kernel( return -# Reorder scale factor from (mn, l, sf_k) to (32, 4, rest_m, 4, rest_k, l) layout -def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): - sf_k = ceil_div(k, sf_vec_size) - atom_m = (32, 4) - atom_k = 4 - mma_shape = ( - l, # batch size - ceil_div(mn, atom_m[0] * atom_m[1]), - ceil_div(sf_k, atom_k), - atom_m[0], - atom_m[1], - atom_k, - ) - # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on CPU. - mma_permute_order = (3, 4, 1, 5, 2, 0) - # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) - reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) - # Permute according to mma_permute_order - reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) - - # Helper function to convert scale factor tensor to CUTE-format scale factor tensor - cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - make_ptr( - cutlass.Float8E4M3FN, - ref_f8_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32, - ), - make_ptr( - cutlass.Float8E4M3FN, - reordered_f8_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32, - ), - mn, - sf_k, - l, - mma_shape, - ) - return reordered_f8_tensor.cuda() - - # Global cache for compiled kernel _compiled_kernel_cache = None @@ -289,12 +220,14 @@ def custom_kernel(data: input_t) -> output_t: b: [1, k, l] - Input vector in float4e2m1fn sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn sfb_cpu: [1, k, l] - Scale factors in float8_e4m3fn + sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn + sfb_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn c: [m, 1, l] - Output vector in float16 - + Returns: Output tensor c with computed GEMV results """ - a, b, sfa_cpu, sfb_cpu, c = data + a, b, _, _, sfa_permuted, sfb_permuted, c = data # Ensure kernel is compiled (will use cached version if available) compiled_func = compile_kernel() @@ -304,12 +237,6 @@ def custom_kernel(data: input_t) -> output_t: k = k * 2 # GEMV N dimension is always 1 n = 1 - # Scaling factor needs to pad the N size to 128 - n_padded_128 = 128 - - # Create the reordered scale factor tensors from the reference scale factor tensors via CuTe function. - sfa_reordered = create_reordered_scale_factor_tensor(l, m, k, sfa_cpu) - sfb_reordered = create_reordered_scale_factor_tensor(l, n_padded_128, k, sfb_cpu) # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer a_ptr = make_ptr( @@ -322,13 +249,13 @@ def custom_kernel(data: input_t) -> output_t: c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 ) sfa_ptr = make_ptr( - sf_dtype, sfa_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) sfb_ptr = make_ptr( - sf_dtype, sfb_reordered.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + sf_dtype, sfb_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) # Execute the compiled kernel compiled_func(a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l)) - return c + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemv/task.py b/problems/nvidia/nvfp4_gemv/task.py index 6005f59..c1a06e3 100644 --- a/problems/nvidia/nvfp4_gemv/task.py +++ b/problems/nvidia/nvfp4_gemv/task.py @@ -1,7 +1,7 @@ import torch from typing import TypedDict, TypeVar -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) output_t = TypeVar("output_t", bound=torch.Tensor) class TestSpec(TypedDict): m: int diff --git a/problems/nvidia/nvfp4_gemv/template.py b/problems/nvidia/nvfp4_gemv/template.py index 2cf273e..acb8228 100644 --- a/problems/nvidia/nvfp4_gemv/template.py +++ b/problems/nvidia/nvfp4_gemv/template.py @@ -8,15 +8,17 @@ def custom_kernel(data: input_t) -> output_t: data: Tuple that expands to: a: torch.Tensor[float4e2m1fn] of shape [m, k, l], b: torch.Tensor[float4e2m1fn] of shape [1, k, l], - sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], - sfb: torch.Tensor[float8_e4m3fnuz] of shape [1, k // 16, l], + sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], used by reference implementation + sfb: torch.Tensor[float8_e4m3fnuz] of shape [1, k // 16, l], used by reference implementation + sfa_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l], + sfb_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], c: torch.Tensor[float16] of shape [m, 1, l] Returns: Tensor containing output in float16 c: torch.Tensor[float16] of shape [m, 1, l] """ # c: [l, m, 1] is pre-allocated memory to avoid timing allocation overhead. - a, b, sfa, sfb, c = data + a, b, sfa, sfb, sfa_permuted, sfb_permuted, c = data # Your implementation here diff --git a/problems/nvidia/nvfp4_gemv/test_python_1.sh b/problems/nvidia/nvfp4_gemv/test_python_1.sh new file mode 100644 index 0000000..ae91eba --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/test_python_1.sh @@ -0,0 +1,87 @@ +# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build +BUILD_DIR=/home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/build_python +LLVM_DIR=$BUILD_DIR/llvm-prebuilt +# # BUILD_DIR=/home/scratch.ftse_gpu/workspace/dkg/build +# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build +# #BUILD_DIR=/home/yanchengz/scratch_1/dynamic-kernel-generator/build_debug2 +# # sudo /home/scratch.computelab/utils/driver/install_driver.py --installer=/home/builds/daily/display/x86_64/rel/gpu_drv/r580/r580_00/20250527_36037303/NVIDIA-Linux-x86_64-rel_gpu_drv_r580_r580_00-20250527_36037303-internal.run --reason="Change to tot driver" + + +# # BUILD_DIR=/home/scratch.nbommi_gpu/warp-phase-trace/dynamic-kernel-generator/build_main + +export PYTHONPATH=$BUILD_DIR/cutlass_ir/python_packages +#export PYTHONPATH=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/scripts +export CUDA_TOOLKIT_PATH=$BUILD_DIR/compiler_next +MLIR_CUDA_RUNTIME="$LLVM_DIR/lib/libmlir_cuda_runtime.so" +MLIR_C_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_c_runner_utils.so" +MLIR_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_runner_utils.so" +CUDA_DIALECT_RUNTIME="$BUILD_DIR/lib/libcuda_dialect_runtime.so" +export CUTE_DSL_LIBS="$MLIR_CUDA_RUNTIME:$MLIR_C_RUNNER_UTILS:$MLIR_RUNNER_UTILS:$CUDA_DIALECT_RUNTIME" + + +#export CUTE_DSL_PREPROCESSOR=True + +# export CUTE_DSL_PRINT_IR=1 +# just compile the IR but not execute it +# export CUTE_DSL_DRYRUN=1 +# export CUTE_DSL_JIT_TIME_PROFILING=ON +# export CUTE_DSL_KEEP_IR=True +# export CUTE_DSL_PRINT_IR=1 +# export CUTE_DSL_KEEP_CUBIN=1 +# export CUTE_DSL_LINEINFO=True +# export CUTE_DSL_LOG_TO_CONSOLE=1 +# export PYTHONUNBUFFERED=1 +# export CUTE_DSL_KEEP_SASS=1 +# whether to show detailed log in preprocessing +# export CUTE_DSL_FILTER_STACKTRACE=10 +export CUTE_DSL_ARCH=sm_100a + +# +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/reference-kernels/problems/nvidia/nvfp4_gemv/submission.py +/home/scratch.vickiw_gpu/env/bin/python3 eval.py test task.yml +/home/scratch.vickiw_gpu/env/bin/python3 eval.py benchmark task.yml +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/cuda-gdb --args + +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py +# # /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gated_dual_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gecccccbkvnjtrvtfreufijlfglnudnvuggvdfucidbnhk +# mm.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemm/nvfp4_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemv/nvfp4_gemv.py +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool memcheck \ +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,16384 #135us +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 4096,128,7168 #62 + +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,2048 #26 + + +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gated_dual_gemm/nvfp4_gated_dual_gemm.py +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_naive.py + + + +# print out ncu time +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ +# python3 vicki/tutorial_fp16_gemm_0__.py --mnk 7168,8,512 + +# use sanitizer to check race contention and memref error +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck|memcheck +# cutlass_ir/compiler/test/python/examples/sm_100a/test_nvfp4_gemv.py + +# capture ncu report +# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --check-exit-code 0 -f --set full --import-source yes --target-processes all --clock-control base --cache-control none -o gemv_4.1 \ +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv.py --m 128 --k 128 --l 2 + +# regular run python example +# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/min_latency_hmma.py --mnkl 7168,8,512,1 + +# run pytest +# pytest cutlass_ir/compiler/test/python/examples/sm_80/test_sgemm.py diff --git a/problems/nvidia/nvfp4_gemv/utils.py b/problems/nvidia/nvfp4_gemv/utils.py new file mode 100644 index 0000000..e8a9082 --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/utils.py @@ -0,0 +1,176 @@ +import os +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + + Raises: + AssertionError: If the tensors are not all close within the given tolerance. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received - expected) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +@torch.no_grad() +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) + + if len(reasons) > 0: + return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) + + return True, '' + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + return wrapped + + +class DeterministicContext: + def __init__(self): + self.allow_tf32 = None + self.deterministic = None + self.cublas = None + + def __enter__(self): + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + torch.use_deterministic_algorithms(False) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy \ No newline at end of file From 20c0bb05e3fddc5c4fd60f8b90a848e6709ed33d Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Tue, 21 Oct 2025 01:28:47 -0700 Subject: [PATCH 16/29] remove useless files. --- problems/nvidia/nvfp4_dual_gemm/eval.py | 500 ---- .../nvidia/nvfp4_dual_gemm/test_python_1.sh | 87 - problems/nvidia/nvfp4_gemm/eval.py | 437 --- problems/nvidia/nvfp4_gemm/eval_vicki.py | 500 ---- problems/nvidia/nvfp4_gemm/test_python_1.sh | 87 - problems/nvidia/nvfp4_gemv/eval.py | 500 ---- problems/nvidia/nvfp4_gemv/log | 2332 ----------------- problems/nvidia/nvfp4_gemv/test_python_1.sh | 87 - 8 files changed, 4530 deletions(-) delete mode 100644 problems/nvidia/nvfp4_dual_gemm/eval.py delete mode 100644 problems/nvidia/nvfp4_dual_gemm/test_python_1.sh delete mode 100644 problems/nvidia/nvfp4_gemm/eval.py delete mode 100644 problems/nvidia/nvfp4_gemm/eval_vicki.py delete mode 100644 problems/nvidia/nvfp4_gemm/test_python_1.sh delete mode 100644 problems/nvidia/nvfp4_gemv/eval.py delete mode 100644 problems/nvidia/nvfp4_gemv/log delete mode 100644 problems/nvidia/nvfp4_gemv/test_python_1.sh diff --git a/problems/nvidia/nvfp4_dual_gemm/eval.py b/problems/nvidia/nvfp4_dual_gemm/eval.py deleted file mode 100644 index e8bb5b2..0000000 --- a/problems/nvidia/nvfp4_dual_gemm/eval.py +++ /dev/null @@ -1,500 +0,0 @@ -import base64 -import dataclasses -import multiprocessing -import re -import time -import os -import sys -import math -from pathlib import Path -from typing import Any, Optional -import tempfile - -import torch.cuda -from cutlass.cute.nvgpu.common import OpError - -from utils import set_seed, clear_l2_cache - -try: - from task import TestSpec -except ImportError: - TestSpec = dict - -from reference import check_implementation, generate_input - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, "w") - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -def _combine(a: int, b: int) -> int: - # combine two integers into one: - # we need this to generate a secret seed based on the test-level seed and - # the global secret seed. - # the test-level seeds are public knowledge, and typically relatively small numbers, - # so we need to make sure they don't provide any useful info for the full seed. - # This Cantor construction ensures that if the secret seed is a large number, - # then so is the overall seed. - return int(a + (a + b) * (a + b + 1) // 2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as E: - print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) - exit(113) - - tests = [] - lines = content.splitlines() - match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" - for line in lines: - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - pass - - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - - return tests - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def calculate_stats(durations: list[int]): - """ - Calculate statistical data from a list of durations. - - @param durations: A list of durations in nanoseconds. - @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. - """ - runs = len(durations) - total = sum(durations) - best = min(durations) - worst = max(durations) - - avg = total / runs - variance = sum(map(lambda x: (x - avg) ** 2, durations)) - std = math.sqrt(variance / (runs - 1)) - err = std / math.sqrt(runs) - - return Stats( - runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) - ) - - -def _clone_data(data): - """ - Recursively goes through data and clones all tensors. - """ - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - elif isinstance(data, list): - return [_clone_data(x) for x in data] - elif isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - elif isinstance(data, torch.Tensor): - return data.clone() - else: - return data - - -def _run_single_test(test: TestCase): - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - - data = generate_input(**test.args) - torch.cuda.synchronize() - try: - submission_output = custom_kernel(_clone_data(data)) - - except OpError as E: - print(f"Encountered {E}", file=sys.stderr) - return False, str(E) - torch.cuda.synchronize() - return check_implementation(data, submission_output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - """ - Runs a single test in another process. - """ - return pool.apply(_run_single_test, (test,)) - - -def run_testing( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes the actual test case code and checks for correctness. - - @param logger: A PopcornOutput object used for logging test results. - @param tests: A list of TestCase objects representing the test cases to be executed. - @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. - """ - # Step 1: Compile kernel once before running tests - logger.log("compile", "start") - compile_success, compile_error = pool.apply(_compile_kernel_once) - if not compile_success: - logger.log("compile", "fail") - logger.log("compile.error", compile_error) - return 112 - logger.log("compile", "pass") - - # Step 2: Run all tests with compiled kernel - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if not good: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - else: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _compile_kernel_once(): - """ - Compile the kernel once before any benchmarking. - This ensures compilation time is not included in benchmark results. - """ - from submission import compile_kernel - - try: - # Trigger compilation (will be cached) - compile_kernel() - torch.cuda.synchronize() - return True, None - except OpError as E: - return False, f"Compilation failed: {E}" - except Exception as E: - return False, f"Compilation failed: {E}" - - -def _run_single_benchmark( - test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float -) -> Stats | Any: - """ - Runs one benchmark. Do not call directly. - """ - from submission import custom_kernel, compile_kernel - - durations = [] - # generate input data once - data = generate_input(**test.args) - check_copy = _clone_data(data) - - # Ensure kernel is compiled before any timing (compilation is cached) - try: - compile_kernel() - torch.cuda.synchronize() - except OpError as E: - return f"Compilation failed: {E}" - except Exception as E: - return f"Compilation failed: {E}" - - # first, one obligatory correctness check - try: - output = custom_kernel(_clone_data(data)) - except OpError as E: - return f"Encountered {E}" - good, message = check_implementation(check_copy, output) - if not good: - return message - - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 200 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. - - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - if recheck: - # ensure we use a different seed for every benchmark - if "seed" in test.args: - test.args["seed"] += 13 - - data = generate_input(**test.args) - check_copy = _clone_data(data) - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - clear_l2_cache() - - start_event.record() - output = custom_kernel(data) - end_event.record() - torch.cuda.synchronize() - duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns - - if recheck: - good, message = check_implementation(check_copy, output) - if not good: - return message - - del output - durations.append(duration) - - if i > 1: - total_bm_duration = time.perf_counter_ns() - bm_start_time - stats = calculate_stats(durations) - # stop if either - # a) relative error dips below 0.1% - # b) we exceed the total time limit for benchmarking the kernel - # c) we exceed 2 minutes of total wallclock time. - if ( - stats.err / stats.mean < 0.001 - or stats.mean * stats.runs > max_time_ns - or total_bm_duration > 120e9 - ): - break - - return calculate_stats(durations) - - -def run_single_benchmark( - pool: multiprocessing.Pool, - test: TestCase, - recheck: bool, - max_repeats: int, - max_time_ns: float, -): - """ - For a particular test case, check correctness (if applicable) and grab runtime results. - - @param pool: Process on which the benchmark will be launched. - @param test: TestCase object. - @param recheck: Flag for whether to explicitly check functional correctness. - @param max_repeats: Number of trials to repeat. - @param max_time_ns: Timeout time in nanoseconds. - @return: A Stats object for this particular benchmark case or an error if the test fails. - """ - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) - - -def run_benchmarking( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes benchmarking code for a CUDA Kernel and logs runtimes. - - @param logger: A PopcornOutput object used for logging benchmark results. - @param pool: Process on which the benchmarks will be launched. - @param tests: A list of TestCase objects representing the test cases to be benchmarked. - @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. - """ - # Step 1: Compile kernel once (outside of timing) - logger.log("compile", "start") - compile_success, compile_error = pool.apply(_compile_kernel_once) - if not compile_success: - logger.log("compile", "fail") - logger.log("compile.error", compile_error) - return 112 - logger.log("compile", "pass") - - # Step 2: Warm up with compiled kernel - run_single_benchmark(pool, tests[0], False, 200, 10e7) - - # Step 3: Run benchmarks (compilation time excluded) - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 200, 10e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def run_single_profile(test: TestCase) -> str: - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - from torch.profiler import profile, record_function, ProfilerActivity - - data = generate_input(**test.args) - torch.cuda.synchronize() - - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - submission_output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) - - -def run_profiling(logger: PopcornOutput, tests: list[TestCase]): - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - report = run_single_profile(test) - logger.log( - f"benchmark.{idx}.report", - base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), - ) - logger.log("check", "pass") - return 0 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - - filename = None - - with tempfile.NamedTemporaryFile(delete=False) as tmp: - - def build_test_string(tests: list[dict]): - as_str = "" - for test in tests: - kvs = [] - for k, v in test.items(): - kvs.append(f"{k}: {v}") - as_str += "; ".join(kvs) + "\n" - return as_str - - import yaml - - yaml_content = yaml.safe_load(open(sys.argv[2], "r")) - if mode == "test": - tests_str = build_test_string(yaml_content.get("tests", [])) - elif mode in ("benchmark", "leaderboard", "profile"): - tests_str = build_test_string(yaml_content.get("benchmarks", [])) - - tmp.write(tests_str.encode("utf-8")) - tmp.flush() - filename = tmp.name - - tests = get_test_cases(filename, seed) - - os.unlink(filename) - - with PopcornOutput(int(fd)) as logger: - import multiprocessing - - mp_context = multiprocessing.get_context("spawn") - with mp_context.Pool(1) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - - if mode == "leaderboard": - # Step 1: Compile kernel once (outside of timing) - logger.log("compile", "start") - compile_success, compile_error = pool.apply(_compile_kernel_once) - if not compile_success: - logger.log("compile", "fail") - logger.log("compile.error", compile_error) - return 112 - logger.log("compile", "pass") - - # Step 2: Warmup with compiled kernel - run_single_benchmark(pool, tests[0], False, 200, 1e7) - - # Step 3: Run leaderboard benchmarks (compilation time excluded) - logger.log("benchmark-count", len(tests)) - passed = True - for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 200, 30e9) - logger.log(f"benchmark.{i}.spec", tests[i].spec) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log( - f"benchmark.{i}.{field.name}", - getattr(result, field.name), - ) - else: - passed = False - logger.log(f"benchmark.{i}.status", "fail") - logger.log( - f"benchmark.{i}.error", str(result) - ) # TODO: Make sure result implements __str__? - break - - logger.log("check", "pass" if passed else "fail") - elif mode == "profile": - run_profiling(logger, tests) - else: - # TODO: Implement script mode - return 2 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/problems/nvidia/nvfp4_dual_gemm/test_python_1.sh b/problems/nvidia/nvfp4_dual_gemm/test_python_1.sh deleted file mode 100644 index 8648bdb..0000000 --- a/problems/nvidia/nvfp4_dual_gemm/test_python_1.sh +++ /dev/null @@ -1,87 +0,0 @@ -# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build -BUILD_DIR=/home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/build_python -LLVM_DIR=$BUILD_DIR/llvm-prebuilt -# # BUILD_DIR=/home/scratch.ftse_gpu/workspace/dkg/build -# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build -# #BUILD_DIR=/home/yanchengz/scratch_1/dynamic-kernel-generator/build_debug2 -# # sudo /home/scratch.computelab/utils/driver/install_driver.py --installer=/home/builds/daily/display/x86_64/rel/gpu_drv/r580/r580_00/20250527_36037303/NVIDIA-Linux-x86_64-rel_gpu_drv_r580_r580_00-20250527_36037303-internal.run --reason="Change to tot driver" - - -# # BUILD_DIR=/home/scratch.nbommi_gpu/warp-phase-trace/dynamic-kernel-generator/build_main - -export PYTHONPATH=$BUILD_DIR/cutlass_ir/python_packages -#export PYTHONPATH=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/scripts -export CUDA_TOOLKIT_PATH=$BUILD_DIR/compiler_next -MLIR_CUDA_RUNTIME="$LLVM_DIR/lib/libmlir_cuda_runtime.so" -MLIR_C_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_c_runner_utils.so" -MLIR_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_runner_utils.so" -CUDA_DIALECT_RUNTIME="$BUILD_DIR/lib/libcuda_dialect_runtime.so" -export CUTE_DSL_LIBS="$MLIR_CUDA_RUNTIME:$MLIR_C_RUNNER_UTILS:$MLIR_RUNNER_UTILS:$CUDA_DIALECT_RUNTIME" - - -#export CUTE_DSL_PREPROCESSOR=True - -# export CUTE_DSL_PRINT_IR=1 -# just compile the IR but not execute it -# export CUTE_DSL_DRYRUN=1 -# export CUTE_DSL_JIT_TIME_PROFILING=ON -# export CUTE_DSL_KEEP_IR=True -# export CUTE_DSL_PRINT_IR=1 -# export CUTE_DSL_KEEP_CUBIN=1 -# export CUTE_DSL_LINEINFO=True -# export CUTE_DSL_LOG_TO_CONSOLE=1 -# export PYTHONUNBUFFERED=1 -# export CUTE_DSL_KEEP_SASS=1 -# whether to show detailed log in preprocessing -# export CUTE_DSL_FILTER_STACKTRACE=10 -export CUTE_DSL_ARCH=sm_100a - -# -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py -/home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/reference-kernels/problems/nvidia/nvfp4_dual_gemm/submission.py -/home/scratch.vickiw_gpu/env/bin/python3 eval.py test task.yml -/home/scratch.vickiw_gpu/env/bin/python3 eval.py benchmark task.yml -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/cuda-gdb --args - -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py -# # /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gated_dual_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gecccccbkvnjtrvtfreufijlfglnudnvuggvdfucidbnhk -# mm.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemm/nvfp4_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemv/nvfp4_gemv.py -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool memcheck \ -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,16384 #135us -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 4096,128,7168 #62 - -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,2048 #26 - - -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gated_dual_gemm/nvfp4_gated_dual_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_naive.py - - - -# print out ncu time -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# python3 vicki/tutorial_fp16_gemm_0__.py --mnk 7168,8,512 - -# use sanitizer to check race contention and memref error -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck|memcheck -# cutlass_ir/compiler/test/python/examples/sm_100a/test_nvfp4_gemv.py - -# capture ncu report -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --check-exit-code 0 -f --set full --import-source yes --target-processes all --clock-control base --cache-control none -o gemv_4.1 \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv.py --m 128 --k 128 --l 2 - -# regular run python example -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/min_latency_hmma.py --mnkl 7168,8,512,1 - -# run pytest -# pytest cutlass_ir/compiler/test/python/examples/sm_80/test_sgemm.py diff --git a/problems/nvidia/nvfp4_gemm/eval.py b/problems/nvidia/nvfp4_gemm/eval.py deleted file mode 100644 index 072b176..0000000 --- a/problems/nvidia/nvfp4_gemm/eval.py +++ /dev/null @@ -1,437 +0,0 @@ -import base64 -import dataclasses -import multiprocessing -import re -import time -import os -import sys -import math -from pathlib import Path -from typing import Any, Optional -import tempfile - -import torch.cuda -from cutlass.cute.nvgpu.common import OpError - -from utils import set_seed, clear_l2_cache - -try: - from task import TestSpec -except ImportError: - TestSpec = dict - -from reference import check_implementation, generate_input - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, "w") - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -def _combine(a: int, b: int) -> int: - # combine two integers into one: - # we need this to generate a secret seed based on the test-level seed and - # the global secret seed. - # the test-level seeds are public knowledge, and typically relatively small numbers, - # so we need to make sure they don't provide any useful info for the full seed. - # This Cantor construction ensures that if the secret seed is a large number, - # then so is the overall seed. - return int(a + (a + b) * (a + b + 1) // 2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as E: - print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) - exit(113) - - tests = [] - lines = content.splitlines() - match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" - for line in lines: - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - pass - - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - - return tests - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def calculate_stats(durations: list[int]): - """ - Calculate statistical data from a list of durations. - @param durations: A list of durations in nanoseconds. - @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. - """ - runs = len(durations) - total = sum(durations) - best = min(durations) - worst = max(durations) - - avg = total / runs - variance = sum(map(lambda x: (x - avg) ** 2, durations)) - std = math.sqrt(variance / (runs - 1)) - err = std / math.sqrt(runs) - - return Stats( - runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) - ) - - -def _clone_data(data): - """ - Recursively goes through data and clones all tensors. - """ - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - elif isinstance(data, list): - return [_clone_data(x) for x in data] - elif isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - elif isinstance(data, torch.Tensor): - return data.clone() - else: - return data - - -def _run_single_test(test: TestCase): - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - - data = generate_input(**test.args) - torch.cuda.synchronize() - try: - submission_output = custom_kernel(_clone_data(data)) - - except OpError as E: - print(f"Encountered {E}", file=sys.stderr) - return False, str(E) - torch.cuda.synchronize() - return check_implementation(data, submission_output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - """ - Runs a single test in another process. - """ - return pool.apply(_run_single_test, (test,)) - - -def run_testing( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes the actual test case code and checks for correctness. - @param logger: A PopcornOutput object used for logging test results. - @param tests: A list of TestCase objects representing the test cases to be executed. - @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. - """ - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if not good: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - else: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _run_single_benchmark( - test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float -) -> Stats | Any: - """ - Runs one benchmark. Do not call directly. - """ - from submission import custom_kernel - - durations = [] - # generate input data once - data = generate_input(**test.args) - check_copy = _clone_data(data) - # first, one obligatory correctness check - try: - output = custom_kernel(_clone_data(data)) - except OpError as E: - return f"Encountered {E}" - good, message = check_implementation(check_copy, output) - if not good: - return message - - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 100 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. - - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - if recheck: - # ensure we use a different seed for every benchmark - if "seed" in test.args: - test.args["seed"] += 13 - - data = generate_input(**test.args) - check_copy = _clone_data(data) - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - clear_l2_cache() - - start_event.record() - output = custom_kernel(data) - end_event.record() - torch.cuda.synchronize() - duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns - - if recheck: - good, message = check_implementation(check_copy, output) - if not good: - return message - - del output - durations.append(duration) - - if i > 1: - total_bm_duration = time.perf_counter_ns() - bm_start_time - stats = calculate_stats(durations) - # stop if either - # a) relative error dips below 0.1% - # b) we exceed the total time limit for benchmarking the kernel - # c) we exceed 2 minutes of total wallclock time. - if ( - stats.err / stats.mean < 0.001 - or stats.mean * stats.runs > max_time_ns - or total_bm_duration > 120e9 - ): - break - - return calculate_stats(durations) - - -def run_single_benchmark( - pool: multiprocessing.Pool, - test: TestCase, - recheck: bool, - max_repeats: int, - max_time_ns: float, -): - """ - For a particular test case, check correctness (if applicable) and grab runtime results. - @param pool: Process on which the benchmark will be launched. - @param test: TestCase object. - @param recheck: Flag for whether to explicitly check functional correctness. - @param max_repeats: Number of trials to repeat. - @param max_time_ns: Timeout time in nanoseconds. - @return: A Stats object for this particular benchmark case or an error if the test fails. - """ - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) - - -def run_benchmarking( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes benchmarking code for a CUDA Kernel and logs runtimes. - @param logger: A PopcornOutput object used for logging benchmark results. - @param pool: Process on which the benchmarks will be launched. - @param tests: A list of TestCase objects representing the test cases to be benchmarked. - @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. - """ - # warm up - run_single_benchmark(pool, tests[0], False, 100, 10e7) - - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 100, 10e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def run_single_profile(test: TestCase) -> str: - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - from torch.profiler import profile, record_function, ProfilerActivity - - data = generate_input(**test.args) - torch.cuda.synchronize() - - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - submission_output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) - - -def run_profiling(logger: PopcornOutput, tests: list[TestCase]): - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - report = run_single_profile(test) - logger.log( - f"benchmark.{idx}.report", - base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), - ) - logger.log("check", "pass") - return 0 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - - filename = None - - with tempfile.NamedTemporaryFile(delete=False) as tmp: - - def build_test_string(tests: list[dict]): - as_str = "" - for test in tests: - kvs = [] - for k, v in test.items(): - kvs.append(f"{k}: {v}") - as_str += "; ".join(kvs) + "\n" - return as_str - - import yaml - - yaml_content = yaml.safe_load(open(sys.argv[2], "r")) - if mode == "test": - tests_str = build_test_string(yaml_content.get("tests", [])) - elif mode in ("benchmark", "leaderboard", "profile"): - tests_str = build_test_string(yaml_content.get("benchmarks", [])) - - tmp.write(tests_str.encode("utf-8")) - tmp.flush() - filename = tmp.name - - tests = get_test_cases(filename, seed) - - os.unlink(filename) - - with PopcornOutput(int(fd)) as logger: - import multiprocessing - - mp_context = multiprocessing.get_context("spawn") - with mp_context.Pool(1) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - - if mode == "leaderboard": - # warmup - run_single_benchmark(pool, tests[0], False, 100, 1e7) - logger.log("benchmark-count", len(tests)) - passed = True - for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 100, 30e9) - logger.log(f"benchmark.{i}.spec", tests[i].spec) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log( - f"benchmark.{i}.{field.name}", - getattr(result, field.name), - ) - else: - passed = False - logger.log(f"benchmark.{i}.status", "fail") - logger.log( - f"benchmark.{i}.error", str(result) - ) # TODO: Make sure result implements __str__? - break - - logger.log("check", "pass" if passed else "fail") - elif mode == "profile": - run_profiling(logger, tests) - else: - # TODO: Implement script mode - return 2 - - -if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/eval_vicki.py b/problems/nvidia/nvfp4_gemm/eval_vicki.py deleted file mode 100644 index 2441b7d..0000000 --- a/problems/nvidia/nvfp4_gemm/eval_vicki.py +++ /dev/null @@ -1,500 +0,0 @@ -import base64 -import dataclasses -import multiprocessing -import re -import time -import os -import sys -import math -from pathlib import Path -from typing import Any, Optional -import tempfile - -import torch.cuda -from cutlass.cute.nvgpu.common import OpError - -from utils import set_seed, clear_l2_cache - -try: - from task import TestSpec -except ImportError: - TestSpec = dict - -from reference import check_implementation, generate_input - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, "w") - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -def _combine(a: int, b: int) -> int: - # combine two integers into one: - # we need this to generate a secret seed based on the test-level seed and - # the global secret seed. - # the test-level seeds are public knowledge, and typically relatively small numbers, - # so we need to make sure they don't provide any useful info for the full seed. - # This Cantor construction ensures that if the secret seed is a large number, - # then so is the overall seed. - return int(a + (a + b) * (a + b + 1) // 2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as E: - print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) - exit(113) - - tests = [] - lines = content.splitlines() - match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" - for line in lines: - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - pass - - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - - return tests - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def calculate_stats(durations: list[int]): - """ - Calculate statistical data from a list of durations. - - @param durations: A list of durations in nanoseconds. - @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. - """ - runs = len(durations) - total = sum(durations) - best = min(durations) - worst = max(durations) - - avg = total / runs - variance = sum(map(lambda x: (x - avg) ** 2, durations)) - std = math.sqrt(variance / (runs - 1)) - err = std / math.sqrt(runs) - - return Stats( - runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) - ) - - -def _clone_data(data): - """ - Recursively goes through data and clones all tensors. - """ - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - elif isinstance(data, list): - return [_clone_data(x) for x in data] - elif isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - elif isinstance(data, torch.Tensor): - return data.clone() - else: - return data - - -def _run_single_test(test: TestCase): - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - - data = generate_input(**test.args) - torch.cuda.synchronize() - try: - submission_output = custom_kernel(_clone_data(data)) - - except OpError as E: - print(f"Encountered {E}", file=sys.stderr) - return False, str(E) - torch.cuda.synchronize() - return check_implementation(data, submission_output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - """ - Runs a single test in another process. - """ - return pool.apply(_run_single_test, (test,)) - - -def run_testing( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes the actual test case code and checks for correctness. - - @param logger: A PopcornOutput object used for logging test results. - @param tests: A list of TestCase objects representing the test cases to be executed. - @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. - """ - # Step 1: Compile kernel once before running tests - logger.log("compile", "start") - compile_success, compile_error = pool.apply(_compile_kernel_once) - if not compile_success: - logger.log("compile", "fail") - logger.log("compile.error", compile_error) - return 112 - logger.log("compile", "pass") - - # Step 2: Run all tests with compiled kernel - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if not good: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - else: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _compile_kernel_once(): - """ - Compile the kernel once before any benchmarking. - This ensures compilation time is not included in benchmark results. - """ - from submission import compile_kernel - - try: - # Trigger compilation (will be cached) - compile_kernel() - torch.cuda.synchronize() - return True, None - except OpError as E: - return False, f"Compilation failed: {E}" - except Exception as E: - return False, f"Compilation failed: {E}" - - -def _run_single_benchmark( - test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float -) -> Stats | Any: - """ - Runs one benchmark. Do not call directly. - """ - from submission import custom_kernel, compile_kernel - - durations = [] - # generate input data once - data = generate_input(**test.args) - check_copy = _clone_data(data) - - # Ensure kernel is compiled before any timing (compilation is cached) - try: - compile_kernel() - torch.cuda.synchronize() - except OpError as E: - return f"Compilation failed: {E}" - except Exception as E: - return f"Compilation failed: {E}" - - # first, one obligatory correctness check - try: - output = custom_kernel(_clone_data(data)) - except OpError as E: - return f"Encountered {E}" - good, message = check_implementation(check_copy, output) - if not good: - return message - - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 200 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. - - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - if recheck: - # ensure we use a different seed for every benchmark - if "seed" in test.args: - test.args["seed"] += 13 - - data = generate_input(**test.args) - check_copy = _clone_data(data) - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - clear_l2_cache() - - start_event.record() - output = custom_kernel(data) - end_event.record() - torch.cuda.synchronize() - duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns - - if recheck: - good, message = check_implementation(check_copy, output) - if not good: - return message - - del output - durations.append(duration) - - if i > 1: - total_bm_duration = time.perf_counter_ns() - bm_start_time - stats = calculate_stats(durations) - # stop if either - # a) relative error dips below 0.1% - # b) we exceed the total time limit for benchmarking the kernel - # c) we exceed 2 minutes of total wallclock time. - if ( - stats.err / stats.mean < 0.001 - or stats.mean * stats.runs > max_time_ns - or total_bm_duration > 120e9 - ): - break - - return calculate_stats(durations) - - -def run_single_benchmark( - pool: multiprocessing.Pool, - test: TestCase, - recheck: bool, - max_repeats: int, - max_time_ns: float, -): - """ - For a particular test case, check correctness (if applicable) and grab runtime results. - - @param pool: Process on which the benchmark will be launched. - @param test: TestCase object. - @param recheck: Flag for whether to explicitly check functional correctness. - @param max_repeats: Number of trials to repeat. - @param max_time_ns: Timeout time in nanoseconds. - @return: A Stats object for this particular benchmark case or an error if the test fails. - """ - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) - - -def run_benchmarking( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes benchmarking code for a CUDA Kernel and logs runtimes. - - @param logger: A PopcornOutput object used for logging benchmark results. - @param pool: Process on which the benchmarks will be launched. - @param tests: A list of TestCase objects representing the test cases to be benchmarked. - @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. - """ - # Step 1: Compile kernel once (outside of timing) - logger.log("compile", "start") - compile_success, compile_error = pool.apply(_compile_kernel_once) - if not compile_success: - logger.log("compile", "fail") - logger.log("compile.error", compile_error) - return 112 - logger.log("compile", "pass") - - # Step 2: Warm up with compiled kernel - run_single_benchmark(pool, tests[0], False, 2, 10e7) - - # Step 3: Run benchmarks (compilation time excluded) - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 2, 10e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def run_single_profile(test: TestCase) -> str: - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - from torch.profiler import profile, record_function, ProfilerActivity - - data = generate_input(**test.args) - torch.cuda.synchronize() - - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - submission_output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) - - -def run_profiling(logger: PopcornOutput, tests: list[TestCase]): - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - report = run_single_profile(test) - logger.log( - f"benchmark.{idx}.report", - base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), - ) - logger.log("check", "pass") - return 0 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - - filename = None - - with tempfile.NamedTemporaryFile(delete=False) as tmp: - - def build_test_string(tests: list[dict]): - as_str = "" - for test in tests: - kvs = [] - for k, v in test.items(): - kvs.append(f"{k}: {v}") - as_str += "; ".join(kvs) + "\n" - return as_str - - import yaml - - yaml_content = yaml.safe_load(open(sys.argv[2], "r")) - if mode == "test": - tests_str = build_test_string(yaml_content.get("tests", [])) - elif mode in ("benchmark", "leaderboard", "profile"): - tests_str = build_test_string(yaml_content.get("benchmarks", [])) - - tmp.write(tests_str.encode("utf-8")) - tmp.flush() - filename = tmp.name - - tests = get_test_cases(filename, seed) - - os.unlink(filename) - - with PopcornOutput(int(fd)) as logger: - import multiprocessing - - mp_context = multiprocessing.get_context("spawn") - with mp_context.Pool(1) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - - if mode == "leaderboard": - # Step 1: Compile kernel once (outside of timing) - logger.log("compile", "start") - compile_success, compile_error = pool.apply(_compile_kernel_once) - if not compile_success: - logger.log("compile", "fail") - logger.log("compile.error", compile_error) - return 112 - logger.log("compile", "pass") - - # Step 2: Warmup with compiled kernel - run_single_benchmark(pool, tests[0], False, 2, 1e7) - - # Step 3: Run leaderboard benchmarks (compilation time excluded) - logger.log("benchmark-count", len(tests)) - passed = True - for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 2, 30e9) - logger.log(f"benchmark.{i}.spec", tests[i].spec) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log( - f"benchmark.{i}.{field.name}", - getattr(result, field.name), - ) - else: - passed = False - logger.log(f"benchmark.{i}.status", "fail") - logger.log( - f"benchmark.{i}.error", str(result) - ) # TODO: Make sure result implements __str__? - break - - logger.log("check", "pass" if passed else "fail") - elif mode == "profile": - run_profiling(logger, tests) - else: - # TODO: Implement script mode - return 2 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/problems/nvidia/nvfp4_gemm/test_python_1.sh b/problems/nvidia/nvfp4_gemm/test_python_1.sh deleted file mode 100644 index ab7a8c9..0000000 --- a/problems/nvidia/nvfp4_gemm/test_python_1.sh +++ /dev/null @@ -1,87 +0,0 @@ -# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build -BUILD_DIR=/home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/build_python -LLVM_DIR=$BUILD_DIR/llvm-prebuilt -# # BUILD_DIR=/home/scratch.ftse_gpu/workspace/dkg/build -# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build -# #BUILD_DIR=/home/yanchengz/scratch_1/dynamic-kernel-generator/build_debug2 -# # sudo /home/scratch.computelab/utils/driver/install_driver.py --installer=/home/builds/daily/display/x86_64/rel/gpu_drv/r580/r580_00/20250527_36037303/NVIDIA-Linux-x86_64-rel_gpu_drv_r580_r580_00-20250527_36037303-internal.run --reason="Change to tot driver" - - -# # BUILD_DIR=/home/scratch.nbommi_gpu/warp-phase-trace/dynamic-kernel-generator/build_main - -export PYTHONPATH=$BUILD_DIR/cutlass_ir/python_packages -#export PYTHONPATH=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/scripts -export CUDA_TOOLKIT_PATH=$BUILD_DIR/compiler_next -MLIR_CUDA_RUNTIME="$LLVM_DIR/lib/libmlir_cuda_runtime.so" -MLIR_C_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_c_runner_utils.so" -MLIR_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_runner_utils.so" -CUDA_DIALECT_RUNTIME="$BUILD_DIR/lib/libcuda_dialect_runtime.so" -export CUTE_DSL_LIBS="$MLIR_CUDA_RUNTIME:$MLIR_C_RUNNER_UTILS:$MLIR_RUNNER_UTILS:$CUDA_DIALECT_RUNTIME" - - -#export CUTE_DSL_PREPROCESSOR=True - -# export CUTE_DSL_PRINT_IR=1 -# just compile the IR but not execute it -# export CUTE_DSL_DRYRUN=1 -# export CUTE_DSL_JIT_TIME_PROFILING=ON -# export CUTE_DSL_KEEP_IR=True -# export CUTE_DSL_PRINT_IR=1 -# export CUTE_DSL_KEEP_CUBIN=1 -# export CUTE_DSL_LINEINFO=True -# export CUTE_DSL_LOG_TO_CONSOLE=1 -# export PYTHONUNBUFFERED=1 -# export CUTE_DSL_KEEP_SASS=1 -# whether to show detailed log in preprocessing -# export CUTE_DSL_FILTER_STACKTRACE=10 -export CUTE_DSL_ARCH=sm_100a - -# -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/reference-kernels/problems/nvidia/nvfp4_gemm/submission.py -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration -- -/home/scratch.vickiw_gpu/env/bin/python3 eval_vicki.py benchmark task.yml -/home/scratch.vickiw_gpu/env/bin/python3 eval_vicki.py test task.yml -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/cuda-gdb --args - -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py -# # /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gated_dual_gemm.py - -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemm/nvfp4_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemv/nvfp4_gemv.py -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool memcheck \ -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,16384 #135us -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 4096,128,7168 #62 - -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,2048 #26 - - -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gated_dual_gemm/nvfp4_gated_dual_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_naive.py - - - -# print out ncu time -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# python3 vicki/tutorial_fp16_gemm_0__.py --mnk 7168,8,512 - -# use sanitizer to check race contention and memref error -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck|memcheck -# cutlass_ir/compiler/test/python/examples/sm_100a/test_nvfp4_gemv.py - -# capture ncu report -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --check-exit-code 0 -f --set full --import-source yes --target-processes all --clock-control base --cache-control none -o gemv_4.1 \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv.py --m 128 --k 128 --l 2 - -# regular run python example -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/min_latency_hmma.py --mnkl 7168,8,512,1 - -# run pytest -# pytest cutlass_ir/compiler/test/python/examples/sm_80/test_sgemm.py diff --git a/problems/nvidia/nvfp4_gemv/eval.py b/problems/nvidia/nvfp4_gemv/eval.py deleted file mode 100644 index e8bb5b2..0000000 --- a/problems/nvidia/nvfp4_gemv/eval.py +++ /dev/null @@ -1,500 +0,0 @@ -import base64 -import dataclasses -import multiprocessing -import re -import time -import os -import sys -import math -from pathlib import Path -from typing import Any, Optional -import tempfile - -import torch.cuda -from cutlass.cute.nvgpu.common import OpError - -from utils import set_seed, clear_l2_cache - -try: - from task import TestSpec -except ImportError: - TestSpec = dict - -from reference import check_implementation, generate_input - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, "w") - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -def _combine(a: int, b: int) -> int: - # combine two integers into one: - # we need this to generate a secret seed based on the test-level seed and - # the global secret seed. - # the test-level seeds are public knowledge, and typically relatively small numbers, - # so we need to make sure they don't provide any useful info for the full seed. - # This Cantor construction ensures that if the secret seed is a large number, - # then so is the overall seed. - return int(a + (a + b) * (a + b + 1) // 2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as E: - print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) - exit(113) - - tests = [] - lines = content.splitlines() - match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" - for line in lines: - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - pass - - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - - return tests - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def calculate_stats(durations: list[int]): - """ - Calculate statistical data from a list of durations. - - @param durations: A list of durations in nanoseconds. - @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. - """ - runs = len(durations) - total = sum(durations) - best = min(durations) - worst = max(durations) - - avg = total / runs - variance = sum(map(lambda x: (x - avg) ** 2, durations)) - std = math.sqrt(variance / (runs - 1)) - err = std / math.sqrt(runs) - - return Stats( - runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) - ) - - -def _clone_data(data): - """ - Recursively goes through data and clones all tensors. - """ - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - elif isinstance(data, list): - return [_clone_data(x) for x in data] - elif isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - elif isinstance(data, torch.Tensor): - return data.clone() - else: - return data - - -def _run_single_test(test: TestCase): - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - - data = generate_input(**test.args) - torch.cuda.synchronize() - try: - submission_output = custom_kernel(_clone_data(data)) - - except OpError as E: - print(f"Encountered {E}", file=sys.stderr) - return False, str(E) - torch.cuda.synchronize() - return check_implementation(data, submission_output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - """ - Runs a single test in another process. - """ - return pool.apply(_run_single_test, (test,)) - - -def run_testing( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes the actual test case code and checks for correctness. - - @param logger: A PopcornOutput object used for logging test results. - @param tests: A list of TestCase objects representing the test cases to be executed. - @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. - """ - # Step 1: Compile kernel once before running tests - logger.log("compile", "start") - compile_success, compile_error = pool.apply(_compile_kernel_once) - if not compile_success: - logger.log("compile", "fail") - logger.log("compile.error", compile_error) - return 112 - logger.log("compile", "pass") - - # Step 2: Run all tests with compiled kernel - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if not good: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - else: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _compile_kernel_once(): - """ - Compile the kernel once before any benchmarking. - This ensures compilation time is not included in benchmark results. - """ - from submission import compile_kernel - - try: - # Trigger compilation (will be cached) - compile_kernel() - torch.cuda.synchronize() - return True, None - except OpError as E: - return False, f"Compilation failed: {E}" - except Exception as E: - return False, f"Compilation failed: {E}" - - -def _run_single_benchmark( - test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float -) -> Stats | Any: - """ - Runs one benchmark. Do not call directly. - """ - from submission import custom_kernel, compile_kernel - - durations = [] - # generate input data once - data = generate_input(**test.args) - check_copy = _clone_data(data) - - # Ensure kernel is compiled before any timing (compilation is cached) - try: - compile_kernel() - torch.cuda.synchronize() - except OpError as E: - return f"Compilation failed: {E}" - except Exception as E: - return f"Compilation failed: {E}" - - # first, one obligatory correctness check - try: - output = custom_kernel(_clone_data(data)) - except OpError as E: - return f"Encountered {E}" - good, message = check_implementation(check_copy, output) - if not good: - return message - - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 200 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. - - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - if recheck: - # ensure we use a different seed for every benchmark - if "seed" in test.args: - test.args["seed"] += 13 - - data = generate_input(**test.args) - check_copy = _clone_data(data) - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - clear_l2_cache() - - start_event.record() - output = custom_kernel(data) - end_event.record() - torch.cuda.synchronize() - duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns - - if recheck: - good, message = check_implementation(check_copy, output) - if not good: - return message - - del output - durations.append(duration) - - if i > 1: - total_bm_duration = time.perf_counter_ns() - bm_start_time - stats = calculate_stats(durations) - # stop if either - # a) relative error dips below 0.1% - # b) we exceed the total time limit for benchmarking the kernel - # c) we exceed 2 minutes of total wallclock time. - if ( - stats.err / stats.mean < 0.001 - or stats.mean * stats.runs > max_time_ns - or total_bm_duration > 120e9 - ): - break - - return calculate_stats(durations) - - -def run_single_benchmark( - pool: multiprocessing.Pool, - test: TestCase, - recheck: bool, - max_repeats: int, - max_time_ns: float, -): - """ - For a particular test case, check correctness (if applicable) and grab runtime results. - - @param pool: Process on which the benchmark will be launched. - @param test: TestCase object. - @param recheck: Flag for whether to explicitly check functional correctness. - @param max_repeats: Number of trials to repeat. - @param max_time_ns: Timeout time in nanoseconds. - @return: A Stats object for this particular benchmark case or an error if the test fails. - """ - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) - - -def run_benchmarking( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes benchmarking code for a CUDA Kernel and logs runtimes. - - @param logger: A PopcornOutput object used for logging benchmark results. - @param pool: Process on which the benchmarks will be launched. - @param tests: A list of TestCase objects representing the test cases to be benchmarked. - @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. - """ - # Step 1: Compile kernel once (outside of timing) - logger.log("compile", "start") - compile_success, compile_error = pool.apply(_compile_kernel_once) - if not compile_success: - logger.log("compile", "fail") - logger.log("compile.error", compile_error) - return 112 - logger.log("compile", "pass") - - # Step 2: Warm up with compiled kernel - run_single_benchmark(pool, tests[0], False, 200, 10e7) - - # Step 3: Run benchmarks (compilation time excluded) - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 200, 10e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def run_single_profile(test: TestCase) -> str: - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - from torch.profiler import profile, record_function, ProfilerActivity - - data = generate_input(**test.args) - torch.cuda.synchronize() - - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - submission_output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) - - -def run_profiling(logger: PopcornOutput, tests: list[TestCase]): - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - report = run_single_profile(test) - logger.log( - f"benchmark.{idx}.report", - base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), - ) - logger.log("check", "pass") - return 0 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - - filename = None - - with tempfile.NamedTemporaryFile(delete=False) as tmp: - - def build_test_string(tests: list[dict]): - as_str = "" - for test in tests: - kvs = [] - for k, v in test.items(): - kvs.append(f"{k}: {v}") - as_str += "; ".join(kvs) + "\n" - return as_str - - import yaml - - yaml_content = yaml.safe_load(open(sys.argv[2], "r")) - if mode == "test": - tests_str = build_test_string(yaml_content.get("tests", [])) - elif mode in ("benchmark", "leaderboard", "profile"): - tests_str = build_test_string(yaml_content.get("benchmarks", [])) - - tmp.write(tests_str.encode("utf-8")) - tmp.flush() - filename = tmp.name - - tests = get_test_cases(filename, seed) - - os.unlink(filename) - - with PopcornOutput(int(fd)) as logger: - import multiprocessing - - mp_context = multiprocessing.get_context("spawn") - with mp_context.Pool(1) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - - if mode == "leaderboard": - # Step 1: Compile kernel once (outside of timing) - logger.log("compile", "start") - compile_success, compile_error = pool.apply(_compile_kernel_once) - if not compile_success: - logger.log("compile", "fail") - logger.log("compile.error", compile_error) - return 112 - logger.log("compile", "pass") - - # Step 2: Warmup with compiled kernel - run_single_benchmark(pool, tests[0], False, 200, 1e7) - - # Step 3: Run leaderboard benchmarks (compilation time excluded) - logger.log("benchmark-count", len(tests)) - passed = True - for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 200, 30e9) - logger.log(f"benchmark.{i}.spec", tests[i].spec) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log( - f"benchmark.{i}.{field.name}", - getattr(result, field.name), - ) - else: - passed = False - logger.log(f"benchmark.{i}.status", "fail") - logger.log( - f"benchmark.{i}.error", str(result) - ) # TODO: Make sure result implements __str__? - break - - logger.log("check", "pass" if passed else "fail") - elif mode == "profile": - run_profiling(logger, tests) - else: - # TODO: Implement script mode - return 2 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/problems/nvidia/nvfp4_gemv/log b/problems/nvidia/nvfp4_gemv/log deleted file mode 100644 index 901787e..0000000 --- a/problems/nvidia/nvfp4_gemv/log +++ /dev/null @@ -1,2332 +0,0 @@ -a_ptr : raw_ptr(0x00007f417f608200: f4E2M1FN, gmem, align<16>) -b_ptr : raw_ptr(0x00007f417f60c200: f4E2M1FN, gmem, align<16>) -sfa_ptr : raw_ptr(0x00007f417f610400: f8E4M3FN, gmem, align<32>) -sfb_ptr : raw_ptr(0x00007f417f610c00: f8E4M3FN, gmem, align<32>) -c_ptr : raw_ptr(0x00007f417f610200: f16, gmem, align<16>) -problem_size : (128,1,256,1) -res[0] = 3.000000 - -res[0] = 4.250000 - -res[0] = 4.500000 - -res[0] = 7.500000 - -a.shape : torch.Size([128, 128, 1]) -b.shape : torch.Size([128, 128, 1]) -sfa_cpu.shape : torch.Size([128, 16, 1]) -sfb_cpu.shape : torch.Size([128, 16, 1]) -sfa_reordered_cpu.shape : torch.Size([32, 4, 1, 4, 4, 1]) -sfb_reordered_cpu.shape : torch.Size([32, 4, 1, 4, 4, 1]) -c.shape : torch.Size([128, 1, 1]) -a_ptr : 139919286632960 -b_ptr : 139919286649344 -sfa_ptr : 139919286666240 -sfb_ptr : 139919286668288 -c_ptr : 139919286665728 -problem_size : (128, 1, 256, 1) -c_cute[0, 0, 0] = 7.5 -c_cute[1, 0, 0] = 10.25 -c_cute[2, 0, 0] = 12.25 -c_cute[3, 0, 0] = 15.25 -c_cute[4, 0, 0] = 13.25 -c_cute[5, 0, 0] = 17.25 -c_cute[6, 0, 0] = 15.25 -c_cute[7, 0, 0] = 15.5 -c_cute[8, 0, 0] = 18.0 -c_cute[9, 0, 0] = 12.25 -c_cute[10, 0, 0] = 14.25 -c_cute[11, 0, 0] = 11.5 -c_cute[12, 0, 0] = 15.0 -c_cute[13, 0, 0] = 14.0 -c_cute[14, 0, 0] = 17.0 -c_cute[15, 0, 0] = 13.25 -c_cute[16, 0, 0] = 19.25 -c_cute[17, 0, 0] = 12.75 -c_cute[18, 0, 0] = 12.5 -c_cute[19, 0, 0] = 17.0 -c_cute[20, 0, 0] = 14.25 -c_cute[21, 0, 0] = 16.25 -c_cute[22, 0, 0] = 18.5 -c_cute[23, 0, 0] = 12.0 -c_cute[24, 0, 0] = 17.25 -c_cute[25, 0, 0] = 13.0 -c_cute[26, 0, 0] = 18.25 -c_cute[27, 0, 0] = 17.0 -c_cute[28, 0, 0] = 10.25 -c_cute[29, 0, 0] = 12.75 -c_cute[30, 0, 0] = 17.5 -c_cute[31, 0, 0] = 19.0 -c_cute[32, 0, 0] = 13.5 -c_cute[33, 0, 0] = 14.75 -c_cute[34, 0, 0] = 14.75 -c_cute[35, 0, 0] = 17.25 -c_cute[36, 0, 0] = 15.25 -c_cute[37, 0, 0] = 18.0 -c_cute[38, 0, 0] = 19.25 -c_cute[39, 0, 0] = 13.75 -c_cute[40, 0, 0] = 15.75 -c_cute[41, 0, 0] = 13.5 -c_cute[42, 0, 0] = 12.0 -c_cute[43, 0, 0] = 16.75 -c_cute[44, 0, 0] = 18.75 -c_cute[45, 0, 0] = 12.75 -c_cute[46, 0, 0] = 10.5 -c_cute[47, 0, 0] = 9.25 -c_cute[48, 0, 0] = 12.5 -c_cute[49, 0, 0] = 14.5 -c_cute[50, 0, 0] = 13.25 -c_cute[51, 0, 0] = 17.25 -c_cute[52, 0, 0] = 14.75 -c_cute[53, 0, 0] = 13.75 -c_cute[54, 0, 0] = 13.5 -c_cute[55, 0, 0] = 12.5 -c_cute[56, 0, 0] = 9.75 -c_cute[57, 0, 0] = 11.0 -c_cute[58, 0, 0] = 16.75 -c_cute[59, 0, 0] = 14.0 -c_cute[60, 0, 0] = 16.0 -c_cute[61, 0, 0] = 13.0 -c_cute[62, 0, 0] = 14.75 -c_cute[63, 0, 0] = 14.75 -c_cute[64, 0, 0] = 13.25 -c_cute[65, 0, 0] = 18.0 -c_cute[66, 0, 0] = 15.0 -c_cute[67, 0, 0] = 13.75 -c_cute[68, 0, 0] = 12.5 -c_cute[69, 0, 0] = 15.75 -c_cute[70, 0, 0] = 10.5 -c_cute[71, 0, 0] = 16.25 -c_cute[72, 0, 0] = 16.25 -c_cute[73, 0, 0] = 14.5 -c_cute[74, 0, 0] = 16.0 -c_cute[75, 0, 0] = 17.0 -c_cute[76, 0, 0] = 17.25 -c_cute[77, 0, 0] = 10.5 -c_cute[78, 0, 0] = 12.5 -c_cute[79, 0, 0] = 13.0 -c_cute[80, 0, 0] = 12.5 -c_cute[81, 0, 0] = 11.0 -c_cute[82, 0, 0] = 15.0 -c_cute[83, 0, 0] = 13.75 -c_cute[84, 0, 0] = 12.25 -c_cute[85, 0, 0] = 13.25 -c_cute[86, 0, 0] = 13.75 -c_cute[87, 0, 0] = 17.0 -c_cute[88, 0, 0] = 14.0 -c_cute[89, 0, 0] = 13.0 -c_cute[90, 0, 0] = 14.25 -c_cute[91, 0, 0] = 15.75 -c_cute[92, 0, 0] = 9.5 -c_cute[93, 0, 0] = 13.0 -c_cute[94, 0, 0] = 11.0 -c_cute[95, 0, 0] = 13.75 -c_cute[96, 0, 0] = 15.25 -c_cute[97, 0, 0] = 12.75 -c_cute[98, 0, 0] = 14.5 -c_cute[99, 0, 0] = 13.0 -c_cute[100, 0, 0] = 11.75 -c_cute[101, 0, 0] = 12.0 -c_cute[102, 0, 0] = 18.0 -c_cute[103, 0, 0] = 15.5 -c_cute[104, 0, 0] = 12.75 -c_cute[105, 0, 0] = 12.5 -c_cute[106, 0, 0] = 14.75 -c_cute[107, 0, 0] = 16.75 -c_cute[108, 0, 0] = 13.5 -c_cute[109, 0, 0] = 15.25 -c_cute[110, 0, 0] = 13.5 -c_cute[111, 0, 0] = 11.75 -c_cute[112, 0, 0] = 17.25 -c_cute[113, 0, 0] = 16.25 -c_cute[114, 0, 0] = 11.25 -c_cute[115, 0, 0] = 10.75 -c_cute[116, 0, 0] = 13.5 -c_cute[117, 0, 0] = 11.5 -c_cute[118, 0, 0] = 15.5 -c_cute[119, 0, 0] = 17.25 -c_cute[120, 0, 0] = 14.75 -c_cute[121, 0, 0] = 17.0 -c_cute[122, 0, 0] = 15.5 -c_cute[123, 0, 0] = 14.75 -c_cute[124, 0, 0] = 18.0 -c_cute[125, 0, 0] = 13.0 -c_cute[126, 0, 0] = 15.5 -c_cute[127, 0, 0] = 14.75 --------------------------------- -sfa_ref_cpu[0] = 1.0 -sfa_ref_cpu[1] = 2.0 -sfa_ref_cpu[2] = 2.0 -sfa_ref_cpu[3] = 2.0 -sfa_ref_cpu[4] = 1.0 -sfa_ref_cpu[5] = 1.0 -sfa_ref_cpu[6] = 1.0 -sfa_ref_cpu[7] = 1.0 -sfa_ref_cpu[8] = 1.0 -sfa_ref_cpu[9] = 1.0 -sfa_ref_cpu[10] = 2.0 -sfa_ref_cpu[11] = 1.0 -sfa_ref_cpu[12] = 1.0 -sfa_ref_cpu[13] = 1.0 -sfa_ref_cpu[14] = 1.0 -sfa_ref_cpu[15] = 1.0 -sfa_ref_cpu[16] = 1.0 -sfa_ref_cpu[17] = 1.0 -sfa_ref_cpu[18] = 1.0 -sfa_ref_cpu[19] = 2.0 -sfa_ref_cpu[20] = 1.0 -sfa_ref_cpu[21] = 2.0 -sfa_ref_cpu[22] = 2.0 -sfa_ref_cpu[23] = 1.0 -sfa_ref_cpu[24] = 1.0 -sfa_ref_cpu[25] = 1.0 -sfa_ref_cpu[26] = 1.0 -sfa_ref_cpu[27] = 1.0 -sfa_ref_cpu[28] = 1.0 -sfa_ref_cpu[29] = 2.0 -sfa_ref_cpu[30] = 1.0 -sfa_ref_cpu[31] = 2.0 -sfa_ref_cpu[32] = 2.0 -sfa_ref_cpu[33] = 2.0 -sfa_ref_cpu[34] = 1.0 -sfa_ref_cpu[35] = 1.0 -sfa_ref_cpu[36] = 1.0 -sfa_ref_cpu[37] = 1.0 -sfa_ref_cpu[38] = 2.0 -sfa_ref_cpu[39] = 2.0 -sfa_ref_cpu[40] = 2.0 -sfa_ref_cpu[41] = 2.0 -sfa_ref_cpu[42] = 2.0 -sfa_ref_cpu[43] = 1.0 -sfa_ref_cpu[44] = 1.0 -sfa_ref_cpu[45] = 1.0 -sfa_ref_cpu[46] = 2.0 -sfa_ref_cpu[47] = 2.0 -sfa_ref_cpu[48] = 2.0 -sfa_ref_cpu[49] = 2.0 -sfa_ref_cpu[50] = 2.0 -sfa_ref_cpu[51] = 2.0 -sfa_ref_cpu[52] = 2.0 -sfa_ref_cpu[53] = 1.0 -sfa_ref_cpu[54] = 2.0 -sfa_ref_cpu[55] = 1.0 -sfa_ref_cpu[56] = 1.0 -sfa_ref_cpu[57] = 1.0 -sfa_ref_cpu[58] = 1.0 -sfa_ref_cpu[59] = 1.0 -sfa_ref_cpu[60] = 2.0 -sfa_ref_cpu[61] = 2.0 -sfa_ref_cpu[62] = 2.0 -sfa_ref_cpu[63] = 2.0 -sfa_ref_cpu[64] = 1.0 -sfa_ref_cpu[65] = 1.0 -sfa_ref_cpu[66] = 1.0 -sfa_ref_cpu[67] = 2.0 -sfa_ref_cpu[68] = 2.0 -sfa_ref_cpu[69] = 2.0 -sfa_ref_cpu[70] = 2.0 -sfa_ref_cpu[71] = 1.0 -sfa_ref_cpu[72] = 2.0 -sfa_ref_cpu[73] = 1.0 -sfa_ref_cpu[74] = 2.0 -sfa_ref_cpu[75] = 2.0 -sfa_ref_cpu[76] = 1.0 -sfa_ref_cpu[77] = 2.0 -sfa_ref_cpu[78] = 1.0 -sfa_ref_cpu[79] = 2.0 -sfa_ref_cpu[80] = 2.0 -sfa_ref_cpu[81] = 2.0 -sfa_ref_cpu[82] = 2.0 -sfa_ref_cpu[83] = 1.0 -sfa_ref_cpu[84] = 2.0 -sfa_ref_cpu[85] = 1.0 -sfa_ref_cpu[86] = 1.0 -sfa_ref_cpu[87] = 1.0 -sfa_ref_cpu[88] = 2.0 -sfa_ref_cpu[89] = 2.0 -sfa_ref_cpu[90] = 2.0 -sfa_ref_cpu[91] = 1.0 -sfa_ref_cpu[92] = 2.0 -sfa_ref_cpu[93] = 2.0 -sfa_ref_cpu[94] = 1.0 -sfa_ref_cpu[95] = 2.0 -sfa_ref_cpu[96] = 2.0 -sfa_ref_cpu[97] = 2.0 -sfa_ref_cpu[98] = 2.0 -sfa_ref_cpu[99] = 2.0 -sfa_ref_cpu[100] = 2.0 -sfa_ref_cpu[101] = 1.0 -sfa_ref_cpu[102] = 1.0 -sfa_ref_cpu[103] = 1.0 -sfa_ref_cpu[104] = 1.0 -sfa_ref_cpu[105] = 2.0 -sfa_ref_cpu[106] = 1.0 -sfa_ref_cpu[107] = 2.0 -sfa_ref_cpu[108] = 1.0 -sfa_ref_cpu[109] = 2.0 -sfa_ref_cpu[110] = 2.0 -sfa_ref_cpu[111] = 2.0 -sfa_ref_cpu[112] = 1.0 -sfa_ref_cpu[113] = 1.0 -sfa_ref_cpu[114] = 1.0 -sfa_ref_cpu[115] = 2.0 -sfa_ref_cpu[116] = 1.0 -sfa_ref_cpu[117] = 1.0 -sfa_ref_cpu[118] = 2.0 -sfa_ref_cpu[119] = 1.0 -sfa_ref_cpu[120] = 2.0 -sfa_ref_cpu[121] = 1.0 -sfa_ref_cpu[122] = 1.0 -sfa_ref_cpu[123] = 2.0 -sfa_ref_cpu[124] = 2.0 -sfa_ref_cpu[125] = 2.0 -sfa_ref_cpu[126] = 2.0 -sfa_ref_cpu[127] = 1.0 -sfa_ref_cpu[128] = 2.0 -sfa_ref_cpu[129] = 2.0 -sfa_ref_cpu[130] = 2.0 -sfa_ref_cpu[131] = 1.0 -sfa_ref_cpu[132] = 2.0 -sfa_ref_cpu[133] = 2.0 -sfa_ref_cpu[134] = 1.0 -sfa_ref_cpu[135] = 2.0 -sfa_ref_cpu[136] = 1.0 -sfa_ref_cpu[137] = 1.0 -sfa_ref_cpu[138] = 2.0 -sfa_ref_cpu[139] = 1.0 -sfa_ref_cpu[140] = 1.0 -sfa_ref_cpu[141] = 2.0 -sfa_ref_cpu[142] = 1.0 -sfa_ref_cpu[143] = 1.0 -sfa_ref_cpu[144] = 1.0 -sfa_ref_cpu[145] = 2.0 -sfa_ref_cpu[146] = 1.0 -sfa_ref_cpu[147] = 2.0 -sfa_ref_cpu[148] = 2.0 -sfa_ref_cpu[149] = 2.0 -sfa_ref_cpu[150] = 2.0 -sfa_ref_cpu[151] = 1.0 -sfa_ref_cpu[152] = 2.0 -sfa_ref_cpu[153] = 2.0 -sfa_ref_cpu[154] = 2.0 -sfa_ref_cpu[155] = 1.0 -sfa_ref_cpu[156] = 1.0 -sfa_ref_cpu[157] = 1.0 -sfa_ref_cpu[158] = 1.0 -sfa_ref_cpu[159] = 1.0 -sfa_ref_cpu[160] = 1.0 -sfa_ref_cpu[161] = 1.0 -sfa_ref_cpu[162] = 1.0 -sfa_ref_cpu[163] = 2.0 -sfa_ref_cpu[164] = 2.0 -sfa_ref_cpu[165] = 2.0 -sfa_ref_cpu[166] = 1.0 -sfa_ref_cpu[167] = 1.0 -sfa_ref_cpu[168] = 2.0 -sfa_ref_cpu[169] = 1.0 -sfa_ref_cpu[170] = 1.0 -sfa_ref_cpu[171] = 2.0 -sfa_ref_cpu[172] = 1.0 -sfa_ref_cpu[173] = 1.0 -sfa_ref_cpu[174] = 2.0 -sfa_ref_cpu[175] = 2.0 -sfa_ref_cpu[176] = 1.0 -sfa_ref_cpu[177] = 1.0 -sfa_ref_cpu[178] = 1.0 -sfa_ref_cpu[179] = 2.0 -sfa_ref_cpu[180] = 1.0 -sfa_ref_cpu[181] = 1.0 -sfa_ref_cpu[182] = 1.0 -sfa_ref_cpu[183] = 1.0 -sfa_ref_cpu[184] = 2.0 -sfa_ref_cpu[185] = 1.0 -sfa_ref_cpu[186] = 1.0 -sfa_ref_cpu[187] = 1.0 -sfa_ref_cpu[188] = 2.0 -sfa_ref_cpu[189] = 1.0 -sfa_ref_cpu[190] = 2.0 -sfa_ref_cpu[191] = 2.0 -sfa_ref_cpu[192] = 2.0 -sfa_ref_cpu[193] = 1.0 -sfa_ref_cpu[194] = 2.0 -sfa_ref_cpu[195] = 2.0 -sfa_ref_cpu[196] = 1.0 -sfa_ref_cpu[197] = 2.0 -sfa_ref_cpu[198] = 2.0 -sfa_ref_cpu[199] = 2.0 -sfa_ref_cpu[200] = 1.0 -sfa_ref_cpu[201] = 2.0 -sfa_ref_cpu[202] = 2.0 -sfa_ref_cpu[203] = 2.0 -sfa_ref_cpu[204] = 1.0 -sfa_ref_cpu[205] = 1.0 -sfa_ref_cpu[206] = 2.0 -sfa_ref_cpu[207] = 2.0 -sfa_ref_cpu[208] = 2.0 -sfa_ref_cpu[209] = 2.0 -sfa_ref_cpu[210] = 1.0 -sfa_ref_cpu[211] = 2.0 -sfa_ref_cpu[212] = 2.0 -sfa_ref_cpu[213] = 1.0 -sfa_ref_cpu[214] = 2.0 -sfa_ref_cpu[215] = 1.0 -sfa_ref_cpu[216] = 2.0 -sfa_ref_cpu[217] = 2.0 -sfa_ref_cpu[218] = 1.0 -sfa_ref_cpu[219] = 1.0 -sfa_ref_cpu[220] = 1.0 -sfa_ref_cpu[221] = 2.0 -sfa_ref_cpu[222] = 1.0 -sfa_ref_cpu[223] = 1.0 -sfa_ref_cpu[224] = 2.0 -sfa_ref_cpu[225] = 1.0 -sfa_ref_cpu[226] = 1.0 -sfa_ref_cpu[227] = 2.0 -sfa_ref_cpu[228] = 1.0 -sfa_ref_cpu[229] = 1.0 -sfa_ref_cpu[230] = 1.0 -sfa_ref_cpu[231] = 1.0 -sfa_ref_cpu[232] = 2.0 -sfa_ref_cpu[233] = 2.0 -sfa_ref_cpu[234] = 2.0 -sfa_ref_cpu[235] = 2.0 -sfa_ref_cpu[236] = 2.0 -sfa_ref_cpu[237] = 2.0 -sfa_ref_cpu[238] = 2.0 -sfa_ref_cpu[239] = 2.0 -sfa_ref_cpu[240] = 2.0 -sfa_ref_cpu[241] = 1.0 -sfa_ref_cpu[242] = 2.0 -sfa_ref_cpu[243] = 1.0 -sfa_ref_cpu[244] = 2.0 -sfa_ref_cpu[245] = 1.0 -sfa_ref_cpu[246] = 1.0 -sfa_ref_cpu[247] = 2.0 -sfa_ref_cpu[248] = 1.0 -sfa_ref_cpu[249] = 1.0 -sfa_ref_cpu[250] = 1.0 -sfa_ref_cpu[251] = 2.0 -sfa_ref_cpu[252] = 1.0 -sfa_ref_cpu[253] = 2.0 -sfa_ref_cpu[254] = 2.0 -sfa_ref_cpu[255] = 1.0 -sfa_ref_cpu[256] = 1.0 -sfa_ref_cpu[257] = 2.0 -sfa_ref_cpu[258] = 2.0 -sfa_ref_cpu[259] = 1.0 -sfa_ref_cpu[260] = 2.0 -sfa_ref_cpu[261] = 1.0 -sfa_ref_cpu[262] = 2.0 -sfa_ref_cpu[263] = 2.0 -sfa_ref_cpu[264] = 1.0 -sfa_ref_cpu[265] = 2.0 -sfa_ref_cpu[266] = 2.0 -sfa_ref_cpu[267] = 2.0 -sfa_ref_cpu[268] = 1.0 -sfa_ref_cpu[269] = 2.0 -sfa_ref_cpu[270] = 2.0 -sfa_ref_cpu[271] = 2.0 -sfa_ref_cpu[272] = 1.0 -sfa_ref_cpu[273] = 1.0 -sfa_ref_cpu[274] = 1.0 -sfa_ref_cpu[275] = 1.0 -sfa_ref_cpu[276] = 1.0 -sfa_ref_cpu[277] = 2.0 -sfa_ref_cpu[278] = 2.0 -sfa_ref_cpu[279] = 2.0 -sfa_ref_cpu[280] = 1.0 -sfa_ref_cpu[281] = 1.0 -sfa_ref_cpu[282] = 2.0 -sfa_ref_cpu[283] = 1.0 -sfa_ref_cpu[284] = 1.0 -sfa_ref_cpu[285] = 2.0 -sfa_ref_cpu[286] = 2.0 -sfa_ref_cpu[287] = 1.0 -sfa_ref_cpu[288] = 1.0 -sfa_ref_cpu[289] = 2.0 -sfa_ref_cpu[290] = 1.0 -sfa_ref_cpu[291] = 2.0 -sfa_ref_cpu[292] = 1.0 -sfa_ref_cpu[293] = 1.0 -sfa_ref_cpu[294] = 1.0 -sfa_ref_cpu[295] = 1.0 -sfa_ref_cpu[296] = 2.0 -sfa_ref_cpu[297] = 2.0 -sfa_ref_cpu[298] = 2.0 -sfa_ref_cpu[299] = 1.0 -sfa_ref_cpu[300] = 1.0 -sfa_ref_cpu[301] = 2.0 -sfa_ref_cpu[302] = 1.0 -sfa_ref_cpu[303] = 1.0 -sfa_ref_cpu[304] = 2.0 -sfa_ref_cpu[305] = 1.0 -sfa_ref_cpu[306] = 2.0 -sfa_ref_cpu[307] = 1.0 -sfa_ref_cpu[308] = 1.0 -sfa_ref_cpu[309] = 2.0 -sfa_ref_cpu[310] = 1.0 -sfa_ref_cpu[311] = 2.0 -sfa_ref_cpu[312] = 1.0 -sfa_ref_cpu[313] = 2.0 -sfa_ref_cpu[314] = 2.0 -sfa_ref_cpu[315] = 2.0 -sfa_ref_cpu[316] = 2.0 -sfa_ref_cpu[317] = 1.0 -sfa_ref_cpu[318] = 1.0 -sfa_ref_cpu[319] = 2.0 -sfa_ref_cpu[320] = 1.0 -sfa_ref_cpu[321] = 2.0 -sfa_ref_cpu[322] = 1.0 -sfa_ref_cpu[323] = 1.0 -sfa_ref_cpu[324] = 2.0 -sfa_ref_cpu[325] = 1.0 -sfa_ref_cpu[326] = 2.0 -sfa_ref_cpu[327] = 1.0 -sfa_ref_cpu[328] = 2.0 -sfa_ref_cpu[329] = 2.0 -sfa_ref_cpu[330] = 1.0 -sfa_ref_cpu[331] = 2.0 -sfa_ref_cpu[332] = 2.0 -sfa_ref_cpu[333] = 1.0 -sfa_ref_cpu[334] = 1.0 -sfa_ref_cpu[335] = 1.0 -sfa_ref_cpu[336] = 1.0 -sfa_ref_cpu[337] = 1.0 -sfa_ref_cpu[338] = 1.0 -sfa_ref_cpu[339] = 2.0 -sfa_ref_cpu[340] = 2.0 -sfa_ref_cpu[341] = 2.0 -sfa_ref_cpu[342] = 2.0 -sfa_ref_cpu[343] = 2.0 -sfa_ref_cpu[344] = 1.0 -sfa_ref_cpu[345] = 1.0 -sfa_ref_cpu[346] = 1.0 -sfa_ref_cpu[347] = 2.0 -sfa_ref_cpu[348] = 2.0 -sfa_ref_cpu[349] = 2.0 -sfa_ref_cpu[350] = 1.0 -sfa_ref_cpu[351] = 1.0 -sfa_ref_cpu[352] = 2.0 -sfa_ref_cpu[353] = 2.0 -sfa_ref_cpu[354] = 2.0 -sfa_ref_cpu[355] = 2.0 -sfa_ref_cpu[356] = 2.0 -sfa_ref_cpu[357] = 1.0 -sfa_ref_cpu[358] = 1.0 -sfa_ref_cpu[359] = 2.0 -sfa_ref_cpu[360] = 1.0 -sfa_ref_cpu[361] = 1.0 -sfa_ref_cpu[362] = 2.0 -sfa_ref_cpu[363] = 1.0 -sfa_ref_cpu[364] = 2.0 -sfa_ref_cpu[365] = 2.0 -sfa_ref_cpu[366] = 2.0 -sfa_ref_cpu[367] = 1.0 -sfa_ref_cpu[368] = 1.0 -sfa_ref_cpu[369] = 1.0 -sfa_ref_cpu[370] = 2.0 -sfa_ref_cpu[371] = 1.0 -sfa_ref_cpu[372] = 1.0 -sfa_ref_cpu[373] = 2.0 -sfa_ref_cpu[374] = 2.0 -sfa_ref_cpu[375] = 2.0 -sfa_ref_cpu[376] = 1.0 -sfa_ref_cpu[377] = 2.0 -sfa_ref_cpu[378] = 1.0 -sfa_ref_cpu[379] = 1.0 -sfa_ref_cpu[380] = 2.0 -sfa_ref_cpu[381] = 1.0 -sfa_ref_cpu[382] = 2.0 -sfa_ref_cpu[383] = 2.0 -sfa_ref_cpu[384] = 2.0 -sfa_ref_cpu[385] = 2.0 -sfa_ref_cpu[386] = 2.0 -sfa_ref_cpu[387] = 2.0 -sfa_ref_cpu[388] = 2.0 -sfa_ref_cpu[389] = 1.0 -sfa_ref_cpu[390] = 1.0 -sfa_ref_cpu[391] = 1.0 -sfa_ref_cpu[392] = 1.0 -sfa_ref_cpu[393] = 1.0 -sfa_ref_cpu[394] = 1.0 -sfa_ref_cpu[395] = 2.0 -sfa_ref_cpu[396] = 2.0 -sfa_ref_cpu[397] = 2.0 -sfa_ref_cpu[398] = 2.0 -sfa_ref_cpu[399] = 1.0 -sfa_ref_cpu[400] = 2.0 -sfa_ref_cpu[401] = 1.0 -sfa_ref_cpu[402] = 1.0 -sfa_ref_cpu[403] = 1.0 -sfa_ref_cpu[404] = 1.0 -sfa_ref_cpu[405] = 1.0 -sfa_ref_cpu[406] = 2.0 -sfa_ref_cpu[407] = 1.0 -sfa_ref_cpu[408] = 1.0 -sfa_ref_cpu[409] = 1.0 -sfa_ref_cpu[410] = 1.0 -sfa_ref_cpu[411] = 1.0 -sfa_ref_cpu[412] = 1.0 -sfa_ref_cpu[413] = 2.0 -sfa_ref_cpu[414] = 2.0 -sfa_ref_cpu[415] = 1.0 -sfa_ref_cpu[416] = 2.0 -sfa_ref_cpu[417] = 2.0 -sfa_ref_cpu[418] = 2.0 -sfa_ref_cpu[419] = 1.0 -sfa_ref_cpu[420] = 1.0 -sfa_ref_cpu[421] = 1.0 -sfa_ref_cpu[422] = 2.0 -sfa_ref_cpu[423] = 2.0 -sfa_ref_cpu[424] = 2.0 -sfa_ref_cpu[425] = 2.0 -sfa_ref_cpu[426] = 2.0 -sfa_ref_cpu[427] = 1.0 -sfa_ref_cpu[428] = 2.0 -sfa_ref_cpu[429] = 1.0 -sfa_ref_cpu[430] = 1.0 -sfa_ref_cpu[431] = 1.0 -sfa_ref_cpu[432] = 1.0 -sfa_ref_cpu[433] = 2.0 -sfa_ref_cpu[434] = 1.0 -sfa_ref_cpu[435] = 2.0 -sfa_ref_cpu[436] = 2.0 -sfa_ref_cpu[437] = 1.0 -sfa_ref_cpu[438] = 1.0 -sfa_ref_cpu[439] = 1.0 -sfa_ref_cpu[440] = 2.0 -sfa_ref_cpu[441] = 2.0 -sfa_ref_cpu[442] = 2.0 -sfa_ref_cpu[443] = 2.0 -sfa_ref_cpu[444] = 2.0 -sfa_ref_cpu[445] = 2.0 -sfa_ref_cpu[446] = 2.0 -sfa_ref_cpu[447] = 2.0 -sfa_ref_cpu[448] = 1.0 -sfa_ref_cpu[449] = 2.0 -sfa_ref_cpu[450] = 1.0 -sfa_ref_cpu[451] = 1.0 -sfa_ref_cpu[452] = 2.0 -sfa_ref_cpu[453] = 1.0 -sfa_ref_cpu[454] = 2.0 -sfa_ref_cpu[455] = 1.0 -sfa_ref_cpu[456] = 1.0 -sfa_ref_cpu[457] = 2.0 -sfa_ref_cpu[458] = 1.0 -sfa_ref_cpu[459] = 2.0 -sfa_ref_cpu[460] = 1.0 -sfa_ref_cpu[461] = 2.0 -sfa_ref_cpu[462] = 2.0 -sfa_ref_cpu[463] = 1.0 -sfa_ref_cpu[464] = 1.0 -sfa_ref_cpu[465] = 1.0 -sfa_ref_cpu[466] = 1.0 -sfa_ref_cpu[467] = 1.0 -sfa_ref_cpu[468] = 1.0 -sfa_ref_cpu[469] = 2.0 -sfa_ref_cpu[470] = 1.0 -sfa_ref_cpu[471] = 2.0 -sfa_ref_cpu[472] = 2.0 -sfa_ref_cpu[473] = 1.0 -sfa_ref_cpu[474] = 2.0 -sfa_ref_cpu[475] = 2.0 -sfa_ref_cpu[476] = 1.0 -sfa_ref_cpu[477] = 2.0 -sfa_ref_cpu[478] = 2.0 -sfa_ref_cpu[479] = 1.0 -sfa_ref_cpu[480] = 1.0 -sfa_ref_cpu[481] = 1.0 -sfa_ref_cpu[482] = 1.0 -sfa_ref_cpu[483] = 2.0 -sfa_ref_cpu[484] = 2.0 -sfa_ref_cpu[485] = 1.0 -sfa_ref_cpu[486] = 1.0 -sfa_ref_cpu[487] = 1.0 -sfa_ref_cpu[488] = 2.0 -sfa_ref_cpu[489] = 2.0 -sfa_ref_cpu[490] = 1.0 -sfa_ref_cpu[491] = 2.0 -sfa_ref_cpu[492] = 2.0 -sfa_ref_cpu[493] = 1.0 -sfa_ref_cpu[494] = 1.0 -sfa_ref_cpu[495] = 1.0 -sfa_ref_cpu[496] = 2.0 -sfa_ref_cpu[497] = 2.0 -sfa_ref_cpu[498] = 1.0 -sfa_ref_cpu[499] = 2.0 -sfa_ref_cpu[500] = 1.0 -sfa_ref_cpu[501] = 1.0 -sfa_ref_cpu[502] = 2.0 -sfa_ref_cpu[503] = 1.0 -sfa_ref_cpu[504] = 2.0 -sfa_ref_cpu[505] = 2.0 -sfa_ref_cpu[506] = 2.0 -sfa_ref_cpu[507] = 1.0 -sfa_ref_cpu[508] = 2.0 -sfa_ref_cpu[509] = 2.0 -sfa_ref_cpu[510] = 2.0 -sfa_ref_cpu[511] = 1.0 -sfa_ref_cpu[512] = 2.0 -sfa_ref_cpu[513] = 2.0 -sfa_ref_cpu[514] = 2.0 -sfa_ref_cpu[515] = 1.0 -sfa_ref_cpu[516] = 1.0 -sfa_ref_cpu[517] = 1.0 -sfa_ref_cpu[518] = 2.0 -sfa_ref_cpu[519] = 2.0 -sfa_ref_cpu[520] = 1.0 -sfa_ref_cpu[521] = 1.0 -sfa_ref_cpu[522] = 2.0 -sfa_ref_cpu[523] = 1.0 -sfa_ref_cpu[524] = 2.0 -sfa_ref_cpu[525] = 2.0 -sfa_ref_cpu[526] = 2.0 -sfa_ref_cpu[527] = 2.0 -sfa_ref_cpu[528] = 1.0 -sfa_ref_cpu[529] = 2.0 -sfa_ref_cpu[530] = 2.0 -sfa_ref_cpu[531] = 2.0 -sfa_ref_cpu[532] = 1.0 -sfa_ref_cpu[533] = 2.0 -sfa_ref_cpu[534] = 2.0 -sfa_ref_cpu[535] = 2.0 -sfa_ref_cpu[536] = 1.0 -sfa_ref_cpu[537] = 2.0 -sfa_ref_cpu[538] = 2.0 -sfa_ref_cpu[539] = 1.0 -sfa_ref_cpu[540] = 1.0 -sfa_ref_cpu[541] = 1.0 -sfa_ref_cpu[542] = 1.0 -sfa_ref_cpu[543] = 2.0 -sfa_ref_cpu[544] = 1.0 -sfa_ref_cpu[545] = 2.0 -sfa_ref_cpu[546] = 2.0 -sfa_ref_cpu[547] = 2.0 -sfa_ref_cpu[548] = 1.0 -sfa_ref_cpu[549] = 1.0 -sfa_ref_cpu[550] = 1.0 -sfa_ref_cpu[551] = 1.0 -sfa_ref_cpu[552] = 1.0 -sfa_ref_cpu[553] = 2.0 -sfa_ref_cpu[554] = 2.0 -sfa_ref_cpu[555] = 2.0 -sfa_ref_cpu[556] = 2.0 -sfa_ref_cpu[557] = 1.0 -sfa_ref_cpu[558] = 1.0 -sfa_ref_cpu[559] = 2.0 -sfa_ref_cpu[560] = 1.0 -sfa_ref_cpu[561] = 2.0 -sfa_ref_cpu[562] = 1.0 -sfa_ref_cpu[563] = 1.0 -sfa_ref_cpu[564] = 2.0 -sfa_ref_cpu[565] = 1.0 -sfa_ref_cpu[566] = 2.0 -sfa_ref_cpu[567] = 2.0 -sfa_ref_cpu[568] = 1.0 -sfa_ref_cpu[569] = 2.0 -sfa_ref_cpu[570] = 1.0 -sfa_ref_cpu[571] = 2.0 -sfa_ref_cpu[572] = 1.0 -sfa_ref_cpu[573] = 1.0 -sfa_ref_cpu[574] = 1.0 -sfa_ref_cpu[575] = 2.0 -sfa_ref_cpu[576] = 2.0 -sfa_ref_cpu[577] = 1.0 -sfa_ref_cpu[578] = 1.0 -sfa_ref_cpu[579] = 2.0 -sfa_ref_cpu[580] = 1.0 -sfa_ref_cpu[581] = 1.0 -sfa_ref_cpu[582] = 2.0 -sfa_ref_cpu[583] = 1.0 -sfa_ref_cpu[584] = 2.0 -sfa_ref_cpu[585] = 2.0 -sfa_ref_cpu[586] = 1.0 -sfa_ref_cpu[587] = 1.0 -sfa_ref_cpu[588] = 2.0 -sfa_ref_cpu[589] = 2.0 -sfa_ref_cpu[590] = 2.0 -sfa_ref_cpu[591] = 1.0 -sfa_ref_cpu[592] = 1.0 -sfa_ref_cpu[593] = 1.0 -sfa_ref_cpu[594] = 1.0 -sfa_ref_cpu[595] = 1.0 -sfa_ref_cpu[596] = 2.0 -sfa_ref_cpu[597] = 2.0 -sfa_ref_cpu[598] = 2.0 -sfa_ref_cpu[599] = 2.0 -sfa_ref_cpu[600] = 2.0 -sfa_ref_cpu[601] = 2.0 -sfa_ref_cpu[602] = 2.0 -sfa_ref_cpu[603] = 2.0 -sfa_ref_cpu[604] = 1.0 -sfa_ref_cpu[605] = 2.0 -sfa_ref_cpu[606] = 2.0 -sfa_ref_cpu[607] = 1.0 -sfa_ref_cpu[608] = 2.0 -sfa_ref_cpu[609] = 2.0 -sfa_ref_cpu[610] = 2.0 -sfa_ref_cpu[611] = 1.0 -sfa_ref_cpu[612] = 1.0 -sfa_ref_cpu[613] = 1.0 -sfa_ref_cpu[614] = 2.0 -sfa_ref_cpu[615] = 1.0 -sfa_ref_cpu[616] = 2.0 -sfa_ref_cpu[617] = 2.0 -sfa_ref_cpu[618] = 2.0 -sfa_ref_cpu[619] = 2.0 -sfa_ref_cpu[620] = 1.0 -sfa_ref_cpu[621] = 2.0 -sfa_ref_cpu[622] = 2.0 -sfa_ref_cpu[623] = 2.0 -sfa_ref_cpu[624] = 2.0 -sfa_ref_cpu[625] = 1.0 -sfa_ref_cpu[626] = 1.0 -sfa_ref_cpu[627] = 1.0 -sfa_ref_cpu[628] = 2.0 -sfa_ref_cpu[629] = 1.0 -sfa_ref_cpu[630] = 1.0 -sfa_ref_cpu[631] = 1.0 -sfa_ref_cpu[632] = 1.0 -sfa_ref_cpu[633] = 2.0 -sfa_ref_cpu[634] = 1.0 -sfa_ref_cpu[635] = 2.0 -sfa_ref_cpu[636] = 2.0 -sfa_ref_cpu[637] = 2.0 -sfa_ref_cpu[638] = 1.0 -sfa_ref_cpu[639] = 2.0 -sfa_ref_cpu[640] = 2.0 -sfa_ref_cpu[641] = 1.0 -sfa_ref_cpu[642] = 2.0 -sfa_ref_cpu[643] = 1.0 -sfa_ref_cpu[644] = 1.0 -sfa_ref_cpu[645] = 1.0 -sfa_ref_cpu[646] = 2.0 -sfa_ref_cpu[647] = 2.0 -sfa_ref_cpu[648] = 1.0 -sfa_ref_cpu[649] = 1.0 -sfa_ref_cpu[650] = 2.0 -sfa_ref_cpu[651] = 1.0 -sfa_ref_cpu[652] = 1.0 -sfa_ref_cpu[653] = 1.0 -sfa_ref_cpu[654] = 1.0 -sfa_ref_cpu[655] = 1.0 -sfa_ref_cpu[656] = 1.0 -sfa_ref_cpu[657] = 1.0 -sfa_ref_cpu[658] = 1.0 -sfa_ref_cpu[659] = 1.0 -sfa_ref_cpu[660] = 2.0 -sfa_ref_cpu[661] = 1.0 -sfa_ref_cpu[662] = 2.0 -sfa_ref_cpu[663] = 1.0 -sfa_ref_cpu[664] = 2.0 -sfa_ref_cpu[665] = 1.0 -sfa_ref_cpu[666] = 1.0 -sfa_ref_cpu[667] = 1.0 -sfa_ref_cpu[668] = 1.0 -sfa_ref_cpu[669] = 2.0 -sfa_ref_cpu[670] = 1.0 -sfa_ref_cpu[671] = 1.0 -sfa_ref_cpu[672] = 1.0 -sfa_ref_cpu[673] = 1.0 -sfa_ref_cpu[674] = 2.0 -sfa_ref_cpu[675] = 1.0 -sfa_ref_cpu[676] = 1.0 -sfa_ref_cpu[677] = 1.0 -sfa_ref_cpu[678] = 1.0 -sfa_ref_cpu[679] = 2.0 -sfa_ref_cpu[680] = 2.0 -sfa_ref_cpu[681] = 1.0 -sfa_ref_cpu[682] = 1.0 -sfa_ref_cpu[683] = 1.0 -sfa_ref_cpu[684] = 2.0 -sfa_ref_cpu[685] = 2.0 -sfa_ref_cpu[686] = 2.0 -sfa_ref_cpu[687] = 1.0 -sfa_ref_cpu[688] = 1.0 -sfa_ref_cpu[689] = 2.0 -sfa_ref_cpu[690] = 2.0 -sfa_ref_cpu[691] = 1.0 -sfa_ref_cpu[692] = 2.0 -sfa_ref_cpu[693] = 1.0 -sfa_ref_cpu[694] = 2.0 -sfa_ref_cpu[695] = 1.0 -sfa_ref_cpu[696] = 1.0 -sfa_ref_cpu[697] = 2.0 -sfa_ref_cpu[698] = 1.0 -sfa_ref_cpu[699] = 1.0 -sfa_ref_cpu[700] = 2.0 -sfa_ref_cpu[701] = 2.0 -sfa_ref_cpu[702] = 1.0 -sfa_ref_cpu[703] = 2.0 -sfa_ref_cpu[704] = 2.0 -sfa_ref_cpu[705] = 2.0 -sfa_ref_cpu[706] = 1.0 -sfa_ref_cpu[707] = 2.0 -sfa_ref_cpu[708] = 2.0 -sfa_ref_cpu[709] = 2.0 -sfa_ref_cpu[710] = 1.0 -sfa_ref_cpu[711] = 2.0 -sfa_ref_cpu[712] = 2.0 -sfa_ref_cpu[713] = 2.0 -sfa_ref_cpu[714] = 2.0 -sfa_ref_cpu[715] = 2.0 -sfa_ref_cpu[716] = 2.0 -sfa_ref_cpu[717] = 2.0 -sfa_ref_cpu[718] = 1.0 -sfa_ref_cpu[719] = 2.0 -sfa_ref_cpu[720] = 1.0 -sfa_ref_cpu[721] = 1.0 -sfa_ref_cpu[722] = 1.0 -sfa_ref_cpu[723] = 1.0 -sfa_ref_cpu[724] = 1.0 -sfa_ref_cpu[725] = 1.0 -sfa_ref_cpu[726] = 1.0 -sfa_ref_cpu[727] = 2.0 -sfa_ref_cpu[728] = 2.0 -sfa_ref_cpu[729] = 2.0 -sfa_ref_cpu[730] = 2.0 -sfa_ref_cpu[731] = 1.0 -sfa_ref_cpu[732] = 1.0 -sfa_ref_cpu[733] = 2.0 -sfa_ref_cpu[734] = 2.0 -sfa_ref_cpu[735] = 1.0 -sfa_ref_cpu[736] = 1.0 -sfa_ref_cpu[737] = 2.0 -sfa_ref_cpu[738] = 2.0 -sfa_ref_cpu[739] = 2.0 -sfa_ref_cpu[740] = 2.0 -sfa_ref_cpu[741] = 2.0 -sfa_ref_cpu[742] = 1.0 -sfa_ref_cpu[743] = 2.0 -sfa_ref_cpu[744] = 2.0 -sfa_ref_cpu[745] = 2.0 -sfa_ref_cpu[746] = 1.0 -sfa_ref_cpu[747] = 1.0 -sfa_ref_cpu[748] = 1.0 -sfa_ref_cpu[749] = 1.0 -sfa_ref_cpu[750] = 2.0 -sfa_ref_cpu[751] = 2.0 -sfa_ref_cpu[752] = 1.0 -sfa_ref_cpu[753] = 1.0 -sfa_ref_cpu[754] = 1.0 -sfa_ref_cpu[755] = 1.0 -sfa_ref_cpu[756] = 1.0 -sfa_ref_cpu[757] = 1.0 -sfa_ref_cpu[758] = 1.0 -sfa_ref_cpu[759] = 1.0 -sfa_ref_cpu[760] = 2.0 -sfa_ref_cpu[761] = 2.0 -sfa_ref_cpu[762] = 2.0 -sfa_ref_cpu[763] = 1.0 -sfa_ref_cpu[764] = 1.0 -sfa_ref_cpu[765] = 2.0 -sfa_ref_cpu[766] = 1.0 -sfa_ref_cpu[767] = 1.0 -sfa_ref_cpu[768] = 2.0 -sfa_ref_cpu[769] = 1.0 -sfa_ref_cpu[770] = 2.0 -sfa_ref_cpu[771] = 2.0 -sfa_ref_cpu[772] = 2.0 -sfa_ref_cpu[773] = 2.0 -sfa_ref_cpu[774] = 2.0 -sfa_ref_cpu[775] = 2.0 -sfa_ref_cpu[776] = 2.0 -sfa_ref_cpu[777] = 2.0 -sfa_ref_cpu[778] = 2.0 -sfa_ref_cpu[779] = 1.0 -sfa_ref_cpu[780] = 1.0 -sfa_ref_cpu[781] = 2.0 -sfa_ref_cpu[782] = 2.0 -sfa_ref_cpu[783] = 1.0 -sfa_ref_cpu[784] = 1.0 -sfa_ref_cpu[785] = 2.0 -sfa_ref_cpu[786] = 2.0 -sfa_ref_cpu[787] = 1.0 -sfa_ref_cpu[788] = 2.0 -sfa_ref_cpu[789] = 2.0 -sfa_ref_cpu[790] = 2.0 -sfa_ref_cpu[791] = 1.0 -sfa_ref_cpu[792] = 1.0 -sfa_ref_cpu[793] = 2.0 -sfa_ref_cpu[794] = 1.0 -sfa_ref_cpu[795] = 1.0 -sfa_ref_cpu[796] = 2.0 -sfa_ref_cpu[797] = 1.0 -sfa_ref_cpu[798] = 2.0 -sfa_ref_cpu[799] = 1.0 -sfa_ref_cpu[800] = 1.0 -sfa_ref_cpu[801] = 2.0 -sfa_ref_cpu[802] = 2.0 -sfa_ref_cpu[803] = 2.0 -sfa_ref_cpu[804] = 1.0 -sfa_ref_cpu[805] = 2.0 -sfa_ref_cpu[806] = 1.0 -sfa_ref_cpu[807] = 2.0 -sfa_ref_cpu[808] = 1.0 -sfa_ref_cpu[809] = 2.0 -sfa_ref_cpu[810] = 2.0 -sfa_ref_cpu[811] = 1.0 -sfa_ref_cpu[812] = 2.0 -sfa_ref_cpu[813] = 2.0 -sfa_ref_cpu[814] = 1.0 -sfa_ref_cpu[815] = 1.0 -sfa_ref_cpu[816] = 1.0 -sfa_ref_cpu[817] = 1.0 -sfa_ref_cpu[818] = 1.0 -sfa_ref_cpu[819] = 2.0 -sfa_ref_cpu[820] = 2.0 -sfa_ref_cpu[821] = 2.0 -sfa_ref_cpu[822] = 1.0 -sfa_ref_cpu[823] = 2.0 -sfa_ref_cpu[824] = 1.0 -sfa_ref_cpu[825] = 1.0 -sfa_ref_cpu[826] = 2.0 -sfa_ref_cpu[827] = 1.0 -sfa_ref_cpu[828] = 2.0 -sfa_ref_cpu[829] = 2.0 -sfa_ref_cpu[830] = 1.0 -sfa_ref_cpu[831] = 2.0 -sfa_ref_cpu[832] = 2.0 -sfa_ref_cpu[833] = 1.0 -sfa_ref_cpu[834] = 2.0 -sfa_ref_cpu[835] = 1.0 -sfa_ref_cpu[836] = 2.0 -sfa_ref_cpu[837] = 2.0 -sfa_ref_cpu[838] = 1.0 -sfa_ref_cpu[839] = 1.0 -sfa_ref_cpu[840] = 1.0 -sfa_ref_cpu[841] = 2.0 -sfa_ref_cpu[842] = 2.0 -sfa_ref_cpu[843] = 1.0 -sfa_ref_cpu[844] = 1.0 -sfa_ref_cpu[845] = 1.0 -sfa_ref_cpu[846] = 2.0 -sfa_ref_cpu[847] = 2.0 -sfa_ref_cpu[848] = 1.0 -sfa_ref_cpu[849] = 1.0 -sfa_ref_cpu[850] = 1.0 -sfa_ref_cpu[851] = 1.0 -sfa_ref_cpu[852] = 1.0 -sfa_ref_cpu[853] = 2.0 -sfa_ref_cpu[854] = 2.0 -sfa_ref_cpu[855] = 1.0 -sfa_ref_cpu[856] = 2.0 -sfa_ref_cpu[857] = 1.0 -sfa_ref_cpu[858] = 1.0 -sfa_ref_cpu[859] = 1.0 -sfa_ref_cpu[860] = 2.0 -sfa_ref_cpu[861] = 1.0 -sfa_ref_cpu[862] = 1.0 -sfa_ref_cpu[863] = 1.0 -sfa_ref_cpu[864] = 2.0 -sfa_ref_cpu[865] = 2.0 -sfa_ref_cpu[866] = 1.0 -sfa_ref_cpu[867] = 2.0 -sfa_ref_cpu[868] = 2.0 -sfa_ref_cpu[869] = 1.0 -sfa_ref_cpu[870] = 1.0 -sfa_ref_cpu[871] = 1.0 -sfa_ref_cpu[872] = 2.0 -sfa_ref_cpu[873] = 2.0 -sfa_ref_cpu[874] = 2.0 -sfa_ref_cpu[875] = 1.0 -sfa_ref_cpu[876] = 1.0 -sfa_ref_cpu[877] = 2.0 -sfa_ref_cpu[878] = 1.0 -sfa_ref_cpu[879] = 2.0 -sfa_ref_cpu[880] = 1.0 -sfa_ref_cpu[881] = 1.0 -sfa_ref_cpu[882] = 1.0 -sfa_ref_cpu[883] = 2.0 -sfa_ref_cpu[884] = 2.0 -sfa_ref_cpu[885] = 2.0 -sfa_ref_cpu[886] = 2.0 -sfa_ref_cpu[887] = 1.0 -sfa_ref_cpu[888] = 2.0 -sfa_ref_cpu[889] = 2.0 -sfa_ref_cpu[890] = 1.0 -sfa_ref_cpu[891] = 1.0 -sfa_ref_cpu[892] = 1.0 -sfa_ref_cpu[893] = 1.0 -sfa_ref_cpu[894] = 2.0 -sfa_ref_cpu[895] = 2.0 -sfa_ref_cpu[896] = 1.0 -sfa_ref_cpu[897] = 2.0 -sfa_ref_cpu[898] = 1.0 -sfa_ref_cpu[899] = 1.0 -sfa_ref_cpu[900] = 2.0 -sfa_ref_cpu[901] = 2.0 -sfa_ref_cpu[902] = 1.0 -sfa_ref_cpu[903] = 1.0 -sfa_ref_cpu[904] = 1.0 -sfa_ref_cpu[905] = 1.0 -sfa_ref_cpu[906] = 1.0 -sfa_ref_cpu[907] = 2.0 -sfa_ref_cpu[908] = 1.0 -sfa_ref_cpu[909] = 1.0 -sfa_ref_cpu[910] = 2.0 -sfa_ref_cpu[911] = 1.0 -sfa_ref_cpu[912] = 1.0 -sfa_ref_cpu[913] = 2.0 -sfa_ref_cpu[914] = 2.0 -sfa_ref_cpu[915] = 2.0 -sfa_ref_cpu[916] = 2.0 -sfa_ref_cpu[917] = 2.0 -sfa_ref_cpu[918] = 1.0 -sfa_ref_cpu[919] = 1.0 -sfa_ref_cpu[920] = 2.0 -sfa_ref_cpu[921] = 1.0 -sfa_ref_cpu[922] = 1.0 -sfa_ref_cpu[923] = 1.0 -sfa_ref_cpu[924] = 2.0 -sfa_ref_cpu[925] = 2.0 -sfa_ref_cpu[926] = 2.0 -sfa_ref_cpu[927] = 1.0 -sfa_ref_cpu[928] = 2.0 -sfa_ref_cpu[929] = 2.0 -sfa_ref_cpu[930] = 2.0 -sfa_ref_cpu[931] = 1.0 -sfa_ref_cpu[932] = 1.0 -sfa_ref_cpu[933] = 1.0 -sfa_ref_cpu[934] = 1.0 -sfa_ref_cpu[935] = 1.0 -sfa_ref_cpu[936] = 2.0 -sfa_ref_cpu[937] = 2.0 -sfa_ref_cpu[938] = 1.0 -sfa_ref_cpu[939] = 2.0 -sfa_ref_cpu[940] = 1.0 -sfa_ref_cpu[941] = 1.0 -sfa_ref_cpu[942] = 2.0 -sfa_ref_cpu[943] = 2.0 -sfa_ref_cpu[944] = 1.0 -sfa_ref_cpu[945] = 2.0 -sfa_ref_cpu[946] = 1.0 -sfa_ref_cpu[947] = 1.0 -sfa_ref_cpu[948] = 1.0 -sfa_ref_cpu[949] = 2.0 -sfa_ref_cpu[950] = 2.0 -sfa_ref_cpu[951] = 2.0 -sfa_ref_cpu[952] = 1.0 -sfa_ref_cpu[953] = 2.0 -sfa_ref_cpu[954] = 1.0 -sfa_ref_cpu[955] = 1.0 -sfa_ref_cpu[956] = 2.0 -sfa_ref_cpu[957] = 2.0 -sfa_ref_cpu[958] = 1.0 -sfa_ref_cpu[959] = 2.0 -sfa_ref_cpu[960] = 1.0 -sfa_ref_cpu[961] = 1.0 -sfa_ref_cpu[962] = 2.0 -sfa_ref_cpu[963] = 2.0 -sfa_ref_cpu[964] = 2.0 -sfa_ref_cpu[965] = 1.0 -sfa_ref_cpu[966] = 2.0 -sfa_ref_cpu[967] = 1.0 -sfa_ref_cpu[968] = 1.0 -sfa_ref_cpu[969] = 1.0 -sfa_ref_cpu[970] = 2.0 -sfa_ref_cpu[971] = 2.0 -sfa_ref_cpu[972] = 1.0 -sfa_ref_cpu[973] = 1.0 -sfa_ref_cpu[974] = 1.0 -sfa_ref_cpu[975] = 2.0 -sfa_ref_cpu[976] = 2.0 -sfa_ref_cpu[977] = 2.0 -sfa_ref_cpu[978] = 1.0 -sfa_ref_cpu[979] = 1.0 -sfa_ref_cpu[980] = 1.0 -sfa_ref_cpu[981] = 2.0 -sfa_ref_cpu[982] = 1.0 -sfa_ref_cpu[983] = 2.0 -sfa_ref_cpu[984] = 2.0 -sfa_ref_cpu[985] = 1.0 -sfa_ref_cpu[986] = 2.0 -sfa_ref_cpu[987] = 2.0 -sfa_ref_cpu[988] = 1.0 -sfa_ref_cpu[989] = 1.0 -sfa_ref_cpu[990] = 1.0 -sfa_ref_cpu[991] = 2.0 -sfa_ref_cpu[992] = 1.0 -sfa_ref_cpu[993] = 1.0 -sfa_ref_cpu[994] = 2.0 -sfa_ref_cpu[995] = 1.0 -sfa_ref_cpu[996] = 2.0 -sfa_ref_cpu[997] = 2.0 -sfa_ref_cpu[998] = 1.0 -sfa_ref_cpu[999] = 2.0 -sfa_ref_cpu[1000] = 1.0 -sfa_ref_cpu[1001] = 1.0 -sfa_ref_cpu[1002] = 2.0 -sfa_ref_cpu[1003] = 2.0 -sfa_ref_cpu[1004] = 1.0 -sfa_ref_cpu[1005] = 2.0 -sfa_ref_cpu[1006] = 2.0 -sfa_ref_cpu[1007] = 2.0 -sfa_ref_cpu[1008] = 2.0 -sfa_ref_cpu[1009] = 2.0 -sfa_ref_cpu[1010] = 2.0 -sfa_ref_cpu[1011] = 1.0 -sfa_ref_cpu[1012] = 2.0 -sfa_ref_cpu[1013] = 2.0 -sfa_ref_cpu[1014] = 2.0 -sfa_ref_cpu[1015] = 1.0 -sfa_ref_cpu[1016] = 2.0 -sfa_ref_cpu[1017] = 2.0 -sfa_ref_cpu[1018] = 1.0 -sfa_ref_cpu[1019] = 2.0 -sfa_ref_cpu[1020] = 2.0 -sfa_ref_cpu[1021] = 1.0 -sfa_ref_cpu[1022] = 2.0 -sfa_ref_cpu[1023] = 1.0 -sfa_ref_cpu[1024] = 1.0 -sfa_ref_cpu[1025] = 1.0 -sfa_ref_cpu[1026] = 1.0 -sfa_ref_cpu[1027] = 2.0 -sfa_ref_cpu[1028] = 2.0 -sfa_ref_cpu[1029] = 2.0 -sfa_ref_cpu[1030] = 2.0 -sfa_ref_cpu[1031] = 1.0 -sfa_ref_cpu[1032] = 1.0 -sfa_ref_cpu[1033] = 1.0 -sfa_ref_cpu[1034] = 1.0 -sfa_ref_cpu[1035] = 2.0 -sfa_ref_cpu[1036] = 1.0 -sfa_ref_cpu[1037] = 2.0 -sfa_ref_cpu[1038] = 2.0 -sfa_ref_cpu[1039] = 2.0 -sfa_ref_cpu[1040] = 2.0 -sfa_ref_cpu[1041] = 1.0 -sfa_ref_cpu[1042] = 1.0 -sfa_ref_cpu[1043] = 1.0 -sfa_ref_cpu[1044] = 1.0 -sfa_ref_cpu[1045] = 1.0 -sfa_ref_cpu[1046] = 1.0 -sfa_ref_cpu[1047] = 2.0 -sfa_ref_cpu[1048] = 1.0 -sfa_ref_cpu[1049] = 1.0 -sfa_ref_cpu[1050] = 2.0 -sfa_ref_cpu[1051] = 2.0 -sfa_ref_cpu[1052] = 2.0 -sfa_ref_cpu[1053] = 2.0 -sfa_ref_cpu[1054] = 1.0 -sfa_ref_cpu[1055] = 2.0 -sfa_ref_cpu[1056] = 2.0 -sfa_ref_cpu[1057] = 2.0 -sfa_ref_cpu[1058] = 1.0 -sfa_ref_cpu[1059] = 2.0 -sfa_ref_cpu[1060] = 2.0 -sfa_ref_cpu[1061] = 2.0 -sfa_ref_cpu[1062] = 2.0 -sfa_ref_cpu[1063] = 2.0 -sfa_ref_cpu[1064] = 2.0 -sfa_ref_cpu[1065] = 2.0 -sfa_ref_cpu[1066] = 1.0 -sfa_ref_cpu[1067] = 1.0 -sfa_ref_cpu[1068] = 1.0 -sfa_ref_cpu[1069] = 1.0 -sfa_ref_cpu[1070] = 2.0 -sfa_ref_cpu[1071] = 1.0 -sfa_ref_cpu[1072] = 2.0 -sfa_ref_cpu[1073] = 1.0 -sfa_ref_cpu[1074] = 2.0 -sfa_ref_cpu[1075] = 1.0 -sfa_ref_cpu[1076] = 1.0 -sfa_ref_cpu[1077] = 1.0 -sfa_ref_cpu[1078] = 2.0 -sfa_ref_cpu[1079] = 1.0 -sfa_ref_cpu[1080] = 1.0 -sfa_ref_cpu[1081] = 2.0 -sfa_ref_cpu[1082] = 1.0 -sfa_ref_cpu[1083] = 2.0 -sfa_ref_cpu[1084] = 2.0 -sfa_ref_cpu[1085] = 1.0 -sfa_ref_cpu[1086] = 1.0 -sfa_ref_cpu[1087] = 2.0 -sfa_ref_cpu[1088] = 1.0 -sfa_ref_cpu[1089] = 2.0 -sfa_ref_cpu[1090] = 2.0 -sfa_ref_cpu[1091] = 2.0 -sfa_ref_cpu[1092] = 2.0 -sfa_ref_cpu[1093] = 2.0 -sfa_ref_cpu[1094] = 2.0 -sfa_ref_cpu[1095] = 2.0 -sfa_ref_cpu[1096] = 1.0 -sfa_ref_cpu[1097] = 1.0 -sfa_ref_cpu[1098] = 1.0 -sfa_ref_cpu[1099] = 1.0 -sfa_ref_cpu[1100] = 1.0 -sfa_ref_cpu[1101] = 2.0 -sfa_ref_cpu[1102] = 1.0 -sfa_ref_cpu[1103] = 2.0 -sfa_ref_cpu[1104] = 1.0 -sfa_ref_cpu[1105] = 2.0 -sfa_ref_cpu[1106] = 1.0 -sfa_ref_cpu[1107] = 2.0 -sfa_ref_cpu[1108] = 2.0 -sfa_ref_cpu[1109] = 2.0 -sfa_ref_cpu[1110] = 1.0 -sfa_ref_cpu[1111] = 1.0 -sfa_ref_cpu[1112] = 2.0 -sfa_ref_cpu[1113] = 1.0 -sfa_ref_cpu[1114] = 1.0 -sfa_ref_cpu[1115] = 1.0 -sfa_ref_cpu[1116] = 1.0 -sfa_ref_cpu[1117] = 2.0 -sfa_ref_cpu[1118] = 2.0 -sfa_ref_cpu[1119] = 1.0 -sfa_ref_cpu[1120] = 1.0 -sfa_ref_cpu[1121] = 2.0 -sfa_ref_cpu[1122] = 1.0 -sfa_ref_cpu[1123] = 1.0 -sfa_ref_cpu[1124] = 2.0 -sfa_ref_cpu[1125] = 2.0 -sfa_ref_cpu[1126] = 2.0 -sfa_ref_cpu[1127] = 2.0 -sfa_ref_cpu[1128] = 1.0 -sfa_ref_cpu[1129] = 2.0 -sfa_ref_cpu[1130] = 1.0 -sfa_ref_cpu[1131] = 1.0 -sfa_ref_cpu[1132] = 1.0 -sfa_ref_cpu[1133] = 2.0 -sfa_ref_cpu[1134] = 2.0 -sfa_ref_cpu[1135] = 1.0 -sfa_ref_cpu[1136] = 1.0 -sfa_ref_cpu[1137] = 2.0 -sfa_ref_cpu[1138] = 2.0 -sfa_ref_cpu[1139] = 1.0 -sfa_ref_cpu[1140] = 2.0 -sfa_ref_cpu[1141] = 1.0 -sfa_ref_cpu[1142] = 2.0 -sfa_ref_cpu[1143] = 1.0 -sfa_ref_cpu[1144] = 2.0 -sfa_ref_cpu[1145] = 2.0 -sfa_ref_cpu[1146] = 2.0 -sfa_ref_cpu[1147] = 1.0 -sfa_ref_cpu[1148] = 2.0 -sfa_ref_cpu[1149] = 1.0 -sfa_ref_cpu[1150] = 1.0 -sfa_ref_cpu[1151] = 2.0 -sfa_ref_cpu[1152] = 2.0 -sfa_ref_cpu[1153] = 1.0 -sfa_ref_cpu[1154] = 1.0 -sfa_ref_cpu[1155] = 1.0 -sfa_ref_cpu[1156] = 1.0 -sfa_ref_cpu[1157] = 1.0 -sfa_ref_cpu[1158] = 2.0 -sfa_ref_cpu[1159] = 1.0 -sfa_ref_cpu[1160] = 2.0 -sfa_ref_cpu[1161] = 2.0 -sfa_ref_cpu[1162] = 2.0 -sfa_ref_cpu[1163] = 1.0 -sfa_ref_cpu[1164] = 2.0 -sfa_ref_cpu[1165] = 1.0 -sfa_ref_cpu[1166] = 2.0 -sfa_ref_cpu[1167] = 2.0 -sfa_ref_cpu[1168] = 2.0 -sfa_ref_cpu[1169] = 2.0 -sfa_ref_cpu[1170] = 2.0 -sfa_ref_cpu[1171] = 1.0 -sfa_ref_cpu[1172] = 2.0 -sfa_ref_cpu[1173] = 1.0 -sfa_ref_cpu[1174] = 1.0 -sfa_ref_cpu[1175] = 1.0 -sfa_ref_cpu[1176] = 2.0 -sfa_ref_cpu[1177] = 1.0 -sfa_ref_cpu[1178] = 1.0 -sfa_ref_cpu[1179] = 2.0 -sfa_ref_cpu[1180] = 2.0 -sfa_ref_cpu[1181] = 2.0 -sfa_ref_cpu[1182] = 2.0 -sfa_ref_cpu[1183] = 2.0 -sfa_ref_cpu[1184] = 1.0 -sfa_ref_cpu[1185] = 1.0 -sfa_ref_cpu[1186] = 2.0 -sfa_ref_cpu[1187] = 1.0 -sfa_ref_cpu[1188] = 2.0 -sfa_ref_cpu[1189] = 2.0 -sfa_ref_cpu[1190] = 2.0 -sfa_ref_cpu[1191] = 1.0 -sfa_ref_cpu[1192] = 1.0 -sfa_ref_cpu[1193] = 1.0 -sfa_ref_cpu[1194] = 2.0 -sfa_ref_cpu[1195] = 2.0 -sfa_ref_cpu[1196] = 2.0 -sfa_ref_cpu[1197] = 1.0 -sfa_ref_cpu[1198] = 1.0 -sfa_ref_cpu[1199] = 2.0 -sfa_ref_cpu[1200] = 1.0 -sfa_ref_cpu[1201] = 1.0 -sfa_ref_cpu[1202] = 2.0 -sfa_ref_cpu[1203] = 2.0 -sfa_ref_cpu[1204] = 2.0 -sfa_ref_cpu[1205] = 2.0 -sfa_ref_cpu[1206] = 1.0 -sfa_ref_cpu[1207] = 1.0 -sfa_ref_cpu[1208] = 1.0 -sfa_ref_cpu[1209] = 2.0 -sfa_ref_cpu[1210] = 2.0 -sfa_ref_cpu[1211] = 2.0 -sfa_ref_cpu[1212] = 2.0 -sfa_ref_cpu[1213] = 2.0 -sfa_ref_cpu[1214] = 2.0 -sfa_ref_cpu[1215] = 2.0 -sfa_ref_cpu[1216] = 2.0 -sfa_ref_cpu[1217] = 2.0 -sfa_ref_cpu[1218] = 1.0 -sfa_ref_cpu[1219] = 2.0 -sfa_ref_cpu[1220] = 1.0 -sfa_ref_cpu[1221] = 2.0 -sfa_ref_cpu[1222] = 2.0 -sfa_ref_cpu[1223] = 2.0 -sfa_ref_cpu[1224] = 1.0 -sfa_ref_cpu[1225] = 2.0 -sfa_ref_cpu[1226] = 1.0 -sfa_ref_cpu[1227] = 2.0 -sfa_ref_cpu[1228] = 1.0 -sfa_ref_cpu[1229] = 1.0 -sfa_ref_cpu[1230] = 1.0 -sfa_ref_cpu[1231] = 1.0 -sfa_ref_cpu[1232] = 2.0 -sfa_ref_cpu[1233] = 2.0 -sfa_ref_cpu[1234] = 1.0 -sfa_ref_cpu[1235] = 1.0 -sfa_ref_cpu[1236] = 2.0 -sfa_ref_cpu[1237] = 2.0 -sfa_ref_cpu[1238] = 1.0 -sfa_ref_cpu[1239] = 2.0 -sfa_ref_cpu[1240] = 2.0 -sfa_ref_cpu[1241] = 2.0 -sfa_ref_cpu[1242] = 2.0 -sfa_ref_cpu[1243] = 1.0 -sfa_ref_cpu[1244] = 1.0 -sfa_ref_cpu[1245] = 1.0 -sfa_ref_cpu[1246] = 1.0 -sfa_ref_cpu[1247] = 2.0 -sfa_ref_cpu[1248] = 2.0 -sfa_ref_cpu[1249] = 1.0 -sfa_ref_cpu[1250] = 2.0 -sfa_ref_cpu[1251] = 1.0 -sfa_ref_cpu[1252] = 2.0 -sfa_ref_cpu[1253] = 1.0 -sfa_ref_cpu[1254] = 1.0 -sfa_ref_cpu[1255] = 1.0 -sfa_ref_cpu[1256] = 2.0 -sfa_ref_cpu[1257] = 1.0 -sfa_ref_cpu[1258] = 2.0 -sfa_ref_cpu[1259] = 2.0 -sfa_ref_cpu[1260] = 1.0 -sfa_ref_cpu[1261] = 1.0 -sfa_ref_cpu[1262] = 1.0 -sfa_ref_cpu[1263] = 2.0 -sfa_ref_cpu[1264] = 1.0 -sfa_ref_cpu[1265] = 2.0 -sfa_ref_cpu[1266] = 2.0 -sfa_ref_cpu[1267] = 1.0 -sfa_ref_cpu[1268] = 1.0 -sfa_ref_cpu[1269] = 2.0 -sfa_ref_cpu[1270] = 2.0 -sfa_ref_cpu[1271] = 1.0 -sfa_ref_cpu[1272] = 2.0 -sfa_ref_cpu[1273] = 1.0 -sfa_ref_cpu[1274] = 2.0 -sfa_ref_cpu[1275] = 2.0 -sfa_ref_cpu[1276] = 1.0 -sfa_ref_cpu[1277] = 2.0 -sfa_ref_cpu[1278] = 1.0 -sfa_ref_cpu[1279] = 1.0 -sfa_ref_cpu[1280] = 1.0 -sfa_ref_cpu[1281] = 2.0 -sfa_ref_cpu[1282] = 1.0 -sfa_ref_cpu[1283] = 1.0 -sfa_ref_cpu[1284] = 2.0 -sfa_ref_cpu[1285] = 1.0 -sfa_ref_cpu[1286] = 1.0 -sfa_ref_cpu[1287] = 1.0 -sfa_ref_cpu[1288] = 2.0 -sfa_ref_cpu[1289] = 2.0 -sfa_ref_cpu[1290] = 2.0 -sfa_ref_cpu[1291] = 1.0 -sfa_ref_cpu[1292] = 2.0 -sfa_ref_cpu[1293] = 1.0 -sfa_ref_cpu[1294] = 1.0 -sfa_ref_cpu[1295] = 2.0 -sfa_ref_cpu[1296] = 1.0 -sfa_ref_cpu[1297] = 2.0 -sfa_ref_cpu[1298] = 2.0 -sfa_ref_cpu[1299] = 1.0 -sfa_ref_cpu[1300] = 1.0 -sfa_ref_cpu[1301] = 1.0 -sfa_ref_cpu[1302] = 1.0 -sfa_ref_cpu[1303] = 1.0 -sfa_ref_cpu[1304] = 1.0 -sfa_ref_cpu[1305] = 2.0 -sfa_ref_cpu[1306] = 2.0 -sfa_ref_cpu[1307] = 1.0 -sfa_ref_cpu[1308] = 2.0 -sfa_ref_cpu[1309] = 1.0 -sfa_ref_cpu[1310] = 2.0 -sfa_ref_cpu[1311] = 1.0 -sfa_ref_cpu[1312] = 2.0 -sfa_ref_cpu[1313] = 1.0 -sfa_ref_cpu[1314] = 1.0 -sfa_ref_cpu[1315] = 1.0 -sfa_ref_cpu[1316] = 1.0 -sfa_ref_cpu[1317] = 2.0 -sfa_ref_cpu[1318] = 1.0 -sfa_ref_cpu[1319] = 1.0 -sfa_ref_cpu[1320] = 2.0 -sfa_ref_cpu[1321] = 1.0 -sfa_ref_cpu[1322] = 2.0 -sfa_ref_cpu[1323] = 1.0 -sfa_ref_cpu[1324] = 1.0 -sfa_ref_cpu[1325] = 2.0 -sfa_ref_cpu[1326] = 2.0 -sfa_ref_cpu[1327] = 1.0 -sfa_ref_cpu[1328] = 2.0 -sfa_ref_cpu[1329] = 1.0 -sfa_ref_cpu[1330] = 1.0 -sfa_ref_cpu[1331] = 1.0 -sfa_ref_cpu[1332] = 2.0 -sfa_ref_cpu[1333] = 2.0 -sfa_ref_cpu[1334] = 1.0 -sfa_ref_cpu[1335] = 2.0 -sfa_ref_cpu[1336] = 2.0 -sfa_ref_cpu[1337] = 2.0 -sfa_ref_cpu[1338] = 1.0 -sfa_ref_cpu[1339] = 1.0 -sfa_ref_cpu[1340] = 1.0 -sfa_ref_cpu[1341] = 2.0 -sfa_ref_cpu[1342] = 2.0 -sfa_ref_cpu[1343] = 2.0 -sfa_ref_cpu[1344] = 1.0 -sfa_ref_cpu[1345] = 2.0 -sfa_ref_cpu[1346] = 1.0 -sfa_ref_cpu[1347] = 2.0 -sfa_ref_cpu[1348] = 2.0 -sfa_ref_cpu[1349] = 2.0 -sfa_ref_cpu[1350] = 1.0 -sfa_ref_cpu[1351] = 2.0 -sfa_ref_cpu[1352] = 2.0 -sfa_ref_cpu[1353] = 1.0 -sfa_ref_cpu[1354] = 1.0 -sfa_ref_cpu[1355] = 1.0 -sfa_ref_cpu[1356] = 2.0 -sfa_ref_cpu[1357] = 1.0 -sfa_ref_cpu[1358] = 2.0 -sfa_ref_cpu[1359] = 1.0 -sfa_ref_cpu[1360] = 1.0 -sfa_ref_cpu[1361] = 1.0 -sfa_ref_cpu[1362] = 1.0 -sfa_ref_cpu[1363] = 2.0 -sfa_ref_cpu[1364] = 1.0 -sfa_ref_cpu[1365] = 2.0 -sfa_ref_cpu[1366] = 1.0 -sfa_ref_cpu[1367] = 2.0 -sfa_ref_cpu[1368] = 2.0 -sfa_ref_cpu[1369] = 2.0 -sfa_ref_cpu[1370] = 1.0 -sfa_ref_cpu[1371] = 1.0 -sfa_ref_cpu[1372] = 2.0 -sfa_ref_cpu[1373] = 1.0 -sfa_ref_cpu[1374] = 1.0 -sfa_ref_cpu[1375] = 2.0 -sfa_ref_cpu[1376] = 1.0 -sfa_ref_cpu[1377] = 1.0 -sfa_ref_cpu[1378] = 1.0 -sfa_ref_cpu[1379] = 1.0 -sfa_ref_cpu[1380] = 1.0 -sfa_ref_cpu[1381] = 1.0 -sfa_ref_cpu[1382] = 2.0 -sfa_ref_cpu[1383] = 1.0 -sfa_ref_cpu[1384] = 2.0 -sfa_ref_cpu[1385] = 2.0 -sfa_ref_cpu[1386] = 2.0 -sfa_ref_cpu[1387] = 1.0 -sfa_ref_cpu[1388] = 1.0 -sfa_ref_cpu[1389] = 2.0 -sfa_ref_cpu[1390] = 2.0 -sfa_ref_cpu[1391] = 1.0 -sfa_ref_cpu[1392] = 2.0 -sfa_ref_cpu[1393] = 1.0 -sfa_ref_cpu[1394] = 1.0 -sfa_ref_cpu[1395] = 2.0 -sfa_ref_cpu[1396] = 2.0 -sfa_ref_cpu[1397] = 2.0 -sfa_ref_cpu[1398] = 2.0 -sfa_ref_cpu[1399] = 2.0 -sfa_ref_cpu[1400] = 1.0 -sfa_ref_cpu[1401] = 1.0 -sfa_ref_cpu[1402] = 2.0 -sfa_ref_cpu[1403] = 1.0 -sfa_ref_cpu[1404] = 1.0 -sfa_ref_cpu[1405] = 1.0 -sfa_ref_cpu[1406] = 2.0 -sfa_ref_cpu[1407] = 1.0 -sfa_ref_cpu[1408] = 2.0 -sfa_ref_cpu[1409] = 2.0 -sfa_ref_cpu[1410] = 1.0 -sfa_ref_cpu[1411] = 1.0 -sfa_ref_cpu[1412] = 1.0 -sfa_ref_cpu[1413] = 2.0 -sfa_ref_cpu[1414] = 1.0 -sfa_ref_cpu[1415] = 2.0 -sfa_ref_cpu[1416] = 2.0 -sfa_ref_cpu[1417] = 1.0 -sfa_ref_cpu[1418] = 1.0 -sfa_ref_cpu[1419] = 1.0 -sfa_ref_cpu[1420] = 2.0 -sfa_ref_cpu[1421] = 1.0 -sfa_ref_cpu[1422] = 2.0 -sfa_ref_cpu[1423] = 1.0 -sfa_ref_cpu[1424] = 1.0 -sfa_ref_cpu[1425] = 2.0 -sfa_ref_cpu[1426] = 2.0 -sfa_ref_cpu[1427] = 1.0 -sfa_ref_cpu[1428] = 1.0 -sfa_ref_cpu[1429] = 1.0 -sfa_ref_cpu[1430] = 1.0 -sfa_ref_cpu[1431] = 1.0 -sfa_ref_cpu[1432] = 1.0 -sfa_ref_cpu[1433] = 1.0 -sfa_ref_cpu[1434] = 1.0 -sfa_ref_cpu[1435] = 2.0 -sfa_ref_cpu[1436] = 2.0 -sfa_ref_cpu[1437] = 1.0 -sfa_ref_cpu[1438] = 2.0 -sfa_ref_cpu[1439] = 1.0 -sfa_ref_cpu[1440] = 1.0 -sfa_ref_cpu[1441] = 2.0 -sfa_ref_cpu[1442] = 1.0 -sfa_ref_cpu[1443] = 2.0 -sfa_ref_cpu[1444] = 1.0 -sfa_ref_cpu[1445] = 1.0 -sfa_ref_cpu[1446] = 2.0 -sfa_ref_cpu[1447] = 1.0 -sfa_ref_cpu[1448] = 1.0 -sfa_ref_cpu[1449] = 1.0 -sfa_ref_cpu[1450] = 1.0 -sfa_ref_cpu[1451] = 1.0 -sfa_ref_cpu[1452] = 2.0 -sfa_ref_cpu[1453] = 2.0 -sfa_ref_cpu[1454] = 1.0 -sfa_ref_cpu[1455] = 2.0 -sfa_ref_cpu[1456] = 2.0 -sfa_ref_cpu[1457] = 1.0 -sfa_ref_cpu[1458] = 1.0 -sfa_ref_cpu[1459] = 2.0 -sfa_ref_cpu[1460] = 2.0 -sfa_ref_cpu[1461] = 1.0 -sfa_ref_cpu[1462] = 1.0 -sfa_ref_cpu[1463] = 1.0 -sfa_ref_cpu[1464] = 2.0 -sfa_ref_cpu[1465] = 1.0 -sfa_ref_cpu[1466] = 1.0 -sfa_ref_cpu[1467] = 2.0 -sfa_ref_cpu[1468] = 1.0 -sfa_ref_cpu[1469] = 2.0 -sfa_ref_cpu[1470] = 2.0 -sfa_ref_cpu[1471] = 2.0 -sfa_ref_cpu[1472] = 1.0 -sfa_ref_cpu[1473] = 1.0 -sfa_ref_cpu[1474] = 1.0 -sfa_ref_cpu[1475] = 1.0 -sfa_ref_cpu[1476] = 1.0 -sfa_ref_cpu[1477] = 1.0 -sfa_ref_cpu[1478] = 2.0 -sfa_ref_cpu[1479] = 2.0 -sfa_ref_cpu[1480] = 1.0 -sfa_ref_cpu[1481] = 1.0 -sfa_ref_cpu[1482] = 2.0 -sfa_ref_cpu[1483] = 1.0 -sfa_ref_cpu[1484] = 1.0 -sfa_ref_cpu[1485] = 1.0 -sfa_ref_cpu[1486] = 2.0 -sfa_ref_cpu[1487] = 2.0 -sfa_ref_cpu[1488] = 2.0 -sfa_ref_cpu[1489] = 2.0 -sfa_ref_cpu[1490] = 2.0 -sfa_ref_cpu[1491] = 1.0 -sfa_ref_cpu[1492] = 1.0 -sfa_ref_cpu[1493] = 2.0 -sfa_ref_cpu[1494] = 1.0 -sfa_ref_cpu[1495] = 2.0 -sfa_ref_cpu[1496] = 1.0 -sfa_ref_cpu[1497] = 2.0 -sfa_ref_cpu[1498] = 1.0 -sfa_ref_cpu[1499] = 1.0 -sfa_ref_cpu[1500] = 2.0 -sfa_ref_cpu[1501] = 2.0 -sfa_ref_cpu[1502] = 2.0 -sfa_ref_cpu[1503] = 2.0 -sfa_ref_cpu[1504] = 2.0 -sfa_ref_cpu[1505] = 2.0 -sfa_ref_cpu[1506] = 2.0 -sfa_ref_cpu[1507] = 2.0 -sfa_ref_cpu[1508] = 2.0 -sfa_ref_cpu[1509] = 2.0 -sfa_ref_cpu[1510] = 1.0 -sfa_ref_cpu[1511] = 2.0 -sfa_ref_cpu[1512] = 1.0 -sfa_ref_cpu[1513] = 2.0 -sfa_ref_cpu[1514] = 2.0 -sfa_ref_cpu[1515] = 1.0 -sfa_ref_cpu[1516] = 1.0 -sfa_ref_cpu[1517] = 2.0 -sfa_ref_cpu[1518] = 1.0 -sfa_ref_cpu[1519] = 1.0 -sfa_ref_cpu[1520] = 2.0 -sfa_ref_cpu[1521] = 2.0 -sfa_ref_cpu[1522] = 2.0 -sfa_ref_cpu[1523] = 2.0 -sfa_ref_cpu[1524] = 1.0 -sfa_ref_cpu[1525] = 2.0 -sfa_ref_cpu[1526] = 2.0 -sfa_ref_cpu[1527] = 1.0 -sfa_ref_cpu[1528] = 1.0 -sfa_ref_cpu[1529] = 2.0 -sfa_ref_cpu[1530] = 1.0 -sfa_ref_cpu[1531] = 2.0 -sfa_ref_cpu[1532] = 1.0 -sfa_ref_cpu[1533] = 2.0 -sfa_ref_cpu[1534] = 1.0 -sfa_ref_cpu[1535] = 1.0 -sfa_ref_cpu[1536] = 1.0 -sfa_ref_cpu[1537] = 2.0 -sfa_ref_cpu[1538] = 2.0 -sfa_ref_cpu[1539] = 1.0 -sfa_ref_cpu[1540] = 2.0 -sfa_ref_cpu[1541] = 1.0 -sfa_ref_cpu[1542] = 1.0 -sfa_ref_cpu[1543] = 2.0 -sfa_ref_cpu[1544] = 1.0 -sfa_ref_cpu[1545] = 2.0 -sfa_ref_cpu[1546] = 1.0 -sfa_ref_cpu[1547] = 2.0 -sfa_ref_cpu[1548] = 1.0 -sfa_ref_cpu[1549] = 2.0 -sfa_ref_cpu[1550] = 2.0 -sfa_ref_cpu[1551] = 1.0 -sfa_ref_cpu[1552] = 1.0 -sfa_ref_cpu[1553] = 1.0 -sfa_ref_cpu[1554] = 2.0 -sfa_ref_cpu[1555] = 1.0 -sfa_ref_cpu[1556] = 2.0 -sfa_ref_cpu[1557] = 2.0 -sfa_ref_cpu[1558] = 1.0 -sfa_ref_cpu[1559] = 2.0 -sfa_ref_cpu[1560] = 2.0 -sfa_ref_cpu[1561] = 2.0 -sfa_ref_cpu[1562] = 2.0 -sfa_ref_cpu[1563] = 1.0 -sfa_ref_cpu[1564] = 2.0 -sfa_ref_cpu[1565] = 1.0 -sfa_ref_cpu[1566] = 1.0 -sfa_ref_cpu[1567] = 2.0 -sfa_ref_cpu[1568] = 1.0 -sfa_ref_cpu[1569] = 1.0 -sfa_ref_cpu[1570] = 2.0 -sfa_ref_cpu[1571] = 1.0 -sfa_ref_cpu[1572] = 2.0 -sfa_ref_cpu[1573] = 2.0 -sfa_ref_cpu[1574] = 1.0 -sfa_ref_cpu[1575] = 1.0 -sfa_ref_cpu[1576] = 2.0 -sfa_ref_cpu[1577] = 1.0 -sfa_ref_cpu[1578] = 2.0 -sfa_ref_cpu[1579] = 1.0 -sfa_ref_cpu[1580] = 2.0 -sfa_ref_cpu[1581] = 2.0 -sfa_ref_cpu[1582] = 1.0 -sfa_ref_cpu[1583] = 2.0 -sfa_ref_cpu[1584] = 2.0 -sfa_ref_cpu[1585] = 1.0 -sfa_ref_cpu[1586] = 1.0 -sfa_ref_cpu[1587] = 1.0 -sfa_ref_cpu[1588] = 2.0 -sfa_ref_cpu[1589] = 2.0 -sfa_ref_cpu[1590] = 2.0 -sfa_ref_cpu[1591] = 2.0 -sfa_ref_cpu[1592] = 1.0 -sfa_ref_cpu[1593] = 1.0 -sfa_ref_cpu[1594] = 1.0 -sfa_ref_cpu[1595] = 2.0 -sfa_ref_cpu[1596] = 2.0 -sfa_ref_cpu[1597] = 2.0 -sfa_ref_cpu[1598] = 2.0 -sfa_ref_cpu[1599] = 1.0 -sfa_ref_cpu[1600] = 1.0 -sfa_ref_cpu[1601] = 1.0 -sfa_ref_cpu[1602] = 2.0 -sfa_ref_cpu[1603] = 2.0 -sfa_ref_cpu[1604] = 1.0 -sfa_ref_cpu[1605] = 1.0 -sfa_ref_cpu[1606] = 1.0 -sfa_ref_cpu[1607] = 2.0 -sfa_ref_cpu[1608] = 2.0 -sfa_ref_cpu[1609] = 1.0 -sfa_ref_cpu[1610] = 2.0 -sfa_ref_cpu[1611] = 1.0 -sfa_ref_cpu[1612] = 1.0 -sfa_ref_cpu[1613] = 1.0 -sfa_ref_cpu[1614] = 1.0 -sfa_ref_cpu[1615] = 2.0 -sfa_ref_cpu[1616] = 1.0 -sfa_ref_cpu[1617] = 1.0 -sfa_ref_cpu[1618] = 2.0 -sfa_ref_cpu[1619] = 1.0 -sfa_ref_cpu[1620] = 2.0 -sfa_ref_cpu[1621] = 2.0 -sfa_ref_cpu[1622] = 1.0 -sfa_ref_cpu[1623] = 2.0 -sfa_ref_cpu[1624] = 1.0 -sfa_ref_cpu[1625] = 1.0 -sfa_ref_cpu[1626] = 2.0 -sfa_ref_cpu[1627] = 1.0 -sfa_ref_cpu[1628] = 2.0 -sfa_ref_cpu[1629] = 1.0 -sfa_ref_cpu[1630] = 1.0 -sfa_ref_cpu[1631] = 2.0 -sfa_ref_cpu[1632] = 2.0 -sfa_ref_cpu[1633] = 2.0 -sfa_ref_cpu[1634] = 2.0 -sfa_ref_cpu[1635] = 1.0 -sfa_ref_cpu[1636] = 2.0 -sfa_ref_cpu[1637] = 1.0 -sfa_ref_cpu[1638] = 2.0 -sfa_ref_cpu[1639] = 1.0 -sfa_ref_cpu[1640] = 2.0 -sfa_ref_cpu[1641] = 1.0 -sfa_ref_cpu[1642] = 2.0 -sfa_ref_cpu[1643] = 2.0 -sfa_ref_cpu[1644] = 2.0 -sfa_ref_cpu[1645] = 2.0 -sfa_ref_cpu[1646] = 2.0 -sfa_ref_cpu[1647] = 2.0 -sfa_ref_cpu[1648] = 2.0 -sfa_ref_cpu[1649] = 1.0 -sfa_ref_cpu[1650] = 2.0 -sfa_ref_cpu[1651] = 2.0 -sfa_ref_cpu[1652] = 1.0 -sfa_ref_cpu[1653] = 2.0 -sfa_ref_cpu[1654] = 2.0 -sfa_ref_cpu[1655] = 2.0 -sfa_ref_cpu[1656] = 2.0 -sfa_ref_cpu[1657] = 2.0 -sfa_ref_cpu[1658] = 2.0 -sfa_ref_cpu[1659] = 2.0 -sfa_ref_cpu[1660] = 1.0 -sfa_ref_cpu[1661] = 1.0 -sfa_ref_cpu[1662] = 2.0 -sfa_ref_cpu[1663] = 1.0 -sfa_ref_cpu[1664] = 2.0 -sfa_ref_cpu[1665] = 2.0 -sfa_ref_cpu[1666] = 2.0 -sfa_ref_cpu[1667] = 1.0 -sfa_ref_cpu[1668] = 1.0 -sfa_ref_cpu[1669] = 2.0 -sfa_ref_cpu[1670] = 1.0 -sfa_ref_cpu[1671] = 2.0 -sfa_ref_cpu[1672] = 1.0 -sfa_ref_cpu[1673] = 1.0 -sfa_ref_cpu[1674] = 2.0 -sfa_ref_cpu[1675] = 2.0 -sfa_ref_cpu[1676] = 2.0 -sfa_ref_cpu[1677] = 1.0 -sfa_ref_cpu[1678] = 1.0 -sfa_ref_cpu[1679] = 1.0 -sfa_ref_cpu[1680] = 2.0 -sfa_ref_cpu[1681] = 2.0 -sfa_ref_cpu[1682] = 2.0 -sfa_ref_cpu[1683] = 1.0 -sfa_ref_cpu[1684] = 2.0 -sfa_ref_cpu[1685] = 2.0 -sfa_ref_cpu[1686] = 2.0 -sfa_ref_cpu[1687] = 1.0 -sfa_ref_cpu[1688] = 1.0 -sfa_ref_cpu[1689] = 2.0 -sfa_ref_cpu[1690] = 2.0 -sfa_ref_cpu[1691] = 1.0 -sfa_ref_cpu[1692] = 2.0 -sfa_ref_cpu[1693] = 2.0 -sfa_ref_cpu[1694] = 1.0 -sfa_ref_cpu[1695] = 1.0 -sfa_ref_cpu[1696] = 2.0 -sfa_ref_cpu[1697] = 2.0 -sfa_ref_cpu[1698] = 2.0 -sfa_ref_cpu[1699] = 1.0 -sfa_ref_cpu[1700] = 2.0 -sfa_ref_cpu[1701] = 2.0 -sfa_ref_cpu[1702] = 1.0 -sfa_ref_cpu[1703] = 2.0 -sfa_ref_cpu[1704] = 1.0 -sfa_ref_cpu[1705] = 1.0 -sfa_ref_cpu[1706] = 1.0 -sfa_ref_cpu[1707] = 1.0 -sfa_ref_cpu[1708] = 2.0 -sfa_ref_cpu[1709] = 2.0 -sfa_ref_cpu[1710] = 2.0 -sfa_ref_cpu[1711] = 1.0 -sfa_ref_cpu[1712] = 1.0 -sfa_ref_cpu[1713] = 2.0 -sfa_ref_cpu[1714] = 1.0 -sfa_ref_cpu[1715] = 2.0 -sfa_ref_cpu[1716] = 1.0 -sfa_ref_cpu[1717] = 1.0 -sfa_ref_cpu[1718] = 2.0 -sfa_ref_cpu[1719] = 2.0 -sfa_ref_cpu[1720] = 2.0 -sfa_ref_cpu[1721] = 2.0 -sfa_ref_cpu[1722] = 1.0 -sfa_ref_cpu[1723] = 1.0 -sfa_ref_cpu[1724] = 2.0 -sfa_ref_cpu[1725] = 2.0 -sfa_ref_cpu[1726] = 2.0 -sfa_ref_cpu[1727] = 2.0 -sfa_ref_cpu[1728] = 2.0 -sfa_ref_cpu[1729] = 2.0 -sfa_ref_cpu[1730] = 1.0 -sfa_ref_cpu[1731] = 2.0 -sfa_ref_cpu[1732] = 1.0 -sfa_ref_cpu[1733] = 1.0 -sfa_ref_cpu[1734] = 1.0 -sfa_ref_cpu[1735] = 2.0 -sfa_ref_cpu[1736] = 1.0 -sfa_ref_cpu[1737] = 1.0 -sfa_ref_cpu[1738] = 1.0 -sfa_ref_cpu[1739] = 1.0 -sfa_ref_cpu[1740] = 2.0 -sfa_ref_cpu[1741] = 2.0 -sfa_ref_cpu[1742] = 1.0 -sfa_ref_cpu[1743] = 2.0 -sfa_ref_cpu[1744] = 1.0 -sfa_ref_cpu[1745] = 2.0 -sfa_ref_cpu[1746] = 2.0 -sfa_ref_cpu[1747] = 1.0 -sfa_ref_cpu[1748] = 1.0 -sfa_ref_cpu[1749] = 2.0 -sfa_ref_cpu[1750] = 1.0 -sfa_ref_cpu[1751] = 2.0 -sfa_ref_cpu[1752] = 1.0 -sfa_ref_cpu[1753] = 1.0 -sfa_ref_cpu[1754] = 2.0 -sfa_ref_cpu[1755] = 1.0 -sfa_ref_cpu[1756] = 2.0 -sfa_ref_cpu[1757] = 2.0 -sfa_ref_cpu[1758] = 2.0 -sfa_ref_cpu[1759] = 1.0 -sfa_ref_cpu[1760] = 1.0 -sfa_ref_cpu[1761] = 2.0 -sfa_ref_cpu[1762] = 1.0 -sfa_ref_cpu[1763] = 1.0 -sfa_ref_cpu[1764] = 1.0 -sfa_ref_cpu[1765] = 2.0 -sfa_ref_cpu[1766] = 2.0 -sfa_ref_cpu[1767] = 2.0 -sfa_ref_cpu[1768] = 1.0 -sfa_ref_cpu[1769] = 2.0 -sfa_ref_cpu[1770] = 1.0 -sfa_ref_cpu[1771] = 1.0 -sfa_ref_cpu[1772] = 1.0 -sfa_ref_cpu[1773] = 1.0 -sfa_ref_cpu[1774] = 2.0 -sfa_ref_cpu[1775] = 1.0 -sfa_ref_cpu[1776] = 1.0 -sfa_ref_cpu[1777] = 1.0 -sfa_ref_cpu[1778] = 1.0 -sfa_ref_cpu[1779] = 1.0 -sfa_ref_cpu[1780] = 2.0 -sfa_ref_cpu[1781] = 2.0 -sfa_ref_cpu[1782] = 1.0 -sfa_ref_cpu[1783] = 2.0 -sfa_ref_cpu[1784] = 2.0 -sfa_ref_cpu[1785] = 1.0 -sfa_ref_cpu[1786] = 1.0 -sfa_ref_cpu[1787] = 1.0 -sfa_ref_cpu[1788] = 1.0 -sfa_ref_cpu[1789] = 1.0 -sfa_ref_cpu[1790] = 2.0 -sfa_ref_cpu[1791] = 2.0 -sfa_ref_cpu[1792] = 2.0 -sfa_ref_cpu[1793] = 1.0 -sfa_ref_cpu[1794] = 2.0 -sfa_ref_cpu[1795] = 1.0 -sfa_ref_cpu[1796] = 1.0 -sfa_ref_cpu[1797] = 1.0 -sfa_ref_cpu[1798] = 2.0 -sfa_ref_cpu[1799] = 2.0 -sfa_ref_cpu[1800] = 2.0 -sfa_ref_cpu[1801] = 1.0 -sfa_ref_cpu[1802] = 1.0 -sfa_ref_cpu[1803] = 2.0 -sfa_ref_cpu[1804] = 2.0 -sfa_ref_cpu[1805] = 2.0 -sfa_ref_cpu[1806] = 1.0 -sfa_ref_cpu[1807] = 1.0 -sfa_ref_cpu[1808] = 2.0 -sfa_ref_cpu[1809] = 1.0 -sfa_ref_cpu[1810] = 1.0 -sfa_ref_cpu[1811] = 2.0 -sfa_ref_cpu[1812] = 1.0 -sfa_ref_cpu[1813] = 2.0 -sfa_ref_cpu[1814] = 2.0 -sfa_ref_cpu[1815] = 1.0 -sfa_ref_cpu[1816] = 2.0 -sfa_ref_cpu[1817] = 2.0 -sfa_ref_cpu[1818] = 2.0 -sfa_ref_cpu[1819] = 1.0 -sfa_ref_cpu[1820] = 2.0 -sfa_ref_cpu[1821] = 2.0 -sfa_ref_cpu[1822] = 2.0 -sfa_ref_cpu[1823] = 1.0 -sfa_ref_cpu[1824] = 1.0 -sfa_ref_cpu[1825] = 1.0 -sfa_ref_cpu[1826] = 1.0 -sfa_ref_cpu[1827] = 2.0 -sfa_ref_cpu[1828] = 2.0 -sfa_ref_cpu[1829] = 1.0 -sfa_ref_cpu[1830] = 1.0 -sfa_ref_cpu[1831] = 1.0 -sfa_ref_cpu[1832] = 1.0 -sfa_ref_cpu[1833] = 1.0 -sfa_ref_cpu[1834] = 1.0 -sfa_ref_cpu[1835] = 1.0 -sfa_ref_cpu[1836] = 2.0 -sfa_ref_cpu[1837] = 2.0 -sfa_ref_cpu[1838] = 2.0 -sfa_ref_cpu[1839] = 1.0 -sfa_ref_cpu[1840] = 2.0 -sfa_ref_cpu[1841] = 2.0 -sfa_ref_cpu[1842] = 2.0 -sfa_ref_cpu[1843] = 1.0 -sfa_ref_cpu[1844] = 2.0 -sfa_ref_cpu[1845] = 1.0 -sfa_ref_cpu[1846] = 1.0 -sfa_ref_cpu[1847] = 1.0 -sfa_ref_cpu[1848] = 1.0 -sfa_ref_cpu[1849] = 1.0 -sfa_ref_cpu[1850] = 2.0 -sfa_ref_cpu[1851] = 1.0 -sfa_ref_cpu[1852] = 1.0 -sfa_ref_cpu[1853] = 2.0 -sfa_ref_cpu[1854] = 1.0 -sfa_ref_cpu[1855] = 1.0 -sfa_ref_cpu[1856] = 1.0 -sfa_ref_cpu[1857] = 1.0 -sfa_ref_cpu[1858] = 2.0 -sfa_ref_cpu[1859] = 2.0 -sfa_ref_cpu[1860] = 2.0 -sfa_ref_cpu[1861] = 2.0 -sfa_ref_cpu[1862] = 1.0 -sfa_ref_cpu[1863] = 2.0 -sfa_ref_cpu[1864] = 2.0 -sfa_ref_cpu[1865] = 2.0 -sfa_ref_cpu[1866] = 1.0 -sfa_ref_cpu[1867] = 2.0 -sfa_ref_cpu[1868] = 2.0 -sfa_ref_cpu[1869] = 1.0 -sfa_ref_cpu[1870] = 2.0 -sfa_ref_cpu[1871] = 2.0 -sfa_ref_cpu[1872] = 2.0 -sfa_ref_cpu[1873] = 2.0 -sfa_ref_cpu[1874] = 1.0 -sfa_ref_cpu[1875] = 2.0 -sfa_ref_cpu[1876] = 1.0 -sfa_ref_cpu[1877] = 1.0 -sfa_ref_cpu[1878] = 1.0 -sfa_ref_cpu[1879] = 2.0 -sfa_ref_cpu[1880] = 1.0 -sfa_ref_cpu[1881] = 2.0 -sfa_ref_cpu[1882] = 1.0 -sfa_ref_cpu[1883] = 2.0 -sfa_ref_cpu[1884] = 1.0 -sfa_ref_cpu[1885] = 2.0 -sfa_ref_cpu[1886] = 1.0 -sfa_ref_cpu[1887] = 2.0 -sfa_ref_cpu[1888] = 2.0 -sfa_ref_cpu[1889] = 2.0 -sfa_ref_cpu[1890] = 2.0 -sfa_ref_cpu[1891] = 2.0 -sfa_ref_cpu[1892] = 2.0 -sfa_ref_cpu[1893] = 2.0 -sfa_ref_cpu[1894] = 1.0 -sfa_ref_cpu[1895] = 1.0 -sfa_ref_cpu[1896] = 2.0 -sfa_ref_cpu[1897] = 1.0 -sfa_ref_cpu[1898] = 2.0 -sfa_ref_cpu[1899] = 2.0 -sfa_ref_cpu[1900] = 2.0 -sfa_ref_cpu[1901] = 2.0 -sfa_ref_cpu[1902] = 2.0 -sfa_ref_cpu[1903] = 1.0 -sfa_ref_cpu[1904] = 1.0 -sfa_ref_cpu[1905] = 2.0 -sfa_ref_cpu[1906] = 2.0 -sfa_ref_cpu[1907] = 1.0 -sfa_ref_cpu[1908] = 1.0 -sfa_ref_cpu[1909] = 2.0 -sfa_ref_cpu[1910] = 2.0 -sfa_ref_cpu[1911] = 1.0 -sfa_ref_cpu[1912] = 2.0 -sfa_ref_cpu[1913] = 2.0 -sfa_ref_cpu[1914] = 2.0 -sfa_ref_cpu[1915] = 1.0 -sfa_ref_cpu[1916] = 2.0 -sfa_ref_cpu[1917] = 2.0 -sfa_ref_cpu[1918] = 2.0 -sfa_ref_cpu[1919] = 2.0 -sfa_ref_cpu[1920] = 1.0 -sfa_ref_cpu[1921] = 1.0 -sfa_ref_cpu[1922] = 1.0 -sfa_ref_cpu[1923] = 1.0 -sfa_ref_cpu[1924] = 2.0 -sfa_ref_cpu[1925] = 1.0 -sfa_ref_cpu[1926] = 2.0 -sfa_ref_cpu[1927] = 1.0 -sfa_ref_cpu[1928] = 1.0 -sfa_ref_cpu[1929] = 2.0 -sfa_ref_cpu[1930] = 1.0 -sfa_ref_cpu[1931] = 2.0 -sfa_ref_cpu[1932] = 2.0 -sfa_ref_cpu[1933] = 2.0 -sfa_ref_cpu[1934] = 2.0 -sfa_ref_cpu[1935] = 1.0 -sfa_ref_cpu[1936] = 1.0 -sfa_ref_cpu[1937] = 2.0 -sfa_ref_cpu[1938] = 2.0 -sfa_ref_cpu[1939] = 1.0 -sfa_ref_cpu[1940] = 1.0 -sfa_ref_cpu[1941] = 2.0 -sfa_ref_cpu[1942] = 2.0 -sfa_ref_cpu[1943] = 2.0 -sfa_ref_cpu[1944] = 2.0 -sfa_ref_cpu[1945] = 2.0 -sfa_ref_cpu[1946] = 2.0 -sfa_ref_cpu[1947] = 1.0 -sfa_ref_cpu[1948] = 1.0 -sfa_ref_cpu[1949] = 2.0 -sfa_ref_cpu[1950] = 1.0 -sfa_ref_cpu[1951] = 2.0 -sfa_ref_cpu[1952] = 1.0 -sfa_ref_cpu[1953] = 1.0 -sfa_ref_cpu[1954] = 1.0 -sfa_ref_cpu[1955] = 2.0 -sfa_ref_cpu[1956] = 2.0 -sfa_ref_cpu[1957] = 1.0 -sfa_ref_cpu[1958] = 2.0 -sfa_ref_cpu[1959] = 1.0 -sfa_ref_cpu[1960] = 1.0 -sfa_ref_cpu[1961] = 2.0 -sfa_ref_cpu[1962] = 2.0 -sfa_ref_cpu[1963] = 2.0 -sfa_ref_cpu[1964] = 1.0 -sfa_ref_cpu[1965] = 2.0 -sfa_ref_cpu[1966] = 2.0 -sfa_ref_cpu[1967] = 1.0 -sfa_ref_cpu[1968] = 1.0 -sfa_ref_cpu[1969] = 2.0 -sfa_ref_cpu[1970] = 1.0 -sfa_ref_cpu[1971] = 1.0 -sfa_ref_cpu[1972] = 2.0 -sfa_ref_cpu[1973] = 1.0 -sfa_ref_cpu[1974] = 2.0 -sfa_ref_cpu[1975] = 1.0 -sfa_ref_cpu[1976] = 1.0 -sfa_ref_cpu[1977] = 2.0 -sfa_ref_cpu[1978] = 2.0 -sfa_ref_cpu[1979] = 1.0 -sfa_ref_cpu[1980] = 1.0 -sfa_ref_cpu[1981] = 2.0 -sfa_ref_cpu[1982] = 1.0 -sfa_ref_cpu[1983] = 2.0 -sfa_ref_cpu[1984] = 1.0 -sfa_ref_cpu[1985] = 2.0 -sfa_ref_cpu[1986] = 2.0 -sfa_ref_cpu[1987] = 1.0 -sfa_ref_cpu[1988] = 2.0 -sfa_ref_cpu[1989] = 1.0 -sfa_ref_cpu[1990] = 2.0 -sfa_ref_cpu[1991] = 2.0 -sfa_ref_cpu[1992] = 1.0 -sfa_ref_cpu[1993] = 2.0 -sfa_ref_cpu[1994] = 1.0 -sfa_ref_cpu[1995] = 2.0 -sfa_ref_cpu[1996] = 2.0 -sfa_ref_cpu[1997] = 1.0 -sfa_ref_cpu[1998] = 1.0 -sfa_ref_cpu[1999] = 2.0 -sfa_ref_cpu[2000] = 2.0 -sfa_ref_cpu[2001] = 2.0 -sfa_ref_cpu[2002] = 2.0 -sfa_ref_cpu[2003] = 2.0 -sfa_ref_cpu[2004] = 2.0 -sfa_ref_cpu[2005] = 1.0 -sfa_ref_cpu[2006] = 2.0 -sfa_ref_cpu[2007] = 1.0 -sfa_ref_cpu[2008] = 1.0 -sfa_ref_cpu[2009] = 1.0 -sfa_ref_cpu[2010] = 2.0 -sfa_ref_cpu[2011] = 1.0 -sfa_ref_cpu[2012] = 2.0 -sfa_ref_cpu[2013] = 1.0 -sfa_ref_cpu[2014] = 2.0 -sfa_ref_cpu[2015] = 1.0 -sfa_ref_cpu[2016] = 2.0 -sfa_ref_cpu[2017] = 2.0 -sfa_ref_cpu[2018] = 1.0 -sfa_ref_cpu[2019] = 2.0 -sfa_ref_cpu[2020] = 2.0 -sfa_ref_cpu[2021] = 2.0 -sfa_ref_cpu[2022] = 2.0 -sfa_ref_cpu[2023] = 1.0 -sfa_ref_cpu[2024] = 2.0 -sfa_ref_cpu[2025] = 2.0 -sfa_ref_cpu[2026] = 1.0 -sfa_ref_cpu[2027] = 1.0 -sfa_ref_cpu[2028] = 2.0 -sfa_ref_cpu[2029] = 1.0 -sfa_ref_cpu[2030] = 2.0 -sfa_ref_cpu[2031] = 2.0 -sfa_ref_cpu[2032] = 1.0 -sfa_ref_cpu[2033] = 1.0 -sfa_ref_cpu[2034] = 2.0 -sfa_ref_cpu[2035] = 1.0 -sfa_ref_cpu[2036] = 2.0 -sfa_ref_cpu[2037] = 2.0 -sfa_ref_cpu[2038] = 2.0 -sfa_ref_cpu[2039] = 2.0 -sfa_ref_cpu[2040] = 1.0 -sfa_ref_cpu[2041] = 2.0 -sfa_ref_cpu[2042] = 1.0 -sfa_ref_cpu[2043] = 2.0 -sfa_ref_cpu[2044] = 2.0 -sfa_ref_cpu[2045] = 1.0 -sfa_ref_cpu[2046] = 1.0 -sfa_ref_cpu[2047] = 2.0 -c_ref[0, 0, 0] = 7.5 -c_ref[1, 0, 0] = 10.25 -c_ref[2, 0, 0] = 12.25 -c_ref[3, 0, 0] = 15.25 -c_ref[4, 0, 0] = 13.25 -c_ref[5, 0, 0] = 17.25 -c_ref[6, 0, 0] = 15.25 -c_ref[7, 0, 0] = 15.5 -c_ref[8, 0, 0] = 18.0 -c_ref[9, 0, 0] = 12.25 -c_ref[10, 0, 0] = 14.25 -c_ref[11, 0, 0] = 11.5 -c_ref[12, 0, 0] = 15.0 -c_ref[13, 0, 0] = 14.0 -c_ref[14, 0, 0] = 17.0 -c_ref[15, 0, 0] = 13.25 -c_ref[16, 0, 0] = 19.25 -c_ref[17, 0, 0] = 12.75 -c_ref[18, 0, 0] = 12.5 -c_ref[19, 0, 0] = 17.0 -c_ref[20, 0, 0] = 14.25 -c_ref[21, 0, 0] = 16.25 -c_ref[22, 0, 0] = 18.5 -c_ref[23, 0, 0] = 12.0 -c_ref[24, 0, 0] = 17.25 -c_ref[25, 0, 0] = 13.0 -c_ref[26, 0, 0] = 18.25 -c_ref[27, 0, 0] = 17.0 -c_ref[28, 0, 0] = 10.25 -c_ref[29, 0, 0] = 12.75 -c_ref[30, 0, 0] = 17.5 -c_ref[31, 0, 0] = 19.0 -c_ref[32, 0, 0] = 13.5 -c_ref[33, 0, 0] = 14.75 -c_ref[34, 0, 0] = 14.75 -c_ref[35, 0, 0] = 17.25 -c_ref[36, 0, 0] = 15.25 -c_ref[37, 0, 0] = 18.0 -c_ref[38, 0, 0] = 19.25 -c_ref[39, 0, 0] = 13.75 -c_ref[40, 0, 0] = 15.75 -c_ref[41, 0, 0] = 13.5 -c_ref[42, 0, 0] = 12.0 -c_ref[43, 0, 0] = 16.75 -c_ref[44, 0, 0] = 18.75 -c_ref[45, 0, 0] = 12.75 -c_ref[46, 0, 0] = 10.5 -c_ref[47, 0, 0] = 9.25 -c_ref[48, 0, 0] = 12.5 -c_ref[49, 0, 0] = 14.5 -c_ref[50, 0, 0] = 13.25 -c_ref[51, 0, 0] = 17.25 -c_ref[52, 0, 0] = 14.75 -c_ref[53, 0, 0] = 13.75 -c_ref[54, 0, 0] = 13.5 -c_ref[55, 0, 0] = 12.5 -c_ref[56, 0, 0] = 9.75 -c_ref[57, 0, 0] = 11.0 -c_ref[58, 0, 0] = 16.75 -c_ref[59, 0, 0] = 14.0 -c_ref[60, 0, 0] = 16.0 -c_ref[61, 0, 0] = 13.0 -c_ref[62, 0, 0] = 14.75 -c_ref[63, 0, 0] = 14.75 -c_ref[64, 0, 0] = 13.25 -c_ref[65, 0, 0] = 18.0 -c_ref[66, 0, 0] = 15.0 -c_ref[67, 0, 0] = 13.75 -c_ref[68, 0, 0] = 12.5 -c_ref[69, 0, 0] = 15.75 -c_ref[70, 0, 0] = 10.5 -c_ref[71, 0, 0] = 16.25 -c_ref[72, 0, 0] = 16.25 -c_ref[73, 0, 0] = 14.5 -c_ref[74, 0, 0] = 16.0 -c_ref[75, 0, 0] = 17.0 -c_ref[76, 0, 0] = 17.25 -c_ref[77, 0, 0] = 10.5 -c_ref[78, 0, 0] = 12.5 -c_ref[79, 0, 0] = 13.0 -c_ref[80, 0, 0] = 12.5 -c_ref[81, 0, 0] = 11.0 -c_ref[82, 0, 0] = 15.0 -c_ref[83, 0, 0] = 13.75 -c_ref[84, 0, 0] = 12.25 -c_ref[85, 0, 0] = 13.25 -c_ref[86, 0, 0] = 13.75 -c_ref[87, 0, 0] = 17.0 -c_ref[88, 0, 0] = 14.0 -c_ref[89, 0, 0] = 13.0 -c_ref[90, 0, 0] = 14.25 -c_ref[91, 0, 0] = 15.75 -c_ref[92, 0, 0] = 9.5 -c_ref[93, 0, 0] = 13.0 -c_ref[94, 0, 0] = 11.0 -c_ref[95, 0, 0] = 13.75 -c_ref[96, 0, 0] = 15.25 -c_ref[97, 0, 0] = 12.75 -c_ref[98, 0, 0] = 14.5 -c_ref[99, 0, 0] = 13.0 -c_ref[100, 0, 0] = 11.75 -c_ref[101, 0, 0] = 12.0 -c_ref[102, 0, 0] = 18.0 -c_ref[103, 0, 0] = 15.5 -c_ref[104, 0, 0] = 12.75 -c_ref[105, 0, 0] = 12.5 -c_ref[106, 0, 0] = 14.75 -c_ref[107, 0, 0] = 16.75 -c_ref[108, 0, 0] = 13.5 -c_ref[109, 0, 0] = 15.25 -c_ref[110, 0, 0] = 13.5 -c_ref[111, 0, 0] = 11.75 -c_ref[112, 0, 0] = 17.25 -c_ref[113, 0, 0] = 16.25 -c_ref[114, 0, 0] = 11.25 -c_ref[115, 0, 0] = 10.75 -c_ref[116, 0, 0] = 13.5 -c_ref[117, 0, 0] = 11.5 -c_ref[118, 0, 0] = 15.5 -c_ref[119, 0, 0] = 17.25 -c_ref[120, 0, 0] = 14.75 -c_ref[121, 0, 0] = 17.0 -c_ref[122, 0, 0] = 15.5 -c_ref[123, 0, 0] = 14.75 -c_ref[124, 0, 0] = 18.0 -c_ref[125, 0, 0] = 13.0 -c_ref[126, 0, 0] = 15.5 -c_ref[127, 0, 0] = 14.75 diff --git a/problems/nvidia/nvfp4_gemv/test_python_1.sh b/problems/nvidia/nvfp4_gemv/test_python_1.sh deleted file mode 100644 index ae91eba..0000000 --- a/problems/nvidia/nvfp4_gemv/test_python_1.sh +++ /dev/null @@ -1,87 +0,0 @@ -# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build -BUILD_DIR=/home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/build_python -LLVM_DIR=$BUILD_DIR/llvm-prebuilt -# # BUILD_DIR=/home/scratch.ftse_gpu/workspace/dkg/build -# # BUILD_DIR=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/build -# #BUILD_DIR=/home/yanchengz/scratch_1/dynamic-kernel-generator/build_debug2 -# # sudo /home/scratch.computelab/utils/driver/install_driver.py --installer=/home/builds/daily/display/x86_64/rel/gpu_drv/r580/r580_00/20250527_36037303/NVIDIA-Linux-x86_64-rel_gpu_drv_r580_r580_00-20250527_36037303-internal.run --reason="Change to tot driver" - - -# # BUILD_DIR=/home/scratch.nbommi_gpu/warp-phase-trace/dynamic-kernel-generator/build_main - -export PYTHONPATH=$BUILD_DIR/cutlass_ir/python_packages -#export PYTHONPATH=/scratch/dynamic-kernel-generator/dynamic-kernel-generator/scripts -export CUDA_TOOLKIT_PATH=$BUILD_DIR/compiler_next -MLIR_CUDA_RUNTIME="$LLVM_DIR/lib/libmlir_cuda_runtime.so" -MLIR_C_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_c_runner_utils.so" -MLIR_RUNNER_UTILS="$LLVM_DIR/lib/libmlir_runner_utils.so" -CUDA_DIALECT_RUNTIME="$BUILD_DIR/lib/libcuda_dialect_runtime.so" -export CUTE_DSL_LIBS="$MLIR_CUDA_RUNTIME:$MLIR_C_RUNNER_UTILS:$MLIR_RUNNER_UTILS:$CUDA_DIALECT_RUNTIME" - - -#export CUTE_DSL_PREPROCESSOR=True - -# export CUTE_DSL_PRINT_IR=1 -# just compile the IR but not execute it -# export CUTE_DSL_DRYRUN=1 -# export CUTE_DSL_JIT_TIME_PROFILING=ON -# export CUTE_DSL_KEEP_IR=True -# export CUTE_DSL_PRINT_IR=1 -# export CUTE_DSL_KEEP_CUBIN=1 -# export CUTE_DSL_LINEINFO=True -# export CUTE_DSL_LOG_TO_CONSOLE=1 -# export PYTHONUNBUFFERED=1 -# export CUTE_DSL_KEEP_SASS=1 -# whether to show detailed log in preprocessing -# export CUTE_DSL_FILTER_STACKTRACE=10 -export CUTE_DSL_ARCH=sm_100a - -# -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dynamic-kernel-generator/dynamic-kernel-generator/cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/reference-kernels/problems/nvidia/nvfp4_gemv/submission.py -/home/scratch.vickiw_gpu/env/bin/python3 eval.py test task.yml -/home/scratch.vickiw_gpu/env/bin/python3 eval.py benchmark task.yml -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/cuda-gdb --args - -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_cute_layout.py -# # /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gated_dual_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gecccccbkvnjtrvtfreufijlfglnudnvuggvdfucidbnhk -# mm.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemm/nvfp4_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gemv/nvfp4_gemv.py -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool memcheck \ -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,16384 #135us -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 4096,128,7168 #62 - -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/blackwell/tutorial_gemm/fp16_gemm_0.py --mnk 7168,128,2048 #26 - - -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_group_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 /home/scratch.vickiw_gpu/dsl-gpu-mode/gated_dual_gemm/nvfp4_gated_dual_gemm.py -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv_naive.py - - - -# print out ncu time -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --metrics gpu__time_duration \ -# python3 vicki/tutorial_fp16_gemm_0__.py --mnk 7168,8,512 - -# use sanitizer to check race contention and memref error -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/bin/compute-sanitizer --tool racecheck|memcheck -# cutlass_ir/compiler/test/python/examples/sm_100a/test_nvfp4_gemv.py - -# capture ncu report -# /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/ncu --check-exit-code 0 -f --set full --import-source yes --target-processes all --clock-control base --cache-control none -o gemv_4.1 \ -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/nvfp4_gemv.py --m 128 --k 128 --l 2 - -# regular run python example -# /home/scratch.vickiw_gpu/env/bin/python3 cutlass_ir/compiler/python/examples/internal/blackwell/min_latency_hmma.py --mnkl 7168,8,512,1 - -# run pytest -# pytest cutlass_ir/compiler/test/python/examples/sm_80/test_sgemm.py From 634a0b3aa8efa01b3354e8e9be31be3b60c865b4 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Tue, 21 Oct 2025 01:52:37 -0700 Subject: [PATCH 17/29] move some costs to host. --- problems/nvidia/nvfp4_group_gemm/reference.py | 56 ++++++++++++-- .../nvidia/nvfp4_group_gemm/submission.py | 77 +------------------ problems/nvidia/nvfp4_group_gemm/task.py | 2 +- problems/nvidia/nvfp4_group_gemm/template.py | 9 ++- 4 files changed, 60 insertions(+), 84 deletions(-) diff --git a/problems/nvidia/nvfp4_group_gemm/reference.py b/problems/nvidia/nvfp4_group_gemm/reference.py index 4b9d8dd..fa0b33c 100644 --- a/problems/nvidia/nvfp4_group_gemm/reference.py +++ b/problems/nvidia/nvfp4_group_gemm/reference.py @@ -9,6 +9,7 @@ def ceil_div(a, b): return (a + b - 1) // b + # Helper function to convert scale factor tensor to blocked format def to_blocked(input_matrix): rows, cols = input_matrix.shape @@ -23,13 +24,14 @@ def to_blocked(input_matrix): return rearranged.flatten() + def ref_kernel( data: input_t, ) -> output_t: """ PyTorch reference implementation of NVFP4 block-scaled group GEMM. """ - abc_tensors, sfasfb_tensors, problem_sizes = data + abc_tensors, sfasfb_tensors, _, problem_sizes = data result_tensors = [] for i, ( @@ -60,6 +62,42 @@ def ref_kernel( result_tensors.append((c_ref)) return result_tensors + +# Reorder scale factor from (mn, l, sf_k) to (32, 4, rest_m, 4, rest_k, l) layout +def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): + sf_k = ceil_div(k, sf_vec_size) + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on CPU. + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) + reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) + + # Please note this movement code is very slow. + for i in range(mn): + for j in range(sf_k): + for b in range(l): + # Calculate the location in MMA shape + mm = i // (atom_m[0] * atom_m[1]) + mm32 = i % atom_m[0] + mm4 = (i % 128) // atom_m[0] + kk = j // atom_k + kk4 = j % atom_k + reordered_f8_tensor[mm32, mm4, mm, kk4, kk, b] = ref_f8_tensor[i, j, b] + return reordered_f8_tensor.cuda() + + def generate_input( m: int, n: int, @@ -81,17 +119,20 @@ def generate_input( seed: Random seed for reproducibility Returns: - Tuple of (list(tuple(a, b, c)), list(tuple(sfa, sfb)), list(tuple(m, n, k, l))) where each group has its own a, b, c, sfa, sfb. + Tuple of (list(tuple(a, b, c)), list(tuple(sfa, sfb)), list(tuple(sfa_reordered, sfb_reordered)), list(tuple(m, n, k, l))) where each group has its own a, b, c, sfa, sfb. a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type b: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b: [n, k, l] - Input scale factors in torch.float8e4m3fn data type + sfa: [m, k // 16, l] - Input scale factors in torch.float8e4m3fn data type + sfb: [n, k // 16, l] - Input scale factors in torch.float8e4m3fn data type + sfa_reordered: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + sfb_reordered: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type c: [m, n, l] - Output matrix in torch.float16 data type """ torch.manual_seed(seed) abc_tensors = [] sfasfb_tensors = [] + sfasfb_reordered_tensors = [] problem_sizes = [] l = 1 # Generate a, b, c, sfa, sfb tensors for all groups @@ -117,10 +158,15 @@ def generate_input( 1, 3, (l, n, sf_k), dtype=torch.int8 ).to(dtype=torch.float8_e4m3fn).permute(1, 2, 0) + sfa_reordered = create_reordered_scale_factor_tensor(l, m, k, sfa_ref_cpu) + sfb_reordered = create_reordered_scale_factor_tensor(l, n, k, sfb_ref_cpu) + abc_tensors.append((a_ref, b_ref, c_ref)) sfasfb_tensors.append((sfa_ref_cpu, sfb_ref_cpu)) + sfasfb_reordered_tensors.append((sfa_reordered, sfb_reordered)) problem_sizes.append((m, n, k, l)) - return (abc_tensors, sfasfb_tensors, problem_sizes) + return (abc_tensors, sfasfb_tensors, sfasfb_reordered_tensors, problem_sizes) + check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) diff --git a/problems/nvidia/nvfp4_group_gemm/submission.py b/problems/nvidia/nvfp4_group_gemm/submission.py index c7affe3..f2eff27 100644 --- a/problems/nvidia/nvfp4_group_gemm/submission.py +++ b/problems/nvidia/nvfp4_group_gemm/submission.py @@ -48,32 +48,6 @@ def ceil_div(a, b): return (a + b - 1) // b -# Helper function to reorder the scale factor tensor to match the layout defined in -# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout -@cute.jit -def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - sf_ref_ptr: cute.Pointer, - sf_mma_ptr: cute.Pointer, - mn: int, - sf_k: int, - l: int, - mma_shape: tuple, -): - mma_permute_order = (3, 4, 1, 5, 2, 0) - permuted_shape = tuple(mma_shape[i] for i in mma_permute_order) - cute_layout = cute.make_ordered_layout(permuted_shape, order=(2, 1, 4, 0, 3, 5)) - - sf_ref_tensor = cute.make_tensor( - sf_ref_ptr, cute.make_layout((mn, sf_k, l), stride=(sf_k, 1, mn * sf_k)) - ) - sf_mma_tensor = cute.make_tensor(sf_mma_ptr, cute_layout) - sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) - sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) - for i in cutlass.range(cute.size(sf_ref_tensor)): - mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) - sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] - - # The CuTe reference implementation for NVFP4 block-scaled GEMM @cute.kernel def kernel( @@ -902,48 +876,6 @@ def my_kernel( return -# Reorder scale factor from (mn, l, sf_k) to (32, 4, rest_m, 4, rest_k, l) layout -def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): - sf_k = ceil_div(k, sf_vec_size) - atom_m = (32, 4) - atom_k = 4 - mma_shape = ( - l, # batch size - ceil_div(mn, atom_m[0] * atom_m[1]), - ceil_div(sf_k, atom_k), - atom_m[0], - atom_m[1], - atom_k, - ) - # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on CPU. - mma_permute_order = (3, 4, 1, 5, 2, 0) - # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) - reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) - # Permute according to mma_permute_order - reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) - - # Helper function to convert scale factor tensor to CUTE-format scale factor tensor - cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - make_ptr( - cutlass.Float8E4M3FN, - ref_f8_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32, - ), - make_ptr( - cutlass.Float8E4M3FN, - reordered_f8_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32, - ), - mn, - sf_k, - l, - mma_shape, - ) - return reordered_f8_tensor.cuda() - _compiled_kernel_cache = None def compile_kernel(): @@ -974,7 +906,7 @@ def custom_kernel(data: input_t) -> output_t: Returns: list of c tensors where c is torch.Tensor[float16] of shape [m, n, l] for each group """ - abc_tensors, sfasfb_tensors, problem_sizes = data + abc_tensors, _, sfasfb_reordered_tensors, problem_sizes = data # Choose A, B, C, SFA, SFB with the smallest size to create initial tensormaps key_size_a = lambda item: item[1][0] * item[1][2] @@ -985,14 +917,9 @@ def custom_kernel(data: input_t) -> output_t: min_b_idx, _ = min(enumerate(problem_sizes), key=key_size_b) min_c_idx, _ = min(enumerate(problem_sizes), key=key_size_c) - sfasfb_reordered_tensors = [] abc_ptrs = [] sfasfb_ptrs = [] - for i, ((a, b, c), (sfa_cpu, sfb_cpu), (m, n, k, l)) in enumerate(zip(abc_tensors, sfasfb_tensors, problem_sizes)): - sf_k = ceil_div(k, sf_vec_size) - sfa_reordered = create_reordered_scale_factor_tensor(l, m, k, sfa_cpu) - sfb_reordered = create_reordered_scale_factor_tensor(l, n, k, sfb_cpu) - sfasfb_reordered_tensors.append((sfa_reordered, sfb_reordered)) + for i, ((a, b, c), (sfa_reordered, sfb_reordered), (m, n, k, l)) in enumerate(zip(abc_tensors, sfasfb_reordered_tensors, problem_sizes)): abc_ptrs.append((a.data_ptr(), b.data_ptr(), c.data_ptr())) sfasfb_ptrs.append((sfa_reordered.data_ptr(), sfb_reordered.data_ptr())) diff --git a/problems/nvidia/nvfp4_group_gemm/task.py b/problems/nvidia/nvfp4_group_gemm/task.py index 6e0961f..94c1143 100644 --- a/problems/nvidia/nvfp4_group_gemm/task.py +++ b/problems/nvidia/nvfp4_group_gemm/task.py @@ -1,7 +1,7 @@ import torch from typing import TypedDict, TypeVar -input_t = TypeVar("input_t", bound=tuple[list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], list[tuple[torch.Tensor, torch.Tensor]], list[tuple[int, int, int, int]]]) +input_t = TypeVar("input_t", bound=tuple[list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], list[tuple[torch.Tensor, torch.Tensor]], list[tuple[torch.Tensor, torch.Tensor]], list[tuple[int, int, int, int]]]) output_t = TypeVar("output_t", bound=list[torch.Tensor]) class TestSpec(TypedDict): problem_sizes: list[tuple[int, int, int, int]] diff --git a/problems/nvidia/nvfp4_group_gemm/template.py b/problems/nvidia/nvfp4_group_gemm/template.py index ea034a9..b6005fa 100644 --- a/problems/nvidia/nvfp4_group_gemm/template.py +++ b/problems/nvidia/nvfp4_group_gemm/template.py @@ -5,7 +5,7 @@ def custom_kernel(data: input_t) -> output_t: """ Reference implementation of block-scale fp4 group gemm Args: - data: list of tuples (abc_tensors, sfasfb_tensors, problem_sizes) where: + data: list of tuples (abc_tensors, sfasfb_tensors, sfasfb_reordered_tensors, problem_sizes) where: abc_tensors: list of tuples (a, b, c) where a is torch.Tensor[float4e2m1fn_x2] of shape [m, k // 2, l] b is torch.Tensor[float4e2m1fn_x2] of shape [n, k // 2, l] @@ -13,15 +13,18 @@ def custom_kernel(data: input_t) -> output_t: sfasfb_tensors: list of tuples (sfa, sfb) where sfa is torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l] sfb is torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l] + sfasfb_reordered_tensors: list of tuples (sfa_reordered, sfb_reordered) where + sfa_reordered is torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l] + sfb_reordered is torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l] problem_sizes: list of tuples (m, n, k, l) each group has its own a, b, c, sfa, sfb with different m, n, k, l problem sizes l should always be 1 for each group. Returns: list of tuples (c) where c is torch.Tensor[float16] of shape [m, n, l] """ - abc_tensors, sfasfb_tensors, problem_sizes = data + abc_tensors, sfasfb_tensors, sfasfb_reordered_tensors, problem_sizes = data result_tensors = [] - for i, (a, b, c) in enumerate(abc_tensors): + for i, ((a, b, c), (sfa_reordered, sfb_reordered), (m, n, k, l)) in enumerate(zip(abc_tensors, sfasfb_reordered_tensors, problem_sizes)): # add you implementation here result_tensors.append(c) From 3ad768000e89bf7c87c51697dcda9b15c09230de Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Wed, 22 Oct 2025 22:46:00 -0700 Subject: [PATCH 18/29] improve speed of light analysis. --- problems/nvidia/nvfp4_dual_gemm/task.yml | 10 +++++----- problems/nvidia/nvfp4_gemm/task.yml | 10 +++++----- problems/nvidia/nvfp4_gemv/task.yml | 14 +++++++------- problems/nvidia/nvfp4_group_gemm/task.yml | 10 +++++----- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/problems/nvidia/nvfp4_dual_gemm/task.yml b/problems/nvidia/nvfp4_dual_gemm/task.yml index 33ceff4..2bac435 100644 --- a/problems/nvidia/nvfp4_dual_gemm/task.yml +++ b/problems/nvidia/nvfp4_dual_gemm/task.yml @@ -30,12 +30,12 @@ description: | For the grand price, your kernel will be evaluated against the speed of light analysis and the solution closest to the speed of light will be awarded the grand price. ``` - The speed of light analysis is (using 1.5Ghz clock): + The speed of light analysis based on the max(FP4 Tensor Core math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock: M N K L time[us] - 128 4096 7168 1 1.09 - 512 4096 7168 1 4.36 - 128 3072 4096 1 0.47 - 512 3072 7168 1 3.27 + 128 4096 7168 1 4.505 + 512 4096 7168 1 8.714 + 128 3072 4096 1 1.984 + 512 3072 7168 1 6.535 ``` config: main: "eval.py" diff --git a/problems/nvidia/nvfp4_gemm/task.yml b/problems/nvidia/nvfp4_gemm/task.yml index 94334f5..06388bb 100644 --- a/problems/nvidia/nvfp4_gemm/task.yml +++ b/problems/nvidia/nvfp4_gemm/task.yml @@ -28,11 +28,11 @@ description: | For the grand price, your kernel will be evaluated against the speed of light analysis and the solution closest to the speed of light will be awarded the grand price. ``` - The speed of light analysis is (using 1.5Ghz clock): - M N K L time[us] - 128 7168 16384 1 4.36 - 128 4096 7168 1 1.09 - 128 7168 2048 1 0.55 + The speed of light analysis based on the max(FP4 Tensor Core math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock: + M N K L time[us] + 128 7168 16384 1 8.994 + 128 4096 7168 1 2.354 + 128 7168 2048 1 1.333 ``` config: main: "eval.py" diff --git a/problems/nvidia/nvfp4_gemv/task.yml b/problems/nvidia/nvfp4_gemv/task.yml index 01f425b..adf8cab 100644 --- a/problems/nvidia/nvfp4_gemv/task.yml +++ b/problems/nvidia/nvfp4_gemv/task.yml @@ -28,11 +28,11 @@ description: | For the grand price, your kernel will be evaluated against the speed of light analysis and the solution closest to the speed of light will be awarded the grand price. ``` - The speed of light analysis is (using 1.5Ghz clock): - M K L time[us] - 7168 16384 1 7.65 - 4096 7168 1 1.91 - 7168 2048 1 0.96 + The speed of light analysis based on the max(FFMA math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock: + M K L time[us] + 7168 16384 1 8.622 + 4096 7168 8 17.275 + 7168 2048 4 4.317 ``` config: main: "eval.py" @@ -55,7 +55,7 @@ tests: benchmarks: - {"m": 7168, "k": 16384, "l":1, "seed": 1111} - - {"m": 4096, "k": 7168, "l":1, "seed": 1111} - - {"m": 7168, "k": 2048, "l":1, "seed": 1111} + - {"m": 4096, "k": 7168, "l":8, "seed": 1111} + - {"m": 7168, "k": 2048, "l":4, "seed": 1111} ranking_by: "geom" \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/task.yml b/problems/nvidia/nvfp4_group_gemm/task.yml index 82a0640..9f0df2b 100644 --- a/problems/nvidia/nvfp4_group_gemm/task.yml +++ b/problems/nvidia/nvfp4_group_gemm/task.yml @@ -31,12 +31,12 @@ description: | For the grand price, your kernel will be evaluated against the speed of light analysis and the solution closest to the speed of light will be awarded the grand price. ``` - The speed of light analysis is (using 1.5Ghz clock): + The speed of light analysis based on the max(FP4 Tensor Core math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock: G M N K L time[us] - 8 128 4096 7168 1 8.71 - 8 128 7168 2048 1 4.36 - 2 256 3072 4096 1 1.87 - 2 256 4096 1536 1 0.93 + 8 128 4096 7168 1 18.833 + 8 128 7168 2048 1 10.667 + 2 256 3072 4096 1 2.406 + 2 256 4096 1536 1 1.525 ``` config: main: "eval.py" From 0d7d037f0d2217207f15fd37e65c816b250159d3 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Wed, 22 Oct 2025 23:54:37 -0700 Subject: [PATCH 19/29] improve comments. --- problems/nvidia/nvfp4_dual_gemm/reference.py | 9 +- problems/nvidia/nvfp4_dual_gemm/submission.py | 66 +++++---- problems/nvidia/nvfp4_gemm/reference.py | 6 +- problems/nvidia/nvfp4_gemm/submission.py | 52 ++++--- problems/nvidia/nvfp4_gemv/reference.py | 9 +- problems/nvidia/nvfp4_gemv/submission.py | 5 +- problems/nvidia/nvfp4_group_gemm/reference.py | 5 +- .../nvidia/nvfp4_group_gemm/submission.py | 127 ++++++++++++------ 8 files changed, 180 insertions(+), 99 deletions(-) diff --git a/problems/nvidia/nvfp4_dual_gemm/reference.py b/problems/nvidia/nvfp4_dual_gemm/reference.py index c64927c..29ee4cb 100644 --- a/problems/nvidia/nvfp4_dual_gemm/reference.py +++ b/problems/nvidia/nvfp4_dual_gemm/reference.py @@ -28,7 +28,8 @@ def ref_kernel( data: input_t, ) -> output_t: """ - PyTorch reference implementation of NVFP4 block-scaled GEMM. + PyTorch reference implementation of NVFP4 block-scaled dual GEMM with silu activation, + C = silu(A @ B1) * (A @ B2). """ a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, _, _, _, c_ref = data @@ -130,7 +131,8 @@ def generate_input( # Helper function to prepare the scale factor tensors for both reference # kernel and customize kernel. Please note this data reordering function - # is very slow. + # is very slow, and the customized data layout can be found in the following link: + # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout def create_scale_factor_tensors(l, mn, sf_k): # Create the reference scale factor tensor (mn, l, sf_k) on CPU. ref_shape = (l, mn, sf_k) @@ -184,4 +186,5 @@ def create_scale_factor_tensors(l, mn, sf_k): return (a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, sfa_ref_permuted, sfb1_ref_permuted, sfb2_ref_permuted, c_ref) -check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) + +check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_dual_gemm/submission.py b/problems/nvidia/nvfp4_dual_gemm/submission.py index e9f6a6d..a30af22 100644 --- a/problems/nvidia/nvfp4_dual_gemm/submission.py +++ b/problems/nvidia/nvfp4_dual_gemm/submission.py @@ -803,28 +803,49 @@ def my_kernel( c_tensor.shape[2], ) - # Launch the kernel synchronously + # Launch the kernel. kernel( - tiled_mma, - tma_atom_a, - tma_tensor_a, - tma_atom_b1, - tma_tensor_b1, - tma_atom_b2, - tma_tensor_b2, - tma_atom_sfa, - tma_tensor_sfa, - tma_atom_sfb1, - tma_tensor_sfb1, - tma_atom_sfb2, - tma_tensor_sfb2, - c_tensor, - a_smem_layout_staged, - b_smem_layout_staged, - sfa_smem_layout_staged, - sfb_smem_layout_staged, - num_tma_load_bytes, - epilogue_op, + # MMA (Matrix Multiply-Accumulate) configuration + tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern + + # TMA (Tensor Memory Accelerator) atoms and tensors for shared input matrix A + tma_atom_a, # TMA copy atom defining how to load A from global memory + tma_tensor_a, # Tensor descriptor for A matrix (m, k, l) - shared by both GEMMs + + # TMA atoms and tensors for first B matrix (B1) + tma_atom_b1, # TMA copy atom defining how to load B1 from global memory + tma_tensor_b1, # Tensor descriptor for B1 matrix (n, k, l) - first GEMM + + # TMA atoms and tensors for second B matrix (B2) + tma_atom_b2, # TMA copy atom defining how to load B2 from global memory + tma_tensor_b2, # Tensor descriptor for B2 matrix (n, k, l) - second GEMM + + # TMA atoms and tensors for scale factor A (shared) + tma_atom_sfa, # TMA copy atom for loading scale factors for A + tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) - shared + + # TMA atoms and tensors for scale factor B1 + tma_atom_sfb1, # TMA copy atom for loading scale factors for B1 + tma_tensor_sfb1, # Tensor descriptor for SFB1 (block scale factors for B1) + + # TMA atoms and tensors for scale factor B2 + tma_atom_sfb2, # TMA copy atom for loading scale factors for B2 + tma_tensor_sfb2, # Tensor descriptor for SFB2 (block scale factors for B2) + + # Output tensor C (stores both C1 and C2 results) + c_tensor, # Output tensor where both GEMM results will be stored (m, n, l) + + # Shared memory layouts with staging for pipelined execution + a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) + b_smem_layout_staged, # Staged shared memory layout for B1/B2 (includes stage dimension) + sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) + sfb_smem_layout_staged, # Staged shared memory layout for SFB1/SFB2 (includes stage dimension) + + # Pipeline synchronization parameter + num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) + + # Epilogue operation + epilogue_op, # Epilogue operation to apply to output (e.g., element-wise ops) ).launch( grid=grid, block=[threads_per_cta, 1, 1], @@ -841,9 +862,6 @@ def compile_kernel(): Compile the kernel once and cache it. This should be called before any timing measurements. - Args: - a, b1, b2, sfa, sfb1, sfb2, c: Sample tensors with the expected shapes and types - Returns: The compiled kernel function """ diff --git a/problems/nvidia/nvfp4_gemm/reference.py b/problems/nvidia/nvfp4_gemm/reference.py index fecb7b6..cd6dfb4 100644 --- a/problems/nvidia/nvfp4_gemm/reference.py +++ b/problems/nvidia/nvfp4_gemm/reference.py @@ -99,7 +99,8 @@ def generate_input( # Helper function to prepare the scale factor tensors for both reference # kernel and customize kernel. Please note this data reordering function - # is very slow. + # is very slow, and the customized data layout can be found in the following link: + # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout def create_scale_factor_tensors(l, mn, sf_k): # Create the reference scale factor tensor (mn, l, sf_k) on CPU. ref_shape = (l, mn, sf_k) @@ -149,7 +150,8 @@ def create_scale_factor_tensors(l, mn, sf_k): sf_k = ceil_div(k, sf_vec_size) sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) sfb_ref_cpu, sfb_ref_permuted = create_scale_factor_tensors(l, n, sf_k) - + return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, sfa_ref_permuted, sfb_ref_permuted, c_ref) + check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) diff --git a/problems/nvidia/nvfp4_gemm/submission.py b/problems/nvidia/nvfp4_gemm/submission.py index 0089c25..0d2a9d4 100644 --- a/problems/nvidia/nvfp4_gemm/submission.py +++ b/problems/nvidia/nvfp4_gemm/submission.py @@ -636,23 +636,38 @@ def my_kernel( c_tensor.shape[2], ) - # Launch the kernel synchronously + # Launch the kernel kernel( - tiled_mma, - tma_atom_a, - tma_tensor_a, - tma_atom_b, - tma_tensor_b, - tma_atom_sfa, - tma_tensor_sfa, - tma_atom_sfb, - tma_tensor_sfb, - c_tensor, - a_smem_layout_staged, - b_smem_layout_staged, - sfa_smem_layout_staged, - sfb_smem_layout_staged, - num_tma_load_bytes, + # MMA (Matrix Multiply-Accumulate) configuration + tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern + + # TMA (Tensor Memory Accelerator) atoms and tensors for input matrix A + tma_atom_a, # TMA copy atom defining how to load A from global memory + tma_tensor_a, # Tensor descriptor for A matrix (m, k, l) + + # TMA atoms and tensors for input matrix B + tma_atom_b, # TMA copy atom defining how to load B from global memory + tma_tensor_b, # Tensor descriptor for B matrix (n, k, l) + + # TMA atoms and tensors for scale factor A + tma_atom_sfa, # TMA copy atom for loading scale factors for A + tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) + + # TMA atoms and tensors for scale factor B + tma_atom_sfb, # TMA copy atom for loading scale factors for B + tma_tensor_sfb, # Tensor descriptor for SFB (block scale factors for B) + + # Output tensor C + c_tensor, # Output tensor C where result will be stored (m, n, l) + + # Shared memory layouts with staging for pipelined execution + a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) + b_smem_layout_staged, # Staged shared memory layout for B (includes stage dimension) + sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) + sfb_smem_layout_staged, # Staged shared memory layout for SFB (includes stage dimension) + + # Pipeline synchronization parameter + num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) ).launch( grid=grid, block=[threads_per_cta, 1, 1], @@ -668,10 +683,7 @@ def compile_kernel(): """ Compile the kernel once and cache it. This should be called before any timing measurements. - - Args: - a, b, scale_a, scale_b, c: Sample tensors with the expected shapes and types - + Returns: The compiled kernel function """ diff --git a/problems/nvidia/nvfp4_gemv/reference.py b/problems/nvidia/nvfp4_gemv/reference.py index ee06e24..0a9c93d 100644 --- a/problems/nvidia/nvfp4_gemv/reference.py +++ b/problems/nvidia/nvfp4_gemv/reference.py @@ -102,7 +102,10 @@ def generate_input( 1, 2, 0 ) - # Helper function to prepare the scale factor tensors + # Helper function to prepare the scale factor tensors for both reference + # kernel and customize kernel. Please note this data reordering function + # is very slow, and the customized data layout can be found in the following link: + # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout def create_scale_factor_tensors(l, mn, sf_k): # Create the reference scale factor tensor (mn, l, sf_k) on CPU. ref_shape = (l, mn, sf_k) @@ -136,7 +139,6 @@ def create_scale_factor_tensors(l, mn, sf_k): reordered_f8_torch_tensor_cpu = reordered_f8_torch_tensor_cpu.permute( *mma_permute_order ) - for i in range(mn): for j in range(sf_k): for b in range(l): @@ -155,4 +157,5 @@ def create_scale_factor_tensors(l, mn, sf_k): return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, sfa_permuted, sfb_permuted, c_ref) -check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) + +check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemv/submission.py b/problems/nvidia/nvfp4_gemv/submission.py index 8ee6814..798afd7 100644 --- a/problems/nvidia/nvfp4_gemv/submission.py +++ b/problems/nvidia/nvfp4_gemv/submission.py @@ -170,10 +170,7 @@ def compile_kernel(): """ Compile the kernel once and cache it. This should be called before any timing measurements. - - Args: - a, b, scale_a, scale_b, c: Sample tensors with the expected shapes and types - + Returns: The compiled kernel function """ diff --git a/problems/nvidia/nvfp4_group_gemm/reference.py b/problems/nvidia/nvfp4_group_gemm/reference.py index fa0b33c..5fcb386 100644 --- a/problems/nvidia/nvfp4_group_gemm/reference.py +++ b/problems/nvidia/nvfp4_group_gemm/reference.py @@ -63,7 +63,10 @@ def ref_kernel( return result_tensors -# Reorder scale factor from (mn, l, sf_k) to (32, 4, rest_m, 4, rest_k, l) layout +# Helper function to prepare the scale factor tensors for both reference +# kernel and customize kernel. Please note this data reordering function +# is very slow, and the customized data layout can be found in the following link: +# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): sf_k = ceil_div(k, sf_vec_size) atom_m = (32, 4) diff --git a/problems/nvidia/nvfp4_group_gemm/submission.py b/problems/nvidia/nvfp4_group_gemm/submission.py index f2eff27..c9ab35f 100644 --- a/problems/nvidia/nvfp4_group_gemm/submission.py +++ b/problems/nvidia/nvfp4_group_gemm/submission.py @@ -17,7 +17,6 @@ from cutlass.cute.runtime import make_ptr # Kernel configuration parameters - # Size of tma descriptor in bytes bytes_per_tensormap = 128 # Number of tensormaps: a, b, sfa, sfb @@ -288,8 +287,8 @@ class SharedStorage: cute.assume(n * k, 32), ), ) - # SFA, SFB follows specialized layout defined - # here: TODO add linke + # SFA, SFB follows specialized layout defined in the following link: + # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout atom_shape = ((32, 4), (sf_vec_size, 4)) atom_stride = ((16, 4), (0, 1)) sfa_layout = cute.tile_to_shape( @@ -847,27 +846,44 @@ def my_kernel( # Compute grid size grid = (1, 1, total_num_clusters) - # Launch the kernel synchronously + # Launch the kernel kernel( - tiled_mma, - tma_atom_a, - tma_tensor_a, - tma_atom_b, - tma_tensor_b, - tma_atom_sfa, - tma_tensor_sfa, - tma_atom_sfb, - tma_tensor_sfb, - tensor_of_abc_ptrs, - tensor_of_sfasfb_ptrs, - tensor_of_tensormap, - tensor_of_problem_sizes, - a_smem_layout_staged, - b_smem_layout_staged, - sfa_smem_layout_staged, - sfb_smem_layout_staged, - cta_mn_list, - num_tma_load_bytes, + # MMA (Matrix Multiply-Accumulate) configuration + tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern + + # TMA (Tensor Memory Accelerator) atoms and tensors for input matrix A + tma_atom_a, # TMA copy atom defining how to load A from global memory + tma_tensor_a, # Tensor descriptor for A (created from smallest A tensor) + + # TMA atoms and tensors for input matrix B + tma_atom_b, # TMA copy atom defining how to load B from global memory + tma_tensor_b, # Tensor descriptor for B (created from smallest B tensor) + + # TMA atoms and tensors for scale factor A + tma_atom_sfa, # TMA copy atom for loading scale factors for A + tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) + + # TMA atoms and tensors for scale factor B + tma_atom_sfb, # TMA copy atom for loading scale factors for B + tma_tensor_sfb, # Tensor descriptor for SFB (block scale factors for B) + + # Runtime tensor metadata for dynamic group access + tensor_of_abc_ptrs, # Device tensor containing pointers to A, B, C for all groups + tensor_of_sfasfb_ptrs, # Device tensor containing pointers to SFA, SFB for all groups + tensor_of_tensormap, # Pre-allocated buffer for tensormap descriptors per CTA + tensor_of_problem_sizes, # Device tensor containing (m, n, k, l) for each group + + # Shared memory layouts with staging for pipelined execution + a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) + b_smem_layout_staged, # Staged shared memory layout for B (includes stage dimension) + sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) + sfb_smem_layout_staged, # Staged shared memory layout for SFB (includes stage dimension) + + # CTA grid configuration per group + cta_mn_list, # List of (M_tiles, N_tiles) for each group + + # Pipeline synchronization parameter + num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) ).launch( grid=grid, block=[threads_per_cta, 1, 1], @@ -917,26 +933,35 @@ def custom_kernel(data: input_t) -> output_t: min_b_idx, _ = min(enumerate(problem_sizes), key=key_size_b) min_c_idx, _ = min(enumerate(problem_sizes), key=key_size_c) + # Extract raw data pointers from all input tensors for each group + # These will be passed to the GPU kernel to access the actual tensor data abc_ptrs = [] sfasfb_ptrs = [] for i, ((a, b, c), (sfa_reordered, sfb_reordered), (m, n, k, l)) in enumerate(zip(abc_tensors, sfasfb_reordered_tensors, problem_sizes)): + # Store pointers to A, B, and C matrices for this group abc_ptrs.append((a.data_ptr(), b.data_ptr(), c.data_ptr())) + # Store pointers to scale factor tensors for this group sfasfb_ptrs.append((sfa_reordered.data_ptr(), sfb_reordered.data_ptr())) - # Pick the tensor with the smallest size to create initial tensormaps + # Create initial CuTe pointers from the smallest tensors for tensormap initialization + # Using smallest tensors helps with efficient TMA (Tensor Memory Accelerator) setup + # These will be used as templates to create tensormaps for all other tensors initial_cute_abc_ptrs = ( + # Pointer to the smallest A matrix (FP4 type) make_ptr( ab_dtype, abc_tensors[min_a_idx][0].data_ptr(), cute.AddressSpace.gmem, assumed_align=16, ), + # Pointer to the smallest B matrix (FP4 type) make_ptr( ab_dtype, abc_tensors[min_b_idx][1].data_ptr(), cute.AddressSpace.gmem, assumed_align=16, ), + # Pointer to the smallest C matrix (FP16 type, output) make_ptr( c_dtype, abc_tensors[min_c_idx][2].data_ptr(), @@ -945,12 +970,14 @@ def custom_kernel(data: input_t) -> output_t: ), ) initial_cute_sfasfb_ptrs = ( + # Pointer to the smallest scale factor A tensor (FP8 type) make_ptr( sf_dtype, sfasfb_reordered_tensors[min_a_idx][0].data_ptr(), cute.AddressSpace.gmem, assumed_align=16, ), + # Pointer to the smallest scale factor B tensor (FP8 type) make_ptr( sf_dtype, sfasfb_reordered_tensors[min_b_idx][1].data_ptr(), @@ -959,32 +986,45 @@ def custom_kernel(data: input_t) -> output_t: ), ) - # Create torch tensor to store problem sizes - # layout (num_groups, 4):(4, 1) + # Create torch tensor to store problem sizes for all groups + # Shape: (num_groups, 4) where each row contains (m, n, k, l) for that group + # Layout: (num_groups, 4):(4, 1) means row-major storage tensor_of_problem_sizes = torch.tensor( problem_sizes, dtype=torch.int32, device="cuda" ) - # Create torch tensors to store abc_ptrs and sfasfb_ptrs - # layout (num_groups,3):(3, 1) + # Create torch tensors to store data pointers for all groups + # These allow the GPU kernel to dynamically access different tensors per group + # tensor_of_abc_ptrs: Shape (num_groups, 3) containing (a_ptr, b_ptr, c_ptr) per group + # tensor_of_sfasfb_ptrs: Shape (num_groups, 2) containing (sfa_ptr, sfb_ptr) per group tensor_of_abc_ptrs = torch.tensor(abc_ptrs, dtype=torch.int64, device="cuda") tensor_of_sfasfb_ptrs = torch.tensor(sfasfb_ptrs, dtype=torch.int64, device="cuda") - # Compute cluster tile shape + # Compute the tile shape for each CUDA Thread Block (CTA) + # cta_tile_shape_mn: [M_tile, N_tile] = [128, 128] for this kernel cta_tile_shape_mn = [128, mma_tiler_mnk[1]] + # cluster_tile_shape_mn: Total tile shape per cluster (same as CTA since cluster is 1x1) cluster_tile_shape_mn = tuple( x * y for x, y in zip(cta_tile_shape_mn, (1, 1)) ) - # Compute total number of cluster tiles we need to compute for given grouped GEMM problem + + # Compute total number of cluster tiles needed across all groups + # Each group's (m, n) dimensions are divided into tiles of size cluster_tile_shape_mn + # This determines the total grid size (bidz dimension) for kernel launch total_num_clusters = 0 num_groups = len(problem_sizes) for m, n, _, _ in problem_sizes: + # Calculate number of tiles needed in M and N dimensions for this group num_clusters_mn = tuple( (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) ) + # Multiply M_tiles * N_tiles to get total tiles for this group total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) - # Preserved buffers for each cluster to update its tma descriptor in device memory + # Allocate device memory for tensormap descriptors + # Each cluster needs its own set of tensormaps (one for A, B, SFA, SFB) + # Shape: (total_num_clusters, num_tensormaps=4, bytes_per_tensormap/8=16) + # Tensormaps are hardware descriptors used by TMA for efficient memory transfers tensormap_shape = ( total_num_clusters, num_tensormaps, @@ -992,6 +1032,8 @@ def custom_kernel(data: input_t) -> output_t: ) tensor_of_tensormap = torch.empty(tensormap_shape, dtype=torch.int64, device="cuda") + # Create CuTe pointers to the metadata tensors that will be passed to the kernel + # These allow the GPU kernel to read problem sizes and tensor pointers cute_ptr_of_tensor_of_abc_ptrs = make_ptr( cutlass.Int64, tensor_of_abc_ptrs.data_ptr(), @@ -1011,18 +1053,19 @@ def custom_kernel(data: input_t) -> output_t: assumed_align=16, ) - # Execute the compiled kernel + # Launch the JIT-compiled GPU kernel with all prepared data + # The kernel will perform block-scaled group GEMM: C = A * SFA * B * SFB for all groups my_kernel( - initial_cute_abc_ptrs, - initial_cute_sfasfb_ptrs, - (min_a_idx, min_b_idx, min_c_idx), - cute_ptr_of_tensor_of_problem_sizes, - cute_ptr_of_tensor_of_abc_ptrs, - cute_ptr_of_tensor_of_sfasfb_ptrs, - total_num_clusters, - problem_sizes, - tensor_of_tensormap, - num_groups, + initial_cute_abc_ptrs, # Template pointers for tensormap initialization + initial_cute_sfasfb_ptrs, # Template scale factor pointers + (min_a_idx, min_b_idx, min_c_idx), # Indices of smallest tensors + cute_ptr_of_tensor_of_problem_sizes, # Pointer to problem sizes array + cute_ptr_of_tensor_of_abc_ptrs, # Pointer to ABC tensor pointers array + cute_ptr_of_tensor_of_sfasfb_ptrs, # Pointer to scale factor pointers array + total_num_clusters, # Total number of CTAs to launch + problem_sizes, # Problem sizes list (for host-side processing) + tensor_of_tensormap, # Pre-allocated tensormap buffer + num_groups, # Number of groups in this batch ) res = [] From 1b762721f6124a045db19c272f815d6523cd92d1 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Tue, 4 Nov 2025 03:53:56 +0100 Subject: [PATCH 20/29] WIP: work on integrating with the platform --- problems/nvidia/eval.py | 489 ++++++++++++++++++ problems/nvidia/nvfp4_gemm/utils.py | 176 ------- problems/nvidia/nvfp4_gemv/task.yml | 4 +- problems/nvidia/nvfp4_gemv/utils.py | 176 ------- problems/nvidia/nvfp4_group_gemm/task.yml | 2 +- .../nvidia/{nvfp4_dual_gemm => }/utils.py | 0 6 files changed, 491 insertions(+), 356 deletions(-) create mode 100644 problems/nvidia/eval.py delete mode 100644 problems/nvidia/nvfp4_gemm/utils.py delete mode 100644 problems/nvidia/nvfp4_gemv/utils.py rename problems/nvidia/{nvfp4_dual_gemm => }/utils.py (100%) diff --git a/problems/nvidia/eval.py b/problems/nvidia/eval.py new file mode 100644 index 0000000..6286f7f --- /dev/null +++ b/problems/nvidia/eval.py @@ -0,0 +1,489 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional +import tempfile + +import torch.cuda +from cutlass.cute.nvgpu.common import OpError + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, "w") + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats( + runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) + ) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + + data = generate_input(**test.args) + torch.cuda.synchronize() + try: + submission_output = custom_kernel(_clone_data(data)) + + except OpError as E: + print(f"Encountered {E}", file=sys.stderr) + return False, str(E) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + # Step 1: Compile kernel once before running tests + # compile_success, compile_error = pool.apply(_compile_kernel_once) + # if not compile_success: + # return 112 + + # Step 2: Run all tests with compiled kernel + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _compile_kernel_once(): + """ + Compile the kernel once before any benchmarking. + This ensures compilation time is not included in benchmark results. + """ + from submission import compile_kernel + + try: + compile_kernel() + torch.cuda.synchronize() + return True, None + except OpError as E: + return False, f"Compilation failed: {E}" + except Exception as E: + return False, f"Compilation failed: {E}" + + +def _run_single_benchmark( + test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float +) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel, compile_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + + # Ensure kernel is compiled before any timing (compilation is cached) + # try: + # a, b, c = data + # compile_kernel(a, b, c) + # torch.cuda.synchronize() + # except OpError as E: + # return f"Compilation failed: {E}" + # except Exception as E: + # return f"Compilation failed: {E}" + + # first, one obligatory correctness check + try: + output = custom_kernel(_clone_data(data)) + except OpError as E: + return f"Encountered {E}" + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 200 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if ( + stats.err / stats.mean < 0.001 + or stats.mean * stats.runs > max_time_ns + or total_bm_duration > 120e9 + ): + break + + return calculate_stats(durations) + + +def run_single_benchmark( + pool: multiprocessing.Pool, + test: TestCase, + recheck: bool, + max_repeats: int, + max_time_ns: float, +): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # Step 1: Compile kernel once (outside of timing) + # compile_success, compile_error = pool.apply(_compile_kernel_once) + # if not compile_success: + # return 112 + + # Step 2: Warm up with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 10e7) + + # Step 3: Run benchmarks (compilation time excluded) + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 200, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log( + f"benchmark.{idx}.report", + base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), + ) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + + # filename = None + + # with tempfile.NamedTemporaryFile(delete=False) as tmp: + + # def build_test_string(tests: list[dict]): + # as_str = "" + # for test in tests: + # kvs = [] + # for k, v in test.items(): + # kvs.append(f"{k}: {v}") + # as_str += "; ".join(kvs) + "\n" + # return as_str + + # import yaml + # print(sys.argv[2]) + # print(open(sys.argv[2], "r").read()) + + # yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + # if mode == "test": + # tests_str = build_test_string(yaml_content.get("tests", [])) + # elif mode in ("benchmark", "leaderboard", "profile"): + # tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + # tmp.write(tests_str.encode("utf-8")) + # tmp.flush() + # filename = tmp.name + + + tests = get_test_cases(sys.argv[2], seed) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + + mp_context = multiprocessing.get_context("spawn") + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # Step 1: Compile kernel once (outside of timing) + # compile_success, compile_error = pool.apply(_compile_kernel_once) + # if not compile_success: + # return 112 + + # Step 2: Warmup with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 1e7) + + # Step 3: Run leaderboard benchmarks (compilation time excluded) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 200, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log( + f"benchmark.{i}.{field.name}", + getattr(result, field.name), + ) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log( + f"benchmark.{i}.error", str(result) + ) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/problems/nvidia/nvfp4_gemm/utils.py b/problems/nvidia/nvfp4_gemm/utils.py deleted file mode 100644 index e8a9082..0000000 --- a/problems/nvidia/nvfp4_gemm/utils.py +++ /dev/null @@ -1,176 +0,0 @@ -import os -import random -import numpy as np -import torch - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_device(use_cuda: bool = True) -> torch.device: - """Get the appropriate device (GPU or CPU).""" - if use_cuda: - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - print("No compatible GPU found. Falling back to CPU.") - return torch.device("cpu") - - -# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py -@torch.no_grad() -def verbose_allclose( - received: torch.Tensor, - expected: torch.Tensor, - rtol=1e-05, - atol=1e-08, - max_print=5 -) -> list[str]: - """ - Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - rtol (float): Relative tolerance; relative to expected - atol (float): Absolute tolerance. - max_print (int): Maximum number of mismatched elements to print. - - Raises: - AssertionError: If the tensors are not all close within the given tolerance. - """ - # Check if the shapes of the tensors match - if received.shape != expected.shape: - return ["SIZE MISMATCH"] - - # Calculate the difference between the tensors - diff = torch.abs(received - expected) - - # Determine the tolerance - tolerance = atol + rtol * torch.abs(expected) - - # Find tolerance mismatched elements - tol_mismatched = diff > tolerance - - # Find nan mismatched elements - nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) - - # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) - # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) - - # Find all mismatched elements - mismatched = torch.logical_or( - torch.logical_or(tol_mismatched, nan_mismatched), - torch.logical_or(posinf_mismatched, neginf_mismatched), - ) - - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -@torch.no_grad() -def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): - """ - Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - max_print (int): Maximum number of mismatched elements to print. - - Returns: - Empty string if tensors are equal, otherwise detailed error information - """ - mismatched = torch.not_equal(received, expected) - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: - """ - Convenient "default" implementation for tasks' `check_implementation` function. - """ - expected = reference(data) - reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) - - if len(reasons) > 0: - return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) - - return True, '' - - -def make_match_reference(reference: callable, **kwargs): - def wrapped(data, output): - return match_reference(data, output, reference=reference, **kwargs) - return wrapped - - -class DeterministicContext: - def __init__(self): - self.allow_tf32 = None - self.deterministic = None - self.cublas = None - - def __enter__(self): - self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') - self.allow_tf32 = torch.backends.cudnn.allow_tf32 - self.deterministic = torch.backends.cudnn.deterministic - torch.backends.cudnn.allow_tf32 = False - torch.backends.cudnn.deterministic = True - torch.use_deterministic_algorithms(True) - return self - - def __exit__(self, exc_type, exc_value, traceback): - torch.backends.cudnn.allow_tf32 = self.allow_tf32 - torch.backends.cudnn.deterministic = self.deterministic - torch.use_deterministic_algorithms(False) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas - -def clear_l2_cache(): - # import cupy as cp - # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) - # create a large dummy tensor - dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") - # write stuff to - dummy.fill_(42) - del dummy \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemv/task.yml b/problems/nvidia/nvfp4_gemv/task.yml index adf8cab..756fe80 100644 --- a/problems/nvidia/nvfp4_gemv/task.yml +++ b/problems/nvidia/nvfp4_gemv/task.yml @@ -1,5 +1,3 @@ -# name: nvfp4-ffma-gemv - files: - {"name": "submission.py", "source": "@SUBMISSION@"} - {"name": "task.py", "source": "task.py"} @@ -58,4 +56,4 @@ benchmarks: - {"m": 4096, "k": 7168, "l":8, "seed": 1111} - {"m": 7168, "k": 2048, "l":4, "seed": 1111} -ranking_by: "geom" \ No newline at end of file +ranking_by: "geom" diff --git a/problems/nvidia/nvfp4_gemv/utils.py b/problems/nvidia/nvfp4_gemv/utils.py deleted file mode 100644 index e8a9082..0000000 --- a/problems/nvidia/nvfp4_gemv/utils.py +++ /dev/null @@ -1,176 +0,0 @@ -import os -import random -import numpy as np -import torch - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_device(use_cuda: bool = True) -> torch.device: - """Get the appropriate device (GPU or CPU).""" - if use_cuda: - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - print("No compatible GPU found. Falling back to CPU.") - return torch.device("cpu") - - -# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py -@torch.no_grad() -def verbose_allclose( - received: torch.Tensor, - expected: torch.Tensor, - rtol=1e-05, - atol=1e-08, - max_print=5 -) -> list[str]: - """ - Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - rtol (float): Relative tolerance; relative to expected - atol (float): Absolute tolerance. - max_print (int): Maximum number of mismatched elements to print. - - Raises: - AssertionError: If the tensors are not all close within the given tolerance. - """ - # Check if the shapes of the tensors match - if received.shape != expected.shape: - return ["SIZE MISMATCH"] - - # Calculate the difference between the tensors - diff = torch.abs(received - expected) - - # Determine the tolerance - tolerance = atol + rtol * torch.abs(expected) - - # Find tolerance mismatched elements - tol_mismatched = diff > tolerance - - # Find nan mismatched elements - nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) - - # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) - # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) - - # Find all mismatched elements - mismatched = torch.logical_or( - torch.logical_or(tol_mismatched, nan_mismatched), - torch.logical_or(posinf_mismatched, neginf_mismatched), - ) - - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -@torch.no_grad() -def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): - """ - Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - max_print (int): Maximum number of mismatched elements to print. - - Returns: - Empty string if tensors are equal, otherwise detailed error information - """ - mismatched = torch.not_equal(received, expected) - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: - """ - Convenient "default" implementation for tasks' `check_implementation` function. - """ - expected = reference(data) - reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) - - if len(reasons) > 0: - return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) - - return True, '' - - -def make_match_reference(reference: callable, **kwargs): - def wrapped(data, output): - return match_reference(data, output, reference=reference, **kwargs) - return wrapped - - -class DeterministicContext: - def __init__(self): - self.allow_tf32 = None - self.deterministic = None - self.cublas = None - - def __enter__(self): - self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') - self.allow_tf32 = torch.backends.cudnn.allow_tf32 - self.deterministic = torch.backends.cudnn.deterministic - torch.backends.cudnn.allow_tf32 = False - torch.backends.cudnn.deterministic = True - torch.use_deterministic_algorithms(True) - return self - - def __exit__(self, exc_type, exc_value, traceback): - torch.backends.cudnn.allow_tf32 = self.allow_tf32 - torch.backends.cudnn.deterministic = self.deterministic - torch.use_deterministic_algorithms(False) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas - -def clear_l2_cache(): - # import cupy as cp - # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) - # create a large dummy tensor - dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") - # write stuff to - dummy.fill_(42) - del dummy \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/task.yml b/problems/nvidia/nvfp4_group_gemm/task.yml index 9f0df2b..6390ae9 100644 --- a/problems/nvidia/nvfp4_group_gemm/task.yml +++ b/problems/nvidia/nvfp4_group_gemm/task.yml @@ -3,7 +3,7 @@ files: - {"name": "submission.py", "source": "@SUBMISSION@"} - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} + - {"name": "utils.py", "source": "utils.py"} - {"name": "reference.py", "source": "reference.py"} - {"name": "eval.py", "source": "../eval.py"} diff --git a/problems/nvidia/nvfp4_dual_gemm/utils.py b/problems/nvidia/utils.py similarity index 100% rename from problems/nvidia/nvfp4_dual_gemm/utils.py rename to problems/nvidia/utils.py From f0a784d1dc9c5353d65c53ce382e22239e22c570 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Tue, 4 Nov 2025 06:57:03 -0800 Subject: [PATCH 21/29] add local eval file. --- problems/nvidia/nvfp4_gemm/eval.py | 437 +++++++++++++++++++++++++++++ 1 file changed, 437 insertions(+) create mode 100644 problems/nvidia/nvfp4_gemm/eval.py diff --git a/problems/nvidia/nvfp4_gemm/eval.py b/problems/nvidia/nvfp4_gemm/eval.py new file mode 100644 index 0000000..072b176 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/eval.py @@ -0,0 +1,437 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional +import tempfile + +import torch.cuda +from cutlass.cute.nvgpu.common import OpError + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, "w") + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats( + runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) + ) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + + data = generate_input(**test.args) + torch.cuda.synchronize() + try: + submission_output = custom_kernel(_clone_data(data)) + + except OpError as E: + print(f"Encountered {E}", file=sys.stderr) + return False, str(E) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes the actual test case code and checks for correctness. + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _run_single_benchmark( + test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float +) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + # first, one obligatory correctness check + try: + output = custom_kernel(_clone_data(data)) + except OpError as E: + return f"Encountered {E}" + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 100 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if ( + stats.err / stats.mean < 0.001 + or stats.mean * stats.runs > max_time_ns + or total_bm_duration > 120e9 + ): + break + + return calculate_stats(durations) + + +def run_single_benchmark( + pool: multiprocessing.Pool, + test: TestCase, + recheck: bool, + max_repeats: int, + max_time_ns: float, +): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # warm up + run_single_benchmark(pool, tests[0], False, 100, 10e7) + + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 100, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log( + f"benchmark.{idx}.report", + base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), + ) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + + filename = None + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + + def build_test_string(tests: list[dict]): + as_str = "" + for test in tests: + kvs = [] + for k, v in test.items(): + kvs.append(f"{k}: {v}") + as_str += "; ".join(kvs) + "\n" + return as_str + + import yaml + + yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + if mode == "test": + tests_str = build_test_string(yaml_content.get("tests", [])) + elif mode in ("benchmark", "leaderboard", "profile"): + tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + tmp.write(tests_str.encode("utf-8")) + tmp.flush() + filename = tmp.name + + tests = get_test_cases(filename, seed) + + os.unlink(filename) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + + mp_context = multiprocessing.get_context("spawn") + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # warmup + run_single_benchmark(pool, tests[0], False, 100, 1e7) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 100, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log( + f"benchmark.{i}.{field.name}", + getattr(result, field.name), + ) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log( + f"benchmark.{i}.error", str(result) + ) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file From 724160c2b5e7d176072c0fdcb8731b5d301a1ed4 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Wed, 5 Nov 2025 15:14:28 -0800 Subject: [PATCH 22/29] tight the tolerance value --- problems/nvidia/nvfp4_dual_gemm/reference.py | 2 +- problems/nvidia/nvfp4_gemm/reference.py | 2 +- problems/nvidia/nvfp4_gemv/reference.py | 2 +- problems/nvidia/nvfp4_group_gemm/reference.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/problems/nvidia/nvfp4_dual_gemm/reference.py b/problems/nvidia/nvfp4_dual_gemm/reference.py index 29ee4cb..13880b3 100644 --- a/problems/nvidia/nvfp4_dual_gemm/reference.py +++ b/problems/nvidia/nvfp4_dual_gemm/reference.py @@ -187,4 +187,4 @@ def create_scale_factor_tensors(l, mn, sf_k): return (a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, sfa_ref_permuted, sfb1_ref_permuted, sfb2_ref_permuted, c_ref) -check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) \ No newline at end of file +check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/reference.py b/problems/nvidia/nvfp4_gemm/reference.py index cd6dfb4..dc55c84 100644 --- a/problems/nvidia/nvfp4_gemm/reference.py +++ b/problems/nvidia/nvfp4_gemm/reference.py @@ -154,4 +154,4 @@ def create_scale_factor_tensors(l, mn, sf_k): return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, sfa_ref_permuted, sfb_ref_permuted, c_ref) -check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) +check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) diff --git a/problems/nvidia/nvfp4_gemv/reference.py b/problems/nvidia/nvfp4_gemv/reference.py index 0a9c93d..cd1b2d1 100644 --- a/problems/nvidia/nvfp4_gemv/reference.py +++ b/problems/nvidia/nvfp4_gemv/reference.py @@ -158,4 +158,4 @@ def create_scale_factor_tensors(l, mn, sf_k): return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, sfa_permuted, sfb_permuted, c_ref) -check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) \ No newline at end of file +check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/reference.py b/problems/nvidia/nvfp4_group_gemm/reference.py index 5fcb386..6fe5add 100644 --- a/problems/nvidia/nvfp4_group_gemm/reference.py +++ b/problems/nvidia/nvfp4_group_gemm/reference.py @@ -172,4 +172,4 @@ def generate_input( return (abc_tensors, sfasfb_tensors, sfasfb_reordered_tensors, problem_sizes) -check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-02) +check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) From c4aecfef12b1a7a4c7cd9331a12d37ae538b3a07 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Thu, 6 Nov 2025 15:48:56 -0800 Subject: [PATCH 23/29] optimize data convert function in reference. --- problems/nvidia/nvfp4_dual_gemm/reference.py | 54 ++++++++++--------- problems/nvidia/nvfp4_gemm/reference.py | 54 ++++++++++--------- problems/nvidia/nvfp4_gemv/reference.py | 53 +++++++++--------- problems/nvidia/nvfp4_group_gemm/reference.py | 43 +++++++++------ 4 files changed, 114 insertions(+), 90 deletions(-) diff --git a/problems/nvidia/nvfp4_dual_gemm/reference.py b/problems/nvidia/nvfp4_dual_gemm/reference.py index 13880b3..4901af6 100644 --- a/problems/nvidia/nvfp4_dual_gemm/reference.py +++ b/problems/nvidia/nvfp4_dual_gemm/reference.py @@ -130,20 +130,17 @@ def generate_input( ) # Helper function to prepare the scale factor tensors for both reference - # kernel and customize kernel. Please note this data reordering function - # is very slow, and the customized data layout can be found in the following link: + # kernel and customize kernel. The customized data layout can be found in: # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout def create_scale_factor_tensors(l, mn, sf_k): - # Create the reference scale factor tensor (mn, l, sf_k) on CPU. + # Create the reference scale factor tensor (mn, sf_k, l) on CPU. ref_shape = (l, mn, sf_k) ref_permute_order = (1, 2, 0) # Init with uint8 tensor, then convert to float8_e4m3fn - ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8) - ref_f8_torch_tensor_cpu = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) + ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8, device='cuda') + ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) # permute to match ref_permute_order - ref_f8_torch_tensor_cpu_permuted = ref_f8_torch_tensor_cpu.permute( - *ref_permute_order - ) + ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order) atom_m = (32, 4) atom_k = 4 @@ -160,24 +157,31 @@ def create_scale_factor_tensors(l, mn, sf_k): # Which is needed by the CuTe customized kernel mma_permute_order = (3, 4, 1, 5, 2, 0) # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) - reordered_f8_torch_tensor_cpu = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8, device='cuda') + reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) # Permute according to mma_permute_order - reordered_f8_torch_tensor_cpu = reordered_f8_torch_tensor_cpu.permute( - *mma_permute_order - ) - - for i in range(mn): - for j in range(sf_k): - for b in range(l): - # Calculate the location in MMA shape - mm = i // (atom_m[0] * atom_m[1]) - mm32 = i % atom_m[0] - mm4 = (i % 128) // atom_m[0] - kk = j // atom_k - kk4 = j % atom_k - reordered_f8_torch_tensor_cpu[mm32, mm4, mm, kk4, kk, b] = ref_f8_torch_tensor_cpu_permuted[i, j, b] - return ref_f8_torch_tensor_cpu_permuted, reordered_f8_torch_tensor_cpu.cuda() + reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order) + + # GPU-side vectorized reordering (replaces slow CPU nested loops) + # Create index grids for all dimensions + i_idx = torch.arange(mn, device='cuda') + j_idx = torch.arange(sf_k, device='cuda') + b_idx = torch.arange(l, device='cuda') + + # Create meshgrid for all combinations of (i, j, b) + i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') + + # Calculate target indices in vectorized manner + mm = i_grid // (atom_m[0] * atom_m[1]) + mm32 = i_grid % atom_m[0] + mm4 = (i_grid % 128) // atom_m[0] + kk = j_grid // atom_k + kk4 = j_grid % atom_k + + # Perform the reordering with advanced indexing (all on GPU) + reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid] + + return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor sf_k = ceil_div(k, sf_vec_size) sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) diff --git a/problems/nvidia/nvfp4_gemm/reference.py b/problems/nvidia/nvfp4_gemm/reference.py index dc55c84..51dd750 100644 --- a/problems/nvidia/nvfp4_gemm/reference.py +++ b/problems/nvidia/nvfp4_gemm/reference.py @@ -98,20 +98,17 @@ def generate_input( ) # Helper function to prepare the scale factor tensors for both reference - # kernel and customize kernel. Please note this data reordering function - # is very slow, and the customized data layout can be found in the following link: + # kernel and customize kernel. The customized data layout can be found in: # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout def create_scale_factor_tensors(l, mn, sf_k): - # Create the reference scale factor tensor (mn, l, sf_k) on CPU. + # Create the reference scale factor tensor (mn, sf_k, l) on CPU. ref_shape = (l, mn, sf_k) ref_permute_order = (1, 2, 0) # Init with uint8 tensor, then convert to float8_e4m3fn - ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8) - ref_f8_torch_tensor_cpu = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) + ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8, device='cuda') + ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) # permute to match ref_permute_order - ref_f8_torch_tensor_cpu_permuted = ref_f8_torch_tensor_cpu.permute( - *ref_permute_order - ) + ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order) atom_m = (32, 4) atom_k = 4 @@ -128,24 +125,31 @@ def create_scale_factor_tensors(l, mn, sf_k): # Which is needed by the CuTe customized kernel mma_permute_order = (3, 4, 1, 5, 2, 0) # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) - reordered_f8_torch_tensor_cpu = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8, device='cuda') + reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) # Permute according to mma_permute_order - reordered_f8_torch_tensor_cpu = reordered_f8_torch_tensor_cpu.permute( - *mma_permute_order - ) - - for i in range(mn): - for j in range(sf_k): - for b in range(l): - # Calculate the location in MMA shape - mm = i // (atom_m[0] * atom_m[1]) - mm32 = i % atom_m[0] - mm4 = (i % 128) // atom_m[0] - kk = j // atom_k - kk4 = j % atom_k - reordered_f8_torch_tensor_cpu[mm32, mm4, mm, kk4, kk, b] = ref_f8_torch_tensor_cpu_permuted[i, j, b] - return ref_f8_torch_tensor_cpu_permuted, reordered_f8_torch_tensor_cpu.cuda() + reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order) + + # GPU-side vectorized reordering (replaces slow CPU nested loops) + # Create index grids for all dimensions + i_idx = torch.arange(mn, device='cuda') + j_idx = torch.arange(sf_k, device='cuda') + b_idx = torch.arange(l, device='cuda') + + # Create meshgrid for all combinations of (i, j, b) + i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') + + # Calculate target indices in vectorized manner + mm = i_grid // (atom_m[0] * atom_m[1]) + mm32 = i_grid % atom_m[0] + mm4 = (i_grid % 128) // atom_m[0] + kk = j_grid // atom_k + kk4 = j_grid % atom_k + + # Perform the reordering with advanced indexing (all on GPU) + reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid] + + return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor sf_k = ceil_div(k, sf_vec_size) sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) diff --git a/problems/nvidia/nvfp4_gemv/reference.py b/problems/nvidia/nvfp4_gemv/reference.py index cd1b2d1..b01ef18 100644 --- a/problems/nvidia/nvfp4_gemv/reference.py +++ b/problems/nvidia/nvfp4_gemv/reference.py @@ -103,20 +103,17 @@ def generate_input( ) # Helper function to prepare the scale factor tensors for both reference - # kernel and customize kernel. Please note this data reordering function - # is very slow, and the customized data layout can be found in the following link: + # kernel and customize kernel. The customized data layout can be found in: # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout def create_scale_factor_tensors(l, mn, sf_k): - # Create the reference scale factor tensor (mn, l, sf_k) on CPU. + # Create the reference scale factor tensor (mn, sf_k, l) on CPU. ref_shape = (l, mn, sf_k) ref_permute_order = (1, 2, 0) # Init with uint8 tensor, then convert to float8_e4m3fn - ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8) - ref_f8_torch_tensor_cpu = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) + ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8, device='cuda') + ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) # permute to match ref_permute_order - ref_f8_torch_tensor_cpu_permuted = ref_f8_torch_tensor_cpu.permute( - *ref_permute_order - ) + ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order) atom_m = (32, 4) atom_k = 4 @@ -133,23 +130,31 @@ def create_scale_factor_tensors(l, mn, sf_k): # Which is needed by the CuTe customized kernel mma_permute_order = (3, 4, 1, 5, 2, 0) # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) - reordered_f8_torch_tensor_cpu = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8, device='cuda') + reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) # Permute according to mma_permute_order - reordered_f8_torch_tensor_cpu = reordered_f8_torch_tensor_cpu.permute( - *mma_permute_order - ) - for i in range(mn): - for j in range(sf_k): - for b in range(l): - # Calculate the location in MMA shape - mm = i // (atom_m[0] * atom_m[1]) - mm32 = i % atom_m[0] - mm4 = (i % 128) // atom_m[0] - kk = j // atom_k - kk4 = j % atom_k - reordered_f8_torch_tensor_cpu[mm32, mm4, mm, kk4, kk, b] = ref_f8_torch_tensor_cpu_permuted[i, j, b] - return ref_f8_torch_tensor_cpu_permuted, reordered_f8_torch_tensor_cpu.cuda() + reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order) + + # GPU-side vectorized reordering (replaces slow CPU nested loops) + # Create index grids for all dimensions + i_idx = torch.arange(mn, device='cuda') + j_idx = torch.arange(sf_k, device='cuda') + b_idx = torch.arange(l, device='cuda') + + # Create meshgrid for all combinations of (i, j, b) + i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') + + # Calculate target indices in vectorized manner + mm = i_grid // (atom_m[0] * atom_m[1]) + mm32 = i_grid % atom_m[0] + mm4 = (i_grid % 128) // atom_m[0] + kk = j_grid // atom_k + kk4 = j_grid % atom_k + + # Perform the reordering with advanced indexing (all on GPU) + reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid] + + return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor sf_k = ceil_div(k, sf_vec_size) sfa_ref_cpu, sfa_permuted = create_scale_factor_tensors(l, m, sf_k) diff --git a/problems/nvidia/nvfp4_group_gemm/reference.py b/problems/nvidia/nvfp4_group_gemm/reference.py index 6fe5add..f71da00 100644 --- a/problems/nvidia/nvfp4_group_gemm/reference.py +++ b/problems/nvidia/nvfp4_group_gemm/reference.py @@ -64,8 +64,7 @@ def ref_kernel( # Helper function to prepare the scale factor tensors for both reference -# kernel and customize kernel. Please note this data reordering function -# is very slow, and the customized data layout can be found in the following link: +# kernel and customize kernel. The customized data layout can be found in: # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): sf_k = ceil_div(k, sf_vec_size) @@ -79,26 +78,38 @@ def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): atom_m[1], atom_k, ) - # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on CPU. + # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on GPU. mma_permute_order = (3, 4, 1, 5, 2, 0) # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8) + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8, device='cuda') reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) # Permute according to mma_permute_order reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) - # Please note this movement code is very slow. - for i in range(mn): - for j in range(sf_k): - for b in range(l): - # Calculate the location in MMA shape - mm = i // (atom_m[0] * atom_m[1]) - mm32 = i % atom_m[0] - mm4 = (i % 128) // atom_m[0] - kk = j // atom_k - kk4 = j % atom_k - reordered_f8_tensor[mm32, mm4, mm, kk4, kk, b] = ref_f8_tensor[i, j, b] - return reordered_f8_tensor.cuda() + # Move ref_f8_tensor to GPU if not already there + if ref_f8_tensor.device.type == 'cpu': + ref_f8_tensor = ref_f8_tensor.cuda() + + # GPU-side vectorized reordering (replaces slow CPU nested loops) + # Create index grids for all dimensions + i_idx = torch.arange(mn, device='cuda') + j_idx = torch.arange(sf_k, device='cuda') + b_idx = torch.arange(l, device='cuda') + + # Create meshgrid for all combinations of (i, j, b) + i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') + + # Calculate target indices in vectorized manner + mm = i_grid // (atom_m[0] * atom_m[1]) + mm32 = i_grid % atom_m[0] + mm4 = (i_grid % 128) // atom_m[0] + kk = j_grid // atom_k + kk4 = j_grid % atom_k + + # Perform the reordering with advanced indexing (all on GPU) + reordered_f8_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_tensor[i_grid, j_grid, b_grid] + + return reordered_f8_tensor def generate_input( From e26912c549e20228fe26071457ca64b0e7bac2b1 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Thu, 6 Nov 2025 15:57:58 -0800 Subject: [PATCH 24/29] add more explanation about why we need a seperate compile_func. --- problems/nvidia/nvfp4_dual_gemm/submission.py | 5 ++++- problems/nvidia/nvfp4_gemm/submission.py | 5 ++++- problems/nvidia/nvfp4_gemv/submission.py | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/problems/nvidia/nvfp4_dual_gemm/submission.py b/problems/nvidia/nvfp4_dual_gemm/submission.py index a30af22..1daf4aa 100644 --- a/problems/nvidia/nvfp4_dual_gemm/submission.py +++ b/problems/nvidia/nvfp4_dual_gemm/submission.py @@ -856,7 +856,8 @@ def my_kernel( # Global cache for compiled kernel _compiled_kernel_cache = None - +# This function is used to compile the kernel once and cache it and then allow users to +# run the kernel multiple times to get more accurate timing results. def compile_kernel(): """ Compile the kernel once and cache it. @@ -928,7 +929,9 @@ def custom_kernel(data: input_t) -> output_t: a, b1, b2, _, _, _, sfa_permuted, sfb1_permuted, sfb2_permuted, c = data # Ensure kernel is compiled (will use cached version if available) + # To avoid the compilation overhead, we compile the kernel once and cache it. compiled_func = compile_kernel() + # Get dimensions from MxKxL layout _, k, _ = a.shape m, n, l = c.shape diff --git a/problems/nvidia/nvfp4_gemm/submission.py b/problems/nvidia/nvfp4_gemm/submission.py index 0d2a9d4..f8889b0 100644 --- a/problems/nvidia/nvfp4_gemm/submission.py +++ b/problems/nvidia/nvfp4_gemm/submission.py @@ -678,7 +678,8 @@ def my_kernel( # Global cache for compiled kernel _compiled_kernel_cache = None - +# This function is used to compile the kernel once and cache it and then allow users to +# run the kernel multiple times to get more accurate timing results. def compile_kernel(): """ Compile the kernel once and cache it. @@ -740,7 +741,9 @@ def custom_kernel(data: input_t) -> output_t: a, b, _, _, sfa_permuted, sfb_permuted, c = data # Ensure kernel is compiled (will use cached version if available) + # To avoid the compilation overhead, we compile the kernel once and cache it. compiled_func = compile_kernel() + # Get dimensions from MxKxL layout m, k, l = a.shape n, _, _ = b.shape diff --git a/problems/nvidia/nvfp4_gemv/submission.py b/problems/nvidia/nvfp4_gemv/submission.py index 798afd7..8e176a2 100644 --- a/problems/nvidia/nvfp4_gemv/submission.py +++ b/problems/nvidia/nvfp4_gemv/submission.py @@ -165,7 +165,8 @@ def my_kernel( # Global cache for compiled kernel _compiled_kernel_cache = None - +# This function is used to compile the kernel once and cache it and then allow users to +# run the kernel multiple times to get more accurate timing results. def compile_kernel(): """ Compile the kernel once and cache it. @@ -227,7 +228,9 @@ def custom_kernel(data: input_t) -> output_t: a, b, _, _, sfa_permuted, sfb_permuted, c = data # Ensure kernel is compiled (will use cached version if available) + # To avoid the compilation overhead, we compile the kernel once and cache it. compiled_func = compile_kernel() + # Get dimensions from MxKxL layout m, k, l = a.shape # Torch use e2m1_x2 data type, thus k is halved From 73175a457a6ce4d905aae479ce602142bd555abd Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Fri, 7 Nov 2025 21:50:43 -0800 Subject: [PATCH 25/29] use cute tensor to do accumulation operation. --- problems/nvidia/nvfp4_gemm/submission.py | 4 ++-- problems/nvidia/nvfp4_gemv/submission.py | 27 ++++++++++++++---------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/problems/nvidia/nvfp4_gemm/submission.py b/problems/nvidia/nvfp4_gemm/submission.py index f8889b0..a289d68 100644 --- a/problems/nvidia/nvfp4_gemm/submission.py +++ b/problems/nvidia/nvfp4_gemm/submission.py @@ -447,11 +447,11 @@ class SharedStorage: # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(tCgC) # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rAcc = cute.make_fragment( + tTR_rAcc = cute.make_rmem_tensor( tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 ) # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rC = cute.make_fragment( + tTR_rC = cute.make_rmem_tensor( tTR_gC[None, None, None, None, 0, 0, 0].shape, c_dtype ) # STG Atom diff --git a/problems/nvidia/nvfp4_gemv/submission.py b/problems/nvidia/nvfp4_gemv/submission.py index 8e176a2..8a8b7c5 100644 --- a/problems/nvidia/nvfp4_gemv/submission.py +++ b/problems/nvidia/nvfp4_gemv/submission.py @@ -66,9 +66,14 @@ def kernel( k_tile_cnt = gA_mkl.layout[3].shape for k_tile in range(k_tile_cnt): tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz] - tBgB = gB_nkl[None, None, bidy, k_tile, bidz] + tBgB = gB_nkl[0, None, bidy, k_tile, bidz] tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz] - tBgSFB = gSFB_nkl[None, None, bidy, k_tile, bidz] + tBgSFB = gSFB_nkl[0, None, bidy, k_tile, bidz] + + tArA = cute.make_rmem_tensor(tAgA, cutlass.Float32) + tBrB = cute.make_rmem_tensor(tBgB, cutlass.Float32) + tArSFA = cute.make_rmem_tensor(tAgSFA, cutlass.Float32) + tBrSFB = cute.make_rmem_tensor(tBgSFB, cutlass.Float32) # Load NVFP4 or FP8 values from global memory a_val_nvfp4 = tAgA.load() @@ -82,16 +87,16 @@ def kernel( sfa_val = sfa_val_fp8.to(cutlass.Float32) sfb_val = sfb_val_fp8.to(cutlass.Float32) + # Store the converted values to RMEM CuTe tensors + tArA.store(a_val) + tBrB.store(b_val) + tArSFA.store(sfa_val) + tBrSFB.store(sfb_val) + # Iterate over SF vector tiles and compute the scale&matmul accumulation - for i in cutlass.range_constexpr(mma_tiler_mnk[2] // sf_vec_size): - for j in cutlass.range_constexpr(sf_vec_size): - # Accumulate: (A * scaleA * B * scaleB), where scaling is per-vector - res += ( - a_val[i * sf_vec_size + j] - * sfa_val[i] - * b_val[i * sf_vec_size + j] - * sfb_val[i] - ) + for i in cutlass.range_constexpr(mma_tiler_mnk[2]): + res += tArA[i] * tArSFA[i] * tBrB[i] * tBrSFB[i] + # Store the final float16 result back to global memory tCgC.store(res.to(cutlass.Float16)) return From d6bbb9782287e1d2a8140b70fd48166cf5b6c5e9 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Fri, 7 Nov 2025 21:52:34 -0800 Subject: [PATCH 26/29] clean codes. --- problems/nvidia/nvfp4_dual_gemm/submission.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/problems/nvidia/nvfp4_dual_gemm/submission.py b/problems/nvidia/nvfp4_dual_gemm/submission.py index 1daf4aa..f5ee4c2 100644 --- a/problems/nvidia/nvfp4_dual_gemm/submission.py +++ b/problems/nvidia/nvfp4_dual_gemm/submission.py @@ -576,15 +576,15 @@ class SharedStorage: # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(tCgC) # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rAcc1 = cute.make_fragment( + tTR_rAcc1 = cute.make_rmem_tensor( tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 ) # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rAcc2 = cute.make_fragment( + tTR_rAcc2 = cute.make_rmem_tensor( tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 ) # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rC = cute.make_fragment( + tTR_rC = cute.make_rmem_tensor( tTR_gC[None, None, None, None, 0, 0, 0].shape, c_dtype ) # STG Atom @@ -636,11 +636,6 @@ def my_kernel( m, n, k, l = problem_size # Setup attributes that depend on gemm inputs - cta_tile_shape_mnk = ( - mma_tiler_mnk[0], - mma_tiler_mnk[1], - mma_tiler_mnk[2], - ) a_tensor = cute.make_tensor( a_ptr, cute.make_layout( @@ -798,8 +793,8 @@ def my_kernel( # Compute grid size grid = ( - cute.ceil_div(c_tensor.shape[0], cta_tile_shape_mnk[0]), - cute.ceil_div(c_tensor.shape[1], cta_tile_shape_mnk[1]), + cute.ceil_div(c_tensor.shape[0], mma_tiler_mnk[0]), + cute.ceil_div(c_tensor.shape[1], mma_tiler_mnk[1]), c_tensor.shape[2], ) From 0bb660a99083fbb5fb1bd9d1c27ea00f2acbe5bd Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Fri, 7 Nov 2025 21:56:33 -0800 Subject: [PATCH 27/29] improve comments. --- problems/nvidia/nvfp4_dual_gemm/submission.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/problems/nvidia/nvfp4_dual_gemm/submission.py b/problems/nvidia/nvfp4_dual_gemm/submission.py index f5ee4c2..fdf4525 100644 --- a/problems/nvidia/nvfp4_dual_gemm/submission.py +++ b/problems/nvidia/nvfp4_dual_gemm/submission.py @@ -217,7 +217,7 @@ class SharedStorage: # # Partition global/shared tensor for TMA load A/B/SFA/SFB # - # TMA load A partition_S/D + # TMA Partition_S/D for A # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tAsA, tAgA = cpasync.tma_partition( @@ -227,7 +227,7 @@ class SharedStorage: cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3), ) - # TMA load B1 partition_S/D + # TMA Partition_S/D for B1 # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsB1, tBgB1 = cpasync.tma_partition( @@ -237,7 +237,7 @@ class SharedStorage: cute.group_modes(sB1, 0, 3), cute.group_modes(tCgB1, 0, 3), ) - # TMA load B2 partition_S/D + # TMA Partition_S/D for B2 # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsB2, tBgB2 = cpasync.tma_partition( @@ -248,7 +248,7 @@ class SharedStorage: cute.group_modes(tCgB2, 0, 3), ) - # TMALDG_SFA partition_S/D + # TMA Partition_S/D for SFA # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tAsSFA, tAgSFA = cpasync.tma_partition( @@ -260,8 +260,7 @@ class SharedStorage: ) tAsSFA = cute.filter_zeros(tAsSFA) tAgSFA = cute.filter_zeros(tAgSFA) - - # TMALDG SFB1 partition_S/D + # TMA Partition_S/D for SFB1 # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsSFB1, tBgSFB1 = cpasync.tma_partition( @@ -273,7 +272,7 @@ class SharedStorage: ) tBsSFB1 = cute.filter_zeros(tBsSFB1) tBgSFB1 = cute.filter_zeros(tBgSFB1) - # TMALDG SFB2 partition_S/D + # TMA Partition_S/D for SFB2 # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsSFB2, tBgSFB2 = cpasync.tma_partition( @@ -451,7 +450,7 @@ class SharedStorage: # Wait for AB buffer empty ab_empty = ab_producer.acquire_and_advance() - # TMALDG A/B1/B2/SFA/SFB1/SFB2 + # TMA load A/B1/B2/SFA/SFB1/SFB2 to shared memory cute.copy( tma_atom_a, tAgA[(None, ab_empty.count)], @@ -716,7 +715,7 @@ def my_kernel( ) atom_thr_size = cute.size(tiled_mma.thr_id.shape) - # TMA load for A + # Setup TMA for A a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), @@ -726,7 +725,7 @@ def my_kernel( tiled_mma, cluster_layout_vmnk .shape, ) - # TMA load for B1 + # Setup TMA for B1 b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) tma_atom_b1, tma_tensor_b1 = cute.nvgpu.make_tiled_tma_atom_B( cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), @@ -736,7 +735,7 @@ def my_kernel( tiled_mma, cluster_layout_vmnk .shape, ) - # TMA load for B2 + # Setup TMA for B2 tma_atom_b2, tma_tensor_b2 = cute.nvgpu.make_tiled_tma_atom_B( cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), b_tensor2, @@ -745,7 +744,7 @@ def my_kernel( tiled_mma, cluster_layout_vmnk .shape, ) - # TMA load for SFA + # Setup TMA for SFA sfa_smem_layout = cute.slice_( sfa_smem_layout_staged , (None, None, None, 0) ) @@ -758,7 +757,7 @@ def my_kernel( cluster_layout_vmnk .shape, internal_type=cutlass.Int16, ) - # TMA load for SFB1 + # Setup TMA for SFB1 sfb_smem_layout = cute.slice_( sfb_smem_layout_staged , (None, None, None, 0) ) @@ -771,7 +770,7 @@ def my_kernel( cluster_layout_vmnk .shape, internal_type=cutlass.Int16, ) - # TMA load for SFB2 + # Setup TMA for SFB2 tma_atom_sfb2, tma_tensor_sfb2 = cute.nvgpu.make_tiled_tma_atom_B( cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), sfb_tensor2, From 329369cede7c65fc07a932564e3aa18e179e2aaf Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Fri, 7 Nov 2025 22:01:38 -0800 Subject: [PATCH 28/29] improve comments. --- problems/nvidia/nvfp4_dual_gemm/submission.py | 4 --- problems/nvidia/nvfp4_gemm/submission.py | 36 +++++++------------ 2 files changed, 12 insertions(+), 28 deletions(-) diff --git a/problems/nvidia/nvfp4_dual_gemm/submission.py b/problems/nvidia/nvfp4_dual_gemm/submission.py index fdf4525..f733212 100644 --- a/problems/nvidia/nvfp4_dual_gemm/submission.py +++ b/problems/nvidia/nvfp4_dual_gemm/submission.py @@ -247,7 +247,6 @@ class SharedStorage: cute.group_modes(sB2, 0, 3), cute.group_modes(tCgB2, 0, 3), ) - # TMA Partition_S/D for SFA # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) @@ -590,9 +589,6 @@ class SharedStorage: simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) tTR_gC = tTR_gC[(None, None, None, None, *mma_tile_coord_mnl)] - # Release tensor memory allocation lock - if warp_idx == 0: - cute.arch.relinquish_tmem_alloc_permit() # Wait for accumulator buffer full acc_full = acc_consumer.wait_and_advance() diff --git a/problems/nvidia/nvfp4_gemm/submission.py b/problems/nvidia/nvfp4_gemm/submission.py index a289d68..c2f37d9 100644 --- a/problems/nvidia/nvfp4_gemm/submission.py +++ b/problems/nvidia/nvfp4_gemm/submission.py @@ -185,7 +185,7 @@ class SharedStorage: # # Partition global/shared tensor for TMA load A/B/SFA/SFB # - # TMA load A partition_S/D + # TMA Partition_S/D for A # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tAsA, tAgA = cpasync.tma_partition( @@ -195,7 +195,7 @@ class SharedStorage: cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3), ) - # TMA load B partition_S/D + # TMA Partition_S/D for B # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsB, tBgB = cpasync.tma_partition( @@ -205,8 +205,7 @@ class SharedStorage: cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3), ) - - # TMALDG_SFA partition_S/D + # TMA Partition_S/D for SFA # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tAsSFA, tAgSFA = cpasync.tma_partition( @@ -218,8 +217,7 @@ class SharedStorage: ) tAsSFA = cute.filter_zeros(tAsSFA) tAgSFA = cute.filter_zeros(tAgSFA) - - # TMALDG_SFB partition_S/D + # TMA Partition_S/D for SFB # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsSFB, tBgSFB = cpasync.tma_partition( @@ -355,7 +353,7 @@ class SharedStorage: # Wait for AB buffer empty ab_empty = ab_producer.acquire_and_advance() - # TMALDG A/B/SFA/SFB + # TMA load A/B/SFA/SFB to shared memory cute.copy( tma_atom_a, tAgA[(None, k_tile)], @@ -384,7 +382,7 @@ class SharedStorage: # Wait for AB buffer full ab_full = ab_consumer.wait_and_advance() - # Copy SFA/SFB to tmem + # Copy SFA/SFB from shared memory to TMEM s2t_stage_coord = (None, None, None, None, ab_full.index) tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] @@ -458,9 +456,6 @@ class SharedStorage: simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) tTR_gC = tTR_gC[(None, None, None, None, *mma_tile_coord_mnl)] - # Release TMEM allocation lock - tmem.relinquish_alloc_permit() - # Wait for accumulator buffer full acc_full = acc_consumer.wait_and_advance() @@ -495,11 +490,6 @@ def my_kernel( m, n, k, l = problem_size # Setup attributes that depend on gemm inputs - cta_tile_shape_mnk = ( - mma_tiler_mnk[0], - mma_tiler_mnk[1], - mma_tiler_mnk[2], - ) a_tensor = cute.make_tensor( a_ptr, cute.make_layout( @@ -571,7 +561,7 @@ def my_kernel( atom_thr_size = cute.size(tiled_mma.thr_id.shape) - # TMA load for A + # Setup TMA for A a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), @@ -581,7 +571,7 @@ def my_kernel( tiled_mma, cluster_layout_vmnk.shape, ) - # TMA load for B + # Setup TMA for B b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), @@ -591,8 +581,7 @@ def my_kernel( tiled_mma, cluster_layout_vmnk.shape, ) - - # TMA load for SFA + # Setup TMA for SFA sfa_smem_layout = cute.slice_( sfa_smem_layout_staged, (None, None, None, 0) ) @@ -605,8 +594,7 @@ def my_kernel( cluster_layout_vmnk.shape, internal_type=cutlass.Int16, ) - - # TMA load for SFB + # Setup TMA for SFB sfb_smem_layout = cute.slice_( sfb_smem_layout_staged, (None, None, None, 0) ) @@ -631,8 +619,8 @@ def my_kernel( # Compute grid size grid = ( - cute.ceil_div(c_tensor.shape[0], cta_tile_shape_mnk[0]), - cute.ceil_div(c_tensor.shape[1], cta_tile_shape_mnk[1]), + cute.ceil_div(c_tensor.shape[0], mma_tiler_mnk[0]), + cute.ceil_div(c_tensor.shape[1], mma_tiler_mnk[1]), c_tensor.shape[2], ) From 7a8d0ccd212b3a69f925da710124c430a3e30178 Mon Sep 17 00:00:00 2001 From: Vicki Wang Date: Fri, 7 Nov 2025 22:04:35 -0800 Subject: [PATCH 29/29] fix compilation error. --- problems/nvidia/nvfp4_gemv/submission.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/problems/nvidia/nvfp4_gemv/submission.py b/problems/nvidia/nvfp4_gemv/submission.py index 8a8b7c5..9cf2394 100644 --- a/problems/nvidia/nvfp4_gemv/submission.py +++ b/problems/nvidia/nvfp4_gemv/submission.py @@ -70,10 +70,10 @@ def kernel( tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz] tBgSFB = gSFB_nkl[0, None, bidy, k_tile, bidz] - tArA = cute.make_rmem_tensor(tAgA, cutlass.Float32) - tBrB = cute.make_rmem_tensor(tBgB, cutlass.Float32) - tArSFA = cute.make_rmem_tensor(tAgSFA, cutlass.Float32) - tBrSFB = cute.make_rmem_tensor(tBgSFB, cutlass.Float32) + tArA = cute.make_rmem_tensor_like(tAgA, cutlass.Float32) + tBrB = cute.make_rmem_tensor_like(tBgB, cutlass.Float32) + tArSFA = cute.make_rmem_tensor_like(tAgSFA, cutlass.Float32) + tBrSFB = cute.make_rmem_tensor_like(tBgSFB, cutlass.Float32) # Load NVFP4 or FP8 values from global memory a_val_nvfp4 = tAgA.load()