diff --git a/tests/converter_tests/test_converters.py b/tests/converter_tests/test_converters.py index de9e0a55..c67fe260 100644 --- a/tests/converter_tests/test_converters.py +++ b/tests/converter_tests/test_converters.py @@ -203,6 +203,11 @@ def test_batch_norm_nd(nd, with_conv): inputs = [torch.randn(*input_size).cuda()] cross_validate(module, inputs, fp16_mode=False, tol=1e-1) + if nd == 1: + input_size = [2, 3] + inputs = [torch.randn(*input_size).cuda()] + cross_validate(module, inputs, fp16_mode=False, tol=1e-1) + @pytest.mark.parametrize("dim", [1, -1]) def test_cat(dim): diff --git a/torch2trt/converters/native_converters.py b/torch2trt/converters/native_converters.py index e1dcd145..660fedbd 100644 --- a/torch2trt/converters/native_converters.py +++ b/torch2trt/converters/native_converters.py @@ -156,7 +156,7 @@ def convert_batch_norm(ctx): bias = bias.detach().cpu().numpy() - running_mean.detach().cpu().numpy() * scale power = np.ones_like(scale) - if ndim == 1: + if ndim == 1 or ndim == 0: # reshape to 2D layer = ctx.network.add_shuffle(input_trt) @@ -171,7 +171,7 @@ def convert_batch_norm(ctx): layer = ctx.network.add_scale_nd(scale_input, trt.ScaleMode.CHANNEL, bias, scale, power, 1) - if ndim == 1: + if ndim == 1 or ndim == 0: # reshape back to 1D layer = ctx.network.add_shuffle(layer.get_output(0)) if len(input.shape) == 2: