From bbda138ccb53539b1bd87b5043f4abedaa981467 Mon Sep 17 00:00:00 2001 From: njtierney Date: Thu, 30 Jan 2025 14:10:32 +1100 Subject: [PATCH 1/3] First pass at adding deterministic (degenerate, point_mass) distribution. --- NAMESPACE | 1 + R/probability_distributions.R | 30 +++++++++++++++++++++++++++++- man/distributions.Rd | 3 +++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/NAMESPACE b/NAMESPACE index ede86f97e..74e6893df 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -182,6 +182,7 @@ export(colSums) export(cov2cor) export(cpu_only) export(destroy_greta_deps) +export(deterministic) export(diag) export(dirichlet) export(dirichlet_multinomial) diff --git a/R/probability_distributions.R b/R/probability_distributions.R index 1f1c35baa..62166644f 100644 --- a/R/probability_distributions.R +++ b/R/probability_distributions.R @@ -1193,6 +1193,27 @@ lkj_correlation_distribution <- R6Class( ) ) +deterministic_distribution <- R6Class( + "deterministic_distribution", + inherit = distribution_node, + public = list( + location = NA, + initialize = function(location, dim) { + location <- as.greta_array(location) + + dim <- check_dims(location, target_dim = dim) + super$initialize("deterministic", dim) + self$add_parameter(location, "location") + }, + + tf_distrib = function(parameters, dag) { + tfp$distributions$Deterministic( + loc = parameters$location + ) + } + ) +) + # module for export via .internals distribution_classes_module <- module( uniform_distribution, @@ -1222,7 +1243,8 @@ distribution_classes_module <- module( multinomial_distribution, categorical_distribution, dirichlet_distribution, - dirichlet_multinomial_distribution + dirichlet_multinomial_distribution, + deterministic_distribution ) # export constructors @@ -1569,3 +1591,9 @@ dirichlet_multinomial <- function(size, alpha, size, alpha, n_realisations, dimension ) } + +#' @rdname distributions +#' @export +deterministic <- function(location, dimension = NULL) { + distrib("deterministic", location, dimension) +} diff --git a/man/distributions.Rd b/man/distributions.Rd index 7c53e071a..32eb8c807 100644 --- a/man/distributions.Rd +++ b/man/distributions.Rd @@ -30,6 +30,7 @@ \alias{categorical} \alias{dirichlet} \alias{dirichlet_multinomial} +\alias{deterministic} \title{probability distributions} \usage{ uniform(min, max, dim = NULL) @@ -87,6 +88,8 @@ categorical(prob, n_realisations = NULL, dimension = NULL) dirichlet(alpha, n_realisations = NULL, dimension = NULL) dirichlet_multinomial(size, alpha, n_realisations = NULL, dimension = NULL) + +deterministic(location, dimension = NULL) } \arguments{ \item{min, max}{scalar values giving optional limits to \code{uniform} From 0d1835d84f1adc9d3161e33cb9175c4a64524428 Mon Sep 17 00:00:00 2001 From: njtierney Date: Thu, 30 Jan 2025 14:10:54 +1100 Subject: [PATCH 2/3] add stubs of test for deterministic distribution --- tests/testthat/test_distributions.R | 12 ++++++++++++ tests/testthat/test_iid_samples.R | 6 ++++++ tests/testthat/test_inference.R | 12 ++++++++++++ tests/testthat/test_mixture.R | 5 +++++ 4 files changed, 35 insertions(+) diff --git a/tests/testthat/test_distributions.R b/tests/testthat/test_distributions.R index 69f5d200c..e460c5cb3 100644 --- a/tests/testthat/test_distributions.R +++ b/tests/testthat/test_distributions.R @@ -389,6 +389,18 @@ test_that("dirichlet-multinomial distribution has correct density", { ) }) +# test_that("deterministic distribution has correct density", { +# skip_if_not(check_tf_version()) +# +# compare_distribution( +# greta::deterministic, +# # stats::dunif, +# # parameters = list(location), +# # x = runif(100, -2.1, -1.2) +# ) +# }) + + test_that("scalar-valued distributions can be defined in models", { skip_if_not(check_tf_version()) diff --git a/tests/testthat/test_iid_samples.R b/tests/testthat/test_iid_samples.R index 8c2d547ab..5c3d7fa2d 100644 --- a/tests/testthat/test_iid_samples.R +++ b/tests/testthat/test_iid_samples.R @@ -104,6 +104,12 @@ test_that("univariate samples are correct", { rf, parameters = list(df1 = 4, df2 = 1) ) + + # compare_iid_samples(deterministic, + # # runif, degenerate + # # parameters = list(min = -2, max = 3) + # ) + }) test_that("truncated univariate samples are correct", { diff --git a/tests/testthat/test_inference.R b/tests/testthat/test_inference.R index b70f2f152..bf570fd0f 100644 --- a/tests/testthat/test_inference.R +++ b/tests/testthat/test_inference.R @@ -298,6 +298,18 @@ test_that("mcmc supports slice sampler with single precision models", { )) }) +# test_that("mcmc works with deterministic distribution", { +# skip_if_not(check_tf_version()) +# set.seed(5) +# x <- deterministic(0) +# m <- model(x, precision = "single") +# expect_ok(draws <- mcmc(m, +# sampler = slice(), +# n_samples = 100, warmup = 100, +# verbose = FALSE +# )) +# }) + test_that("initials works", { skip_if_not(check_tf_version()) diff --git a/tests/testthat/test_mixture.R b/tests/testthat/test_mixture.R index 5dbdbcabf..0d61b3382 100644 --- a/tests/testthat/test_mixture.R +++ b/tests/testthat/test_mixture.R @@ -283,3 +283,8 @@ test_that("mixture of normals with varying weights has correct density", { dim = dim ) }) + +# +# test_that("mixture of deterministic and continuous has correct density", { +# +# }) From 27b2ade835cade982c0cd9b143eb801738afae51 Mon Sep 17 00:00:00 2001 From: njtierney Date: Wed, 12 Feb 2025 11:28:15 +0800 Subject: [PATCH 3/3] notes from deterministic distribution --- R/mixture.R | 2 +- tests/testthat/test_distributions.R | 26 ++++++++++++++++---------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/R/mixture.R b/R/mixture.R index 55dd4e731..cd82f1bfc 100644 --- a/R/mixture.R +++ b/R/mixture.R @@ -56,7 +56,7 @@ #' # check the mixing probabilities after fitting using calculate() #' # (you could also do this within the model) #' normalized_weights <- weights / sum(weights) -#' draws_weights <- calculate(normalized_weights, draws_rates) +#' draws_weights <- calculate(normalized_weights, values = draws_rates) #' #' # get the posterior means #' summary(draws_rates)$statistics[, "Mean"] diff --git a/tests/testthat/test_distributions.R b/tests/testthat/test_distributions.R index e460c5cb3..1a1445f6c 100644 --- a/tests/testthat/test_distributions.R +++ b/tests/testthat/test_distributions.R @@ -389,16 +389,22 @@ test_that("dirichlet-multinomial distribution has correct density", { ) }) -# test_that("deterministic distribution has correct density", { -# skip_if_not(check_tf_version()) -# -# compare_distribution( -# greta::deterministic, -# # stats::dunif, -# # parameters = list(location), -# # x = runif(100, -2.1, -1.2) -# ) -# }) +ddegenerate <- function(x, location) { + ifelse(test = x == location, + yes = 1, + no = 0) +} + +test_that("deterministic distribution has correct density", { + skip_if_not(check_tf_version()) + + compare_distribution( + greta::deterministic, + ddegenerate, + parameters = list(location), + x = sample(x = 1, size = 100, replace = TRUE) + ) +}) test_that("scalar-valued distributions can be defined in models", {