Skip to content

Commit

Permalink
Merge pull request #33 from alexander-beedie/minor-roll-param-tweaks
Browse files Browse the repository at this point in the history
fix: don't panic on invalid `roll` strategy, and update an error message
  • Loading branch information
MarcoGorelli authored Dec 5, 2023
2 parents a637b0d + 1eae985 commit b48bdfd
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
19 changes: 17 additions & 2 deletions polars_business/polars_business/polars_business/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@
from polars.utils.udfs import _get_shared_lib_location
import re
from datetime import date
import sys

from polars_business.ranges import date_range

from polars.type_aliases import PolarsDataType
from typing import Sequence, cast, Iterable, Protocol
from typing import Iterable, Literal, Protocol, Sequence, cast, get_args

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

RollStrategy: TypeAlias = Literal["raise", "forward", "backward"]


lib = _get_shared_lib_location(__file__)

Expand Down Expand Up @@ -81,7 +90,7 @@ def offset_by(
*,
weekend: Sequence[str] = ("Sat", "Sun"),
holidays: Sequence[date] | None = None,
roll: str = "raise",
roll: RollStrategy = "raise",
) -> BExpr:
"""
Offset this date by a relative time offset.
Expand Down Expand Up @@ -173,6 +182,12 @@ def offset_by(
│ 2024-01-04 ┆ -3bd ┆ 2024-01-01 │
└────────────┴──────┴──────────────┘
"""
if roll not in (valid_roll_strategies := get_args(RollStrategy)):
allowed = ", ".join(repr(m) for m in valid_roll_strategies)
raise ValueError(
f"`roll` strategy must be one of {{{allowed}}}, got {roll!r}"
)

if (
isinstance(by, str)
and (match := re.search(r"(\d+bd)", by)) is not None
Expand Down
10 changes: 8 additions & 2 deletions polars_business/polars_business/src/business_days.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ pub(crate) fn calculate_advance(
let date = NaiveDateTime::from_timestamp_opt(date as i64 * 24 * 60 * 60, 0)
.unwrap()
.format("%Y-%m-%d");
polars_bail!(ComputeError: format!("date {} is not a business date, cannot advance. `roll` argument coming soon.", date))
polars_bail!(ComputeError:
format!("date {} is not a business date, cannot advance; set a valid `roll` strategy.", date)
)
};
}
"forward" => {
Expand All @@ -50,7 +52,11 @@ pub(crate) fn calculate_advance(
}
}
}
_ => unreachable!(),
_ => {
polars_bail!(InvalidOperation:
"`roll` must be one of 'raise', 'forward' or 'backward'; found '{}'", roll
)
}
}

if offset > 0 {
Expand Down
8 changes: 8 additions & 0 deletions polars_business/tests/test_business_offsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,11 @@ def test_within_group_by() -> None:
}
)
assert_frame_equal(result, expected)


def test_invalid_roll_strategy() -> None:
df = pl.DataFrame(
{"date": pl.date_range(dt.date(2023, 12, 1), dt.date(2023, 12, 5), eager=True)}
)
with pytest.raises(ValueError):
df.with_columns(plb.col("date").bdt.offset_by("1bd", roll="cabbage")) # type: ignore[arg-type]

0 comments on commit b48bdfd

Please sign in to comment.