diff --git a/polars_business/polars_business/polars_business/__init__.py b/polars_business/polars_business/polars_business/__init__.py index 73464be..84acf2f 100644 --- a/polars_business/polars_business/polars_business/__init__.py +++ b/polars_business/polars_business/polars_business/__init__.py @@ -4,11 +4,20 @@ from polars.utils.udfs import _get_shared_lib_location import re from datetime import date +import sys from polars_business.ranges import date_range from polars.type_aliases import PolarsDataType -from typing import Sequence, cast, Iterable, Protocol +from typing import Iterable, Literal, Protocol, Sequence, cast, get_args + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + +RollStrategy: TypeAlias = Literal["raise", "forward", "backward"] + lib = _get_shared_lib_location(__file__) @@ -81,7 +90,7 @@ def offset_by( *, weekend: Sequence[str] = ("Sat", "Sun"), holidays: Sequence[date] | None = None, - roll: str = "raise", + roll: RollStrategy = "raise", ) -> BExpr: """ Offset this date by a relative time offset. @@ -173,6 +182,12 @@ def offset_by( │ 2024-01-04 ┆ -3bd ┆ 2024-01-01 │ └────────────┴──────┴──────────────┘ """ + if roll not in (valid_roll_strategies := get_args(RollStrategy)): + allowed = ", ".join(repr(m) for m in valid_roll_strategies) + raise ValueError( + f"`roll` strategy must be one of {{{allowed}}}, got {roll!r}" + ) + if ( isinstance(by, str) and (match := re.search(r"(\d+bd)", by)) is not None diff --git a/polars_business/polars_business/src/business_days.rs b/polars_business/polars_business/src/business_days.rs index aae3054..a350ee5 100644 --- a/polars_business/polars_business/src/business_days.rs +++ b/polars_business/polars_business/src/business_days.rs @@ -25,7 +25,9 @@ pub(crate) fn calculate_advance( let date = NaiveDateTime::from_timestamp_opt(date as i64 * 24 * 60 * 60, 0) .unwrap() .format("%Y-%m-%d"); - polars_bail!(ComputeError: format!("date {} is not a business date, cannot advance. `roll` argument coming soon.", date)) + polars_bail!(ComputeError: + format!("date {} is not a business date, cannot advance; set a valid `roll` strategy.", date) + ) }; } "forward" => { @@ -50,7 +52,11 @@ pub(crate) fn calculate_advance( } } } - _ => unreachable!(), + _ => { + polars_bail!(InvalidOperation: + "`roll` must be one of 'raise', 'forward' or 'backward'; found '{}'", roll + ) + } } if offset > 0 { diff --git a/polars_business/tests/test_business_offsets.py b/polars_business/tests/test_business_offsets.py index 446fe93..8e18289 100644 --- a/polars_business/tests/test_business_offsets.py +++ b/polars_business/tests/test_business_offsets.py @@ -220,3 +220,11 @@ def test_within_group_by() -> None: } ) assert_frame_equal(result, expected) + + +def test_invalid_roll_strategy() -> None: + df = pl.DataFrame( + {"date": pl.date_range(dt.date(2023, 12, 1), dt.date(2023, 12, 5), eager=True)} + ) + with pytest.raises(ValueError): + df.with_columns(plb.col("date").bdt.offset_by("1bd", roll="cabbage")) # type: ignore[arg-type]