Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2868e70
[add] support for action outputs of type dict
ANarayan Oct 29, 2021
2a11c33
[add] merge_objects function
ANarayan Oct 29, 2021
bcb2a30
[format] imports
ANarayan Oct 29, 2021
38b30c9
[format] imports
ANarayan Oct 29, 2021
cc214a0
[add] tests for merge_objects
ANarayan Oct 29, 2021
8ea9b2c
[reformat] utils.py
ANarayan Oct 29, 2021
5134f8a
[add] return type for merge_objects
ANarayan Oct 29, 2021
af11025
[add] docstring to merge_objects
ANarayan Oct 29, 2021
aeda146
[add] docstring to merge_objects
ANarayan Oct 29, 2021
73643ea
[add] change docstring mood inmerge_objects
ANarayan Oct 29, 2021
e548b77
[reformat]
ANarayan Oct 29, 2021
2ffebb7
[reformat] and add comment
ANarayan Oct 29, 2021
dc427ce
[fix] merge objects utils and check for empty value
ANarayan Oct 29, 2021
f1af3cc
[merge] w/upstream master
ANarayan Oct 29, 2021
b873a57
[merge] w/upstream master
ANarayan Oct 29, 2021
fccc45c
[fix] merge_objects np.array test and merge of list
ANarayan Oct 29, 2021
1a1a6a3
[add] test for tuple type in merge objects
ANarayan Oct 29, 2021
668271d
[add] fix input type
ANarayan Oct 29, 2021
3f2b6e7
[add] fix input type
ANarayan Oct 29, 2021
aeb61f8
[add] fix input type
ANarayan Oct 30, 2021
a9a2588
[fix] type mismatch
ANarayan Oct 30, 2021
94a748f
[fix] type mismatch
ANarayan Oct 30, 2021
d2efc22
[fix] type mismatch
ANarayan Oct 30, 2021
4162e32
change move_object logic out of if/else clause in forward call
ANarayan Oct 30, 2021
5c7e2de
[add] more tests for array and tensor, and fix 1D merge edge case
ANarayan Oct 30, 2021
0a9fd2d
[add] detach and numpy convert to move_to_device function
ANarayan Nov 1, 2021
edd56dd
[add] e2e test for outputs which are dicts
ANarayan Nov 1, 2021
2010890
[fix] bug in recursive move_to_device calls
ANarayan Nov 1, 2021
7f76dbf
[fix] add more tests for coverage, and add types func def
ANarayan Nov 2, 2021
044ed8c
[fix] model action output dict assignment
ANarayan Nov 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/emmental/contrib/slicing/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def ce_loss(
Loss.
"""
return F.cross_entropy(
intermediate_output_dict[module_name][0],
Y.view(-1) - 1,
weight,
intermediate_output_dict[module_name][0], Y.view(-1) - 1, weight
)


Expand Down
9 changes: 2 additions & 7 deletions src/emmental/logging/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ def checkpoint(

# Save optimizer state
optimizer_path = f"{self.checkpoint_path}/checkpoint_{iteration}.optimizer.pth"
optimizer_dict = {
"optimizer": optimizer.state_dict(),
}
optimizer_dict = {"optimizer": optimizer.state_dict()}
torch.save(optimizer_dict, optimizer_path)

# Save lr_scheduler state
Expand All @@ -165,10 +163,7 @@ def checkpoint(
f"{self.checkpoint_path}/best_model_"
f"{metric.replace('/', '_')}.model.pth"
)
copyfile(
model_path,
best_metric_model_path,
)
copyfile(model_path, best_metric_model_path)
logger.info(
f"Save best model of metric {metric} to {best_metric_model_path}"
)
Expand Down
5 changes: 1 addition & 4 deletions src/emmental/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,7 @@ def check_config() -> None:
Meta.config["logging_config"]["evaluation_freq"] = new_evaluation_freq

if (
Meta.config["logging_config"]["counter_unit"]
in [
"epoch",
]
Meta.config["logging_config"]["counter_unit"] in ["epoch"]
and isinstance(Meta.config["logging_config"]["evaluation_freq"], int)
and Meta.config["logging_config"]["writer_config"]["write_loss_per_step"]
):
Expand Down
35 changes: 17 additions & 18 deletions src/emmental/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from emmental.utils.utils import (
array_to_numpy,
construct_identifier,
merge_objects,
move_to_device,
prob_to_pred,
)
Expand Down Expand Up @@ -323,7 +324,7 @@ def forward( # type: ignore
Dict[str, Tensor],
Dict[str, Union[ndarray, List[ndarray]]],
Dict[str, Union[ndarray, List[ndarray]]],
Dict[str, Dict[str, Union[ndarray, List]]],
Dict[str, Dict[str, Union[ndarray, List, int, float, Dict]]],
],
Tuple[
Dict[str, List[str]],
Expand Down Expand Up @@ -356,7 +357,7 @@ def forward( # type: ignore
prob_dict: Dict[str, Union[ndarray, List[ndarray]]] = (
defaultdict(list) if return_probs else None
)
out_dict: Dict[str, Dict[str, Union[ndarray, List]]] = (
out_dict: Dict[str, Dict[str, Union[ndarray, List, int, float, Dict]]] = (
defaultdict(lambda: defaultdict(list)) if return_action_outputs else None
)

Expand All @@ -378,8 +379,7 @@ def forward( # type: ignore
loss_dict[task_name] = self.loss_funcs[task_name](
output_dict,
move_to_device(
Y_dict[label_name],
Meta.config["model_config"]["device"],
Y_dict[label_name], Meta.config["model_config"]["device"]
)
if Y_dict is not None and label_name is not None
else None,
Expand All @@ -403,9 +403,11 @@ def forward( # type: ignore
and self.action_outputs[task_name] is not None
):
for action_name, output_index in self.action_outputs[task_name]:
out_dict[task_name][f"{action_name}_{output_index}"] = (
output_dict[action_name][output_index].cpu().detach().numpy()
action_output = output_dict[action_name][output_index]
action_output = move_to_device(
action_output, -1, detach=True, convert_to_numpy=True
)
out_dict[task_name][f"{action_name}_{output_index}"] = action_output

if return_action_outputs:
return uid_dict, loss_dict, prob_dict, gold_dict, out_dict
Expand Down Expand Up @@ -446,7 +448,7 @@ def predict(
pred_dict: Dict[str, Union[ndarray, List[ndarray]]] = (
defaultdict(list) if return_preds else None
)
out_dict: Dict[str, Dict[str, List[Union[ndarray, int, float]]]] = (
out_dict: Dict[str, Dict[str, Union[ndarray, List, int, float, Dict]]] = (
defaultdict(lambda: defaultdict(list)) if return_action_outputs else None
)
loss_dict: Dict[str, Union[ndarray, float]] = (
Expand Down Expand Up @@ -526,8 +528,13 @@ def predict(
if return_action_outputs and out_bdict:
for task_name in out_bdict.keys():
for action_name in out_bdict[task_name].keys():
out_dict[task_name][action_name].extend(
out_dict[task_name][action_name] = (
out_bdict[task_name][action_name]
if out_dict[task_name][action_name] == []
else merge_objects(
out_dict[task_name][action_name],
out_bdict[task_name][action_name],
)
)

# Calculate average loss
Expand All @@ -536,11 +543,7 @@ def predict(
if not isinstance(loss_dict[task_name], list):
loss_dict[task_name] /= len(uid_dict[task_name])

res = {
"uids": uid_dict,
"golds": gold_dict,
"losses": loss_dict,
}
res = {"uids": uid_dict, "golds": gold_dict, "losses": loss_dict}

if return_probs:
for task_name in prob_dict.keys():
Expand Down Expand Up @@ -734,11 +737,7 @@ def save(
if Meta.config["meta_config"]["verbose"] and verbose:
logger.info(f"[{self.name}] Model saved in {model_path}")

def load(
self,
model_path: str,
verbose: bool = True,
) -> None:
def load(self, model_path: str, verbose: bool = True) -> None:
"""Load model state_dict from file and reinitialize the model weights.

Args:
Expand Down
5 changes: 1 addition & 4 deletions src/emmental/utils/parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,10 +819,7 @@ def parse_args(parser: Optional[ArgumentParser] = None) -> ArgumentParser:
)

