diff --git a/cpc/dataset.py b/cpc/dataset.py index 0251508..639491a 100644 --- a/cpc/dataset.py +++ b/cpc/dataset.py @@ -27,7 +27,9 @@ def __init__(self, phoneLabelsDict, nSpeakers, nProcessLoader=50, - MAX_SIZE_LOADED=4000000000): + MAX_SIZE_LOADED=4000000000, + keepSameSeedForDSshuffle=False, + newTorchaudio=False): """ Args: - path (string): path to the training dataset @@ -45,6 +47,8 @@ def __init__(self, - MAX_SIZE_LOADED (int): target maximal size of the floating array containing all loaded data. """ + self.keepSameSeedForDSshuffle = keepSameSeedForDSshuffle + self.newTorchaudio = newTorchaudio self.MAX_SIZE_LOADED = MAX_SIZE_LOADED self.nProcessLoader = nProcessLoader self.dbPath = Path(path) @@ -91,15 +95,24 @@ def clear(self): del self.seqLabel def prepare(self): - randomstate = random.getstate() - random.seed(767543) # set seed only for batching so that it is random but always same for same dataset - # so that capturing captures data for same audio across runs if same dataset provided + if self.keepSameSeedForDSshuffle: + print("--> setting same seed for DS seqNames shuffling") + randomstate = random.getstate() + random.seed(767543) # set seed only for batching so that it is random but always same for same dataset + # so that capturing captures data for same audio across runs if same dataset provided + else: + print("--> using random seed for DS seqNames shuffling") random.shuffle(self.seqNames) - random.setstate(randomstate) # restore random state so that other stuff changes with seed in args + if self.keepSameSeedForDSshuffle: + random.setstate(randomstate) # restore random state so that other stuff changes with seed in args start_time = time.time() print("Checking length...") - allLength = self.reload_pool.map(extractLength, self.seqNames) + if self.newTorchaudio: + mapFun = extractLengthNewTorchaudio + else: + mapFun = extractLength + allLength = self.reload_pool.map(mapFun, self.seqNames) self.packageIndex, self.totSize = [], 0 start, packageSize = 0, 0 @@ -423,11 +436,17 @@ def __iter__(self): return iter(self.batches) -def extractLength(couple): +def extractLength(couple): # for old torchaudio speaker, locPath = couple info = torchaudio.info(str(locPath))[0] return info.length +def extractLengthNewTorchaudio(couple): # linux machines, new torchaudio 0.8.1+ for CUDA around >= 11 + speaker, locPath = couple + # https://pytorch.org/audio/stable/backend.html#torchaudio.backend.common.AudioMetaData + info = torchaudio.info(str(locPath)) + return info.num_frames * info.num_channels # (default 'sox' backend) + def findAllSeqs(dirName, extension='.flac', diff --git a/cpc/train.py b/cpc/train.py index 3ac18a4..f2e57cc 100644 --- a/cpc/train.py +++ b/cpc/train.py @@ -316,6 +316,7 @@ def run(trainDataset, for epoch in range(startEpoch, nEpoch): print(f"Starting epoch {epoch}") + sys.stdout.flush() utils.cpu_stats() trainLoader = trainDataset.getDataLoader(batchSize, samplingMode, @@ -505,7 +506,9 @@ def main(args): phoneLabels, len(speakers), nProcessLoader=args.n_process_loader, - MAX_SIZE_LOADED=args.max_size_loaded) + MAX_SIZE_LOADED=args.max_size_loaded, + keepSameSeedForDSshuffle=args.fixedDSshuffleSeed, + newTorchaudio=args.newTorchaudio) print("Training dataset loaded") print("") @@ -515,7 +518,9 @@ def main(args): seqVal, phoneLabels, len(speakers), - nProcessLoader=args.n_process_loader) + nProcessLoader=args.n_process_loader, + keepSameSeedForDSshuffle=args.fixedDSshuffleSeed, + newTorchaudio=args.newTorchaudio) print("Validation dataset loaded") print("") else: @@ -538,7 +543,9 @@ def main(args): seqCapture, phoneLabelsForCapture, len(speakers), - nProcessLoader=args.n_process_loader) + nProcessLoader=args.n_process_loader, + keepSameSeedForDSshuffle=True, + newTorchaudio=args.newTorchaudio) print("Capture dataset loaded") print("") @@ -713,16 +720,20 @@ def constructSpeakerCriterionAndOptimizer(): return speaker_criterion, speaker_optimizer - linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, - phoneLabelsData, len(speakers)) - linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, - phoneLabelsData, len(speakers)) + # loading this second time kills RAM + # linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, + # phoneLabelsData, len(speakers), keepSameSeedForDSshuffle=args.fixedDSshuffleSeed, + # newTorchaudio=args.newTorchaudio) + # linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, + # phoneLabelsData, len(speakers), keepSameSeedForDSshuffle=args.fixedDSshuffleSeed, + # newTorchaudio=args.newTorchaudio) - linsep_train_loader = linsep_db_train.getDataLoader(linsep_batch_size, "uniform", True, + linsep_train_loader = trainDataset.getDataLoader(linsep_batch_size, "uniform", True, numWorkers=0) - - linsep_val_loader = linsep_db_val.getDataLoader(linsep_batch_size, 'sequential', False, + print("linsep_train_loader ready") + linsep_val_loader = valDataset.getDataLoader(linsep_batch_size, 'sequential', False, numWorkers=0) + print("linsep_val_loader ready") def runLinsepClassificationTraining(numOfEpoch, cpcMdl, cpcStateEpoch): log_path_for_epoch = os.path.join(args.linsep_logs_dir, str(numOfEpoch)) @@ -841,6 +852,11 @@ def parseArgs(argv): group_db.add_argument('--pathVal', type=str, default=None, help='Path to a .txt file containing the list of the ' 'validation sequences.') + group_db.add_argument('--fixedDSshuffleSeed', action='store_true', + help="if set, will always shuffle train & val DS same way (with same seed); " + "if not set, will use randomized seed used for other stuff also for this") + group_db.add_argument('--newTorchaudio', action='store_true', + help="if set, use newer audio data loading API compatible with newer torchaudio (0.8.1+)") # stuff below for capturing data group_db.add_argument('--onlyCapture', action='store_true', help='Only capture data from learned model for one epoch, ignore training; '