Skip to content

Add Empirical#63

Open
gvcallen wants to merge 5 commits intolockwo:mainfrom
gvcallen:empirical
Open

Add Empirical#63
gvcallen wants to merge 5 commits intolockwo:mainfrom
gvcallen:empirical

Conversation

@gvcallen
Copy link
Copy Markdown

@gvcallen gvcallen commented Apr 4, 2026

This PR adds AbstractEmpirical, Empirical and WeightedEmpirical distributions.

These distributions encapsulate a set of observed samples of a variable. The math was mirrored off tfp.distributions.Empirical with additional support for weighted samples.

@gvcallen gvcallen changed the title Add Empirical distributions Add Empirical Apr 4, 2026
self.rtol == 0, self.atol, self.atol + self.rtol * jnp.abs(value)
)

def sample_and_log_prob(self, key: Key[Array, ""]) -> tuple[Array, Array]:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for these functions (that do the same as the default approach), we can use the mixin approach to inheritance (e.g. inherit from AbstractSampleLogProbDistribution as well, this can be done for multiple classes)

import jax

# Must be set before any JAX arrays are initialized
jax.config.update("jax_enable_x64", True)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if memory serves, this can be spotty when enabled in pytest files, if you need it, we can put in a conftest.py

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think making sure the numerics check out for non 0 values of empirical rtol/atol should have additional tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants