Skip to content

Commit

Permalink
feat(labels): add daily vol calculation (#1)
Browse files Browse the repository at this point in the history
This PR adds daily vol function from Lopez de Prado AFML pg.44

---------

Co-authored-by: Nelson Griffiths <[email protected]>
  • Loading branch information
ngriffiths13 and Nelson Griffiths authored Sep 20, 2024
1 parent 1266bae commit f2137ea
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 175 deletions.
157 changes: 0 additions & 157 deletions .github/workflows/publish_to_pypi.yml

This file was deleted.

9 changes: 6 additions & 3 deletions polars_trading/_testing/data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from functools import lru_cache

import polars as pl
from mimesis import Fieldset
from mimesis import Fieldset, Finance
from mimesis.locales import Locale


@lru_cache
def generate_trade_data(n_rows: int) -> pl.DataFrame:
def generate_trade_data(n_rows: int, n_companies: int = 3) -> pl.DataFrame:
fs = Fieldset(locale=Locale.EN, i=n_rows)

return pl.DataFrame(
Expand All @@ -16,6 +16,9 @@ def generate_trade_data(n_rows: int) -> pl.DataFrame:
),
"price": fs("finance.price", minimum=1, maximum=100),
"size": fs("numeric.integer_number", start=10_000, end=100_000),
"symbol": fs("choice.choice", items=["AAPL", "GOOGL", "MSFT"]),
"symbol": fs(
"choice.choice",
items=[Finance().stock_ticker() for _ in range(n_companies)],
),
}
)
14 changes: 14 additions & 0 deletions polars_trading/_testing/labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pandas as pd


def get_daily_vol(close: pd.Series, span0: int = 100) -> pd.Series:
# This function calculates returns as close to 24 hours ago as possible and then
# calculates the exponentially weighted moving standard deviation of those returns.
close = close.sort_index().copy()
df0 = close.index.searchsorted(close.index - pd.Timedelta(days=1))
df0 = df0[df0 > 0]
df0 = pd.Series(
close.index[df0], index=close.index[close.shape[0] - df0.shape[0] :]
)
df0 = close.loc[df0.index] / close.loc[df0.values].values - 1
return df0.ewm(span=span0).std().fillna(0.0)
66 changes: 66 additions & 0 deletions polars_trading/labels/dynamic_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from datetime import timedelta
from typing import TYPE_CHECKING

import polars as pl

if TYPE_CHECKING:
from polars_trading.typing import FrameType


def daily_vol(
df: FrameType,
timestamp_col: str,
price_col: str,
symbol_col: str | None = None,
span: int = 100,
) -> FrameType:
"""This function calculates the daily volatility of a price series.
It uses an the daily volatiity by looking back at the return from the oldest price
in the last 24 hour period to the current price. It then calculates the exponential
weighted standard deviation of the returns.
Args:
----
df (DataFrame): The DataFrame containing the price series.
timestamp_col (str): The column name containing the timestamps.
price_col (str): The column name containing the prices.
symbol_col (str | None): The column name containing the symbols. If None, it is
assumed that the prices are for a single symbol. Defaults to None.
span (int): The span of the exponential weighted standard deviation. Defaults to
100.
Returns:
FrameType: The DataFrame with the daily volatility.
"""
returns = (
df.sort(timestamp_col)
.rolling(timestamp_col, period="24h", group_by=symbol_col)
.agg(pl.last(price_col).truediv(pl.first(price_col)).sub(1).alias("return"))
)
returns = returns.filter(
(pl.col(timestamp_col) - timedelta(hours=24))
> pl.col(timestamp_col).min().over(symbol_col)
)
vol_expr = (
pl.col("return")
.ewm_std(span=span)
.over(symbol_col)
.alias("daily_return_volatility")
)
return_cols = (
[
timestamp_col,
symbol_col,
vol_expr,
]
if symbol_col
else [
timestamp_col,
vol_expr,
]
)

return returns.select(*return_cols)
4 changes: 2 additions & 2 deletions polars_trading/labels/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def fixed_time_return(
pl.Expr: The fixed time return as an expression.
"""
return_expr = (
pl.col(prices)
parse_into_expr(prices)
.shift(-offset - window)
.truediv(pl.col(prices).shift(-offset))
.truediv(parse_into_expr(prices).shift(-offset))
.sub(1)
)
if symbol is not None:
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pytest
from polars_trading._testing.data import generate_trade_data


@pytest.fixture
def trade_data(request):
return generate_trade_data(**request.param)
84 changes: 84 additions & 0 deletions tests/labels/test_dynamic_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from polars_trading.labels.dynamic_labels import daily_vol
import polars as pl
from polars_trading._testing.labels import get_daily_vol
import pytest
from polars.testing import assert_frame_equal


@pytest.mark.parametrize(
"trade_data",
[
{"n_rows": 10_000, "n_companies": 1},
],
indirect=True,
)
def test__daily_vol__single_security(trade_data):
pl_result = daily_vol(trade_data.lazy(), "ts_event", "price", None, 5).collect()
pd_result = get_daily_vol(trade_data.to_pandas().set_index("ts_event")["price"], 5)
pd_result = pl.from_pandas(pd_result.reset_index()).rename(
{"price": "daily_return_volatility"}
)
assert_frame_equal(pl_result, pd_result)


@pytest.mark.parametrize(
"trade_data",
[
{"n_rows": 10_000, "n_companies": 3},
],
indirect=True,
)
def test__daily_vol__multi_security(trade_data):
pl_result = (
daily_vol(trade_data.lazy(), "ts_event", "price", "symbol", 5)
.collect()
.sort("ts_event", "symbol")
)
pd_result = (
trade_data.to_pandas()
.set_index("ts_event")[["symbol", "price"]]
.groupby("symbol")["price"]
.apply(get_daily_vol, 5)
)
pd_result = (
pl.from_pandas(pd_result.reset_index())
.rename({"price": "daily_return_volatility"})
.sort("ts_event", "symbol")
)
assert_frame_equal(
pl_result, pd_result, check_row_order=False, check_column_order=False
)


@pytest.mark.benchmark(group="daily_vol")
@pytest.mark.parametrize(
"trade_data",
[
{"n_rows": 10_000, "n_companies": 3},
{"n_rows": 100_000, "n_companies": 5},
{"n_rows": 1_000_000, "n_companies": 10},
],
indirect=True,
)
def test__daily_vol__polars_benchmark(benchmark, trade_data):
benchmark(daily_vol(trade_data.lazy(), "ts_event", "price", "symbol", 100).collect)


@pytest.mark.benchmark(group="daily_vol")
@pytest.mark.parametrize(
"trade_data",
[
{"n_rows": 10_000, "n_companies": 3},
{"n_rows": 100_000, "n_companies": 5},
{"n_rows": 1_000_000, "n_companies": 10},
],
indirect=True,
)
def test__daily_vol__pandas_benchmark(benchmark, trade_data):
pd_df = (
trade_data.to_pandas()
.set_index("ts_event")[["symbol", "price"]]
)
def get_daily_vol_pd(pd_df):
return pd_df.groupby("symbol")["price"].apply(get_daily_vol).reset_index()
benchmark(get_daily_vol_pd, pd_df)
Loading

0 comments on commit f2137ea

Please sign in to comment.