Skip to content

Commit 0cf5fb6

Browse files
committed
new fp3 kernel
Signed-off-by: mxin <mxin@nvidia.com>
1 parent 89c686c commit 0cf5fb6

File tree

1 file changed

+89
-62
lines changed

1 file changed

+89
-62
lines changed

modelopt/torch/quantization/triton/fp4_kernel.py

Lines changed: 89 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -34,60 +34,60 @@ def fp4_fake_quant_kernel(
3434
M,
3535
N,
3636
global_scale_ptr,
37+
stride_xm,
38+
stride_xn,
39+
stride_ym,
40+
stride_yn,
3741
BLOCK_SIZE: tl.constexpr,
38-
TILE_SIZE: tl.constexpr,
42+
TILE_M: tl.constexpr,
43+
TILE_N: tl.constexpr,
3944
NUM_FP4_BLOCKS: tl.constexpr,
4045
):
41-
"""Applies FP4 fake quantization on input data using per-block scaling factors.
42-
43-
Args:
44-
x_ptr (tl.pointer): Pointer to the input tensor (BF16/FP32)
45-
y_ptr (tl.pointer): Pointer to the output buffer
46-
M (int): Number of rows in the matrix
47-
N (int): Number of columns in the matrix
48-
global_scale_ptr (tl.pointer): Pointer to the global scaling factor tensor
49-
BLOCK_SIZE (tl.constexpr): Size of each FP4 quantization block
50-
TILE_SIZE (tl.constexpr): Size of the processing block
51-
NUM_FP4_BLOCKS (tl.constexpr): Number of FP4 blocks within TILE_SIZE
52-
"""
46+
"""Applies FP4 fake quantization using block pointers for memory addressing."""
5347
pid_m = tl.program_id(axis=0)
5448
pid_n = tl.program_id(axis=1)
5549

56-
# Load global scale from tensor
57-
global_scale = tl.load(global_scale_ptr).to(tl.float32)
50+
row_start = pid_m * TILE_M
51+
col_start = pid_n * TILE_N
5852

59-
# Calculate offsets
60-
offs_m = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE)
61-
offs_n = pid_n * TILE_SIZE + tl.arange(0, TILE_SIZE)
62-
offs = offs_m[:, None] * N + offs_n[None, :]
63-
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
53+
x_block_ptr = tl.make_block_ptr(
54+
base=x_ptr,
55+
shape=(M, N),
56+
strides=(stride_xm, stride_xn),
57+
offsets=(row_start, col_start),
58+
block_shape=(TILE_M, TILE_N),
59+
order=(1, 0),
60+
)
61+
y_block_ptr = tl.make_block_ptr(
62+
base=y_ptr,
63+
shape=(M, N),
64+
strides=(stride_ym, stride_yn),
65+
offsets=(row_start, col_start),
66+
block_shape=(TILE_M, TILE_N),
67+
order=(1, 0),
68+
)
69+
70+
global_scale = tl.load(global_scale_ptr).to(tl.float32)
71+
global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12)
6472

65-
# Load input data
66-
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
73+
tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
6774

68-
# Reshape for block processing
69-
x_reshaped = tl.reshape(x, (TILE_SIZE, NUM_FP4_BLOCKS, BLOCK_SIZE))
70-
x_abs = tl.abs(x_reshaped)
75+
tile_reshaped = tl.reshape(tile, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE))
76+
x_abs = tl.abs(tile_reshaped)
7177

72-
# Calculate max values for each FP4 block
7378
block_max = tl.max(x_abs, axis=2, keep_dims=True)
74-
# global_scale = global_amax / (448 * 6)
75-
block_max_quant = (
76-
tl.minimum((block_max / (6.0 * global_scale)), 448.0).to(tl.float8e4nv).to(tl.float32)
77-
* global_scale
78-
)
7979

80-
# Broadcast max values
80+
block_max_scaled = block_max / (6.0 * global_scale_safe)
81+
block_max_scaled = tl.minimum(block_max_scaled, 448.0)
82+
block_max_quant = block_max_scaled.to(tl.float8e4nv).to(tl.float32) * global_scale
83+
block_max_quant = tl.where(block_max_quant >= 1e-5, block_max_quant, 1.0)
84+
8185
block_max_quant_broadcast = tl.broadcast_to(
82-
block_max_quant, (TILE_SIZE, NUM_FP4_BLOCKS, BLOCK_SIZE)
83-
)
84-
# Set scale to 1 if block amax is 0
85-
block_max_quant_broadcast = tl.where(
86-
block_max_quant_broadcast < 1e-5, 1.0, block_max_quant_broadcast
86+
block_max_quant, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE)
8787
)
88+
8889
abs_scaled = x_abs / block_max_quant_broadcast
8990

