-
Notifications
You must be signed in to change notification settings - Fork 34
[WIP] DomainAwareNet supports regression #338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #338 +/- ##
==========================================
- Coverage 96.19% 94.80% -1.39%
==========================================
Files 63 51 -12
Lines 7020 6223 -797
==========================================
- Hits 6753 5900 -853
- Misses 267 323 +56 🚀 New features to boost your workflow:
|
antoinecollas
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for the PR! It's on the right track!
The nexts steps are:
- I noticed many diffs that look like autoformatting. Do you have such a tool in your editor? You should revert these modifications.
- Add tests for the regressor
| """Modified CrossEntropyLoss as described in (29) from [35]_ with label smoothing. | ||
| """Modified CrossEntropyLoss as described in (29) from [35]_ with label | ||
| smoothing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this docstring now on two lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the line seemed too long, do you want me to revert it?
| The modified CrossEntropyLoss with label smoothing applied to predictions. | ||
| The modified CrossEntropyLoss with label smoothing applied to | ||
| predictions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
| _EMPTY_INT_ = torch.tensor([],dtype=torch.int64) | ||
| _EMPTY_INT_ = torch.tensor([], dtype=torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like a formatting typo
| _DEFAULT_SAMPLE_DOMAIN_ = 0 | ||
| _NO_LABEL_ = -1 | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
| pass | ||
|
|
||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
| A dictionary containing 'X', 'sample_domain', and optionally 'y' and 'sample_weight' as numpy arrays. | ||
| A dictionary containing 'X', 'sample_domain', and optionally 'y' | ||
| and 'sample_weight' as numpy arrays. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
skada/deep/base.py
Outdated
| """ | ||
| A domain-aware neural network classifier with sample weight support. | ||
| This class extends NeuralNetClassifier to handle domain-specific input data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skorch.classifier.NeuralNetClassifier
| def predict_proba( | ||
| self, | ||
| X: Union[Dict, torch.Tensor, np.ndarray, Dataset], | ||
| sample_domain: Union[torch.Tensor, np.ndarray] = None, | ||
| sample_weight: Union[torch.Tensor, np.ndarray] = None, | ||
| allow_source: bool = False, | ||
| **predict_params | ||
| ): | ||
| return super().predict_DA_proba( | ||
| X, | ||
| sample_domain, | ||
| sample_weight, | ||
| allow_source, | ||
| **predict_params | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you do not want to predict probabilities with a regressor right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, conceptually, we don't. In the case of a regressor, predict_proba returns the same as predict. I kept it this way to be consistent with the skorch.regressor API.
| X = check_array(X, | ||
| ensure_2d=False, | ||
| allow_nd=True, | ||
| ensure_min_samples=0, | ||
| ensure_min_features=0, | ||
| ) | ||
| X = check_array( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the remaining of the file, you have many modifications just because of autoformatting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand why autoformatting would need to be reverted?
|
Also, please add the You can use the decorator |
| if base_criterion is None: | ||
| base_criterion = torch.nn.CrossEntropyLoss() | ||
|
|
||
| net = DomainAwareNet( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the next week should have a parameter to chose between classifier, regression on binary classification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reasoning was to have one class per task, and users would call one class depending on their use case. This parameter option seems nice though :)
|
|
||
|
|
||
| class DomainAwareNet(NeuralNetClassifier, _DAMetadataRequesterMixin): | ||
| class _DomainAwareNet(NeuralNet, _DAMetadataRequesterMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we dont need the underscore here. It could be use by user
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure when the user would need to call _DomainAwareNet. The reasoning is that the user would call one of DomainAwareNetClassifier, DomainAwareNetBinaryClassifier, or DomainAwareNetRegressor, depending on their intended use. These three classes inherit from _DomainAwareNet to avoid redundancy.
skada/deep/base.py
Outdated
| return nonlin | ||
|
|
||
| def predict_proba( | ||
| def predict_DA_proba( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think you remove it and add it only to DomainAwareNetClassiffier
Done. I didn't know this decorator, thanks for the discovery! |
|
Based on the comments, it seems like a couple of choices need to be made.
All options seem good, let me know what you want :) |
DomainAwareNet implements a classifier. The goal is to implement regression too.
To do so, I replaced
DomainAwareNetwith a mother class_DomainAwareNetand two classesDomainAwareNetRegressorandDomainNetClassifierthat inherit from_DomainAwareNet.This is complementary to Arthur's work on Issue 249 . He will add a
DomainNetBinaryClassifier, which will also inherit from_DomainAwareNet.Done:
DomainAwareNetas deprecated.DomainAwareNetby calls toDomainAwareNetClassifierin the module.To do:
DomainAwareNetRegressor.Bonus: fixed tests in
test_utils.py, they were too restrictive on the error message and caused the tests to fail, although the code had the expected behavior.