logging_config.add_argument(
"--wandb_run_name",
type=nullable_string,
default=None,
help="Wandb run name",
"--wandb_run_name", type=nullable_string, default=None, help="Wandb run name"
)

logging_config.add_argument(
Expand Down
94 changes: 89 additions & 5 deletions src/emmental/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def pred_to_prob(preds: ndarray, n_classes: int) -> ndarray:


def move_to_device(
obj: Any, device: Optional[Union[int, str, torch.device]] = -1
obj: Union[Tensor, ndarray, dict, list, tuple],
device: Optional[Union[int, str, torch.device]] = -1,
detach: bool = False,
convert_to_numpy: bool = False,
) -> Any:
"""Move object to specified device.

Expand All @@ -147,17 +150,98 @@ def move_to_device(
device = torch.device("cpu")

if isinstance(obj, torch.Tensor):
return obj.to(device)
obj.to(device)
if detach:
obj = obj.detach()
if convert_to_numpy:
obj = obj.numpy()
return obj
elif isinstance(obj, dict):
return {key: move_to_device(value, device) for key, value in obj.items()}
return {
key: move_to_device(value, device, detach, convert_to_numpy)
for key, value in obj.items()
}
elif isinstance(obj, list):
return [move_to_device(item, device) for item in obj]
return [move_to_device(item, device, detach, convert_to_numpy) for item in obj]
elif isinstance(obj, tuple):
return tuple([move_to_device(item, device) for item in obj])
return tuple(
[move_to_device(item, device, detach, convert_to_numpy) for item in obj]
)
else:
return obj


def merge_objects(obj_1: Any, obj_2: Any) -> Any:
"""Merge two objects of the same type.

Given two objects of the same type and structure, merges the second object
into the first object. If either of the objects is empty, the non-empty
object is returned. Supported types include torch tensors, numpy arrays
lists, dicts and tuples. For tensors and arrays, objects are merged
along the 1st dimension:

obj_1: torch.Tensor([1,2]), obj_2: torch.Tensor([2,3])
merged object: torch.Tensor([[1,2],[2,3]])

Args:
obj_1: first object.
obj_2: second object to be merged into the first object.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more description about this function here?

Returns:
an object reflecting the merged output of the two inputs.
"""
if type(obj_1) != type(obj_2):
raise TypeError(
f"Cannot merge object of type {type(obj_1)} "
f"with object of type {type(obj_2)}."
)
if isinstance(obj_1, torch.Tensor):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need check the two objects are the same type right?

# empty edge case
if not obj_1.size()[0]:
return obj_2
elif not obj_2.size()[0]:
return obj_1

# unsqueeze of object is 1D and not empty
if len(obj_1.shape) == 1:
obj_1 = obj_1.unsqueeze(0)
if len(obj_2.shape) == 1:
obj_2 = obj_2.unsqueeze(0)
return torch.cat([obj_1, obj_2])
elif isinstance(obj_1, np.ndarray):
# empty edge case
if not obj_1.size:
return obj_2
elif not obj_2.size:
return obj_1

# expand if array has 1 dimension
if len(obj_1.shape) == 1:
obj_1 = np.expand_dims(obj_1, axis=0)
if len(obj_2.shape) == 1:
obj_2 = np.expand_dims(obj_2, axis=0)
return np.concatenate((obj_1, obj_2))
elif isinstance(obj_1, list):
obj_1.extend(obj_2)
return obj_1
elif isinstance(obj_1, dict):
if not obj_1:
return obj_2
elif not obj_2:
return obj_1

for key, value in obj_1.items():
obj_1[key] = merge_objects(value, obj_2[key])
return obj_1
elif isinstance(obj_1, tuple):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tuple might have more than 2 objects.

merged_tuple_vals = []
for idx in range(len(obj_1)):
merged_tuple_vals.append(merge_objects(obj_1[idx], obj_2[idx]))
return tuple(merged_tuple_vals)
else:
return obj_1


def array_to_numpy(
array: Union[ndarray, List[Any], Tensor], flatten: bool = False
) -> ndarray:
Expand Down
Loading