From b1725d39e736b61f8235ea1e3a5a50a114c4f6a6 Mon Sep 17 00:00:00 2001 From: eamonn Date: Sat, 17 Aug 2024 16:39:35 +0800 Subject: [PATCH] Handle None in output_size --- torch2trt/converters/AdaptiveAvgPool2d.py | 39 +++++++++++++++++++---- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/torch2trt/converters/AdaptiveAvgPool2d.py b/torch2trt/converters/AdaptiveAvgPool2d.py index 41ad141d..3b7d070f 100644 --- a/torch2trt/converters/AdaptiveAvgPool2d.py +++ b/torch2trt/converters/AdaptiveAvgPool2d.py @@ -7,20 +7,37 @@ def convert_AdaptiveAvgPool2d(ctx): module = ctx.method_args[0] input = ctx.method_args[1] output = ctx.method_return - input_trt = add_missing_trt_tensors(ctx.network, [input])[0] - output_size = module.output_size - if not isinstance(output_size, tuple): - output_size = (output_size, ) * 2 - - stride = (input_trt.shape[-2] // output_size[-2], input_trt.shape[-1] // output_size[-1]) + # Determine the target output size + target_output_size = module.output_size + + if not isinstance(target_output_size, tuple): + target_output_size = (target_output_size, ) * 2 + + # Handle cases where target output size has None values + if None in target_output_size: + target_output_shape = tuple(output.shape) + target_output_shape = target_output_shape[2:] + new_output_size = [] + for size, shape in zip(target_output_size, target_output_shape): + if size is None: + new_output_size.append(shape) + else: + new_output_size.append(size) + target_output_size = tuple(new_output_size) + + # Calculate stride and kernel size + stride = (input_trt.shape[-2] // target_output_size[-2], input_trt.shape[-1] // target_output_size[-1]) kernel_size = stride + + # Create pooling layer layer = ctx.network.add_pooling( input=input_trt, type=trt.PoolingType.AVERAGE, window_size=kernel_size) layer.stride = stride + # Set _trt attribute for output output._trt = layer.get_output(0) @@ -37,3 +54,13 @@ def test_AdaptiveAvgPool2d_2x2(): @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) def test_AdaptiveAvgPool2d_3x3(): return torch.nn.AdaptiveAvgPool2d((3, 3)) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) +def test_AdaptiveAvgPool2d_None_1(): + return torch.nn.AdaptiveAvgPool2d((None, 1)) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) +def test_AdaptiveAvgPool2d_1_None(): + return torch.nn.AdaptiveAvgPool2d((1, None)) \ No newline at end of file