Skip to content

Commit

Permalink
working frac diff
Browse files Browse the repository at this point in the history
  • Loading branch information
ngriffiths13 committed Nov 15, 2024
1 parent 8a1f927 commit 37c9eb5
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 148 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ pyo3-polars = { version = "*", features = ["derive", "dtype-struct"] }
serde = { version = "1", features = ["derive"] }
polars = { version = "*", features = ["dtype-struct"] }
num = "0.4.3"
polars-arrow = "0.44.2"
29 changes: 28 additions & 1 deletion polars_trading/_testing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,38 @@ def frac_diff_ffd(series: pd.DataFrame, d: float, thresh: float = 1e-5) -> pd.Da
width = len(w) - 1
df = {}
for name in series.columns:
series_f = series[[name]].fillna(method="ffill").dropna()
series_f = series[[name]].ffill().dropna()
df_ = pd.Series()
for iloc1 in range(width, series_f.shape[0]):
loc0 = series_f.index[iloc1 - width]
loc1 = series_f.index[iloc1]
df_[loc1] = np.dot(w.T, series_f.loc[loc0:loc1])[0, 0]
df[name] = df_.copy(deep=True)
return pd.concat(df, axis=1)


if __name__ == "__main__":
import polars as pl

from polars_trading._testing.data import generate_trade_data
from polars_trading.features.frac_diff import frac_diff

pl.Config.set_verbose(True)
data = generate_trade_data(10_000, n_companies=3).sort("ts_event")
print(

Check failure on line 61 in polars_trading/_testing/features.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (T201)

polars_trading/_testing/features.py:61:5: T201 `print` found
data.with_columns(
frac_diff("price", 0.5, 1e-3).over("symbol").alias("frac_diff")
).sort("symbol")
)

print(

Check failure on line 67 in polars_trading/_testing/features.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (T201)

polars_trading/_testing/features.py:67:5: T201 `print` found
pl.from_pandas(
data.to_pandas()
.set_index("ts_event")
.groupby("symbol")[["price"]]
.apply(frac_diff_ffd, 0.5, 1e-3),
include_index=True,
).sort("symbol")
)

# print(frac_diff_ffd(data.to_pandas().set_index("ts_event")[["price"]], 0.5, 1e-3))

Check failure on line 77 in polars_trading/_testing/features.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (ERA001)

