diff --git a/python/polars_ds/stats.py b/python/polars_ds/stats.py index bc7953d..85d5380 100644 --- a/python/polars_ds/stats.py +++ b/python/polars_ds/stats.py @@ -477,6 +477,7 @@ def random( lower: pl.Expr | float = 0.0, upper: pl.Expr | float = 1.0, seed: int | None = None, + length: int | pl.Expr = pl.len(), ) -> pl.Expr: """ Generate random numbers in [lower, upper) @@ -489,24 +490,27 @@ def random( The upper bound, exclusive seed The random seed. None means no seed. + length + Custom length. Note length needs to match with other columns in the context. """ lo = pl.lit(lower, pl.Float64) if isinstance(lower, float) else lower up = pl.lit(upper, pl.Float64) if isinstance(upper, float) else upper + len_ = pl.lit(length, pl.UInt32) if isinstance(length, int) else length return pl_plugin( symbol="pl_random", - args=[pl.len(), lo, up, pl.lit(seed, pl.UInt64)], + args=[len_, lo, up, pl.lit(seed, pl.UInt64)], is_elementwise=True, ) -def random_null(var: str | pl.Expr, pct: float, seed: int | None = None) -> pl.Expr: +def random_null(x: str | pl.Expr, pct: float, seed: int | None = None) -> pl.Expr: """ - Creates random null values in var. If var contains nulls originally, they + Creates random null values in the columns. If var contains nulls originally, they will stay null. Parameters ---------- - var + x Either the name of the column or a Polars expression pct Percentage of nulls to randomly generate. This percentage is based on the @@ -518,11 +522,15 @@ def random_null(var: str | pl.Expr, pct: float, seed: int | None = None) -> pl.E if pct <= 0.0 or pct >= 1.0: raise ValueError("Input `pct` must be > 0 and < 1") - to_null = random(0.0, 1.0, seed=seed) < pct - return pl.when(to_null).then(None).otherwise(str_to_expr(var)) + return pl.when(random(0.0, 1.0, seed=seed) < pct).then(None).otherwise(str_to_expr(x)) -def random_int(lower: int | pl.Expr, upper: int | pl.Expr, seed: int | None = None) -> pl.Expr: +def random_int( + lower: int | pl.Expr, + upper: int | pl.Expr, + seed: int | None = None, + length: int | pl.Expr = pl.len(), +) -> pl.Expr: """ Generates random integer between lower and upper. @@ -534,16 +542,19 @@ def random_int(lower: int | pl.Expr, upper: int | pl.Expr, seed: int | None = No The upper bound, exclusive seed The random seed. None means no seed. + length + Custom length. Note length needs to match with other columns in the context. """ if lower == upper: raise ValueError("Input `lower` must be smaller than `higher`") lo = pl.lit(lower, pl.Int32) if isinstance(lower, int) else lower.cast(pl.Int32) hi = pl.lit(upper, pl.Int32) if isinstance(upper, int) else upper.cast(pl.Int32) + len_ = pl.lit(length, pl.UInt32) if isinstance(length, int) else length return pl_plugin( symbol="pl_rand_int", args=[ - pl.len().cast(pl.UInt32), + len_, lo, hi, pl.lit(seed, pl.UInt64), @@ -552,7 +563,12 @@ def random_int(lower: int | pl.Expr, upper: int | pl.Expr, seed: int | None = No ) -def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr: +def random_str( + min_size: int, + max_size: int, + seed: int | None = None, + length: int | pl.Expr = pl.len(), +) -> pl.Expr: """ Generates random strings of length between min_size and max_size. The characters are uniformly distributed over ASCII letters and numbers: a-z, A-Z and 0-9. @@ -565,6 +581,8 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr The max size of the string, inclusive seed The random seed. None means no seed. + length + Custom length. Note length needs to match with other columns in the context. """ mi, ma = min_size, max_size if min_size > max_size: @@ -573,7 +591,7 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr return pl_plugin( symbol="pl_rand_str", args=[ - pl.len().cast(pl.UInt32), + pl.lit(length, pl.UInt32) if isinstance(length, int) else length, pl.lit(mi, pl.UInt32), pl.lit(ma, pl.UInt32), pl.lit(seed, pl.UInt64), @@ -582,7 +600,7 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr ) -def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr: +def random_binomial(n: int, p: int, seed: int | None = None, length: int | pl.Expr = pl.len()) -> pl.Expr: """ Generates random integer following a binomial distribution. @@ -594,6 +612,8 @@ def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr: The p in a binomial distribution seed The random seed. None means no seed. + length + Custom length. Note length needs to match with other columns in the context. """ if n < 1: raise ValueError("Input `n` must be > 1.") @@ -601,7 +621,7 @@ def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr: return pl_plugin( symbol="pl_rand_binomial", args=[ - pl.len().cast(pl.UInt32), + pl.lit(length, pl.UInt32) if isinstance(length, int) else length, pl.lit(n, pl.Int32), pl.lit(p, pl.Float64), pl.lit(seed, pl.UInt64), @@ -610,7 +630,7 @@ def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr: ) -def random_exp(lambda_: float, seed: int | None = None) -> pl.Expr: +def random_exp(lambda_: float, seed: int | None = None, length: int | pl.Expr = pl.len()) -> pl.Expr: """ Generates random numbers following an exponential distribution. @@ -620,15 +640,26 @@ def random_exp(lambda_: float, seed: int | None = None) -> pl.Expr: The lambda in an exponential distribution seed The random seed. None means no seed. + length + Custom length. Note length needs to match with other columns in the context. """ return pl_plugin( symbol="pl_rand_exp", - args=[pl.len().cast(pl.UInt32), pl.lit(lambda_, pl.Float64), pl.lit(seed, pl.UInt64)], + args=[ + pl.lit(length, pl.UInt32) if isinstance(length, int) else length, + pl.lit(lambda_, pl.Float64), + pl.lit(seed, pl.UInt64) + ], is_elementwise=True, ) -def random_normal(mean: pl.Expr | float, std: pl.Expr | float, seed: int | None = None) -> pl.Expr: +def random_normal( + mean: pl.Expr | float, + std: pl.Expr | float, + seed: int | None = None, + length: int | pl.Expr = pl.len() +) -> pl.Expr: """ Generates random number following a normal distribution. @@ -640,12 +671,17 @@ def random_normal(mean: pl.Expr | float, std: pl.Expr | float, seed: int | None The std in a normal distribution seed The random seed. None means no seed. + length + Custom length. Note length needs to match with other columns in the context. """ - m = pl.lit(mean, pl.Float64) if isinstance(mean, float) else mean - s = pl.lit(std, pl.Float64) if isinstance(std, float) else std return pl_plugin( symbol="pl_rand_normal", - args=[pl.len().cast(pl.UInt32), m, s, pl.lit(seed, pl.UInt64)], + args=[ + pl.lit(length, pl.UInt32) if isinstance(length, int) else length, + pl.lit(mean, pl.Float64) if isinstance(mean, float) else mean, + pl.lit(std, pl.Float64) if isinstance(std, float) else std, + pl.lit(seed, pl.UInt64) + ], is_elementwise=True, )