Skip to content

Commit

Permalink
More robust rand_laplace implementation (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan authored Oct 12, 2024
1 parent f538672 commit 8955005
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
9 changes: 3 additions & 6 deletions quinn/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,9 @@ def rand_laplace(
if not isinstance(beta, Column):
beta = F.lit(beta)

u = F.rand(seed)

return (
F.when(u < F.lit(0.5), mu + beta * F.log(2 * u))
.otherwise(mu - beta * F.log(2 * (1 - u)))
.alias("laplace_random")
u = F.rand(seed) - F.lit(0.5)
return (mu - beta * F.signum(u) * F.log(F.lit(1) - (F.lit(2) * F.abs(u)))).alias(
"laplace_random"

Check failure on line 43 in quinn/math.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (COM812)

quinn/math.py:43:25: COM812 Trailing comma missing
)


Expand Down
17 changes: 13 additions & 4 deletions tests/test_math.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import pyspark.sql.functions as F

import pytest
import quinn
import math
from .spark import spark


def test_rand_laplace():
@pytest.mark.parametrize(
"mean, scale",
[
(1.0, 2.0),
(2.0, 3.0),
(3.0, 4.0),
],
)
def test_rand_laplace(mean: float, scale: float):
stats = (
spark.range(100000)
.select(quinn.rand_laplace(0.0, 1.0, 42))
.select(quinn.rand_laplace(mean, scale, 42))
.agg(
F.mean("laplace_random").alias("mean"),
F.stddev("laplace_random").alias("std_dev"),
Expand All @@ -20,8 +29,8 @@ def test_rand_laplace():
laplace_stddev = stats["std_dev"]

# Laplace distribution with mean=0.0 and scale=1.0 has mean=0.0 and stddev=sqrt(2.0)
assert abs(laplace_mean) <= 0.1
assert abs(laplace_stddev - math.sqrt(2.0)) < 0.5
assert abs(laplace_mean - mean) <= 0.1
assert abs(laplace_stddev - scale * math.sqrt(2.0)) <= 0.1


def test_rand_range():
Expand Down

0 comments on commit 8955005

Please sign in to comment.