Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Helpers for using rngs from other packages #333

Merged
merged 27 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d44fadd
Start work on a sensible API
richfitz Nov 10, 2021
7bbbec7
Start on a nice interface for pointers
richfitz Nov 11, 2021
872eb55
Gracefully cope with pointer serialisation
richfitz Nov 11, 2021
e6f9d7f
Add some encapsulation
richfitz Nov 11, 2021
4e858cb
Start adding docs
richfitz Nov 11, 2021
0b99e97
Tidy up examples
richfitz Nov 11, 2021
ec82dcf
Tidy up interface
richfitz Nov 11, 2021
bf69eb4
Bump version and add news
richfitz Nov 11, 2021
db72276
Expose the number of streams
richfitz Nov 11, 2021
27a12f1
Expand example
richfitz Nov 11, 2021
394b0ee
Restore important error check
richfitz Nov 12, 2021
4c5cd33
Minor cleanup
richfitz Nov 12, 2021
8c58863
Move code into its own file
richfitz Nov 12, 2021
be595a3
Check that all pointer types can be synced
richfitz Nov 12, 2021
e3b687b
Check we can hit all branches in pointer code
richfitz Nov 12, 2021
34cd062
Tidy up headers
richfitz Nov 12, 2021
827ca85
Test pointer handling
richfitz Nov 12, 2021
3e68479
Make vignettes buildable
richfitz Nov 12, 2021
0c43a24
Tidy vignette
richfitz Nov 12, 2021
73920e9
Copy files from right place
richfitz Nov 12, 2021
de32262
Spelling
richfitz Nov 12, 2021
30da86a
Avoid windows pain
richfitz Nov 12, 2021
57688d3
Better default behaviour of state()
richfitz Nov 12, 2021
2c1935b
Documentation improvement from review
richfitz Nov 12, 2021
b4bf97e
Fix docs
richfitz Nov 12, 2021
369fc29
Regenerate vignette
richfitz Nov 12, 2021
621e602
Add test that state() syncs rng state
richfitz Nov 12, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ jobs:
remotes::install_cran("rcmdcheck")
shell: Rscript {0}

- name: Move real vignettes
if: runner.os != 'Windows'
run: |
cp vignettes_src/rng_package.Rmd vignettes
cp vignettes_src/rng_pi*.cpp vignettes

