diff --git a/src/pybroker/context.py b/src/pybroker/context.py index ed68522..07c38fe 100644 --- a/src/pybroker/context.py +++ b/src/pybroker/context.py @@ -236,7 +236,7 @@ def _verify_pos_type(self, pos_type: str): def calc_target_shares( self, target_size: float, price: float, cash: Optional[float] = None - ) -> int: + ) -> Union[Decimal, int]: r"""Calculates the number of shares given a ``target_size`` allocation and share ``price``. @@ -250,14 +250,18 @@ def calc_target_shares( is used to calculate the number of shares. Returns: - Number of shares given ``target_size`` and share ``price``. + Number of shares given ``target_size`` and share ``price``. If + :attr:`pybroker.config.StrategyConfig.enable_fractional_shares` is + ``True``, then a Decimal is returned. """ - shares = int( + shares = ( (to_decimal(cash) if cash is not None else self._portfolio.equity) * to_decimal(target_size) / to_decimal(price) ) - return max(shares, 0) + if self.config.enable_fractional_shares: + return shares.max(0) + return max(int(shares), 0) def model(self, name: str, symbol: str) -> Any: r"""Returns a trained model. @@ -972,7 +976,7 @@ def calc_target_shares( target_size: float, price: Optional[float] = None, cash: Optional[float] = None, - ) -> int: + ) -> Union[Decimal, int]: r"""Calculates the number of shares given a ``target_size`` allocation and share ``price``. @@ -988,7 +992,9 @@ def calc_target_shares( is used to calculate the number of shares. Returns: - Number of shares given ``target_size`` and share ``price``. + Number of shares given ``target_size`` and share ``price``. If + :attr:`pybroker.config.StrategyConfig.enable_fractional_shares` is + ``True``, then a Decimal is returned. """ price = self.close[-1] if price is None else price return super().calc_target_shares(target_size, price, cash) diff --git a/tests/test_context.py b/tests/test_context.py index 23b9119..568d114 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -461,6 +461,36 @@ def test_calc_target_shares(ctx): assert ctx.calc_target_shares(0.5, 33.50) == 50_000 // 33.5 +def test_calc_target_shares_when_enable_fractional_shares( + col_scope, + ind_scope, + input_scope, + pred_scope, + pending_order_scope, + portfolio, + trained_models, + sym_end_index, + session, + symbol, +): + ctx = ExecContext( + symbol=symbol, + config=StrategyConfig(enable_fractional_shares=True), + portfolio=portfolio, + col_scope=col_scope, + ind_scope=ind_scope, + input_scope=input_scope, + pred_scope=pred_scope, + pending_order_scope=pending_order_scope, + models=trained_models, + sym_end_index=sym_end_index, + session=session, + ) + assert ctx.calc_target_shares(0.5, 33.50) == Decimal("50_000") / Decimal( + "33.5" + ) + + def test_calc_target_shares_with_cash(ctx): assert ctx.calc_target_shares(1 / 3, 20, 10_000) == 166