-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(labels): add daily vol calculation (#1)
This PR adds daily vol function from Lopez de Prado AFML pg.44 --------- Co-authored-by: Nelson Griffiths <[email protected]>
- Loading branch information
1 parent
1266bae
commit f2137ea
Showing
8 changed files
with
239 additions
and
175 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import pandas as pd | ||
|
||
|
||
def get_daily_vol(close: pd.Series, span0: int = 100) -> pd.Series: | ||
# This function calculates returns as close to 24 hours ago as possible and then | ||
# calculates the exponentially weighted moving standard deviation of those returns. | ||
close = close.sort_index().copy() | ||
df0 = close.index.searchsorted(close.index - pd.Timedelta(days=1)) | ||
df0 = df0[df0 > 0] | ||
df0 = pd.Series( | ||
close.index[df0], index=close.index[close.shape[0] - df0.shape[0] :] | ||
) | ||
df0 = close.loc[df0.index] / close.loc[df0.values].values - 1 | ||
return df0.ewm(span=span0).std().fillna(0.0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from __future__ import annotations | ||
|
||
from datetime import timedelta | ||
from typing import TYPE_CHECKING | ||
|
||
import polars as pl | ||
|
||
if TYPE_CHECKING: | ||
from polars_trading.typing import FrameType | ||
|
||
|
||
def daily_vol( | ||
df: FrameType, | ||
timestamp_col: str, | ||
price_col: str, | ||
symbol_col: str | None = None, | ||
span: int = 100, | ||
) -> FrameType: | ||
"""This function calculates the daily volatility of a price series. | ||
It uses an the daily volatiity by looking back at the return from the oldest price | ||
in the last 24 hour period to the current price. It then calculates the exponential | ||
weighted standard deviation of the returns. | ||
Args: | ||
---- | ||
df (DataFrame): The DataFrame containing the price series. | ||
timestamp_col (str): The column name containing the timestamps. | ||
price_col (str): The column name containing the prices. | ||
symbol_col (str | None): The column name containing the symbols. If None, it is | ||
assumed that the prices are for a single symbol. Defaults to None. | ||
span (int): The span of the exponential weighted standard deviation. Defaults to | ||
100. | ||
Returns: | ||
FrameType: The DataFrame with the daily volatility. | ||
""" | ||
returns = ( | ||
df.sort(timestamp_col) | ||
.rolling(timestamp_col, period="24h", group_by=symbol_col) | ||
.agg(pl.last(price_col).truediv(pl.first(price_col)).sub(1).alias("return")) | ||
) | ||
returns = returns.filter( | ||
(pl.col(timestamp_col) - timedelta(hours=24)) | ||
> pl.col(timestamp_col).min().over(symbol_col) | ||
) | ||
vol_expr = ( | ||
pl.col("return") | ||
.ewm_std(span=span) | ||
.over(symbol_col) | ||
.alias("daily_return_volatility") | ||
) | ||
return_cols = ( | ||
[ | ||
timestamp_col, | ||
symbol_col, | ||
vol_expr, | ||
] | ||
if symbol_col | ||
else [ | ||
timestamp_col, | ||
vol_expr, | ||
] | ||
) | ||
|
||
return returns.select(*return_cols) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import pytest | ||
from polars_trading._testing.data import generate_trade_data | ||
|
||
|
||
@pytest.fixture | ||
def trade_data(request): | ||
return generate_trade_data(**request.param) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from polars_trading.labels.dynamic_labels import daily_vol | ||
import polars as pl | ||
from polars_trading._testing.labels import get_daily_vol | ||
import pytest | ||
from polars.testing import assert_frame_equal | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"trade_data", | ||
[ | ||
{"n_rows": 10_000, "n_companies": 1}, | ||
], | ||
indirect=True, | ||
) | ||
def test__daily_vol__single_security(trade_data): | ||
pl_result = daily_vol(trade_data.lazy(), "ts_event", "price", None, 5).collect() | ||
pd_result = get_daily_vol(trade_data.to_pandas().set_index("ts_event")["price"], 5) | ||
pd_result = pl.from_pandas(pd_result.reset_index()).rename( | ||
{"price": "daily_return_volatility"} | ||
) | ||
assert_frame_equal(pl_result, pd_result) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"trade_data", | ||
[ | ||
{"n_rows": 10_000, "n_companies": 3}, | ||
], | ||
indirect=True, | ||
) | ||
def test__daily_vol__multi_security(trade_data): | ||
pl_result = ( | ||
daily_vol(trade_data.lazy(), "ts_event", "price", "symbol", 5) | ||
.collect() | ||
.sort("ts_event", "symbol") | ||
) | ||
pd_result = ( | ||
trade_data.to_pandas() | ||
.set_index("ts_event")[["symbol", "price"]] | ||
.groupby("symbol")["price"] | ||
.apply(get_daily_vol, 5) | ||
) | ||
pd_result = ( | ||
pl.from_pandas(pd_result.reset_index()) | ||
.rename({"price": "daily_return_volatility"}) | ||
.sort("ts_event", "symbol") | ||
) | ||
assert_frame_equal( | ||
pl_result, pd_result, check_row_order=False, check_column_order=False | ||
) | ||
|
||
|
||
@pytest.mark.benchmark(group="daily_vol") | ||
@pytest.mark.parametrize( | ||
"trade_data", | ||
[ | ||
{"n_rows": 10_000, "n_companies": 3}, | ||
{"n_rows": 100_000, "n_companies": 5}, | ||
{"n_rows": 1_000_000, "n_companies": 10}, | ||
], | ||
indirect=True, | ||
) | ||
def test__daily_vol__polars_benchmark(benchmark, trade_data): | ||
benchmark(daily_vol(trade_data.lazy(), "ts_event", "price", "symbol", 100).collect) | ||
|
||
|
||
@pytest.mark.benchmark(group="daily_vol") | ||
@pytest.mark.parametrize( | ||
"trade_data", | ||
[ | ||
{"n_rows": 10_000, "n_companies": 3}, | ||
{"n_rows": 100_000, "n_companies": 5}, | ||
{"n_rows": 1_000_000, "n_companies": 10}, | ||
], | ||
indirect=True, | ||
) | ||
def test__daily_vol__pandas_benchmark(benchmark, trade_data): | ||
pd_df = ( | ||
trade_data.to_pandas() | ||
.set_index("ts_event")[["symbol", "price"]] | ||
) | ||
def get_daily_vol_pd(pd_df): | ||
return pd_df.groupby("symbol")["price"].apply(get_daily_vol).reset_index() | ||
benchmark(get_daily_vol_pd, pd_df) |
Oops, something went wrong.