@@ -697,7 +697,7 @@ def forward(self, x, bool_masked_pos, return_all_tokens=False):
697697class CAERegressorDecoder (nn .Layer ):
698698 def __init__ (self ,
699699 patch_size = 16 ,
700- num_classes = 8192 ,
700+ class_num = 8192 ,
701701 embed_dim = 768 ,
702702 depth = 6 ,
703703 num_heads = 12 ,
@@ -760,7 +760,7 @@ def __init__(self,
760760 if args .num_decoder_self_attention > 0 :
761761 self .norm2 = norm_layer (embed_dim )
762762 self .head = nn .Linear (
763- embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
763+ embed_dim , class_num ) if class_num > 0 else nn .Identity ()
764764
765765 self .init_std = init_std
766766
@@ -907,7 +907,7 @@ def __init__(self,
907907
908908 self .regressor_and_decoder = CAERegressorDecoder (
909909 patch_size = patch_size ,
910- num_classes = args .decoder_num_classes ,
910+ class_num = args .decoder_class_num ,
911911 embed_dim = args .decoder_embed_dim ,
912912 depth = args .regressor_depth ,
913913 num_heads = args .decoder_num_heads ,
@@ -1083,7 +1083,7 @@ def __init__(self,
10831083 img_size = 224 ,
10841084 patch_size = 16 ,
10851085 in_chans = 3 ,
1086- num_classes = 1000 ,
1086+ class_num = 1000 ,
10871087 embed_dim = 768 ,
10881088 depth = 12 ,
10891089 num_heads = 12 ,
@@ -1103,7 +1103,7 @@ def __init__(self,
11031103 lin_probe = False ,
11041104 args = None ):
11051105 super ().__init__ ()
1106- self .num_classes = num_classes
1106+ self .class_num = class_num
11071107 self .num_features = self .embed_dim = embed_dim # num_features for consistency with other models
11081108 self .use_mean_pooling = use_mean_pooling
11091109
@@ -1193,8 +1193,8 @@ def __init__(self,
11931193 init .trunc_normal_ (self .query_token , std = .02 )
11941194
11951195 self .head = nn .Linear (
1196- embed_dim , num_classes ,
1197- bias_attr = True ) if num_classes > 0 else nn .Identity ()
1196+ embed_dim , class_num ,
1197+ bias_attr = True ) if class_num > 0 else nn .Identity ()
11981198
11991199 if self .pos_embed is not None and use_abs_pos_emb :
12001200 init .trunc_normal_ (self .pos_embed , std = .02 )
@@ -1266,10 +1266,10 @@ def no_weight_decay(self):
12661266 def get_classifier (self ):
12671267 return self .head
12681268
1269- def reset_classifier (self , num_classes , global_pool = '' ):
1270- self .num_classes = num_classes
1269+ def reset_classifier (self , class_num , global_pool = '' ):
1270+ self .class_num = class_num
12711271 self .head = nn .Linear (
1272- self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
1272+ self .embed_dim , class_num ) if class_num > 0 else nn .Identity ()
12731273
12741274 def forward_features (self , x , is_train = True ):
12751275 x = self .patch_embed (x )
0 commit comments