Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ export(colSums)
export(cov2cor)
export(cpu_only)
export(destroy_greta_deps)
export(deterministic)
export(diag)
export(dirichlet)
export(dirichlet_multinomial)
Expand Down
2 changes: 1 addition & 1 deletion R/mixture.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
#' # check the mixing probabilities after fitting using calculate()
#' # (you could also do this within the model)
#' normalized_weights <- weights / sum(weights)
#' draws_weights <- calculate(normalized_weights, draws_rates)
#' draws_weights <- calculate(normalized_weights, values = draws_rates)
#'
#' # get the posterior means
#' summary(draws_rates)$statistics[, "Mean"]
Expand Down
30 changes: 29 additions & 1 deletion R/probability_distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,27 @@ lkj_correlation_distribution <- R6Class(
)
)

deterministic_distribution <- R6Class(
"deterministic_distribution",
inherit = distribution_node,
public = list(
location = NA,
initialize = function(location, dim) {
location <- as.greta_array(location)

dim <- check_dims(location, target_dim = dim)
super$initialize("deterministic", dim)
self$add_parameter(location, "location")
},

tf_distrib = function(parameters, dag) {
tfp$distributions$Deterministic(
loc = parameters$location
)
}
)
)

# module for export via .internals
distribution_classes_module <- module(
uniform_distribution,
Expand Down Expand Up @@ -1222,7 +1243,8 @@ distribution_classes_module <- module(
multinomial_distribution,
categorical_distribution,
dirichlet_distribution,
dirichlet_multinomial_distribution
dirichlet_multinomial_distribution,
deterministic_distribution
)

# export constructors
Expand Down Expand Up @@ -1569,3 +1591,9 @@ dirichlet_multinomial <- function(size, alpha,
size, alpha, n_realisations, dimension
)
}

#' @rdname distributions
#' @export
deterministic <- function(location, dimension = NULL) {
distrib("deterministic", location, dimension)
}
3 changes: 3 additions & 0 deletions man/distributions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions tests/testthat/test_distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,24 @@ test_that("dirichlet-multinomial distribution has correct density", {
)
})

ddegenerate <- function(x, location) {
ifelse(test = x == location,
yes = 1,
no = 0)
}

test_that("deterministic distribution has correct density", {
skip_if_not(check_tf_version())

compare_distribution(
greta::deterministic,
ddegenerate,
parameters = list(location),
x = sample(x = 1, size = 100, replace = TRUE)
)
})


test_that("scalar-valued distributions can be defined in models", {
skip_if_not(check_tf_version())

Expand Down
6 changes: 6 additions & 0 deletions tests/testthat/test_iid_samples.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ test_that("univariate samples are correct", {
rf,
parameters = list(df1 = 4, df2 = 1)
)

# compare_iid_samples(deterministic,
# # runif, degenerate
# # parameters = list(min = -2, max = 3)
# )

})

test_that("truncated univariate samples are correct", {
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test_inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,18 @@ test_that("mcmc supports slice sampler with single precision models", {
))
})

# test_that("mcmc works with deterministic distribution", {
# skip_if_not(check_tf_version())
# set.seed(5)
# x <- deterministic(0)
# m <- model(x, precision = "single")
# expect_ok(draws <- mcmc(m,
# sampler = slice(),
# n_samples = 100, warmup = 100,
# verbose = FALSE
# ))
# })

test_that("initials works", {
skip_if_not(check_tf_version())

Expand Down
5 changes: 5 additions & 0 deletions tests/testthat/test_mixture.R
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,8 @@ test_that("mixture of normals with varying weights has correct density", {
dim = dim
)
})

#
# test_that("mixture of deterministic and continuous has correct density", {
#
# })
Loading