File tree Expand file tree Collapse file tree
src/fp_quant/module/triton Expand file tree Collapse file tree Original file line number Diff line number Diff line change 22
33setup (
44 name = "fp_quant" ,
5- version = "0.1.5 " ,
5+ version = "0.1.6 " ,
66 packages = find_packages (where = "src" ),
77 package_dir = {"" : "src" },
88 author = "Andrei Panferov" ,
2020 install_requires = [
2121 "torch>=2.7.0" ,
2222 "scipy>=1.13.0" ,
23- "qutlass>=0.0.1" ,
2423 "triton>=3.3.0" ,
2524 ],
2625)
Original file line number Diff line number Diff line change @@ -215,18 +215,19 @@ def mxfp4_forward_kernel_wrapper(
215215 grid = lambda meta : (triton .cdiv (n_elements , meta ["BLOCK_SIZE" ]),)
216216
217217 # Launch optimized kernel
218- mxfp4_forward_kernel [grid ](
219- x_ptr = x ,
220- hadamard_matrix_ptr = hadamard_matrix ,
221- output_ptr = output ,
222- clip_mask_ptr = clip_mask ,
223- n_elements = n_elements ,
224- hadamard_dim = hadamard_matrix .shape [- 1 ],
225- group_size = 32 ,
226- gaussian_scale = gaussian_scale ,
227- stochastic_round = stochastic_round ,
228- seed = seed ,
229- quest = quest ,
230- )
218+ with torch .cuda .device (x .device ):
219+ mxfp4_forward_kernel [grid ](
220+ x_ptr = x ,
221+ hadamard_matrix_ptr = hadamard_matrix ,
222+ output_ptr = output ,
223+ clip_mask_ptr = clip_mask ,
224+ n_elements = n_elements ,
225+ hadamard_dim = hadamard_matrix .shape [- 1 ],
226+ group_size = 32 ,
227+ gaussian_scale = gaussian_scale ,
228+ stochastic_round = stochastic_round ,
229+ seed = seed ,
230+ quest = quest ,
231+ )
231232
232233 return output , clip_mask
You can’t perform that action at this time.
0 commit comments