From 2981da56fa2b4579f11cb9621f9ce03e1a781e1d Mon Sep 17 00:00:00 2001 From: Leora Date: Thu, 11 Mar 2021 14:49:31 +0200 Subject: [PATCH] Update torchsummary.py Add support for a net outputting a list. Say you are returning intermediary feature-maps from a U-net. These feature-maps will have a different size and cannot be stacked into one tensor. Instead they can be placed in a list. --- torchsummary/torchsummary.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..4cd42d3 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -30,9 +30,15 @@ def hook(module, input, output): summary[m_key]["input_shape"] = list(input[0].size()) summary[m_key]["input_shape"][0] = batch_size if isinstance(output, (list, tuple)): - summary[m_key]["output_shape"] = [ - [-1] + list(o.size())[1:] for o in output - ] + output_shape = [] + for o in output: + if isinstance(o, (list, tuple)): + for item in o: + output_shape.append([-1] + list(item.size())[1:]) + else: + output_shape.append([-1] + list(o.size())[1:]) + + summary[m_key]["output_shape"] = output_shape else: summary[m_key]["output_shape"] = list(output.size()) summary[m_key]["output_shape"][0] = batch_size