Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ziggurat implementation could be faster... #3

Open
camel-cdr opened this issue Dec 19, 2020 · 0 comments
Open

Ziggurat implementation could be faster... #3

camel-cdr opened this issue Dec 19, 2020 · 0 comments

Comments

@camel-cdr
Copy link

camel-cdr commented Dec 19, 2020

It looks like your ziggurat implementation is from https://www.seehuhn.de/pages/ziggurat.html.
He says that he uses "an exponential distribution together with rejection for the tails of the base strip to simplify the implementation. Both changes seem to have no significant performance impact.".
But I observed that my implementation based on the original ziggurat (I only condensed the three lookup tables into one) is 60% faster.
I didn't write a pull request directly, because there is already a ziggurat implementation present and I don't know if you'd like to benchmark them separately or want to use the fastest.
Anyway, here is the proposed code:

#include "test.h"
#include "lcg.h"

#include <cstdint>
#include <cmath>
#include <limits>
#include <climits>

/* block count can be any power of two until 2^11 */
#define ziggurat_BLOCK_COUNT 128
#define ziggurat_R     3.442619855899
#define ziggurat_AREA  0.00991256303526217

template<class Tflt, class Tint>
Tflt ziggurat_next(Tint (*rng)(void))
{
#define RNG2F(x) \
	((x >> (std::numeric_limits<Tint>::digits - \
	        std::numeric_limits<Tflt>::digits)) \
	 * 1.0 / (1ull << std::numeric_limits<Tflt>::digits))

	static constexpr Tflt xtbl[ziggurat_BLOCK_COUNT + 1] = {
	3.7130862467425505002,  3.4426198558990002141,  3.2230849845811415655,
	3.0832288582168683178,  2.9786962526477802626,  2.8943440070215289417,
	2.8231253505489104505,  2.7611693723871768569,  2.7061135731218195488,
	2.6564064112613596791,  2.6109722484318473867,  2.5690336259249377804,
	2.5300096723888274575,  2.4934545220953721056,  2.4590181774118304858,
	2.4264206455337498092,  2.3954342780110624567,  2.3658713701176385946,
	2.3375752413392367757,  2.3104136836987629877,  2.2842740596774717687,
	2.2590595738691980898,  2.2346863955909785915,  2.2110814088787025256,
	2.1881804320760482874,  2.1659267937489210532,  2.1442701823603944611,
	2.1231657086739761375,  2.1025731351892376075,  2.0824562379920159572,
	2.0627822745083079781,  2.0435215366550671945,  2.02464697337738464,
	2.0061338699634712057,  1.9879595741276190335,  1.970103260854325633,
	1.9525457295535557645,  1.9352692282966217352,  1.9182573008645087409,
	1.901494653105150201,   1.8849670357077579208,  1.8686611409944875462,
	1.85256451172808978,    1.836665460258444682,   1.820952996596124418,
	1.8054167642192271437,  1.7900469825998572837,  1.7748343955860681476,
	1.759770224899592117,   1.744846128113799022,   1.730054160563729182,
	1.7153867407136660361,  1.7008366185699153039,  1.6863968467791665695,
	1.6720607540975995775,  1.6578219209540228096,  1.6436741568628672194,
	1.6296114794706331175,  1.6156280950431594068,  1.6017183802213763588,
	1.5878768648905743355,  1.5740982160229990416,  1.5603772223661671603,
	1.5467087798599086224,  1.5330878776740417546,  1.5195095847659385591,
	1.5059690368632017154,  1.492461423781352492,   1.4789819769899226198,
	1.4655259573427090736,  1.4520886428892227915,  1.4386653166845615459,
	1.4252512545140580968,  1.4118417124470556967,  1.3984319141310033174,
	1.3850170377326498361,  1.3715922024273405899,  1.3581524543301413122,
	1.3446927517535449681,  1.3312079496656250566,  1.3176927832094120774,
	1.3041418501286148324,  1.2905495919261944504,  1.2769102735601534082,
	1.2632179614546188429,  1.2494664995730662138,  1.2356494832633604375,
	1.2217602305399941631,  1.2077917504159472184,  1.1937367078331260206,
	1.1795873846639857163,  1.1653356361647499995,  1.1509728421488649719,
	1.1364898520131583304,  1.1218769225825397928,  1.1071236475340335836,
	1.092218876907274927,   1.0771506248928932603,  1.0619059636948215974,
	1.0464709007640424776,  1.0308302360681926846,  1.0149673952513273978,
	0.99886423349298048002, 0.98250080351542590229, 0.96585507940114656567,
	0.94890262551130311053, 0.93161619661514749602, 0.91396525102302894616,
	0.89591535258093435434, 0.87742742911291993213, 0.85845684319380943794,
	0.83895221429757360632, 0.81885390670035329563, 0.79809206064405291414,
	0.7765839878947558006,  0.7542306644540515137,  0.73091191064248450804,
	0.70647961133543202283, 0.68074791866914985405, 0.6534786387399702523,
	0.62435859733604526234, 0.59296294247144254452, 0.55869217840817897436,
	0.52065603876205379663, 0.47743783729668198834, 0.42654798635541407714,
	0.36287143109701985866, 0.27232086481394562893, 0, };

	Tflt x, y, f0, f1;

	while (1) {
		const Tint u = rng();
		const Tint idx = u & (ziggurat_BLOCK_COUNT - 1);
		const Tflt uf = (2.0 * RNG2F(u) - 1.0) * xtbl[idx];

		if (std::fabs(uf) < xtbl[idx + 1])  {
			return uf;
		}

		if (idx == 0) {
			do {
				x = std::log(RNG2F(rng()))
				        * (1.0 / ziggurat_R);
				y = std::log(RNG2F(rng()));
			} while (-(y + y) < x * x);
			if (uf < 0)
				return x - ziggurat_R;
			else
				return ziggurat_R - x;
		}

		y = uf * uf;
		f0 = std::exp(-0.5 * (xtbl[idx]     * xtbl[idx]     - y));
		f1 = std::exp(-0.5 * (xtbl[idx + 1] * xtbl[idx + 1] - y));
		if (f1 + RNG2F(rng()) * (f0 - f1) < 1.0)
			return uf;
	}

#undef RNG2F
}

