From b5e0619bec388fd1a0ee06f938f0b5c709ed70a4 Mon Sep 17 00:00:00 2001 From: wang3702 Date: Mon, 10 Jun 2019 15:41:17 -0700 Subject: [PATCH 1/4] update the backbone to conventional backbone --- .idea/MixMatch-pytorch.iml | 12 + .idea/misc.xml | 4 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + .idea/workspace.xml | 325 +++++++++++++++++++++++++++ dataset/__init__.py | 0 dataset/cifar10.py | 2 +- models/Attention.py | 46 ++++ models/Classifier.py | 24 ++ models/TEBlock.py | 39 ++++ models/TE_Module.py | 131 +++++++++++ models/__init__.py | 0 train.py | 32 +-- utils/__init__.py | 3 +- utils/progress/LICENSE | 13 ++ utils/progress/MANIFEST.in | 1 + utils/progress/README.rst | 131 +++++++++++ utils/progress/__init__.py | 0 utils/progress/progress/__init__.py | 127 +++++++++++ utils/progress/progress/__init__.pyc | Bin 0 -> 5388 bytes utils/progress/progress/bar.py | 88 ++++++++ utils/progress/progress/bar.pyc | Bin 0 -> 3573 bytes utils/progress/progress/counter.py | 48 ++++ utils/progress/progress/helpers.py | 91 ++++++++ utils/progress/progress/helpers.pyc | Bin 0 -> 3882 bytes utils/progress/progress/spinner.py | 44 ++++ utils/progress/setup.py | 29 +++ utils/progress/test_progress.py | 48 ++++ 28 files changed, 1234 insertions(+), 18 deletions(-) create mode 100644 .idea/MixMatch-pytorch.iml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 .idea/workspace.xml create mode 100644 dataset/__init__.py create mode 100644 models/Attention.py create mode 100644 models/Classifier.py create mode 100644 models/TEBlock.py create mode 100644 models/TE_Module.py create mode 100644 models/__init__.py create mode 100644 utils/progress/LICENSE create mode 100644 utils/progress/MANIFEST.in create mode 100644 utils/progress/README.rst create mode 100644 utils/progress/__init__.py create mode 100644 utils/progress/progress/__init__.py create mode 100644 utils/progress/progress/__init__.pyc create mode 100644 utils/progress/progress/bar.py create mode 100644 utils/progress/progress/bar.pyc create mode 100644 utils/progress/progress/counter.py create mode 100644 utils/progress/progress/helpers.py create mode 100644 utils/progress/progress/helpers.pyc create mode 100644 utils/progress/progress/spinner.py create mode 100644 utils/progress/setup.py create mode 100644 utils/progress/test_progress.py diff --git a/.idea/MixMatch-pytorch.iml b/.idea/MixMatch-pytorch.iml new file mode 100644 index 0000000..7c9d48f --- /dev/null +++ b/.idea/MixMatch-pytorch.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..a2e120d --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..a739c27 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml new file mode 100644 index 0000000..dfa7e4c --- /dev/null +++ b/.idea/workspace.xml @@ -0,0 +1,325 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + models + progress + summarywriter + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -170,6 +144,7 @@ + @@ -265,58 +240,71 @@ - + + + + - - + + - - - - + - - + + + + + - - + + + + + + + + + + + + - - + + - + - - + + - + diff --git a/train.py b/train.py index 370c464..d560f3b 100644 --- a/train.py +++ b/train.py @@ -21,6 +21,7 @@ from models.Classifier import Classifier import dataset.cifar10 as dataset from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig +from models.wideresnet import WideResNet #from tensorboardX import SummaryWriter parser = argparse.ArgumentParser(description='PyTorch MixMatch Training') @@ -53,7 +54,7 @@ parser.add_argument('--lambda-u', default=75, type=float) parser.add_argument('--T', default=0.5, type=float) parser.add_argument('--ema-decay', default=0.999, type=float) - +parser.add_argument('--type', default=0, type=int,help='Choose different backbone: 0:wide-resnet 1:te-module 13 layer') args = parser.parse_args() state = {k: v for k, v in args._get_kwargs()} @@ -97,7 +98,10 @@ def main(): print("==> creating WRN-28-2") def create_model(ema=False): - model = Classifier(num_classes=10) + if args.type==1: + model = Classifier(num_classes=10) + else: + model=WideResNet(num_classes=10) model = model.cuda() if ema: @@ -116,7 +120,7 @@ def create_model(ema=False): criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=args.lr) - ema_optimizer= WeightEMA(model, ema_model, alpha=args.ema_decay) + ema_optimizer= WeightEMA(model, ema_model, run_type=args.type,alpha=args.ema_decay) start_epoch = 0 # Resume @@ -210,6 +214,7 @@ def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_opti labeled_train_iter = iter(labeled_trainloader) inputs_x, targets_x = labeled_train_iter.next() + try: (inputs_u, inputs_u2), _ = unlabeled_train_iter.next() except: @@ -382,11 +387,14 @@ def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch): return Lx, Lu, args.lambda_u * linear_rampup(epoch) class WeightEMA(object): - def __init__(self, model, ema_model, alpha=0.999): + def __init__(self, model, ema_model,run_type=0, alpha=0.999): self.model = model self.ema_model = ema_model self.alpha = alpha - self.tmp_model = Classifier(num_classes=10).cuda() + if run_type==1: + self.tmp_model = Classifier(num_classes=10).cuda() + else: + self.tmp_model =WideResNet(num_classes=10).cuda() self.wd = 0.02 * args.lr for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()): From 3531fe6206f6ea7556e8cb62dcedf15e63059aca Mon Sep 17 00:00:00 2001 From: wang3702 Date: Thu, 13 Jun 2019 17:17:03 -0700 Subject: [PATCH 3/4] try tensorboardX --- train.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/train.py b/train.py index d560f3b..7d64702 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,7 @@ import dataset.cifar10 as dataset from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig from models.wideresnet import WideResNet -#from tensorboardX import SummaryWriter +from tensorboardX import SummaryWriter parser = argparse.ArgumentParser(description='PyTorch MixMatch Training') # Optimization options @@ -141,7 +141,7 @@ def create_model(ema=False): logger = Logger(os.path.join(args.out, 'log.txt'), title=title) logger.set_names(['Train Loss', 'Train Loss X', 'Train Loss U', 'Valid Loss', 'Valid Acc.', 'Test Loss', 'Test Acc.']) - #writer = SummaryWriter(args.out) + writer = SummaryWriter(args.out) step = 0 test_accs = [] # Train and val @@ -156,13 +156,13 @@ def create_model(ema=False): step = args.batch_size * args.val_iteration * (epoch + 1) - # writer.add_scalar('losses/train_loss', train_loss, step) - # writer.add_scalar('losses/valid_loss', val_loss, step) - # writer.add_scalar('losses/test_loss', test_loss, step) - # - # writer.add_scalar('accuracy/train_acc', train_acc, step) - # writer.add_scalar('accuracy/val_acc', val_acc, step) - # writer.add_scalar('accuracy/test_acc', test_acc, step) + writer.add_scalar('losses/train_loss', train_loss, step) + writer.add_scalar('losses/valid_loss', val_loss, step) + writer.add_scalar('losses/test_loss', test_loss, step) + + writer.add_scalar('accuracy/train_acc', train_acc, step) + writer.add_scalar('accuracy/val_acc', val_acc, step) + writer.add_scalar('accuracy/test_acc', test_acc, step) # scheduler.step() @@ -182,7 +182,7 @@ def create_model(ema=False): }, is_best) test_accs.append(test_acc) logger.close() - #writer.close() + writer.close() print('Best acc:') print(best_acc) From 1d3f35e3dfa3270cb8ae650aaa7028307c0422b6 Mon Sep 17 00:00:00 2001 From: wang3702 Date: Sat, 22 Jun 2019 17:16:34 -0700 Subject: [PATCH 4/4] update a new dataloader for training --- dataset/Download_Cifar.py | 109 +++++ dataset/Projective_MixMatch_Dataloader.py | 222 +++++++++ ops/__init__.py | 0 ops/os_operation.py | 17 + train_myloader.py | 543 ++++++++++++++++++++++ 5 files changed, 891 insertions(+) create mode 100644 dataset/Download_Cifar.py create mode 100644 dataset/Projective_MixMatch_Dataloader.py create mode 100644 ops/__init__.py create mode 100644 ops/os_operation.py create mode 100644 train_myloader.py diff --git a/dataset/Download_Cifar.py b/dataset/Download_Cifar.py new file mode 100644 index 0000000..cd571ff --- /dev/null +++ b/dataset/Download_Cifar.py @@ -0,0 +1,109 @@ + +from ops.os_operation import mkdir +import os +from torchvision.datasets.utils import download_url, check_integrity +import sys +if sys.version_info[0] == 2: + import cPickle as pickle +else: + import pickle +import numpy as np + +class CIFAR10(object): + """`CIFAR10 `_ Dataset. + Args: + root (string): Root directory of dataset where directory + ``cifar-10-batches-py`` exists or will be saved to if download is set to True. + + download (): downloads the dataset from the internet and + puts it in root directory + """ + + def __init__(self, save_path): + self.root=save_path + self.download_init() + if not self._check_integrity(): + mkdir(save_path) + self.download() + self.final_path=os.path.join(save_path,'cifar10') + mkdir(self.final_path) + #generate npy files here + self.train_path=os.path.join(self.final_path,'trainset') + self.test_path = os.path.join(self.final_path, 'testset') + mkdir(self.train_path) + mkdir(self.test_path) + if os.path.getsize(self.train_path)<10000: + self.Process_Dataset(self.train_list,self.train_path) + if os.path.getsize(self.test_path)<10000: + self.Process_Dataset(self.test_list,self.test_path) + def download_init(self): + self.base_folder = 'cifar-10-batches-py' + self.url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + self.filename = "cifar-10-python.tar.gz" + self.tgz_md5 = 'c58f30108f718f92721af3b95e74349a' + self.train_list = [ + ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], + ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], + ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], + ['data_batch_4', '634d18415352ddfa80567beed471001a'], + ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], + ] + + self.test_list = [ + ['test_batch', '40351d587109b95175f43aff81a1287e'], + ] + + def download(self): + import tarfile + + if self._check_integrity(): + print('Files already downloaded and verified') + return + + root = self.root + download_url(self.url, root, self.filename, self.tgz_md5) + + # extract file + cwd = os.getcwd() + tar = tarfile.open(os.path.join(root, self.filename), "r:gz") + os.chdir(root) + tar.extractall() + tar.close() + os.chdir(cwd) + + def _check_integrity(self): + + root = self.root + for fentry in (self.train_list + self.test_list): + filename, md5 = fentry[0], fentry[1] + fpath = os.path.join(root, self.base_folder, filename) + if not check_integrity(fpath, md5): + return False + return True + + def Process_Dataset(self,train_list,train_path): + train_data=[] + train_labels=[] + for fentry in train_list: + f = fentry[0] + file = os.path.join(self.root, self.base_folder, f) + with open(file, 'rb') as fo: + if sys.version_info[0] == 2: + entry = pickle.load(fo) + else: + entry = pickle.load(fo, encoding='latin1') + train_data.append(entry['data']) + if 'labels' in entry: + train_labels += entry['labels'] + else: + train_labels += entry['fine_labels'] + train_data = np.concatenate(train_data) + train_data = train_data.reshape((len(train_data), 3, 32, 32)) + train_labels=np.array(train_labels) + #following Channel,height,width format + #self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC + for i in range(len(train_data)): + tmp_train_path=os.path.join(train_path,'trainset'+str(i)+'.npy') + tmp_aim_path = os.path.join(train_path, 'aimset' + str(i) + '.npy') + np.save(tmp_train_path,train_data[i]) + np.save(tmp_aim_path,train_labels[i]) diff --git a/dataset/Projective_MixMatch_Dataloader.py b/dataset/Projective_MixMatch_Dataloader.py new file mode 100644 index 0000000..b960e61 --- /dev/null +++ b/dataset/Projective_MixMatch_Dataloader.py @@ -0,0 +1,222 @@ +import torch +import torch.utils.data as data +import numpy as np +import random +import os +from PIL import Image, PILLOW_VERSION +import numbers +from torchvision.transforms.functional import _get_inverse_affine_matrix +import math +from sklearn.model_selection import train_test_split +from collections import defaultdict + +class TransformTwice: + def __init__(self, transform): + self.transform = transform + + def __call__(self, inp): + out1 = self.transform(inp) + out2 = self.transform(inp) + return out1, out2 + +class Projective_MixMatch_Data_Loader(data.Dataset): + def __init__(self, dataset_dir,shift=6, train_label=True, scale=None, resample=False, + fillcolor=0,matrix_transform=None, + transform_pre=None, transform=None, target_transform=None, rand_state=888, + valid_size=0.1,uniform_label=False,num_classes=10,unlabel_Data=False): + super(Projective_MixMatch_Data_Loader, self).__init__() + self.root=os.path.abspath(dataset_dir) + self.shift=shift + self.trainsetFile = [] + self.aimsetFile = [] + listfiles = os.listdir(dataset_dir) + self.trainlist = [os.path.join(dataset_dir, x) for x in listfiles if "trainset" in x] + self.aimlist = [os.path.join(dataset_dir, x) for x in listfiles if "aimset" in x] + self.trainlist.sort() + self.aimlist.sort() + self.train_label=train_label + self.valid_size=valid_size + self.unlabel_Data=unlabel_Data + # here update this with 80% as training, 20%as validation + if valid_size>0: + if uniform_label==False: + + X_train, X_test, y_train, y_test = train_test_split(self.trainlist, self.aimlist, test_size=valid_size, + random_state=rand_state) + if train_label: + self.trainlist = X_train + self.aimlist = y_train + + else: + self.trainlist = X_test + self.aimlist = y_test + else: + #pick the uniform valid size indicated + shuffle_range=np.arange(len(self.trainlist)) + random.seed(rand_state) + random.shuffle(shuffle_range) + require_size=int(len(self.aimlist)*valid_size/num_classes) + self.trainlist,self.aimlist=self.pick_top_k_example(require_size,shuffle_range,num_classes) + if uniform_label==True and len(self.trainlist)<50000: + #to accelerate training to avoid dataloader load again and again for small data + repeat_times=int(50000/len(self.trainlist)) + self.trainlist=self.trainlist*repeat_times + self.aimlist=self.aimlist*repeat_times + self.transform_pre = transform_pre + if scale is not None: + assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ + "scale should be a list or tuple and it must be of length 2." + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + self.resample = resample + self.fillcolor = fillcolor + self.transform = transform + self.target_transform = target_transform + self.matrix_transform=matrix_transform + def pick_top_k_example(self,img_per_cat,shuffle_range,num_class): + record_dict=defaultdict(list) + for i in range(len(shuffle_range)): + tmp_id=shuffle_range[i] + label=int(np.load(self.aimlist[tmp_id])) + if label not in record_dict: + record_dict[label].append(tmp_id) + elif len(record_dict[label]) Preparing cifar10') + transform_train = transforms.Compose([ + dataset.RandomPadandCrop(32), + dataset.RandomFlip(), + dataset.ToTensor(), + ]) + + transform_val = transforms.Compose([ + dataset.ToTensor(), + ]) + + # train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, + # transform_train=transform_train, + # transform_val=transform_val) + # labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, + # drop_last=True) + # unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size, shuffle=True, + # num_workers=0, drop_last=True) + # val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0) + # test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0) + data = CIFAR10('./data') + train_dataloader, unlabel_dataloader, valid_dataloader, testloader=prepare_Dataloader(data,params) + # Model + print("==> creating WRN-28-2") + + def create_model(ema=False): + if args.type == 1: + model = Classifier(num_classes=10) + else: + model = WideResNet(num_classes=10) + model = model.cuda() + + if ema: + for param in model.parameters(): + param.detach_() + + return model + + model = create_model() + ema_model = create_model(ema=True) + + cudnn.benchmark = True + print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) + + train_criterion = SemiLoss() + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + ema_optimizer = WeightEMA(model, ema_model, run_type=args.type, alpha=args.ema_decay) + start_epoch = 0 + + # Resume + title = 'noisy-cifar-10' + if args.resume: + # Load checkpoint. + print('==> Resuming from checkpoint..') + assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' + args.out = os.path.dirname(args.resume) + checkpoint = torch.load(args.resume) + best_acc = checkpoint['best_acc'] + start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + ema_model.load_state_dict(checkpoint['ema_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + logger = Logger(os.path.join(args.resume, 'log.txt'), title=title, resume=True) + else: + logger = Logger(os.path.join(args.out, 'log.txt'), title=title) + logger.set_names( + ['Train Loss', 'Train Loss X', 'Train Loss U', 'Valid Loss', 'Valid Acc.', 'Test Loss', 'Test Acc.']) + + writer = SummaryWriter(args.out) + step = 0 + test_accs = [] + # Train and val + for epoch in range(start_epoch, args.epochs): + print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) + + train_loss, train_loss_x, train_loss_u = train(train_dataloader, unlabel_dataloader, model, optimizer, + ema_optimizer, train_criterion, epoch, use_cuda) + _, train_acc = validate(train_dataloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats') + val_loss, val_acc = validate(valid_dataloader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats') + test_loss, test_acc = validate(testloader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ') + + step = args.batch_size * args.val_iteration * (epoch + 1) + + writer.add_scalar('losses/train_loss', train_loss, step) + writer.add_scalar('losses/valid_loss', val_loss, step) + writer.add_scalar('losses/test_loss', test_loss, step) + + writer.add_scalar('accuracy/train_acc', train_acc, step) + writer.add_scalar('accuracy/val_acc', val_acc, step) + writer.add_scalar('accuracy/test_acc', test_acc, step) + + # scheduler.step() + + # append logger file + logger.append([train_loss, train_loss_x, train_loss_u, val_loss, val_acc, test_loss, test_acc]) + + # save model + is_best = val_acc > best_acc + best_acc = max(val_acc, best_acc) + save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'ema_state_dict': ema_model.state_dict(), + 'acc': val_acc, + 'best_acc': best_acc, + 'optimizer': optimizer.state_dict(), + }, is_best) + test_accs.append(test_acc) + logger.close() + writer.close() + + print('Best acc:') + print(best_acc) + + print('Mean acc:') + print(np.mean(test_accs[-20:])) + +def prepare_Dataloader(data,params): + cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255 + cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255 + #dataset=torchvision.dataset.cifar10(params['F'], train=True, download=True) + #from ops.Transform_ops import RandomFlip,RandomPadandCrop + # transform_train = transforms.Compose([ + # transforms.RandomCrop(32, padding=4), + # transforms.RandomHorizontalFlip(p=0.5), + # transforms.Normalize(cifar10_mean,cifar10_std), + # #dataset.ToTensor(), + # ]) + # transform_train = transforms.Compose([ + # transforms.RandomCrop(32, padding=4), + # transforms.RandomHorizontalFlip(), + # ]), + # transform_final=transforms.Compose([ + # + # transforms.Normalize(cifar10_mean, cifar10_std), + # ]) + #from Data_Processing.Projective_MixMatch_Dataloader import TransformTwice + #actually unlabelled training dataloader + valid_dataset = Projective_MixMatch_Data_Loader(dataset_dir=data.train_path, shift=params['shift'], train_label=True, + scale=(params['shrink'], params['enlarge']), fillcolor=(128, 128, 128), + resample=PIL.Image.BILINEAR, + matrix_transform=transforms.Compose([ + transforms.Normalize((0., 0., 16., 0., 0., 16., 0., 0.), + (1., 1., 20., 1., 1., 20., 0.015, 0.015)), + ]), + transform_pre= transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + ]), rand_state=params['manualSeed'], + valid_size=0, + ) + valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=params['batch_size'], + shuffle=True, num_workers=int(params['num_workers'])) + + unlabel_dataset = Projective_MixMatch_Data_Loader(dataset_dir=data.train_path, shift=params['shift'], + train_label=True, + scale=(params['shrink'], params['enlarge']), + fillcolor=(128, 128, 128), + resample=PIL.Image.BILINEAR, + matrix_transform=transforms.Compose([ + transforms.Normalize((0., 0., 16., 0., 0., 16., 0., 0.), + (1., 1., 20., 1., 1., 20., 0.015, 0.015)), + ]), + transform_pre= transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + ]), rand_state=params['manualSeed'], + valid_size=0, unlabel_Data=True + ) + unlabel_dataloader = torch.utils.data.DataLoader(unlabel_dataset, batch_size=params['batch_size'], + shuffle=True, num_workers=int(params['num_workers']),drop_last=True) + train_labeled_dataset = Projective_MixMatch_Data_Loader(dataset_dir=data.train_path, shift=params['shift'], + train_label=False, + scale=(params['shrink'], params['enlarge']), + fillcolor=(128, 128, 128), resample=PIL.Image.BILINEAR, + matrix_transform=transforms.Compose([ + transforms.Normalize((0., 0., 16., 0., 0., 16., 0., 0.), + (1., 1., 20., 1., 1., 20., 0.015, 0.015)), + ]), + transform_pre= transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + ]), rand_state=666, valid_size=params['portion'], uniform_label=True, + ) + train_dataloader = torch.utils.data.DataLoader(train_labeled_dataset, batch_size=params['batch_size'], + shuffle=True, num_workers=int(params['num_workers']),drop_last=True) + test_dataset = Projective_MixMatch_Data_Loader(dataset_dir=data.test_path, shift=params['shift'], train_label=True, + scale=(params['shrink'], params['enlarge']), fillcolor=(128, 128, 128), + resample=PIL.Image.BILINEAR, + matrix_transform=transforms.Compose([ + transforms.Normalize((0., 0., 16., 0., 0., 16., 0., 0.), + (1., 1., 20., 1., 1., 20., 0.015, 0.015)), + ]), + rand_state=666, valid_size=0, + ) + testloader = torch.utils.data.DataLoader(test_dataset, batch_size=params['batch_size'], shuffle=False, + num_workers=int(params['num_workers'])) + return train_dataloader,unlabel_dataloader,valid_dataloader,testloader + + +def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda): + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + losses_x = AverageMeter() + losses_u = AverageMeter() + ws = AverageMeter() + end = time.time() + + bar = Bar('Training', max=args.val_iteration) + labeled_train_iter = iter(labeled_trainloader) + unlabeled_train_iter = iter(unlabeled_trainloader) + + model.train() + for batch_idx in range(args.val_iteration): + try: + inputs_x, _,_,targets_x = labeled_train_iter.next() + except: + labeled_train_iter = iter(labeled_trainloader) + inputs_x, _, _, targets_x = labeled_train_iter.next() + + try: + (inputs_u, inputs_u2), _,_,_= unlabeled_train_iter.next() + except: + unlabeled_train_iter = iter(unlabeled_trainloader) + (inputs_u, inputs_u2), _,_,_= unlabeled_train_iter.next() + + # measure data loading time + data_time.update(time.time() - end) + + batch_size = inputs_x.size(0) + + # Transform label to one-hot + targets_x = torch.zeros(batch_size, 10).scatter_(1, targets_x.view(-1, 1), 1) + + if use_cuda: + inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True) + inputs_u = inputs_u.cuda() + inputs_u2 = inputs_u2.cuda() + + with torch.no_grad(): + # compute guessed labels of unlabel samples + outputs_u = model(inputs_u) + outputs_u2 = model(inputs_u2) + p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 + pt = p ** (1 / args.T) + targets_u = pt / pt.sum(dim=1, keepdim=True) + targets_u = targets_u.detach() + + # mixup + all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0) + all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0) + + l = np.random.beta(args.alpha, args.alpha) + + l = max(l, 1 - l) + + idx = torch.randperm(all_inputs.size(0)) + + input_a, input_b = all_inputs, all_inputs[idx] + target_a, target_b = all_targets, all_targets[idx] + + mixed_input = l * input_a + (1 - l) * input_b + mixed_target = l * target_a + (1 - l) * target_b + + # interleave labeled and unlabed samples between batches to get correct batchnorm calculation + mixed_input = list(torch.split(mixed_input, batch_size)) + mixed_input = interleave(mixed_input, batch_size) + + logits = [model(mixed_input[0])] + for input in mixed_input[1:]: + logits.append(model(input)) + + # put interleaved samples back + logits = interleave(logits, batch_size) + logits_x = logits[0] + logits_u = torch.cat(logits[1:], dim=0) + + Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], + epoch + batch_idx / args.val_iteration) + + loss = Lx + w * Lu + + # record loss + losses.update(loss.item(), inputs_x.size(0)) + losses_x.update(Lx.item(), inputs_x.size(0)) + losses_u.update(Lu.item(), inputs_x.size(0)) + ws.update(w, inputs_x.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + ema_optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + # plot progress + bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | W: {w:.4f}'.format( + batch=batch_idx + 1, + size=args.val_iteration, + data=data_time.avg, + bt=batch_time.avg, + total=bar.elapsed_td, + eta=bar.eta_td, + loss=losses.avg, + loss_x=losses_x.avg, + loss_u=losses_u.avg, + w=ws.avg, + ) + bar.next() + bar.finish() + + ema_optimizer.step(bn=True) + + return (losses.avg, losses_x.avg, losses_u.avg,) + + +def validate(valloader, model, criterion, epoch, use_cuda, mode): + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to evaluate mode + model.eval() + + end = time.time() + bar = Bar('{mode}', max=len(valloader)) + with torch.no_grad(): + for batch_idx, (inputs,_,_, targets) in enumerate(valloader): + # measure data loading time + data_time.update(time.time() - end) + + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) + + # compute output + outputs = model(inputs) + loss = criterion(outputs, targets) + + # measure accuracy and record loss + prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) + losses.update(loss.item(), inputs.size(0)) + top1.update(prec1.item(), inputs.size(0)) + top5.update(prec5.item(), inputs.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + # plot progress + bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( + batch=batch_idx + 1, + size=len(valloader), + data=data_time.avg, + bt=batch_time.avg, + total=bar.elapsed_td, + eta=bar.eta_td, + loss=losses.avg, + top1=top1.avg, + top5=top5.avg, + ) + bar.next() + bar.finish() + return (losses.avg, top1.avg) + + +def save_checkpoint(state, is_best, checkpoint=args.out, filename='checkpoint.pth.tar'): + filepath = os.path.join(checkpoint, filename) + torch.save(state, filepath) + if is_best: + shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) + + +def linear_rampup(current, rampup_length=16): + if rampup_length == 0: + return 1.0 + else: + current = np.clip(current / rampup_length, 0.0, 1.0) + return float(current) + + +class SemiLoss(object): + def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch): + probs_u = torch.softmax(outputs_u, dim=1) + + Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1)) + Lu = torch.mean((probs_u - targets_u) ** 2) + + return Lx, Lu, args.lambda_u * linear_rampup(epoch) + + +class WeightEMA(object): + def __init__(self, model, ema_model, run_type=0, alpha=0.999): + self.model = model + self.ema_model = ema_model + self.alpha = alpha + if run_type == 1: + self.tmp_model = Classifier(num_classes=10).cuda() + else: + self.tmp_model = WideResNet(num_classes=10).cuda() + self.wd = 0.02 * args.lr + + for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()): + ema_param.data.copy_(param.data) + + def step(self, bn=False): + if bn: + # copy batchnorm stats to ema model + for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()): + tmp_param.data.copy_(ema_param.data.detach()) + + self.ema_model.load_state_dict(self.model.state_dict()) + + for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()): + ema_param.data.copy_(tmp_param.data.detach()) + else: + one_minus_alpha = 1.0 - self.alpha + for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()): + ema_param.data.mul_(self.alpha) + ema_param.data.add_(param.data.detach() * one_minus_alpha) + # customized weight decay + param.data.mul_(1 - self.wd) + + +def interleave_offsets(batch, nu): + groups = [batch // (nu + 1)] * (nu + 1) + for x in range(batch - sum(groups)): + groups[-x - 1] += 1 + offsets = [0] + for g in groups: + offsets.append(offsets[-1] + g) + assert offsets[-1] == batch + return offsets + + +def interleave(xy, batch): + nu = len(xy) - 1 + offsets = interleave_offsets(batch, nu) + xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy] + for i in range(1, nu + 1): + xy[0][i], xy[i][i] = xy[i][i], xy[0][i] + return [torch.cat(v, dim=0) for v in xy] + + +if __name__ == '__main__': + main() \ No newline at end of file