@@ -75,7 +75,7 @@ def __init__(self, params):
7575 def get_parameters_to_update (self ):
7676 if (self .update_mode == "all" ):
7777 return self .net .parameters ()
78- elif (self .update_layers == "last" ):
78+ elif (self .update_mode == "last" ):
7979 params = self .net .fc .parameters ()
8080 if (self .in_chns != 3 ):
8181 # combining the two iterables into a single one
@@ -119,7 +119,7 @@ def get_parameters_to_update(self):
119119 params = self .net .classifier [- 1 ].parameters ()
120120 if (self .in_chns != 3 ):
121121 params = itertools .chain ()
122- for pram in [self .net .classifier [- 1 ].parameters (), self .net .net . features [0 ].parameters ()]:
122+ for pram in [self .net .classifier [- 1 ].parameters (), self .net .features [0 ].parameters ()]:
123123 params = itertools .chain (params , pram )
124124 return params
125125 else :
@@ -138,7 +138,7 @@ class MobileNetV2(BuiltInNet):
138138 as well as the first layer when `input_chns` is not 3.
139139 """
140140 def __init__ (self , params ):
141- super (MobileNetV2 , self ).__init__ ()
141+ super (MobileNetV2 , self ).__init__ (params )
142142 self .net = models .mobilenet_v2 (pretrained = self .pretrain )
143143
144144 # replace the last layer
@@ -157,7 +157,7 @@ def get_parameters_to_update(self):
157157 params = self .net .classifier [- 1 ].parameters ()
158158 if (self .in_chns != 3 ):
159159 params = itertools .chain ()
160- for pram in [self .net .classifier [- 1 ].parameters (), self .net .net . features [0 ][0 ].parameters ()]:
160+ for pram in [self .net .classifier [- 1 ].parameters (), self .net .features [0 ][0 ].parameters ()]:
161161 params = itertools .chain (params , pram )
162162 return params
163163 else :
0 commit comments