From 827ca85eb23e54faecb3c5aebf7dba3b80f955d8 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Fri, 12 Nov 2021 10:50:56 +0000 Subject: [PATCH] Test pointer handling --- R/cpp11.R | 4 ++++ src/cpp11.cpp | 8 ++++++++ src/dust_rng_pointer.cpp | 9 +++++++++ tests/testthat/test-rng-interface.R | 16 ++++++++++++++++ 4 files changed, 37 insertions(+) diff --git a/R/cpp11.R b/R/cpp11.R index 9e6e452ce..8320d6c16 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -32,6 +32,10 @@ dust_rng_pointer_sync <- function(obj, algorithm) { invisible(.Call(`_dust_dust_rng_pointer_sync`, obj, algorithm)) } +test_rng_pointer_get <- function(obj, n_streams) { + .Call(`_dust_test_rng_pointer_get`, obj, n_streams) +} + dust_rng_alloc <- function(r_seed, n_generators, deterministic, is_float) { .Call(`_dust_dust_rng_alloc`, r_seed, n_generators, deterministic, is_float) } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 4e3b513db..d4735c866 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -62,6 +62,13 @@ extern "C" SEXP _dust_dust_rng_pointer_sync(SEXP obj, SEXP algorithm) { return R_NilValue; END_CPP11 } +// dust_rng_pointer.cpp +double test_rng_pointer_get(cpp11::environment obj, int n_streams); +extern "C" SEXP _dust_test_rng_pointer_get(SEXP obj, SEXP n_streams) { + BEGIN_CPP11 + return cpp11::as_sexp(test_rng_pointer_get(cpp11::as_cpp>(obj), cpp11::as_cpp>(n_streams))); + END_CPP11 +} // dust_rng.cpp SEXP dust_rng_alloc(cpp11::sexp r_seed, int n_generators, bool deterministic, bool is_float); extern "C" SEXP _dust_dust_rng_alloc(SEXP r_seed, SEXP n_generators, SEXP deterministic, SEXP is_float) { @@ -1192,6 +1199,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust_dust_walk_capabilities", (DL_FUNC) &_dust_dust_walk_capabilities, 0}, {"_dust_dust_walk_gpu_info", (DL_FUNC) &_dust_dust_walk_gpu_info, 0}, {"_dust_test_cuda_pars", (DL_FUNC) &_dust_test_cuda_pars, 9}, + {"_dust_test_rng_pointer_get", (DL_FUNC) &_dust_test_rng_pointer_get, 2}, {"_dust_test_xoshiro_run", (DL_FUNC) &_dust_test_xoshiro_run, 1}, {NULL, NULL, 0} }; diff --git a/src/dust_rng_pointer.cpp b/src/dust_rng_pointer.cpp index 7f9e6db34..a4f0c7cab 100644 --- a/src/dust_rng_pointer.cpp +++ b/src/dust_rng_pointer.cpp @@ -73,3 +73,12 @@ void dust_rng_pointer_sync(cpp11::environment obj, std::string algorithm) { r::rng_pointer_sync(obj); } } + +// This exists to check some error paths in rng_pointer_get; it is not +// for use by users. +[[cpp11::register]] +double test_rng_pointer_get(cpp11::environment obj, int n_streams) { + using namespace dust::random; + auto rng = r::rng_pointer_get(obj, n_streams); + return random_real(rng->state(0)); +} diff --git a/tests/testthat/test-rng-interface.R b/tests/testthat/test-rng-interface.R index 983ce25e0..1404301da 100644 --- a/tests/testthat/test-rng-interface.R +++ b/tests/testthat/test-rng-interface.R @@ -50,3 +50,19 @@ test_that("can't create invalid pointer types", { dust_rng_pointer$new(algorithm = "mt19937"), "Unknown algorithm 'mt19937'") }) + + +test_that("Validate pointers on fetch", { + obj <- dust_rng_pointer$new(algorithm = "xoshiro256starstar") + expect_error( + test_rng_pointer_get(obj, 1), + "Incorrect rng type: given xoshiro256starstar, expected xoshiro256plus") + obj <- dust_rng_pointer$new(algorithm = "xoshiro256plus", n_streams = 4) + expect_error( + test_rng_pointer_get(obj, 20), + "Requested a rng with 20 streams but only have 4") + expect_silent( + test_rng_pointer_get(obj, 0)) + expect_silent( + test_rng_pointer_get(obj, 1)) +})