From f0553d41338b5481ec4325e18bac53929b0081a6 Mon Sep 17 00:00:00 2001 From: danielsf Date: Tue, 1 Feb 2022 10:11:55 -0800 Subject: [PATCH 01/10] allow users to pass an actual file path to MovieJSONGenerator This will protect us against our naming conventions for motion-corrected movies changing. It will also allow non-Allen users to use the MovieJSONGenerator, which is very useful, since it gives the user granular control over which frames are used in training and validation. --- deepinterpolation/generator_collection.py | 27 +++++++++++++++-------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/deepinterpolation/generator_collection.py b/deepinterpolation/generator_collection.py index 9deedb45..a8842669 100755 --- a/deepinterpolation/generator_collection.py +++ b/deepinterpolation/generator_collection.py @@ -1109,15 +1109,24 @@ def __data_generation__(self, index_frame): # Initialization local_path = self.frame_data_location[local_lims]["path"] - _filenames = ["motion_corrected_video.h5", "concat_31Hz_0.h5"] - motion_path = [] - for _filename in _filenames: - _filepath = os.path.join(local_path, "processed", _filename) - if os.path.exists(_filepath) and not os.path.islink( - _filepath - ): # Path exists and is not symbolic - motion_path = _filepath - break + motion_path = None + if os.path.isfile(local_path): + motion_path = local_path + else: + _filenames = ["motion_corrected_video.h5", "concat_31Hz_0.h5"] + motion_path = [] + for _filename in _filenames: + _filepath = os.path.join(local_path, "processed", _filename) + if os.path.exists(_filepath) and not os.path.islink( + _filepath + ): # Path exists and is not symbolic + motion_path = _filepath + break + + if motion_path is None: + msg = f"unable to find valid movie file for path\n" + msg += f"{local_path}" + raise RuntimeError(msg) movie_obj = h5py.File(motion_path, "r") From dcae28aa16ea7fed48a6588a5b00e1554bcc2f5e Mon Sep 17 00:00:00 2001 From: danielsf Date: Tue, 1 Feb 2022 10:17:03 -0800 Subject: [PATCH 02/10] pass error msg from data generation try/except along as warning This will give the user the option to see what happened that prevented the generator from passing along all of their specified data This only affects the MovieJSONGenerator --- deepinterpolation/generator_collection.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/deepinterpolation/generator_collection.py b/deepinterpolation/generator_collection.py index a8842669..bf0a6fc2 100755 --- a/deepinterpolation/generator_collection.py +++ b/deepinterpolation/generator_collection.py @@ -7,6 +7,7 @@ import nibabel as nib import s3fs import glob +import warnings from deepinterpolation.generic import JsonLoader @@ -1171,5 +1172,8 @@ def __data_generation__(self, index_frame): movie_obj.close() return input_full, output_full - except Exception: - print("Issues with " + str(self.lims_id)) + except Exception as err: + msg = f"Issues with {local_lims}\n" + msg += f"Error: {str(err)}\n" + msg += "moving on\n" + warnings.warn(msg) From f941a9c97b022ea3296812565f4e1542e9f3bf0f Mon Sep 17 00:00:00 2001 From: danielsf Date: Tue, 1 Feb 2022 11:31:25 -0800 Subject: [PATCH 03/10] remove unused motion_path = [] --- deepinterpolation/generator_collection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepinterpolation/generator_collection.py b/deepinterpolation/generator_collection.py index bf0a6fc2..012e30e6 100755 --- a/deepinterpolation/generator_collection.py +++ b/deepinterpolation/generator_collection.py @@ -1115,7 +1115,6 @@ def __data_generation__(self, index_frame): motion_path = local_path else: _filenames = ["motion_corrected_video.h5", "concat_31Hz_0.h5"] - motion_path = [] for _filename in _filenames: _filepath = os.path.join(local_path, "processed", _filename) if os.path.exists(_filepath) and not os.path.islink( From 28cd6db10e3aad7021fd8e05972d2b83d2949eac Mon Sep 17 00:00:00 2001 From: danielsf Date: Tue, 1 Feb 2022 11:33:14 -0800 Subject: [PATCH 04/10] use context manager to open and close movies --- deepinterpolation/generator_collection.py | 81 +++++++++++------------ 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/deepinterpolation/generator_collection.py b/deepinterpolation/generator_collection.py index 012e30e6..bea4fd0b 100755 --- a/deepinterpolation/generator_collection.py +++ b/deepinterpolation/generator_collection.py @@ -1128,47 +1128,46 @@ def __data_generation__(self, index_frame): msg += f"{local_path}" raise RuntimeError(msg) - movie_obj = h5py.File(motion_path, "r") - - local_frame_data = self.frame_data_location[local_lims] - output_frame = local_frame_data["frames"][local_img] - local_mean = local_frame_data["mean"] - local_std = local_frame_data["std"] - - input_full = np.zeros( - [1, 512, 512, self.pre_frame + self.post_frame]) - output_full = np.zeros([1, 512, 512, 1]) - - input_index = np.arange( - output_frame - self.pre_frame - self.pre_post_omission, - output_frame + self.post_frame + self.pre_post_omission + 1, - ) - input_index = input_index[input_index != output_frame] - - for index_padding in np.arange(self.pre_post_omission + 1): - input_index = input_index[input_index != - output_frame - index_padding] - input_index = input_index[input_index != - output_frame + index_padding] - - data_img_input = movie_obj["data"][input_index, :, :] - data_img_output = movie_obj["data"][output_frame, :, :] - - data_img_input = np.swapaxes(data_img_input, 1, 2) - data_img_input = np.swapaxes(data_img_input, 0, 2) - - img_in_shape = data_img_input.shape - img_out_shape = data_img_output.shape - - data_img_input = (data_img_input.astype( - "float") - local_mean) / local_std - data_img_output = (data_img_output.astype( - "float") - local_mean) / local_std - input_full[0, : img_in_shape[0], - : img_in_shape[1], :] = data_img_input - output_full[0, : img_out_shape[0], - : img_out_shape[1], 0] = data_img_output - movie_obj.close() + with h5py.File(motion_path, "r") as movie_obj: + + local_frame_data = self.frame_data_location[local_lims] + output_frame = local_frame_data["frames"][local_img] + local_mean = local_frame_data["mean"] + local_std = local_frame_data["std"] + + input_full = np.zeros( + [1, 512, 512, self.pre_frame + self.post_frame]) + output_full = np.zeros([1, 512, 512, 1]) + + input_index = np.arange( + output_frame - self.pre_frame - self.pre_post_omission, + output_frame + self.post_frame + self.pre_post_omission + 1, + ) + input_index = input_index[input_index != output_frame] + + for index_padding in np.arange(self.pre_post_omission + 1): + input_index = input_index[input_index != + output_frame - index_padding] + input_index = input_index[input_index != + output_frame + index_padding] + + data_img_input = movie_obj["data"][input_index, :, :] + data_img_output = movie_obj["data"][output_frame, :, :] + + data_img_input = np.swapaxes(data_img_input, 1, 2) + data_img_input = np.swapaxes(data_img_input, 0, 2) + + img_in_shape = data_img_input.shape + img_out_shape = data_img_output.shape + + data_img_input = (data_img_input.astype( + "float") - local_mean) / local_std + data_img_output = (data_img_output.astype( + "float") - local_mean) / local_std + input_full[0, : img_in_shape[0], + : img_in_shape[1], :] = data_img_input + output_full[0, : img_out_shape[0], + : img_out_shape[1], 0] = data_img_output return input_full, output_full except Exception as err: From 02e71798b6017cb845f6e96133638ae8a29b5390 Mon Sep 17 00:00:00 2001 From: danielsf Date: Tue, 1 Feb 2022 11:36:38 -0800 Subject: [PATCH 05/10] sort the lims_id list in MovieJSONGenerator so that the order over which the movies are iterated does not depend on the way in which json deserializes the input parameter file --- deepinterpolation/generator_collection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepinterpolation/generator_collection.py b/deepinterpolation/generator_collection.py index bea4fd0b..6fc55dee 100755 --- a/deepinterpolation/generator_collection.py +++ b/deepinterpolation/generator_collection.py @@ -1039,6 +1039,7 @@ def __init__(self, json_path): self.frame_data_location = json.load(json_handle) self.lims_id = list(self.frame_data_location.keys()) + self.lims_id.sort() self.nb_lims = len(self.lims_id) self.img_per_movie = len( self.frame_data_location[self.lims_id[0]]["frames"]) From 6c24a4c37c364d690bcf3bf9726390fbb2bb893b Mon Sep 17 00:00:00 2001 From: danielsf Date: Tue, 1 Feb 2022 12:12:25 -0800 Subject: [PATCH 06/10] pep8 changes --- deepinterpolation/generator_collection.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/deepinterpolation/generator_collection.py b/deepinterpolation/generator_collection.py index 6fc55dee..79895023 100755 --- a/deepinterpolation/generator_collection.py +++ b/deepinterpolation/generator_collection.py @@ -1117,7 +1117,10 @@ def __data_generation__(self, index_frame): else: _filenames = ["motion_corrected_video.h5", "concat_31Hz_0.h5"] for _filename in _filenames: - _filepath = os.path.join(local_path, "processed", _filename) + _filepath = os.path.join(local_path, + "processed", + _filename) + if os.path.exists(_filepath) and not os.path.islink( _filepath ): # Path exists and is not symbolic @@ -1125,7 +1128,7 @@ def __data_generation__(self, index_frame): break if motion_path is None: - msg = f"unable to find valid movie file for path\n" + msg = "unable to find valid movie file for path\n" msg += f"{local_path}" raise RuntimeError(msg) @@ -1141,8 +1144,8 @@ def __data_generation__(self, index_frame): output_full = np.zeros([1, 512, 512, 1]) input_index = np.arange( - output_frame - self.pre_frame - self.pre_post_omission, - output_frame + self.post_frame + self.pre_post_omission + 1, + output_frame - self.pre_frame - self.pre_post_omission, + output_frame + self.post_frame + self.pre_post_omission + 1, ) input_index = input_index[input_index != output_frame] From e464e9b6187bca39591e795cb5c8ff3ab03384a8 Mon Sep 17 00:00:00 2001 From: danielsf Date: Tue, 1 Feb 2022 12:13:13 -0800 Subject: [PATCH 07/10] add unit test for MovieJSONGenerator --- tests/test_movie_json_generator.py | 201 +++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 tests/test_movie_json_generator.py diff --git a/tests/test_movie_json_generator.py b/tests/test_movie_json_generator.py new file mode 100644 index 00000000..d041e158 --- /dev/null +++ b/tests/test_movie_json_generator.py @@ -0,0 +1,201 @@ +import pytest +import h5py +import json +import tempfile +import pathlib +import numpy as np +from deepinterpolation.generator_collection import MovieJSONGenerator + + +@pytest.fixture(scope='session') +def frame_list_fixture(): + """ + Indexes of frames returned by MovieJSONGenerator + """ + return [4, 3, 5, 7] + + +@pytest.fixture(scope='session') +def movie_path_list_fixture(tmpdir_factory): + """ + yields a list of paths to test movie files + """ + path_list = [] + + parent_tmpdir = tmpdir_factory.mktemp('movies_for_test') + rng = np.random.default_rng(172312) + this_dir = tempfile.mkdtemp(dir=parent_tmpdir) + this_path = tempfile.mkstemp(dir=this_dir, suffix='.h5')[1] + with h5py.File(this_path, 'w') as out_file: + out_file.create_dataset('data', + data=rng.random((12, 512, 512))) + + path_list.append(this_path) + + this_dir = tempfile.mkdtemp(dir=parent_tmpdir) + this_dir = pathlib.Path(this_dir) / 'processed' + this_dir.mkdir() + this_path = this_dir / 'concat_31Hz_0.h5' + with h5py.File(this_path, 'w') as out_file: + out_file.create_dataset('data', + data=rng.random((12, 512, 512))) + path_list.append(str(this_path.resolve().absolute())) + + this_dir = tempfile.mkdtemp(dir=parent_tmpdir) + this_dir = pathlib.Path(this_dir) / 'processed' + this_dir.mkdir() + this_path = this_dir / 'motion_corrected_video.h5' + with h5py.File(this_path, 'w') as out_file: + out_file.create_dataset('data', + data=rng.random((12, 512, 512))) + path_list.append(str(this_path.resolve().absolute())) + + yield path_list + + for this_path in path_list: + this_path = pathlib.Path(this_path) + if this_path.is_file(): + this_path.unlink() + + +@pytest.fixture(scope='session') +def json_frame_specification_fixture(movie_path_list_fixture, + tmpdir_factory, + frame_list_fixture): + """ + yields a dict with the following key/value pairs + + 'json_path' -- path to the file specifying the + movies/frames for the generator + + 'expected_input' -- list of expected input + datasets returned by the generator + + 'expected_output' -- list of expected output + datasets returned by the generator + """ + + params = dict() + + for ii, movie_path in enumerate(movie_path_list_fixture): + this_params = dict() + if ii > 0: + movie_path = str(pathlib.Path(movie_path).parent.parent) + this_params['path'] = movie_path + this_params['frames'] = frame_list_fixture + this_params['mean'] = (ii+1)*2.1 + this_params['std'] = (ii+1)*3.4 + params[str(ii)] = this_params + + tmpdir = tmpdir_factory.mktemp('frame_specification') + json_path = tempfile.mkstemp( + dir=tmpdir, + prefix='frame_specification_params_', + suffix='.json')[1] + with open(json_path, 'w') as out_file: + out_file.write(json.dumps(params)) + + # now construct the input and output frames that + # we expect this generator to yield + expected_output_frames = [] + expected_input_frames = [] + + path_to_data = dict() + for movie_path in movie_path_list_fixture: + with h5py.File(movie_path, 'r') as in_file: + data = in_file['data'][()] + path_to_data[movie_path] = data + + for i_frame in range(len(frame_list_fixture)): + for ii in range(len(movie_path_list_fixture)): + this_params = params[str(ii)] + mu = this_params['mean'] + std = this_params['std'] + movie_path = movie_path_list_fixture[ii] + data = path_to_data[movie_path] + frame = frame_list_fixture[i_frame] + output_data = (data[frame, :, :] - mu)/std + + input_indexes = np.array([frame-2, frame-1, frame+1, frame+2]) + input_data = (data[input_indexes, :, :]-mu)/std + + expected_output_frames.append(output_data) + expected_input_frames.append(input_data) + + yield {'json_path': json_path, + 'expected_input': expected_input_frames, + 'expected_output': expected_output_frames} + + json_path = pathlib.Path(json_path) + if json_path.is_file(): + json_path.unlink() + + +@pytest.fixture(scope='session') +def json_generator_params_fixture( + tmpdir_factory, + json_frame_specification_fixture): + """ + yields the path to the JSON configuration file for the MovieJSONGenerator + """ + + tmpdir = tmpdir_factory.mktemp('json_generator_params') + json_path = tempfile.mkstemp(dir=tmpdir, + prefix='movie_json_generator_params_', + suffix='.json')[1] + + params = dict() + params['pre_post_omission'] = 0 + params['total_samples'] = -1 + params['name'] = 'MovieJSONGenerator' + params['batch_size'] = 1 + params['start_frame'] = 0 + params['end_frame'] = -1 + params['pre_frame'] = 2 + params['post_frame'] = 2 + params['randomize'] = True + params['data_path'] = json_frame_specification_fixture['json_path'] + params['steps_per_epoch'] = -1 + params['train_path'] = json_frame_specification_fixture['json_path'] + params['type'] = 'generator' + + with open(json_path, 'w') as out_file: + out_file.write(json.dumps(params, indent=2)) + + yield json_path + + json_path = pathlib.Path(json_path) + if json_path.is_file(): + json_path.unlink() + + +def test_movie_json_generator( + movie_path_list_fixture, + json_frame_specification_fixture, + json_generator_params_fixture, + frame_list_fixture): + + expected_input = json_frame_specification_fixture['expected_input'] + expected_output = json_frame_specification_fixture['expected_output'] + + generator = MovieJSONGenerator(json_generator_params_fixture) + lims_id_list = generator.lims_id + + n_frames = len(frame_list_fixture) + dataset_ct = 0 + + for dataset in generator: + # check that the dataset contains the expected input/output frames + expected_i = expected_input[dataset_ct] + expected_o = expected_output[dataset_ct] + + actual_o = dataset[1][0, :, :, 0] + np.testing.assert_array_equal(actual_o, expected_o) + + actual_i = dataset[0][0, :, :, :].transpose(2, 0, 1) + np.testing.assert_array_equal(actual_i, expected_i) + + dataset_ct += 1 + + # make sure we got the expected number of datasets + assert dataset_ct == len(lims_id_list)*n_frames From 885037f6281ea1f34d6b2eae4b6f6c196ce1dcd7 Mon Sep 17 00:00:00 2001 From: danielsf Date: Thu, 3 Feb 2022 08:06:24 -0800 Subject: [PATCH 08/10] only keep movie_obj open as long as you need to --- deepinterpolation/generator_collection.py | 65 ++++++++++++----------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/deepinterpolation/generator_collection.py b/deepinterpolation/generator_collection.py index 79895023..b38cc8cb 100755 --- a/deepinterpolation/generator_collection.py +++ b/deepinterpolation/generator_collection.py @@ -1132,48 +1132,49 @@ def __data_generation__(self, index_frame): msg += f"{local_path}" raise RuntimeError(msg) - with h5py.File(motion_path, "r") as movie_obj: - - local_frame_data = self.frame_data_location[local_lims] - output_frame = local_frame_data["frames"][local_img] - local_mean = local_frame_data["mean"] - local_std = local_frame_data["std"] - - input_full = np.zeros( - [1, 512, 512, self.pre_frame + self.post_frame]) - output_full = np.zeros([1, 512, 512, 1]) + local_frame_data = self.frame_data_location[local_lims] + output_frame = local_frame_data["frames"][local_img] + local_mean = local_frame_data["mean"] + local_std = local_frame_data["std"] + + input_full = np.zeros( + [1, 512, 512, self.pre_frame + self.post_frame]) + output_full = np.zeros([1, 512, 512, 1]) + + input_index = np.arange( + output_frame - self.pre_frame - self.pre_post_omission, + output_frame + self.post_frame + self.pre_post_omission + 1, + ) + input_index = input_index[input_index != output_frame] - input_index = np.arange( - output_frame - self.pre_frame - self.pre_post_omission, - output_frame + self.post_frame + self.pre_post_omission + 1, - ) - input_index = input_index[input_index != output_frame] + for index_padding in np.arange(self.pre_post_omission + 1): + input_index = input_index[input_index != + output_frame - index_padding] + input_index = input_index[input_index != + output_frame + index_padding] - for index_padding in np.arange(self.pre_post_omission + 1): - input_index = input_index[input_index != - output_frame - index_padding] - input_index = input_index[input_index != - output_frame + index_padding] + with h5py.File(motion_path, "r") as movie_obj: data_img_input = movie_obj["data"][input_index, :, :] data_img_output = movie_obj["data"][output_frame, :, :] - data_img_input = np.swapaxes(data_img_input, 1, 2) - data_img_input = np.swapaxes(data_img_input, 0, 2) + data_img_input = np.swapaxes(data_img_input, 1, 2) + data_img_input = np.swapaxes(data_img_input, 0, 2) - img_in_shape = data_img_input.shape - img_out_shape = data_img_output.shape + img_in_shape = data_img_input.shape + img_out_shape = data_img_output.shape - data_img_input = (data_img_input.astype( - "float") - local_mean) / local_std - data_img_output = (data_img_output.astype( - "float") - local_mean) / local_std - input_full[0, : img_in_shape[0], - : img_in_shape[1], :] = data_img_input - output_full[0, : img_out_shape[0], - : img_out_shape[1], 0] = data_img_output + data_img_input = (data_img_input.astype( + "float") - local_mean) / local_std + data_img_output = (data_img_output.astype( + "float") - local_mean) / local_std + input_full[0, : img_in_shape[0], + : img_in_shape[1], :] = data_img_input + output_full[0, : img_out_shape[0], + : img_out_shape[1], 0] = data_img_output return input_full, output_full + except Exception as err: msg = f"Issues with {local_lims}\n" msg += f"Error: {str(err)}\n" From ce7b9c3bf55e645128dd519231a2b4f500028085 Mon Sep 17 00:00:00 2001 From: danielsf Date: Mon, 7 Feb 2022 17:32:45 -0800 Subject: [PATCH 09/10] do not sort the LIMS ID of experiment in the generator --- deepinterpolation/generator_collection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepinterpolation/generator_collection.py b/deepinterpolation/generator_collection.py index b38cc8cb..3710ad6a 100755 --- a/deepinterpolation/generator_collection.py +++ b/deepinterpolation/generator_collection.py @@ -1039,7 +1039,6 @@ def __init__(self, json_path): self.frame_data_location = json.load(json_handle) self.lims_id = list(self.frame_data_location.keys()) - self.lims_id.sort() self.nb_lims = len(self.lims_id) self.img_per_movie = len( self.frame_data_location[self.lims_id[0]]["frames"]) From 9161114c7cfcb9a966cbdaf39594731781f0c9f2 Mon Sep 17 00:00:00 2001 From: danielsf Date: Tue, 8 Feb 2022 13:39:15 -0800 Subject: [PATCH 10/10] allow users to specify seed used in randomization --- deepinterpolation/cli/schemas.py | 6 ++++++ deepinterpolation/generator_collection.py | 17 ++++++++++++++++- tests/test_movie_json_generator.py | 19 ++++++++++++++++--- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/deepinterpolation/cli/schemas.py b/deepinterpolation/cli/schemas.py index 7484316e..4cd9102f 100644 --- a/deepinterpolation/cli/schemas.py +++ b/deepinterpolation/cli/schemas.py @@ -179,6 +179,12 @@ class GeneratorSchema(argschema.schemas.DefaultSchema): inference." ) + seed = argschema.fields.Int( + required=False, + default=235813, + description=("Seed used for random number generator used to " + "shuffle samples")) + total_samples = argschema.fields.Int( required=False, default=-1, diff --git a/deepinterpolation/generator_collection.py b/deepinterpolation/generator_collection.py index 3710ad6a..0c36629e 100755 --- a/deepinterpolation/generator_collection.py +++ b/deepinterpolation/generator_collection.py @@ -440,6 +440,12 @@ def __init__(self, json_path): else: self.randomize = True + if self.randomize: + if "seed" in self.json_data.keys(): + self.rng = np.random.default_rng(self.json_data["seed"]) + else: + self.rng = np.random.default_rng() + if "pre_post_omission" in self.json_data.keys(): self.pre_post_omission = self.json_data["pre_post_omission"] else: @@ -491,7 +497,7 @@ def _calculate_list_samples(self, total_frame_per_movie): self.list_samples = np.arange(self.start_sample, self.end_sample+1) if self.randomize: - np.random.shuffle(self.list_samples) + self.rng.shuffle(self.list_samples) # We cut the number of samples if asked to if (self.total_samples > 0 @@ -1039,6 +1045,15 @@ def __init__(self, json_path): self.frame_data_location = json.load(json_handle) self.lims_id = list(self.frame_data_location.keys()) + + self.lims_id.sort() + if self.json_data["randomize"]: + if "seed" in self.json_data: + rng = np.random.default_rng(self.json_data["seed"]) + else: + rng = np.random.default_rng() + rng.shuffle(self.lims_id) + self.nb_lims = len(self.lims_id) self.img_per_movie = len( self.frame_data_location[self.lims_id[0]]["frames"]) diff --git a/tests/test_movie_json_generator.py b/tests/test_movie_json_generator.py index d041e158..5dbd365e 100644 --- a/tests/test_movie_json_generator.py +++ b/tests/test_movie_json_generator.py @@ -7,6 +7,11 @@ from deepinterpolation.generator_collection import MovieJSONGenerator +@pytest.fixture(scope='session') +def random_seed_fixture(): + return 221 + + @pytest.fixture(scope='session') def frame_list_fixture(): """ @@ -61,7 +66,8 @@ def movie_path_list_fixture(tmpdir_factory): @pytest.fixture(scope='session') def json_frame_specification_fixture(movie_path_list_fixture, tmpdir_factory, - frame_list_fixture): + frame_list_fixture, + random_seed_fixture): """ yields a dict with the following key/value pairs @@ -106,8 +112,13 @@ def json_frame_specification_fixture(movie_path_list_fixture, data = in_file['data'][()] path_to_data[movie_path] = data + # replicate shuffling that happens inside the generator + rng = np.random.default_rng(random_seed_fixture) + index_list = list(range(len(movie_path_list_fixture))) + rng.shuffle(index_list) + for i_frame in range(len(frame_list_fixture)): - for ii in range(len(movie_path_list_fixture)): + for ii in index_list: this_params = params[str(ii)] mu = this_params['mean'] std = this_params['std'] @@ -134,7 +145,8 @@ def json_frame_specification_fixture(movie_path_list_fixture, @pytest.fixture(scope='session') def json_generator_params_fixture( tmpdir_factory, - json_frame_specification_fixture): + json_frame_specification_fixture, + random_seed_fixture): """ yields the path to the JSON configuration file for the MovieJSONGenerator """ @@ -158,6 +170,7 @@ def json_generator_params_fixture( params['steps_per_epoch'] = -1 params['train_path'] = json_frame_specification_fixture['json_path'] params['type'] = 'generator' + params['seed'] = random_seed_fixture with open(json_path, 'w') as out_file: out_file.write(json.dumps(params, indent=2))