@@ -34,60 +34,60 @@ def fp4_fake_quant_kernel(
3434 M ,
3535 N ,
3636 global_scale_ptr ,
37+ stride_xm ,
38+ stride_xn ,
39+ stride_ym ,
40+ stride_yn ,
3741 BLOCK_SIZE : tl .constexpr ,
38- TILE_SIZE : tl .constexpr ,
42+ TILE_M : tl .constexpr ,
43+ TILE_N : tl .constexpr ,
3944 NUM_FP4_BLOCKS : tl .constexpr ,
4045):
41- """Applies FP4 fake quantization on input data using per-block scaling factors.
42-
43- Args:
44- x_ptr (tl.pointer): Pointer to the input tensor (BF16/FP32)
45- y_ptr (tl.pointer): Pointer to the output buffer
46- M (int): Number of rows in the matrix
47- N (int): Number of columns in the matrix
48- global_scale_ptr (tl.pointer): Pointer to the global scaling factor tensor
49- BLOCK_SIZE (tl.constexpr): Size of each FP4 quantization block
50- TILE_SIZE (tl.constexpr): Size of the processing block
51- NUM_FP4_BLOCKS (tl.constexpr): Number of FP4 blocks within TILE_SIZE
52- """
46+ """Applies FP4 fake quantization using block pointers for memory addressing."""
5347 pid_m = tl .program_id (axis = 0 )
5448 pid_n = tl .program_id (axis = 1 )
5549
56- # Load global scale from tensor
57- global_scale = tl . load ( global_scale_ptr ). to ( tl . float32 )
50+ row_start = pid_m * TILE_M
51+ col_start = pid_n * TILE_N
5852
59- # Calculate offsets
60- offs_m = pid_m * TILE_SIZE + tl .arange (0 , TILE_SIZE )
61- offs_n = pid_n * TILE_SIZE + tl .arange (0 , TILE_SIZE )
62- offs = offs_m [:, None ] * N + offs_n [None , :]
63- mask = (offs_m [:, None ] < M ) & (offs_n [None , :] < N )
53+ x_block_ptr = tl .make_block_ptr (
54+ base = x_ptr ,
55+ shape = (M , N ),
56+ strides = (stride_xm , stride_xn ),
57+ offsets = (row_start , col_start ),
58+ block_shape = (TILE_M , TILE_N ),
59+ order = (1 , 0 ),
60+ )
61+ y_block_ptr = tl .make_block_ptr (
62+ base = y_ptr ,
63+ shape = (M , N ),
64+ strides = (stride_ym , stride_yn ),
65+ offsets = (row_start , col_start ),
66+ block_shape = (TILE_M , TILE_N ),
67+ order = (1 , 0 ),
68+ )
69+
70+ global_scale = tl .load (global_scale_ptr ).to (tl .float32 )
71+ global_scale_safe = tl .where (global_scale > 0.0 , global_scale , 1e-12 )
6472
65- # Load input data
66- x = tl .load (x_ptr + offs , mask = mask ).to (tl .float32 )
73+ tile = tl .load (x_block_ptr , boundary_check = (0 , 1 ), padding_option = "zero" ).to (tl .float32 )
6774
68- # Reshape for block processing
69- x_reshaped = tl .reshape (x , (TILE_SIZE , NUM_FP4_BLOCKS , BLOCK_SIZE ))
70- x_abs = tl .abs (x_reshaped )
75+ tile_reshaped = tl .reshape (tile , (TILE_M , NUM_FP4_BLOCKS , BLOCK_SIZE ))
76+ x_abs = tl .abs (tile_reshaped )
7177
72- # Calculate max values for each FP4 block
7378 block_max = tl .max (x_abs , axis = 2 , keep_dims = True )
74- # global_scale = global_amax / (448 * 6)
75- block_max_quant = (
76- tl .minimum ((block_max / (6.0 * global_scale )), 448.0 ).to (tl .float8e4nv ).to (tl .float32 )
77- * global_scale
78- )
7979
80- # Broadcast max values
80+ block_max_scaled = block_max / (6.0 * global_scale_safe )
81+ block_max_scaled = tl .minimum (block_max_scaled , 448.0 )
82+ block_max_quant = block_max_scaled .to (tl .float8e4nv ).to (tl .float32 ) * global_scale
83+ block_max_quant = tl .where (block_max_quant >= 1e-5 , block_max_quant , 1.0 )
84+
8185 block_max_quant_broadcast = tl .broadcast_to (
82- block_max_quant , (TILE_SIZE , NUM_FP4_BLOCKS , BLOCK_SIZE )
83- )
84- # Set scale to 1 if block amax is 0
85- block_max_quant_broadcast = tl .where (
86- block_max_quant_broadcast < 1e-5 , 1.0 , block_max_quant_broadcast
86+ block_max_quant , (TILE_M , NUM_FP4_BLOCKS , BLOCK_SIZE )
8787 )
88+
8889 abs_scaled = x_abs / block_max_quant_broadcast
8990
90- # Quantize to FP4 values: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}, following round to even
9191 q_val = tl .where (
9292 abs_scaled <= 0.25 ,
9393 0.0 ,
@@ -103,63 +103,90 @@ def fp4_fake_quant_kernel(
103103 tl .where (
104104 abs_scaled <= 2.5 ,
105105 2.0 ,
106- tl .where (abs_scaled < 3.5 , 3.0 , tl .where (abs_scaled <= 5.0 , 4.0 , 6.0 )),
106+ tl .where (
107+ abs_scaled < 3.5 ,
108+ 3.0 ,
109+ tl .where (abs_scaled <= 5.0 , 4.0 , 6.0 ),
110+ ),
107111 ),
108112 ),
109113 ),
110114 ),
111115 )
112116
113- # Apply signs and rescale
114117 x_rescaled = q_val * block_max_quant_broadcast
115- x_rescaled = tl .where (x_reshaped >= 0 , x_rescaled , - x_rescaled )
118+ x_rescaled = tl .where (tile_reshaped >= 0 , x_rescaled , - x_rescaled )
119+
120+ tile_quant = tl .reshape (x_rescaled , (TILE_M , TILE_N ))
116121
117- # Reshape back and store
118- x_rescaled = tl .reshape (x_rescaled , (TILE_SIZE , TILE_SIZE ))
119- tl .store (y_ptr + offs , x_rescaled , mask = mask )
122+ tl .store (y_block_ptr , tile_quant , boundary_check = (0 , 1 ))
120123
121124
122125def fp4_fake_quant_block (
123126 x : torch .Tensor ,
124127 global_amax : torch .Tensor ,
125128 block_size : int = 16 ,
126- tile_size : int = 128 ,
129+ tile_rows : int = 16 ,
130+ tile_cols : int = 64 ,
131+ num_warps : int | None = None ,
132+ num_stages : int | None = None ,
127133) -> torch .Tensor :
128- """Applies FP4 fake quantization on the input tensor .
134+ """FP4 fake quantization implementation using block-pointer tiling .
129135
130136 Args:
131- x (torch.Tensor): Input tensor of shape (M, N)
132- global_amax (torch.Tensor): Global max value of the input tensor
133- This needs to be a tensor to be cuda-graph compatible
134- block_size (int): Size of FP4 quantization blocks
135- tile_size (int): Size of processing blocks
137+ x (torch.Tensor): Input tensor of shape ``(M, N)`` or higher.
138+ global_amax (torch.Tensor): Global maximum value tensor for scaling.
139+ block_size (int): Number of elements per FP4 block.
140+ tile_rows (int, optional): Row tile size. Defaults to 64.
141+ tile_cols (int, optional): Column tile size. Defaults to 128. Rounded up to
142+ the nearest multiple of ``block_size`` internally.
143+ num_warps (int | None, optional): Override for Triton warps. Autotuned when ``None``.
144+ num_stages (int | None, optional): Override for pipeline stages. Autotuned when ``None``.
136145
137146 Returns:
138- torch.Tensor: Quantized tensor of the same shape as input
147+ torch.Tensor: Fake-quantized tensor matching the input shape and dtype.
139148 """
140149 x_shape = x .shape
141150 x_dtype = x .dtype
142151 x = x .reshape (- 1 , x_shape [- 1 ]).contiguous ()
143152
144- M , N = x .size ()
145- y = torch .empty_like (x , dtype = torch .get_default_dtype ())
153+ M , N = x .shape
154+ y = torch .empty_like (x , dtype = torch .float32 )
155+
156+ stride_xm , stride_xn = x .stride ()
157+ stride_ym , stride_yn = y .stride ()
158+
159+ tile_cols = max (tile_cols , block_size )
160+ tile_cols_aligned = ((tile_cols + block_size - 1 ) // block_size ) * block_size
161+ num_fp4_blocks = tile_cols_aligned // block_size
146162
147- grid = lambda meta : (
148- triton .cdiv (M , meta ["TILE_SIZE" ]),
149- triton .cdiv (N , meta ["TILE_SIZE" ]),
150- )
151163 global_scale = global_amax .float () / (6.0 * 448.0 )
152- num_fp4_blocks = tile_size // block_size
164+
165+ grid = lambda * _ : (triton .cdiv (M , tile_rows ), triton .cdiv (N , tile_cols_aligned ))
166+
167+ launch_kwargs = {
168+ "BLOCK_SIZE" : block_size ,
169+ "TILE_M" : tile_rows ,
170+ "TILE_N" : tile_cols_aligned ,
171+ "NUM_FP4_BLOCKS" : num_fp4_blocks ,
172+ }
173+ if num_warps is not None :
174+ launch_kwargs ["num_warps" ] = num_warps
175+ if num_stages is not None :
176+ launch_kwargs ["num_stages" ] = num_stages
153177 fp4_fake_quant_kernel [grid ](
154178 x ,
155179 y ,
156180 M ,
157181 N ,
158182 global_scale ,
159- TILE_SIZE = tile_size ,
160- BLOCK_SIZE = block_size ,
161- NUM_FP4_BLOCKS = num_fp4_blocks ,
183+ stride_xm ,
184+ stride_xn ,
185+ stride_ym ,
186+ stride_yn ,
187+ ** launch_kwargs ,
162188 )
189+
163190 y = y .reshape (x_shape ).contiguous ().to (dtype = x_dtype )
164191 return y
165192
0 commit comments