From 402a3436d8314dd895f6c137ec654141b3771e9b Mon Sep 17 00:00:00 2001 From: Enrico Fini Date: Thu, 27 Oct 2022 15:26:12 +0200 Subject: [PATCH] fix linear eval --- main_linear.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/main_linear.py b/main_linear.py index a5c179a..fe9ea7e 100644 --- a/main_linear.py +++ b/main_linear.py @@ -72,9 +72,18 @@ def get_args_parser(): help='path to CLIP pretrained checkpoint') return parser + best_acc1 = 0 +@timm.models.registry.register_model +def vit_small_mocov3_patch16_224(**kwargs): + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12, **kwargs) + model = timm.models.vision_transformer._create_vision_transformer('vit_small_patch16_224', **model_kwargs) + + return model + + def main(args): utils.init_distributed_mode(args) @@ -131,7 +140,7 @@ def main(args): getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01) getattr(model, linear_keyword).bias.data.zero_() - init_lr = args.lr * int(args.batch_size / utils.get_world_size()) / 256 + init_lr = (args.lr * args.batch_size * utils.get_world_size()) / 256 args.workers = int((args.workers + utils.get_world_size() - 1) / utils.get_world_size()) model.cuda(args.gpu)