-
Notifications
You must be signed in to change notification settings - Fork 454
Enhance Variational Bayesian Last Layers implementation #3067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Enhanced the Variational Bayesian Last Layers implementation with consistent use of torch, improved numerical stability, and added convenience helpers. Updated typing and docstrings for clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the improvements! I left some comments, but this basically looks good to me pending tests passing.
| Original: vbll (https://github.com/VectorInstitute/vbll), MIT license. | ||
| Paper: "Variational Bayesian Last Layers" by Harrison et al., ICLR 2024 | ||
| Enhancements: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these enhancements between the new version and the vbll original, or between the new version and the old BoTorch version? Or are the vbll original and the BoTorch version the same?
|
|
||
| from dataclasses import dataclass | ||
| from typing import Callable | ||
| from typing import Callable, Union, Optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: In all the Python versions that BoTorch supports, you can use x | y instead of Union[x, y] and x | None instead of Optional[x]
| def __init__(self, loc: Tensor, chol: Tensor): | ||
| """Normal distribution. | ||
| """ | ||
| Diagonal Gaussian wrapper. 'scale' is interpreted as the std-dev vector. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the docstring!
| # ensure shape broadcastability but keep behavior identical to torch.Normal | ||
| super().__init__(loc, scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused by this comment. It looks to me like there's no need for an init method at all since it's just calling the super class's init.
| """ | ||
| Variational Bayesian Linear Regression | ||
| Parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to keep these in, although BoTorch uses google-style docstrings with an "Args:" block so that things are rendered correctly by Sphinx
| else: | ||
| # DenseNormalPrec expects tril of precision; we accept passed tril as precision-cholesky | ||
| return DenseNormalPrec(self.W_mean, tril) | ||
| elif self.W_dist is LowRankNormal: | ||
| return LowRankNormal(self.W_mean, self.W_offdiag, cov_diag) | ||
| else: | ||
| raise RuntimeError("Unsupported W distribution type") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit
| else: | |
| # DenseNormalPrec expects tril of precision; we accept passed tril as precision-cholesky | |
| return DenseNormalPrec(self.W_mean, tril) | |
| elif self.W_dist is LowRankNormal: | |
| return LowRankNormal(self.W_mean, self.W_offdiag, cov_diag) | |
| else: | |
| raise RuntimeError("Unsupported W distribution type") | |
| # DenseNormalPrec expects tril of precision; we accept passed tril as precision-cholesky | |
| return DenseNormalPrec(self.W_mean, tril) | |
| elif self.W_dist is LowRankNormal: | |
| return LowRankNormal(self.W_mean, self.W_offdiag, cov_diag) | |
| raise RuntimeError("Unsupported W distribution type") |
| def sample_predictive(self, x: Tensor, num_samples: int = 1) -> Tensor: | ||
| """ | ||
| Draw samples from the predictive posterior: | ||
| returns tensor with shape (num_samples, batch..., out_features) | ||
| """ | ||
| pred = self.predictive(x) | ||
| # Distribution supports sample with given sample_shape | ||
| samples = pred.rsample(sample_shape=(num_samples,)) | ||
| return samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used?
|
@esantorella has imported this pull request. If you are a Meta employee, you can view this in D86228576. |
esantorella
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I'm seeing a bunch of lint issues -- could you auto-format this following the contributing guidelines? There are also test failures.
Enhanced the Variational Bayesian Last Layers implementation with consistent use of torch, improved numerical stability, and added convenience helpers. Updated typing and docstrings for clarity.
Motivation
(Write your motivation here.)
Have you read the Contributing Guidelines on pull requests?
(Write your answer here.)
Test Plan
(Write your test plan here. If you changed any code, please provide us with clear instructions on how you verified your changes work. Bonus points for screenshots and videos!)
Related PRs
(If this PR adds or changes functionality, please take some time to update the docs at https://github.com/meta-pytorch/botorch, and link to your PR here.)