90-
# Quantize to FP4 values: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}, following round to even
9191
q_val = tl.where(
9292
abs_scaled <= 0.25,
9393
0.0,
@@ -103,63 +103,90 @@ def fp4_fake_quant_kernel(
103103
tl.where(
104104
abs_scaled <= 2.5,
105105
2.0,
106-
tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, 6.0)),
106+
tl.where(
107+
abs_scaled < 3.5,
108+
3.0,
109+
tl.where(abs_scaled <= 5.0, 4.0, 6.0),
110+
),
107111
),
108112
),
109113
),
110114
),
111115
)
112116

113-
# Apply signs and rescale
114117
x_rescaled = q_val * block_max_quant_broadcast
115-
x_rescaled = tl.where(x_reshaped >= 0, x_rescaled, -x_rescaled)
118+
x_rescaled = tl.where(tile_reshaped >= 0, x_rescaled, -x_rescaled)
119+
120+
tile_quant = tl.reshape(x_rescaled, (TILE_M, TILE_N))
116121

117-
# Reshape back and store
118-
x_rescaled = tl.reshape(x_rescaled, (TILE_SIZE, TILE_SIZE))
119-
tl.store(y_ptr + offs, x_rescaled, mask=mask)
122+
tl.store(y_block_ptr, tile_quant, boundary_check=(0, 1))
120123

121124

122125
def fp4_fake_quant_block(
123126
x: torch.Tensor,
124127
global_amax: torch.Tensor,
125128
block_size: int = 16,
126-
tile_size: int = 128,
129+
tile_rows: int = 16,
130+
tile_cols: int = 64,
131+
num_warps: int | None = None,
132+
num_stages: int | None = None,
127133
) -> torch.Tensor:
128-
"""Applies FP4 fake quantization on the input tensor.
134+
"""FP4 fake quantization implementation using block-pointer tiling.
129135
130136
Args:
131-
x (torch.Tensor): Input tensor of shape (M, N)
132-
global_amax (torch.Tensor): Global max value of the input tensor
133-
This needs to be a tensor to be cuda-graph compatible
134-
block_size (int): Size of FP4 quantization blocks
135-
tile_size (int): Size of processing blocks
137+
x (torch.Tensor): Input tensor of shape ``(M, N)`` or higher.
138+
global_amax (torch.Tensor): Global maximum value tensor for scaling.
139+
block_size (int): Number of elements per FP4 block.
140+
tile_rows (int, optional): Row tile size. Defaults to 64.
141+
tile_cols (int, optional): Column tile size. Defaults to 128. Rounded up to
142+
the nearest multiple of ``block_size`` internally.
143+
num_warps (int | None, optional): Override for Triton warps. Autotuned when ``None``.
144+
num_stages (int | None, optional): Override for pipeline stages. Autotuned when ``None``.
136145
137146
Returns:
138-
torch.Tensor: Quantized tensor of the same shape as input
147+
torch.Tensor: Fake-quantized tensor matching the input shape and dtype.
139148
"""
140149
x_shape = x.shape
141150
x_dtype = x.dtype
142151
x = x.reshape(-1, x_shape[-1]).contiguous()
143152

144-
M, N = x.size()
145-
y = torch.empty_like(x, dtype=torch.get_default_dtype())
153+
M, N = x.shape
154+
y = torch.empty_like(x, dtype=torch.float32)
155+
156+
stride_xm, stride_xn = x.stride()
157+
stride_ym, stride_yn = y.stride()
158+
159+
tile_cols = max(tile_cols, block_size)
160+
tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size
161+
num_fp4_blocks = tile_cols_aligned // block_size
146162

147-
grid = lambda meta: (
148-
triton.cdiv(M, meta["TILE_SIZE"]),
149-
triton.cdiv(N, meta["TILE_SIZE"]),
150-
)
151163
global_scale = global_amax.float() / (6.0 * 448.0)
152-
num_fp4_blocks = tile_size // block_size
164+
165+
grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned))
166+
167+
launch_kwargs = {
168+
"BLOCK_SIZE": block_size,
169+
"TILE_M": tile_rows,
170+
"TILE_N": tile_cols_aligned,
171+
"NUM_FP4_BLOCKS": num_fp4_blocks,
172+
}
173+
if num_warps is not None:
174+
launch_kwargs["num_warps"] = num_warps
175+
if num_stages is not None:
176+
launch_kwargs["num_stages"] = num_stages
153177
fp4_fake_quant_kernel[grid](
154178
x,
155179
y,
156180
M,
157181
N,
158182
global_scale,
159-
TILE_SIZE=tile_size,
160-
BLOCK_SIZE=block_size,
161-
NUM_FP4_BLOCKS=num_fp4_blocks,
183+
stride_xm,
184+
stride_xn,
185+
stride_ym,
186+
stride_yn,
187+
**launch_kwargs,
162188
)
189+
163190
y = y.reshape(x_shape).contiguous().to(dtype=x_dtype)
164191
return y
165192

0 commit comments

Comments
 (0)