Skip to content

Commit

Permalink
Merge pull request #297 from mrc-ide/mrc-4326
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz authored Jun 26, 2023
2 parents ed4242b + 08963fa commit 4f9f69e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 1 deletion.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin
Title: ODE Generation and Integration
Version: 1.5.2
Version: 1.5.3
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Thibaut", "Jombart", role = "ctb"),
Expand Down
37 changes: 37 additions & 0 deletions R/differentiate-support.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,40 @@ deterministic_rules <- list(
rsignrank = function(expr) {
substitute(n * (n + 1) / 4, list(n = expr[[2]]))
})


## These are all worked out by manually taking logarithms of the
## densities - I've not been terribly exhaustive here, but have copied
## what we use in dust already...
##
## The user is going to write out:
##
## > compare(d) ~ poisson(lambda)
##
## which corresponds to writing
##
## > dpois(d, lambda, log = TRUE)
## ==> log(lambda^x * exp(-lambda) / x!)
## ==> x * log(lambda) - lambda - lfactorial(x)
##
## All the density functions will have the same form here, with the
## lhs becoming the 'x' argument (all d* functions take 'x' as the
## first argument).
log_density <- function(distribution, target, args) {
target <- as.name(target)
switch(
distribution,
## Assumption here is that sd is never zero, which might warrant
## special treatment (except that it's infinite so probably
## problematic anyway).
normal = substitute(
- (x - mu)^2 / (2 * sd^2) - log(2 * pi) / 2 - log(sd),
list(x = target, mu = args[[1]], sd = args[[2]])),
poisson = substitute(
x * log(lambda) - lambda - lfactorial(x),
list(x = target, lambda = args[[1]])),
uniform = substitute(
if (x < a || x > b) -Inf else -log(b - a),
list(x = target, a = args[[1]], b = args[[2]])),
stop(sprintf("Unsupported distribution '%s'", distribution)))
}
40 changes: 40 additions & 0 deletions tests/testthat/test-differentiate-support.R
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,43 @@ test_that("can't compute expectation of cauchy", {
make_deterministic(quote(rcauchy(x, y))),
"The Cauchy distribution has no mean, and may not be used")
})


test_that("log density of normal is correct", {
expr <- log_density("normal", quote(d), list(quote(a), quote(b)))
expect_equal(expr, quote(-(d - a)^2 / (2 * b^2) - log(2 * pi) / 2 - log(b)))
dat <- list(d = 2.341, a = 5.924, b = 4.2)
expect_equal(eval(expr, dat),
dnorm(dat$d, dat$a, dat$b, log = TRUE))
})


test_that("log density of poisson is correct", {
expr <- log_density("poisson", quote(d), list(quote(mu)))
expect_equal(expr, quote(d * log(mu) - mu - lfactorial(d)))
dat <- list(d = 3, mu = 5.234)
expect_equal(eval(expr, dat),
dpois(dat$d, dat$mu, log = TRUE))
})


test_that("log density of uniform is correct", {
expr <- log_density("uniform", quote(d), list(quote(x0), quote(x1)))
expect_equal(expr, quote(if (d < x0 || d > x1) -Inf else -log(x1 - x0)))
dat1 <- list(d = 3, x0 = 1, x1 = 75)
expect_equal(eval(expr, dat1),
dunif(dat1$d, dat1$x0, dat1$x1, log = TRUE))
dat2 <- list(d = 3, x0 = 9, x1 = 75)
expect_equal(eval(expr, dat2),
dunif(dat2$d, dat2$x0, dat2$x1, log = TRUE))
dat3 <- list(d = 3, x0 = 1, x1 = 2)
expect_equal(eval(expr, dat3),
dunif(dat3$d, dat3$x0, dat3$x1, log = TRUE))
})


test_that("disable unknown distributions", {
expect_error(
log_density("cauchy", quote(d), list(quote(a), quote(b))),
"Unsupported distribution 'cauchy'")
})

0 comments on commit 4f9f69e

Please sign in to comment.