-
Notifications
You must be signed in to change notification settings - Fork 17
[model/predict] Add support for different types of action outputs #113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ANarayan
wants to merge
30
commits into
main
Choose a base branch
from
flexibile-action-outputs
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
2868e70
[add] support for action outputs of type dict
ANarayan 2a11c33
[add] merge_objects function
ANarayan bcb2a30
[format] imports
ANarayan 38b30c9
[format] imports
ANarayan cc214a0
[add] tests for merge_objects
ANarayan 8ea9b2c
[reformat] utils.py
ANarayan 5134f8a
[add] return type for merge_objects
ANarayan af11025
[add] docstring to merge_objects
ANarayan aeda146
[add] docstring to merge_objects
ANarayan 73643ea
[add] change docstring mood inmerge_objects
ANarayan e548b77
[reformat]
ANarayan 2ffebb7
[reformat] and add comment
ANarayan dc427ce
[fix] merge objects utils and check for empty value
ANarayan f1af3cc
[merge] w/upstream master
ANarayan b873a57
[merge] w/upstream master
ANarayan fccc45c
[fix] merge_objects np.array test and merge of list
ANarayan 1a1a6a3
[add] test for tuple type in merge objects
ANarayan 668271d
[add] fix input type
ANarayan 3f2b6e7
[add] fix input type
ANarayan aeb61f8
[add] fix input type
ANarayan a9a2588
[fix] type mismatch
ANarayan 94a748f
[fix] type mismatch
ANarayan d2efc22
[fix] type mismatch
ANarayan 4162e32
change move_object logic out of if/else clause in forward call
ANarayan 5c7e2de
[add] more tests for array and tensor, and fix 1D merge edge case
ANarayan 0a9fd2d
[add] detach and numpy convert to move_to_device function
ANarayan edd56dd
[add] e2e test for outputs which are dicts
ANarayan 2010890
[fix] bug in recursive move_to_device calls
ANarayan 7f76dbf
[fix] add more tests for coverage, and add types func def
ANarayan 044ed8c
[fix] model action output dict assignment
ANarayan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -122,7 +122,10 @@ def pred_to_prob(preds: ndarray, n_classes: int) -> ndarray: | |
|
|
||
|
|
||
| def move_to_device( | ||
| obj: Any, device: Optional[Union[int, str, torch.device]] = -1 | ||
| obj: Union[Tensor, ndarray, dict, list, tuple], | ||
| device: Optional[Union[int, str, torch.device]] = -1, | ||
| detach: bool = False, | ||
| convert_to_numpy: bool = False, | ||
| ) -> Any: | ||
| """Move object to specified device. | ||
|
|
||
|
|
@@ -147,17 +150,98 @@ def move_to_device( | |
| device = torch.device("cpu") | ||
|
|
||
| if isinstance(obj, torch.Tensor): | ||
| return obj.to(device) | ||
| obj.to(device) | ||
| if detach: | ||
| obj = obj.detach() | ||
| if convert_to_numpy: | ||
| obj = obj.numpy() | ||
| return obj | ||
| elif isinstance(obj, dict): | ||
| return {key: move_to_device(value, device) for key, value in obj.items()} | ||
| return { | ||
| key: move_to_device(value, device, detach, convert_to_numpy) | ||
| for key, value in obj.items() | ||
| } | ||
| elif isinstance(obj, list): | ||
| return [move_to_device(item, device) for item in obj] | ||
| return [move_to_device(item, device, detach, convert_to_numpy) for item in obj] | ||
| elif isinstance(obj, tuple): | ||
| return tuple([move_to_device(item, device) for item in obj]) | ||
| return tuple( | ||
| [move_to_device(item, device, detach, convert_to_numpy) for item in obj] | ||
| ) | ||
| else: | ||
| return obj | ||
|
|
||
|
|
||
| def merge_objects(obj_1: Any, obj_2: Any) -> Any: | ||
| """Merge two objects of the same type. | ||
|
|
||
| Given two objects of the same type and structure, merges the second object | ||
| into the first object. If either of the objects is empty, the non-empty | ||
| object is returned. Supported types include torch tensors, numpy arrays | ||
| lists, dicts and tuples. For tensors and arrays, objects are merged | ||
| along the 1st dimension: | ||
|
|
||
| obj_1: torch.Tensor([1,2]), obj_2: torch.Tensor([2,3]) | ||
| merged object: torch.Tensor([[1,2],[2,3]]) | ||
|
|
||
| Args: | ||
| obj_1: first object. | ||
| obj_2: second object to be merged into the first object. | ||
|
|
||
| Returns: | ||
| an object reflecting the merged output of the two inputs. | ||
| """ | ||
| if type(obj_1) != type(obj_2): | ||
| raise TypeError( | ||
| f"Cannot merge object of type {type(obj_1)} " | ||
| f"with object of type {type(obj_2)}." | ||
| ) | ||
| if isinstance(obj_1, torch.Tensor): | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need check the two objects are the same type right? |
||
| # empty edge case | ||
| if not obj_1.size()[0]: | ||
| return obj_2 | ||
| elif not obj_2.size()[0]: | ||
| return obj_1 | ||
|
|
||
| # unsqueeze of object is 1D and not empty | ||
| if len(obj_1.shape) == 1: | ||
| obj_1 = obj_1.unsqueeze(0) | ||
| if len(obj_2.shape) == 1: | ||
| obj_2 = obj_2.unsqueeze(0) | ||
| return torch.cat([obj_1, obj_2]) | ||
| elif isinstance(obj_1, np.ndarray): | ||
| # empty edge case | ||
| if not obj_1.size: | ||
| return obj_2 | ||
| elif not obj_2.size: | ||
| return obj_1 | ||
|
|
||
| # expand if array has 1 dimension | ||
| if len(obj_1.shape) == 1: | ||
| obj_1 = np.expand_dims(obj_1, axis=0) | ||
| if len(obj_2.shape) == 1: | ||
| obj_2 = np.expand_dims(obj_2, axis=0) | ||
| return np.concatenate((obj_1, obj_2)) | ||
| elif isinstance(obj_1, list): | ||
| obj_1.extend(obj_2) | ||
| return obj_1 | ||
| elif isinstance(obj_1, dict): | ||
| if not obj_1: | ||
| return obj_2 | ||
| elif not obj_2: | ||
| return obj_1 | ||
|
|
||
| for key, value in obj_1.items(): | ||
| obj_1[key] = merge_objects(value, obj_2[key]) | ||
| return obj_1 | ||
| elif isinstance(obj_1, tuple): | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tuple might have more than 2 objects. |
||
| merged_tuple_vals = [] | ||
| for idx in range(len(obj_1)): | ||
| merged_tuple_vals.append(merge_objects(obj_1[idx], obj_2[idx])) | ||
| return tuple(merged_tuple_vals) | ||
| else: | ||
| return obj_1 | ||
|
|
||
|
|
||
| def array_to_numpy( | ||
| array: Union[ndarray, List[Any], Tensor], flatten: bool = False | ||
| ) -> ndarray: | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add more description about this function here?