diff --git a/brownie/test/strategies.py b/brownie/test/strategies.py index 136fa0890..87d1408ff 100644 --- a/brownie/test/strategies.py +++ b/brownie/test/strategies.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 -from typing import Any, Callable, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Iterable, Optional, Tuple, Union, overload from eth_abi.grammar import BasicType, TupleType, parse from hypothesis import strategies as st @@ -16,6 +16,17 @@ ArrayLengthType = Union[int, list, None] NumberType = Union[float, int, None] +EvmIntType = Literal[ + "int8", "int16", "int24", "int32", "int40", "int48", "int56", "int64", "int72", "int80", "int88", "int96", + "int104", "int112", "int120", "int128", "int136", "int144", "int152", "int160", "int168", "int176", "int184", + "int192", "int200", "int208", "int216", "int224", "int232", "int240", "int248", "int256" +] + +EvmUintType = Literal[ + "uint8", "uint16", "uint24", "uint32", "uint40", "uint48", "uint56", "uint64", "uint72", "uint80", "uint88", "uint96", + "uint104", "uint112", "uint120", "uint128", "uint136", "uint144", "uint152", "uint160", "uint168", "uint176", "uint184", + "uint192", "uint200", "uint208", "uint216", "uint224", "uint232", "uint240", "uint248", "uint256" +] class _DeferredStrategyRepr(DeferredStrategy): def __init__(self, fn: Callable, repr_target: str) -> None: @@ -76,9 +87,9 @@ def _decimal_strategy( @_exclude_filter -def _address_strategy(length: Optional[int] = None) -> SearchStrategy: +def _address_strategy(length: Optional[int] = None, include: list = []) -> SearchStrategy: return _DeferredStrategyRepr( - lambda: st.sampled_from(list(network.accounts)[:length]), "accounts" + lambda: st.sampled_from(list(network.accounts)[:length] + include), "accounts" ) @@ -152,7 +163,12 @@ def _contract_deferred(name): return _DeferredStrategyRepr(lambda: _contract_deferred(contract_name), contract_name) - +@overload +def strategy(type_str: Literal["address"], length: Optional[int] = None, include: list = []): + ... +@overload +def strategy(type_str: Union[EvmIntType, EvmUintType], min_value: Optional[int] = None, max_value: Optional[int] = None): + ... def strategy(type_str: str, **kwargs: Any) -> SearchStrategy: type_str = TYPE_STR_TRANSLATIONS.get(type_str, type_str) if type_str == "fixed168x10":