@@ -38,16 +38,34 @@ def hook(model, input, output):
3838
3939 # 모델의 각 계층에 특징맵을 받아오는 hook을 등록
4040 feature_maps = {}
41- model .encode1 .register_forward_hook (get_feature_maps (feature_maps , 'encode1' ))
42- model .encode2 .register_forward_hook (get_feature_maps (feature_maps , 'encode2' ))
43- model .encode3 .register_forward_hook (get_feature_maps (feature_maps , 'encode3' ))
44- model .encode4 .register_forward_hook (get_feature_maps (feature_maps , 'encode4' ))
45- model .encode_end .register_forward_hook (get_feature_maps (feature_maps , 'encode_end' ))
46- model .decode4 .register_forward_hook (get_feature_maps (feature_maps , 'decode4' ))
47- model .decode3 .register_forward_hook (get_feature_maps (feature_maps , 'decode3' ))
48- model .decode2 .register_forward_hook (get_feature_maps (feature_maps , 'decode2' ))
49- model .decode1 .register_forward_hook (get_feature_maps (feature_maps , 'decode1' ))
50- model .classifier .register_forward_hook (get_feature_maps (feature_maps , 'classifier' ))
41+ if config ['model_name' ] == 'UNet' :
42+ model .encode1 .register_forward_hook (get_feature_maps (feature_maps , 'encode1' ))
43+ model .encode2 .register_forward_hook (get_feature_maps (feature_maps , 'encode2' ))
44+ model .encode3 .register_forward_hook (get_feature_maps (feature_maps , 'encode3' ))
45+ model .encode4 .register_forward_hook (get_feature_maps (feature_maps , 'encode4' ))
46+ model .encode_end .register_forward_hook (get_feature_maps (feature_maps , 'encode_end' ))
47+ model .decode4 .register_forward_hook (get_feature_maps (feature_maps , 'decode4' ))
48+ model .decode3 .register_forward_hook (get_feature_maps (feature_maps , 'decode3' ))
49+ model .decode2 .register_forward_hook (get_feature_maps (feature_maps , 'decode2' ))
50+ model .decode1 .register_forward_hook (get_feature_maps (feature_maps , 'decode1' ))
51+ model .classifier .register_forward_hook (get_feature_maps (feature_maps , 'classifier' ))
52+ elif config ['model_name' ] == 'Proposed' :
53+ model .initial_conv .register_forward_hook (get_feature_maps (feature_maps , 'initial_conv' ))
54+ model .encode1 .register_forward_hook (get_feature_maps (feature_maps , 'encode1' ))
55+ model .encode2 .register_forward_hook (get_feature_maps (feature_maps , 'encode2' ))
56+ model .encode3 .register_forward_hook (get_feature_maps (feature_maps , 'encode3' ))
57+ model .encode4 .register_forward_hook (get_feature_maps (feature_maps , 'encode4' ))
58+ model .aspp .register_forward_hook (get_feature_maps (feature_maps , 'aspp' ))
59+ model .decode3 .register_forward_hook (get_feature_maps (feature_maps , 'decode3' ))
60+ model .decode2 .register_forward_hook (get_feature_maps (feature_maps , 'decode2' ))
61+ model .decode1 .register_forward_hook (get_feature_maps (feature_maps , 'decode1' ))
62+ model .classifier .register_forward_hook (get_feature_maps (feature_maps , 'classifier' ))
63+ elif config ['model_name' ] == 'Backbone' :
64+ model .initial_conv .register_forward_hook (get_feature_maps (feature_maps , 'initial_conv' ))
65+ model .layer1 .register_forward_hook (get_feature_maps (feature_maps , 'layer1' ))
66+ model .layer2 .register_forward_hook (get_feature_maps (feature_maps , 'layer2' ))
67+ model .layer3 .register_forward_hook (get_feature_maps (feature_maps , 'layer3' ))
68+ model .layer4 .register_forward_hook (get_feature_maps (feature_maps , 'layer4' ))
5169
5270 # 예측
5371 with torch .no_grad ():
0 commit comments