diff --git a/src/emmental/contrib/slicing/modules/utils.py b/src/emmental/contrib/slicing/modules/utils.py index 7f22483..e61b372 100644 --- a/src/emmental/contrib/slicing/modules/utils.py +++ b/src/emmental/contrib/slicing/modules/utils.py @@ -23,9 +23,7 @@ def ce_loss( Loss. """ return F.cross_entropy( - intermediate_output_dict[module_name][0], - Y.view(-1) - 1, - weight, + intermediate_output_dict[module_name][0], Y.view(-1) - 1, weight ) diff --git a/src/emmental/logging/checkpointer.py b/src/emmental/logging/checkpointer.py index 24285ff..c9bf02f 100644 --- a/src/emmental/logging/checkpointer.py +++ b/src/emmental/logging/checkpointer.py @@ -137,9 +137,7 @@ def checkpoint( # Save optimizer state optimizer_path = f"{self.checkpoint_path}/checkpoint_{iteration}.optimizer.pth" - optimizer_dict = { - "optimizer": optimizer.state_dict(), - } + optimizer_dict = {"optimizer": optimizer.state_dict()} torch.save(optimizer_dict, optimizer_path) # Save lr_scheduler state @@ -165,10 +163,7 @@ def checkpoint( f"{self.checkpoint_path}/best_model_" f"{metric.replace('/', '_')}.model.pth" ) - copyfile( - model_path, - best_metric_model_path, - ) + copyfile(model_path, best_metric_model_path) logger.info( f"Save best model of metric {metric} to {best_metric_model_path}" ) diff --git a/src/emmental/meta.py b/src/emmental/meta.py index c4149d3..52b91f8 100644 --- a/src/emmental/meta.py +++ b/src/emmental/meta.py @@ -251,10 +251,7 @@ def check_config() -> None: Meta.config["logging_config"]["evaluation_freq"] = new_evaluation_freq if ( - Meta.config["logging_config"]["counter_unit"] - in [ - "epoch", - ] + Meta.config["logging_config"]["counter_unit"] in ["epoch"] and isinstance(Meta.config["logging_config"]["evaluation_freq"], int) and Meta.config["logging_config"]["writer_config"]["write_loss_per_step"] ): diff --git a/src/emmental/model.py b/src/emmental/model.py index f103010..ddac8a4 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -20,6 +20,7 @@ from emmental.utils.utils import ( array_to_numpy, construct_identifier, + merge_objects, move_to_device, prob_to_pred, ) @@ -323,7 +324,7 @@ def forward( # type: ignore Dict[str, Tensor], Dict[str, Union[ndarray, List[ndarray]]], Dict[str, Union[ndarray, List[ndarray]]], - Dict[str, Dict[str, Union[ndarray, List]]], + Dict[str, Dict[str, Union[ndarray, List, int, float, Dict]]], ], Tuple[ Dict[str, List[str]], @@ -356,7 +357,7 @@ def forward( # type: ignore prob_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if return_probs else None ) - out_dict: Dict[str, Dict[str, Union[ndarray, List]]] = ( + out_dict: Dict[str, Dict[str, Union[ndarray, List, int, float, Dict]]] = ( defaultdict(lambda: defaultdict(list)) if return_action_outputs else None ) @@ -378,8 +379,7 @@ def forward( # type: ignore loss_dict[task_name] = self.loss_funcs[task_name]( output_dict, move_to_device( - Y_dict[label_name], - Meta.config["model_config"]["device"], + Y_dict[label_name], Meta.config["model_config"]["device"] ) if Y_dict is not None and label_name is not None else None, @@ -403,9 +403,11 @@ def forward( # type: ignore and self.action_outputs[task_name] is not None ): for action_name, output_index in self.action_outputs[task_name]: - out_dict[task_name][f"{action_name}_{output_index}"] = ( - output_dict[action_name][output_index].cpu().detach().numpy() + action_output = output_dict[action_name][output_index] + action_output = move_to_device( + action_output, -1, detach=True, convert_to_numpy=True ) + out_dict[task_name][f"{action_name}_{output_index}"] = action_output if return_action_outputs: return uid_dict, loss_dict, prob_dict, gold_dict, out_dict @@ -446,7 +448,7 @@ def predict( pred_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if return_preds else None ) - out_dict: Dict[str, Dict[str, List[Union[ndarray, int, float]]]] = ( + out_dict: Dict[str, Dict[str, Union[ndarray, List, int, float, Dict]]] = ( defaultdict(lambda: defaultdict(list)) if return_action_outputs else None ) loss_dict: Dict[str, Union[ndarray, float]] = ( @@ -526,8 +528,13 @@ def predict( if return_action_outputs and out_bdict: for task_name in out_bdict.keys(): for action_name in out_bdict[task_name].keys(): - out_dict[task_name][action_name].extend( + out_dict[task_name][action_name] = ( out_bdict[task_name][action_name] + if out_dict[task_name][action_name] == [] + else merge_objects( + out_dict[task_name][action_name], + out_bdict[task_name][action_name], + ) ) # Calculate average loss @@ -536,11 +543,7 @@ def predict( if not isinstance(loss_dict[task_name], list): loss_dict[task_name] /= len(uid_dict[task_name]) - res = { - "uids": uid_dict, - "golds": gold_dict, - "losses": loss_dict, - } + res = {"uids": uid_dict, "golds": gold_dict, "losses": loss_dict} if return_probs: for task_name in prob_dict.keys(): @@ -734,11 +737,7 @@ def save( if Meta.config["meta_config"]["verbose"] and verbose: logger.info(f"[{self.name}] Model saved in {model_path}") - def load( - self, - model_path: str, - verbose: bool = True, - ) -> None: + def load(self, model_path: str, verbose: bool = True) -> None: """Load model state_dict from file and reinitialize the model weights. Args: diff --git a/src/emmental/utils/parse_args.py b/src/emmental/utils/parse_args.py index 54b8eb0..895190e 100644 --- a/src/emmental/utils/parse_args.py +++ b/src/emmental/utils/parse_args.py @@ -819,10 +819,7 @@ def parse_args(parser: Optional[ArgumentParser] = None) -> ArgumentParser: ) logging_config.add_argument( - "--wandb_run_name", - type=nullable_string, - default=None, - help="Wandb run name", + "--wandb_run_name", type=nullable_string, default=None, help="Wandb run name" ) logging_config.add_argument( diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index 6b89ecd..17501e2 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -122,7 +122,10 @@ def pred_to_prob(preds: ndarray, n_classes: int) -> ndarray: def move_to_device( - obj: Any, device: Optional[Union[int, str, torch.device]] = -1 + obj: Union[Tensor, ndarray, dict, list, tuple], + device: Optional[Union[int, str, torch.device]] = -1, + detach: bool = False, + convert_to_numpy: bool = False, ) -> Any: """Move object to specified device. @@ -147,17 +150,98 @@ def move_to_device( device = torch.device("cpu") if isinstance(obj, torch.Tensor): - return obj.to(device) + obj.to(device) + if detach: + obj = obj.detach() + if convert_to_numpy: + obj = obj.numpy() + return obj elif isinstance(obj, dict): - return {key: move_to_device(value, device) for key, value in obj.items()} + return { + key: move_to_device(value, device, detach, convert_to_numpy) + for key, value in obj.items() + } elif isinstance(obj, list): - return [move_to_device(item, device) for item in obj] + return [move_to_device(item, device, detach, convert_to_numpy) for item in obj] elif isinstance(obj, tuple): - return tuple([move_to_device(item, device) for item in obj]) + return tuple( + [move_to_device(item, device, detach, convert_to_numpy) for item in obj] + ) else: return obj +def merge_objects(obj_1: Any, obj_2: Any) -> Any: + """Merge two objects of the same type. + + Given two objects of the same type and structure, merges the second object + into the first object. If either of the objects is empty, the non-empty + object is returned. Supported types include torch tensors, numpy arrays + lists, dicts and tuples. For tensors and arrays, objects are merged + along the 1st dimension: + + obj_1: torch.Tensor([1,2]), obj_2: torch.Tensor([2,3]) + merged object: torch.Tensor([[1,2],[2,3]]) + + Args: + obj_1: first object. + obj_2: second object to be merged into the first object. + + Returns: + an object reflecting the merged output of the two inputs. + """ + if type(obj_1) != type(obj_2): + raise TypeError( + f"Cannot merge object of type {type(obj_1)} " + f"with object of type {type(obj_2)}." + ) + if isinstance(obj_1, torch.Tensor): + # empty edge case + if not obj_1.size()[0]: + return obj_2 + elif not obj_2.size()[0]: + return obj_1 + + # unsqueeze of object is 1D and not empty + if len(obj_1.shape) == 1: + obj_1 = obj_1.unsqueeze(0) + if len(obj_2.shape) == 1: + obj_2 = obj_2.unsqueeze(0) + return torch.cat([obj_1, obj_2]) + elif isinstance(obj_1, np.ndarray): + # empty edge case + if not obj_1.size: + return obj_2 + elif not obj_2.size: + return obj_1 + + # expand if array has 1 dimension + if len(obj_1.shape) == 1: + obj_1 = np.expand_dims(obj_1, axis=0) + if len(obj_2.shape) == 1: + obj_2 = np.expand_dims(obj_2, axis=0) + return np.concatenate((obj_1, obj_2)) + elif isinstance(obj_1, list): + obj_1.extend(obj_2) + return obj_1 + elif isinstance(obj_1, dict): + if not obj_1: + return obj_2 + elif not obj_2: + return obj_1 + + for key, value in obj_1.items(): + obj_1[key] = merge_objects(value, obj_2[key]) + return obj_1 + elif isinstance(obj_1, tuple): + merged_tuple_vals = [] + for idx in range(len(obj_1)): + merged_tuple_vals.append(merge_objects(obj_1[idx], obj_2[idx])) + return tuple(merged_tuple_vals) + else: + return obj_1 + + def array_to_numpy( array: Union[ndarray, List[Any], Tensor], flatten: bool = False ) -> ndarray: diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 8591aa8..f07f397 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -102,6 +102,11 @@ def grouped_parameters(model): torch.tensor(Y2[int(0.8 * N) : int(0.9 * N)]), torch.tensor(Y2[int(0.9 * N) :]), ) + Y3_train, Y3_dev, Y3_test = ( + torch.tensor(Y2[: int(0.8 * N)]), + torch.tensor(Y2[int(0.8 * N) : int(0.9 * N)]), + torch.tensor(Y2[int(0.9 * N) :]), + ) train_dataset1 = EmmentalDataset( name="synthetic", X_dict={"data": X_train}, Y_dict={"label1": Y1_train} @@ -110,6 +115,9 @@ def grouped_parameters(model): train_dataset2 = EmmentalDataset( name="synthetic", X_dict={"data": X_train}, Y_dict={"label2": Y2_train} ) + train_dataset3 = EmmentalDataset( + name="synthetic", X_dict={"data": X_train}, Y_dict={"label3": Y3_train} + ) dev_dataset1 = EmmentalDataset( name="synthetic", X_dict={"data": X_dev}, Y_dict={"label1": Y1_dev} @@ -118,6 +126,9 @@ def grouped_parameters(model): dev_dataset2 = EmmentalDataset( name="synthetic", X_dict={"data": X_dev}, Y_dict={"label2": Y2_dev} ) + dev_dataset3 = EmmentalDataset( + name="synthetic", X_dict={"data": X_dev}, Y_dict={"label3": Y3_dev} + ) test_dataset1 = EmmentalDataset( name="synthetic", X_dict={"data": X_test}, Y_dict={"label1": Y1_test} @@ -126,6 +137,9 @@ def grouped_parameters(model): test_dataset2 = EmmentalDataset( name="synthetic", X_dict={"data": X_test}, Y_dict={"label2": Y2_test} ) + test_dataset4 = EmmentalDataset( + name="synthetic", X_dict={"data": X_test}, Y_dict={"label3": Y3_test} + ) test_dataset3 = EmmentalDataset(name="synthetic", X_dict={"data": X_test}) @@ -177,6 +191,26 @@ def grouped_parameters(model): split="test", batch_size=10, ) + task_to_label_dict = {"task3": "label3"} + + train_dataloader3 = EmmentalDataLoader( + task_to_label_dict=task_to_label_dict, + dataset=train_dataset3, + split="train", + batch_size=10, + ) + dev_dataloader3 = EmmentalDataLoader( + task_to_label_dict=task_to_label_dict, + dataset=dev_dataset3, + split="valid", + batch_size=10, + ) + test_dataloader4 = EmmentalDataLoader( + task_to_label_dict=task_to_label_dict, + dataset=test_dataset4, + split="test", + batch_size=10, + ) # Create task def ce_loss(task_name, immediate_output_dict, Y): @@ -187,7 +221,11 @@ def output(task_name, immediate_output_dict): module_name = f"{task_name}_pred_head" return F.softmax(immediate_output_dict[module_name][0], dim=1) - task_metrics = {"task1": ["accuracy"], "task2": ["accuracy", "roc_auc"]} + task_metrics = { + "task1": ["accuracy"], + "task2": ["accuracy", "roc_auc"], + "task3": ["accuracy"], + } class IdentityModule(nn.Module): def __init__(self): @@ -197,6 +235,32 @@ def __init__(self): def forward(self, input): return {"out": input} + class IdentityDictModule(nn.Module): + def __init__(self): + """Initialize IdentityModule.""" + super().__init__() + + def forward(self, input): + return {"out": {"image_pil": input}} + + class LinearLayerDictModule(nn.Module): + def __init__(self): + """Initialize IdentityModule.""" + super().__init__() + self.linear = nn.Linear(2, 8) + + def forward(self, input): + return {"out": {"image_pil": self.linear(input["image_pil"])}} + + class PredictHeadDictModule(nn.Module): + def __init__(self): + """Initialize IdentityModule.""" + super().__init__() + self.linear = nn.Linear(8, 2) + + def forward(self, input): + return self.linear(input["image_pil"]) + tasks = [ EmmentalTask( name=task_name, @@ -240,6 +304,45 @@ def forward(self, input): ) for task_name in ["task1", "task2"] ] + tasks.append( + EmmentalTask( + name="task3", + module_pool=nn.ModuleDict( + { + "task3_input_module": IdentityDictModule(), + "task3_input_module1": LinearLayerDictModule(), + "task3_pred_head": PredictHeadDictModule(), + } + ), + task_flow=[ + { + "name": "input_t3", + "module": "task3_input_module", + "inputs": [("_input_", "data")], + }, + { + "name": "input1_t3", + "module": "task3_input_module1", + "inputs": [("input_t3", "out")], + }, + { + "name": "task3_pred_head", + "module": "task3_pred_head", + "inputs": [("input1_t3", "out")], + }, + ], + module_device={"task3_input_module": -1}, + loss_func=partial(ce_loss, "task3"), + output_func=partial(output, "task3"), + action_outputs=[ + ("task3_pred_head", 0), + ("_input_", "data"), + ("input1_t3", "out"), + ], + scorer=Scorer(metrics=task_metrics["task3"]), + require_prob_for_eval=True, + ) + ) # Build model mtl_model = EmmentalModel(name="all", tasks=tasks) @@ -252,7 +355,14 @@ def forward(self, input): # Learning emmental_learner.learn( mtl_model, - [train_dataloader1, train_dataloader2, dev_dataloader1, dev_dataloader2], + [ + train_dataloader1, + train_dataloader2, + train_dataloader3, + dev_dataloader1, + dev_dataloader2, + dev_dataloader3, + ], ) test1_score = mtl_model.score(test_dataloader1) @@ -268,9 +378,10 @@ def forward(self, input): test2_pred = mtl_model.predict(test_dataloader2, return_action_outputs=True) test3_pred = mtl_model.predict( - test_dataloader3, - return_action_outputs=True, - return_loss=False, + test_dataloader3, return_action_outputs=True, return_loss=False + ) + test4_pred = mtl_model.predict( + test_dataloader4, return_action_outputs=True, return_loss=False ) assert test2_pred["uids"] == test3_pred["uids"] @@ -296,6 +407,10 @@ def forward(self, input): ) for idx in range(len(test2_pred["outputs"]["task2"]["_input__data"])) ] + assert test4_pred["outputs"]["task3"]["input1_t3_out"]["image_pil"].shape == (10, 8) + assert isinstance( + test4_pred["outputs"]["task3"]["input1_t3_out"]["image_pil"], np.ndarray + ) test4_pred = mtl_model.predict(test_dataloader2, return_action_outputs=False) assert "outputs" not in test4_pred diff --git a/tests/e2e/test_e2e_mixed.py b/tests/e2e/test_e2e_mixed.py index 57ddccb..263d40f 100644 --- a/tests/e2e/test_e2e_mixed.py +++ b/tests/e2e/test_e2e_mixed.py @@ -117,10 +117,7 @@ def test_e2e_mixed(caplog): def ave_scorer(metric_score_dict): logger.info(metric_score_dict) - metric_names = [ - "task1/synthetic/test/loss", - "task2/synthetic/test/loss", - ] + metric_names = ["task1/synthetic/test/loss", "task2/synthetic/test/loss"] total = 0.0 cnt = 0 @@ -219,10 +216,7 @@ def forward(self, input): emmental_learner = EmmentalLearner() # Learning - emmental_learner.learn( - mtl_model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(mtl_model, [train_dataloader, dev_dataloader]) test_score = mtl_model.score(test_dataloader) diff --git a/tests/e2e/test_e2e_no_y_dict.py b/tests/e2e/test_e2e_no_y_dict.py index f1e9603..35a87c9 100644 --- a/tests/e2e/test_e2e_no_y_dict.py +++ b/tests/e2e/test_e2e_no_y_dict.py @@ -76,18 +76,15 @@ def test_e2e_no_y_dict(caplog): ) train_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_train, "label1": Y_train}, + name="synthetic", X_dict={"data": X_train, "label1": Y_train} ) dev_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_dev, "label1": Y_dev}, + name="synthetic", X_dict={"data": X_dev, "label1": Y_dev} ) test_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_test, "label1": Y_test}, + name="synthetic", X_dict={"data": X_test, "label1": Y_test} ) task_name = "task1" @@ -147,11 +144,7 @@ def forward(self, input): "module": "input_module0", "inputs": [("_input_", "data")], }, - { - "name": "input1", - "module": "input_module1", - "inputs": [("input", "out")], - }, + {"name": "input1", "module": "input_module1", "inputs": [("input", "out")]}, { "name": f"{task_name}_pred_head", "module": f"{task_name}_pred_head", @@ -173,10 +166,7 @@ def forward(self, input): emmental_learner = EmmentalLearner() # Learning - emmental_learner.learn( - mtl_model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(mtl_model, [train_dataloader, dev_dataloader]) test_score = mtl_model.score(test_dataloader) diff --git a/tests/e2e/test_e2e_skip_trained.py b/tests/e2e/test_e2e_skip_trained.py index 51bed23..b083cc2 100644 --- a/tests/e2e/test_e2e_skip_trained.py +++ b/tests/e2e/test_e2e_skip_trained.py @@ -51,21 +51,15 @@ def test_e2e_skip_trained_step(caplog): ) train_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_train}, - Y_dict={"label1": Y_train}, + name="synthetic", X_dict={"data": X_train}, Y_dict={"label1": Y_train} ) dev_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_dev}, - Y_dict={"label1": Y_dev}, + name="synthetic", X_dict={"data": X_dev}, Y_dict={"label1": Y_dev} ) test_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_test}, - Y_dict={"label1": Y_test}, + name="synthetic", X_dict={"data": X_test}, Y_dict={"label1": Y_test} ) task_to_label_dict = {"task1": "label1"} @@ -182,10 +176,7 @@ def forward(self, input): Meta.update_config(config) # Learning - emmental_learner.learn( - model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(model, [train_dataloader, dev_dataloader]) test_score = model.score(test_dataloader) @@ -236,10 +227,7 @@ def forward(self, input): model.load(Meta.config["model_config"]["model_path"]) # Learning - emmental_learner.learn( - model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(model, [train_dataloader, dev_dataloader]) test_score = model.score(test_dataloader) @@ -277,21 +265,15 @@ def test_e2e_skip_trained_epoch(caplog): ) train_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_train}, - Y_dict={"label1": Y_train}, + name="synthetic", X_dict={"data": X_train}, Y_dict={"label1": Y_train} ) dev_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_dev}, - Y_dict={"label1": Y_dev}, + name="synthetic", X_dict={"data": X_dev}, Y_dict={"label1": Y_dev} ) test_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_test}, - Y_dict={"label1": Y_test}, + name="synthetic", X_dict={"data": X_test}, Y_dict={"label1": Y_test} ) task_to_label_dict = {"task1": "label1"} @@ -412,10 +394,7 @@ def forward(self, input): Meta.update_config(config) # Learning - emmental_learner.learn( - model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(model, [train_dataloader, dev_dataloader]) test_score = model.score(test_dataloader) @@ -470,10 +449,7 @@ def forward(self, input): model.load(Meta.config["model_config"]["model_path"]) # Learning - emmental_learner.learn( - model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(model, [train_dataloader, dev_dataloader]) test_score = model.score(test_dataloader) diff --git a/tests/task/test_task.py b/tests/task/test_task.py index 497471a..d55ce8c 100644 --- a/tests/task/test_task.py +++ b/tests/task/test_task.py @@ -42,11 +42,7 @@ def output(module_name, output_dict): "module": "input_module0", "inputs": [("_input_", "data")], }, - { - "name": "input2", - "module": "input_module1", - "inputs": [("input1", 0)], - }, + {"name": "input2", "module": "input_module1", "inputs": [("input1", 0)]}, { "name": f"{task_name}_pred_head", "module": f"{task_name}_pred_head", diff --git a/tests/test_meta.py b/tests/test_meta.py index 83641de..79a28de 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -61,33 +61,19 @@ def test_config_check_in_meta(caplog): Meta.reset() init(dirpath) - config = { - "logging_config": { - "evaluation_freq": 5.0, - }, - } + config = {"logging_config": {"evaluation_freq": 5.0}} Meta.update_config(config) assert type(Meta.config["logging_config"]["evaluation_freq"]) == int assert Meta.config["logging_config"]["evaluation_freq"] == 5 - config = { - "logging_config": { - "counter_unit": "batch", - "evaluation_freq": 2.3, - }, - } + config = {"logging_config": {"counter_unit": "batch", "evaluation_freq": 2.3}} Meta.update_config(config) assert type(Meta.config["logging_config"]["evaluation_freq"]) == int assert Meta.config["logging_config"]["evaluation_freq"] == 3 - config = { - "logging_config": { - "counter_unit": "sample", - "evaluation_freq": 0.2, - }, - } + config = {"logging_config": {"counter_unit": "sample", "evaluation_freq": 0.2}} Meta.update_config(config) assert type(Meta.config["logging_config"]["evaluation_freq"]) == int @@ -98,7 +84,7 @@ def test_config_check_in_meta(caplog): "counter_unit": "epoch", "evaluation_freq": 1, "writer_config": {"write_loss_per_step": True}, - }, + } } Meta.update_config(config) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 7cc4319..c8656bc 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -11,6 +11,7 @@ construct_identifier, convert_to_serializable_json, merge, + merge_objects, move_to_device, nullable_float, nullable_int, @@ -68,6 +69,76 @@ def test_move_to_device(caplog): assert move_to_device((torch.tensor([1, 2]), torch.tensor([3, 4])), -1) +def test_merge_objects(caplog): + """Unit test of merge_objects.""" + caplog.set_level(logging.INFO) + + assert torch.equal( + merge_objects(torch.Tensor([1, 2]), torch.Tensor([2, 3])), + torch.Tensor([[1, 2], [2, 3]]), + ) + assert torch.equal( + merge_objects(torch.Tensor(), torch.Tensor([2, 3])), + torch.Tensor([2, 3]), + ) + assert torch.equal( + merge_objects(torch.Tensor([2, 3]), torch.Tensor()), + torch.Tensor([2, 3]), + ) + assert merge_objects( + torch.zeros((128, 256)), torch.zeros((64, 256)) + ).shape == torch.Size([192, 256]) + + assert np.array_equal( + merge_objects(np.array([1, 2]), np.array([2, 3])), np.array([[1, 2], [2, 3]]) + ) + assert np.array_equal( + merge_objects(np.array([]), np.array([2, 3])), np.array([2, 3]) + ) + assert np.array_equal( + merge_objects(np.zeros((2, 3)), np.array([])), np.zeros((2, 3)) + ) + + assert 1 == merge_objects(1, 2) + + try: + merge_objects([], torch.Tensor(1)) + assert False + except (TypeError): + assert True + + assert merge_objects({"a": [1, 2]}, {"a": [2, 3]}) == {"a": [1, 2, 2, 3]} + assert merge_objects({"a": [1, 2]}, {}) == {"a": [1, 2]} + assert merge_objects({}, {"a": [1, 2]}) == {"a": [1, 2]} + assert merge_objects(([2, 4], [3, 4]), ([3, 4], [4, 5])) == ( + [2, 4, 3, 4], + [3, 4, 4, 5], + ) + assert ( + torch.equal( + merge_objects( + (torch.Tensor([2]), torch.Tensor([2]), [2, 3]), + (torch.Tensor([3]), torch.Tensor([2]), [3, 4]), + )[0], + torch.Tensor([[2], [3]]), + ) + and torch.equal( + merge_objects( + (torch.Tensor([2]), torch.Tensor([2]), [2, 3]), + (torch.Tensor([3]), torch.Tensor([2]), [3, 4]), + )[1], + torch.Tensor([[2], [2]]), + ) + and merge_objects( + (torch.Tensor([2]), torch.Tensor([2]), [2, 3]), + (torch.Tensor([3]), torch.Tensor([2]), [3, 4]), + )[2] + == [2, 3, 3, 4] + ) + + assert merge_objects([1, 2, 3], [2, 3, 4]) == [1, 2, 3, 2, 3, 4] + + def test_array_to_numpy(caplog): """Unit test of array_to_numpy.""" caplog.set_level(logging.INFO)