diff --git a/deepinterpolation/cli/schemas.py b/deepinterpolation/cli/schemas.py index 7484316e..4cd9102f 100644 --- a/deepinterpolation/cli/schemas.py +++ b/deepinterpolation/cli/schemas.py @@ -179,6 +179,12 @@ class GeneratorSchema(argschema.schemas.DefaultSchema): inference." ) + seed = argschema.fields.Int( + required=False, + default=235813, + description=("Seed used for random number generator used to " + "shuffle samples")) + total_samples = argschema.fields.Int( required=False, default=-1, diff --git a/deepinterpolation/generator_collection.py b/deepinterpolation/generator_collection.py index 9deedb45..0c36629e 100755 --- a/deepinterpolation/generator_collection.py +++ b/deepinterpolation/generator_collection.py @@ -7,6 +7,7 @@ import nibabel as nib import s3fs import glob +import warnings from deepinterpolation.generic import JsonLoader @@ -439,6 +440,12 @@ def __init__(self, json_path): else: self.randomize = True + if self.randomize: + if "seed" in self.json_data.keys(): + self.rng = np.random.default_rng(self.json_data["seed"]) + else: + self.rng = np.random.default_rng() + if "pre_post_omission" in self.json_data.keys(): self.pre_post_omission = self.json_data["pre_post_omission"] else: @@ -490,7 +497,7 @@ def _calculate_list_samples(self, total_frame_per_movie): self.list_samples = np.arange(self.start_sample, self.end_sample+1) if self.randomize: - np.random.shuffle(self.list_samples) + self.rng.shuffle(self.list_samples) # We cut the number of samples if asked to if (self.total_samples > 0 @@ -1038,6 +1045,15 @@ def __init__(self, json_path): self.frame_data_location = json.load(json_handle) self.lims_id = list(self.frame_data_location.keys()) + + self.lims_id.sort() + if self.json_data["randomize"]: + if "seed" in self.json_data: + rng = np.random.default_rng(self.json_data["seed"]) + else: + rng = np.random.default_rng() + rng.shuffle(self.lims_id) + self.nb_lims = len(self.lims_id) self.img_per_movie = len( self.frame_data_location[self.lims_id[0]]["frames"]) @@ -1109,17 +1125,26 @@ def __data_generation__(self, index_frame): # Initialization local_path = self.frame_data_location[local_lims]["path"] - _filenames = ["motion_corrected_video.h5", "concat_31Hz_0.h5"] - motion_path = [] - for _filename in _filenames: - _filepath = os.path.join(local_path, "processed", _filename) - if os.path.exists(_filepath) and not os.path.islink( - _filepath - ): # Path exists and is not symbolic - motion_path = _filepath - break - - movie_obj = h5py.File(motion_path, "r") + motion_path = None + if os.path.isfile(local_path): + motion_path = local_path + else: + _filenames = ["motion_corrected_video.h5", "concat_31Hz_0.h5"] + for _filename in _filenames: + _filepath = os.path.join(local_path, + "processed", + _filename) + + if os.path.exists(_filepath) and not os.path.islink( + _filepath + ): # Path exists and is not symbolic + motion_path = _filepath + break + + if motion_path is None: + msg = "unable to find valid movie file for path\n" + msg += f"{local_path}" + raise RuntimeError(msg) local_frame_data = self.frame_data_location[local_lims] output_frame = local_frame_data["frames"][local_img] @@ -1131,8 +1156,8 @@ def __data_generation__(self, index_frame): output_full = np.zeros([1, 512, 512, 1]) input_index = np.arange( - output_frame - self.pre_frame - self.pre_post_omission, - output_frame + self.post_frame + self.pre_post_omission + 1, + output_frame - self.pre_frame - self.pre_post_omission, + output_frame + self.post_frame + self.pre_post_omission + 1, ) input_index = input_index[input_index != output_frame] @@ -1142,8 +1167,10 @@ def __data_generation__(self, index_frame): input_index = input_index[input_index != output_frame + index_padding] - data_img_input = movie_obj["data"][input_index, :, :] - data_img_output = movie_obj["data"][output_frame, :, :] + with h5py.File(motion_path, "r") as movie_obj: + + data_img_input = movie_obj["data"][input_index, :, :] + data_img_output = movie_obj["data"][output_frame, :, :] data_img_input = np.swapaxes(data_img_input, 1, 2) data_img_input = np.swapaxes(data_img_input, 0, 2) @@ -1159,8 +1186,11 @@ def __data_generation__(self, index_frame): : img_in_shape[1], :] = data_img_input output_full[0, : img_out_shape[0], : img_out_shape[1], 0] = data_img_output - movie_obj.close() return input_full, output_full - except Exception: - print("Issues with " + str(self.lims_id)) + + except Exception as err: + msg = f"Issues with {local_lims}\n" + msg += f"Error: {str(err)}\n" + msg += "moving on\n" + warnings.warn(msg) diff --git a/tests/test_movie_json_generator.py b/tests/test_movie_json_generator.py new file mode 100644 index 00000000..5dbd365e --- /dev/null +++ b/tests/test_movie_json_generator.py @@ -0,0 +1,214 @@ +import pytest +import h5py +import json +import tempfile +import pathlib +import numpy as np +from deepinterpolation.generator_collection import MovieJSONGenerator + + +@pytest.fixture(scope='session') +def random_seed_fixture(): + return 221 + + +@pytest.fixture(scope='session') +def frame_list_fixture(): + """ + Indexes of frames returned by MovieJSONGenerator + """ + return [4, 3, 5, 7] + + +@pytest.fixture(scope='session') +def movie_path_list_fixture(tmpdir_factory): + """ + yields a list of paths to test movie files + """ + path_list = [] + + parent_tmpdir = tmpdir_factory.mktemp('movies_for_test') + rng = np.random.default_rng(172312) + this_dir = tempfile.mkdtemp(dir=parent_tmpdir) + this_path = tempfile.mkstemp(dir=this_dir, suffix='.h5')[1] + with h5py.File(this_path, 'w') as out_file: + out_file.create_dataset('data', + data=rng.random((12, 512, 512))) + + path_list.append(this_path) + + this_dir = tempfile.mkdtemp(dir=parent_tmpdir) + this_dir = pathlib.Path(this_dir) / 'processed' + this_dir.mkdir() + this_path = this_dir / 'concat_31Hz_0.h5' + with h5py.File(this_path, 'w') as out_file: + out_file.create_dataset('data', + data=rng.random((12, 512, 512))) + path_list.append(str(this_path.resolve().absolute())) + + this_dir = tempfile.mkdtemp(dir=parent_tmpdir) + this_dir = pathlib.Path(this_dir) / 'processed' + this_dir.mkdir() + this_path = this_dir / 'motion_corrected_video.h5' + with h5py.File(this_path, 'w') as out_file: + out_file.create_dataset('data', + data=rng.random((12, 512, 512))) + path_list.append(str(this_path.resolve().absolute())) + + yield path_list + + for this_path in path_list: + this_path = pathlib.Path(this_path) + if this_path.is_file(): + this_path.unlink() + + +@pytest.fixture(scope='session') +def json_frame_specification_fixture(movie_path_list_fixture, + tmpdir_factory, + frame_list_fixture, + random_seed_fixture): + """ + yields a dict with the following key/value pairs + + 'json_path' -- path to the file specifying the + movies/frames for the generator + + 'expected_input' -- list of expected input + datasets returned by the generator + + 'expected_output' -- list of expected output + datasets returned by the generator + """ + + params = dict() + + for ii, movie_path in enumerate(movie_path_list_fixture): + this_params = dict() + if ii > 0: + movie_path = str(pathlib.Path(movie_path).parent.parent) + this_params['path'] = movie_path + this_params['frames'] = frame_list_fixture + this_params['mean'] = (ii+1)*2.1 + this_params['std'] = (ii+1)*3.4 + params[str(ii)] = this_params + + tmpdir = tmpdir_factory.mktemp('frame_specification') + json_path = tempfile.mkstemp( + dir=tmpdir, + prefix='frame_specification_params_', + suffix='.json')[1] + with open(json_path, 'w') as out_file: + out_file.write(json.dumps(params)) + + # now construct the input and output frames that + # we expect this generator to yield + expected_output_frames = [] + expected_input_frames = [] + + path_to_data = dict() + for movie_path in movie_path_list_fixture: + with h5py.File(movie_path, 'r') as in_file: + data = in_file['data'][()] + path_to_data[movie_path] = data + + # replicate shuffling that happens inside the generator + rng = np.random.default_rng(random_seed_fixture) + index_list = list(range(len(movie_path_list_fixture))) + rng.shuffle(index_list) + + for i_frame in range(len(frame_list_fixture)): + for ii in index_list: + this_params = params[str(ii)] + mu = this_params['mean'] + std = this_params['std'] + movie_path = movie_path_list_fixture[ii] + data = path_to_data[movie_path] + frame = frame_list_fixture[i_frame] + output_data = (data[frame, :, :] - mu)/std + + input_indexes = np.array([frame-2, frame-1, frame+1, frame+2]) + input_data = (data[input_indexes, :, :]-mu)/std + + expected_output_frames.append(output_data) + expected_input_frames.append(input_data) + + yield {'json_path': json_path, + 'expected_input': expected_input_frames, + 'expected_output': expected_output_frames} + + json_path = pathlib.Path(json_path) + if json_path.is_file(): + json_path.unlink() + + +@pytest.fixture(scope='session') +def json_generator_params_fixture( + tmpdir_factory, + json_frame_specification_fixture, + random_seed_fixture): + """ + yields the path to the JSON configuration file for the MovieJSONGenerator + """ + + tmpdir = tmpdir_factory.mktemp('json_generator_params') + json_path = tempfile.mkstemp(dir=tmpdir, + prefix='movie_json_generator_params_', + suffix='.json')[1] + + params = dict() + params['pre_post_omission'] = 0 + params['total_samples'] = -1 + params['name'] = 'MovieJSONGenerator' + params['batch_size'] = 1 + params['start_frame'] = 0 + params['end_frame'] = -1 + params['pre_frame'] = 2 + params['post_frame'] = 2 + params['randomize'] = True + params['data_path'] = json_frame_specification_fixture['json_path'] + params['steps_per_epoch'] = -1 + params['train_path'] = json_frame_specification_fixture['json_path'] + params['type'] = 'generator' + params['seed'] = random_seed_fixture + + with open(json_path, 'w') as out_file: + out_file.write(json.dumps(params, indent=2)) + + yield json_path + + json_path = pathlib.Path(json_path) + if json_path.is_file(): + json_path.unlink() + + +def test_movie_json_generator( + movie_path_list_fixture, + json_frame_specification_fixture, + json_generator_params_fixture, + frame_list_fixture): + + expected_input = json_frame_specification_fixture['expected_input'] + expected_output = json_frame_specification_fixture['expected_output'] + + generator = MovieJSONGenerator(json_generator_params_fixture) + lims_id_list = generator.lims_id + + n_frames = len(frame_list_fixture) + dataset_ct = 0 + + for dataset in generator: + # check that the dataset contains the expected input/output frames + expected_i = expected_input[dataset_ct] + expected_o = expected_output[dataset_ct] + + actual_o = dataset[1][0, :, :, 0] + np.testing.assert_array_equal(actual_o, expected_o) + + actual_i = dataset[0][0, :, :, :].transpose(2, 0, 1) + np.testing.assert_array_equal(actual_i, expected_i) + + dataset_ct += 1 + + # make sure we got the expected number of datasets + assert dataset_ct == len(lims_id_list)*n_frames