polars_trading/_testing/features.py:77:5: ERA001 Found commented-out code
31 changes: 31 additions & 0 deletions polars_trading/features/frac_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Module containing functions to generate fractionally differentiated features."""

import polars as pl
from polars.plugins import register_plugin_function

from polars_trading._utils import LIB
from polars_trading.typing import IntoExpr


def frac_diff(expr: IntoExpr, d: float, threshold: float) -> pl.Expr:
"""Generate expression to calculate the fractionally differentiated series.
Args:
----
expr: IntoExpr - The expression to calculate the fractionally differentiated
series.
d: float - The fractional difference.
threshold: float - The threshold.
Returns:
-------
pl.Expr: The expression to calculate the fractionally differentiated series.
"""
return register_plugin_function(
plugin_path=LIB,
args=[expr],
kwargs={"d": d, "threshold": threshold},
is_elementwise=False,
function_name="frac_diff",
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ dev-dependencies = [
"build>=1.2.1",
"commitizen>=3.29.0",
"marimo>=0.8.8",
"maturin>=1.7.1",
"maturin[patchelf]>=1.7.1",
"mkdocs-gen-files>=0.5.0",
"mkdocstrings[python]>=0.26.1",
"pip>=24.2",
Expand Down
45 changes: 45 additions & 0 deletions src/frac_diff.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
use polars::prelude::*;
use polars_arrow::bitmap::MutableBitmap;
use pyo3_polars::derive::polars_expr;

use serde::Deserialize;

pub fn get_weights_ffd(d: f64, threshold: f64) -> Vec<f64> {
let mut w = vec![1.];
let mut k = 1.0;
Expand All @@ -12,3 +18,42 @@ pub fn get_weights_ffd(d: f64, threshold: f64) -> Vec<f64> {
w.reverse();
w
}

fn dot_product(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
}

#[derive(Deserialize)]
struct FracDiffKwargs {
d: f64,
threshold: f64,
}

#[polars_expr(output_type=Float64)]
fn frac_diff(inputs: &[Series], kwargs: FracDiffKwargs) -> PolarsResult<Series> {
let prices = inputs[0].f64().unwrap().to_vec_null_aware();
let prices = if prices.is_left() {
prices.left().unwrap()
} else {
return Err(PolarsError::InvalidOperation("Null price found".into()));
};
let weights = get_weights_ffd(kwargs.d, kwargs.threshold);
let n_weights = weights.len();
let mut outputs: Vec<f64> = Vec::with_capacity(prices.len());
let mut validity_mask = MutableBitmap::with_capacity(prices.len());
validity_mask.extend_constant(prices.len(), true);
for i in 0..prices.len() {
if i < (n_weights - 1) {
outputs.push(0.0);
validity_mask.set(i, false);
} else {
let window = &prices[i + 1 - n_weights..i + 1];
let output = dot_product(window, &weights);
outputs.push(output);
}
}
Ok(
Float64Chunked::from_vec_validity("frac_diff".into(), outputs, validity_mask.into())
.into_series(),
)
}
58 changes: 57 additions & 1 deletion tests/features/test_frac_diff.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,73 @@
from polars_trading._testing.features import get_weights_ffd
from polars_trading._testing.features import get_weights_ffd, frac_diff_ffd
from polars_trading._internal import get_weights_ffd_py
from polars_trading.features.frac_diff import frac_diff
import pytest
from polars.testing import assert_frame_equal
import pandas as pd
import polars as pl


def apply_pd_frac_diff(df: pd.DataFrame, d: float, threshold: float) -> pd.DataFrame:
return (
df.set_index("ts_event")
.groupby("symbol")[["price"]]
.apply(frac_diff_ffd, 0.5, 1e-3)
.reset_index()
)


def test__get_weights_ffd__matches_pandas():
out = get_weights_ffd(0.5, 1e-3).flatten().tolist()
out2 = get_weights_ffd_py(0.5, 1e-3)
assert out == out2


@pytest.mark.benchmark(group="get_weights_ffd")
def test__get_weights_ffd__benchmark_rs(benchmark):
benchmark(get_weights_ffd_py, 0.5, 1e-5)


@pytest.mark.benchmark(group="get_weights_ffd")
@pytest.mark.pandas
def test__get_weights_ffd__benchmark_pandas(benchmark):
benchmark(get_weights_ffd, 0.5, 1e-5)


@pytest.mark.parametrize(
"trade_data", [{"n_rows": 10_000, "n_companies": 3}], indirect=True
)
def test__frac_diff__matches_pandas(trade_data):
trade_data = trade_data.sort("ts_event")
out = trade_data.select(
"ts_event", "symbol", frac_diff("price", 0.5, 1e-3).alias("frac_diff")
)
out2 = pl.DataFrame(apply_pd_frac_diff(trade_data.to_pandas(), 0.5, 1e-3)).rename(
{"level_1": "ts_event", "price": "frac_diff"}
).cast(out.schema)
assert_frame_equal(
out.drop_nulls().sort("ts_event", "symbol"),
out2.sort("ts_event", "symbol"),
check_column_order=False,
)


@pytest.mark.benchmark(group="frac_diff")
@pytest.mark.parametrize(
"trade_data", [{"n_rows": 10_000, "n_companies": 3}], indirect=True
)
def test__frac_diff__benchmark_polars(benchmark, trade_data):
trade_data = trade_data.lazy()
benchmark(
trade_data.select(
"ts_event", "symbol", frac_diff("price", 0.5, 1e-3).alias("frac_diff")
).collect
)


@pytest.mark.benchmark(group="frac_diff")
@pytest.mark.pandas
@pytest.mark.parametrize(
"trade_data", [{"n_rows": 10_000, "n_companies": 3}], indirect=True
)
def test__frac_diff__benchmark_pandas(benchmark, trade_data):
benchmark(apply_pd_frac_diff, trade_data.to_pandas(), 0.5, 1e-5)
16 changes: 12 additions & 4 deletions tests/labels/test_dynamic_labels.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from polars_trading.labels.dynamic_labels import daily_vol, get_vertical_barrier_by_timedelta
from polars_trading.labels.dynamic_labels import (
daily_vol,
get_vertical_barrier_by_timedelta,
)
import polars as pl
from polars_trading._testing.labels import get_daily_vol
import pytest
from polars.testing import assert_frame_equal
from datetime import datetime, timedelta


@pytest.mark.parametrize(
"trade_data",
[
Expand All @@ -30,8 +34,9 @@ def test__daily_vol__single_security(trade_data):
)
def test__daily_vol__multi_security(trade_data):
pl_result = (
daily_vol(trade_data, "ts_event", "price", "symbol", 5)
.sort("ts_event", "symbol")
daily_vol(trade_data, "ts_event", "price", "symbol", 5).sort(
"ts_event", "symbol"
)
).drop_nulls()
pd_result = (
trade_data.to_pandas()
Expand Down Expand Up @@ -163,6 +168,7 @@ def test__get_vertical_barrier_by_timedelta__simple():
result = get_vertical_barrier_by_timedelta(df.lazy(), "ts_event", "2h").collect()
assert_frame_equal(result, expected)


def test__get_vertical_barrier_by_timedelta__skip_rows():
df = pl.DataFrame(
{
Expand Down Expand Up @@ -243,7 +249,9 @@ def test__get_vertical_barrier_by_timedelta__timedelta():
└─────────────────────┴─────────────────────┘
""")

result = get_vertical_barrier_by_timedelta(df.lazy(), "ts_event", timedelta(hours=2)).collect()
result = get_vertical_barrier_by_timedelta(
df.lazy(), "ts_event", timedelta(hours=2)
).collect()
assert_frame_equal(result, expected)


Expand Down
6 changes: 5 additions & 1 deletion tests/labels/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,8 @@ def test__fixed_time_return_classification__expr_threshold():
indirect=True,
)
def test__fixed_time_return_classification__benchmark(trade_data, benchmark):
benchmark(trade_data.lazy().select(fixed_time_return_classification("price", 50, 0.2, symbol="symbol")).collect)
benchmark(
trade_data.lazy()
.select(fixed_time_return_classification("price", 50, 0.2, symbol="symbol"))
.collect
)
Loading

0 comments on commit 37c9eb5

Please sign in to comment.