Skip to content

Commit

Permalink
Merge pull request #302 from abstractqqq/add_len_in_random
Browse files Browse the repository at this point in the history
added len in random
  • Loading branch information
abstractqqq authored Dec 21, 2024
2 parents 8a10f2a + e95ab8d commit 80efce7
Showing 1 changed file with 54 additions and 18 deletions.
72 changes: 54 additions & 18 deletions python/polars_ds/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def random(
lower: pl.Expr | float = 0.0,
upper: pl.Expr | float = 1.0,
seed: int | None = None,
length: int | pl.Expr = pl.len(),
) -> pl.Expr:
"""
Generate random numbers in [lower, upper)
Expand All @@ -489,24 +490,27 @@ def random(
The upper bound, exclusive
seed
The random seed. None means no seed.
length
Custom length. Note length needs to match with other columns in the context.
"""
lo = pl.lit(lower, pl.Float64) if isinstance(lower, float) else lower
up = pl.lit(upper, pl.Float64) if isinstance(upper, float) else upper
len_ = pl.lit(length, pl.UInt32) if isinstance(length, int) else length
return pl_plugin(
symbol="pl_random",
args=[pl.len(), lo, up, pl.lit(seed, pl.UInt64)],
args=[len_, lo, up, pl.lit(seed, pl.UInt64)],
is_elementwise=True,
)


def random_null(var: str | pl.Expr, pct: float, seed: int | None = None) -> pl.Expr:
def random_null(x: str | pl.Expr, pct: float, seed: int | None = None) -> pl.Expr:
"""
Creates random null values in var. If var contains nulls originally, they
Creates random null values in the columns. If var contains nulls originally, they
will stay null.
Parameters
----------
var
x
Either the name of the column or a Polars expression
pct
Percentage of nulls to randomly generate. This percentage is based on the
Expand All @@ -518,11 +522,15 @@ def random_null(var: str | pl.Expr, pct: float, seed: int | None = None) -> pl.E
if pct <= 0.0 or pct >= 1.0:
raise ValueError("Input `pct` must be > 0 and < 1")

to_null = random(0.0, 1.0, seed=seed) < pct
return pl.when(to_null).then(None).otherwise(str_to_expr(var))
return pl.when(random(0.0, 1.0, seed=seed) < pct).then(None).otherwise(str_to_expr(x))


def random_int(lower: int | pl.Expr, upper: int | pl.Expr, seed: int | None = None) -> pl.Expr:
def random_int(
lower: int | pl.Expr,
upper: int | pl.Expr,
seed: int | None = None,
length: int | pl.Expr = pl.len(),
) -> pl.Expr:
"""
Generates random integer between lower and upper.
Expand All @@ -534,16 +542,19 @@ def random_int(lower: int | pl.Expr, upper: int | pl.Expr, seed: int | None = No
The upper bound, exclusive
seed
The random seed. None means no seed.
length
Custom length. Note length needs to match with other columns in the context.
"""
if lower == upper:
raise ValueError("Input `lower` must be smaller than `higher`")

lo = pl.lit(lower, pl.Int32) if isinstance(lower, int) else lower.cast(pl.Int32)
hi = pl.lit(upper, pl.Int32) if isinstance(upper, int) else upper.cast(pl.Int32)
len_ = pl.lit(length, pl.UInt32) if isinstance(length, int) else length
return pl_plugin(
symbol="pl_rand_int",
args=[
pl.len().cast(pl.UInt32),
len_,
lo,
hi,
pl.lit(seed, pl.UInt64),
Expand All @@ -552,7 +563,12 @@ def random_int(lower: int | pl.Expr, upper: int | pl.Expr, seed: int | None = No
)


def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr:
def random_str(
min_size: int,
max_size: int,
seed: int | None = None,
length: int | pl.Expr = pl.len(),
) -> pl.Expr:
"""
Generates random strings of length between min_size and max_size. The characters are
uniformly distributed over ASCII letters and numbers: a-z, A-Z and 0-9.
Expand All @@ -565,6 +581,8 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr
The max size of the string, inclusive
seed
The random seed. None means no seed.
length
Custom length. Note length needs to match with other columns in the context.
"""
mi, ma = min_size, max_size
if min_size > max_size:
Expand All @@ -573,7 +591,7 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr
return pl_plugin(
symbol="pl_rand_str",
args=[
pl.len().cast(pl.UInt32),
pl.lit(length, pl.UInt32) if isinstance(length, int) else length,
pl.lit(mi, pl.UInt32),
pl.lit(ma, pl.UInt32),
pl.lit(seed, pl.UInt64),
Expand All @@ -582,7 +600,7 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr
)


def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr:
def random_binomial(n: int, p: int, seed: int | None = None, length: int | pl.Expr = pl.len()) -> pl.Expr:
"""
Generates random integer following a binomial distribution.
Expand All @@ -594,14 +612,16 @@ def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr:
The p in a binomial distribution
seed
The random seed. None means no seed.
length
Custom length. Note length needs to match with other columns in the context.
"""
if n < 1:
raise ValueError("Input `n` must be > 1.")

return pl_plugin(
symbol="pl_rand_binomial",
args=[
pl.len().cast(pl.UInt32),
pl.lit(length, pl.UInt32) if isinstance(length, int) else length,
pl.lit(n, pl.Int32),
pl.lit(p, pl.Float64),
pl.lit(seed, pl.UInt64),
Expand All @@ -610,7 +630,7 @@ def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr:
)


def random_exp(lambda_: float, seed: int | None = None) -> pl.Expr:
def random_exp(lambda_: float, seed: int | None = None, length: int | pl.Expr = pl.len()) -> pl.Expr:
"""
Generates random numbers following an exponential distribution.
Expand All @@ -620,15 +640,26 @@ def random_exp(lambda_: float, seed: int | None = None) -> pl.Expr:
The lambda in an exponential distribution
seed
The random seed. None means no seed.
length
Custom length. Note length needs to match with other columns in the context.
"""
return pl_plugin(
symbol="pl_rand_exp",
args=[pl.len().cast(pl.UInt32), pl.lit(lambda_, pl.Float64), pl.lit(seed, pl.UInt64)],
args=[
pl.lit(length, pl.UInt32) if isinstance(length, int) else length,
pl.lit(lambda_, pl.Float64),
pl.lit(seed, pl.UInt64)
],
is_elementwise=True,
)


def random_normal(mean: pl.Expr | float, std: pl.Expr | float, seed: int | None = None) -> pl.Expr:
def random_normal(
mean: pl.Expr | float,
std: pl.Expr | float,
seed: int | None = None,
length: int | pl.Expr = pl.len()
) -> pl.Expr:
"""
Generates random number following a normal distribution.
Expand All @@ -640,12 +671,17 @@ def random_normal(mean: pl.Expr | float, std: pl.Expr | float, seed: int | None
The std in a normal distribution
seed
The random seed. None means no seed.
length
Custom length. Note length needs to match with other columns in the context.
"""
m = pl.lit(mean, pl.Float64) if isinstance(mean, float) else mean
s = pl.lit(std, pl.Float64) if isinstance(std, float) else std
return pl_plugin(
symbol="pl_rand_normal",
args=[pl.len().cast(pl.UInt32), m, s, pl.lit(seed, pl.UInt64)],
args=[
pl.lit(length, pl.UInt32) if isinstance(length, int) else length,
pl.lit(mean, pl.Float64) if isinstance(mean, float) else mean,
pl.lit(std, pl.Float64) if isinstance(std, float) else std,
pl.lit(seed, pl.UInt64)
],
is_elementwise=True,
)

Expand Down

0 comments on commit 80efce7

Please sign in to comment.