Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions torch2trt/converters/AdaptiveAvgPool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,37 @@ def convert_AdaptiveAvgPool2d(ctx):
module = ctx.method_args[0]
input = ctx.method_args[1]
output = ctx.method_return

input_trt = add_missing_trt_tensors(ctx.network, [input])[0]

output_size = module.output_size
if not isinstance(output_size, tuple):
output_size = (output_size, ) * 2

stride = (input_trt.shape[-2] // output_size[-2], input_trt.shape[-1] // output_size[-1])
# Determine the target output size
target_output_size = module.output_size

if not isinstance(target_output_size, tuple):
target_output_size = (target_output_size, ) * 2

# Handle cases where target output size has None values
if None in target_output_size:
target_output_shape = tuple(output.shape)
target_output_shape = target_output_shape[2:]
new_output_size = []
for size, shape in zip(target_output_size, target_output_shape):
if size is None:
new_output_size.append(shape)
else:
new_output_size.append(size)

target_output_size = tuple(new_output_size)

# Calculate stride and kernel size
stride = (input_trt.shape[-2] // target_output_size[-2], input_trt.shape[-1] // target_output_size[-1])
kernel_size = stride

# Create pooling layer
layer = ctx.network.add_pooling(
input=input_trt, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
layer.stride = stride

# Set _trt attribute for output
output._trt = layer.get_output(0)


Expand All @@ -37,3 +54,13 @@ def test_AdaptiveAvgPool2d_2x2():
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)])
def test_AdaptiveAvgPool2d_3x3():
return torch.nn.AdaptiveAvgPool2d((3, 3))


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)])
def test_AdaptiveAvgPool2d_None_1():
return torch.nn.AdaptiveAvgPool2d((None, 1))


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)])
def test_AdaptiveAvgPool2d_1_None():
return torch.nn.AdaptiveAvgPool2d((1, None))