From 2868e70667138fdaaaf374dd356c9272606ad124 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Thu, 28 Oct 2021 21:11:50 -0700 Subject: [PATCH 01/29] [add] support for action outputs of type dict --- src/emmental/model.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index f103010..70f91fc 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -22,6 +22,7 @@ construct_identifier, move_to_device, prob_to_pred, + merge_objects, ) if importlib.util.find_spec("ipywidgets") is not None: @@ -356,7 +357,7 @@ def forward( # type: ignore prob_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if return_probs else None ) - out_dict: Dict[str, Dict[str, Union[ndarray, List]]] = ( + out_dict: Dict[str, Dict[str, Union[ndarray, List, Dict]]] = ( defaultdict(lambda: defaultdict(list)) if return_action_outputs else None ) @@ -403,8 +404,16 @@ def forward( # type: ignore and self.action_outputs[task_name] is not None ): for action_name, output_index in self.action_outputs[task_name]: + action_output = output_dict[action_name][output_index] + if isinstance(action_output, dict): + action_output = move_to_device(action_output, -1) + for key, value in action_output.items(): + action_output[key] = [value.detach().numpy()] + else: + action_output = action_output.cpu().detach().numpy() + out_dict[task_name][f"{action_name}_{output_index}"] = ( - output_dict[action_name][output_index].cpu().detach().numpy() + action_output ) if return_action_outputs: @@ -446,7 +455,7 @@ def predict( pred_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if return_preds else None ) - out_dict: Dict[str, Dict[str, List[Union[ndarray, int, float]]]] = ( + out_dict: Dict[str, Dict[str, Union[dict, List[Union[ndarray, int, float, dict]]]]] = ( defaultdict(lambda: defaultdict(list)) if return_action_outputs else None ) loss_dict: Dict[str, Union[ndarray, float]] = ( @@ -526,9 +535,14 @@ def predict( if return_action_outputs and out_bdict: for task_name in out_bdict.keys(): for action_name in out_bdict[task_name].keys(): - out_dict[task_name][action_name].extend( - out_bdict[task_name][action_name] - ) + if not out_dict[task_name][action_name]: + out_dict[task_name][action_name] = out_bdict[task_name][action_name] + else: + out_dict[task_name][action_name] = merge_objects( + out_dict[task_name][action_name], + out_bdict[task_name][action_name] + ) + # Calculate average loss if return_loss: From 2a11c331a8796766ae25a7b2ce2ca0e049927d17 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Thu, 28 Oct 2021 21:12:22 -0700 Subject: [PATCH 02/29] [add] merge_objects function --- src/emmental/utils/utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index 6b89ecd..0651a4a 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -157,6 +157,23 @@ def move_to_device( else: return obj +def merge_objects(obj_1: Any, obj_2: Any): + if isinstance(obj_1, torch.Tensor): + return torch.cat([obj_1, obj_2]) + elif isinstance(obj_1, dict): + if not obj_1: return obj_2 + elif not obj_2: return obj_1 + else: + for key, value in obj_1.items(): + obj_1[key] = merge_objects(value, obj_2[key]) + return obj_1 + elif isinstance(obj_1, list): + obj_1.extend(obj_2) + return obj_1 + elif isinstance(obj_1, np.ndarray): + return np.append((obj_1, obj_2)) + else: + return obj_1 def array_to_numpy( array: Union[ndarray, List[Any], Tensor], flatten: bool = False From bcb2a304aa896e2149162ac7cb6ae6fe28617445 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Thu, 28 Oct 2021 23:03:26 -0700 Subject: [PATCH 03/29] [format] imports --- src/emmental/model.py | 147 ++++++++++++++++++++++++++++++------------ 1 file changed, 107 insertions(+), 40 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index 70f91fc..a430345 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -95,7 +95,9 @@ def _move_to_device(self) -> None: if device != torch.device("cpu"): if torch.cuda.is_available(): if Meta.config["meta_config"]["verbose"]: - logger.info(f"Moving {module_name} module to GPU ({device}).") + logger.info( + f"Moving {module_name} module to GPU ({device})." + ) self.module_pool[module_name].to(device) else: if Meta.config["meta_config"]["verbose"]: @@ -127,7 +129,9 @@ def _to_distributed_dataparallel(self) -> None: # TODO support multiple device with DistributedDataParallel for key in self.module_pool.keys(): # Ensure there is some gradient parameter for DDP - if not any(p.requires_grad for p in self.module_pool[key].parameters()): + if not any( + p.requires_grad for p in self.module_pool[key].parameters() + ): continue self.module_pool[ key @@ -138,7 +142,9 @@ def _to_distributed_dataparallel(self) -> None: find_unused_parameters=True, ) - def add_tasks(self, tasks: Union[EmmentalTask, List[EmmentalTask]]) -> None: + def add_tasks( + self, tasks: Union[EmmentalTask, List[EmmentalTask]] + ) -> None: """Build the MTL network using all tasks. Args: @@ -234,7 +240,9 @@ def remove_task(self, task_name: str) -> None: """ if task_name not in self.task_flows: if Meta.config["meta_config"]["verbose"]: - logger.info(f"Task ({task_name}) not in the current model, skip...") + logger.info( + f"Task ({task_name}) not in the current model, skip..." + ) return # Remove task by task_name @@ -257,7 +265,9 @@ def __repr__(self) -> str: cls_name = type(self).__name__ return f"{cls_name}(name={self.name})" - def flow(self, X_dict: Dict[str, Any], task_names: List[str]) -> Dict[str, Any]: + def flow( + self, X_dict: Dict[str, Any], task_names: List[str] + ) -> Dict[str, Any]: """Forward based on input and task flow. Note: @@ -291,19 +301,27 @@ def flow(self, X_dict: Dict[str, Any], task_names: List[str]) -> Dict[str, Any]: input = move_to_device( [ output_dict[action_name][output_index] - for action_name, output_index in action["inputs"] + for action_name, output_index in action[ + "inputs" + ] ], action_module_device, ) except Exception: raise ValueError(f"Unrecognized action {action}.") - output = self.module_pool[action["module"]].forward(*input) + output = self.module_pool[action["module"]].forward( + *input + ) else: # TODO: Handle multiple device with not inputs case - output = self.module_pool[action["module"]].forward(output_dict) + output = self.module_pool[action["module"]].forward( + output_dict + ) if isinstance(output, tuple): output = list(output) - if not isinstance(output, list) and not isinstance(output, dict): + if not isinstance(output, list) and not isinstance( + output, dict + ): output = [output] output_dict[action["name"]] = output @@ -350,7 +368,9 @@ def forward( # type: ignore all tasks. """ uid_dict: Dict[str, List[str]] = defaultdict(list) - loss_dict: Dict[str, Tensor] = defaultdict(Tensor) if return_loss else None + loss_dict: Dict[str, Tensor] = ( + defaultdict(Tensor) if return_loss else None + ) gold_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if Y_dict is not None else None ) @@ -358,7 +378,9 @@ def forward( # type: ignore defaultdict(list) if return_probs else None ) out_dict: Dict[str, Dict[str, Union[ndarray, List, Dict]]] = ( - defaultdict(lambda: defaultdict(list)) if return_action_outputs else None + defaultdict(lambda: defaultdict(list)) + if return_action_outputs + else None ) output_dict = self.flow(X_dict, list(task_to_label_dict.keys())) @@ -392,7 +414,10 @@ def forward( # type: ignore and self.output_funcs[task_name] is not None ): prob_dict[task_name] = ( - self.output_funcs[task_name](output_dict).cpu().detach().numpy() + self.output_funcs[task_name](output_dict) + .cpu() + .detach() + .numpy() ) if Y_dict is not None and label_name is not None: @@ -403,7 +428,9 @@ def forward( # type: ignore and task_name in self.action_outputs and self.action_outputs[task_name] is not None ): - for action_name, output_index in self.action_outputs[task_name]: + for action_name, output_index in self.action_outputs[ + task_name + ]: action_output = output_dict[action_name][output_index] if isinstance(action_output, dict): action_output = move_to_device(action_output, -1) @@ -412,9 +439,9 @@ def forward( # type: ignore else: action_output = action_output.cpu().detach().numpy() - out_dict[task_name][f"{action_name}_{output_index}"] = ( - action_output - ) + out_dict[task_name][ + f"{action_name}_{output_index}" + ] = action_output if return_action_outputs: return uid_dict, loss_dict, prob_dict, gold_dict, out_dict @@ -455,8 +482,12 @@ def predict( pred_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if return_preds else None ) - out_dict: Dict[str, Dict[str, Union[dict, List[Union[ndarray, int, float, dict]]]]] = ( - defaultdict(lambda: defaultdict(list)) if return_action_outputs else None + out_dict: Dict[ + str, Dict[str, Union[dict, List[Union[ndarray, int, float, dict]]]] + ] = ( + defaultdict(lambda: defaultdict(list)) + if return_action_outputs + else None ) loss_dict: Dict[str, Union[ndarray, float]] = ( defaultdict(list) if return_loss else None # type: ignore @@ -515,9 +546,9 @@ def predict( if len(loss_bdict[task_name].size()) == 0: if loss_dict[task_name] == []: loss_dict[task_name] = 0 - loss_dict[task_name] += loss_bdict[task_name].item() * len( - uid_bdict[task_name] - ) + loss_dict[task_name] += loss_bdict[ + task_name + ].item() * len(uid_bdict[task_name]) else: loss_dict[task_name].extend( # type: ignore loss_bdict[task_name].cpu().numpy() @@ -536,13 +567,16 @@ def predict( for task_name in out_bdict.keys(): for action_name in out_bdict[task_name].keys(): if not out_dict[task_name][action_name]: - out_dict[task_name][action_name] = out_bdict[task_name][action_name] + out_dict[task_name][action_name] = out_bdict[ + task_name + ][action_name] else: - out_dict[task_name][action_name] = merge_objects( - out_dict[task_name][action_name], - out_bdict[task_name][action_name] + out_dict[task_name][ + action_name + ] = merge_objects( + out_dict[task_name][action_name], + out_bdict[task_name][action_name], ) - # Calculate average loss if return_loss: @@ -603,8 +637,12 @@ def score( return_preds = False for task_name in dataloader.task_to_label_dict: - return_probs = return_probs or self.require_prob_for_evals[task_name] - return_preds = return_preds or self.require_pred_for_evals[task_name] + return_probs = ( + return_probs or self.require_prob_for_evals[task_name] + ) + return_preds = ( + return_preds or self.require_pred_for_evals[task_name] + ) predictions = self.predict( dataloader, @@ -620,14 +658,20 @@ def score( metric_score_dict[identifier] = np.mean( # type: ignore predictions["losses"][task_name] ) - macro_loss_dict[dataloader.split].append(metric_score_dict[identifier]) + macro_loss_dict[dataloader.split].append( + metric_score_dict[identifier] + ) # Store the task specific metric score if self.scorers[task_name]: metric_score = self.scorers[task_name].score( predictions["golds"][task_name], - predictions["probs"][task_name] if return_probs else None, - predictions["preds"][task_name] if return_preds else None, + predictions["probs"][task_name] + if return_probs + else None, + predictions["preds"][task_name] + if return_preds + else None, predictions["uids"][task_name], ) @@ -643,7 +687,10 @@ def score( if return_average: # Collect average score identifier = construct_identifier( - task_name, dataloader.data_name, dataloader.split, "average" + task_name, + dataloader.data_name, + dataloader.split, + "average", ) metric_score_dict[identifier] = np.mean( # type: ignore list(metric_score.values()) @@ -671,7 +718,9 @@ def score( macro_score_dict[split] ) for split in macro_loss_dict.keys(): - identifier = construct_identifier("model", "all", split, "loss") + identifier = construct_identifier( + "model", "all", split, "loss" + ) metric_score_dict[identifier] = np.mean( # type: ignore macro_loss_dict[split] ) @@ -682,19 +731,31 @@ def score( "model", "all", "all", "micro_average" ) metric_score_dict[identifier] = np.mean( # type: ignore - list(itertools.chain.from_iterable(micro_score_dict.values())) + list( + itertools.chain.from_iterable( + micro_score_dict.values() + ) + ) ) if macro_score_dict: identifier = construct_identifier( "model", "all", "all", "macro_average" ) metric_score_dict[identifier] = np.mean( # type: ignore - list(itertools.chain.from_iterable(macro_score_dict.values())) + list( + itertools.chain.from_iterable( + macro_score_dict.values() + ) + ) ) if macro_loss_dict: - identifier = construct_identifier("model", "all", "all", "loss") + identifier = construct_identifier( + "model", "all", "all", "loss" + ) metric_score_dict[identifier] = np.mean( # type: ignore - list(itertools.chain.from_iterable(macro_loss_dict.values())) + list( + itertools.chain.from_iterable(macro_loss_dict.values()) + ) ) # TODO: have a better to handle global evaluation metric @@ -763,9 +824,13 @@ def load( logger.error("Loading failed... Model does not exist.") try: - checkpoint = torch.load(model_path, map_location=torch.device("cpu")) + checkpoint = torch.load( + model_path, map_location=torch.device("cpu") + ) except BaseException: - logger.error(f"Loading failed... Cannot load model from {model_path}") + logger.error( + f"Loading failed... Cannot load model from {model_path}" + ) raise self.load_state_dict(checkpoint["model"]["module_pool"]) @@ -801,6 +866,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # type: ignore module_state_dict ) else: - self.module_pool[module_name].load_state_dict(module_state_dict) + self.module_pool[module_name].load_state_dict( + module_state_dict + ) else: logger.info(f"Missing {module_name} in module_pool, skip it..") From 38b30c9d67b794239a8376a4f763acd0aa097103 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 09:45:06 -0700 Subject: [PATCH 04/29] [format] imports --- src/emmental/model.py | 137 +++++++++++------------------------------- 1 file changed, 36 insertions(+), 101 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index a430345..0ab750e 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -20,9 +20,9 @@ from emmental.utils.utils import ( array_to_numpy, construct_identifier, + merge_objects, move_to_device, prob_to_pred, - merge_objects, ) if importlib.util.find_spec("ipywidgets") is not None: @@ -95,9 +95,7 @@ def _move_to_device(self) -> None: if device != torch.device("cpu"): if torch.cuda.is_available(): if Meta.config["meta_config"]["verbose"]: - logger.info( - f"Moving {module_name} module to GPU ({device})." - ) + logger.info(f"Moving {module_name} module to GPU ({device}).") self.module_pool[module_name].to(device) else: if Meta.config["meta_config"]["verbose"]: @@ -129,9 +127,7 @@ def _to_distributed_dataparallel(self) -> None: # TODO support multiple device with DistributedDataParallel for key in self.module_pool.keys(): # Ensure there is some gradient parameter for DDP - if not any( - p.requires_grad for p in self.module_pool[key].parameters() - ): + if not any(p.requires_grad for p in self.module_pool[key].parameters()): continue self.module_pool[ key @@ -142,9 +138,7 @@ def _to_distributed_dataparallel(self) -> None: find_unused_parameters=True, ) - def add_tasks( - self, tasks: Union[EmmentalTask, List[EmmentalTask]] - ) -> None: + def add_tasks(self, tasks: Union[EmmentalTask, List[EmmentalTask]]) -> None: """Build the MTL network using all tasks. Args: @@ -240,9 +234,7 @@ def remove_task(self, task_name: str) -> None: """ if task_name not in self.task_flows: if Meta.config["meta_config"]["verbose"]: - logger.info( - f"Task ({task_name}) not in the current model, skip..." - ) + logger.info(f"Task ({task_name}) not in the current model, skip...") return # Remove task by task_name @@ -265,9 +257,7 @@ def __repr__(self) -> str: cls_name = type(self).__name__ return f"{cls_name}(name={self.name})" - def flow( - self, X_dict: Dict[str, Any], task_names: List[str] - ) -> Dict[str, Any]: + def flow(self, X_dict: Dict[str, Any], task_names: List[str]) -> Dict[str, Any]: """Forward based on input and task flow. Note: @@ -301,27 +291,19 @@ def flow( input = move_to_device( [ output_dict[action_name][output_index] - for action_name, output_index in action[ - "inputs" - ] + for action_name, output_index in action["inputs"] ], action_module_device, ) except Exception: raise ValueError(f"Unrecognized action {action}.") - output = self.module_pool[action["module"]].forward( - *input - ) + output = self.module_pool[action["module"]].forward(*input) else: # TODO: Handle multiple device with not inputs case - output = self.module_pool[action["module"]].forward( - output_dict - ) + output = self.module_pool[action["module"]].forward(output_dict) if isinstance(output, tuple): output = list(output) - if not isinstance(output, list) and not isinstance( - output, dict - ): + if not isinstance(output, list) and not isinstance(output, dict): output = [output] output_dict[action["name"]] = output @@ -368,9 +350,7 @@ def forward( # type: ignore all tasks. """ uid_dict: Dict[str, List[str]] = defaultdict(list) - loss_dict: Dict[str, Tensor] = ( - defaultdict(Tensor) if return_loss else None - ) + loss_dict: Dict[str, Tensor] = defaultdict(Tensor) if return_loss else None gold_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if Y_dict is not None else None ) @@ -378,9 +358,7 @@ def forward( # type: ignore defaultdict(list) if return_probs else None ) out_dict: Dict[str, Dict[str, Union[ndarray, List, Dict]]] = ( - defaultdict(lambda: defaultdict(list)) - if return_action_outputs - else None + defaultdict(lambda: defaultdict(list)) if return_action_outputs else None ) output_dict = self.flow(X_dict, list(task_to_label_dict.keys())) @@ -414,10 +392,7 @@ def forward( # type: ignore and self.output_funcs[task_name] is not None ): prob_dict[task_name] = ( - self.output_funcs[task_name](output_dict) - .cpu() - .detach() - .numpy() + self.output_funcs[task_name](output_dict).cpu().detach().numpy() ) if Y_dict is not None and label_name is not None: @@ -428,9 +403,7 @@ def forward( # type: ignore and task_name in self.action_outputs and self.action_outputs[task_name] is not None ): - for action_name, output_index in self.action_outputs[ - task_name - ]: + for action_name, output_index in self.action_outputs[task_name]: action_output = output_dict[action_name][output_index] if isinstance(action_output, dict): action_output = move_to_device(action_output, -1) @@ -439,9 +412,7 @@ def forward( # type: ignore else: action_output = action_output.cpu().detach().numpy() - out_dict[task_name][ - f"{action_name}_{output_index}" - ] = action_output + out_dict[task_name][f"{action_name}_{output_index}"] = action_output if return_action_outputs: return uid_dict, loss_dict, prob_dict, gold_dict, out_dict @@ -484,11 +455,7 @@ def predict( ) out_dict: Dict[ str, Dict[str, Union[dict, List[Union[ndarray, int, float, dict]]]] - ] = ( - defaultdict(lambda: defaultdict(list)) - if return_action_outputs - else None - ) + ] = (defaultdict(lambda: defaultdict(list)) if return_action_outputs else None) loss_dict: Dict[str, Union[ndarray, float]] = ( defaultdict(list) if return_loss else None # type: ignore ) @@ -546,9 +513,9 @@ def predict( if len(loss_bdict[task_name].size()) == 0: if loss_dict[task_name] == []: loss_dict[task_name] = 0 - loss_dict[task_name] += loss_bdict[ - task_name - ].item() * len(uid_bdict[task_name]) + loss_dict[task_name] += loss_bdict[task_name].item() * len( + uid_bdict[task_name] + ) else: loss_dict[task_name].extend( # type: ignore loss_bdict[task_name].cpu().numpy() @@ -567,13 +534,11 @@ def predict( for task_name in out_bdict.keys(): for action_name in out_bdict[task_name].keys(): if not out_dict[task_name][action_name]: - out_dict[task_name][action_name] = out_bdict[ - task_name - ][action_name] - else: - out_dict[task_name][ + out_dict[task_name][action_name] = out_bdict[task_name][ action_name - ] = merge_objects( + ] + else: + out_dict[task_name][action_name] = merge_objects( out_dict[task_name][action_name], out_bdict[task_name][action_name], ) @@ -637,12 +602,8 @@ def score( return_preds = False for task_name in dataloader.task_to_label_dict: - return_probs = ( - return_probs or self.require_prob_for_evals[task_name] - ) - return_preds = ( - return_preds or self.require_pred_for_evals[task_name] - ) + return_probs = return_probs or self.require_prob_for_evals[task_name] + return_preds = return_preds or self.require_pred_for_evals[task_name] predictions = self.predict( dataloader, @@ -658,20 +619,14 @@ def score( metric_score_dict[identifier] = np.mean( # type: ignore predictions["losses"][task_name] ) - macro_loss_dict[dataloader.split].append( - metric_score_dict[identifier] - ) + macro_loss_dict[dataloader.split].append(metric_score_dict[identifier]) # Store the task specific metric score if self.scorers[task_name]: metric_score = self.scorers[task_name].score( predictions["golds"][task_name], - predictions["probs"][task_name] - if return_probs - else None, - predictions["preds"][task_name] - if return_preds - else None, + predictions["probs"][task_name] if return_probs else None, + predictions["preds"][task_name] if return_preds else None, predictions["uids"][task_name], ) @@ -718,9 +673,7 @@ def score( macro_score_dict[split] ) for split in macro_loss_dict.keys(): - identifier = construct_identifier( - "model", "all", split, "loss" - ) + identifier = construct_identifier("model", "all", split, "loss") metric_score_dict[identifier] = np.mean( # type: ignore macro_loss_dict[split] ) @@ -731,31 +684,19 @@ def score( "model", "all", "all", "micro_average" ) metric_score_dict[identifier] = np.mean( # type: ignore - list( - itertools.chain.from_iterable( - micro_score_dict.values() - ) - ) + list(itertools.chain.from_iterable(micro_score_dict.values())) ) if macro_score_dict: identifier = construct_identifier( "model", "all", "all", "macro_average" ) metric_score_dict[identifier] = np.mean( # type: ignore - list( - itertools.chain.from_iterable( - macro_score_dict.values() - ) - ) + list(itertools.chain.from_iterable(macro_score_dict.values())) ) if macro_loss_dict: - identifier = construct_identifier( - "model", "all", "all", "loss" - ) + identifier = construct_identifier("model", "all", "all", "loss") metric_score_dict[identifier] = np.mean( # type: ignore - list( - itertools.chain.from_iterable(macro_loss_dict.values()) - ) + list(itertools.chain.from_iterable(macro_loss_dict.values())) ) # TODO: have a better to handle global evaluation metric @@ -824,13 +765,9 @@ def load( logger.error("Loading failed... Model does not exist.") try: - checkpoint = torch.load( - model_path, map_location=torch.device("cpu") - ) + checkpoint = torch.load(model_path, map_location=torch.device("cpu")) except BaseException: - logger.error( - f"Loading failed... Cannot load model from {model_path}" - ) + logger.error(f"Loading failed... Cannot load model from {model_path}") raise self.load_state_dict(checkpoint["model"]["module_pool"]) @@ -866,8 +803,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # type: ignore module_state_dict ) else: - self.module_pool[module_name].load_state_dict( - module_state_dict - ) + self.module_pool[module_name].load_state_dict(module_state_dict) else: logger.info(f"Missing {module_name} in module_pool, skip it..") From cc214a083718066abe47bff996fdd724439f1394 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 10:14:11 -0700 Subject: [PATCH 05/29] [add] tests for merge_objects --- tests/utils/test_utils.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 7cc4319..92b9263 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -11,6 +11,7 @@ construct_identifier, convert_to_serializable_json, merge, + merge_objects, move_to_device, nullable_float, nullable_int, @@ -68,6 +69,24 @@ def test_move_to_device(caplog): assert move_to_device((torch.tensor([1, 2]), torch.tensor([3, 4])), -1) +def test_merge_objects(caplog): + caplog.set_level(logging.INFO) + + assert torch.equal( + merge_objects(torch.Tensor([1, 2]), torch.Tensor([2, 3])), + torch.Tensor([1, 2, 2, 3]), + ) + assert merge_objects(np.array([1, 2]), np.array([2, 3])) == np.array([1, 2, 2, 3]) + assert merge_objects({"a": torch.Tensor([1, 2])}, {"a": torch.Tensor([2, 3])}) == { + "a": torch.Tensor([1, 2, 2, 3]) + } + + assert merge_objects({"a": [1, 2]}, {}) == {"a": [1, 2]} + assert merge_objects({}, {"a": [1, 2]}) == {"a": [1, 2]} + + assert merge_objects([1, 2, 3], [2, 3, 4]) == [1, 2, 3, 2, 3, 4] + + def test_array_to_numpy(caplog): """Unit test of array_to_numpy.""" caplog.set_level(logging.INFO) @@ -78,7 +97,8 @@ def test_array_to_numpy(caplog): ) assert ( np.array_equal( - array_to_numpy(torch.tensor([[1, 2], [3, 4]])), np.array([[1, 2], [3, 4]]) + array_to_numpy(torch.tensor([[1, 2], [3, 4]])), + np.array([[1, 2], [3, 4]]), ) is True ) From 8ea9b2cc3ece1e908bcf8cba6406f456612c0697 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 10:20:37 -0700 Subject: [PATCH 06/29] [reformat] utils.py --- src/emmental/utils/utils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index 0651a4a..b0cfa75 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -157,12 +157,15 @@ def move_to_device( else: return obj + def merge_objects(obj_1: Any, obj_2: Any): if isinstance(obj_1, torch.Tensor): return torch.cat([obj_1, obj_2]) elif isinstance(obj_1, dict): - if not obj_1: return obj_2 - elif not obj_2: return obj_1 + if not obj_1: + return obj_2 + elif not obj_2: + return obj_1 else: for key, value in obj_1.items(): obj_1[key] = merge_objects(value, obj_2[key]) @@ -175,6 +178,7 @@ def merge_objects(obj_1: Any, obj_2: Any): else: return obj_1 + def array_to_numpy( array: Union[ndarray, List[Any], Tensor], flatten: bool = False ) -> ndarray: @@ -203,7 +207,9 @@ def array_to_numpy( def merge( - x: Dict[str, Any], y: Dict[str, Any], specical_keys: Union[str, List[str]] = None + x: Dict[str, Any], + y: Dict[str, Any], + specical_keys: Union[str, List[str]] = None, ) -> Dict[str, Any]: """Merge two nested dictionaries. Overwrite values in x with values in y. @@ -326,7 +332,10 @@ def nullable_string(v: str) -> Optional[str]: def construct_identifier( - task_name: str, data_name: str, split_name: str, metric_name: Optional[str] = None + task_name: str, + data_name: str, + split_name: str, + metric_name: Optional[str] = None, ) -> str: """Construct identifier. From 5134f8ad8240007daeb56ea12888c389b06b7de5 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 10:25:48 -0700 Subject: [PATCH 07/29] [add] return type for merge_objects --- src/emmental/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index b0cfa75..4effd3d 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -158,7 +158,7 @@ def move_to_device( return obj -def merge_objects(obj_1: Any, obj_2: Any): +def merge_objects(obj_1: Any, obj_2: Any) -> Any: if isinstance(obj_1, torch.Tensor): return torch.cat([obj_1, obj_2]) elif isinstance(obj_1, dict): From af110256b0ada53ff97af46593f76b2ec249f2d0 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 10:36:11 -0700 Subject: [PATCH 08/29] [add] docstring to merge_objects --- src/emmental/utils/utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index 4effd3d..1f961ee 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -159,6 +159,18 @@ def move_to_device( def merge_objects(obj_1: Any, obj_2: Any) -> Any: + """Merges two objects of the same type + + Given two objects of the same type and structure, merges the second object + into the first object. + + Args: + obj_1: first object + obj_2: seecond object to be merged into the first object + + Returns: + an object reflecting the merged output of the two inputs + """ if isinstance(obj_1, torch.Tensor): return torch.cat([obj_1, obj_2]) elif isinstance(obj_1, dict): From aeda146b6f7e8c470152a4f533129b3c458eadab Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 10:41:30 -0700 Subject: [PATCH 09/29] [add] docstring to merge_objects --- src/emmental/utils/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index 1f961ee..6bfbb05 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -159,17 +159,17 @@ def move_to_device( def merge_objects(obj_1: Any, obj_2: Any) -> Any: - """Merges two objects of the same type + """Merges two objects of the same type. Given two objects of the same type and structure, merges the second object into the first object. Args: - obj_1: first object - obj_2: seecond object to be merged into the first object + obj_1: first object. + obj_2: seecond object to be merged into the first object. Returns: - an object reflecting the merged output of the two inputs + an object reflecting the merged output of the two inputs. """ if isinstance(obj_1, torch.Tensor): return torch.cat([obj_1, obj_2]) From 73643eae896c851440f5b7947080da78d9f7fc68 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 10:47:11 -0700 Subject: [PATCH 10/29] [add] change docstring mood inmerge_objects --- src/emmental/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index 6bfbb05..c6545cd 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -159,7 +159,7 @@ def move_to_device( def merge_objects(obj_1: Any, obj_2: Any) -> Any: - """Merges two objects of the same type. + """Merge two objects of the same type. Given two objects of the same type and structure, merges the second object into the first object. From e548b77a9ff05e8efa7b0d469016dc21a6f38345 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 10:49:03 -0700 Subject: [PATCH 11/29] [reformat] --- src/emmental/contrib/slicing/modules/utils.py | 4 +- src/emmental/logging/checkpointer.py | 9 +--- src/emmental/meta.py | 5 +-- src/emmental/model.py | 20 ++------- src/emmental/utils/parse_args.py | 5 +-- src/emmental/utils/utils.py | 9 +--- tests/e2e/test_e2e.py | 4 +- tests/e2e/test_e2e_mixed.py | 10 +---- tests/e2e/test_e2e_no_y_dict.py | 20 +++------ tests/e2e/test_e2e_skip_trained.py | 44 +++++-------------- tests/task/test_task.py | 6 +-- tests/test_meta.py | 22 ++-------- tests/utils/test_utils.py | 3 +- 13 files changed, 35 insertions(+), 126 deletions(-) diff --git a/src/emmental/contrib/slicing/modules/utils.py b/src/emmental/contrib/slicing/modules/utils.py index 7f22483..e61b372 100644 --- a/src/emmental/contrib/slicing/modules/utils.py +++ b/src/emmental/contrib/slicing/modules/utils.py @@ -23,9 +23,7 @@ def ce_loss( Loss. """ return F.cross_entropy( - intermediate_output_dict[module_name][0], - Y.view(-1) - 1, - weight, + intermediate_output_dict[module_name][0], Y.view(-1) - 1, weight ) diff --git a/src/emmental/logging/checkpointer.py b/src/emmental/logging/checkpointer.py index 24285ff..c9bf02f 100644 --- a/src/emmental/logging/checkpointer.py +++ b/src/emmental/logging/checkpointer.py @@ -137,9 +137,7 @@ def checkpoint( # Save optimizer state optimizer_path = f"{self.checkpoint_path}/checkpoint_{iteration}.optimizer.pth" - optimizer_dict = { - "optimizer": optimizer.state_dict(), - } + optimizer_dict = {"optimizer": optimizer.state_dict()} torch.save(optimizer_dict, optimizer_path) # Save lr_scheduler state @@ -165,10 +163,7 @@ def checkpoint( f"{self.checkpoint_path}/best_model_" f"{metric.replace('/', '_')}.model.pth" ) - copyfile( - model_path, - best_metric_model_path, - ) + copyfile(model_path, best_metric_model_path) logger.info( f"Save best model of metric {metric} to {best_metric_model_path}" ) diff --git a/src/emmental/meta.py b/src/emmental/meta.py index c4149d3..52b91f8 100644 --- a/src/emmental/meta.py +++ b/src/emmental/meta.py @@ -251,10 +251,7 @@ def check_config() -> None: Meta.config["logging_config"]["evaluation_freq"] = new_evaluation_freq if ( - Meta.config["logging_config"]["counter_unit"] - in [ - "epoch", - ] + Meta.config["logging_config"]["counter_unit"] in ["epoch"] and isinstance(Meta.config["logging_config"]["evaluation_freq"], int) and Meta.config["logging_config"]["writer_config"]["write_loss_per_step"] ): diff --git a/src/emmental/model.py b/src/emmental/model.py index 0ab750e..6e2b490 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -379,8 +379,7 @@ def forward( # type: ignore loss_dict[task_name] = self.loss_funcs[task_name]( output_dict, move_to_device( - Y_dict[label_name], - Meta.config["model_config"]["device"], + Y_dict[label_name], Meta.config["model_config"]["device"] ) if Y_dict is not None and label_name is not None else None, @@ -549,11 +548,7 @@ def predict( if not isinstance(loss_dict[task_name], list): loss_dict[task_name] /= len(uid_dict[task_name]) - res = { - "uids": uid_dict, - "golds": gold_dict, - "losses": loss_dict, - } + res = {"uids": uid_dict, "golds": gold_dict, "losses": loss_dict} if return_probs: for task_name in prob_dict.keys(): @@ -642,10 +637,7 @@ def score( if return_average: # Collect average score identifier = construct_identifier( - task_name, - dataloader.data_name, - dataloader.split, - "average", + task_name, dataloader.data_name, dataloader.split, "average" ) metric_score_dict[identifier] = np.mean( # type: ignore list(metric_score.values()) @@ -750,11 +742,7 @@ def save( if Meta.config["meta_config"]["verbose"] and verbose: logger.info(f"[{self.name}] Model saved in {model_path}") - def load( - self, - model_path: str, - verbose: bool = True, - ) -> None: + def load(self, model_path: str, verbose: bool = True) -> None: """Load model state_dict from file and reinitialize the model weights. Args: diff --git a/src/emmental/utils/parse_args.py b/src/emmental/utils/parse_args.py index 54b8eb0..895190e 100644 --- a/src/emmental/utils/parse_args.py +++ b/src/emmental/utils/parse_args.py @@ -819,10 +819,7 @@ def parse_args(parser: Optional[ArgumentParser] = None) -> ArgumentParser: ) logging_config.add_argument( - "--wandb_run_name", - type=nullable_string, - default=None, - help="Wandb run name", + "--wandb_run_name", type=nullable_string, default=None, help="Wandb run name" ) logging_config.add_argument( diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index c6545cd..daa0c23 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -219,9 +219,7 @@ def array_to_numpy( def merge( - x: Dict[str, Any], - y: Dict[str, Any], - specical_keys: Union[str, List[str]] = None, + x: Dict[str, Any], y: Dict[str, Any], specical_keys: Union[str, List[str]] = None ) -> Dict[str, Any]: """Merge two nested dictionaries. Overwrite values in x with values in y. @@ -344,10 +342,7 @@ def nullable_string(v: str) -> Optional[str]: def construct_identifier( - task_name: str, - data_name: str, - split_name: str, - metric_name: Optional[str] = None, + task_name: str, data_name: str, split_name: str, metric_name: Optional[str] = None ) -> str: """Construct identifier. diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 8591aa8..24b895f 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -268,9 +268,7 @@ def forward(self, input): test2_pred = mtl_model.predict(test_dataloader2, return_action_outputs=True) test3_pred = mtl_model.predict( - test_dataloader3, - return_action_outputs=True, - return_loss=False, + test_dataloader3, return_action_outputs=True, return_loss=False ) assert test2_pred["uids"] == test3_pred["uids"] diff --git a/tests/e2e/test_e2e_mixed.py b/tests/e2e/test_e2e_mixed.py index 57ddccb..263d40f 100644 --- a/tests/e2e/test_e2e_mixed.py +++ b/tests/e2e/test_e2e_mixed.py @@ -117,10 +117,7 @@ def test_e2e_mixed(caplog): def ave_scorer(metric_score_dict): logger.info(metric_score_dict) - metric_names = [ - "task1/synthetic/test/loss", - "task2/synthetic/test/loss", - ] + metric_names = ["task1/synthetic/test/loss", "task2/synthetic/test/loss"] total = 0.0 cnt = 0 @@ -219,10 +216,7 @@ def forward(self, input): emmental_learner = EmmentalLearner() # Learning - emmental_learner.learn( - mtl_model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(mtl_model, [train_dataloader, dev_dataloader]) test_score = mtl_model.score(test_dataloader) diff --git a/tests/e2e/test_e2e_no_y_dict.py b/tests/e2e/test_e2e_no_y_dict.py index f1e9603..35a87c9 100644 --- a/tests/e2e/test_e2e_no_y_dict.py +++ b/tests/e2e/test_e2e_no_y_dict.py @@ -76,18 +76,15 @@ def test_e2e_no_y_dict(caplog): ) train_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_train, "label1": Y_train}, + name="synthetic", X_dict={"data": X_train, "label1": Y_train} ) dev_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_dev, "label1": Y_dev}, + name="synthetic", X_dict={"data": X_dev, "label1": Y_dev} ) test_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_test, "label1": Y_test}, + name="synthetic", X_dict={"data": X_test, "label1": Y_test} ) task_name = "task1" @@ -147,11 +144,7 @@ def forward(self, input): "module": "input_module0", "inputs": [("_input_", "data")], }, - { - "name": "input1", - "module": "input_module1", - "inputs": [("input", "out")], - }, + {"name": "input1", "module": "input_module1", "inputs": [("input", "out")]}, { "name": f"{task_name}_pred_head", "module": f"{task_name}_pred_head", @@ -173,10 +166,7 @@ def forward(self, input): emmental_learner = EmmentalLearner() # Learning - emmental_learner.learn( - mtl_model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(mtl_model, [train_dataloader, dev_dataloader]) test_score = mtl_model.score(test_dataloader) diff --git a/tests/e2e/test_e2e_skip_trained.py b/tests/e2e/test_e2e_skip_trained.py index 51bed23..b083cc2 100644 --- a/tests/e2e/test_e2e_skip_trained.py +++ b/tests/e2e/test_e2e_skip_trained.py @@ -51,21 +51,15 @@ def test_e2e_skip_trained_step(caplog): ) train_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_train}, - Y_dict={"label1": Y_train}, + name="synthetic", X_dict={"data": X_train}, Y_dict={"label1": Y_train} ) dev_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_dev}, - Y_dict={"label1": Y_dev}, + name="synthetic", X_dict={"data": X_dev}, Y_dict={"label1": Y_dev} ) test_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_test}, - Y_dict={"label1": Y_test}, + name="synthetic", X_dict={"data": X_test}, Y_dict={"label1": Y_test} ) task_to_label_dict = {"task1": "label1"} @@ -182,10 +176,7 @@ def forward(self, input): Meta.update_config(config) # Learning - emmental_learner.learn( - model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(model, [train_dataloader, dev_dataloader]) test_score = model.score(test_dataloader) @@ -236,10 +227,7 @@ def forward(self, input): model.load(Meta.config["model_config"]["model_path"]) # Learning - emmental_learner.learn( - model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(model, [train_dataloader, dev_dataloader]) test_score = model.score(test_dataloader) @@ -277,21 +265,15 @@ def test_e2e_skip_trained_epoch(caplog): ) train_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_train}, - Y_dict={"label1": Y_train}, + name="synthetic", X_dict={"data": X_train}, Y_dict={"label1": Y_train} ) dev_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_dev}, - Y_dict={"label1": Y_dev}, + name="synthetic", X_dict={"data": X_dev}, Y_dict={"label1": Y_dev} ) test_dataset = EmmentalDataset( - name="synthetic", - X_dict={"data": X_test}, - Y_dict={"label1": Y_test}, + name="synthetic", X_dict={"data": X_test}, Y_dict={"label1": Y_test} ) task_to_label_dict = {"task1": "label1"} @@ -412,10 +394,7 @@ def forward(self, input): Meta.update_config(config) # Learning - emmental_learner.learn( - model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(model, [train_dataloader, dev_dataloader]) test_score = model.score(test_dataloader) @@ -470,10 +449,7 @@ def forward(self, input): model.load(Meta.config["model_config"]["model_path"]) # Learning - emmental_learner.learn( - model, - [train_dataloader, dev_dataloader], - ) + emmental_learner.learn(model, [train_dataloader, dev_dataloader]) test_score = model.score(test_dataloader) diff --git a/tests/task/test_task.py b/tests/task/test_task.py index 497471a..d55ce8c 100644 --- a/tests/task/test_task.py +++ b/tests/task/test_task.py @@ -42,11 +42,7 @@ def output(module_name, output_dict): "module": "input_module0", "inputs": [("_input_", "data")], }, - { - "name": "input2", - "module": "input_module1", - "inputs": [("input1", 0)], - }, + {"name": "input2", "module": "input_module1", "inputs": [("input1", 0)]}, { "name": f"{task_name}_pred_head", "module": f"{task_name}_pred_head", diff --git a/tests/test_meta.py b/tests/test_meta.py index 83641de..79a28de 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -61,33 +61,19 @@ def test_config_check_in_meta(caplog): Meta.reset() init(dirpath) - config = { - "logging_config": { - "evaluation_freq": 5.0, - }, - } + config = {"logging_config": {"evaluation_freq": 5.0}} Meta.update_config(config) assert type(Meta.config["logging_config"]["evaluation_freq"]) == int assert Meta.config["logging_config"]["evaluation_freq"] == 5 - config = { - "logging_config": { - "counter_unit": "batch", - "evaluation_freq": 2.3, - }, - } + config = {"logging_config": {"counter_unit": "batch", "evaluation_freq": 2.3}} Meta.update_config(config) assert type(Meta.config["logging_config"]["evaluation_freq"]) == int assert Meta.config["logging_config"]["evaluation_freq"] == 3 - config = { - "logging_config": { - "counter_unit": "sample", - "evaluation_freq": 0.2, - }, - } + config = {"logging_config": {"counter_unit": "sample", "evaluation_freq": 0.2}} Meta.update_config(config) assert type(Meta.config["logging_config"]["evaluation_freq"]) == int @@ -98,7 +84,7 @@ def test_config_check_in_meta(caplog): "counter_unit": "epoch", "evaluation_freq": 1, "writer_config": {"write_loss_per_step": True}, - }, + } } Meta.update_config(config) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 92b9263..3506e29 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -97,8 +97,7 @@ def test_array_to_numpy(caplog): ) assert ( np.array_equal( - array_to_numpy(torch.tensor([[1, 2], [3, 4]])), - np.array([[1, 2], [3, 4]]), + array_to_numpy(torch.tensor([[1, 2], [3, 4]])), np.array([[1, 2], [3, 4]]) ) is True ) From 2ffebb7f7994abed92d69e8f2e8f3e5e33387168 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 10:55:49 -0700 Subject: [PATCH 12/29] [reformat] and add comment --- tests/utils/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 3506e29..0060207 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -70,6 +70,7 @@ def test_move_to_device(caplog): def test_merge_objects(caplog): + """Unit test of merge_objects.""" caplog.set_level(logging.INFO) assert torch.equal( From dc427ce8170b38c43ffd63f5b63e24038864faae Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 15:29:57 -0700 Subject: [PATCH 13/29] [fix] merge objects utils and check for empty value --- src/emmental/model.py | 6 +++--- src/emmental/utils/utils.py | 19 ++++++++++++++----- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index 70f91fc..09693f4 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -408,7 +408,7 @@ def forward( # type: ignore if isinstance(action_output, dict): action_output = move_to_device(action_output, -1) for key, value in action_output.items(): - action_output[key] = [value.detach().numpy()] + action_output[key] = value.detach().numpy() else: action_output = action_output.cpu().detach().numpy() @@ -457,7 +457,7 @@ def predict( ) out_dict: Dict[str, Dict[str, Union[dict, List[Union[ndarray, int, float, dict]]]]] = ( defaultdict(lambda: defaultdict(list)) if return_action_outputs else None - ) + ) # HOW DO WE INFER Type loss_dict: Dict[str, Union[ndarray, float]] = ( defaultdict(list) if return_loss else None # type: ignore ) @@ -535,7 +535,7 @@ def predict( if return_action_outputs and out_bdict: for task_name in out_bdict.keys(): for action_name in out_bdict[task_name].keys(): - if not out_dict[task_name][action_name]: + if out_dict[task_name][action_name] == []: out_dict[task_name][action_name] = out_bdict[task_name][action_name] else: out_dict[task_name][action_name] = merge_objects( diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index 0651a4a..2fe7de1 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -165,13 +165,22 @@ def merge_objects(obj_1: Any, obj_2: Any): elif not obj_2: return obj_1 else: for key, value in obj_1.items(): - obj_1[key] = merge_objects(value, obj_2[key]) + if isinstance(value, torch.Tensor): + obj_1[key] = torch.cat([value, obj_2[key]]) + elif isinstance(value, np.ndarray): + obj_1[key] = np.concatenate((value, obj_2[key])) + elif isinstance(value, list): + obj_1[key].extend(obj_2[key]) return obj_1 - elif isinstance(obj_1, list): - obj_1.extend(obj_2) - return obj_1 elif isinstance(obj_1, np.ndarray): - return np.append((obj_1, obj_2)) + return np.concatenate((obj_1, obj_2)) + elif isinstance(obj_1, list): + return obj_1.extend(obj_2) + elif isinstance(obj_1, tuple): + idx_1, idx_2 = None, None + idx_1 = merge_objects(obj_1[0], obj_2[0]) + idx_2 = merge_objects(obj_1[1], obj_2[1]) + return (idx_1, idx_2) else: return obj_1 From b873a57d373cb30266d41804ba526fe4ead60fd4 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 15:31:52 -0700 Subject: [PATCH 14/29] [merge] w/upstream master --- src/emmental/model.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index f76056b..b1d5cdc 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -452,15 +452,9 @@ def predict( pred_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if return_preds else None ) -<<<<<<< HEAD - out_dict: Dict[str, Dict[str, Union[dict, List[Union[ndarray, int, float, dict]]]]] = ( - defaultdict(lambda: defaultdict(list)) if return_action_outputs else None - ) # HOW DO WE INFER Type -======= out_dict: Dict[ str, Dict[str, Union[dict, List[Union[ndarray, int, float, dict]]]] ] = (defaultdict(lambda: defaultdict(list)) if return_action_outputs else None) ->>>>>>> origin/flexibile-action-outputs loss_dict: Dict[str, Union[ndarray, float]] = ( defaultdict(list) if return_loss else None # type: ignore ) From fccc45cd6eb33c7dbfa3247f19c94d690cbf1598 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 15:54:03 -0700 Subject: [PATCH 15/29] [fix] merge_objects np.array test and merge of list --- src/emmental/model.py | 4 +++- src/emmental/utils/utils.py | 3 ++- tests/utils/test_utils.py | 10 ++++++---- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index b1d5cdc..0173bae 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -533,7 +533,9 @@ def predict( for task_name in out_bdict.keys(): for action_name in out_bdict[task_name].keys(): if out_dict[task_name][action_name] == []: - out_dict[task_name][action_name] = out_bdict[task_name][action_name] + out_dict[task_name][action_name] = out_bdict[task_name][ + action_name + ] else: out_dict[task_name][action_name] = merge_objects( out_dict[task_name][action_name], diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index 6013101..e4c0233 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -190,7 +190,8 @@ def merge_objects(obj_1: Any, obj_2: Any) -> Any: elif isinstance(obj_1, np.ndarray): return np.concatenate((obj_1, obj_2)) elif isinstance(obj_1, list): - return obj_1.extend(obj_2) + obj_1.extend(obj_2) + return obj_1 elif isinstance(obj_1, tuple): idx_1, idx_2 = None, None idx_1 = merge_objects(obj_1[0], obj_2[0]) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 0060207..9e41351 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -77,10 +77,12 @@ def test_merge_objects(caplog): merge_objects(torch.Tensor([1, 2]), torch.Tensor([2, 3])), torch.Tensor([1, 2, 2, 3]), ) - assert merge_objects(np.array([1, 2]), np.array([2, 3])) == np.array([1, 2, 2, 3]) - assert merge_objects({"a": torch.Tensor([1, 2])}, {"a": torch.Tensor([2, 3])}) == { - "a": torch.Tensor([1, 2, 2, 3]) - } + + assert np.array_equal( + merge_objects(np.array([1, 2]), np.array([2, 3])), np.array([1, 2, 2, 3]) + ) + + assert merge_objects({"a": [1, 2]}, {"a": [2, 3]}) == {"a": [1, 2, 2, 3]} assert merge_objects({"a": [1, 2]}, {}) == {"a": [1, 2]} assert merge_objects({}, {"a": [1, 2]}) == {"a": [1, 2]} From 1a1a6a3e3024e8c7d5c33b37524da0ea68d29e7d Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 15:59:09 -0700 Subject: [PATCH 16/29] [add] test for tuple type in merge objects --- tests/utils/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 9e41351..1878ad6 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -77,16 +77,16 @@ def test_merge_objects(caplog): merge_objects(torch.Tensor([1, 2]), torch.Tensor([2, 3])), torch.Tensor([1, 2, 2, 3]), ) - assert np.array_equal( merge_objects(np.array([1, 2]), np.array([2, 3])), np.array([1, 2, 2, 3]) ) - assert merge_objects({"a": [1, 2]}, {"a": [2, 3]}) == {"a": [1, 2, 2, 3]} - assert merge_objects({"a": [1, 2]}, {}) == {"a": [1, 2]} assert merge_objects({}, {"a": [1, 2]}) == {"a": [1, 2]} - + assert merge_objects(([2, 4], [3, 4]), ([3, 4], [4, 5])) == ( + [2, 4, 3, 4], + [3, 4, 4, 5], + ) assert merge_objects([1, 2, 3], [2, 3, 4]) == [1, 2, 3, 2, 3, 4] From 668271d344f6d1eb7d136e1aa202ec5fdfb35948 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 16:44:41 -0700 Subject: [PATCH 17/29] [add] fix input type --- src/emmental/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index 0173bae..5650d97 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -324,7 +324,7 @@ def forward( # type: ignore Dict[str, Tensor], Dict[str, Union[ndarray, List[ndarray]]], Dict[str, Union[ndarray, List[ndarray]]], - Dict[str, Dict[str, Union[ndarray, List]]], + Dict[str, Dict[str, Union[ndarray, List, Dict]]], ], Tuple[ Dict[str, List[str]], From 3f2b6e72f2e6eb737bf1af655c50ac28926b44f8 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 16:58:34 -0700 Subject: [PATCH 18/29] [add] fix input type --- src/emmental/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index 5650d97..bb42079 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -452,9 +452,9 @@ def predict( pred_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if return_preds else None ) - out_dict: Dict[ - str, Dict[str, Union[dict, List[Union[ndarray, int, float, dict]]]] - ] = (defaultdict(lambda: defaultdict(list)) if return_action_outputs else None) + out_dict: Dict[str, Dict[str, List[Union[ndarray, int, float, dict]]]] = ( + defaultdict(lambda: defaultdict(list)) if return_action_outputs else None + ) loss_dict: Dict[str, Union[ndarray, float]] = ( defaultdict(list) if return_loss else None # type: ignore ) From aeb61f8495ad6a4acc8c397a7d9f05c366a60c2c Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 17:15:51 -0700 Subject: [PATCH 19/29] [add] fix input type --- src/emmental/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index bb42079..a92afb4 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -452,7 +452,7 @@ def predict( pred_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if return_preds else None ) - out_dict: Dict[str, Dict[str, List[Union[ndarray, int, float, dict]]]] = ( + out_dict: Dict[str, Dict[str, List[Union[ndarray, int, float, Dict]]]] = ( defaultdict(lambda: defaultdict(list)) if return_action_outputs else None ) loss_dict: Dict[str, Union[ndarray, float]] = ( From a9a2588af75169a15e5b9ef889ac87515b9d6a91 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 17:57:26 -0700 Subject: [PATCH 20/29] [fix] type mismatch --- src/emmental/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index a92afb4..179cfbf 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -452,7 +452,7 @@ def predict( pred_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if return_preds else None ) - out_dict: Dict[str, Dict[str, List[Union[ndarray, int, float, Dict]]]] = ( + out_dict: Dict[str, Dict[str, Union[ndarray, int, float, Dict]]] = ( defaultdict(lambda: defaultdict(list)) if return_action_outputs else None ) loss_dict: Dict[str, Union[ndarray, float]] = ( From 94a748fd3b4e2a3e3aaca446430d71667e1b4d80 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 18:11:56 -0700 Subject: [PATCH 21/29] [fix] type mismatch --- src/emmental/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index 179cfbf..8eb1c93 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -324,7 +324,7 @@ def forward( # type: ignore Dict[str, Tensor], Dict[str, Union[ndarray, List[ndarray]]], Dict[str, Union[ndarray, List[ndarray]]], - Dict[str, Dict[str, Union[ndarray, List, Dict]]], + Dict[str, Dict[str, Union[ndarray, List, int, float, Dict]]], ], Tuple[ Dict[str, List[str]], @@ -357,7 +357,7 @@ def forward( # type: ignore prob_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if return_probs else None ) - out_dict: Dict[str, Dict[str, Union[ndarray, List, Dict]]] = ( + out_dict: Dict[str, Dict[str, Union[ndarray, List, int, float, Dict]]] = ( defaultdict(lambda: defaultdict(list)) if return_action_outputs else None ) @@ -453,7 +453,7 @@ def predict( defaultdict(list) if return_preds else None ) out_dict: Dict[str, Dict[str, Union[ndarray, int, float, Dict]]] = ( - defaultdict(lambda: defaultdict(list)) if return_action_outputs else None + defaultdict(lambda: defaultdict(dict)) if return_action_outputs else None ) loss_dict: Dict[str, Union[ndarray, float]] = ( defaultdict(list) if return_loss else None # type: ignore From d2efc22f7b7a30a771f99cdb9620996fcf6dde07 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Fri, 29 Oct 2021 18:13:46 -0700 Subject: [PATCH 22/29] [fix] type mismatch --- src/emmental/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index 8eb1c93..9fb58be 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -452,8 +452,8 @@ def predict( pred_dict: Dict[str, Union[ndarray, List[ndarray]]] = ( defaultdict(list) if return_preds else None ) - out_dict: Dict[str, Dict[str, Union[ndarray, int, float, Dict]]] = ( - defaultdict(lambda: defaultdict(dict)) if return_action_outputs else None + out_dict: Dict[str, Dict[str, Union[ndarray, List, int, float, Dict]]] = ( + defaultdict(lambda: defaultdict(list)) if return_action_outputs else None ) loss_dict: Dict[str, Union[ndarray, float]] = ( defaultdict(list) if return_loss else None # type: ignore From 4162e32cd96e48e0d20ee5ec85022d70608a56d4 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Sat, 30 Oct 2021 14:03:34 -0700 Subject: [PATCH 23/29] change move_object logic out of if/else clause in forward call --- src/emmental/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index 9fb58be..9641bbb 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -404,12 +404,12 @@ def forward( # type: ignore ): for action_name, output_index in self.action_outputs[task_name]: action_output = output_dict[action_name][output_index] + action_output = move_to_device(action_output, -1) if isinstance(action_output, dict): - action_output = move_to_device(action_output, -1) for key, value in action_output.items(): action_output[key] = value.detach().numpy() else: - action_output = action_output.cpu().detach().numpy() + action_output = action_output.detach().numpy() out_dict[task_name][f"{action_name}_{output_index}"] = action_output From 5c7e2de833fba0124961fab9348ecfb5a6af9b5b Mon Sep 17 00:00:00 2001 From: ANarayan Date: Sat, 30 Oct 2021 14:04:41 -0700 Subject: [PATCH 24/29] [add] more tests for array and tensor, and fix 1D merge edge case --- src/emmental/utils/utils.py | 66 ++++++++++++++++++++++++++----------- tests/utils/test_utils.py | 37 +++++++++++++++++++-- 2 files changed, 82 insertions(+), 21 deletions(-) diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index e4c0233..f4043e5 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -162,41 +162,69 @@ def merge_objects(obj_1: Any, obj_2: Any) -> Any: """Merge two objects of the same type. Given two objects of the same type and structure, merges the second object - into the first object. + into the first object. If either of the objects is empty, the non-empty + object is returned. Supported types include torch tensors, numpy arrays + lists, dicts and tuples. For tensors and arrays, objects are merged + along the 1st dimension: + + obj_1: torch.Tensor([1,2]), obj_2: torch.Tensor([2,3]) + merged object: torch.Tensor([[1,2],[2,3]]) Args: obj_1: first object. - obj_2: seecond object to be merged into the first object. + obj_2: second object to be merged into the first object. Returns: an object reflecting the merged output of the two inputs. """ + if type(obj_1) != type(obj_2): + raise TypeError( + f"Cannot merge object of type {type(obj_1)} " + f"with object of type {type(obj_2)}." + ) if isinstance(obj_1, torch.Tensor): - return torch.cat([obj_1, obj_2]) - elif isinstance(obj_1, dict): - if not obj_1: + # empty edge case + if not obj_1.size()[0]: return obj_2 - elif not obj_2: - return obj_1 - else: - for key, value in obj_1.items(): - if isinstance(value, torch.Tensor): - obj_1[key] = torch.cat([value, obj_2[key]]) - elif isinstance(value, np.ndarray): - obj_1[key] = np.concatenate((value, obj_2[key])) - elif isinstance(value, list): - obj_1[key].extend(obj_2[key]) + elif not obj_2.size()[0]: return obj_1 + + # unsqueeze of object is 1D and not empty + if len(obj_1.shape) == 1: + obj_1 = obj_1.unsqueeze(0) + if len(obj_2.shape) == 1: + obj_2 = obj_2.unsqueeze(0) + return torch.cat([obj_1, obj_2]) elif isinstance(obj_1, np.ndarray): + # empty edge case + if not obj_1.size: + return obj_2 + elif not obj_2.size: + return obj_1 + + # expand if array has 1 dimension + if len(obj_1.shape) == 1: + obj_1 = np.expand_dims(obj_1, axis=0) + if len(obj_2.shape) == 1: + obj_2 = np.expand_dims(obj_2, axis=0) return np.concatenate((obj_1, obj_2)) elif isinstance(obj_1, list): obj_1.extend(obj_2) return obj_1 + elif isinstance(obj_1, dict): + if not obj_1: + return obj_2 + elif not obj_2: + return obj_1 + + for key, value in obj_1.items(): + obj_1[key] = merge_objects(value, obj_2[key]) + return obj_1 elif isinstance(obj_1, tuple): - idx_1, idx_2 = None, None - idx_1 = merge_objects(obj_1[0], obj_2[0]) - idx_2 = merge_objects(obj_1[1], obj_2[1]) - return (idx_1, idx_2) + merged_tuple_vals = [] + for idx in range(len(obj_1)): + merged_tuple_vals.append(merge_objects(obj_1[idx], obj_2[idx])) + return tuple(merged_tuple_vals) else: return obj_1 diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 1878ad6..9d8b93a 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -75,10 +75,21 @@ def test_merge_objects(caplog): assert torch.equal( merge_objects(torch.Tensor([1, 2]), torch.Tensor([2, 3])), - torch.Tensor([1, 2, 2, 3]), + torch.Tensor([[1, 2], [2, 3]]), + ) + assert torch.equal( + merge_objects(torch.Tensor(), torch.Tensor([2, 3])), + torch.Tensor([2, 3]), + ) + assert merge_objects( + torch.zeros((128, 256)), torch.zeros((64, 256)) + ).shape == torch.Size([192, 256]) + + assert np.array_equal( + merge_objects(np.array([1, 2]), np.array([2, 3])), np.array([[1, 2], [2, 3]]) ) assert np.array_equal( - merge_objects(np.array([1, 2]), np.array([2, 3])), np.array([1, 2, 2, 3]) + merge_objects(np.array([]), np.array([2, 3])), np.array([2, 3]) ) assert merge_objects({"a": [1, 2]}, {"a": [2, 3]}) == {"a": [1, 2, 2, 3]} assert merge_objects({"a": [1, 2]}, {}) == {"a": [1, 2]} @@ -87,6 +98,28 @@ def test_merge_objects(caplog): [2, 4, 3, 4], [3, 4, 4, 5], ) + assert ( + torch.equal( + merge_objects( + (torch.Tensor([2]), torch.Tensor([2]), [2, 3]), + (torch.Tensor([3]), torch.Tensor([2]), [3, 4]), + )[0], + torch.Tensor([[2], [3]]), + ) + and torch.equal( + merge_objects( + (torch.Tensor([2]), torch.Tensor([2]), [2, 3]), + (torch.Tensor([3]), torch.Tensor([2]), [3, 4]), + )[1], + torch.Tensor([[2], [2]]), + ) + and merge_objects( + (torch.Tensor([2]), torch.Tensor([2]), [2, 3]), + (torch.Tensor([3]), torch.Tensor([2]), [3, 4]), + )[2] + == [2, 3, 3, 4] + ) + assert merge_objects([1, 2, 3], [2, 3, 4]) == [1, 2, 3, 2, 3, 4] From 0a9fd2d725973394206488f18745c2694be7d03a Mon Sep 17 00:00:00 2001 From: ANarayan Date: Mon, 1 Nov 2021 16:13:06 -0700 Subject: [PATCH 25/29] [add] detach and numpy convert to move_to_device function --- src/emmental/model.py | 20 +++++++------------- src/emmental/utils/utils.py | 12 ++++++++++-- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index 9641bbb..d4eda35 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -405,12 +405,6 @@ def forward( # type: ignore for action_name, output_index in self.action_outputs[task_name]: action_output = output_dict[action_name][output_index] action_output = move_to_device(action_output, -1) - if isinstance(action_output, dict): - for key, value in action_output.items(): - action_output[key] = value.detach().numpy() - else: - action_output = action_output.detach().numpy() - out_dict[task_name][f"{action_name}_{output_index}"] = action_output if return_action_outputs: @@ -533,13 +527,13 @@ def predict( for task_name in out_bdict.keys(): for action_name in out_bdict[task_name].keys(): if out_dict[task_name][action_name] == []: - out_dict[task_name][action_name] = out_bdict[task_name][ - action_name - ] - else: - out_dict[task_name][action_name] = merge_objects( - out_dict[task_name][action_name], - out_bdict[task_name][action_name], + out_dict[task_name][action_name] = ( + out_bdict[task_name][action_name] + if (out_dict[task_name][action_name] == []) + else merge_objects( + out_dict[task_name][action_name], + out_bdict[task_name][action_name], + ) ) # Calculate average loss diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index f4043e5..5633fc0 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -122,7 +122,10 @@ def pred_to_prob(preds: ndarray, n_classes: int) -> ndarray: def move_to_device( - obj: Any, device: Optional[Union[int, str, torch.device]] = -1 + obj: Any, + device: Optional[Union[int, str, torch.device]] = -1, + detach: bool = False, + convert_to_numpy: bool = False, ) -> Any: """Move object to specified device. @@ -147,7 +150,12 @@ def move_to_device( device = torch.device("cpu") if isinstance(obj, torch.Tensor): - return obj.to(device) + obj.to(device) + if detach: + obj = obj.detach() + if convert_to_numpy: + obj = obj.numpy() + return obj elif isinstance(obj, dict): return {key: move_to_device(value, device) for key, value in obj.items()} elif isinstance(obj, list): From edd56dd947733fee2bc7b9a72f562b214b5b04f1 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Mon, 1 Nov 2021 16:13:45 -0700 Subject: [PATCH 26/29] [add] e2e test for outputs which are dicts --- tests/e2e/test_e2e.py | 123 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 121 insertions(+), 2 deletions(-) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 24b895f..9b9ff0b 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -102,6 +102,11 @@ def grouped_parameters(model): torch.tensor(Y2[int(0.8 * N) : int(0.9 * N)]), torch.tensor(Y2[int(0.9 * N) :]), ) + Y3_train, Y3_dev, Y3_test = ( + torch.tensor(Y2[: int(0.8 * N)]), + torch.tensor(Y2[int(0.8 * N) : int(0.9 * N)]), + torch.tensor(Y2[int(0.9 * N) :]), + ) train_dataset1 = EmmentalDataset( name="synthetic", X_dict={"data": X_train}, Y_dict={"label1": Y1_train} @@ -110,6 +115,9 @@ def grouped_parameters(model): train_dataset2 = EmmentalDataset( name="synthetic", X_dict={"data": X_train}, Y_dict={"label2": Y2_train} ) + train_dataset3 = EmmentalDataset( + name="synthetic", X_dict={"data": X_train}, Y_dict={"label3": Y3_train} + ) dev_dataset1 = EmmentalDataset( name="synthetic", X_dict={"data": X_dev}, Y_dict={"label1": Y1_dev} @@ -118,6 +126,9 @@ def grouped_parameters(model): dev_dataset2 = EmmentalDataset( name="synthetic", X_dict={"data": X_dev}, Y_dict={"label2": Y2_dev} ) + dev_dataset3 = EmmentalDataset( + name="synthetic", X_dict={"data": X_dev}, Y_dict={"label3": Y3_dev} + ) test_dataset1 = EmmentalDataset( name="synthetic", X_dict={"data": X_test}, Y_dict={"label1": Y1_test} @@ -126,6 +137,9 @@ def grouped_parameters(model): test_dataset2 = EmmentalDataset( name="synthetic", X_dict={"data": X_test}, Y_dict={"label2": Y2_test} ) + test_dataset4 = EmmentalDataset( + name="synthetic", X_dict={"data": X_test}, Y_dict={"label3": Y3_test} + ) test_dataset3 = EmmentalDataset(name="synthetic", X_dict={"data": X_test}) @@ -177,6 +191,26 @@ def grouped_parameters(model): split="test", batch_size=10, ) + task_to_label_dict = {"task3": "label3"} + + train_dataloader3 = EmmentalDataLoader( + task_to_label_dict=task_to_label_dict, + dataset=train_dataset3, + split="train", + batch_size=10, + ) + dev_dataloader3 = EmmentalDataLoader( + task_to_label_dict=task_to_label_dict, + dataset=dev_dataset3, + split="valid", + batch_size=10, + ) + test_dataloader4 = EmmentalDataLoader( + task_to_label_dict=task_to_label_dict, + dataset=test_dataset4, + split="test", + batch_size=10, + ) # Create task def ce_loss(task_name, immediate_output_dict, Y): @@ -187,7 +221,11 @@ def output(task_name, immediate_output_dict): module_name = f"{task_name}_pred_head" return F.softmax(immediate_output_dict[module_name][0], dim=1) - task_metrics = {"task1": ["accuracy"], "task2": ["accuracy", "roc_auc"]} + task_metrics = { + "task1": ["accuracy"], + "task2": ["accuracy", "roc_auc"], + "task3": ["accuracy"], + } class IdentityModule(nn.Module): def __init__(self): @@ -197,6 +235,33 @@ def __init__(self): def forward(self, input): return {"out": input} + class IdentityDictModule(nn.Module): + def __init__(self): + """Initialize IdentityModule.""" + super().__init__() + + def forward(self, input): + return {"out": {"image_pil": input}} + + class LinearLayerDictModule(nn.Module): + def __init__(self): + """Initialize IdentityModule.""" + super().__init__() + self.linear = nn.Linear(2, 8) + + def forward(self, input): + return {"out": {"image_pil": self.linear(input["image_pil"])}} + + class PredictHeadDictModule(nn.Module): + def __init__(self): + """Initialize IdentityModule.""" + super().__init__() + self.linear = nn.Linear(8, 2) + + def forward(self, input): + print(input) + return self.linear(input["image_pil"]) + tasks = [ EmmentalTask( name=task_name, @@ -240,6 +305,45 @@ def forward(self, input): ) for task_name in ["task1", "task2"] ] + tasks.append( + EmmentalTask( + name="task3", + module_pool=nn.ModuleDict( + { + "task3_input_module": IdentityDictModule(), + "task3_input_module1": LinearLayerDictModule(), + "task3_pred_head": PredictHeadDictModule(), + } + ), + task_flow=[ + { + "name": "input_t3", + "module": "task3_input_module", + "inputs": [("_input_", "data")], + }, + { + "name": "input1_t3", + "module": "task3_input_module1", + "inputs": [("input_t3", "out")], + }, + { + "name": "task3_pred_head", + "module": "task3_pred_head", + "inputs": [("input1_t3", "out")], + }, + ], + module_device={"task3_input_module": -1}, + loss_func=partial(ce_loss, "task3"), + output_func=partial(output, "task3"), + action_outputs=[ + ("task3_pred_head", 0), + ("_input_", "data"), + ("input1_t3", "out"), + ], + scorer=Scorer(metrics=task_metrics["task3"]), + require_prob_for_eval=True, + ) + ) # Build model mtl_model = EmmentalModel(name="all", tasks=tasks) @@ -252,7 +356,14 @@ def forward(self, input): # Learning emmental_learner.learn( mtl_model, - [train_dataloader1, train_dataloader2, dev_dataloader1, dev_dataloader2], + [ + train_dataloader1, + train_dataloader2, + train_dataloader3, + dev_dataloader1, + dev_dataloader2, + dev_dataloader3, + ], ) test1_score = mtl_model.score(test_dataloader1) @@ -270,6 +381,9 @@ def forward(self, input): test3_pred = mtl_model.predict( test_dataloader3, return_action_outputs=True, return_loss=False ) + test4_pred = mtl_model.predict( + test_dataloader4, return_action_outputs=True, return_loss=False + ) assert test2_pred["uids"] == test3_pred["uids"] assert False not in [ @@ -295,6 +409,11 @@ def forward(self, input): for idx in range(len(test2_pred["outputs"]["task2"]["_input__data"])) ] + assert len(test4_pred["outputs"]["task3"]["input1_t3_out"]["image_pil"]) == 10 + assert isinstance( + test4_pred["outputs"]["task3"]["input1_t3_out"]["image_pil"], np.ndarray + ) + test4_pred = mtl_model.predict(test_dataloader2, return_action_outputs=False) assert "outputs" not in test4_pred From 2010890ba1f168346233f5fd63fd969c7e3edbf1 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Mon, 1 Nov 2021 16:53:47 -0700 Subject: [PATCH 27/29] [fix] bug in recursive move_to_device calls --- src/emmental/model.py | 4 +++- src/emmental/utils/utils.py | 11 ++++++++--- tests/e2e/test_e2e.py | 1 - 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index d4eda35..8c0bee0 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -404,7 +404,9 @@ def forward( # type: ignore ): for action_name, output_index in self.action_outputs[task_name]: action_output = output_dict[action_name][output_index] - action_output = move_to_device(action_output, -1) + action_output = move_to_device( + action_output, -1, detach=True, convert_to_numpy=True + ) out_dict[task_name][f"{action_name}_{output_index}"] = action_output if return_action_outputs: diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index 5633fc0..47da582 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -157,11 +157,16 @@ def move_to_device( obj = obj.numpy() return obj elif isinstance(obj, dict): - return {key: move_to_device(value, device) for key, value in obj.items()} + return { + key: move_to_device(value, device, detach, convert_to_numpy) + for key, value in obj.items() + } elif isinstance(obj, list): - return [move_to_device(item, device) for item in obj] + return [move_to_device(item, device, detach, convert_to_numpy) for item in obj] elif isinstance(obj, tuple): - return tuple([move_to_device(item, device) for item in obj]) + return tuple( + [move_to_device(item, device, detach, convert_to_numpy) for item in obj] + ) else: return obj diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 9b9ff0b..0191d2f 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -259,7 +259,6 @@ def __init__(self): self.linear = nn.Linear(8, 2) def forward(self, input): - print(input) return self.linear(input["image_pil"]) tasks = [ From 7f76dbf539444e4e92e532c87d1fd5522f114286 Mon Sep 17 00:00:00 2001 From: ANarayan Date: Mon, 1 Nov 2021 23:46:43 -0700 Subject: [PATCH 28/29] [fix] add more tests for coverage, and add types func def --- src/emmental/model.py | 2 +- src/emmental/utils/utils.py | 2 +- tests/e2e/test_e2e.py | 3 +-- tests/utils/test_utils.py | 16 ++++++++++++++++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index 8c0bee0..94c836f 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -531,7 +531,7 @@ def predict( if out_dict[task_name][action_name] == []: out_dict[task_name][action_name] = ( out_bdict[task_name][action_name] - if (out_dict[task_name][action_name] == []) + if out_dict[task_name][action_name] == [] else merge_objects( out_dict[task_name][action_name], out_bdict[task_name][action_name], diff --git a/src/emmental/utils/utils.py b/src/emmental/utils/utils.py index 47da582..17501e2 100644 --- a/src/emmental/utils/utils.py +++ b/src/emmental/utils/utils.py @@ -122,7 +122,7 @@ def pred_to_prob(preds: ndarray, n_classes: int) -> ndarray: def move_to_device( - obj: Any, + obj: Union[Tensor, ndarray, dict, list, tuple], device: Optional[Union[int, str, torch.device]] = -1, detach: bool = False, convert_to_numpy: bool = False, diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 0191d2f..f07f397 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -407,8 +407,7 @@ def forward(self, input): ) for idx in range(len(test2_pred["outputs"]["task2"]["_input__data"])) ] - - assert len(test4_pred["outputs"]["task3"]["input1_t3_out"]["image_pil"]) == 10 + assert test4_pred["outputs"]["task3"]["input1_t3_out"]["image_pil"].shape == (10, 8) assert isinstance( test4_pred["outputs"]["task3"]["input1_t3_out"]["image_pil"], np.ndarray ) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 9d8b93a..c8656bc 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -81,6 +81,10 @@ def test_merge_objects(caplog): merge_objects(torch.Tensor(), torch.Tensor([2, 3])), torch.Tensor([2, 3]), ) + assert torch.equal( + merge_objects(torch.Tensor([2, 3]), torch.Tensor()), + torch.Tensor([2, 3]), + ) assert merge_objects( torch.zeros((128, 256)), torch.zeros((64, 256)) ).shape == torch.Size([192, 256]) @@ -91,6 +95,18 @@ def test_merge_objects(caplog): assert np.array_equal( merge_objects(np.array([]), np.array([2, 3])), np.array([2, 3]) ) + assert np.array_equal( + merge_objects(np.zeros((2, 3)), np.array([])), np.zeros((2, 3)) + ) + + assert 1 == merge_objects(1, 2) + + try: + merge_objects([], torch.Tensor(1)) + assert False + except (TypeError): + assert True + assert merge_objects({"a": [1, 2]}, {"a": [2, 3]}) == {"a": [1, 2, 2, 3]} assert merge_objects({"a": [1, 2]}, {}) == {"a": [1, 2]} assert merge_objects({}, {"a": [1, 2]}) == {"a": [1, 2]} From 044ed8c9bc14e539c7beb79c890b8e0b46dec3bd Mon Sep 17 00:00:00 2001 From: ANarayan Date: Wed, 3 Nov 2021 10:27:06 -0700 Subject: [PATCH 29/29] [fix] model action output dict assignment --- src/emmental/model.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/emmental/model.py b/src/emmental/model.py index 94c836f..ddac8a4 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -528,15 +528,14 @@ def predict( if return_action_outputs and out_bdict: for task_name in out_bdict.keys(): for action_name in out_bdict[task_name].keys(): - if out_dict[task_name][action_name] == []: - out_dict[task_name][action_name] = ( - out_bdict[task_name][action_name] - if out_dict[task_name][action_name] == [] - else merge_objects( - out_dict[task_name][action_name], - out_bdict[task_name][action_name], - ) + out_dict[task_name][action_name] = ( + out_bdict[task_name][action_name] + if out_dict[task_name][action_name] == [] + else merge_objects( + out_dict[task_name][action_name], + out_bdict[task_name][action_name], ) + ) # Calculate average loss if return_loss: