From 68a47711b15dd6a4be4b902db8b6e9f07875db49 Mon Sep 17 00:00:00 2001 From: edtechre Date: Wed, 13 Dec 2023 23:41:50 -0800 Subject: [PATCH] Adds support for custom fee calculation. Allows fee_mode to be a Callable for custom fee calculation. --- docs/source/reference/pybroker.common.rst | 2 +- src/pybroker/common.py | 16 +++++++++++++- src/pybroker/config.py | 9 ++++++-- src/pybroker/portfolio.py | 26 +++++++++++++++++++---- tests/test_portfolio.py | 10 +++++++++ 5 files changed, 55 insertions(+), 8 deletions(-) diff --git a/docs/source/reference/pybroker.common.rst b/docs/source/reference/pybroker.common.rst index e031f22..efc80ff 100644 --- a/docs/source/reference/pybroker.common.rst +++ b/docs/source/reference/pybroker.common.rst @@ -6,4 +6,4 @@ pybroker.common module :undoc-members: :show-inheritance: :exclude-members: ind_name, symbol, model_name, instance, name, exec_id, - predict_fn + predict_fn, shares, fill_price, order_type diff --git a/src/pybroker/common.py b/src/pybroker/common.py index 632a3f2..a16ce87 100644 --- a/src/pybroker/common.py +++ b/src/pybroker/common.py @@ -15,7 +15,7 @@ from enum import Enum from joblib import Parallel from numpy.typing import NDArray -from typing import Any, Callable, Final, NamedTuple, Optional, Union +from typing import Any, Callable, Final, Literal, NamedTuple, Optional, Union _tf_pattern: Final = re.compile(r"(\d+)([A-Za-z]+)") _tf_abbr: Final = { @@ -148,6 +148,20 @@ class FeeMode(Enum): PER_SHARE = "per_share" +class FeeInfo(NamedTuple): + """Contains info for custom fee calculations. + + Attributes: + shares: Number of shares in order. + fill_price: Fill price of order. + order_type: Type of order, either "buy" or "sell". + """ + + shares: Decimal + fill_price: Decimal + order_type: Literal["buy", "sell"] + + class BarData: r"""Contains data for a series of bars. Each field is a :class:`numpy.ndarray` that contains bar values in the series. The values diff --git a/src/pybroker/config.py b/src/pybroker/config.py index 134e6cb..ce07077 100644 --- a/src/pybroker/config.py +++ b/src/pybroker/config.py @@ -6,7 +6,7 @@ (see LICENSE for details). """ -from pybroker.common import BarData, FeeMode, PriceType +from pybroker.common import BarData, FeeInfo, FeeMode, PriceType from dataclasses import dataclass, field from decimal import Decimal from typing import Callable, Optional, Union @@ -24,6 +24,9 @@ class StrategyConfig: - ``ORDER_PERCENT``: Fee is a percentage of order amount. - ``PER_ORDER``: Fee is a constant amount per order. - ``PER_SHARE``: Fee is a constant amount per share in order. + - ``Callable[[FeeInfo], Decimal]]``: Fees are calculated using a + custom ``Callable`` that is passed + :class:`pybroker.common.FeeInfo`. - ``None``: Fees are disabled (default). fee_amount: Brokerage fee amount. enable_fractional_shares: Whether to enable trading fractional shares. @@ -62,7 +65,9 @@ class StrategyConfig: """ initial_cash: float = field(default=100_000) - fee_mode: Optional[FeeMode] = field(default=None) + fee_mode: Optional[Union[FeeMode, Callable[[FeeInfo], Decimal]]] = field( + default=None + ) fee_amount: float = field(default=0) enable_fractional_shares: bool = field(default=False) max_long_positions: Optional[int] = field(default=None) diff --git a/src/pybroker/portfolio.py b/src/pybroker/portfolio.py index 4ac7748..c4bd61c 100644 --- a/src/pybroker/portfolio.py +++ b/src/pybroker/portfolio.py @@ -14,6 +14,7 @@ from pybroker.common import ( BarData, DataCol, + FeeInfo, FeeMode, PriceType, StopType, @@ -312,7 +313,9 @@ class Portfolio: def __init__( self, cash: float, - fee_mode: Optional[FeeMode] = None, + fee_mode: Optional[ + Union[FeeMode, Callable[[FeeInfo], Decimal], None] + ] = None, fee_amount: Optional[float] = None, enable_fractional_shares: bool = False, max_long_positions: Optional[int] = None, @@ -348,11 +351,26 @@ def __init__( self._entry_id: int = 0 self._trade_id: int = 0 - def _calculate_fees(self, fill_price: Decimal, shares: Decimal) -> Decimal: + def _calculate_fees( + self, + fill_price: Decimal, + shares: Decimal, + order_type: Literal["buy", "sell"], + ) -> Decimal: fees = Decimal() if self._fee_mode is None or self._fee_amount is None: return fees - if self._fee_mode == FeeMode.ORDER_PERCENT: + if callable(self._fee_mode): + fees = to_decimal( + self._fee_mode( + FeeInfo( + shares=shares, + fill_price=fill_price, + order_type=order_type, + ) + ) + ) + elif self._fee_mode == FeeMode.ORDER_PERCENT: fees = self._fee_amount / _DECIMAL_100 * fill_price * shares elif self._fee_mode == FeeMode.PER_ORDER: fees = self._fee_amount @@ -406,7 +424,7 @@ def _add_order( shares: Decimal, ) -> Order: self._order_id += 1 - fees = self._calculate_fees(fill_price, shares) + fees = self._calculate_fees(fill_price, shares, type) order = Order( id=self._order_id, date=date, diff --git a/tests/test_portfolio.py b/tests/test_portfolio.py index e0cfa69..2dc2774 100644 --- a/tests/test_portfolio.py +++ b/tests/test_portfolio.py @@ -732,6 +732,15 @@ def test_sell_when_all_shares_and_fractional(): ) +def calc_fees(fee_info): + assert fee_info.shares == SHARES_1 + if fee_info.order_type == "buy": + assert fee_info.fill_price == FILL_PRICE_1 + else: + assert fee_info.fill_price == FILL_PRICE_3 + return Decimal("9.99") + + @pytest.mark.parametrize( "fee_mode, expected_buy_fees, expected_sell_fees", [ @@ -746,6 +755,7 @@ def test_sell_when_all_shares_and_fractional(): SHARES_1, ), (FeeMode.PER_ORDER, Decimal("1"), Decimal("1")), + (calc_fees, Decimal("9.99"), Decimal("9.99")), ], ) def test_buy_and_sell_when_fees(