Skip to content

Commit c2c0a2b

Browse files
committed
models: Simplify debug code
1 parent b9f3978 commit c2c0a2b

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

models/proposed.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,11 @@ def forward(self, x):
154154

155155
if __name__ == '__main__':
156156
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
157-
158-
model = Proposed(3, 20).to(device)
157+
model = Proposed(3, 8).to(device)
159158
model.eval()
160159

161160
torchsummary.torchsummary.summary(model, (3, 256, 512))
162161

163-
input_image = torch.rand(1, 3, 256, 512).to(device)
164-
model(input_image)
165-
166162
writer = torch.utils.tensorboard.SummaryWriter('../runs')
167-
writer.add_graph(model, input_image)
163+
writer.add_graph(model, torch.rand(1, 3, 256, 512).to(device))
168164
writer.close()

models/unet.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,11 @@ def forward(self, x):
5858

5959
if __name__ == '__main__':
6060
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61-
62-
model = UNet(3, 20).to(device)
61+
model = UNet(3, 8).to(device)
6362
model.eval()
6463

6564
torchsummary.torchsummary.summary(model, (3, 256, 512))
6665

67-
input_image = torch.rand(1, 3, 256, 512).to(device)
68-
model(input_image)
69-
7066
writer = torch.utils.tensorboard.SummaryWriter('../runs')
71-
writer.add_graph(model, input_image)
67+
writer.add_graph(model, torch.rand(1, 3, 256, 512).to(device))
7268
writer.close()

0 commit comments

Comments
 (0)