uint32_t randu32()
{
	static LCG<uint32_t> r1;
	return r1();
}

uint64_t randu64()
{
	static LCG<uint32_t> r1;
	static LCG<uint32_t> r2;
	return ((uint64_t)r1()) << 32 | r2();
}

static void normaldistf_ziggurat_new(float* data, size_t count) {
  for (size_t i = 0; i < count; i++)
    data[i] = ziggurat_next<float, uint32_t>(&randu32);
}

static void normaldist_ziggurat_new(double* data, size_t count) {
  for (size_t i = 0; i < count; i++)
    data[i] = ziggurat_next<double, uint64_t>(&randu64);
}

REGISTER_TEST(ziggurat_new);

And here is the output of the method's I got compiled:

normaldistf
Benchmarking     ziggurat_new         ...    4.018ns
Benchmarking     ratio                ...   10.963ns
Benchmarking     ziggurat             ...    7.289ns
Benchmarking     null                 ...    1.049ns
Benchmarking     marsagliapolar       ...    6.133ns
Benchmarking     inverse              ...    6.017ns
Benchmarking     cpp11random          ...    8.852ns
Benchmarking     clt16                ...   16.563ns
Benchmarking     clt8                 ...    8.111ns
Benchmarking     clt4                 ...    4.064ns
Benchmarking     boxmuller            ...    7.484ns

normaldist
Benchmarking     ziggurat_new         ...    4.081ns
Benchmarking     ratio                ...   11.189ns
Benchmarking     ziggurat             ...    7.817ns
Benchmarking     null                 ...    1.149ns
Benchmarking     marsagliapolar       ...    8.711ns
Benchmarking     inverse              ...    7.222ns
Benchmarking     cpp11random          ...   17.009ns
Benchmarking     clt16                ...   16.760ns
Benchmarking     clt8                 ...    8.132ns
Benchmarking     clt4                 ...    4.088ns
Benchmarking     boxmuller            ...   16.546ns

There was no builtin LCG for uint64_t so I combined two calls uint32_t's the final patch should probably contain a special LCG<uint64_t>.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant