diff --git a/main_linear.py b/main_linear.py index a5c179a..fe9ea7e 100644 --- a/main_linear.py +++ b/main_linear.py @@ -72,9 +72,18 @@ def get_args_parser(): help='path to CLIP pretrained checkpoint') return parser + best_acc1 = 0 +@timm.models.registry.register_model +def vit_small_mocov3_patch16_224(**kwargs): + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12, **kwargs) + model = timm.models.vision_transformer._create_vision_transformer('vit_small_patch16_224', **model_kwargs) + + return model + + def main(args): utils.init_distributed_mode(args) @@ -131,7 +140,7 @@ def main(args): getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01) getattr(model, linear_keyword).bias.data.zero_() - init_lr = args.lr * int(args.batch_size / utils.get_world_size()) / 256 + init_lr = (args.lr * args.batch_size * utils.get_world_size()) / 256 args.workers = int((args.workers + utils.get_world_size() - 1) / utils.get_world_size()) model.cuda(args.gpu)