Skip to content

Commit 1be36aa

Browse files
committed
triton dispach fix
1 parent b4e9d6b commit 1be36aa

2 files changed

Lines changed: 15 additions & 15 deletions

File tree

inference_lib/setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="fp_quant",
5-
version="0.1.5",
5+
version="0.1.6",
66
packages=find_packages(where="src"),
77
package_dir={"": "src"},
88
author="Andrei Panferov",
@@ -20,7 +20,6 @@
2020
install_requires=[
2121
"torch>=2.7.0",
2222
"scipy>=1.13.0",
23-
"qutlass>=0.0.1",
2423
"triton>=3.3.0",
2524
],
2625
)

inference_lib/src/fp_quant/module/triton/pseudoquant.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -215,18 +215,19 @@ def mxfp4_forward_kernel_wrapper(
215215
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
216216

217217
# Launch optimized kernel
218-
mxfp4_forward_kernel[grid](
219-
x_ptr=x,
220-
hadamard_matrix_ptr=hadamard_matrix,
221-
output_ptr=output,
222-
clip_mask_ptr=clip_mask,
223-
n_elements=n_elements,
224-
hadamard_dim=hadamard_matrix.shape[-1],
225-
group_size=32,
226-
gaussian_scale=gaussian_scale,
227-
stochastic_round=stochastic_round,
228-
seed=seed,
229-
quest=quest,
230-
)
218+
with torch.cuda.device(x.device):
219+
mxfp4_forward_kernel[grid](
220+
x_ptr=x,
221+
hadamard_matrix_ptr=hadamard_matrix,
222+
output_ptr=output,
223+
clip_mask_ptr=clip_mask,
224+
n_elements=n_elements,
225+
hadamard_dim=hadamard_matrix.shape[-1],
226+
group_size=32,
227+
gaussian_scale=gaussian_scale,
228+
stochastic_round=stochastic_round,
229+
seed=seed,
230+
quest=quest,
231+
)
231232

232233
return output, clip_mask

0 commit comments

Comments
 (0)