- name: Check
env:
_R_CHECK_CRAN_INCOMING_REMOTE_: false
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ inst/doc
pkgdown
inst/include/cub
*.gcov
vignettes_src/gpu.md
vignettes_src/*.md

.vscode/

Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: dust
Title: Iterate Multiple Realisations of Stochastic Models
Version: 0.11.7
Version: 0.11.8
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("John", "Lees", role = "aut"),
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ clean:
src/*.gcov src/*.gcda src/*.gcno

vignettes/gpu.Rmd: vignettes_src/gpu.Rmd
./scripts/build_gpu_vignette
./scripts/build_vignette gpu

vignettes/rng_package.Rmd: vignettes_src/rng_package.Rmd
./scripts/build_vignette rng_package

vignettes: vignettes/dust.Rmd vignettes/rng.Rmd
${RSCRIPT} -e 'tools::buildVignettes(dir = ".")'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export(dust_openmp_support)
export(dust_openmp_threads)
export(dust_package)
export(dust_rng)
export(dust_rng_pointer)
export(dust_rng_state_long_jump)
importFrom(stats,coef)
useDynLib(dust, .registration = TRUE)
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# dust 0.11.8

* Improved the interface for using dust's random number support from other packages (#329)

# dust 0.11.7

* New polar algorithm for normally distributed random numbers; faster than Box-Muller but slower than Ziggurat
Expand Down
16 changes: 14 additions & 2 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ density_poisson <- function(x, lambda, log) {
.Call(`_dust_density_poisson`, x, lambda, log)
}

dust_rng_pointer_init <- function(n_streams, seed, algorithm) {
.Call(`_dust_dust_rng_pointer_init`, n_streams, seed, algorithm)
}

dust_rng_pointer_sync <- function(obj, algorithm) {
invisible(.Call(`_dust_dust_rng_pointer_sync`, obj, algorithm))
}

test_rng_pointer_get <- function(obj, n_streams) {
.Call(`_dust_test_rng_pointer_get`, obj, n_streams)
}

dust_rng_alloc <- function(r_seed, n_generators, deterministic, is_float) {
.Call(`_dust_dust_rng_alloc`, r_seed, n_generators, deterministic, is_float)
}
Expand Down Expand Up @@ -288,8 +300,8 @@ test_cuda_pars <- function(r_gpu_config, n_particles, n_particles_each, n_state,
.Call(`_dust_test_cuda_pars`, r_gpu_config, n_particles, n_particles_each, n_state, n_state_full, n_shared_int, n_shared_real, data_size, shared_size)
}

test_xoshiro_run <- function(name) {
.Call(`_dust_test_xoshiro_run`, name)
test_xoshiro_run <- function(obj) {
.Call(`_dust_test_xoshiro_run`, obj)
}

cpp_scale_log_weights <- function(w) {
Expand Down
2 changes: 1 addition & 1 deletion R/dust.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## Generated by dust (version 0.11.4) - do not edit
## Generated by dust (version 0.11.6) - do not edit
sir <- R6::R6Class(
"dust",
cloneable = FALSE,
Expand Down
76 changes: 76 additions & 0 deletions R/rng_pointer.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
##' @title Create pointer to random number generator stream
##'
##' @description This function exists to support use from other
##' packages that wish to use dust's random number support, and
##' creates an opaque pointer to a set of random number streams. It
##' is described more fully in `vignette("rng_package.Rmd")`
##'
##' @export
##' @examples
##' dust::dust_rng_pointer$new()
dust_rng_pointer <- R6::R6Class(
"dust_rng_pointer",
cloneable = FALSE,

private = list(
ptr_ = NULL,
state_ = NULL,
is_current_ = NULL
),

public = list(
##' @field algorithm The name of the generator algorithm used (read-only)
algorithm = NULL,

##' @field n_streams The number of streams of random numbers provided
##' (read-only)
n_streams = NULL,

##' @description Create a new `dust_rng_pointer` object
##'
##' @param seed The random number seed to use (see [dust::dust_rng]
##' for details)
##'
##' @param n_streams The number of independent random number streams to
##' create
##'
##' @param algorithm The random number algorithm to use. The default is
##' `xoshiro256plus` which is a good general choice
initialize = function(seed = NULL, n_streams = 1L,
algorithm = "xoshiro256plus") {
dat <- dust_rng_pointer_init(n_streams, seed, algorithm)
private$ptr_ <- dat[[1L]]
private$state_ <- dat[[2L]]
private$is_current_ <- TRUE

self$algorithm <- algorithm
self$n_streams <- n_streams
lockBinding("algorithm", self)
lockBinding("n_streams", self)
},

##' @description Synchronise the R copy of the random number state.
##' Typically this is only needed before serialisation if you have
##' ever used the object.
sync = function() {
dust_rng_pointer_sync(private, self$algorithm)
},

##' @description Return a raw vector of state. This can be used to
##' create other generators with the same state.
state = function() {
if (!private$is_current_) {
self$sync()
}
private$state_
},

##' @description Return a logical, indicating if the random number
##' state that would be returned by `state()` is "current" (i.e., the
##' same as the copy held in the pointer) or not. This is `TRUE` on
##' creation or immediately after calling `$sync()` or `$state()`
##' and `FALSE` after any use of the pointer.
is_current = function() {
private$is_current_
}
))
12 changes: 9 additions & 3 deletions extra/harness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
typedef uint64_t int_type;
constexpr size_t data_size = 4;
#elif defined(XOSHIRO128)
typedef uint64_t int_type;
typedef uint32_t int_type;
constexpr size_t data_size = 4;
#elif defined(XOROSHIRO128)
typedef uint64_t int_type;
constexpr size_t data_size = 4;
constexpr size_t data_size = 2;
#elif defined(XOSHIRO512)
typedef uint64_t int_type;
constexpr size_t data_size = 8;
Expand Down Expand Up @@ -51,10 +51,16 @@ int main() {
}
auto x = next();
std::cout <<
//std::setw(16) << std::setfill('0') << std::hex << x << " " <<
std::dec << x <<
std::endl;
}

// At the end we dump out the model state, in hex:
for (int i = 0; i < data_size; ++i) {
std::cout << std::setw(sizeof(int_type) * 2) <<
std::setfill('0') << std::hex << s[i] <<
std::endl;
}

return 0;
}
1 change: 1 addition & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Marsaglia
Mersenne
OMP
OpenMP
OpenMP's
Perez
Poisson
R's
Expand Down
4 changes: 2 additions & 2 deletions inst/include/dust/r/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ dust_inputs<T> process_inputs_single(cpp11::list r_pars, int step,
dust::r::validate_size(step, "step");
dust::r::validate_positive(n_threads, "n_threads");
std::vector<typename T::rng_state_type::int_type> seed =
dust::r::as_rng_seed<typename T::rng_state_type>(r_seed);
dust::random::r::as_rng_seed<typename T::rng_state_type>(r_seed);

std::vector<dust::pars_type<T>> pars;
pars.push_back(dust::dust_pars<T>(r_pars));
Expand All @@ -389,7 +389,7 @@ dust_inputs<T> process_inputs_multi(cpp11::list r_pars, int step,
dust::r::validate_size(step, "step");
dust::r::validate_positive(n_threads, "n_threads");
std::vector<typename T::rng_state_type::int_type> seed =
dust::r::as_rng_seed<typename T::rng_state_type>(r_seed);
dust::random::r::as_rng_seed<typename T::rng_state_type>(r_seed);

dust::r::check_pars_multi(r_pars);
std::vector<dust::pars_type<T>> pars;
Expand Down
134 changes: 127 additions & 7 deletions inst/include/dust/r/random.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,33 @@

#include <cstring> // memcpy

#include <cpp11/environment.hpp>
#include <cpp11/external_pointer.hpp>
#include <cpp11/list.hpp>
#include <cpp11/raws.hpp>

#include <R_ext/Random.h>

#include "dust/random/generator.hpp"
#include "dust/random/prng.hpp"

namespace dust {
namespace random {
namespace r {

template <typename rng_state_type>
std::vector<typename rng_state_type::int_type> raw_seed(cpp11::raws seed_data) {
using int_type = typename rng_state_type::int_type;
constexpr size_t len = sizeof(int_type) * rng_state_type::size();
if (seed_data.size() == 0 || seed_data.size() % len != 0) {
cpp11::stop("Expected raw vector of length as multiple of %d for 'seed'",
len);
}
std::vector<int_type> seed(seed_data.size() / sizeof(int_type));
std::memcpy(seed.data(), RAW(seed_data), seed_data.size());
return seed;
}

template <typename rng_state_type>
std::vector<typename rng_state_type::int_type> as_rng_seed(cpp11::sexp r_seed) {
typedef typename rng_state_type::int_type int_type;
Expand All @@ -21,13 +40,7 @@ std::vector<typename rng_state_type::int_type> as_rng_seed(cpp11::sexp r_seed) {
seed = dust::random::seed_data<rng_state_type>(seed_int);
} else if (seed_type == RAWSXP) {
cpp11::raws seed_data = cpp11::as_cpp<cpp11::raws>(r_seed);
constexpr size_t len = sizeof(int_type) * rng_state_type::size();
if (seed_data.size() == 0 || seed_data.size() % len != 0) {
cpp11::stop("Expected raw vector of length as multiple of %d for 'seed'",
len);
}
seed.resize(seed_data.size() / sizeof(int_type));
std::memcpy(seed.data(), RAW(seed_data), seed_data.size());
seed = raw_seed<rng_state_type>(seed_data);
} else if (seed_type == NILSXP) {
GetRNGstate();
size_t seed_int =
Expand All @@ -40,6 +53,113 @@ std::vector<typename rng_state_type::int_type> as_rng_seed(cpp11::sexp r_seed) {
return seed;
}

namespace {

template<typename T>
std::string algorithm_name() {
std::string ret;
if (std::is_same<T, xoshiro128plus_state>::value) {
ret = "xoshiro128plus";
} else if (std::is_same<T, xoshiro128plusplus_state>::value) {
ret = "xoshiro128plusplus";
} else if (std::is_same<T, xoshiro128starstar_state>::value) {
ret = "xoshiro128starstar";
} else if (std::is_same<T, xoroshiro128plus_state>::value) {
ret = "xoroshiro128plus";
} else if (std::is_same<T, xoroshiro128plusplus_state>::value) {
ret = "xoroshiro128plusplus";
} else if (std::is_same<T, xoroshiro128starstar_state>::value) {
ret = "xoroshiro128starstar";
} else if (std::is_same<T, xoshiro256plus_state>::value) {
ret = "xoshiro256plus";
} else if (std::is_same<T, xoshiro256plusplus_state>::value) {
ret = "xoshiro256plusplus";
} else if (std::is_same<T, xoshiro256starstar_state>::value) {
ret = "xoshiro256starstar";
} else if (std::is_same<T, xoshiro512plus_state>::value) {
ret = "xoshiro512plus";
} else if (std::is_same<T, xoshiro512plusplus_state>::value) {
ret = "xoshiro512plusplus";
} else if (std::is_same<T, xoshiro512starstar_state>::value) {
ret = "xoshiro512starstar";
}
return ret;
}

template <typename rng_state_type>
cpp11::raws rng_state_vector(prng<rng_state_type>* rng) {
auto state = rng->export_state();
size_t len = sizeof(typename rng_state_type::int_type) * state.size();
cpp11::writable::raws r_state(len);
std::memcpy(RAW(r_state), state.data(), len);
return r_state;
}

}

template <typename rng_state_type>
SEXP rng_pointer_init(int n_streams, cpp11::sexp r_seed) {
auto seed = as_rng_seed<rng_state_type>(r_seed);
auto *rng = new prng<rng_state_type>(n_streams, seed);
auto r_ptr = cpp11::external_pointer<prng<rng_state_type>>(rng);
auto r_state = rng_state_vector(rng);
return cpp11::writable::list({r_ptr, r_state});
}

// Start with the assumption that we'll pass in the R6 object, might
// write a simpler version later.
template <typename rng_state_type>
prng<rng_state_type>* rng_pointer_get(cpp11::environment obj,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible useful reference for #271. If we wanted to add this and support saving and loading, see the serialisation examples:
https://pybind11.readthedocs.io/en/stable/advanced/classes.html#pickling-support
https://github.com/pybind/pybind11/blob/master/tests/test_pickling.cpp

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The R serialisation has no hooking functionality, it's super annoying

int n_streams = 0) {
// We could probably do this more efficiently if we store an enum
// in the object but this is probably ok.
const auto algorithm_given = cpp11::as_cpp<std::string>(obj["algorithm"]);
const auto algorithm_expected = algorithm_name<rng_state_type>();
if (algorithm_given != algorithm_expected) {
cpp11::stop("Incorrect rng type: given %s, expected %s",
algorithm_given.c_str(), algorithm_expected.c_str());
}

cpp11::environment env_enclos =
cpp11::as_cpp<cpp11::environment>(obj[".__enclos_env__"]);
cpp11::environment env =
cpp11::as_cpp<cpp11::environment>(env_enclos["private"]);

using ptr_type = cpp11::external_pointer<prng<rng_state_type>>;
auto ptr = cpp11::as_cpp<ptr_type>(env["ptr_"]);

auto * rng = ptr.get();
if (rng == nullptr) {
if (!cpp11::as_cpp<bool>(env["is_current_"])) {
cpp11::stop("Can't unserialise an rng pointer that was not synced");
}
cpp11::raws seed_data = cpp11::as_cpp<cpp11::raws>(env["state_"]);
auto seed = raw_seed<rng_state_type>(seed_data);
const auto n_streams_orig = seed.size() / rng_state_type::size();
rng = new prng<rng_state_type>(n_streams_orig, seed);
env["ptr_"] = cpp11::external_pointer<prng<rng_state_type>>(rng);
}

if (n_streams > 0 && static_cast<int>(rng->size()) < n_streams) {
cpp11::stop("Requested a rng with %d streams but only have %d",
n_streams, rng->size());
}
env["is_current_"] = cpp11::as_sexp(false);

return rng;
}

template <typename rng_state_type>
void rng_pointer_sync(cpp11::environment obj) {
using ptr_type = cpp11::external_pointer<prng<rng_state_type>>;
if (!cpp11::as_cpp<bool>(obj["is_current_"])) {
auto ptr = cpp11::as_cpp<ptr_type>(obj["ptr_"]);
obj["state_"] = rng_state_vector(ptr.get());
obj["is_current_"] = cpp11::as_sexp(true);
}
}

}
}
}

Expand Down
Loading