diff --git a/.github/workflows/publish_to_pypi.yml b/.github/workflows/publish_to_pypi.yml deleted file mode 100644 index 231a178..0000000 --- a/.github/workflows/publish_to_pypi.yml +++ /dev/null @@ -1,157 +0,0 @@ -name: CI - -on: - push: - branches: - - main - - master - tags: - - '*' - pull_request: - workflow_dispatch: - -concurrency: - - group: ${{ github.workflow }}-${{ github.ref }} - - cancel-in-progress: true - -permissions: - contents: read - -# Make sure CI fails on all warnings, including Clippy lints -env: - RUSTFLAGS: "-Dwarnings" - -jobs: - linux_tests: - runs-on: ubuntu-latest - strategy: - matrix: - target: [x86_64] - python-version: ["3.8", "3.9", "3.10", "3.11"] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - - python-version: ${{ matrix.python-version }} - - - - name: Set up Rust - run: rustup show - - uses: mozilla-actions/sccache-action@v0.0.3 - - run: make .venv - - run: make pre-commit - - run: make install - - run: make test - - linux: - runs-on: ubuntu-latest - strategy: - matrix: - target: [x86_64, x86] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - - target: ${{ matrix.target }} - - args: --release --out dist --find-interpreter - sccache: 'true' - manylinux: auto - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - windows: - runs-on: windows-latest - strategy: - matrix: - target: [x64] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - architecture: ${{ matrix.target }} - - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - - target: ${{ matrix.target }} - - args: --release --out dist --find-interpreter - sccache: 'true' - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - macos: - runs-on: macos-latest - strategy: - matrix: - target: [x86_64, aarch64] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - - target: ${{ matrix.target }} - - args: --release --out dist --find-interpreter - sccache: 'true' - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - sdist: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Build sdist - uses: PyO3/maturin-action@v1 - with: - command: sdist - args: --out dist - - name: Upload sdist - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - release: - name: Release - if: "startsWith(github.ref, 'refs/tags/')" - needs: [linux, windows, macos, sdist] - runs-on: ubuntu-latest - environment: pypi - permissions: - id-token: write # IMPORTANT: mandatory for trusted publishing - steps: - - uses: actions/download-artifact@v3 - with: - name: wheels - - name: Publish to PyPI - uses: PyO3/maturin-action@v1 - with: - command: upload - args: --non-interactive --skip-existing * - diff --git a/polars_trading/_testing/data.py b/polars_trading/_testing/data.py index df00a97..b933c6e 100644 --- a/polars_trading/_testing/data.py +++ b/polars_trading/_testing/data.py @@ -1,12 +1,12 @@ from functools import lru_cache import polars as pl -from mimesis import Fieldset +from mimesis import Fieldset, Finance from mimesis.locales import Locale @lru_cache -def generate_trade_data(n_rows: int) -> pl.DataFrame: +def generate_trade_data(n_rows: int, n_companies: int = 3) -> pl.DataFrame: fs = Fieldset(locale=Locale.EN, i=n_rows) return pl.DataFrame( @@ -16,6 +16,9 @@ def generate_trade_data(n_rows: int) -> pl.DataFrame: ), "price": fs("finance.price", minimum=1, maximum=100), "size": fs("numeric.integer_number", start=10_000, end=100_000), - "symbol": fs("choice.choice", items=["AAPL", "GOOGL", "MSFT"]), + "symbol": fs( + "choice.choice", + items=[Finance().stock_ticker() for _ in range(n_companies)], + ), } ) diff --git a/polars_trading/_testing/labels.py b/polars_trading/_testing/labels.py new file mode 100644 index 0000000..1ff20f3 --- /dev/null +++ b/polars_trading/_testing/labels.py @@ -0,0 +1,14 @@ +import pandas as pd + + +def get_daily_vol(close: pd.Series, span0: int = 100) -> pd.Series: + # This function calculates returns as close to 24 hours ago as possible and then + # calculates the exponentially weighted moving standard deviation of those returns. + close = close.sort_index().copy() + df0 = close.index.searchsorted(close.index - pd.Timedelta(days=1)) + df0 = df0[df0 > 0] + df0 = pd.Series( + close.index[df0], index=close.index[close.shape[0] - df0.shape[0] :] + ) + df0 = close.loc[df0.index] / close.loc[df0.values].values - 1 + return df0.ewm(span=span0).std().fillna(0.0) diff --git a/polars_trading/labels/dynamic_labels.py b/polars_trading/labels/dynamic_labels.py index e69de29..7ba5e84 100644 --- a/polars_trading/labels/dynamic_labels.py +++ b/polars_trading/labels/dynamic_labels.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import TYPE_CHECKING + +import polars as pl + +if TYPE_CHECKING: + from polars_trading.typing import FrameType + + +def daily_vol( + df: FrameType, + timestamp_col: str, + price_col: str, + symbol_col: str | None = None, + span: int = 100, +) -> FrameType: + """This function calculates the daily volatility of a price series. + + It uses an the daily volatiity by looking back at the return from the oldest price + in the last 24 hour period to the current price. It then calculates the exponential + weighted standard deviation of the returns. + + Args: + ---- + df (DataFrame): The DataFrame containing the price series. + timestamp_col (str): The column name containing the timestamps. + price_col (str): The column name containing the prices. + symbol_col (str | None): The column name containing the symbols. If None, it is + assumed that the prices are for a single symbol. Defaults to None. + span (int): The span of the exponential weighted standard deviation. Defaults to + 100. + + Returns: + FrameType: The DataFrame with the daily volatility. + """ + returns = ( + df.sort(timestamp_col) + .rolling(timestamp_col, period="24h", group_by=symbol_col) + .agg(pl.last(price_col).truediv(pl.first(price_col)).sub(1).alias("return")) + ) + returns = returns.filter( + (pl.col(timestamp_col) - timedelta(hours=24)) + > pl.col(timestamp_col).min().over(symbol_col) + ) + vol_expr = ( + pl.col("return") + .ewm_std(span=span) + .over(symbol_col) + .alias("daily_return_volatility") + ) + return_cols = ( + [ + timestamp_col, + symbol_col, + vol_expr, + ] + if symbol_col + else [ + timestamp_col, + vol_expr, + ] + ) + + return returns.select(*return_cols) diff --git a/polars_trading/labels/labels.py b/polars_trading/labels/labels.py index 87eb4df..44bd393 100644 --- a/polars_trading/labels/labels.py +++ b/polars_trading/labels/labels.py @@ -94,9 +94,9 @@ def fixed_time_return( pl.Expr: The fixed time return as an expression. """ return_expr = ( - pl.col(prices) + parse_into_expr(prices) .shift(-offset - window) - .truediv(pl.col(prices).shift(-offset)) + .truediv(parse_into_expr(prices).shift(-offset)) .sub(1) ) if symbol is not None: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..684f7bc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +import pytest +from polars_trading._testing.data import generate_trade_data + + +@pytest.fixture +def trade_data(request): + return generate_trade_data(**request.param) diff --git a/tests/labels/test_dynamic_labels.py b/tests/labels/test_dynamic_labels.py new file mode 100644 index 0000000..e7ca8b1 --- /dev/null +++ b/tests/labels/test_dynamic_labels.py @@ -0,0 +1,84 @@ +from polars_trading.labels.dynamic_labels import daily_vol +import polars as pl +from polars_trading._testing.labels import get_daily_vol +import pytest +from polars.testing import assert_frame_equal + + +@pytest.mark.parametrize( + "trade_data", + [ + {"n_rows": 10_000, "n_companies": 1}, + ], + indirect=True, +) +def test__daily_vol__single_security(trade_data): + pl_result = daily_vol(trade_data.lazy(), "ts_event", "price", None, 5).collect() + pd_result = get_daily_vol(trade_data.to_pandas().set_index("ts_event")["price"], 5) + pd_result = pl.from_pandas(pd_result.reset_index()).rename( + {"price": "daily_return_volatility"} + ) + assert_frame_equal(pl_result, pd_result) + + +@pytest.mark.parametrize( + "trade_data", + [ + {"n_rows": 10_000, "n_companies": 3}, + ], + indirect=True, +) +def test__daily_vol__multi_security(trade_data): + pl_result = ( + daily_vol(trade_data.lazy(), "ts_event", "price", "symbol", 5) + .collect() + .sort("ts_event", "symbol") + ) + pd_result = ( + trade_data.to_pandas() + .set_index("ts_event")[["symbol", "price"]] + .groupby("symbol")["price"] + .apply(get_daily_vol, 5) + ) + pd_result = ( + pl.from_pandas(pd_result.reset_index()) + .rename({"price": "daily_return_volatility"}) + .sort("ts_event", "symbol") + ) + assert_frame_equal( + pl_result, pd_result, check_row_order=False, check_column_order=False + ) + + +@pytest.mark.benchmark(group="daily_vol") +@pytest.mark.parametrize( + "trade_data", + [ + {"n_rows": 10_000, "n_companies": 3}, + {"n_rows": 100_000, "n_companies": 5}, + {"n_rows": 1_000_000, "n_companies": 10}, + ], + indirect=True, +) +def test__daily_vol__polars_benchmark(benchmark, trade_data): + benchmark(daily_vol(trade_data.lazy(), "ts_event", "price", "symbol", 100).collect) + + +@pytest.mark.benchmark(group="daily_vol") +@pytest.mark.parametrize( + "trade_data", + [ + {"n_rows": 10_000, "n_companies": 3}, + {"n_rows": 100_000, "n_companies": 5}, + {"n_rows": 1_000_000, "n_companies": 10}, + ], + indirect=True, +) +def test__daily_vol__pandas_benchmark(benchmark, trade_data): + pd_df = ( + trade_data.to_pandas() + .set_index("ts_event")[["symbol", "price"]] + ) + def get_daily_vol_pd(pd_df): + return pd_df.groupby("symbol")["price"].apply(get_daily_vol).reset_index() + benchmark(get_daily_vol_pd, pd_df) diff --git a/tests/test_bars.py b/tests/test_bars.py index 3513186..12e6e6b 100644 --- a/tests/test_bars.py +++ b/tests/test_bars.py @@ -8,11 +8,6 @@ from polars_trading._testing.data import generate_trade_data -@pytest.fixture -def trade_data(request): - return generate_trade_data(request.param) - - def pandas_time_bars(df: pd.DataFrame, period: str) -> pd.DataFrame: df.index = pd.to_datetime(df["ts_event"]) df["pvt"] = df["price"] * df["size"] @@ -65,7 +60,9 @@ def pandas_tick_bars(df: pd.DataFrame, n_ticks: int) -> pd.DataFrame: return resampled_df -@pytest.mark.parametrize("trade_data", [10_000], indirect=True) +@pytest.mark.parametrize( + "trade_data", [{"n_rows": 10_000, "n_companies": 3}], indirect=True +) def test__time_bars__matches_pandas(trade_data): pd_df = pandas_time_bars(trade_data.to_pandas(), "1d") pd_df.index = pd_df.index.to_timestamp() @@ -81,19 +78,37 @@ def test__time_bars__matches_pandas(trade_data): @pytest.mark.benchmark(group="time_bars") -@pytest.mark.parametrize("trade_data", [100, 10_000, 1_000_000], indirect=True) +@pytest.mark.parametrize( + "trade_data", + [ + {"n_rows": 1000, "n_companies": 3}, + {"n_rows": 10_000, "n_companies": 3}, + {"n_rows": 1_000_000, "n_companies": 20}, + ], + indirect=True, +) def test__time_bars__polars_benchmark(benchmark, trade_data): benchmark(time_bars, trade_data, timestamp_col="ts_event", bar_size="1d") @pytest.mark.benchmark(group="time_bars") -@pytest.mark.parametrize("trade_data", [100, 10_000, 1_000_000], indirect=True) +@pytest.mark.parametrize( + "trade_data", + [ + {"n_rows": 1000, "n_companies": 3}, + {"n_rows": 10_000, "n_companies": 3}, + {"n_rows": 1_000_000, "n_companies": 10}, + ], + indirect=True, +) def test__time_bars__pandas_benchmark(benchmark, trade_data): trade_data = trade_data.to_pandas() benchmark(pandas_time_bars, trade_data, "1d") -@pytest.mark.parametrize("trade_data", [10000], indirect=True) +@pytest.mark.parametrize( + "trade_data", [{"n_rows": 10_000, "n_companies": 3}], indirect=True +) def test__tick_bars__matches_pandas(trade_data): pd_df = pandas_tick_bars(trade_data.to_pandas(), 100) res = tick_bars(trade_data, timestamp_col="ts_event", bar_size=100) @@ -107,25 +122,57 @@ def test__tick_bars__matches_pandas(trade_data): @pytest.mark.benchmark(group="tick_bars") -@pytest.mark.parametrize("trade_data", [100, 10_000, 1_000_000], indirect=True) +@pytest.mark.parametrize( + "trade_data", + [ + {"n_rows": 1000, "n_companies": 3}, + {"n_rows": 10_000, "n_companies": 3}, + {"n_rows": 1_000_000, "n_companies": 10}, + ], + indirect=True, +) def test__tick_bars__polars_benchmark(benchmark, trade_data): benchmark(tick_bars, trade_data, timestamp_col="ts_event", bar_size=100) @pytest.mark.benchmark(group="tick_bars") -@pytest.mark.parametrize("trade_data", [100, 10_000, 1_000_000], indirect=True) +@pytest.mark.parametrize( + "trade_data", + [ + {"n_rows": 1000, "n_companies": 3}, + {"n_rows": 10_000, "n_companies": 3}, + {"n_rows": 1_000_000, "n_companies": 10}, + ], + indirect=True, +) def test__tick_bars__pandas_benchmark(benchmark, trade_data): trade_data = trade_data.to_pandas() benchmark(pandas_tick_bars, trade_data, 100) @pytest.mark.benchmark(group="volume_bars") -@pytest.mark.parametrize("trade_data", [100, 10_000, 1_000_000], indirect=True) +@pytest.mark.parametrize( + "trade_data", + [ + {"n_rows": 1000, "n_companies": 3}, + {"n_rows": 10_000, "n_companies": 3}, + {"n_rows": 1_000_000, "n_companies": 10}, + ], + indirect=True, +) def test__volume_bars__polars_benchmark(benchmark, trade_data): benchmark(volume_bars, trade_data, timestamp_col="ts_event", bar_size=10_000) @pytest.mark.benchmark(group="dollar_bars") -@pytest.mark.parametrize("trade_data", [100, 10_000, 1_000_000], indirect=True) +@pytest.mark.parametrize( + "trade_data", + [ + {"n_rows": 1000, "n_companies": 3}, + {"n_rows": 10_000, "n_companies": 3}, + {"n_rows": 1_000_000, "n_companies": 10}, + ], + indirect=True, +) def test__dollar_bars__polars_benchmark(benchmark, trade_data): benchmark(dollar_bars, trade_data, timestamp_col="ts_event", bar_size=100_000)