[model/predict] Add support for different types of action outputs #113
[model/predict] Add support for different types of action outputs #113
Conversation
Codecov Report
@@ Coverage Diff @@
## master #113 +/- ##
==========================================
+ Coverage 91.21% 91.31% +0.10%
==========================================
Files 40 40
Lines 1991 2039 +48
Branches 425 446 +21
==========================================
+ Hits 1816 1862 +46
Misses 101 101
- Partials 74 76 +2
Flags with carried forward coverage won't be shown. Click here to find out more.
|
src/emmental/model.py
Outdated
| output_dict[action_name][output_index].cpu().detach().numpy() | ||
| ) | ||
| action_output = output_dict[action_name][output_index] | ||
| if isinstance(action_output, dict): |
There was a problem hiding this comment.
Can we move this if-else to move_to_device?
src/emmental/model.py
Outdated
| out_dict[task_name][action_name].extend( | ||
| out_bdict[task_name][action_name] | ||
| ) | ||
| if out_dict[task_name][action_name] == []: |
There was a problem hiding this comment.
Can we move this to merge_objects?
There was a problem hiding this comment.
problem is that out_dict[task_name][action_name] is always going to be a list, but out_bdict[task_name][action_name] can be any type. Thus, the merge_objects function will just raise a type error because of type mismatch
src/emmental/utils/utils.py
Outdated
|
|
||
| Args: | ||
| obj_1: first object. | ||
| obj_2: seecond object to be merged into the first object. |
| Args: | ||
| obj_1: first object. | ||
| obj_2: seecond object to be merged into the first object. | ||
|
|
There was a problem hiding this comment.
Can you add more description about this function here?
| Returns: | ||
| an object reflecting the merged output of the two inputs. | ||
| """ | ||
| if isinstance(obj_1, torch.Tensor): |
There was a problem hiding this comment.
We need check the two objects are the same type right?
src/emmental/utils/utils.py
Outdated
| return obj_1 | ||
| else: | ||
| for key, value in obj_1.items(): | ||
| if isinstance(value, torch.Tensor): |
There was a problem hiding this comment.
@senwu isn't the type of a np.array <class 'numpy.ndarray'>
| elif isinstance(obj_1, list): | ||
| obj_1.extend(obj_2) | ||
| return obj_1 | ||
| elif isinstance(obj_1, tuple): |
There was a problem hiding this comment.
Tuple might have more than 2 objects.
tests/utils/test_utils.py
Outdated
|
|
||
| assert torch.equal( | ||
| merge_objects(torch.Tensor([1, 2]), torch.Tensor([2, 3])), | ||
| torch.Tensor([1, 2, 2, 3]), |
There was a problem hiding this comment.
This should be torch.Tensor[[1, 2], [2, 3]]? We want to merge based on the first dim.
tests/utils/test_utils.py
Outdated
| torch.Tensor([1, 2, 2, 3]), | ||
| ) | ||
| assert np.array_equal( | ||
| merge_objects(np.array([1, 2]), np.array([2, 3])), np.array([1, 2, 2, 3]) |
| [3, 4, 4, 5], | ||
| ) | ||
| assert merge_objects([1, 2, 3], [2, 3, 4]) == [1, 2, 3, 2, 3, 4] | ||
|
|
There was a problem hiding this comment.
can we have more use cases about np.array, torch.Tensor?
src/emmental/model.py
Outdated
| out_dict[task_name][action_name].extend( | ||
| out_bdict[task_name][action_name] | ||
| ) | ||
| if out_dict[task_name][action_name] == []: |
|
|
||
| assert len(test4_pred["outputs"]["task3"]["input1_t3_out"]["image_pil"]) == 10 | ||
| assert isinstance( | ||
| test4_pred["outputs"]["task3"]["input1_t3_out"]["image_pil"], np.ndarray |
There was a problem hiding this comment.
for #113 (comment), we can't get rid of it because out_dict[task_name][action_name] and out_bdict[task_name][action_name] may not be of the same type for the empty check to be handled by merge_objects
This PR adds support processing action outputs of different types. The current implementation expects that all action outputs are of type tensor. Here, we expand support for outputs of type dict. This PR also adds a new utils function (
merge_objects) which merges different types of objects.