From d02ea08b046c5470b5eb179ec8facecbe165ad45 Mon Sep 17 00:00:00 2001 From: tuanh208 Date: Sun, 8 Nov 2020 19:41:30 +0100 Subject: [PATCH 1/3] zerospeech: update clustering_script.py, clustering_quantization.py, feature_loader.py --- .../clustering/clustering_quantization.py | 209 ++++++++++-------- cpc/criterion/clustering/clustering_script.py | 10 +- cpc/feature_loader.py | 98 +++++++- 3 files changed, 215 insertions(+), 102 deletions(-) diff --git a/cpc/criterion/clustering/clustering_quantization.py b/cpc/criterion/clustering/clustering_quantization.py index f26cc57..91b0b5f 100644 --- a/cpc/criterion/clustering/clustering_quantization.py +++ b/cpc/criterion/clustering/clustering_quantization.py @@ -9,7 +9,7 @@ import torch from cpc.dataset import findAllSeqs from cpc.feature_loader import buildFeature, FeatureModule, loadModel, buildFeature_batch -from cpc.criterion.research.clustering import kMeanCluster +from cpc.criterion.clustering import kMeanCluster def readArgs(pathArgs): print(f"Loading args from {pathArgs}") @@ -18,31 +18,57 @@ def readArgs(pathArgs): return args +def writeArgs(pathArgs, args): + with open(pathArgs, 'w') as file: + json.dump(vars(args), file, indent=2) + def loadClusterModule(pathCheckpoint): print(f"Loading ClusterModule at {pathCheckpoint}") - state_dict = torch.load(pathCheckpoint) - if "state_dict" in state_dict: #kmeans - clusterModule = kMeanCluster(torch.zeros(1, state_dict["n_clusters"], state_dict["dim"])) - clusterModule.load_state_dict(state_dict["state_dict"]) - else: #dpmeans - clusterModule = kMeanCluster(state_dict["mu"]) - clusterModule = clusterModule.cuda() + state_dict = torch.load(pathCheckpoint, map_location=torch.device('cpu')) + clusterModule = kMeanCluster(torch.zeros(1, state_dict["n_clusters"], state_dict["dim"])) + clusterModule.load_state_dict(state_dict["state_dict"]) return clusterModule +def quantize_file(file_path, cpc_feature_function, clusterModule): + # Get CPC features + cFeatures = cpc_feature_function(file_path) + if clusterModule.Ck.is_cuda: + cFeatures = cFeatures.cuda() + + nGroups = cFeatures.size(-1)//clusterModule.Ck.size(-1) # groups information + + # Quantize the output of clustering on the CPC features + cFeatures = cFeatures.view(1, -1, clusterModule.Ck.size(-1)) + if cFeatures.size(1) > 50000: # Librilight, to avoid GPU OOM, decrease when still OOM + clusterModule = clusterModule.cpu() + cFeatures = cFeatures.cpu() + qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1) + if not args.cpu: + clusterModule = clusterModule.cuda() + else: + qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1) + qFeatures = qFeatures[0].detach().cpu().numpy() + + # Transform to quantized line + quantLine = ",".join(["-".join([str(i) for i in item]) for item in qFeatures.reshape(-1, nGroups)]) + + return quantLine + def parseArgs(argv): # Run parameters parser = argparse.ArgumentParser(description='Quantize audio files using CPC Clustering Module.') - parser.add_argument('pathCheckpoint', type=str, + parser.add_argument('pathClusteringCheckpoint', type=str, help='Path to the clustering checkpoint.') parser.add_argument('pathDB', type=str, help='Path to the dataset that we want to quantize.') - parser.add_argument('pathOutput', type=str, + parser.add_argument('pathOutputDir', type=str, help='Path to the output directory.') - parser.add_argument('--pathSeq', type=str, - help='Path to the sequences (file names) to be included used.') + parser.add_argument('--pathSeq', type=str, + help='Path to the sequences (file names) to be included used ' + '(if not speficied, included all files found in pathDB).') parser.add_argument('--split', type=str, default=None, - help="If you want to divide the dataset in small splits, specify it " - "with idxSplit-numSplits (idxSplit > 0), eg. --split 1-20.") + help='If you want to divide the dataset in small splits, specify it ' + 'with idxSplit-numSplits (idxSplit > 0), eg. --split 1-20.') parser.add_argument('--file_extension', type=str, default=".flac", help="Extension of the audio files in the dataset (default: .flac).") parser.add_argument('--max_size_seq', type=int, default=10240, @@ -62,11 +88,10 @@ def parseArgs(argv): "NOTE: This can have better quantized units as we can set " "model.gAR.keepHidden = True (line 162), but the quantization" "will be a bit longer.") - parser.add_argument('--recursionLevel', type=int, default=1, - help='Speaker level in pathDB (defaut: 1). This is only helpful' - 'when --separate-speaker is activated.') - parser.add_argument('--separate-speaker', action='store_true', - help="Separate each speaker with a different output file.") + parser.add_argument('--cpu', action='store_true', + help="Run on a cpu machine.") + parser.add_argument('--resume', action='store_true', + help="Continue to quantize if an output file already exists.") return parser.parse_args(argv) def main(argv): @@ -77,12 +102,6 @@ def main(argv): print(f"Quantizing data from {args.pathDB}") print("=============================================================") - # Check if directory exists - if not os.path.exists(args.pathOutput): - print("") - print(f"Creating the output directory at {args.pathOutput}") - Path(args.pathOutput).mkdir(parents=True, exist_ok=True) - # Get splits if args.split: assert len(args.split.split("-"))==2 and int(args.split.split("-")[1]) >= int(args.split.split("-")[0]) >= 1, \ @@ -93,40 +112,45 @@ def main(argv): # Find all sequences print("") - print(f"Looking for all {args.file_extension} files in {args.pathDB} with speakerLevel {args.recursionLevel}") - seqNames, speakers = findAllSeqs(args.pathDB, - speaker_level=args.recursionLevel, + print(f"Looking for all {args.file_extension} files in {args.pathDB}") + seqNames, _ = findAllSeqs(args.pathDB, + speaker_level=1, extension=args.file_extension, loadCache=True) + if len(seqNames) == 0 or not os.path.splitext(seqNames[0][1])[1].endswith(args.file_extension): + print(f"Seems like the _seq_cache.txt does not contain the correct extension, reload the file list") + seqNames, _ = findAllSeqs(args.pathDB, + speaker_level=1, + extension=args.file_extension, + loadCache=False) + print(f"Done! Found {len(seqNames)} files!") + # Filter specific sequences if args.pathSeq: - with open(args.pathSeq, 'r') as f: - seqs = set([x.strip() for x in f]) - - filtered = [] - for s in seqNames: - if s[1].split('/')[-1].split('.')[0] in seqs: - filtered.append(s) + print("") + print(f"Filtering seqs in {args.pathSeq}") + with open(args.pathSeq, 'r') as f: + seqs = set([x.strip() for x in f]) + filtered = [] + for s in seqNames: + if os.path.splitext(s[1].split('/')[-1])[0] in seqs: + filtered.append(s) seqNames = filtered + print(f"Done! {len(seqNames)} files filtered!") - print(f"Done! Found {len(seqNames)} files and {len(speakers)} speakers!") - if args.separate_speaker: - seqNames_by_speaker = {} - for seq in seqNames: - speaker = seq[1].split("/")[args.recursionLevel-1] - if speaker not in seqNames_by_speaker: - seqNames_by_speaker[speaker] = [] - seqNames_by_speaker[speaker].append(seq) + # Check if directory exists + if not os.path.exists(args.pathOutputDir): + print("") + print(f"Creating the output directory at {args.pathOutputDir}") + Path(args.pathOutputDir).mkdir(parents=True, exist_ok=True) + writeArgs(os.path.join(args.pathOutputDir, "_info_args.json"), args) # Check if output file exists if not args.split: nameOutput = "quantized_outputs.txt" else: nameOutput = f"quantized_outputs_split_{idx_split}-{num_splits}.txt" - if args.separate_speaker is False: - outputFile = os.path.join(args.pathOutput, nameOutput) - assert not os.path.exists(outputFile), \ - f"Output file {outputFile} already exists !!!" + outputFile = os.path.join(args.pathOutputDir, nameOutput) # Get splits if args.split: @@ -147,27 +171,50 @@ def main(argv): # shuffle(seqNames) seqNames = seqNames[:nsamples] + # Continue + addEndLine = False # to add end line (\n) to first line or not + if args.resume: + if os.path.exists(outputFile): + with open(outputFile, 'r') as f: + lines = [line for line in f] + existing_files = set([x.split()[0] for x in lines if x.split()]) + seqNames = [s for s in seqNames if os.path.splitext(s[1].split('/')[-1])[0] not in existing_files] + print(f"Found existing output file, continue to quantize {len(seqNames)} audio files left!") + if len(lines) > 0 and not lines[-1].endswith("\n"): + addEndLine = True + else: + assert not os.path.exists(outputFile), \ + f"Output file {outputFile} already exists !!! If you want to continue quantizing audio files, please check the --resume option." + + assert len(seqNames) > 0, \ + "No file to be quantized!" + # Load Clustering args - assert args.pathCheckpoint[-3:] == ".pt" - if os.path.exists(args.pathCheckpoint[:-3] + "_args.json"): - pathConfig = args.pathCheckpoint[:-3] + "_args.json" - elif os.path.exists(os.path.join(os.path.dirname(args.pathCheckpoint), "checkpoint_args.json")): - pathConfig = os.path.join(os.path.dirname(args.pathCheckpoint), "checkpoint_args.json") + assert args.pathClusteringCheckpoint[-3:] == ".pt" + if os.path.exists(args.pathClusteringCheckpoint[:-3] + "_args.json"): + pathConfig = args.pathClusteringCheckpoint[:-3] + "_args.json" + elif os.path.exists(os.path.join(os.path.dirname(args.pathClusteringCheckpoint), "checkpoint_args.json")): + pathConfig = os.path.join(os.path.dirname(args.pathClusteringCheckpoint), "checkpoint_args.json") else: assert False, \ - f"Args file not found in the directory {os.path.dirname(args.pathCheckpoint)}" + f"Args file not found in the directory {os.path.dirname(args.pathClusteringCheckpoint)}" clustering_args = readArgs(pathConfig) print("") print(f"Clutering args:\n{json.dumps(vars(clustering_args), indent=4, sort_keys=True)}") print('-' * 50) # Load CluterModule - clusterModule = loadClusterModule(args.pathCheckpoint) - clusterModule.cuda() + clusterModule = loadClusterModule(args.pathClusteringCheckpoint) + if not args.cpu: + clusterModule.cuda() # Load FeatureMaker print("") print("Loading CPC FeatureMaker") + if not os.path.isabs(clustering_args.pathCheckpoint): # Maybe it's relative path + clustering_args.pathCheckpoint = os.path.join(os.path.dirname(os.path.abspath(args.pathClusteringCheckpoint)), clustering_args.pathCheckpoint) + assert os.path.exists(clustering_args.pathCheckpoint), \ + f"CPC path at {clustering_args.pathCheckpoint} does not exist!!" if 'level_gru' in vars(clustering_args) and clustering_args.level_gru is not None: updateConfig = argparse.Namespace(nLevelsGRU=clustering_args.level_gru) else: @@ -183,8 +230,9 @@ def main(argv): featureMaker = torch.nn.Sequential(featureMaker, dimRed) if not clustering_args.train_mode: featureMaker.eval() - featureMaker.cuda() - def feature_function(x): + if not args.cpu: + featureMaker.cuda() + def cpc_feature_function(x): if args.nobatch is False: return buildFeature_batch(featureMaker, x, seqNorm=False, @@ -196,11 +244,11 @@ def feature_function(x): seqNorm=False, strict=args.strict) print("CPC FeatureMaker loaded!") - + # Quantization of files print("") - print(f"Quantizing audio files...") - seqQuantLines = [] + print(f"Quantizing audio files and saving outputs to {outputFile}...") + f = open(outputFile, "a") bar = progressbar.ProgressBar(maxval=len(seqNames)) bar.start() start_time = time() @@ -210,39 +258,20 @@ def feature_function(x): file_path = vals[1] file_path = os.path.join(args.pathDB, file_path) - # Get features & quantizing - cFeatures = feature_function(file_path).cuda() - - nGroups = cFeatures.size(-1)//clusterModule.Ck.size(-1) - - cFeatures = cFeatures.view(1, -1, clusterModule.Ck.size(-1)) + # Quantizing + quantLine = quantize_file(file_path, cpc_feature_function, clusterModule) - if len(vals) > 2 and int(vals[-1]) > 9400000: # Librilight, to avoid OOM - clusterModule = clusterModule.cpu() - cFeatures = cFeatures.cpu() - qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1) - clusterModule = clusterModule.cuda() + # Save the outputs + file_name = os.path.splitext(os.path.basename(file_path))[0] + outLine = "\t".join([file_name, quantLine]) + if addEndLine: + f.write("\n"+outLine) else: - qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1) - qFeatures = qFeatures[0].detach().cpu().numpy() - - # Transform to quantized line - quantLine = ",".join(["-".join([str(i) for i in item]) for item in qFeatures.reshape(-1, nGroups)]) - seqQuantLines.append(quantLine) - + f.write(outLine) + addEndLine = True bar.finish() - print(f"...done {len(seqQuantLines)} files in {time()-start_time} seconds.") - - # Saving outputs - print("") - print(f"Saving outputs to {outputFile}") - outLines = [] - for vals, quantln in zip(seqNames, seqQuantLines): - file_path = vals[1] - file_name = os.path.splitext(os.path.basename(file_path))[0] - outLines.append("\t".join([file_name, quantln])) - with open(outputFile, "w") as f: - f.write("\n".join(outLines)) + print(f"...done {len(seqNames)} files in {time()-start_time} seconds.") + f.close() if __name__ == "__main__": args = sys.argv[1:] diff --git a/cpc/criterion/clustering/clustering_script.py b/cpc/criterion/clustering/clustering_script.py index 5a4eca0..8a4c42c 100644 --- a/cpc/criterion/clustering/clustering_script.py +++ b/cpc/criterion/clustering/clustering_script.py @@ -128,12 +128,12 @@ def parseArgs(argv): print(f"Length of dataLoader: {len(trainLoader)}") print("") - #if args.level_gru is None: - # updateConfig = None - #else: - # updateConfig = argparse.Namespace(nLevelsGRU=args.level_gru) + if args.level_gru is None: + updateConfig = None + else: + updateConfig = argparse.Namespace(nLevelsGRU=args.level_gru) - model = loadModel([args.pathCheckpoint])[0]#, updateConfig=updateConfig)[0] + model = loadModel([args.pathCheckpoint])[0], updateConfig=updateConfig)[0] featureMaker = FeatureModule(model, args.encoder_layer) print("Checkpoint loaded!") print("") diff --git a/cpc/feature_loader.py b/cpc/feature_loader.py index 1e9c8b3..cc1a6da 100644 --- a/cpc/feature_loader.py +++ b/cpc/feature_loader.py @@ -30,7 +30,10 @@ def getDownsamplingFactor(self): def forward(self, data): batchAudio, label = data - cFeature, encoded, _ = self.featureMaker(batchAudio.cuda(), label) + if next(self.featureMaker.parameters()).is_cuda: + cFeature, encoded, _ = self.featureMaker(batchAudio.cuda(), label) + else: + cFeature, encoded, _ = self.featureMaker(batchAudio, label) if self.get_encoded: cFeature = encoded if self.collapse: @@ -108,8 +111,12 @@ def getCheckpointData(pathDir): return None checkpoints.sort(key=lambda x: int(os.path.splitext(x[11:])[0])) data = os.path.join(pathDir, checkpoints[-1]) - with open(os.path.join(pathDir, 'checkpoint_logs.json'), 'rb') as file: - logs = json.load(file) + + if os.path.exists(os.path.join(pathDir, 'checkpoint_logs.json')): + with open(os.path.join(pathDir, 'checkpoint_logs.json'), 'rb') as file: + logs = json.load(file) + else: + logs = None with open(os.path.join(pathDir, 'checkpoint_args.json'), 'rb') as file: args = json.load(file) @@ -153,7 +160,7 @@ def getAR(args): return arNet -def loadModel(pathCheckpoints, loadStateDict=True): +def loadModel(pathCheckpoints, loadStateDict=True, updateConfig=None): models = [] hiddenGar, hiddenEncoder = 0, 0 for path in pathCheckpoints: @@ -164,8 +171,15 @@ def loadModel(pathCheckpoints, loadStateDict=True): (len(locArgs.load) > 1 or os.path.dirname(locArgs.load[0]) != os.path.dirname(path)) + if updateConfig is not None and not doLoad: + print(f"Updating the configuartion file with ") + print(f'{json.dumps(vars(updateConfig), indent=4, sort_keys=True)}') + loadArgs(locArgs, updateConfig) + if doLoad: - m_, hg, he = loadModel(locArgs.load, loadStateDict=False) + m_, hg, he = loadModel(locArgs.load, + loadStateDict=False, + updateConfig=updateConfig) hiddenGar += hg hiddenEncoder += he else: @@ -240,6 +254,10 @@ def buildFeature(featureMaker, seqPath, strict=False, Return: a torch vector of size 1 x Seq_size x Feature_dim """ + if next(featureMaker.parameters()).is_cuda: + device = 'cuda' + else: + device = 'cpu' seq = torchaudio.load(seqPath)[0] sizeSeq = seq.size(1) start = 0 @@ -248,7 +266,7 @@ def buildFeature(featureMaker, seqPath, strict=False, if strict and start + maxSizeSeq > sizeSeq: break end = min(sizeSeq, start + maxSizeSeq) - subseq = (seq[:, start:end]).view(1, 1, -1).cuda(device=0) + subseq = (seq[:, start:end]).view(1, 1, -1).to(device) with torch.no_grad(): features = featureMaker((subseq, None)) if seqNorm: @@ -257,7 +275,7 @@ def buildFeature(featureMaker, seqPath, strict=False, start += maxSizeSeq if strict and start < sizeSeq: - subseq = (seq[:, -maxSizeSeq:]).view(1, 1, -1).cuda(device=0) + subseq = (seq[:, -maxSizeSeq:]).view(1, 1, -1).to(device) with torch.no_grad(): features = featureMaker((subseq, None)) if seqNorm: @@ -267,3 +285,69 @@ def buildFeature(featureMaker, seqPath, strict=False, out = torch.cat(out, dim=1) return out + + +def buildFeature_batch(featureMaker, seqPath, strict=False, + maxSizeSeq=8000, seqNorm=False, batch_size=8): + r""" + Apply the featureMaker to the given file. Apply batch-computation + Arguments: + - featureMaker (FeatureModule): model to apply + - seqPath (string): path of the sequence to load + - strict (bool): if True, always work with chunks of the size + maxSizeSeq + - maxSizeSeq (int): maximal size of a chunk + - seqNorm (bool): if True, normalize the output along the time + dimension to get chunks of mean zero and var 1 + Return: + a torch vector of size 1 x Seq_size x Feature_dim + """ + if next(featureMaker.parameters()).is_cuda: + device = 'cuda' + else: + device = 'cpu' + seq = torchaudio.load(seqPath)[0] + sizeSeq = seq.size(1) + + # Compute number of batches + n_chunks = sizeSeq//maxSizeSeq + n_batches = n_chunks//batch_size + if n_chunks % batch_size != 0: + n_batches += 1 + + out = [] + # Treat each batch + for batch_idx in range(n_batches): + start = batch_idx*batch_size*maxSizeSeq + end = min((batch_idx+1)*batch_size*maxSizeSeq, maxSizeSeq*n_chunks) + batch_seqs = (seq[:, start:end]).view(-1, 1, maxSizeSeq).to(device) + with torch.no_grad(): + # breakpoint() + batch_out = featureMaker((batch_seqs, None)) + for features in batch_out: + features = features.unsqueeze(0) + if seqNorm: + features = seqNormalization(features) + out.append(features.detach().cpu()) + + # Remaining frames + if sizeSeq % maxSizeSeq >= featureMaker.getDownsamplingFactor(): + remainders = sizeSeq % maxSizeSeq + if strict: + subseq = (seq[:, -maxSizeSeq:]).view(1, 1, -1).to(device) + with torch.no_grad(): + features = featureMaker((subseq, None)) + if seqNorm: + features = seqNormalization(features) + delta = remainders // featureMaker.getDownsamplingFactor() + out.append(features[:, -delta:].detach().cpu()) + else: + subseq = (seq[:, -remainders:]).view(1, 1, -1).to(device) + with torch.no_grad(): + features = featureMaker((subseq, None)) + if seqNorm: + features = seqNormalization(features) + out.append(features.detach().cpu()) + + out = torch.cat(out, dim=1) + return out \ No newline at end of file From db7cff44c9fc18eddeafca40ad34e04b51936a8a Mon Sep 17 00:00:00 2001 From: tuanh208 Date: Fri, 12 Mar 2021 17:17:33 +0100 Subject: [PATCH 2/3] modify train.py and clustering_script.py --- cpc/criterion/clustering/clustering_script.py | 2 +- cpc/train.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/cpc/criterion/clustering/clustering_script.py b/cpc/criterion/clustering/clustering_script.py index 8a4c42c..6f0dceb 100644 --- a/cpc/criterion/clustering/clustering_script.py +++ b/cpc/criterion/clustering/clustering_script.py @@ -133,7 +133,7 @@ def parseArgs(argv): else: updateConfig = argparse.Namespace(nLevelsGRU=args.level_gru) - model = loadModel([args.pathCheckpoint])[0], updateConfig=updateConfig)[0] + model = loadModel([args.pathCheckpoint][0], updateConfig=updateConfig)[0] featureMaker = FeatureModule(model, args.encoder_layer) print("Checkpoint loaded!") print("") diff --git a/cpc/train.py b/cpc/train.py index 46fa30c..4190c15 100644 --- a/cpc/train.py +++ b/cpc/train.py @@ -248,7 +248,8 @@ def main(args): seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, - loadCache=not args.ignore_cache) + loadCache=not args.ignore_cache, + speaker_level=args.speaker_level) print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers') # Datasets @@ -347,6 +348,8 @@ def main(args): if not os.path.isdir(args.pathCheckpoint): os.mkdir(args.pathCheckpoint) args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint") + with open(args.pathCheckpoint + "_args.json", 'w') as file: + json.dump(vars(args), file, indent=2) scheduler = None if args.schedulerStep > 0: @@ -415,6 +418,8 @@ def parseArgs(argv): group_db.add_argument('--max_size_loaded', type=int, default=4000000000, help='Maximal amount of data (in byte) a dataset ' 'can hold in memory at any given time') + group_db.add_argument('--speaker_level', type=int, default=1, + help="Level of speaker in the training directory.") group_supervised = parser.add_argument_group( 'Supervised mode (depreciated)') group_supervised.add_argument('--supervised', action='store_true', From 09307a1f2fb21cc462acf3f9f298510219424c96 Mon Sep 17 00:00:00 2001 From: tuanh208 Date: Sat, 27 Mar 2021 02:15:59 +0100 Subject: [PATCH 3/3] add Early stopping --- cpc/train.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/cpc/train.py b/cpc/train.py index 4190c15..7ca12d6 100644 --- a/cpc/train.py +++ b/cpc/train.py @@ -165,7 +165,9 @@ def run(trainDataset, pathCheckpoint, optimizer, scheduler, - logs): + logs, + earlyStopping=False, + patience=5): print(f"Running {nEpoch} epochs") startEpoch = len(logs["epoch"]) @@ -173,6 +175,17 @@ def run(trainDataset, bestStateDict = None start_time = time.time() + def count_inverse(ls): + return sum([1 for i in range(len(ls)-1) if ls[i] > ls[i+1]]) + + if "locAcc_val" in logs and len(logs["locAcc_val"]) > 0: + valAccuracyList = [100*np.mean(ls) for ls in logs['locAcc_val']] + if count_inverse(valAccuracyList) > patience: + print(f"The patience={patience} has been reached, early stopping activated. Stopped!") + return + else: + valAccuracyList = [] + for epoch in range(startEpoch, nEpoch): print(f"Starting epoch {epoch}") @@ -198,6 +211,7 @@ def run(trainDataset, torch.cuda.empty_cache() currentAccuracy = float(locLogsVal["locAcc_val"].mean()) + valAccuracyList.append(100*currentAccuracy) if currentAccuracy > bestAcc: bestStateDict = fl.get_module(cpcModel).state_dict() @@ -221,6 +235,18 @@ def run(trainDataset, f"{pathCheckpoint}_{epoch}.pt") utils.save_logs(logs, pathCheckpoint + "_logs.json") + with open(os.path.join(os.path.dirname(pathCheckpoint), "valAcc_info.txt"), 'w') as file: + outLines = [f"Epoch {ep} : {acc}" for ep, acc in enumerate(valAccuracyList)] + [f"Best valAcc : checkpoint_{np.argmax(valAccuracyList)}"] + file.write("\n".join(outLines)) + + if earlyStopping: + if count_inverse(valAccuracyList) > patience: + print(f"Early stopping activated. Stopped at epoch {epoch}!") + with open(os.path.join(os.path.dirname(pathCheckpoint), "earlyStopping.txt"), 'a') as file: + file.write(f"The patience={patience} has been reached, early stopping activated. Stopped at epoch {epoch}!\n") + break + + def main(args): args = parseArgs(args) @@ -387,7 +413,9 @@ def main(args): args.pathCheckpoint, optimizer, scheduler, - logs) + logs, + args.early_stopping, + args.patience) def parseArgs(argv): @@ -440,6 +468,8 @@ def parseArgs(argv): group_save.add_argument('--save_step', type=int, default=5, help="Frequency (in epochs) at which a checkpoint " "should be saved") + group_save.add_argument('--early_stopping', action='store_true') + group_save.add_argument('--patience', type=int, default=5) group_load = parser.add_argument_group('Load') group_load.add_argument('--load', type=str, default=None, nargs='*',