-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster normal draws with the ziggurat algorithm #326
Conversation
Turns out that passing cpp11::writable::doubles caused unexpected copies, which was a surprise. Something to investigate later
Codecov Report
@@ Coverage Diff @@
## master #326 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 57 59 +2
Lines 3258 3331 +73
=========================================
+ Hits 3258 3331 +73
Continue to review full report at Codecov.
|
return std::sqrt(-2 * std::log(u1)) * std::cos(two_pi * u2); | ||
} | ||
|
||
real_type random_normal(rng_state_type& rng_state) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the algorithm chosen at run time or when compiling the object? Wondering whether we could just template rather than using this function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is a template - am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussion: it feels like one could overload this as we know the algorithm is known at compile time but C++ doesn't let us do that
// TODO: this will not work efficiently for float types because we | ||
// don't have float tables for 'x' and 'y'; getting them is not easy | ||
// without requiring c++14 either. The lower loop using 'x' could |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come we are able to manage this with binomial draws?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The binomial algorithm doesn't need to use the random numbers; the tables are only used in stirling_approx_tail
, so we have fully specialised templates on real_type
(and some ugly names k_tail_values_d
and stirling_approx_tail_f
). With C++14 we can make these template variables which removes the naming problem but because the algorithm here needs to use one or two random numbers along side the constants we'd end up with trying to make a partially specialised template (providing real_type but leaving rng_state_type open).
There are some solutions, but I'd like to try and implement these separately and compare timings to make sure that we don't end up paying too much on the CPU case
real_type random_normal_ziggurat(rng_state_type& rng_state) { | ||
using ziggurat::x; | ||
using ziggurat::y; | ||
constexpr size_t n = 256; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment to explain this choice and whether it is tuneable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My first shot at this was fully tunable using std::array<real_type, int>
for all sorts of useful n
! But it was a pain to work with and we can't have std::array
on the GPU...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a note indicating where this comes from and why now
const auto f1 = std::exp(-0.5 * (x[i + 1] * x[i + 1] - z * z)); | ||
const auto u1 = random_real<real_type>(rng_state); | ||
if (f1 + u1 * (f0 - f1) < 1.0) { | ||
ret = z; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this break
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ugh, yes
s <- vapply(z, deparse, "", control = "digits17") | ||
paste( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know that this is enough precision? (is it based on being < 2.2e-16?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
digits17 is the full precision of the underlying number
|
||
|
||
## Helper for root polishing | ||
uniroot2 <- function(f, bounds, ..., scal = 10) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does tolerance need scaling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tolerance comes through in the dots here
} | ||
|
||
|
||
zig_constants <- function(n, tolerance = 1e-10) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enough precision here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so - this is 2 orders more than we get than by default. To do this "properly" we'd want to use long doubles really. I've seen implementations with these numbers only ok to 5 digits though, and ours is more accurate than Doornik (only differs in the ~6th place I think)
|
||
zig_constants <- function(n, tolerance = 1e-10) { | ||
## As for intervals but with more robustness to being out of bounds | ||
intervals <- function(n, r, v) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would consider adding a reference here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added one, but a more considerable derivation is forthcoming (I've got most of a detailed vignette working through the process_)
This PR implements the ziggurat algorithm for normally distributed numbers.
There are some follow-on bits of work to do that I am avoiding in this PR because they're more likely to be disruptive and this is large enough atm
The current version does compile on a GPU, but runs fairly slowly due to the doubles
Fixes #308