diff --git a/.github/workflows/black_auto.yml b/.github/workflows/black_auto.yml new file mode 100644 index 0000000000..ad1ad2ea4b --- /dev/null +++ b/.github/workflows/black_auto.yml @@ -0,0 +1,43 @@ +--- +name: Run black (auto) + +defaults: + run: + # To load bashrc + shell: bash -ieo pipefail {0} + +on: + pull_request: + branches: [master, dev] + paths: + - "**/*.py" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build: + name: Black + runs-on: ubuntu-latest + + steps: + - name: Checkout Code + uses: actions/checkout@v4 + + - name: Set up Python 3.8 + uses: actions/setup-python@v5 + with: + python-version: 3.8 + + - name: Run black + uses: psf/black@stable + with: + options: "" + summary: false + version: "~= 22.3.0" + + - name: Annotate diff changes using reviewdog + uses: reviewdog/action-suggester@v1 + with: + tool_name: blackfmt diff --git a/FUNDING.json b/FUNDING.json index 189cf2f950..b356f2a74e 100644 --- a/FUNDING.json +++ b/FUNDING.json @@ -6,7 +6,7 @@ }, "drips": { "ethereum": { - "ownedBy": "0x5e2BA02F62bD4efa939e3B80955bBC21d015DbA0" + "ownedBy": "0xc44F30Be3eBBEfdDBB5a85168710b4f0e18f4Ff0" } } } diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index f600d0f43d..fc178db4a7 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -46,12 +46,6 @@ if TYPE_CHECKING: from slither.slithir.variables.variable import SlithIRVariable from slither.core.compilation_unit import SlitherCompilationUnit - from slither.utils.type_helpers import ( - InternalCallType, - HighLevelCallType, - LibraryCallType, - LowLevelCallType, - ) from slither.core.cfg.scope import Scope from slither.core.scope.scope import FileScope @@ -153,11 +147,11 @@ def __init__( self._ssa_vars_written: List["SlithIRVariable"] = [] self._ssa_vars_read: List["SlithIRVariable"] = [] - self._internal_calls: List[Union["Function", "SolidityFunction"]] = [] - self._solidity_calls: List[SolidityFunction] = [] - self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls - self._library_calls: List["LibraryCallType"] = [] - self._low_level_calls: List["LowLevelCallType"] = [] + self._internal_calls: List[InternalCall] = [] # contains solidity calls + self._solidity_calls: List[SolidityCall] = [] + self._high_level_calls: List[Tuple[Contract, HighLevelCall]] = [] # contains library calls + self._library_calls: List[LibraryCall] = [] + self._low_level_calls: List[LowLevelCall] = [] self._external_calls_as_expressions: List[Expression] = [] self._internal_calls_as_expressions: List[Expression] = [] self._irs: List[Operation] = [] @@ -226,8 +220,9 @@ def type(self, new_type: NodeType) -> None: @property def will_return(self) -> bool: if not self.sons and self.type != NodeType.THROW: - if SolidityFunction("revert()") not in self.solidity_calls: - if SolidityFunction("revert(string)") not in self.solidity_calls: + solidity_calls = [ir.function for ir in self.solidity_calls] + if SolidityFunction("revert()") not in solidity_calls: + if SolidityFunction("revert(string)") not in solidity_calls: return True return False @@ -373,44 +368,38 @@ def variables_written_as_expression(self, exprs: List[Expression]) -> None: ################################################################################### @property - def internal_calls(self) -> List["InternalCallType"]: + def internal_calls(self) -> List[InternalCall]: """ - list(Function or SolidityFunction): List of internal/soldiity function calls + list(InternalCall): List of IR operations with internal/solidity function calls """ return list(self._internal_calls) @property - def solidity_calls(self) -> List[SolidityFunction]: + def solidity_calls(self) -> List[SolidityCall]: """ - list(SolidityFunction): List of Soldity calls + list(SolidityCall): List of IR operations with solidity calls """ return list(self._solidity_calls) @property - def high_level_calls(self) -> List["HighLevelCallType"]: + def high_level_calls(self) -> List[HighLevelCall]: """ - list((Contract, Function|Variable)): - List of high level calls (external calls). - A variable is called in case of call to a public state variable + list(HighLevelCall): List of IR operations with high level calls (external calls). Include library calls """ return list(self._high_level_calls) @property - def library_calls(self) -> List["LibraryCallType"]: + def library_calls(self) -> List[LibraryCall]: """ - list((Contract, Function)): - Include library calls + list(LibraryCall): List of IR operations with library calls. """ return list(self._library_calls) @property - def low_level_calls(self) -> List["LowLevelCallType"]: + def low_level_calls(self) -> List[LowLevelCall]: """ - list((Variable|SolidityVariable, str)): List of low_level call - A low level call is defined by - - the variable called - - the name of the function (call/delegatecall/codecall) + list(LowLevelCall): List of IR operations with low_level call """ return list(self._low_level_calls) @@ -529,9 +518,9 @@ def contains_require_or_assert(self) -> bool: bool: True if the node has a require or assert call """ return any( - c.name + ir.function.name in ["require(bool)", "require(bool,string)", "require(bool,error)", "assert(bool)"] - for c in self.internal_calls + for ir in self.internal_calls ) def contains_if(self, include_loop: bool = True) -> bool: @@ -895,11 +884,11 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements self._vars_written.append(var) if isinstance(ir, InternalCall): - self._internal_calls.append(ir.function) + self._internal_calls.append(ir) if isinstance(ir, SolidityCall): # TODO: consider removing dependancy of solidity_call to internal_call - self._solidity_calls.append(ir.function) - self._internal_calls.append(ir.function) + self._solidity_calls.append(ir) + self._internal_calls.append(ir) if ( isinstance(ir, SolidityCall) and ir.function == SolidityFunction("sstore(uint256,uint256)") @@ -917,22 +906,22 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements self._vars_read.append(ir.arguments[0]) if isinstance(ir, LowLevelCall): assert isinstance(ir.destination, (Variable, SolidityVariable)) - self._low_level_calls.append((ir.destination, str(ir.function_name.value))) + self._low_level_calls.append(ir) elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall): # Todo investigate this if condition # It does seem right to compare against a contract # This might need a refactoring if isinstance(ir.destination.type, Contract): - self._high_level_calls.append((ir.destination.type, ir.function)) + self._high_level_calls.append((ir.destination.type, ir)) elif ir.destination == SolidityVariable("this"): func = self.function # Can't use this in a top level function assert isinstance(func, FunctionContract) - self._high_level_calls.append((func.contract, ir.function)) + self._high_level_calls.append((func.contract, ir)) else: try: # Todo this part needs more tests and documentation - self._high_level_calls.append((ir.destination.type.type, ir.function)) + self._high_level_calls.append((ir.destination.type.type, ir)) except AttributeError as error: # pylint: disable=raise-missing-from raise SlitherException( @@ -941,8 +930,8 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements elif isinstance(ir, LibraryCall): assert isinstance(ir.destination, Contract) assert isinstance(ir.function, Function) - self._high_level_calls.append((ir.destination, ir.function)) - self._library_calls.append((ir.destination, ir.function)) + self._high_level_calls.append((ir.destination, ir)) + self._library_calls.append(ir) self._vars_read = list(set(self._vars_read)) self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 3f97a33ed2..a2f24fc6cd 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -29,7 +29,6 @@ # pylint: disable=too-many-lines,too-many-instance-attributes,import-outside-toplevel,too-many-nested-blocks if TYPE_CHECKING: - from slither.utils.type_helpers import LibraryCallType, HighLevelCallType, InternalCallType from slither.core.declarations import ( Enum, EventContract, @@ -39,6 +38,7 @@ FunctionContract, CustomErrorContract, ) + from slither.slithir.operations import HighLevelCall, LibraryCall from slither.slithir.variables.variable import SlithIRVariable from slither.core.variables import Variable, StateVariable from slither.core.compilation_unit import SlitherCompilationUnit @@ -106,7 +106,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope self._is_incorrectly_parsed: bool = False self._available_functions_as_dict: Optional[Dict[str, "Function"]] = None - self._all_functions_called: Optional[List["InternalCallType"]] = None + self._all_functions_called: Optional[List["Function"]] = None self.compilation_unit: "SlitherCompilationUnit" = compilation_unit self.file_scope: "FileScope" = scope @@ -1023,15 +1023,21 @@ def get_functions_overridden_by(self, function: "Function") -> List["Function"]: ################################################################################### @property - def all_functions_called(self) -> List["InternalCallType"]: + def all_functions_called(self) -> List["Function"]: """ list(Function): List of functions reachable from the contract Includes super, and private/internal functions not shadowed """ + from slither.slithir.operations import Operation + if self._all_functions_called is None: all_functions = [f for f in self.functions + self.modifiers if not f.is_shadowed] # type: ignore all_callss = [f.all_internal_calls() for f in all_functions] + [list(all_functions)] - all_calls = [item for sublist in all_callss for item in sublist] + all_calls = [ + item.function if isinstance(item, Operation) else item + for sublist in all_callss + for item in sublist + ] all_calls = list(set(all_calls)) all_constructors = [c.constructor for c in self.inheritance if c.constructor] @@ -1069,18 +1075,18 @@ def all_state_variables_read(self) -> List["StateVariable"]: return list(set(all_state_variables_read)) @property - def all_library_calls(self) -> List["LibraryCallType"]: + def all_library_calls(self) -> List["LibraryCall"]: """ - list((Contract, Function): List all of the libraries func called + list(LibraryCall): List all of the libraries func called """ all_high_level_callss = [f.all_library_calls() for f in self.functions + self.modifiers] # type: ignore all_high_level_calls = [item for sublist in all_high_level_callss for item in sublist] return list(set(all_high_level_calls)) @property - def all_high_level_calls(self) -> List["HighLevelCallType"]: + def all_high_level_calls(self) -> List[Tuple["Contract", "HighLevelCall"]]: """ - list((Contract, Function|Variable)): List all of the external high level calls + list(Tuple("Contract", "HighLevelCall")): List all of the external high level calls """ all_high_level_callss = [f.all_high_level_calls() for f in self.functions + self.modifiers] # type: ignore all_high_level_calls = [item for sublist in all_high_level_callss for item in sublist] diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 6e8968dfb2..b91e58f24c 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -31,19 +31,20 @@ # pylint: disable=import-outside-toplevel,too-many-instance-attributes,too-many-statements,too-many-lines if TYPE_CHECKING: - from slither.utils.type_helpers import ( - InternalCallType, - LowLevelCallType, - HighLevelCallType, - LibraryCallType, - ) from slither.core.declarations import Contract, FunctionContract from slither.core.cfg.node import Node, NodeType from slither.core.variables.variable import Variable from slither.slithir.variables.variable import SlithIRVariable from slither.slithir.variables import LocalIRVariable from slither.core.expressions.expression import Expression - from slither.slithir.operations import Operation + from slither.slithir.operations import ( + HighLevelCall, + InternalCall, + LibraryCall, + LowLevelCall, + SolidityCall, + Operation, + ) from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.scope.scope import FileScope @@ -149,11 +150,11 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: self._vars_read_or_written: List["Variable"] = [] self._solidity_vars_read: List["SolidityVariable"] = [] self._state_vars_written: List["StateVariable"] = [] - self._internal_calls: List["InternalCallType"] = [] - self._solidity_calls: List["SolidityFunction"] = [] - self._low_level_calls: List["LowLevelCallType"] = [] - self._high_level_calls: List["HighLevelCallType"] = [] - self._library_calls: List["LibraryCallType"] = [] + self._internal_calls: List["InternalCall"] = [] + self._solidity_calls: List["SolidityCall"] = [] + self._low_level_calls: List["LowLevelCall"] = [] + self._high_level_calls: List[Tuple["Contract", "HighLevelCall"]] = [] + self._library_calls: List["LibraryCall"] = [] self._external_calls_as_expressions: List["Expression"] = [] self._expression_vars_read: List["Expression"] = [] self._expression_vars_written: List["Expression"] = [] @@ -169,11 +170,11 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: self._all_expressions: Optional[List["Expression"]] = None self._all_slithir_operations: Optional[List["Operation"]] = None - self._all_internals_calls: Optional[List["InternalCallType"]] = None - self._all_high_level_calls: Optional[List["HighLevelCallType"]] = None - self._all_library_calls: Optional[List["LibraryCallType"]] = None - self._all_low_level_calls: Optional[List["LowLevelCallType"]] = None - self._all_solidity_calls: Optional[List["SolidityFunction"]] = None + self._all_internals_calls: Optional[List["InternalCall"]] = None + self._all_high_level_calls: Optional[List[Tuple["Contract", "HighLevelCall"]]] = None + self._all_library_calls: Optional[List["LibraryCall"]] = None + self._all_low_level_calls: Optional[List["LowLevelCall"]] = None + self._all_solidity_calls: Optional[List["SolidityCall"]] = None self._all_variables_read: Optional[List["Variable"]] = None self._all_variables_written: Optional[List["Variable"]] = None self._all_state_variables_read: Optional[List["StateVariable"]] = None @@ -857,43 +858,42 @@ def slithir_variables(self) -> List["SlithIRVariable"]: ################################################################################### @property - def internal_calls(self) -> List["InternalCallType"]: + def internal_calls(self) -> List["InternalCall"]: """ - list(Function or SolidityFunction): List of function calls (that does not create a transaction) + list(InternalCall): List of IR operations for internal calls """ return list(self._internal_calls) @property - def solidity_calls(self) -> List[SolidityFunction]: + def solidity_calls(self) -> List["SolidityCall"]: """ - list(SolidityFunction): List of Soldity calls + list(SolidityCall): List of IR operations for Solidity calls """ return list(self._solidity_calls) @property - def high_level_calls(self) -> List["HighLevelCallType"]: + def high_level_calls(self) -> List[Tuple["Contract", "HighLevelCall"]]: """ - list((Contract, Function|Variable)): - List of high level calls (external calls). + list(Tuple(Contract, "HighLevelCall")): List of call target contract and IR of the high level call A variable is called in case of call to a public state variable Include library calls """ return list(self._high_level_calls) @property - def library_calls(self) -> List["LibraryCallType"]: + def library_calls(self) -> List["LibraryCall"]: """ - list((Contract, Function)): + list(LibraryCall): List of IR operations for library calls """ return list(self._library_calls) @property - def low_level_calls(self) -> List["LowLevelCallType"]: + def low_level_calls(self) -> List["LowLevelCall"]: """ - list((Variable|SolidityVariable, str)): List of low_level call + list(LowLevelCall): List of IR operations for low level calls A low level call is defined by - the variable called - - the name of the function (call/delegatecall/codecall) + - the name of the function (call/delegatecall/callcode) """ return list(self._low_level_calls) @@ -1121,10 +1121,14 @@ def _explore_functions(self, f_new_values: Callable[["Function"], List]) -> List values = f_new_values(self) explored = [self] to_explore = [ - c for c in self.internal_calls if isinstance(c, Function) and c not in explored + ir.function + for ir in self.internal_calls + if isinstance(ir.function, Function) and ir.function not in explored ] to_explore += [ - c for (_, c) in self.library_calls if isinstance(c, Function) and c not in explored + ir.function + for ir in self.library_calls + if isinstance(ir.function, Function) and ir.function not in explored ] to_explore += [m for m in self.modifiers if m not in explored] @@ -1138,14 +1142,18 @@ def _explore_functions(self, f_new_values: Callable[["Function"], List]) -> List values += f_new_values(f) to_explore += [ - c - for c in f.internal_calls - if isinstance(c, Function) and c not in explored and c not in to_explore + ir.function + for ir in f.internal_calls + if isinstance(ir.function, Function) + and ir.function not in explored + and ir.function not in to_explore ] to_explore += [ - c - for (_, c) in f.library_calls - if isinstance(c, Function) and c not in explored and c not in to_explore + ir.function + for ir in f.library_calls + if isinstance(ir.function, Function) + and ir.function not in explored + and ir.function not in to_explore ] to_explore += [m for m in f.modifiers if m not in explored and m not in to_explore] @@ -1210,31 +1218,31 @@ def all_state_variables_written(self) -> List[StateVariable]: ) return self._all_state_variables_written - def all_internal_calls(self) -> List["InternalCallType"]: + def all_internal_calls(self) -> List["InternalCall"]: """recursive version of internal_calls""" if self._all_internals_calls is None: self._all_internals_calls = self._explore_functions(lambda x: x.internal_calls) return self._all_internals_calls - def all_low_level_calls(self) -> List["LowLevelCallType"]: + def all_low_level_calls(self) -> List["LowLevelCall"]: """recursive version of low_level calls""" if self._all_low_level_calls is None: self._all_low_level_calls = self._explore_functions(lambda x: x.low_level_calls) return self._all_low_level_calls - def all_high_level_calls(self) -> List["HighLevelCallType"]: + def all_high_level_calls(self) -> List[Tuple["Contract", "HighLevelCall"]]: """recursive version of high_level calls""" if self._all_high_level_calls is None: self._all_high_level_calls = self._explore_functions(lambda x: x.high_level_calls) return self._all_high_level_calls - def all_library_calls(self) -> List["LibraryCallType"]: + def all_library_calls(self) -> List["LibraryCall"]: """recursive version of library calls""" if self._all_library_calls is None: self._all_library_calls = self._explore_functions(lambda x: x.library_calls) return self._all_library_calls - def all_solidity_calls(self) -> List[SolidityFunction]: + def all_solidity_calls(self) -> List["SolidityCall"]: """recursive version of solidity calls""" if self._all_solidity_calls is None: self._all_solidity_calls = self._explore_functions(lambda x: x.solidity_calls) @@ -1653,7 +1661,9 @@ def _analyze_calls(self) -> None: internal_calls = [item for sublist in internal_calls for item in sublist] self._internal_calls = list(set(internal_calls)) - self._solidity_calls = [c for c in internal_calls if isinstance(c, SolidityFunction)] + self._solidity_calls = [ + ir for ir in internal_calls if isinstance(ir.function, SolidityFunction) + ] low_level_calls = [x.low_level_calls for x in self.nodes] low_level_calls = [x for x in low_level_calls if x] diff --git a/slither/detectors/assembly/incorrect_return.py b/slither/detectors/assembly/incorrect_return.py index bd5a6d8449..9052979ace 100644 --- a/slither/detectors/assembly/incorrect_return.py +++ b/slither/detectors/assembly/incorrect_return.py @@ -21,10 +21,8 @@ def _assembly_node(function: Function) -> Optional[SolidityCall]: """ - for ir in function.all_slithir_operations(): - if isinstance(ir, SolidityCall) and ir.function == SolidityFunction( - "return(uint256,uint256)" - ): + for ir in function.all_solidity_calls(): + if ir.function == SolidityFunction("return(uint256,uint256)"): return ir return None @@ -71,23 +69,23 @@ def _detect(self) -> List[Output]: for c in self.contracts: for f in c.functions_and_modifiers_declared: - for node in f.nodes: - if node.sons: - for function_called in node.internal_calls: - if isinstance(function_called, Function): - found = _assembly_node(function_called) - if found: - - info: DETECTOR_INFO = [ - f, - " calls ", - function_called, - " which halt the execution ", - found.node, - "\n", - ] - json = self.generate_result(info) - - results.append(json) + for ir in f.internal_calls: + if ir.node.sons: + function_called = ir.function + if isinstance(function_called, Function): + found = _assembly_node(function_called) + if found: + + info: DETECTOR_INFO = [ + f, + " calls ", + function_called, + " which halt the execution ", + found.node, + "\n", + ] + json = self.generate_result(info) + + results.append(json) return results diff --git a/slither/detectors/assembly/return_instead_of_leave.py b/slither/detectors/assembly/return_instead_of_leave.py index a1ad9c87e9..6037059744 100644 --- a/slither/detectors/assembly/return_instead_of_leave.py +++ b/slither/detectors/assembly/return_instead_of_leave.py @@ -6,7 +6,6 @@ DetectorClassification, DETECTOR_INFO, ) -from slither.slithir.operations import SolidityCall from slither.utils.output import Output @@ -42,15 +41,12 @@ class ReturnInsteadOfLeave(AbstractDetector): def _check_function(self, f: Function) -> List[Output]: results: List[Output] = [] - for node in f.nodes: - for ir in node.irs: - if isinstance(ir, SolidityCall) and ir.function == SolidityFunction( - "return(uint256,uint256)" - ): - info: DETECTOR_INFO = [f, " contains an incorrect call to return: ", node, "\n"] - json = self.generate_result(info) + for ir in f.solidity_calls: + if ir.function == SolidityFunction("return(uint256,uint256)"): + info: DETECTOR_INFO = [f, " contains an incorrect call to return: ", ir.node, "\n"] + json = self.generate_result(info) - results.append(json) + results.append(json) return results def _detect(self) -> List[Output]: diff --git a/slither/detectors/attributes/locked_ether.py b/slither/detectors/attributes/locked_ether.py index 91ec686503..efb376e229 100644 --- a/slither/detectors/attributes/locked_ether.py +++ b/slither/detectors/attributes/locked_ether.py @@ -59,7 +59,7 @@ def do_no_send_ether(contract: Contract) -> bool: explored += to_explore to_explore = [] for function in functions: - calls = [c.name for c in function.internal_calls] + calls = [ir.function.name for ir in function.internal_calls] if "suicide(address)" in calls or "selfdestruct(address)" in calls: return False for node in function.nodes: diff --git a/slither/detectors/compiler_bugs/array_by_reference.py b/slither/detectors/compiler_bugs/array_by_reference.py index 47e2af5819..e4dde43608 100644 --- a/slither/detectors/compiler_bugs/array_by_reference.py +++ b/slither/detectors/compiler_bugs/array_by_reference.py @@ -13,8 +13,6 @@ from slither.core.solidity_types.array_type import ArrayType from slither.core.variables.state_variable import StateVariable from slither.core.variables.local_variable import LocalVariable -from slither.slithir.operations.high_level_call import HighLevelCall -from slither.slithir.operations.internal_call import InternalCall from slither.core.cfg.node import Node from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract @@ -117,37 +115,26 @@ def detect_calls_passing_ref_to_function( # pylint: disable=too-many-nested-blocks for contract in contracts: for function in contract.functions_and_modifiers_declared: - for node in function.nodes: + for ir in [ir for _, ir in function.high_level_calls] + function.internal_calls: - # If this node has no expression, skip it. - if not node.expression: + # Verify this references a function in our array modifying functions collection. + if ir.function not in array_modifying_funcs: continue - for ir in node.irs: - # Verify this is a high level call. - if not isinstance(ir, (HighLevelCall, InternalCall)): + # Verify one of these parameters is an array in storage. + for (param, arg) in zip(ir.function.parameters, ir.arguments): + # Verify this argument is a variable that is an array type. + if not isinstance(arg, (StateVariable, LocalVariable)): continue - - # Verify this references a function in our array modifying functions collection. - if ir.function not in array_modifying_funcs: + if not isinstance(arg.type, ArrayType): continue - # Verify one of these parameters is an array in storage. - for (param, arg) in zip(ir.function.parameters, ir.arguments): - # Verify this argument is a variable that is an array type. - if not isinstance(arg, (StateVariable, LocalVariable)): - continue - if not isinstance(arg.type, ArrayType): - continue - - # If it is a state variable OR a local variable referencing storage, we add it to the list. - if ( - isinstance(arg, StateVariable) - or (isinstance(arg, LocalVariable) and arg.location == "storage") - ) and ( - isinstance(param.type, ArrayType) and param.location != "storage" - ): - results.append((node, arg, ir.function)) + # If it is a state variable OR a local variable referencing storage, we add it to the list. + if ( + isinstance(arg, StateVariable) + or (isinstance(arg, LocalVariable) and arg.location == "storage") + ) and (isinstance(param.type, ArrayType) and param.location != "storage"): + results.append((ir.node, arg, ir.function)) return results def _detect(self) -> List[Output]: diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20.py b/slither/detectors/erc/erc20/arbitrary_send_erc20.py index f060054590..4dc1f8db50 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20.py @@ -3,7 +3,7 @@ from slither.analyses.data_dependency.data_dependency import is_dependent from slither.core.cfg.node import Node from slither.core.compilation_unit import SlitherCompilationUnit -from slither.core.declarations import Contract, Function, SolidityVariableComposed +from slither.core.declarations import Contract, Function, SolidityVariableComposed, FunctionContract from slither.core.declarations.solidity_variables import SolidityVariable from slither.slithir.operations import HighLevelCall, LibraryCall @@ -31,11 +31,11 @@ def permit_results(self) -> List[Node]: def _detect_arbitrary_from(self, contract: Contract) -> None: for f in contract.functions: all_high_level_calls = [ - f_called[1].solidity_signature - for f_called in f.high_level_calls - if isinstance(f_called[1], Function) + ir.function.solidity_signature + for _, ir in f.high_level_calls + if isinstance(ir.function, Function) ] - all_library_calls = [f_called[1].solidity_signature for f_called in f.library_calls] + all_library_calls = [ir.function.solidity_signature for ir in f.library_calls] if ( "transferFrom(address,address,uint256)" in all_high_level_calls or "safeTransferFrom(address,address,address,uint256)" in all_library_calls @@ -44,51 +44,50 @@ def _detect_arbitrary_from(self, contract: Contract) -> None: "permit(address,address,uint256,uint256,uint8,bytes32,bytes32)" in all_high_level_calls ): - ArbitrarySendErc20._arbitrary_from(f.nodes, self._permit_results) + ArbitrarySendErc20._arbitrary_from(f, self._permit_results) else: - ArbitrarySendErc20._arbitrary_from(f.nodes, self._no_permit_results) + ArbitrarySendErc20._arbitrary_from(f, self._no_permit_results) @staticmethod - def _arbitrary_from(nodes: List[Node], results: List[Node]) -> None: + def _arbitrary_from(function: FunctionContract, results: List[Node]) -> None: """Finds instances of (safe)transferFrom that do not use msg.sender or address(this) as from parameter.""" - for node in nodes: - for ir in node.irs: - if ( - isinstance(ir, HighLevelCall) - and isinstance(ir.function, Function) - and ir.function.solidity_signature == "transferFrom(address,address,uint256)" - and not ( - is_dependent( - ir.arguments[0], - SolidityVariableComposed("msg.sender"), - node, - ) - or is_dependent( - ir.arguments[0], - SolidityVariable("this"), - node, - ) + for _, ir in function.high_level_calls: + if ( + isinstance(ir, LibraryCall) + and ir.function.solidity_signature + == "safeTransferFrom(address,address,address,uint256)" + and not ( + is_dependent( + ir.arguments[1], + SolidityVariableComposed("msg.sender"), + ir.node, ) - ): - results.append(ir.node) - elif ( - isinstance(ir, LibraryCall) - and ir.function.solidity_signature - == "safeTransferFrom(address,address,address,uint256)" - and not ( - is_dependent( - ir.arguments[1], - SolidityVariableComposed("msg.sender"), - node, - ) - or is_dependent( - ir.arguments[1], - SolidityVariable("this"), - node, - ) + or is_dependent( + ir.arguments[1], + SolidityVariable("this"), + ir.node, ) - ): - results.append(ir.node) + ) + ): + results.append(ir.node) + elif ( + isinstance(ir, HighLevelCall) + and isinstance(ir.function, Function) + and ir.function.solidity_signature == "transferFrom(address,address,uint256)" + and not ( + is_dependent( + ir.arguments[0], + SolidityVariableComposed("msg.sender"), + ir.node, + ) + or is_dependent( + ir.arguments[0], + SolidityVariable("this"), + ir.node, + ) + ) + ): + results.append(ir.node) def detect(self) -> None: """Detect transfers that use arbitrary `from` parameter.""" diff --git a/slither/detectors/functions/dead_code.py b/slither/detectors/functions/dead_code.py index 5cafa16504..3628d10a2c 100644 --- a/slither/detectors/functions/dead_code.py +++ b/slither/detectors/functions/dead_code.py @@ -48,13 +48,15 @@ def _detect(self) -> List[Output]: all_functionss_called = [ f.all_internal_calls() for f in contract.functions_entry_points ] - all_functions_called = [item for sublist in all_functionss_called for item in sublist] + all_functions_called = [ + item.function for sublist in all_functionss_called for item in sublist + ] functions_used |= { f.canonical_name for f in all_functions_called if isinstance(f, Function) } all_libss_called = [f.all_library_calls() for f in contract.functions_entry_points] all_libs_called: List[Tuple[Contract, Function]] = [ - item for sublist in all_libss_called for item in sublist + item.function for sublist in all_libss_called for item in sublist ] functions_used |= { lib[1].canonical_name for lib in all_libs_called if isinstance(lib, tuple) diff --git a/slither/detectors/functions/external_function.py b/slither/detectors/functions/external_function.py index 5858c2baf2..d9cc2bc361 100644 --- a/slither/detectors/functions/external_function.py +++ b/slither/detectors/functions/external_function.py @@ -13,8 +13,7 @@ make_solc_versions, ) from slither.formatters.functions.external_function import custom_format -from slither.slithir.operations import InternalCall, InternalDynamicCall -from slither.slithir.operations import SolidityCall +from slither.slithir.operations import InternalDynamicCall from slither.utils.output import Output @@ -55,11 +54,11 @@ def detect_functions_called(contract: Contract) -> List[Function]: for func in contract.all_functions_called: if not isinstance(func, Function): continue - # Loop through all nodes in the function, add all calls to a list. - for node in func.nodes: - for ir in node.irs: - if isinstance(ir, (InternalCall, SolidityCall)): - result.append(ir.function) + + # Loop through all internal and solidity calls in the function, add them to a list. + for ir in func.internal_calls + func.solidity_calls: + result.append(ir.function) + return result @staticmethod @@ -101,6 +100,7 @@ def get_base_most_function(function: FunctionContract) -> FunctionContract: # Somehow we couldn't resolve it, which shouldn't happen, as the provided function should be found if we could # not find some any more basic. + # pylint: disable=broad-exception-raised raise Exception("Could not resolve the base-most function for the provided function.") @staticmethod diff --git a/slither/detectors/functions/modifier.py b/slither/detectors/functions/modifier.py index 7f14872663..a888d5b703 100644 --- a/slither/detectors/functions/modifier.py +++ b/slither/detectors/functions/modifier.py @@ -17,7 +17,7 @@ def is_revert(node: Node) -> bool: return node.type == NodeType.THROW or any( - c.name in ["revert()", "revert(string"] for c in node.internal_calls + ir.function.name in ["revert()", "revert(string"] for ir in node.internal_calls ) diff --git a/slither/detectors/functions/out_of_order_retryable.py b/slither/detectors/functions/out_of_order_retryable.py index db9096f95f..a11e31ef45 100644 --- a/slither/detectors/functions/out_of_order_retryable.py +++ b/slither/detectors/functions/out_of_order_retryable.py @@ -101,9 +101,9 @@ def _detect_multiple_tickets( # include ops from internal function calls internal_ops = [] - for internal_call in node.internal_calls: - if isinstance(internal_call, Function): - internal_ops += internal_call.all_slithir_operations() + for ir in node.internal_calls: + if isinstance(ir.function, Function): + internal_ops += ir.function.all_slithir_operations() # analyze node for retryable tickets for ir in node.irs + internal_ops: diff --git a/slither/detectors/functions/protected_variable.py b/slither/detectors/functions/protected_variable.py index 5796729262..b9260abd61 100644 --- a/slither/detectors/functions/protected_variable.py +++ b/slither/detectors/functions/protected_variable.py @@ -61,7 +61,9 @@ def _analyze_function(self, function: Function, contract: Contract) -> List[Outp if not function_protection: self.logger.error(f"{function_sig} not found") continue - if function_protection not in function.all_internal_calls(): + if function_protection not in [ + ir.function for ir in function.all_internal_calls() + ]: info: DETECTOR_INFO = [ function, " should have ", diff --git a/slither/detectors/functions/suicidal.py b/slither/detectors/functions/suicidal.py index f0af978ec7..7c7d87f8a6 100644 --- a/slither/detectors/functions/suicidal.py +++ b/slither/detectors/functions/suicidal.py @@ -59,7 +59,7 @@ def detect_suicidal_func(func: FunctionContract) -> bool: if func.visibility not in ["public", "external"]: return False - calls = [c.name for c in func.all_internal_calls()] + calls = [ir.function.name for ir in func.all_internal_calls()] if not ("suicide(address)" in calls or "selfdestruct(address)" in calls): return False diff --git a/slither/detectors/operations/encode_packed.py b/slither/detectors/operations/encode_packed.py index ea7b094df2..b661ddcd72 100644 --- a/slither/detectors/operations/encode_packed.py +++ b/slither/detectors/operations/encode_packed.py @@ -3,14 +3,14 @@ """ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.core.declarations.solidity_variables import SolidityFunction -from slither.slithir.operations import SolidityCall +from slither.core.declarations import Contract, SolidityFunction +from slither.core.variables import Variable from slither.analyses.data_dependency.data_dependency import is_tainted from slither.core.solidity_types import ElementaryType from slither.core.solidity_types import ArrayType -def _is_dynamic_type(arg): +def _is_dynamic_type(arg: Variable): """ Args: arg (function argument) @@ -25,7 +25,7 @@ def _is_dynamic_type(arg): return False -def _detect_abi_encodePacked_collision(contract): +def _detect_abi_encodePacked_collision(contract: Contract): """ Args: contract (Contract) @@ -35,22 +35,19 @@ def _detect_abi_encodePacked_collision(contract): ret = [] # pylint: disable=too-many-nested-blocks for f in contract.functions_and_modifiers_declared: - for n in f.nodes: - for ir in n.irs: - if isinstance(ir, SolidityCall) and ir.function == SolidityFunction( - "abi.encodePacked()" - ): - dynamic_type_count = 0 - for arg in ir.arguments: - if is_tainted(arg, contract) and _is_dynamic_type(arg): - dynamic_type_count += 1 - elif dynamic_type_count > 1: - ret.append((f, n)) - dynamic_type_count = 0 - else: - dynamic_type_count = 0 - if dynamic_type_count > 1: - ret.append((f, n)) + for ir in f.solidity_calls: + if ir.function == SolidityFunction("abi.encodePacked()"): + dynamic_type_count = 0 + for arg in ir.arguments: + if is_tainted(arg, contract) and _is_dynamic_type(arg): + dynamic_type_count += 1 + elif dynamic_type_count > 1: + ret.append((f, ir.node)) + dynamic_type_count = 0 + else: + dynamic_type_count = 0 + if dynamic_type_count > 1: + ret.append((f, ir.node)) return ret diff --git a/slither/detectors/operations/low_level_calls.py b/slither/detectors/operations/low_level_calls.py index 463c748757..4925fc4661 100644 --- a/slither/detectors/operations/low_level_calls.py +++ b/slither/detectors/operations/low_level_calls.py @@ -44,10 +44,9 @@ def detect_low_level_calls( ) -> List[Tuple[FunctionContract, List[Node]]]: ret = [] for f in [f for f in contract.functions if contract == f.contract_declarer]: - nodes = f.nodes - assembly_nodes = [n for n in nodes if self._contains_low_level_calls(n)] - if assembly_nodes: - ret.append((f, assembly_nodes)) + low_level_nodes = [ir.node for ir in f.low_level_calls] + if low_level_nodes: + ret.append((f, low_level_nodes)) return ret def _detect(self) -> List[Output]: diff --git a/slither/detectors/reentrancy/reentrancy.py b/slither/detectors/reentrancy/reentrancy.py index 8dd9aecc05..2982801cb7 100644 --- a/slither/detectors/reentrancy/reentrancy.py +++ b/slither/detectors/reentrancy/reentrancy.py @@ -145,15 +145,16 @@ def analyze_node(self, node: Node, detector: "Reentrancy") -> bool: ) slithir_operations = [] # Add the state variables written in internal calls - for internal_call in node.internal_calls: + for ir in node.internal_calls: # Filter to Function, as internal_call can be a solidity call - if isinstance(internal_call, Function): - for internal_node in internal_call.all_nodes(): + function = ir.function + if isinstance(function, Function): + for internal_node in function.all_nodes(): for read in internal_node.state_variables_read: state_vars_read[read].add(internal_node) for write in internal_node.state_variables_written: state_vars_written[write].add(internal_node) - slithir_operations += internal_call.all_slithir_operations() + slithir_operations += function.all_slithir_operations() contains_call = False diff --git a/slither/detectors/statements/assert_state_change.py b/slither/detectors/statements/assert_state_change.py index 769d730b82..d70495365f 100644 --- a/slither/detectors/statements/assert_state_change.py +++ b/slither/detectors/statements/assert_state_change.py @@ -30,22 +30,22 @@ def detect_assert_state_change( # Loop for each function and modifier. for function in contract.functions_declared + list(contract.modifiers_declared): - for node in function.nodes: + for ir_call in function.internal_calls: # Detect assert() calls - if any(c.name == "assert(bool)" for c in node.internal_calls) and ( + if ir_call.function.name == "assert(bool)" and ( # Detect direct changes to state - node.state_variables_written + ir_call.node.state_variables_written or # Detect changes to state via function calls any( ir - for ir in node.irs + for ir in ir_call.node.irs if isinstance(ir, InternalCall) and ir.function and ir.function.state_variables_written ) ): - results.append((function, node)) + results.append((function, ir_call.node)) # Return the resulting set of nodes return results diff --git a/slither/detectors/statements/controlled_delegatecall.py b/slither/detectors/statements/controlled_delegatecall.py index 32e59d6eb7..bf78b3bf98 100644 --- a/slither/detectors/statements/controlled_delegatecall.py +++ b/slither/detectors/statements/controlled_delegatecall.py @@ -8,20 +8,18 @@ DetectorClassification, DETECTOR_INFO, ) -from slither.slithir.operations import LowLevelCall from slither.utils.output import Output def controlled_delegatecall(function: FunctionContract) -> List[Node]: ret = [] - for node in function.nodes: - for ir in node.irs: - if isinstance(ir, LowLevelCall) and ir.function_name in [ - "delegatecall", - "callcode", - ]: - if is_tainted(ir.destination, function.contract): - ret.append(node) + for ir in function.low_level_calls: + if ir.function_name in [ + "delegatecall", + "callcode", + ]: + if is_tainted(ir.destination, function.contract): + ret.append(ir.node) return ret diff --git a/slither/detectors/statements/divide_before_multiply.py b/slither/detectors/statements/divide_before_multiply.py index e33477135d..6734fb239d 100644 --- a/slither/detectors/statements/divide_before_multiply.py +++ b/slither/detectors/statements/divide_before_multiply.py @@ -56,7 +56,7 @@ def is_assert(node: Node) -> bool: # Old Solidity code where using an internal 'assert(bool)' function # While we dont check that this function is correct, we assume it is # To avoid too many FP - if "assert(bool)" in [c.full_name for c in node.internal_calls]: + if "assert(bool)" in [ir.function.full_name for ir in node.internal_calls]: return True return False diff --git a/slither/detectors/statements/return_bomb.py b/slither/detectors/statements/return_bomb.py index 8b6cd07a29..6d7052cf40 100644 --- a/slither/detectors/statements/return_bomb.py +++ b/slither/detectors/statements/return_bomb.py @@ -9,7 +9,7 @@ DetectorClassification, DETECTOR_INFO, ) -from slither.slithir.operations import LowLevelCall, HighLevelCall +from slither.slithir.operations import HighLevelCall from slither.analyses.data_dependency.data_dependency import is_tainted from slither.utils.output import Output @@ -71,34 +71,31 @@ def is_dynamic_type(ty: Type) -> bool: def get_nodes_for_function(self, function: Function, contract: Contract) -> List[Node]: nodes = [] - for node in function.nodes: - for ir in node.irs: - if isinstance(ir, (HighLevelCall, LowLevelCall)): - if not is_tainted(ir.destination, contract): # type:ignore - # Only interested if the target address is controlled/tainted - continue - - if isinstance(ir, HighLevelCall) and isinstance(ir.function, Function): - # in normal highlevel calls return bombs are _possible_ - # if the return type is dynamic and the caller tries to copy and decode large data - has_dyn = False - if ir.function.return_type: - has_dyn = any( - self.is_dynamic_type(ty) for ty in ir.function.return_type - ) - - if not has_dyn: - continue - - # If a gas budget was specified then the - # user may not know about the return bomb - if ir.call_gas is None: - # if a gas budget was NOT specified then the caller - # may already suspect the call may spend all gas? - continue - - nodes.append(node) - # TODO: check that there is some state change after the call + + for ir in [ir for _, ir in function.high_level_calls] + function.low_level_calls: + if not is_tainted(ir.destination, contract): # type:ignore + # Only interested if the target address is controlled/tainted + continue + + if isinstance(ir, HighLevelCall) and isinstance(ir.function, Function): + # in normal highlevel calls return bombs are _possible_ + # if the return type is dynamic and the caller tries to copy and decode large data + has_dyn = False + if ir.function.return_type: + has_dyn = any(self.is_dynamic_type(ty) for ty in ir.function.return_type) + + if not has_dyn: + continue + + # If a gas budget was specified then the + # user may not know about the return bomb + if ir.call_gas is None: + # if a gas budget was NOT specified then the caller + # may already suspect the call may spend all gas? + continue + + nodes.append(ir.node) + # TODO: check that there is some state change after the call return nodes diff --git a/slither/detectors/statements/unprotected_upgradeable.py b/slither/detectors/statements/unprotected_upgradeable.py index d25aff187d..aeb785da36 100644 --- a/slither/detectors/statements/unprotected_upgradeable.py +++ b/slither/detectors/statements/unprotected_upgradeable.py @@ -7,23 +7,28 @@ DetectorClassification, DETECTOR_INFO, ) -from slither.slithir.operations import LowLevelCall, SolidityCall from slither.utils.output import Output def _can_be_destroyed(contract: Contract) -> List[Function]: targets = [] for f in contract.functions_entry_points: - for ir in f.all_slithir_operations(): - if ( - isinstance(ir, LowLevelCall) and ir.function_name in ["delegatecall", "codecall"] - ) or ( - isinstance(ir, SolidityCall) - and ir.function - in [SolidityFunction("suicide(address)"), SolidityFunction("selfdestruct(address)")] - ): + found = False + for ir in f.all_low_level_calls(): + if ir.function_name in ["delegatecall", "codecall"]: targets.append(f) + found = True break + + if not found: + for ir in f.all_solidity_calls(): + if ir.function in [ + SolidityFunction("suicide(address)"), + SolidityFunction("selfdestruct(address)"), + ]: + targets.append(f) + break + return targets @@ -35,8 +40,8 @@ def _has_initializing_protection(functions: List[Function]) -> bool: for m in f.modifiers: if m.name == "initializer": return True - for ifc in f.all_internal_calls(): - if ifc.name == "_disableInitializers": + for ir in f.all_internal_calls(): + if ir.function.name == "_disableInitializers": return True # to avoid future FPs in different modifier + function naming implementations, we can also implement a broader check for state var "_initialized" being written to in the constructor diff --git a/slither/detectors/variables/var_read_using_this.py b/slither/detectors/variables/var_read_using_this.py index 537eecf8a3..1e4787e363 100644 --- a/slither/detectors/variables/var_read_using_this.py +++ b/slither/detectors/variables/var_read_using_this.py @@ -7,7 +7,6 @@ DetectorClassification, DETECTOR_INFO, ) -from slither.slithir.operations.high_level_call import HighLevelCall from slither.utils.output import Output @@ -54,13 +53,11 @@ def _detect(self) -> List[Output]: @staticmethod def _detect_var_read_using_this(func: Function) -> List[Node]: results: List[Node] = [] - for node in func.nodes: - for ir in node.irs: - if isinstance(ir, HighLevelCall): - if ( - ir.destination == SolidityVariable("this") - and ir.is_static_call() - and ir.function.visibility == "public" - ): - results.append(node) + for _, ir in func.high_level_calls: + if ( + ir.destination == SolidityVariable("this") + and ir.is_static_call() + and ir.function.visibility == "public" + ): + results.append(ir.node) return sorted(results, key=lambda x: x.node_id) diff --git a/slither/printers/call/call_graph.py b/slither/printers/call/call_graph.py index 38225e6d7a..668606760b 100644 --- a/slither/printers/call/call_graph.py +++ b/slither/printers/call/call_graph.py @@ -10,6 +10,7 @@ from slither.core.declarations import Contract, FunctionContract from slither.core.declarations.function import Function +from slither.slithir.operations import HighLevelCall, InternalCall from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.variables.variable import Variable from slither.printers.abstract_printer import AbstractPrinter @@ -49,26 +50,26 @@ def _node(node: str, label: Optional[str] = None) -> str: def _process_internal_call( contract: Contract, function: Function, - internal_call: Union[Function, SolidityFunction], + internal_call: InternalCall, contract_calls: Dict[Contract, Set[str]], solidity_functions: Set[str], solidity_calls: Set[str], ) -> None: - if isinstance(internal_call, (Function)): + if isinstance(internal_call.function, (Function)): contract_calls[contract].add( _edge( _function_node(contract, function), - _function_node(contract, internal_call), + _function_node(contract, internal_call.function), ) ) - elif isinstance(internal_call, (SolidityFunction)): + elif isinstance(internal_call.function, (SolidityFunction)): solidity_functions.add( - _node(_solidity_function_node(internal_call)), + _node(_solidity_function_node(internal_call.function)), ) solidity_calls.add( _edge( _function_node(contract, function), - _solidity_function_node(internal_call), + _solidity_function_node(internal_call.function), ) ) @@ -112,29 +113,29 @@ def _render_solidity_calls(solidity_functions: Set[str], solidity_calls: Set[str def _process_external_call( contract: Contract, function: Function, - external_call: Tuple[Contract, Union[Function, Variable]], + external_call: Tuple[Contract, HighLevelCall], contract_functions: Dict[Contract, Set[str]], external_calls: Set[str], all_contracts: Set[Contract], ) -> None: - external_contract, external_function = external_call + external_contract, ir = external_call if not external_contract in all_contracts: return # add variable as node to respective contract - if isinstance(external_function, (Variable)): + if isinstance(ir.function, (Variable)): contract_functions[external_contract].add( _node( - _function_node(external_contract, external_function), - external_function.name, + _function_node(external_contract, ir.function), + ir.function.name, ) ) external_calls.add( _edge( _function_node(contract, function), - _function_node(external_contract, external_function), + _function_node(external_contract, ir.function), ) ) diff --git a/slither/printers/functions/authorization.py b/slither/printers/functions/authorization.py index 32efeaabeb..288392a468 100644 --- a/slither/printers/functions/authorization.py +++ b/slither/printers/functions/authorization.py @@ -19,7 +19,11 @@ class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): @staticmethod def get_msg_sender_checks(function: Function) -> List[str]: all_functions = ( - [f for f in function.all_internal_calls() if isinstance(f, Function)] + [ + ir.function + for ir in function.all_internal_calls() + if isinstance(ir.function, Function) + ] + [function] + [m for m in function.modifiers if isinstance(m, Function)] ) diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index 0c47fa0f98..35a6091935 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -140,8 +140,8 @@ def _extract_assert(contracts: List[Contract]) -> Dict[str, Dict[str, List[Dict] for contract in contracts: functions_using_assert = [] # Dict[str, List[Dict]] = defaultdict(list) for f in contract.functions_entry_points: - for v in f.all_solidity_calls(): - if v == SolidityFunction("assert(bool)"): + for ir in f.all_solidity_calls(): + if ir.function == SolidityFunction("assert(bool)"): functions_using_assert.append(_get_name(f)) break # Revert https://github.com/crytic/slither/pull/2105 until format is supported by echidna. @@ -156,7 +156,7 @@ def _extract_assert(contracts: List[Contract]) -> Dict[str, Dict[str, List[Dict] # Create a named tuple that is serialization in json def json_serializable(cls): - # pylint: disable=unnecessary-comprehension + # pylint: disable=unnecessary-comprehension,unnecessary-dunder-call # TODO: the next line is a quick workaround to prevent pylint from crashing # It can be removed once https://github.com/PyCQA/pylint/pull/3810 is merged my_super = super diff --git a/slither/printers/summary/modifier_calls.py b/slither/printers/summary/modifier_calls.py index cd6c4062e3..225376a3cf 100644 --- a/slither/printers/summary/modifier_calls.py +++ b/slither/printers/summary/modifier_calls.py @@ -29,12 +29,12 @@ def output(self, _filename): table = MyPrettyTable(["Function", "Modifiers"]) for function in contract.functions: modifiers = function.modifiers - for call in function.all_internal_calls(): - if isinstance(call, Function): - modifiers += call.modifiers - for (_, call) in function.all_library_calls(): - if isinstance(call, Function): - modifiers += call.modifiers + for ir in function.all_internal_calls(): + if isinstance(ir.function, Function): + modifiers += ir.function.modifiers + for ir in function.all_library_calls(): + if isinstance(ir.function, Function): + modifiers += ir.function.modifiers table.add_row([function.name, sorted([m.name for m in set(modifiers)])]) txt += "\n" + str(table) self.info(txt) diff --git a/slither/printers/summary/when_not_paused.py b/slither/printers/summary/when_not_paused.py index aaeeeacec2..fc96268ef3 100644 --- a/slither/printers/summary/when_not_paused.py +++ b/slither/printers/summary/when_not_paused.py @@ -11,8 +11,8 @@ def _use_modifier(function: Function, modifier_name: str = "whenNotPaused") -> bool: - for internal_call in function.all_internal_calls(): - if isinstance(internal_call, SolidityFunction): + for ir in function.all_internal_calls(): + if isinstance(ir, function, SolidityFunction): continue if any(modifier.name == modifier_name for modifier in function.modifiers): return True diff --git a/slither/tools/flattening/flattening.py b/slither/tools/flattening/flattening.py index 9cb2abc3f1..182333d3fc 100644 --- a/slither/tools/flattening/flattening.py +++ b/slither/tools/flattening/flattening.py @@ -294,10 +294,9 @@ def _export_list_used_contracts( # pylint: disable=too-many-branches self._export_list_used_contracts(inherited, exported, list_contract, list_top_level) # Find all the external contracts called - externals = contract.all_library_calls + contract.all_high_level_calls - # externals is a list of (contract, function) + # High level calls already includes library calls # We also filter call to itself to avoid infilite loop - externals = list({e[0] for e in externals if e[0] != contract}) + externals = list({e[0] for e in contract.all_high_level_calls if e[0] != contract}) for inherited in externals: self._export_list_used_contracts(inherited, exported, list_contract, list_top_level) diff --git a/slither/tools/possible_paths/possible_paths.py b/slither/tools/possible_paths/possible_paths.py index 6e836e76ab..15218a8720 100644 --- a/slither/tools/possible_paths/possible_paths.py +++ b/slither/tools/possible_paths/possible_paths.py @@ -123,10 +123,14 @@ def __find_target_paths( # Find all function calls in this function (except for low level) called_functions_list = [ - f for (_, f) in function.high_level_calls if isinstance(f, Function) + ir.function + for _, ir in function.high_level_calls + if isinstance(ir.function, Function) + ] + called_functions_list += [ir.function for ir in function.library_calls] + called_functions_list += [ + ir.function for ir in function.internal_calls if isinstance(ir.function, Function) ] - called_functions_list += [f for (_, f) in function.library_calls] - called_functions_list += [f for f in function.internal_calls if isinstance(f, Function)] called_functions = set(called_functions_list) # If any of our target functions are reachable from this function, it's a result. diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 22461dbcf6..59979bca64 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -123,7 +123,9 @@ def compare( ): continue modified_calls = [ - func for func in new_modified_functions if func in function.internal_calls + func + for func in new_modified_functions + if func in [ir.function for ir in function.internal_calls] ] tainted_vars = [ var @@ -179,7 +181,8 @@ def tainted_external_contracts(funcs: List[Function]) -> List[TaintedExternalCon tainted_list: list[TaintedExternalContract] = [] for func in funcs: - for contract, target in func.all_high_level_calls(): + for contract, ir in func.all_high_level_calls(): + target = ir.function if contract.is_library: # Not interested in library calls continue @@ -254,7 +257,11 @@ def tainted_inheriting_contracts( new_taint = TaintedExternalContract(c) for f in c.functions_declared: # Search for functions that call an inherited tainted function or access an inherited tainted variable - internal_calls = [c for c in f.all_internal_calls() if isinstance(c, Function)] + internal_calls = [ + ir.function + for ir in f.all_internal_calls() + if isinstance(ir.function, Function) + ] if any( call.canonical_name == t.canonical_name for t in tainted.tainted_functions diff --git a/tests/unit/slithir/vyper/test_ir_generation.py b/tests/unit/slithir/vyper/test_ir_generation.py index 73c9b5e70b..efcf5ce549 100644 --- a/tests/unit/slithir/vyper/test_ir_generation.py +++ b/tests/unit/slithir/vyper/test_ir_generation.py @@ -35,9 +35,9 @@ def bar(): interface = next(iter(x for x in sl.contracts if x.is_interface)) contract = next(iter(x for x in sl.contracts if not x.is_interface)) func = contract.get_function_from_signature("bar()") - (contract, function) = func.high_level_calls[0] + (contract, ir) = func.high_level_calls[0] assert contract == interface - assert function.signature_str == "foo() returns(int128,uint256)" + assert ir.function.signature_str == "foo() returns(int128,uint256)" def test_phi_entry_point_internal_call(slither_from_vyper_source):