Skip to content

Commit 05fecac

Browse files
committed
unify num_classes to class_num
1 parent 068b30f commit 05fecac

File tree

12 files changed

+40
-40
lines changed

12 files changed

+40
-40
lines changed

passl/models/cae.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ def forward(self, x, bool_masked_pos, return_all_tokens=False):
697697
class CAERegressorDecoder(nn.Layer):
698698
def __init__(self,
699699
patch_size=16,
700-
num_classes=8192,
700+
class_num=8192,
701701
embed_dim=768,
702702
depth=6,
703703
num_heads=12,
@@ -760,7 +760,7 @@ def __init__(self,
760760
if args.num_decoder_self_attention > 0:
761761
self.norm2 = norm_layer(embed_dim)
762762
self.head = nn.Linear(
763-
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
763+
embed_dim, class_num) if class_num > 0 else nn.Identity()
764764

765765
self.init_std = init_std
766766

@@ -907,7 +907,7 @@ def __init__(self,
907907

908908
self.regressor_and_decoder = CAERegressorDecoder(
909909
patch_size=patch_size,
910-
num_classes=args.decoder_num_classes,
910+
class_num=args.decoder_class_num,
911911
embed_dim=args.decoder_embed_dim,
912912
depth=args.regressor_depth,
913913
num_heads=args.decoder_num_heads,
@@ -1083,7 +1083,7 @@ def __init__(self,
10831083
img_size=224,
10841084
patch_size=16,
10851085
in_chans=3,
1086-
num_classes=1000,
1086+
class_num=1000,
10871087
embed_dim=768,
10881088
depth=12,
10891089
num_heads=12,
@@ -1103,7 +1103,7 @@ def __init__(self,
11031103
lin_probe=False,
11041104
args=None):
11051105
super().__init__()
1106-
self.num_classes = num_classes
1106+
self.class_num = class_num
11071107
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
11081108
self.use_mean_pooling = use_mean_pooling
11091109

@@ -1193,8 +1193,8 @@ def __init__(self,
11931193
init.trunc_normal_(self.query_token, std=.02)
11941194

11951195
self.head = nn.Linear(
1196-
embed_dim, num_classes,
1197-
bias_attr=True) if num_classes > 0 else nn.Identity()
1196+
embed_dim, class_num,
1197+
bias_attr=True) if class_num > 0 else nn.Identity()
11981198

11991199
if self.pos_embed is not None and use_abs_pos_emb:
12001200
init.trunc_normal_(self.pos_embed, std=.02)
@@ -1266,10 +1266,10 @@ def no_weight_decay(self):
12661266
def get_classifier(self):
12671267
return self.head
12681268

1269-
def reset_classifier(self, num_classes, global_pool=''):
1270-
self.num_classes = num_classes
1269+
def reset_classifier(self, class_num, global_pool=''):
1270+
self.class_num = class_num
12711271
self.head = nn.Linear(
1272-
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
1272+
self.embed_dim, class_num) if class_num > 0 else nn.Identity()
12731273

12741274
def forward_features(self, x, is_train=True):
12751275
x = self.patch_embed(x)

passl/models/cait.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def __init__(self,
234234
img_size=224,
235235
patch_size=16,
236236
in_chans=3,
237-
num_classes=1000,
237+
class_num=1000,
238238
global_pool='token',
239239
embed_dim=768,
240240
depth=12,
@@ -260,7 +260,7 @@ def __init__(self,
260260
super().__init__()
261261
assert global_pool in ('', 'token', 'avg')
262262

263-
self.num_classes = num_classes
263+
self.class_num = class_num
264264
self.global_pool = global_pool
265265
self.num_features = self.embed_dim = embed_dim
266266

@@ -319,7 +319,7 @@ def __init__(self,
319319
num_chs=embed_dim, reduction=0, module='head')
320320
]
321321
self.head = nn.Linear(
322-
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
322+
embed_dim, class_num) if class_num > 0 else nn.Identity()
323323

324324
init.trunc_normal_(self.pos_embed, std=.02)
325325
init.trunc_normal_(self.cls_token, std=.02)
@@ -340,14 +340,14 @@ def no_weight_decay(self):
340340
def get_classifier(self):
341341
return self.head
342342

343-
def reset_classifier(self, num_classes, global_pool=None):
344-
self.num_classes = num_classes
343+
def reset_classifier(self, class_num, global_pool=None):
344+
self.class_num = class_num
345345
if global_pool is not None:
346346
assert global_pool in ('', 'token', 'avg')
347347
self.global_pool = global_pool
348348
self.head = nn.Linear(
349349
self.num_features,
350-
num_classes) if num_classes > 0 else nn.Identity()
350+
class_num) if class_num > 0 else nn.Identity()
351351

352352
def forward_features(self, x):
353353
x = self.patch_embed(x)

passl/models/convmae/conv_vit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def __init__(self,
180180
img_size=224,
181181
patch_size=16,
182182
in_chans=3,
183-
num_classes=1000,
183+
class_num=1000,
184184
embed_dim=768,
185185
depth=12,
186186
num_heads=12,
@@ -195,7 +195,7 @@ def __init__(self,
195195
global_pool=False,
196196
**kwargs):
197197
super().__init__()
198-
self.num_classes = num_classes
198+
self.class_num = class_num
199199
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
200200

201201
if hybrid_backbone is not None:
@@ -269,7 +269,7 @@ def __init__(self,
269269

270270
# Classifier head
271271
self.head = nn.Linear(
272-
embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
272+
embed_dim[-1], class_num) if class_num > 0 else nn.Identity()
273273

274274
init.trunc_normal_(self.pos_embed, std=.02)
275275
self.apply(self._init_weights)
@@ -294,10 +294,10 @@ def no_weight_decay(self):
294294
def get_classifier(self):
295295
return self.head
296296

297-
def reset_classifier(self, num_classes, global_pool=''):
298-
self.num_classes = num_classes
297+
def reset_classifier(self, class_num, global_pool=''):
298+
self.class_num = class_num
299299
self.head = nn.Linear(
300-
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
300+
self.embed_dim, class_num) if class_num > 0 else nn.Identity()
301301

302302
def forward_features(self, x):
303303
B = x.shape[0]

passl/models/convnext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class ConvNeXt(Model):
108108
A Paddle impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
109109
Args:
110110
in_chans (int): Number of input image channels. Default: 3
111-
num_classes (int): Number of classes for classification head. Default: 1000
111+
class_num (int): Number of classes for classification head. Default: 1000
112112
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
113113
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
114114
drop_path_rate (float): Stochastic depth rate. Default: 0.

passl/models/dino/dino_vit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def forward(self, x, return_attention=False):
114114
class DINOVisionTransformer(nn.Layer):
115115
""" DINO Vision Transformer """
116116

117-
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
117+
def __init__(self, img_size=[224], patch_size=16, in_chans=3, class_num=0, embed_dim=768, depth=12,
118118
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
119119
drop_path_rate=0., norm_layer=nn.LayerNorm, n_last_blocks=1, avgpool_patchtokens=False, **kwargs):
120120
super().__init__()
@@ -147,7 +147,7 @@ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, emb
147147
self.norm = norm_layer(embed_dim)
148148

149149
# Classifier head
150-
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
150+
self.head = nn.Linear(embed_dim, class_num) if class_num > 0 else nn.Identity()
151151

152152
self.n_last_blocks = n_last_blocks
153153
self.avgpool_patchtokens = avgpool_patchtokens

passl/models/simsiam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, base_encoder, dim=2048, pred_dim=512):
4545
super(SimSiamPretain, self).__init__()
4646

4747
# create the encoder
48-
# num_classes is the output fc dimension, zero-initialize last BNs
48+
# class_num is the output fc dimension, zero-initialize last BNs
4949
self.encoder = base_encoder(class_num=dim, zero_init_residual=True)
5050

5151
# build a 3-layer projector

passl/models/swin_transformer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ class SwinTransformer(Model):
467467
img_size (int | tuple(int)): Input image size. Default 224
468468
patch_size (int | tuple(int)): Patch size. Default: 4
469469
in_chans (int): Number of input image channels. Default: 3
470-
num_classes (int): Number of classes for classification head. Default: 1000
470+
class_num (int): Number of classes for classification head. Default: 1000
471471
embed_dim (int): Patch embedding dimension. Default: 96
472472
depths (tuple(int)): Depth of each Swin Transformer layer.
473473
num_heads (tuple(int)): Number of attention heads in different layers.
@@ -487,7 +487,7 @@ def __init__(self,
487487
img_size=224,
488488
patch_size=4,
489489
in_chans=3,
490-
num_classes=1000,
490+
class_num=1000,
491491
global_pool='avg',
492492
embed_dim=96,
493493
depths=(2, 2, 6, 2),
@@ -506,7 +506,7 @@ def __init__(self,
506506
**kwargs):
507507
super().__init__()
508508
assert global_pool in ('', 'avg')
509-
self.num_classes = num_classes
509+
self.class_num = class_num
510510
self.global_pool = global_pool
511511
self.num_layers = len(depths)
512512
self.embed_dim = embed_dim
@@ -560,7 +560,7 @@ def __init__(self,
560560
self.norm = norm_layer(self.num_features)
561561
self.head = nn.Linear(
562562
self.num_features,
563-
num_classes) if num_classes > 0 else nn.Identity()
563+
class_num) if class_num > 0 else nn.Identity()
564564

565565
if self.absolute_pos_embed is not None:
566566
init.trunc_normal_(self.absolute_pos_embed, std=.02)
@@ -595,14 +595,14 @@ def group_matcher(self, coarse=False):
595595
def get_classifier(self):
596596
return self.head
597597

598-
def reset_classifier(self, num_classes, global_pool=None):
599-
self.num_classes = num_classes
598+
def reset_classifier(self, class_num, global_pool=None):
599+
self.class_num = class_num
600600
if global_pool is not None:
601601
assert global_pool in ('', 'avg')
602602
self.global_pool = global_pool
603603
self.head = nn.Linear(
604604
self.num_features,
605-
num_classes) if num_classes > 0 else nn.Identity()
605+
class_num) if class_num > 0 else nn.Identity()
606606

607607
def forward_features(self, x):
608608
x = self.patch_embed(x)

tasks/ssl/cae/main_finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def mixup_collate_fn(batch):
558558
data_loader_val = None
559559

560560
model = models_cae.__dict__[args.model](
561-
num_classes=args.nb_classes,
561+
class_num=args.nb_classes,
562562
drop_rate=args.drop,
563563
drop_path_rate=args.drop_path,
564564
attn_drop_rate=args.attn_drop_rate,

tasks/ssl/cae/main_linprobe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def main(args):
345345
use_shared_memory=args.pin_mem, )
346346

347347
model = models_cae.__dict__[args.model](
348-
num_classes=args.nb_classes,
348+
class_num=args.nb_classes,
349349
drop_rate=args.drop,
350350
drop_path_rate=args.drop_path,
351351
attn_drop_rate=args.attn_drop_rate,

tasks/ssl/cae/main_pretrain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def get_args():
268268
type=int,
269269
help='Number of heads for decoder')
270270
parser.add_argument(
271-
'--decoder_num_classes',
271+
'--decoder_class_num',
272272
default=8192,
273273
type=int,
274274
help='Number of classes for decoder')

0 commit comments

Comments
 (0)