Skip to content

Conversation

@vloison
Copy link
Contributor

@vloison vloison commented Jun 25, 2025

DomainAwareNet implements a classifier. The goal is to implement regression too.

To do so, I replaced DomainAwareNet with a mother class _DomainAwareNet and two classes DomainAwareNetRegressor and DomainNetClassifier that inherit from _DomainAwareNet.
This is complementary to Arthur's work on Issue 249 . He will add a DomainNetBinaryClassifier, which will also inherit from _DomainAwareNet.

Done:

  • Separate three classes.
  • Mark DomainAwareNet as deprecated.
  • Replace calls to DomainAwareNet by calls to DomainAwareNetClassifier in the module.

To do:

  • a script to try DomainAwareNetRegressor.

Bonus: fixed tests in test_utils.py, they were too restrictive on the error message and caused the tests to fail, although the code had the expected behavior.

@codecov
Copy link

codecov bot commented Jun 25, 2025

Codecov Report

❌ Patch coverage is 84.64730% with 37 lines in your changes missing coverage. Please review.
✅ Project coverage is 94.80%. Comparing base (8ccfd75) to head (4d2b5af).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #338      +/-   ##
==========================================
- Coverage   96.19%   94.80%   -1.39%     
==========================================
  Files          63       51      -12     
  Lines        7020     6223     -797     
==========================================
- Hits         6753     5900     -853     
- Misses        267      323      +56     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@antoinecollas antoinecollas self-requested a review June 26, 2025 08:11
Copy link
Collaborator

@antoinecollas antoinecollas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the PR! It's on the right track!

The nexts steps are:

  1. I noticed many diffs that look like autoformatting. Do you have such a tool in your editor? You should revert these modifications.
  2. Add tests for the regressor

Comment on lines -377 to +378
"""Modified CrossEntropyLoss as described in (29) from [35]_ with label smoothing.
"""Modified CrossEntropyLoss as described in (29) from [35]_ with label
smoothing.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this docstring now on two lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the line seemed too long, do you want me to revert it?

Comment on lines -404 to +406
The modified CrossEntropyLoss with label smoothing applied to predictions.
The modified CrossEntropyLoss with label smoothing applied to
predictions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment on lines -32 to +33
_EMPTY_INT_ = torch.tensor([],dtype=torch.int64)
_EMPTY_INT_ = torch.tensor([], dtype=torch.int64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like a formatting typo

_DEFAULT_SAMPLE_DOMAIN_ = 0
_NO_LABEL_ = -1


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

pass



Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment on lines 827 to 780
A dictionary containing 'X', 'sample_domain', and optionally 'y' and 'sample_weight' as numpy arrays.
A dictionary containing 'X', 'sample_domain', and optionally 'y'
and 'sample_weight' as numpy arrays.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

"""
A domain-aware neural network classifier with sample weight support.
This class extends NeuralNetClassifier to handle domain-specific input data
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skorch.classifier.NeuralNetClassifier

Comment on lines +1009 to +1023
def predict_proba(
self,
X: Union[Dict, torch.Tensor, np.ndarray, Dataset],
sample_domain: Union[torch.Tensor, np.ndarray] = None,
sample_weight: Union[torch.Tensor, np.ndarray] = None,
allow_source: bool = False,
**predict_params
):
return super().predict_DA_proba(
X,
sample_domain,
sample_weight,
allow_source,
**predict_params
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you do not want to predict probabilities with a regressor right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, conceptually, we don't. In the case of a regressor, predict_proba returns the same as predict. I kept it this way to be consistent with the skorch.regressor API.

Comment on lines 1025 to 1244
X = check_array(X,
ensure_2d=False,
allow_nd=True,
ensure_min_samples=0,
ensure_min_features=0,
)
X = check_array(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the remaining of the file, you have many modifications just because of autoformatting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand why autoformatting would need to be reverted?

@antoinecollas antoinecollas changed the title [ENH] DomainAwareNet supports regression [WIP] DomainAwareNet supports regression Jun 26, 2025
@antoinecollas
Copy link
Collaborator

Also, please add the DomainAwareNet class, wrapping DomainAwareNetClassifier and marking it as deprecated, for consistency with previous releases.

You can use the decorator @deprecated to do this and put a warning in the docstring.

if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the next week should have a parameter to chose between classifier, regression on binary classification

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reasoning was to have one class per task, and users would call one class depending on their use case. This parameter option seems nice though :)



class DomainAwareNet(NeuralNetClassifier, _DAMetadataRequesterMixin):
class _DomainAwareNet(NeuralNet, _DAMetadataRequesterMixin):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we dont need the underscore here. It could be use by user

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure when the user would need to call _DomainAwareNet. The reasoning is that the user would call one of DomainAwareNetClassifier, DomainAwareNetBinaryClassifier, or DomainAwareNetRegressor, depending on their intended use. These three classes inherit from _DomainAwareNet to avoid redundancy.

return nonlin

def predict_proba(
def predict_DA_proba(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think you remove it and add it only to DomainAwareNetClassiffier

@vloison
Copy link
Contributor Author

vloison commented Sep 2, 2025

Also, please add the DomainAwareNet class, wrapping DomainAwareNetClassifier and marking it as deprecated, for consistency with previous releases.

You can use the decorator @deprecated to do this and put a warning in the docstring.

Done. I didn't know this decorator, thanks for the discovery!

@vloison
Copy link
Contributor Author

vloison commented Sep 2, 2025

Based on the comments, it seems like a couple of choices need to be made.

  • Do we want to follow the skorch API (implemented now) or remove predict_proba from the regressor?
  • Do we want the user to call a different class based on their use case, like in skorch (implemented now), or to call a single class with a parameter for the use case? I think the second option will probably be more code-heavy, but maybe more user-friendly.

All options seem good, let me know what you want :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants