Skip to content

Commit

Permalink
feat(type): overloads for strategy kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored May 25, 2024
1 parent a7daa70 commit 38d74a5
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions brownie/test/strategies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/python3

from typing import Any, Callable, Iterable, Optional, Tuple, Union
from typing import Any, Callable, Iterable, Optional, Tuple, Union, overload

from eth_abi.grammar import BasicType, TupleType, parse
from hypothesis import strategies as st
Expand All @@ -16,6 +16,17 @@
ArrayLengthType = Union[int, list, None]
NumberType = Union[float, int, None]

EvmIntType = Literal[
"int8", "int16", "int24", "int32", "int40", "int48", "int56", "int64", "int72", "int80", "int88", "int96",
"int104", "int112", "int120", "int128", "int136", "int144", "int152", "int160", "int168", "int176", "int184",
"int192", "int200", "int208", "int216", "int224", "int232", "int240", "int248", "int256"
]

EvmUintType = Literal[
"uint8", "uint16", "uint24", "uint32", "uint40", "uint48", "uint56", "uint64", "uint72", "uint80", "uint88", "uint96",
"uint104", "uint112", "uint120", "uint128", "uint136", "uint144", "uint152", "uint160", "uint168", "uint176", "uint184",
"uint192", "uint200", "uint208", "uint216", "uint224", "uint232", "uint240", "uint248", "uint256"
]

class _DeferredStrategyRepr(DeferredStrategy):
def __init__(self, fn: Callable, repr_target: str) -> None:
Expand Down Expand Up @@ -76,9 +87,9 @@ def _decimal_strategy(


@_exclude_filter
def _address_strategy(length: Optional[int] = None) -> SearchStrategy:
def _address_strategy(length: Optional[int] = None, include: list = []) -> SearchStrategy:
return _DeferredStrategyRepr(
lambda: st.sampled_from(list(network.accounts)[:length]), "accounts"
lambda: st.sampled_from(list(network.accounts)[:length] + include), "accounts"
)


Expand Down Expand Up @@ -152,7 +163,12 @@ def _contract_deferred(name):

return _DeferredStrategyRepr(lambda: _contract_deferred(contract_name), contract_name)


@overload
def strategy(type_str: Literal["address"], length: Optional[int] = None, include: list = []):
...
@overload
def strategy(type_str: Union[EvmIntType, EvmUintType], min_value: Optional[int] = None, max_value: Optional[int] = None):
...
def strategy(type_str: str, **kwargs: Any) -> SearchStrategy:
type_str = TYPE_STR_TRANSLATIONS.get(type_str, type_str)
if type_str == "fixed168x10":
Expand Down

0 comments on commit 38d74a5

Please sign in to comment.