Skip to content

Commit 9ac0e46

Browse files
committed
feature_vis: Implement forward_hook to other models
1 parent 44a9139 commit 9ac0e46

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

feature_visualizer.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,34 @@ def hook(model, input, output):
3838

3939
# 모델의 각 계층에 특징맵을 받아오는 hook을 등록
4040
feature_maps = {}
41-
model.encode1.register_forward_hook(get_feature_maps(feature_maps, 'encode1'))
42-
model.encode2.register_forward_hook(get_feature_maps(feature_maps, 'encode2'))
43-
model.encode3.register_forward_hook(get_feature_maps(feature_maps, 'encode3'))
44-
model.encode4.register_forward_hook(get_feature_maps(feature_maps, 'encode4'))
45-
model.encode_end.register_forward_hook(get_feature_maps(feature_maps, 'encode_end'))
46-
model.decode4.register_forward_hook(get_feature_maps(feature_maps, 'decode4'))
47-
model.decode3.register_forward_hook(get_feature_maps(feature_maps, 'decode3'))
48-
model.decode2.register_forward_hook(get_feature_maps(feature_maps, 'decode2'))
49-
model.decode1.register_forward_hook(get_feature_maps(feature_maps, 'decode1'))
50-
model.classifier.register_forward_hook(get_feature_maps(feature_maps, 'classifier'))
41+
if config['model_name'] == 'UNet':
42+
model.encode1.register_forward_hook(get_feature_maps(feature_maps, 'encode1'))
43+
model.encode2.register_forward_hook(get_feature_maps(feature_maps, 'encode2'))
44+
model.encode3.register_forward_hook(get_feature_maps(feature_maps, 'encode3'))
45+
model.encode4.register_forward_hook(get_feature_maps(feature_maps, 'encode4'))
46+
model.encode_end.register_forward_hook(get_feature_maps(feature_maps, 'encode_end'))
47+
model.decode4.register_forward_hook(get_feature_maps(feature_maps, 'decode4'))
48+
model.decode3.register_forward_hook(get_feature_maps(feature_maps, 'decode3'))
49+
model.decode2.register_forward_hook(get_feature_maps(feature_maps, 'decode2'))
50+
model.decode1.register_forward_hook(get_feature_maps(feature_maps, 'decode1'))
51+
model.classifier.register_forward_hook(get_feature_maps(feature_maps, 'classifier'))
52+
elif config['model_name'] == 'Proposed':
53+
model.initial_conv.register_forward_hook(get_feature_maps(feature_maps, 'initial_conv'))
54+
model.encode1.register_forward_hook(get_feature_maps(feature_maps, 'encode1'))
55+
model.encode2.register_forward_hook(get_feature_maps(feature_maps, 'encode2'))
56+
model.encode3.register_forward_hook(get_feature_maps(feature_maps, 'encode3'))
57+
model.encode4.register_forward_hook(get_feature_maps(feature_maps, 'encode4'))
58+
model.aspp.register_forward_hook(get_feature_maps(feature_maps, 'aspp'))
59+
model.decode3.register_forward_hook(get_feature_maps(feature_maps, 'decode3'))
60+
model.decode2.register_forward_hook(get_feature_maps(feature_maps, 'decode2'))
61+
model.decode1.register_forward_hook(get_feature_maps(feature_maps, 'decode1'))
62+
model.classifier.register_forward_hook(get_feature_maps(feature_maps, 'classifier'))
63+
elif config['model_name'] == 'Backbone':
64+
model.initial_conv.register_forward_hook(get_feature_maps(feature_maps, 'initial_conv'))
65+
model.layer1.register_forward_hook(get_feature_maps(feature_maps, 'layer1'))
66+
model.layer2.register_forward_hook(get_feature_maps(feature_maps, 'layer2'))
67+
model.layer3.register_forward_hook(get_feature_maps(feature_maps, 'layer3'))
68+
model.layer4.register_forward_hook(get_feature_maps(feature_maps, 'layer4'))
5169

5270
# 예측
5371
with torch.no_grad():

0 commit comments

Comments
 (0)