-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8a1f927
commit 37c9eb5
Showing
10 changed files
with
340 additions
and
148 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""Module containing functions to generate fractionally differentiated features.""" | ||
|
||
import polars as pl | ||
from polars.plugins import register_plugin_function | ||
|
||
from polars_trading._utils import LIB | ||
from polars_trading.typing import IntoExpr | ||
|
||
|
||
def frac_diff(expr: IntoExpr, d: float, threshold: float) -> pl.Expr: | ||
"""Generate expression to calculate the fractionally differentiated series. | ||
Args: | ||
---- | ||
expr: IntoExpr - The expression to calculate the fractionally differentiated | ||
series. | ||
d: float - The fractional difference. | ||
threshold: float - The threshold. | ||
Returns: | ||
------- | ||
pl.Expr: The expression to calculate the fractionally differentiated series. | ||
""" | ||
return register_plugin_function( | ||
plugin_path=LIB, | ||
args=[expr], | ||
kwargs={"d": d, "threshold": threshold}, | ||
is_elementwise=False, | ||
function_name="frac_diff", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,73 @@ | ||
from polars_trading._testing.features import get_weights_ffd | ||
from polars_trading._testing.features import get_weights_ffd, frac_diff_ffd | ||
from polars_trading._internal import get_weights_ffd_py | ||
from polars_trading.features.frac_diff import frac_diff | ||
import pytest | ||
from polars.testing import assert_frame_equal | ||
import pandas as pd | ||
import polars as pl | ||
|
||
|
||
def apply_pd_frac_diff(df: pd.DataFrame, d: float, threshold: float) -> pd.DataFrame: | ||
return ( | ||
df.set_index("ts_event") | ||
.groupby("symbol")[["price"]] | ||
.apply(frac_diff_ffd, 0.5, 1e-3) | ||
.reset_index() | ||
) | ||
|
||
|
||
def test__get_weights_ffd__matches_pandas(): | ||
out = get_weights_ffd(0.5, 1e-3).flatten().tolist() | ||
out2 = get_weights_ffd_py(0.5, 1e-3) | ||
assert out == out2 | ||
|
||
|
||
@pytest.mark.benchmark(group="get_weights_ffd") | ||
def test__get_weights_ffd__benchmark_rs(benchmark): | ||
benchmark(get_weights_ffd_py, 0.5, 1e-5) | ||
|
||
|
||
@pytest.mark.benchmark(group="get_weights_ffd") | ||
@pytest.mark.pandas | ||
def test__get_weights_ffd__benchmark_pandas(benchmark): | ||
benchmark(get_weights_ffd, 0.5, 1e-5) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"trade_data", [{"n_rows": 10_000, "n_companies": 3}], indirect=True | ||
) | ||
def test__frac_diff__matches_pandas(trade_data): | ||
trade_data = trade_data.sort("ts_event") | ||
out = trade_data.select( | ||
"ts_event", "symbol", frac_diff("price", 0.5, 1e-3).alias("frac_diff") | ||
) | ||
out2 = pl.DataFrame(apply_pd_frac_diff(trade_data.to_pandas(), 0.5, 1e-3)).rename( | ||
{"level_1": "ts_event", "price": "frac_diff"} | ||
).cast(out.schema) | ||
assert_frame_equal( | ||
out.drop_nulls().sort("ts_event", "symbol"), | ||
out2.sort("ts_event", "symbol"), | ||
check_column_order=False, | ||
) | ||
|
||
|
||
@pytest.mark.benchmark(group="frac_diff") | ||
@pytest.mark.parametrize( | ||
"trade_data", [{"n_rows": 10_000, "n_companies": 3}], indirect=True | ||
) | ||
def test__frac_diff__benchmark_polars(benchmark, trade_data): | ||
trade_data = trade_data.lazy() | ||
benchmark( | ||
trade_data.select( | ||
"ts_event", "symbol", frac_diff("price", 0.5, 1e-3).alias("frac_diff") | ||
).collect | ||
) | ||
|
||
|
||
@pytest.mark.benchmark(group="frac_diff") | ||
@pytest.mark.pandas | ||
@pytest.mark.parametrize( | ||
"trade_data", [{"n_rows": 10_000, "n_companies": 3}], indirect=True | ||
) | ||
def test__frac_diff__benchmark_pandas(benchmark, trade_data): | ||
benchmark(apply_pd_frac_diff, trade_data.to_pandas(), 0.5, 1e-5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.