Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster normal draws with the ziggurat algorithm #326

Merged
merged 17 commits into from
Nov 8, 2021
Merged

Conversation

richfitz
Copy link
Member

@richfitz richfitz commented Nov 8, 2021

This PR implements the ziggurat algorithm for normally distributed numbers.

There are some follow-on bits of work to do that I am avoiding in this PR because they're more likely to be disruptive and this is large enough atm

The current version does compile on a GPU, but runs fairly slowly due to the doubles

Fixes #308

@codecov
Copy link

codecov bot commented Nov 8, 2021

Codecov Report

Merging #326 (57a17bf) into master (29b427b) will not change coverage.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff            @@
##            master      #326   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files           57        59    +2     
  Lines         3258      3331   +73     
=========================================
+ Hits          3258      3331   +73     
Impacted Files Coverage Δ
inst/include/dust/random/binomial.hpp 100.00% <ø> (ø)
R/rng.R 100.00% <100.00%> (ø)
inst/include/dust/random/normal.hpp 100.00% <100.00%> (ø)
inst/include/dust/random/normal_box_muller.hpp 100.00% <100.00%> (ø)
inst/include/dust/random/normal_ziggurat.hpp 100.00% <100.00%> (ø)
src/dust_rng.cpp 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 29b427b...57a17bf. Read the comment docs.

@richfitz richfitz marked this pull request as ready for review November 8, 2021 11:08
@richfitz richfitz requested a review from johnlees November 8, 2021 11:20
return std::sqrt(-2 * std::log(u1)) * std::cos(two_pi * u2);
}

real_type random_normal(rng_state_type& rng_state) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the algorithm chosen at run time or when compiling the object? Wondering whether we could just template rather than using this function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is a template - am I missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion: it feels like one could overload this as we know the algorithm is known at compile time but C++ doesn't let us do that

Comment on lines +33 to +35
// TODO: this will not work efficiently for float types because we
// don't have float tables for 'x' and 'y'; getting them is not easy
// without requiring c++14 either. The lower loop using 'x' could
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come we are able to manage this with binomial draws?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The binomial algorithm doesn't need to use the random numbers; the tables are only used in stirling_approx_tail, so we have fully specialised templates on real_type (and some ugly names k_tail_values_d and stirling_approx_tail_f). With C++14 we can make these template variables which removes the naming problem but because the algorithm here needs to use one or two random numbers along side the constants we'd end up with trying to make a partially specialised template (providing real_type but leaving rng_state_type open).

There are some solutions, but I'd like to try and implement these separately and compare timings to make sure that we don't end up paying too much on the CPU case

real_type random_normal_ziggurat(rng_state_type& rng_state) {
using ziggurat::x;
using ziggurat::y;
constexpr size_t n = 256;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment to explain this choice and whether it is tuneable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first shot at this was fully tunable using std::array<real_type, int> for all sorts of useful n! But it was a pain to work with and we can't have std::array on the GPU...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a note indicating where this comes from and why now

const auto f1 = std::exp(-0.5 * (x[i + 1] * x[i + 1] - z * z));
const auto u1 = random_real<real_type>(rng_state);
if (f1 + u1 * (f0 - f1) < 1.0) {
ret = z;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this break?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh, yes

Comment on lines +13 to +14
s <- vapply(z, deparse, "", control = "digits17")
paste(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know that this is enough precision? (is it based on being < 2.2e-16?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

digits17 is the full precision of the underlying number



## Helper for root polishing
uniroot2 <- function(f, bounds, ..., scal = 10) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does tolerance need scaling?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tolerance comes through in the dots here

}


zig_constants <- function(n, tolerance = 1e-10) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enough precision here too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so - this is 2 orders more than we get than by default. To do this "properly" we'd want to use long doubles really. I've seen implementations with these numbers only ok to 5 digits though, and ours is more accurate than Doornik (only differs in the ~6th place I think)


zig_constants <- function(n, tolerance = 1e-10) {
## As for intervals but with more robustness to being out of bounds
intervals <- function(n, r, v) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider adding a reference here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added one, but a more considerable derivation is forthcoming (I've got most of a detailed vignette working through the process_)

@richfitz richfitz requested a review from johnlees November 8, 2021 14:05
@johnlees johnlees merged commit 906d994 into master Nov 8, 2021
@johnlees johnlees deleted the i308-normal branch November 8, 2021 14:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add alternative normal distribution algorithm
2 participants