diff --git a/torchsparse/nn/functional/conv/func/implicit_gemm.py b/torchsparse/nn/functional/conv/func/implicit_gemm.py index e91ad18..c700541 100644 --- a/torchsparse/nn/functional/conv/func/implicit_gemm.py +++ b/torchsparse/nn/functional/conv/func/implicit_gemm.py @@ -161,7 +161,7 @@ def backward(ctx, grad_output: torch.Tensor): reorder_out_in_map_bwd, reduced_sorted_mask_bwd_wgrad, reorder_loc_bwd, - 32, + 1, torchsparse.backends.allow_tf32, torchsparse.backends.allow_fp16, torchsparse.backends.allow_bf16 @@ -192,7 +192,7 @@ def backward(ctx, grad_output: torch.Tensor): grad_output, input, out_in_map_bwd, - 32, + 1, torchsparse.backends.allow_tf32, torchsparse.backends.allow_fp16, torchsparse.backends.allow_bf16