From ef705b2f0cea6393f9f59021d2eb8ad2716bfd43 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 13 Jun 2023 20:02:22 +0100 Subject: [PATCH 1/5] Add log density support --- R/differentiate-support.R | 37 +++++++++++++++++++ tests/testthat/test-differentiate-support.R | 40 +++++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/R/differentiate-support.R b/R/differentiate-support.R index 43ee8640..9a04db25 100644 --- a/R/differentiate-support.R +++ b/R/differentiate-support.R @@ -85,3 +85,40 @@ deterministic_rules <- list( rsignrank = function(expr) { substitute(n * (n + 1) / 4, list(n = expr[[2]])) }) + + +## These are all worked out by manually taking logarithms of the +## densities - I've not been terribly exhaustive here, but have copied +## what we use in dust already... +## +## The user is going to write out: +## +## compare(d) ~ poisson(lambda) +## +## which corresponds to writing +## +## dpois(d, lambda, log = TRUE) +## ==> log(lambda^x * exp(-lambda) / x!) +## ==> x * log(lambda) - lambda - lfactorial(x) +## +## All the density functions will have the same form here, with the +## lhs becoming the 'x' argument (all d* functions take 'x' as the +## first argument). +log_density <- function(distribution, target, args) { + target <- as.name(target) + switch( + distribution, + ## Assumption here is that sd is never zero, which might warrant + ## special treatment (except that it's infinite so probably + ## problematic anyway). + normal = substitute( + - (x - mu)^2 / (2 * sd^2) - log(sqrt(2 * pi)) - log(sd), + list(x = target, mu = args[[1]], sd = args[[2]])), + poisson = substitute( + x * log(lambda) - lambda - lfactorial(x), + list(x = target, lambda = args[[1]])), + uniform = substitute( + if (x < a || x > b) -Inf else -log(b - a), + list(x = target, a = args[[1]], b = args[[2]])), + stop(sprintf("Unsupported distribution '%s'", distribution))) +} diff --git a/tests/testthat/test-differentiate-support.R b/tests/testthat/test-differentiate-support.R index 342cf2bc..7b7c350f 100644 --- a/tests/testthat/test-differentiate-support.R +++ b/tests/testthat/test-differentiate-support.R @@ -206,3 +206,43 @@ test_that("can't compute expectation of cauchy", { make_deterministic(quote(rcauchy(x, y))), "The Cauchy distribution has no mean, and may not be used") }) + + +test_that("log density of normal is correct", { + expr <- log_density("normal", quote(d), list(quote(a), quote(b))) + expect_equal(expr, quote(-(d - a)^2/(2 * b^2) - log(sqrt(2 * pi)) - log(b))) + dat <- list(d = 2.341, a = 5.924, b = 4.2) + expect_equal(eval(expr, dat), + dnorm(dat$d, dat$a, dat$b, log = TRUE)) +}) + + +test_that("log density of poisson is correct", { + expr <- log_density("poisson", quote(d), list(quote(mu))) + expect_equal(expr, quote(d * log(mu) - mu - lfactorial(d))) + dat <- list(d = 3, mu = 5.234) + expect_equal(eval(expr, dat), + dpois(dat$d, dat$mu, log = TRUE)) +}) + + +test_that("log density of uniform is correct", { + expr <- log_density("uniform", quote(d), list(quote(x0), quote(x1))) + expect_equal(expr, quote(if (d < x0 || d > x1) -Inf else -log(x1 - x0))) + dat1 <- list(d = 3, x0 = 1, x1 = 75) + expect_equal(eval(expr, dat1), + dunif(dat1$d, dat1$x0, dat1$x1, log = TRUE)) + dat2 <- list(d = 3, x0 = 9, x1 = 75) + expect_equal(eval(expr, dat2), + dunif(dat2$d, dat2$x0, dat2$x1, log = TRUE)) + dat3 <- list(d = 3, x0 = 1, x1 = 2) + expect_equal(eval(expr, dat3), + dunif(dat3$d, dat3$x0, dat3$x1, log = TRUE)) +}) + + +test_that("disable unknown distributions", { + expect_error( + log_density("cauchy", quote(d), list(quote(a), quote(b))), + "Unsupported distribution 'cauchy'") +}) From 18c9458136cce7e307932b68d66029cc5b94fd8c Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 14 Jun 2023 09:48:11 +0100 Subject: [PATCH 2/5] Fix lint --- R/differentiate-support.R | 8 ++++---- tests/testthat/test-differentiate-support.R | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/differentiate-support.R b/R/differentiate-support.R index 9a04db25..4fc429ea 100644 --- a/R/differentiate-support.R +++ b/R/differentiate-support.R @@ -93,13 +93,13 @@ deterministic_rules <- list( ## ## The user is going to write out: ## -## compare(d) ~ poisson(lambda) +## > compare(d) ~ poisson(lambda) ## ## which corresponds to writing ## -## dpois(d, lambda, log = TRUE) -## ==> log(lambda^x * exp(-lambda) / x!) -## ==> x * log(lambda) - lambda - lfactorial(x) +## > dpois(d, lambda, log = TRUE) +## ==> log(lambda^x * exp(-lambda) / x!) +## ==> x * log(lambda) - lambda - lfactorial(x) ## ## All the density functions will have the same form here, with the ## lhs becoming the 'x' argument (all d* functions take 'x' as the diff --git a/tests/testthat/test-differentiate-support.R b/tests/testthat/test-differentiate-support.R index 7b7c350f..56d881b8 100644 --- a/tests/testthat/test-differentiate-support.R +++ b/tests/testthat/test-differentiate-support.R @@ -210,7 +210,7 @@ test_that("can't compute expectation of cauchy", { test_that("log density of normal is correct", { expr <- log_density("normal", quote(d), list(quote(a), quote(b))) - expect_equal(expr, quote(-(d - a)^2/(2 * b^2) - log(sqrt(2 * pi)) - log(b))) + expect_equal(expr, quote(-(d - a)^2 / (2 * b^2) - log(sqrt(2 * pi)) - log(b))) dat <- list(d = 2.341, a = 5.924, b = 4.2) expect_equal(eval(expr, dat), dnorm(dat$d, dat$a, dat$b, log = TRUE)) From 3e16cbdba9881f7bdf608a58ac5be294c211d3fd Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 26 Jun 2023 13:17:35 +0100 Subject: [PATCH 3/5] Update R/differentiate-support.R Co-authored-by: M-Kusumgar <98405247+M-Kusumgar@users.noreply.github.com> --- R/differentiate-support.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/differentiate-support.R b/R/differentiate-support.R index 4fc429ea..4793b3b5 100644 --- a/R/differentiate-support.R +++ b/R/differentiate-support.R @@ -112,7 +112,7 @@ log_density <- function(distribution, target, args) { ## special treatment (except that it's infinite so probably ## problematic anyway). normal = substitute( - - (x - mu)^2 / (2 * sd^2) - log(sqrt(2 * pi)) - log(sd), + - (x - mu)^2 / (2 * sd^2) - log(2 * pi) / 2 - log(sd), list(x = target, mu = args[[1]], sd = args[[2]])), poisson = substitute( x * log(lambda) - lambda - lfactorial(x), From 49ad95d1c08ce90606b3a33d7852890f4e0591ec Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 26 Jun 2023 14:08:31 +0100 Subject: [PATCH 4/5] Fix test --- tests/testthat/test-differentiate-support.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-differentiate-support.R b/tests/testthat/test-differentiate-support.R index 56d881b8..fd4fa7cd 100644 --- a/tests/testthat/test-differentiate-support.R +++ b/tests/testthat/test-differentiate-support.R @@ -210,7 +210,7 @@ test_that("can't compute expectation of cauchy", { test_that("log density of normal is correct", { expr <- log_density("normal", quote(d), list(quote(a), quote(b))) - expect_equal(expr, quote(-(d - a)^2 / (2 * b^2) - log(sqrt(2 * pi)) - log(b))) + expect_equal(expr, quote(-(d - a)^2 / (2 * b^2) - log(2 * pi) / 2 - log(b))) dat <- list(d = 2.341, a = 5.924, b = 4.2) expect_equal(eval(expr, dat), dnorm(dat$d, dat$a, dat$b, log = TRUE)) From 08963fa6b2d06c01dcb0c381d66adc7d70d4563e Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 26 Jun 2023 14:57:12 +0100 Subject: [PATCH 5/5] Bump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 7529268b..4cb45f4d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: odin Title: ODE Generation and Integration -Version: 1.5.2 +Version: 1.5.3 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Thibaut", "Jombart", role = "ctb"),