From 425686ae3a5510b7669800f2ce9acdbbefef65a6 Mon Sep 17 00:00:00 2001 From: Nelson Griffiths Date: Wed, 3 Apr 2024 15:52:44 -0600 Subject: [PATCH] initial commit ; --- .DS_Store | Bin 0 -> 8196 bytes .github/workflows/publish_to_pypi.yml | 157 ++ .gitignore | 178 +++ Cargo.lock | 1917 +++++++++++++++++++++++++ Cargo.toml | 17 + Makefile | 29 + README.md | 1 + exercises/Untitled.ipynb | 252 ++++ exercises/test.ipynb | 184 +++ polars_finance/__init__.py | 22 + polars_finance/bars.py | 272 ++++ polars_finance/bet_size.py | 0 polars_finance/cross_validation.py | 8 + polars_finance/feature_importance.py | 17 + polars_finance/frac_diff.py | 0 polars_finance/hyperparams.py | 11 + polars_finance/labels.py | 120 ++ polars_finance/plots.py | 0 polars_finance/sampling_features.py | 19 + polars_finance/ta.py | 11 + polars_finance/utils.py | 106 ++ pyproject.toml | 66 + requirements.txt | 5 + src/bars.rs | 385 +++++ src/labels.rs | 224 +++ src/lib.rs | 11 + src/nbbo.rs | 31 + src/symmetric_cusum_filter.rs | 56 + 28 files changed, 4099 insertions(+) create mode 100644 .DS_Store create mode 100644 .github/workflows/publish_to_pypi.yml create mode 100644 .gitignore create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 Makefile create mode 100644 README.md create mode 100644 exercises/Untitled.ipynb create mode 100644 exercises/test.ipynb create mode 100644 polars_finance/__init__.py create mode 100644 polars_finance/bars.py create mode 100644 polars_finance/bet_size.py create mode 100644 polars_finance/cross_validation.py create mode 100644 polars_finance/feature_importance.py create mode 100644 polars_finance/frac_diff.py create mode 100644 polars_finance/hyperparams.py create mode 100644 polars_finance/labels.py create mode 100644 polars_finance/plots.py create mode 100644 polars_finance/sampling_features.py create mode 100644 polars_finance/ta.py create mode 100644 polars_finance/utils.py create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 src/bars.rs create mode 100644 src/labels.rs create mode 100644 src/lib.rs create mode 100644 src/nbbo.rs create mode 100644 src/symmetric_cusum_filter.rs diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..97abdc11c0cbb4853414a376448183d3028a9780 GIT binary patch literal 8196 zcmeHMO^?!06uoaiC_|iJ)P#*m6SosY5=o33WMJXK&0ut)rnVFrq_jwj%!H7z*8kvN zaO*GezqrzK-z%go;L=Po?@jKzukAVS>lRJvDwV%mFT9}dc!2{1-&D_GwLh|; zMN}XMe*x8)nZf?ZV2|kxUj}A0lz0zypT=p9c<@h71kQxK4Ch;hlP4GBL}X`hzYpP@ zML3Pu_|utsqcE!1e~F`9{#K!A6pga+zIPEUdXXK?!?rzp%SW$*z;iZj+j%`2EUe1C zv%rt+kw2J8!pI!}^6t&ZcY{ScnEP%hxjl8kC>bTIQe7?^&HAa?Xf#)+=JII+?I+Fp zYE?4sKYaA;+!}imKlqBdqloyyN4M4Swq^~s>;BT4`Ci};nE)$p^w_scM~q7D=rPa^ zBEccvq2WG`jDqipC@A4u&n0Yz9UW1KR+!eFI2Mu&gI;2f#z1}q!=phOFXJGIKS6v2 zz;XP9m>six9QhbA%VjOH+$->B=;0Ga>?1a6uzj(_F_Pq#=$yJ(ZwK@3f;AyGJyWS) zl2O4p1NuyJ%#d3Xn2*syIp!rrX+CQyP7+!cz?y?`;EqaRbLbVCr+Z}WQm6vo3&0IA zlg#J>TBWdefW_{iwo*~Ud|uOT4^S-C0|ex#)9-&4LKm&THCG^~=2tDlcX8jc*F`I! z6}VXnh\n", + "shape: (200, 3)
symbolts_eventprice
strdatef64
"AAPL"2021-01-011.0
"AAPL"2021-01-022.0
"AAPL"2021-01-033.0
"AAPL"2021-01-044.0
"AAPL"2021-01-055.0
"AAPL"2021-07-15196.0
"AAPL"2021-07-16197.0
"AAPL"2021-07-17198.0
"AAPL"2021-07-18199.0
"AAPL"2021-07-19200.0
" + ], + "text/plain": [ + "shape: (200, 3)\n", + "┌────────┬────────────┬───────┐\n", + "│ symbol ┆ ts_event ┆ price │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ str ┆ date ┆ f64 │\n", + "╞════════╪════════════╪═══════╡\n", + "│ AAPL ┆ 2021-01-01 ┆ 1.0 │\n", + "│ AAPL ┆ 2021-01-02 ┆ 2.0 │\n", + "│ AAPL ┆ 2021-01-03 ┆ 3.0 │\n", + "│ AAPL ┆ 2021-01-04 ┆ 4.0 │\n", + "│ AAPL ┆ 2021-01-05 ┆ 5.0 │\n", + "│ … ┆ … ┆ … │\n", + "│ AAPL ┆ 2021-07-15 ┆ 196.0 │\n", + "│ AAPL ┆ 2021-07-16 ┆ 197.0 │\n", + "│ AAPL ┆ 2021-07-17 ┆ 198.0 │\n", + "│ AAPL ┆ 2021-07-18 ┆ 199.0 │\n", + "│ AAPL ┆ 2021-07-19 ┆ 200.0 │\n", + "└────────┴────────────┴───────┘" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "48763c7b-74a6-49e1-ba95-0494b2286186", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (200, 6)
symbolts_eventpriceretlabel1label2
strdatef64f64i32i32
"AAPL"2021-01-011.01.011
"AAPL"2021-01-022.00.511
"AAPL"2021-01-033.00.33333311
"AAPL"2021-01-044.00.2511
"AAPL"2021-01-055.00.211
"AAPL"2021-07-15196.00.00510200
"AAPL"2021-07-16197.00.00507600
"AAPL"2021-07-17198.00.00505100
"AAPL"2021-07-18199.00.00502500
"AAPL"2021-07-19200.0null00
" + ], + "text/plain": [ + "shape: (200, 6)\n", + "┌────────┬────────────┬───────┬──────────┬────────┬────────┐\n", + "│ symbol ┆ ts_event ┆ price ┆ ret ┆ label1 ┆ label2 │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ str ┆ date ┆ f64 ┆ f64 ┆ i32 ┆ i32 │\n", + "╞════════╪════════════╪═══════╪══════════╪════════╪════════╡\n", + "│ AAPL ┆ 2021-01-01 ┆ 1.0 ┆ 1.0 ┆ 1 ┆ 1 │\n", + "│ AAPL ┆ 2021-01-02 ┆ 2.0 ┆ 0.5 ┆ 1 ┆ 1 │\n", + "│ AAPL ┆ 2021-01-03 ┆ 3.0 ┆ 0.333333 ┆ 1 ┆ 1 │\n", + "│ AAPL ┆ 2021-01-04 ┆ 4.0 ┆ 0.25 ┆ 1 ┆ 1 │\n", + "│ AAPL ┆ 2021-01-05 ┆ 5.0 ┆ 0.2 ┆ 1 ┆ 1 │\n", + "│ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", + "│ AAPL ┆ 2021-07-15 ┆ 196.0 ┆ 0.005102 ┆ 0 ┆ 0 │\n", + "│ AAPL ┆ 2021-07-16 ┆ 197.0 ┆ 0.005076 ┆ 0 ┆ 0 │\n", + "│ AAPL ┆ 2021-07-17 ┆ 198.0 ┆ 0.005051 ┆ 0 ┆ 0 │\n", + "│ AAPL ┆ 2021-07-18 ┆ 199.0 ┆ 0.005025 ┆ 0 ┆ 0 │\n", + "│ AAPL ┆ 2021-07-19 ┆ 200.0 ┆ null ┆ 0 ┆ 0 │\n", + "└────────┴────────────┴───────┴──────────┴────────┴────────┘" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df.with_columns(\n", + " raw_forward_returns(pl.col(\"price\")).alias(\"ret\"),\n", + " fixed_time_label(pl.col(\"price\"), t=1).alias(\"label1\"),\n", + " fixed_time_label(pl.col(\"price\"), upper_threshold=.2, t=2).alias(\"label2\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "1f2b0729-cba4-4ab0-95a1-ac74fd3dc44f", + "metadata": {}, + "outputs": [], + "source": [ + "import databento as db\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "03d0cc8a-294e-4c37-92b5-f98390984d4b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'xnas-itch-20231018'" + ] + }, + "execution_count": 67, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(Path(\"../data/XNAS-20240403-QMLQV3MJHY/\").glob(\"*.zst\"))[0].name.split(\".\")[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "a620723f-14da-4be5-af32-cbbc327059f4", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a74b3257ddae48fd9243cdf8ec8bfdb2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/250 [00:00\n", + "shape: (5, 2)
datec
datetime[ms]i64
2020-08-02 00:01:001
2020-08-03 00:01:002
2020-08-04 00:01:003
2020-08-05 00:01:004
2020-08-06 00:01:005
" + ], + "text/plain": [ + "shape: (5, 2)\n", + "┌─────────────────────┬─────┐\n", + "│ date ┆ c │\n", + "│ --- ┆ --- │\n", + "│ datetime[ms] ┆ i64 │\n", + "╞═════════════════════╪═════╡\n", + "│ 2020-08-02 00:01:00 ┆ 1 │\n", + "│ 2020-08-03 00:01:00 ┆ 2 │\n", + "│ 2020-08-04 00:01:00 ┆ 3 │\n", + "│ 2020-08-05 00:01:00 ┆ 4 │\n", + "│ 2020-08-06 00:01:00 ┆ 5 │\n", + "└─────────────────────┴─────┘" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df.select(\n", + " (pl.col(\"date\") + timedelta(days=1, minutes=1)).alias(\"date\"), pl.int_range(1, pl.len() + 1).alias(\"c\")\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 4)
datecvaluecount
datetime[ms]i64i64i64
2020-08-02 00:01:00122
2020-08-03 00:01:00233
2020-08-04 00:01:00344
2020-08-05 00:01:00455
2020-08-06 00:01:00555
" + ], + "text/plain": [ + "shape: (5, 4)\n", + "┌─────────────────────┬─────┬───────┬───────┐\n", + "│ date ┆ c ┆ value ┆ count │\n", + "│ --- ┆ --- ┆ --- ┆ --- │\n", + "│ datetime[ms] ┆ i64 ┆ i64 ┆ i64 │\n", + "╞═════════════════════╪═════╪═══════╪═══════╡\n", + "│ 2020-08-02 00:01:00 ┆ 1 ┆ 2 ┆ 2 │\n", + "│ 2020-08-03 00:01:00 ┆ 2 ┆ 3 ┆ 3 │\n", + "│ 2020-08-04 00:01:00 ┆ 3 ┆ 4 ┆ 4 │\n", + "│ 2020-08-05 00:01:00 ┆ 4 ┆ 5 ┆ 5 │\n", + "│ 2020-08-06 00:01:00 ┆ 5 ┆ 5 ┆ 5 │\n", + "└─────────────────────┴─────┴───────┴───────┘" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df.select(\n", + " (pl.col(\"date\") + timedelta(days=1, minutes=1)).alias(\"date\"), pl.int_range(1, pl.len() + 1).alias(\"c\")\n", + " ).set_sorted(\"date\").join_asof(\n", + " test_df.set_sorted(\"date\"),\n", + " on=\"date\",\n", + " strategy=\"backward\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "ename": "ComputeError", + "evalue": "n must be a single value.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mComputeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[30], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtest_df\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwith_columns\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshift\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mSeries\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwith_columns\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdate\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshift\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mshift\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malias\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdate_2\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/projects/polars_finance/.venv/lib/python3.11/site-packages/polars/dataframe/frame.py:8366\u001b[0m, in \u001b[0;36mDataFrame.with_columns\u001b[0;34m(self, *exprs, **named_exprs)\u001b[0m\n\u001b[1;32m 8220\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwith_columns\u001b[39m(\n\u001b[1;32m 8221\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 8222\u001b[0m \u001b[38;5;241m*\u001b[39mexprs: IntoExpr \u001b[38;5;241m|\u001b[39m Iterable[IntoExpr],\n\u001b[1;32m 8223\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mnamed_exprs: IntoExpr,\n\u001b[1;32m 8224\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m DataFrame:\n\u001b[1;32m 8225\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 8226\u001b[0m \u001b[38;5;124;03m Add columns to this DataFrame.\u001b[39;00m\n\u001b[1;32m 8227\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 8364\u001b[0m \u001b[38;5;124;03m └─────┴──────┴─────────────┘\u001b[39;00m\n\u001b[1;32m 8365\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 8366\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlazy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwith_columns\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mexprs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnamed_exprs\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_eager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/projects/polars_finance/.venv/lib/python3.11/site-packages/polars/lazyframe/frame.py:1943\u001b[0m, in \u001b[0;36mLazyFrame.collect\u001b[0;34m(self, type_coercion, predicate_pushdown, projection_pushdown, simplify_expression, slice_pushdown, comm_subplan_elim, comm_subexpr_elim, no_optimization, streaming, background, _eager)\u001b[0m\n\u001b[1;32m 1940\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m background:\n\u001b[1;32m 1941\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m InProcessQuery(ldf\u001b[38;5;241m.\u001b[39mcollect_concurrently())\n\u001b[0;32m-> 1943\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrap_df(ldf\u001b[38;5;241m.\u001b[39mcollect())\n", + "\u001b[0;31mComputeError\u001b[0m: n must be a single value." + ] + } + ], + "source": [ + "test_df.with_columns(shift=pl.Series(\"t\", [1, 2, 0, 1, 2])).with_columns(pl.col(\"date\").shift(pl.col(\"shift\")).alias(\"date_2\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/polars_finance/__init__.py b/polars_finance/__init__.py new file mode 100644 index 0000000..1da4d80 --- /dev/null +++ b/polars_finance/__init__.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import polars as pl +from polars.utils.udfs import _get_shared_lib_location + +from polars_finance.utils import parse_into_expr + +if TYPE_CHECKING: + from polars.type_aliases import IntoExpr + +lib = _get_shared_lib_location(__file__) + + +def pig_latinnify(expr: IntoExpr) -> pl.Expr: + expr = parse_into_expr(expr) + return expr.register_plugin( + lib=lib, + symbol="pig_latinnify", + is_elementwise=True, + ) diff --git a/polars_finance/bars.py b/polars_finance/bars.py new file mode 100644 index 0000000..8763414 --- /dev/null +++ b/polars_finance/bars.py @@ -0,0 +1,272 @@ +import polars as pl +from polars.type_aliases import IntoExpr, FrameType +from datetime import timedelta +from typing import Literal +from polars.plugins import register_plugin_function +from polars.utils.udfs import _get_shared_lib_location + +from polars_finance.utils import parse_into_expr + + +lib = _get_shared_lib_location(__file__) + + +def _dynamic_tick_bar_groups(thresholds: IntoExpr) -> pl.Expr: + expr = parse_into_expr(thresholds) + return expr.cast(pl.UInt16).register_plugin( + lib=lib, + symbol="dynamic_tick_bar_groups", + is_elementwise=False, + cast_to_supertypes=True, + ) + + +def _ohlcv_expr(timestamp_col: str, price_col: str, size_col: str) -> list[pl.Expr]: + return [ + pl.first(timestamp_col).name.prefix("begin_"), + pl.last(timestamp_col).name.prefix("end_"), + pl.first(price_col).alias("open"), + pl.max(price_col).alias("high"), + pl.min(price_col).alias("low"), + pl.last(price_col).alias("close"), + ((pl.col(size_col) * pl.col(price_col)).sum() / pl.col(size_col).sum()).alias( + "vwap" + ), + pl.sum(size_col).alias("volume"), + pl.len().alias("n_trades"), + ] + + +def standard_bars( + df: FrameType, + timestamp_col: str = "ts_event", + price_col: str = "price", + size_col: str = "size", + symbol_col: str = "symbol", + bar_size: str = "1m", +): + """ + This function generates standard bars for a given DataFrame. + + + Args: + df (FrameType): The DataFrame/LazyFrame to generate standard bars for. + timestamp_col (str): The name of the timestamp column in the DataFrame. + price_col (str, optional): The name of the price column in the DataFrame. Defaults to "price". + size_col (str, optional): The name of the size column in the DataFrame. Defaults to "size". + symbol_col (str, optional): The name of the symbol column in the DataFrame. Defaults to "symbol". + bar_size (str, optional): The size of the bars to generate. + Can use any number followed by a time symbol. For example: + + 1s = 1 second + 2m = 2 minutes + 3h = 3 hours + 4d = 4 days + + Defaults to "1m". + """ + + ohlcv = ( + df.drop_nulls(subset=price_col) + .sort(timestamp_col) + .with_columns(pl.col(timestamp_col).dt.truncate(bar_size)) + .group_by(timestamp_col, symbol_col) + .agg( + pl.first(price_col).alias("open"), + pl.max(price_col).alias("high"), + pl.min(price_col).alias("low"), + pl.last(price_col).alias("close"), + ( + (pl.col(size_col) * pl.col(price_col)).sum() / pl.col(size_col).sum() + ).alias("vwap"), + pl.sum(size_col).alias("volume"), + pl.len().alias("n_trades"), + ) + .sort(timestamp_col) + ) + return ohlcv + + +def tick_bars( + df: FrameType, + timestamp_col: str = "ts_event", + price_col: str = "price", + size_col: str = "size", + symbol_col: str = "symbol", + bar_size: int | pl.Expr = 100, +): + """ + This function generates tick bars for a given DataFrame. + + The function takes a DataFrame, a timestamp column, a price column, a size column, a symbol column, and a bar size as input. + The bar size is the number of ticks that will be aggregated into a single bar. + + Args: + df (FrameType): The DataFrame/LazyFrame to generate tick bars for. + timestamp_col (str): The name of the timestamp column in the DataFrame. + price_col (str): The name of the price column in the DataFrame. + size_col (str): The name of the size column in the DataFrame. + symbol_col (str): The name of the symbol column in the DataFrame. + bar_size (int): The number of ticks to aggregate into a single bar. + """ + ohlcv = ( + df.drop_nulls(subset=price_col) + .sort(timestamp_col) + .with_columns(pl.col(timestamp_col).dt.date().alias("__date")) + ) + if isinstance(bar_size, int): + ohlcv = ohlcv.with_columns( + ( + ((pl.col(symbol_col).cum_count()).over(symbol_col, "__date") - 1) + // bar_size + ).alias( + "__tick_group", + ) + ) + elif isinstance(bar_size, pl.Expr): + ohlcv = ohlcv.with_columns( + _dynamic_tick_bar_groups(bar_size) + .over(symbol_col, "__date") + .alias("__tick_group") + ) + ohlcv = ( + ohlcv.group_by("__tick_group", symbol_col, "__date") + .agg(_ohlcv_expr(timestamp_col, price_col, size_col)) + .drop("__tick_group", "__date") + .sort(f"end_{timestamp_col}") + ) + return ohlcv + + +def volume_bars( + df: FrameType, + timestamp_col: str = "ts_event", + price_col: str = "price", + size_col: str = "size", + symbol_col: str = "symbol", + bar_size: int | float | pl.Expr = 1_000_000, +): + """ + This function generates volume bars for a given DataFrame. + + The function takes a DataFrame, a timestamp column, a price column, a size column, + a symbol column, and a bar size as input. + The bar size is the total volume that will be aggregated into a single bar. + + Args: + df (FrameType): The DataFrame/LazyFrame to generate volume bars for. + timestamp_col (str): The name of the timestamp column in the DataFrame. + price_col (str): The name of the price column in the DataFrame. + size_col (str): The name of the size column in the DataFrame. + symbol_col (str): The name of the symbol column in the DataFrame. + bar_size (int | float | dict[str, int | float] | pl.Expr): The total volume to + aggregate into a single bar. + + Returns: + FrameType: A DataFrame with volume bars. + """ + df = df.sort(timestamp_col) + if isinstance(bar_size, int | float): + df = df.with_columns( + (pl.lit(bar_size).cast(pl.UInt32).alias("__PFIN_bar_size")) + ) + elif isinstance(bar_size, pl.Expr): + df = df.with_columns(bar_size.cast(pl.UInt32).alias("__PFIN_bar_size")) + else: + raise TypeError("bar_size must be an int, float, dict, or pl.Expr") + + return ( + df.group_by(symbol_col, pl.col(timestamp_col).dt.date()) + .agg( + pl.col(timestamp_col).register_plugin( + lib=lib, + symbol="volume_bars", + is_elementwise=False, + cast_to_supertypes=False, + args=[pl.col(price_col), pl.col(size_col), pl.col("__PFIN_bar_size")], + changes_length=True, + ) + ) + .explode("ohlcv") + .unnest("ohlcv") + ) + + +def dollar_bars( + df: FrameType, + timestamp_col: str = "ts_event", + price_col: str = "price", + size_col: str = "size", + symbol_col: str = "symbol", + bar_size: int | float | pl.Expr = 1_000_000, +): + """ + This function generates dollar bars for a given DataFrame. + + The function takes a DataFrame, a timestamp column, a price column, a size column, + a symbol column, and a bar size as input. + The bar size is the total dollar amount that will be aggregated into a single bar. + + Args: + df (FrameType): The DataFrame/LazyFrame to generate dollar bars for. + timestamp_col (str): The name of the timestamp column in the DataFrame. + price_col (str): The name of the price column in the DataFrame. + size_col (str): The name of the size column in the DataFrame. + symbol_col (str): The name of the symbol column in the DataFrame. + bar_size (int | float | pl.Expr): The total dollar amount to aggregate into a single bar. + + Returns: + FrameType: A DataFrame with dollar bars. + """ + df = df.sort(timestamp_col) + if isinstance(bar_size, int | float): + df = df.with_columns( + (pl.lit(bar_size).cast(pl.Float64).alias("__PFIN_bar_size")) + ) + elif isinstance(bar_size, pl.Expr): + df = df.with_columns(bar_size.cast(pl.Float64).alias("__PFIN_bar_size")) + else: + raise TypeError("bar_size must be an int, float, or pl.Expr") + + return ( + df.group_by(symbol_col, pl.col(timestamp_col).dt.date()) + .agg( + pl.col(timestamp_col).register_plugin( + lib=lib, + symbol="dollar_bars", + is_elementwise=False, + cast_to_supertypes=False, + args=[pl.col(price_col), pl.col(size_col), pl.col("__PFIN_bar_size")], + changes_length=True, + ) + ) + .explode("ohlcv") + .unnest("ohlcv") + ) + + +def tick_imbalance_bars(): + raise NotImplementedError("This function has not been implemented yet.") + + +def volume_imbalance_bars(): + raise NotImplementedError("This function has not been implemented yet.") + + +def dollar_imbalance_bars(): + raise NotImplementedError("This function has not been implemented yet.") + + +def tick_runs_bars(): + raise NotImplementedError("This function has not been implemented yet.") + + +def volume_runs_bars(): + raise NotImplementedError("This function has not been implemented yet.") + + +def dollar_runs_bars(): + raise NotImplementedError("This function has not been implemented yet.") + + +# ETF Trick? diff --git a/polars_finance/bet_size.py b/polars_finance/bet_size.py new file mode 100644 index 0000000..e69de29 diff --git a/polars_finance/cross_validation.py b/polars_finance/cross_validation.py new file mode 100644 index 0000000..fd0ac69 --- /dev/null +++ b/polars_finance/cross_validation.py @@ -0,0 +1,8 @@ +import polars as pl +from polars.type_aliases import IntoExpr + + +def purged_k_fold_cv(): + ... + + # OTher time series splits? diff --git a/polars_finance/feature_importance.py b/polars_finance/feature_importance.py new file mode 100644 index 0000000..47587f9 --- /dev/null +++ b/polars_finance/feature_importance.py @@ -0,0 +1,17 @@ +import polars as pl +from polars.type_aliases import IntoExpr + + +def mdi_feature_importance(): ... + + +def mda_feature_importance(): ... + + +def sfi_feature_importance(): ... + + +def orthogonal_features(): ... + + +def weighted_kendall_tau(): ... diff --git a/polars_finance/frac_diff.py b/polars_finance/frac_diff.py new file mode 100644 index 0000000..e69de29 diff --git a/polars_finance/hyperparams.py b/polars_finance/hyperparams.py new file mode 100644 index 0000000..1524762 --- /dev/null +++ b/polars_finance/hyperparams.py @@ -0,0 +1,11 @@ +import polars as pl +from polars.type_aliases import IntoExpr + + +def purged_k_fold_grid_search(): ... + + +def purged_k_fold_random_search(): ... + + +def log_uniform_gen(): ... diff --git a/polars_finance/labels.py b/polars_finance/labels.py new file mode 100644 index 0000000..05db0ad --- /dev/null +++ b/polars_finance/labels.py @@ -0,0 +1,120 @@ +import polars as pl +from polars.type_aliases import IntoExpr, FrameType +from polars_finance.utils import parse_into_expr +from polars.utils.udfs import _get_shared_lib_location + + +lib = _get_shared_lib_location(__file__) + + +def raw_forward_returns(prices: IntoExpr, n_bars: int = 1): + price_expr = parse_into_expr(prices) + return price_expr.shift(-n_bars) / price_expr - 1 + + +def fixed_time_label( + price_series: IntoExpr, + upper_threshold: float = 0.01, + lower_threshold: float = -0.1, + t: int = 1, + symbol_col: str = "symbol", +): + return_expr = parse_into_expr(price_series) + return_expr = return_expr.shift(-t).over(symbol_col) / return_expr - 1 + return ( + pl.when(return_expr > upper_threshold) + .then(1) + .when(return_expr < lower_threshold) + .then(-1) + .otherwise(0) + ) + + +def fixed_time_dynamic_threshold_label( + price_series: IntoExpr, + span: int = 100, + upper_multiplier: float = 1.0, + lower_multiplier: float = 1.0, + t: int = 1, + symbol_col: str = "symbol", +): + price_expr = parse_into_expr(price_series) + return_expr = price_expr / price_expr.shift(-t).over(symbol_col) - 1 + rolling_std = ( + (price_expr / price_expr.shift(t) - 1).ewm_std(span=span).over(symbol_col) + ) + return ( + pl.when(return_expr > rolling_std * upper_multiplier) + .then(1) + .when(return_expr < rolling_std * -lower_multiplier) + .then(-1) + .otherwise(0) + ) + + +# TODO: Implement this function +def get_vertical_barrier( + df: FrameType, date_col: str, barrier_size: str, symbol_col: str = "symbol" +) -> FrameType: + raise NotImplementedError("This function is not yet implemented.") + + +# TODO: This needs to return a df +# TODO: Write rust funtion for variable shifts +def triple_barrier_label( + df: FrameType, + price_series: IntoExpr, + horizontal_width: IntoExpr, + pt: float, + sl: float, + vertical_barrier: IntoExpr = 5, + min_return: float = 0.0, + use_vertical_barrier_sign: bool = True, + seed_indicator: IntoExpr | None = None, +): + if seed_indicator is None: + seed_indicator = pl.lit(True) + else: + seed_indicator = parse_into_expr(seed_indicator).cast(pl.Boolean) + price_expr = parse_into_expr(price_series) + horizontal_width_expr = parse_into_expr(horizontal_width) + vertical_barrier_expr = parse_into_expr(vertical_barrier) + labels = df.with_columns( + price_expr.register_plugin( + lib, + "triple_barrier_label", + kwargs={ + "pt": pt, + "sl": sl, + "min_return": min_return, + "use_vertical_barrier_sign": use_vertical_barrier_sign, + }, + args=[horizontal_width_expr, vertical_barrier_expr, seed_indicator], + cast_to_supertypes=True, + ) + ) + return labels + + +def metalabel(): ... + + +def label_uniqueness(): ... + + +def avg_label_uniqueness(): ... + + +def sequential_bootstrap(): ... + + +# TODO: MC Experiment on 4.5.4 + + +def return_attribution_weighting(): ... + + +def time_decay_weighting(): ... + + +def class_sample_weighting(): ... diff --git a/polars_finance/plots.py b/polars_finance/plots.py new file mode 100644 index 0000000..e69de29 diff --git a/polars_finance/sampling_features.py b/polars_finance/sampling_features.py new file mode 100644 index 0000000..6e97db0 --- /dev/null +++ b/polars_finance/sampling_features.py @@ -0,0 +1,19 @@ +import polars as pl +from polars.utils.udfs import _get_shared_lib_location + +from polars_finance.utils import parse_into_expr + +from polars.type_aliases import IntoExpr + +lib = _get_shared_lib_location(__file__) + + +def symmetric_cusum_filter(time_series: IntoExpr, threshold: float) -> pl.Expr: + expr = parse_into_expr(time_series) + return expr.register_plugin( + lib=lib, + symbol="symmetric_cusum_filter", + kwargs={"threshold": threshold}, + is_elementwise=False, + cast_to_supertypes=True, + ) diff --git a/polars_finance/ta.py b/polars_finance/ta.py new file mode 100644 index 0000000..8d67daf --- /dev/null +++ b/polars_finance/ta.py @@ -0,0 +1,11 @@ +import polars as pl +from polars_finance.utils import parse_into_expr +from polars.type_aliases import IntoExpr + + +def balance_of_power( + high: IntoExpr, low: IntoExpr, close: IntoExpr, open: IntoExpr +) -> pl.Expr: + return (parse_into_expr(close) - parse_into_expr(open)) / ( + parse_into_expr(high) - parse_into_expr(low) + ) diff --git a/polars_finance/utils.py b/polars_finance/utils.py new file mode 100644 index 0000000..08ded1b --- /dev/null +++ b/polars_finance/utils.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import polars as pl +from polars.type_aliases import FrameType +from datetime import timedelta +from functools import reduce + +if TYPE_CHECKING: + from polars.type_aliases import IntoExpr, PolarsDataType + + +def parse_into_expr( + expr: IntoExpr, + *, + str_as_lit: bool = False, + list_as_lit: bool = True, + dtype: PolarsDataType | None = None, +) -> pl.Expr: + """ + Parse a single input into an expression. + + Parameters + ---------- + expr + The input to be parsed as an expression. + str_as_lit + Interpret string input as a string literal. If set to `False` (default), + strings are parsed as column names. + list_as_lit + Interpret list input as a lit literal, If set to `False`, + lists are parsed as `Series` literals. + dtype + If the input is expected to resolve to a literal with a known dtype, pass + this to the `lit` constructor. + + Returns + ------- + polars.Expr + """ + if isinstance(expr, pl.Expr): + pass + elif isinstance(expr, str) and not str_as_lit: + expr = pl.col(expr) + elif isinstance(expr, list) and not list_as_lit: + expr = pl.lit(pl.Series(expr), dtype=dtype) + else: + expr = pl.lit(expr, dtype=dtype) + + return expr + + +def dynamic_shift( + df: FrameType, value_col: str, shift_col: str, group_col: str | None = None +) -> FrameType: + """ + Shift a column in a DataFrame by a dynamic amount. + + Parameters + ---------- + df + The DataFrame/LazyFrame to shift. + shift_col + The column to use for shifting. + group_col + The column to group by when shifting. + + Returns + ------- + FrameType + """ + df_ind = df.with_row_index() + if isinstance(df, pl.DataFrame): + shift_values = df[shift_col].unique().to_list() + elif isinstance(df, pl.LazyFrame): + shift_values = ( + df.select(pl.col(shift_col).unique()).collect()[shift_col].to_list() + ) + else: + raise ValueError("df must be a DataFrame or LazyFrame") + + shifted_dfs = [] + for shift_val in shift_values: + if group_col is not None: + shifted_dfs.append( + df_ind.select( + "index", + pl.col(value_col) + .shift(shift_val) + .over(group_col) + .alias(f"{value_col}_shifted"), + pl.lit(shift_val).alias(shift_col).cast(pl.Int64), + ) + ) + else: + shifted_dfs.append( + df_ind.select( + "index", + pl.col(value_col).shift(shift_val).alias(f"{value_col}_shifted"), + pl.lit(shift_val).alias(shift_col).cast(pl.Int64), + ) + ) + return df_ind.join( + pl.concat(shifted_dfs), on=["index", shift_col], how="left" + ).drop("index") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2747366 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,66 @@ +[build-system] +requires = ["maturin>=1.1,<2.0", "polars>=0.20.6"] +build-backend = "maturin" + +[project] +name = "polars-finance" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +version="0.1.0" + +dependencies = [ + "polars>=0.20.13", +] +license = { file = "LICENSE" } +readme = "README.md" + +[tool.maturin] +features = ["pyo3/extension-module"] + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff] +lint.select = [ + "ALL", +] +lint.ignore = [ + 'A003', + 'ANN101', + 'ANN401', + 'ARG002', # todo: enable + 'ARG003', # todo: enable + 'C901', + 'COM812', + 'D100', + 'D103', + 'D104', + 'D105', + 'D107', + 'D203', + 'D212', + 'DTZ', + 'E501', + 'FBT003', # todo: enable + 'FIX', + 'ISC001', + 'PD', + 'PLR0911', + 'PLR0912', + 'PLR5501', + 'PLR2004', + 'PT011', + 'PTH', + 'RET505', + 'S', + 'SLF001', + 'TD', + 'TRY004' +] + +# Allow autofix for all enabled rules (when `--fix`) is provided. +lint.fixable = ["ALL"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4872b0d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +polars +maturin +ruff +pytest + diff --git a/src/bars.rs b/src/bars.rs new file mode 100644 index 0000000..a72730a --- /dev/null +++ b/src/bars.rs @@ -0,0 +1,385 @@ +#![allow(clippy::unused_unit)] +use polars::{lazy::dsl::as_struct, prelude::*}; +use pyo3_polars::derive::polars_expr; + +/// This function calculates dynamic tick bar groups. +/// It takes a slice of Series as input and returns a PolarsResult of a Series with UInt16 type. +/// The function iterates over the thresholds and assigns a group id to each tick based on the threshold. +/// If the row count exceeds the threshold, the group id is incremented and the row count is reset. +#[polars_expr(output_type=UInt16)] +pub fn dynamic_tick_bar_groups(inputs: &[Series]) -> PolarsResult { + let thresholds = inputs[0].u16()?; + let mut group_id: u16 = 0; + let mut row_count: u16 = 0; + let mut builder: PrimitiveChunkedBuilder = + PrimitiveChunkedBuilder::new("group_id", thresholds.len()); + for threshold in thresholds.into_iter() { + match threshold { + Some(threshold) => { + row_count += 1; + builder.append_value(group_id); + if row_count >= threshold { + group_id += 1; + row_count = 0; + } + } + None => builder.append_null(), + } + } + Ok(builder.finish().into_series()) +} + +/// Struct representing a single transaction. +struct Transaction { + dt: i64, + price: f64, + size: u32, +} + +/// Struct representing OHLCV data. +struct OHLCV { + start_dt: i64, + end_dt: i64, + open: f64, + high: f64, + low: f64, + close: f64, + vwap: f64, + volume: u32, + n_transactions: u32, +} + +/// Struct representing a collection of transactions. +struct BarTransactions { + transactions: Vec, +} + +impl BarTransactions { + /// Create a new instance of BarTransactions. + fn new() -> Self { + Self { + transactions: Vec::new(), + } + } + + /// Add a new transaction to the collection. + fn add_transaction(&mut self, price: f64, size: u32, dt: i64) { + self.transactions.push(Transaction { price, size, dt }); + } + + /// Check if the collection is empty. + fn is_empty(&self) -> bool { + self.transactions.is_empty() + } + + /// Clear all transactions from the collection. + fn clear_transactions(&mut self) { + self.transactions.clear(); + } + + /// Get the current volume of the transactions. + fn get_current_volume(&self) -> u32 { + self.transactions.iter().map(|t| t.size).sum() + } + + /// Get the current dollar volume of the transactions. + fn get_current_dollar_volume(&self) -> f64 { + self.transactions + .iter() + .map(|t| t.price * t.size as f64) + .sum() + } + + /// Calculate the OHLCV data from the transactions. + fn calculate_ohlcv(&self) -> OHLCV { + let start_dt = self.transactions.first().unwrap().dt; + let end_dt = self.transactions.last().unwrap().dt; + let open = self.transactions.first().unwrap().price; + let close = self.transactions.last().unwrap().price; + let high = self + .transactions + .iter() + .map(|t| t.price) + .fold(f64::MIN, f64::max); + let low = self + .transactions + .iter() + .map(|t| t.price) + .fold(f64::MAX, f64::min); + let volume = self.transactions.iter().map(|t| t.size).sum::(); + let vwap = self + .transactions + .iter() + .map(|t| t.price * t.size as f64) + .sum::() + / volume as f64; + let n_transactions = self.transactions.len().try_into().unwrap(); + OHLCV { + start_dt, + end_dt, + open, + high, + low, + close, + vwap, + volume, + n_transactions, + } + } +} + +/// Enum to represent the threshold for calculating bars. +enum Threshold { + Volume(u32), // Threshold based on volume + Dollar(f64), // Threshold based on dollar value +} + +/// Implementation of the Threshold enum. +impl Threshold { + /// Create a Threshold enum from a u32 value. + fn from_u32(threshold: Option) -> Option { + threshold.map(Threshold::Volume) + } + + /// Create a Threshold enum from a f64 value. + fn from_f64(threshold: Option) -> Option { + threshold.map(Threshold::Dollar) + } +} + +/// Function to calculate bars from trades. +/// It takes in datetimes, prices, sizes, and thresholds as input. +/// It returns a DataFrame containing the calculated bars. +fn calculate_bars_from_trades( + datetimes: &[Option], // Datetimes of the trades + prices: &[Option], // Prices of the trades + sizes: &[Option], // Sizes of the trades + threshold: &[Option], // Threshold for calculating the bars +) -> PolarsResult { + // TODO: Add dollar volume to OHLCV + let mut bars: Vec = Vec::new(); // Vector to store the calculated bars + let mut start_dt: Vec = Vec::new(); // Vector to store the start datetimes of the bars + let mut end_dt: Vec = Vec::new(); // Vector to store the end datetimes of the bars + let mut opens: Vec = Vec::new(); // Vector to store the opening prices of the bars + let mut highs: Vec = Vec::new(); // Vector to store the highest prices of the bars + let mut lows: Vec = Vec::new(); // Vector to store the lowest prices of the bars + let mut closes: Vec = Vec::new(); // Vector to store the closing prices of the bars + let mut vwap: Vec = Vec::new(); // Vector to store the volume weighted average prices of the bars + let mut volumes: Vec = Vec::new(); // Vector to store the volumes of the bars + let mut n_transactions: Vec = Vec::new(); // Vector to store the number of transactions in the bars + let mut bar_transactions = BarTransactions::new(); // BarTransactions instance to calculate the bars + + // CALCULATE BARS AND ADD TO SERIES THEN CREATE DF + for (((dt, price), size), thresh) in datetimes + .iter() + .zip(prices.iter()) + .zip(sizes.iter()) + .zip(threshold.iter()) + { + match (dt, price, size, thresh) { + (Some(dt), Some(price), Some(mut size), Some(thresh)) => match thresh { + Threshold::Volume(thresh) => { + if size >= thresh - bar_transactions.get_current_volume() { + let mut remaining_size = thresh - bar_transactions.get_current_volume(); + while size >= remaining_size { + bar_transactions.add_transaction(*price, remaining_size, *dt); + let ohlcv = bar_transactions.calculate_ohlcv(); + start_dt.push(ohlcv.start_dt); + end_dt.push(ohlcv.end_dt); + opens.push(ohlcv.open); + highs.push(ohlcv.high); + lows.push(ohlcv.low); + closes.push(ohlcv.close); + vwap.push(ohlcv.vwap); + volumes.push(ohlcv.volume); + n_transactions.push(ohlcv.n_transactions); + bars.push(ohlcv); + bar_transactions.clear_transactions(); + size -= remaining_size; + remaining_size = *thresh; + } + if size > 0 { + bar_transactions.add_transaction(*price, size, *dt); + } + } else { + bar_transactions.add_transaction(*price, size, *dt); + } + } + Threshold::Dollar(thresh) => { + if price * size as f64 + >= *thresh - bar_transactions.get_current_dollar_volume() as f64 + { + let mut remaining = *thresh - bar_transactions.get_current_dollar_volume(); + while price * size as f64 >= remaining { + bar_transactions.add_transaction( + *price, + (remaining / *price) as u32, + *dt, + ); + let ohlcv = bar_transactions.calculate_ohlcv(); + start_dt.push(ohlcv.start_dt); + end_dt.push(ohlcv.end_dt); + opens.push(ohlcv.open); + highs.push(ohlcv.high); + lows.push(ohlcv.low); + closes.push(ohlcv.close); + vwap.push(ohlcv.vwap); + volumes.push(ohlcv.volume); + n_transactions.push(ohlcv.n_transactions); + bars.push(ohlcv); + bar_transactions.clear_transactions(); + size -= (remaining / *price) as u32; + remaining = *thresh; + } + if size > 0 { + bar_transactions.add_transaction(*price, size, *dt); + } + } else { + bar_transactions.add_transaction(*price, size, *dt); + } + } + }, + _ => {} + } + // create an array of len bars and fill with symbol + } + if !bar_transactions.is_empty() { + let ohlcv = bar_transactions.calculate_ohlcv(); + start_dt.push(ohlcv.start_dt); + end_dt.push(ohlcv.end_dt); + opens.push(ohlcv.open); + highs.push(ohlcv.high); + lows.push(ohlcv.low); + closes.push(ohlcv.close); + vwap.push(ohlcv.vwap); + volumes.push(ohlcv.volume); + n_transactions.push(ohlcv.n_transactions); + bars.push(ohlcv); + } + + // Create a DataFrame from the calculated bars + df!( + "start_dt" => Series::from_vec("start_dt", start_dt), + "end_dt" => Series::from_vec("end_dt", end_dt), + "open" => Series::from_vec("open", opens), + "high" => Series::from_vec("high", highs), + "low" => Series::from_vec("low", lows), + "close" => Series::from_vec("close", closes), + "vwap" => Series::from_vec("vwap", vwap), + "volume" => Series::from_vec("volume", volumes), + "n_transactions" => Series::from_vec("n_transactions", n_transactions) + ) +} + +/// Function to define the type of the OHLCV struct. +/// It takes in the input fields of the DataFrame. +/// It returns a Field representing the OHLCV struct. +fn ohlcv_struct_type(_input_fields: &[Field]) -> PolarsResult { + Ok(Field::new( + "ohlcv", + DataType::Struct(vec![ + Field::new("start_dt", DataType::Datetime(TimeUnit::Nanoseconds, None)), + Field::new("end_dt", DataType::Datetime(TimeUnit::Nanoseconds, None)), + Field::new("open", DataType::Float64), + Field::new("high", DataType::Float64), + Field::new("low", DataType::Float64), + Field::new("close", DataType::Float64), + Field::new("vwap", DataType::Float64), + Field::new("volume", DataType::UInt32), + Field::new("n_transactions", DataType::UInt32), + ]), + )) +} + +/// Function to calculate volume bars. +/// It takes in a list of Series as input. +/// It returns a Series containing the calculated volume bars. +#[polars_expr(output_type_func=ohlcv_struct_type)] // FIXME +pub fn volume_bars(inputs: &[Series]) -> PolarsResult { + let dts = inputs[0].datetime()?; // Datetimes of the trades + let dt_type = dts.dtype(); // Type of the datetimes + let dts = dts.to_vec(); // Convert the datetimes to a vector + let prices = inputs[1].f64()?.to_vec(); // Prices of the trades + let sizes = inputs[2].u32()?.to_vec(); // Sizes of the trades + let threshold = inputs[3].u32()?.to_vec(); // Threshold for calculating the bars + let threshold = threshold + .iter() + .map(|&x| Threshold::from_u32(x)) + .collect::>>(); // Convert the threshold to a vector of Threshold enums + + // Calculate the bars from the trades + let bars = calculate_bars_from_trades( + dts.as_slice(), + prices.as_slice(), + sizes.as_slice(), + threshold.as_slice(), + )?; + let s = bars + .lazy() + .with_columns(vec![ + col("start_dt").cast(dt_type.clone()), // Cast the start datetimes to the original type + col("end_dt").cast(dt_type.clone()), // Cast the end datetimes to the original type + ]) + .select([as_struct(vec![ + col("start_dt"), + col("end_dt"), + col("open"), + col("high"), + col("low"), + col("close"), + col("vwap"), + col("volume"), + col("n_transactions"), + ]) + .alias("bar")]) + .collect()? + .column("bar")? + .clone(); // Select the OHLCV struct and cast it to the original type + Ok(s) // Return the calculated bars +} + + +/// Function to calculate dollar bars. +/// It takes in a list of Series as input. +/// It returns a Series containing the calculated dollar bars. +#[polars_expr(output_type_func=ohlcv_struct_type)] +pub fn dollar_bars(inputs: &[Series]) -> PolarsResult { + let dts = inputs[0].datetime()?; // Datetimes of the trades + let dt_type = dts.dtype(); // Type of the datetimes + let dts = dts.to_vec(); // Convert the datetimes to a vector + let prices = inputs[1].f64()?.to_vec(); // Prices of the trades + let sizes = inputs[2].u32()?.to_vec(); // Sizes of the trades + let threshold = inputs[3].f64()?.to_vec(); // Threshold for calculating the bars + let threshold = threshold + .iter() + .map(|&x| Threshold::from_f64(x)) + .collect::>>(); // Convert the threshold to a vector of Threshold enums + let bars = calculate_bars_from_trades( + dts.as_slice(), + prices.as_slice(), + sizes.as_slice(), + threshold.as_slice(), + )?; // Calculate the bars from the trades + let s = bars + .lazy() + .with_columns(vec![ + col("start_dt").cast(dt_type.clone()), // Cast the start datetimes to the original type + col("end_dt").cast(dt_type.clone()), // Cast the end datetimes to the original type + ]) + .select([as_struct(vec![ + col("start_dt"), + col("end_dt"), + col("open"), + col("high"), + col("low"), + col("close"), + col("vwap"), + col("volume"), + col("n_transactions"), + ]) + .alias("bar")]) + .collect()? + .column("bar")? + .clone(); // Select the OHLCV struct and cast it to the original type + Ok(s) // Return the calculated bars +} diff --git a/src/labels.rs b/src/labels.rs new file mode 100644 index 0000000..69d5175 --- /dev/null +++ b/src/labels.rs @@ -0,0 +1,224 @@ +#![allow(clippy::unused_unit)] +use polars::error::ErrString; +use polars::prelude::*; +use pyo3_polars::derive::polars_expr; +use serde::Deserialize; + +#[derive(Deserialize, Debug)] +struct TripleBarrierLabelKwargs { + stop_loss: Option, + profit_taker: Option, + use_vertical_barrier_sign: bool, + min_return: f64, +} + +struct HorizontalBarrier { + lower: Option, + upper: Option, +} + +struct Label { + event: Option, + ret: f64, + n_bars: i64, +} + +fn get_event( + path_prices: &[f64], + stop_loss: Option, + profit_taker: Option, + use_vertical_barrier_sign: bool, + min_return: f64, +) -> Label { + for (i, price) in path_prices.iter().enumerate() { + match (stop_loss, profit_taker) { + (Some(sl), Some(pt)) => { + if *price <= -sl && *price <= -min_return { + return Label { + event: Some(-1), + ret: *price, + n_bars: i as i64, + }; + } else if *price >= pt && *price >= min_return { + return Label { + event: Some(1), + ret: *price, + n_bars: i as i64, + }; + } + } + (None, Some(pt)) => { + if *price >= pt && *price >= min_return { + return Label { + event: Some(1), + ret: *price, + n_bars: i as i64, + }; + } + } + (Some(sl), None) => { + if *price <= -sl && *price <= -min_return { + return Label { + event: Some(-1), + ret: *price, + n_bars: i as i64, + }; + } + } + _ => {} + } + } + if use_vertical_barrier_sign { + if *path_prices.last().unwrap_or(&0.0) < -min_return { + Label { + event: Some(-1), + ret: *path_prices.last().unwrap_or(&0.0), + n_bars: path_prices.len() as i64, + } + } else if *path_prices.last().unwrap_or(&0.0) > min_return { + Label { + event: Some(1), + ret: *path_prices.last().unwrap_or(&0.0), + n_bars: path_prices.len() as i64, + } + } else { + Label { + event: None, + ret: *path_prices.last().unwrap_or(&0.0), + n_bars: path_prices.len() as i64, + } + } + } else { + Label { + event: Some(0), + ret: *path_prices.last().unwrap_or(&0.0), + n_bars: path_prices.len() as i64, + } + } +} + +fn get_horizontal_barriers( + horizontal_widths: &[Option], + stop_loss: Option, + profit_taker: Option, +) -> Vec { + let mut horizontal_barriers = Vec::new(); + for width in horizontal_widths { + let (lower, upper) = match width { + Some(w) => match (stop_loss, profit_taker) { + (Some(sl), Some(pt)) => (Some(sl * w), Some(pt * w)), + (None, Some(pt)) => (None, Some(pt * w)), + (Some(sl), None) => (Some(sl * w), None), + _ => (None, None), + }, + None => (None, None), + }; + horizontal_barriers.push(HorizontalBarrier { lower, upper }); + } + horizontal_barriers +} + +fn get_path_prices(prices: &[f64]) -> Vec { + let first_price = prices[0]; + let mut path_prices = Vec::new(); + for price in prices { + path_prices.push(price / first_price - 1.0); + } + path_prices +} + +fn tbl_struct_type(_input_fields: &[Field]) -> PolarsResult { + Ok(Field::new( + "triple_barrier_label", + DataType::Struct(vec![ + Field::new("label", DataType::Int8), + Field::new("ret", DataType::Float64), + Field::new("n_bars", DataType::Int64), + ]), + )) +} + +#[polars_expr(output_type_func=tbl_struct_type)] +pub fn triple_barrier_label( + inputs: &[Series], + kwargs: TripleBarrierLabelKwargs, +) -> PolarsResult { + let prices = inputs[0].f64()?.to_vec(); + let horizontal_widths = inputs[1].f64()?; + let vertical_barriers = inputs[2].i64()?; + let seed_indicator = inputs[3].bool()?; + let stop_loss = kwargs.stop_loss; + let profit_taker = kwargs.profit_taker; + + if prices.iter().any(|&x| x.is_none()) { + return Err(PolarsError::ComputeError(ErrString::from( + "Missing prices in the input".to_string(), + ))); + } + let prices: Vec = prices.iter().map(|&x| x.unwrap()).collect(); + let horizontal_barriers = + get_horizontal_barriers(&horizontal_widths.to_vec(), stop_loss, profit_taker); + + let mut event_builder: PrimitiveChunkedBuilder = + PrimitiveChunkedBuilder::new("triple_barrier_label_event", prices.len()); + let mut ret_builder: PrimitiveChunkedBuilder = + PrimitiveChunkedBuilder::new("triple_barrier_label_ret", prices.len()); + let mut n_bar_builder: PrimitiveChunkedBuilder = + PrimitiveChunkedBuilder::new("triple_barrier_label_n_bars", prices.len()); + for i in 0..prices.len() { + if !seed_indicator.get(i).unwrap_or(false) { + event_builder.append_null(); + ret_builder.append_null(); + n_bar_builder.append_null(); + } else { + let path_prices = get_path_prices( + &prices[i..vertical_barriers.get(i).unwrap_or(prices.len() as i64) as usize], + ); + let label = get_event( + &path_prices, + horizontal_barriers[i].lower, + horizontal_barriers[i].upper, + kwargs.use_vertical_barrier_sign, + kwargs.min_return, + ); + // TODO: Add n_bars to the output + match label { + Label { + event: Some(e), + ret: _, + n_bars: _, + } => { + event_builder.append_value(e); + ret_builder.append_value(label.ret); + n_bar_builder.append_value(label.n_bars); + } + + Label { + event: None, + ret: _, + n_bars: _, + } => { + event_builder.append_null(); + ret_builder.append_null(); + n_bar_builder.append_null(); + } + } + } + } + let s = df!( + "triple_barrier_label_event" => event_builder.finish(), + "triple_barrier_label_ret" => ret_builder.finish(), + "triple_barrier_label_n_bars" => n_bar_builder.finish() + )? + .lazy() + .select([as_struct(vec![ + col("triple_barrier_label_event"), + col("triple_barrier_label_ret"), + col("triple_barrier_label_n_bars") + ]) + .alias("triple_barrier_label")]) + .collect()? + .column("triple_barrier_label")? + .clone(); + Ok(s) +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..20e0d65 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,11 @@ +mod bars; +mod labels; +mod nbbo; +mod symmetric_cusum_filter; + +#[cfg(target_os = "linux")] +use jemallocator::Jemalloc; + +#[global_allocator] +#[cfg(target_os = "linux")] +static ALLOC: Jemalloc = Jemalloc; diff --git a/src/nbbo.rs b/src/nbbo.rs new file mode 100644 index 0000000..2f510f2 --- /dev/null +++ b/src/nbbo.rs @@ -0,0 +1,31 @@ +// #![allow(clippy::unused_unit)] +// use polars::prelude::*; +// use pyo3_polars::derive::polars_expr; + +// struct BBO { +// bid: Option, +// ask: Option, +// } + +// #[polars_expr(output_type=Float64)] +// fn nbbo(inputs: &[Series]) -> PolarsResult { +// let mut bbo_map: HashMap = HashMap::new(); +// let bid: &Float64Chunked = inputs[0].f64()?; +// let ask: &Float64Chunked = inputs[1].f64()?; +// let publisher_id: &UInt32Chunked = inputs[2].u32()?; +// let bbos = Vec::with_capacity(bid.len()); +// for (i, (bid, ask, publisher_id)) in bid.into_iter().zip(ask).zip(publisher_id).enumerate() { +// let bbo = bbo_map.entry(publisher_id).or_insert(BBO { bid: None, ask: None }); +// if bbo.bid.is_none() || bid > bbo.bid.unwrap() { +// bbo.bid = Some(bid); +// } +// if bbo.ask.is_none() || ask < bbo.ask.unwrap() { +// bbo.ask = Some(ask); +// } +// let best_bid = bbo_map.values().filter_map(|bbo| bbo.bid).max(); +// let best_ask = bbo_map.values().filter_map(|bbo| bbo.ask).min(); +// bbos.push(BBO { bid: best_bid, ask: best_ask }); +// } +// let out = best_bid.zip(best_ask).map(|(bid, ask)| bid - ask); +// Ok(out.into_series()) +// } diff --git a/src/symmetric_cusum_filter.rs b/src/symmetric_cusum_filter.rs new file mode 100644 index 0000000..e78143c --- /dev/null +++ b/src/symmetric_cusum_filter.rs @@ -0,0 +1,56 @@ +#![allow(clippy::unused_unit)] +use polars::prelude::*; +use pyo3_polars::derive::polars_expr; +use serde::Deserialize; + +#[derive(Deserialize, Debug)] +struct CusumKwargs { + threshold: f64, +} + +fn calculate_cusum_filter(diff_series: &ChunkedArray, threshold: f64) -> Vec { + let mut out: Vec = Vec::with_capacity(diff_series.len()); + let mut s_pos = 0.0; + let mut s_neg = 0.0; + for val in diff_series.iter() { + match val { + Some(v) => { + s_pos = (s_pos + v).max(0.0); + s_neg = (s_neg + v).min(0.0); + if s_neg < -threshold { + s_neg = 0.0; + out.push(-1); + } else if s_pos > threshold { + s_pos = 0.0; + out.push(1); + } else { + out.push(0); + } + } + None => out.push(0), + } + } + out +} + +#[polars_expr(output_type=Int8)] +pub fn symmetric_cusum_filter(inputs: &[Series], kwargs: CusumKwargs) -> PolarsResult { + let diff_series = inputs[0].f64()?; + let out = calculate_cusum_filter(diff_series, kwargs.threshold); + Ok(Series::from_vec("cusum_filter", out)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_calculate_cusum_filter() { + let diff_series = Float64Chunked::from_slice("diff_series", &[1.0, 2.0, -3.0, -4.0, 5.0]); + let threshold = 2.0; + let expected = vec![0, 1, -1, -1, 1]; + + let result = calculate_cusum_filter(&diff_series, threshold); + assert_eq!(result, expected); + } +}