Skip to content

Conversation

@Jatkingmodern
Copy link

Enhanced the Variational Bayesian Last Layers implementation with consistent use of torch, improved numerical stability, and added convenience helpers. Updated typing and docstrings for clarity.

Motivation

(Write your motivation here.)

Have you read the Contributing Guidelines on pull requests?

(Write your answer here.)

Test Plan

(Write your test plan here. If you changed any code, please provide us with clear instructions on how you verified your changes work. Bonus points for screenshots and videos!)

Related PRs

(If this PR adds or changes functionality, please take some time to update the docs at https://github.com/meta-pytorch/botorch, and link to your PR here.)

Enhanced the Variational Bayesian Last Layers implementation with consistent use of torch, improved numerical stability, and added convenience helpers. Updated typing and docstrings for clarity.
@meta-cla meta-cla bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Nov 3, 2025
Copy link
Contributor

@esantorella esantorella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the improvements! I left some comments, but this basically looks good to me pending tests passing.

Original: vbll (https://github.com/VectorInstitute/vbll), MIT license.
Paper: "Variational Bayesian Last Layers" by Harrison et al., ICLR 2024
Enhancements:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these enhancements between the new version and the vbll original, or between the new version and the old BoTorch version? Or are the vbll original and the BoTorch version the same?


from dataclasses import dataclass
from typing import Callable
from typing import Callable, Union, Optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: In all the Python versions that BoTorch supports, you can use x | y instead of Union[x, y] and x | None instead of Optional[x]

def __init__(self, loc: Tensor, chol: Tensor):
"""Normal distribution.
"""
Diagonal Gaussian wrapper. 'scale' is interpreted as the std-dev vector.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the docstring!

Comment on lines +44 to +45
# ensure shape broadcastability but keep behavior identical to torch.Normal
super().__init__(loc, scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by this comment. It looks to me like there's no need for an init method at all since it's just calling the super class's init.

"""
Variational Bayesian Linear Regression
Parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to keep these in, although BoTorch uses google-style docstrings with an "Args:" block so that things are rendered correctly by Sphinx

Comment on lines +424 to +430
else:
# DenseNormalPrec expects tril of precision; we accept passed tril as precision-cholesky
return DenseNormalPrec(self.W_mean, tril)
elif self.W_dist is LowRankNormal:
return LowRankNormal(self.W_mean, self.W_offdiag, cov_diag)
else:
raise RuntimeError("Unsupported W distribution type")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
else:
# DenseNormalPrec expects tril of precision; we accept passed tril as precision-cholesky
return DenseNormalPrec(self.W_mean, tril)
elif self.W_dist is LowRankNormal:
return LowRankNormal(self.W_mean, self.W_offdiag, cov_diag)
else:
raise RuntimeError("Unsupported W distribution type")
# DenseNormalPrec expects tril of precision; we accept passed tril as precision-cholesky
return DenseNormalPrec(self.W_mean, tril)
elif self.W_dist is LowRankNormal:
return LowRankNormal(self.W_mean, self.W_offdiag, cov_diag)
raise RuntimeError("Unsupported W distribution type")

Comment on lines +442 to +450
def sample_predictive(self, x: Tensor, num_samples: int = 1) -> Tensor:
"""
Draw samples from the predictive posterior:
returns tensor with shape (num_samples, batch..., out_features)
"""
pred = self.predictive(x)
# Distribution supports sample with given sample_shape
samples = pred.rsample(sample_shape=(num_samples,))
return samples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used?

@esantorella esantorella self-assigned this Nov 4, 2025
@meta-codesync
Copy link

meta-codesync bot commented Nov 4, 2025

@esantorella has imported this pull request. If you are a Meta employee, you can view this in D86228576.

Copy link
Contributor

@esantorella esantorella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'm seeing a bunch of lint issues -- could you auto-format this following the contributing guidelines? There are also test failures.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed Do not delete this pull request or issue due to inactivity.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants