diff --git a/slither/core/solidity_types/function_type.py b/slither/core/solidity_types/function_type.py index 2d644148e..8a328e361 100644 --- a/slither/core/solidity_types/function_type.py +++ b/slither/core/solidity_types/function_type.py @@ -36,7 +36,7 @@ def storage_size(self) -> Tuple[int, bool]: def is_dynamic(self) -> bool: return False - def __str__(self): + def __str__(self) -> str: # Use x.type # x.name may be empty params = ",".join([str(x.type) for x in self._params]) diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index f9ef19024..4af350d31 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from slither.core.expressions.expression import Expression + from slither.core.declarations import Function # pylint: disable=too-many-instance-attributes class Variable(SourceMapping): @@ -16,7 +17,7 @@ def __init__(self) -> None: super().__init__() self._name: Optional[str] = None self._initial_expression: Optional["Expression"] = None - self._type: Optional[Type] = None + self._type: Optional[Union[List, Type, "Function", str]] = None self._initialized: Optional[bool] = None self._visibility: Optional[str] = None self._is_constant = False @@ -77,7 +78,7 @@ def name(self, name: str) -> None: self._name = name @property - def type(self) -> Optional[Type]: + def type(self) -> Optional[Union[List, Type, "Function", str]]: return self._type @type.setter @@ -120,7 +121,7 @@ def visibility(self) -> Optional[str]: def visibility(self, v: str) -> None: self._visibility = v - def set_type(self, t: Optional[Union[List, Type, str]]) -> None: + def set_type(self, t: Optional[Union[List, Type, "Function", str]]) -> None: if isinstance(t, str): self._type = ElementaryType(t) return diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 7d8aa543b..cfeb0fe39 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -36,7 +36,6 @@ ) from slither.core.solidity_types.type import Type from slither.core.solidity_types.type_alias import TypeAliasTopLevel, TypeAlias -from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable from slither.slithir.exceptions import SlithIRError @@ -81,7 +80,7 @@ from slither.slithir.tmp_operations.tmp_new_structure import TmpNewStructure from slither.slithir.variables import Constant, ReferenceVariable, TemporaryVariable from slither.slithir.variables import TupleVariable -from slither.utils.function import get_function_id +from slither.utils.function import get_function_id, get_event_id from slither.utils.type import export_nested_types_from_variable from slither.utils.using_for import USING_FOR from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR @@ -279,20 +278,6 @@ def is_temporary(ins: Operation) -> bool: ) -def _make_function_type(func: Function) -> FunctionType: - parameters = [] - returns = [] - for parameter in func.parameters: - v = FunctionTypeVariable() - v.name = parameter.name - parameters.append(v) - for return_var in func.returns: - v = FunctionTypeVariable() - v.name = return_var.name - returns.append(v) - return FunctionType(parameters, returns) - - # endregion ################################################################################### ################################################################################### @@ -793,12 +778,29 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo assignment.set_node(ir.node) assignment.lvalue.set_type(ElementaryType("bytes4")) return assignment - if ir.variable_right == "selector" and isinstance( - ir.variable_left.type, (Function) + if ir.variable_right == "selector" and isinstance(ir.variable_left, (Event)): + # the event selector returns a bytes32, which is different from the error/function selector + # which returns a bytes4 + assignment = Assignment( + ir.lvalue, + Constant( + str(get_event_id(ir.variable_left.full_name)), ElementaryType("bytes32") + ), + ElementaryType("bytes32"), + ) + assignment.set_expression(ir.expression) + assignment.set_node(ir.node) + assignment.lvalue.set_type(ElementaryType("bytes32")) + return assignment + if ir.variable_right == "selector" and ( + isinstance(ir.variable_left.type, (Function)) ): assignment = Assignment( ir.lvalue, - Constant(str(get_function_id(ir.variable_left.type.full_name))), + Constant( + str(get_function_id(ir.variable_left.type.full_name)), + ElementaryType("bytes4"), + ), ElementaryType("bytes4"), ) assignment.set_expression(ir.expression) @@ -826,10 +828,9 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo targeted_function = next( (x for x in ir_func.contract.functions if x.name == str(ir.variable_right)) ) - t = _make_function_type(targeted_function) - ir.lvalue.set_type(t) + ir.lvalue.set_type(targeted_function) elif isinstance(left, (Variable, SolidityVariable)): - t = ir.variable_left.type + t = left.type elif isinstance(left, (Contract, Enum, Structure)): t = UserDefinedType(left) # can be None due to temporary operation @@ -846,10 +847,10 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo ir.lvalue.set_type(elems[elem].type) else: assert isinstance(type_t, Contract) - # Allow type propagtion as a Function + # Allow type propagation as a Function # Only for reference variables # This allows to track the selector keyword - # We dont need to check for function collision, as solc prevents the use of selector + # We don't need to check for function collision, as solc prevents the use of selector # if there are multiple functions with the same name f = next( (f for f in type_t.functions if f.name == ir.variable_right), @@ -858,7 +859,7 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo if f: ir.lvalue.set_type(f) else: - # Allow propgation for variable access through contract's name + # Allow propagation for variable access through contract's name # like Base_contract.my_variable v = next( ( diff --git a/slither/utils/function.py b/slither/utils/function.py index 34e6f221b..64c098bfd 100644 --- a/slither/utils/function.py +++ b/slither/utils/function.py @@ -12,3 +12,16 @@ def get_function_id(sig: str) -> int: digest = keccak.new(digest_bits=256) digest.update(sig.encode("utf-8")) return int("0x" + digest.hexdigest()[:8], 16) + + +def get_event_id(sig: str) -> int: + """' + Return the event id of the given signature + Args: + sig (str) + Return: + (int) + """ + digest = keccak.new(digest_bits=256) + digest.update(sig.encode("utf-8")) + return int("0x" + digest.hexdigest(), 16) diff --git a/tests/unit/slithir/test_data/selector.sol b/tests/unit/slithir/test_data/selector.sol new file mode 100644 index 000000000..60ec33cd6 --- /dev/null +++ b/tests/unit/slithir/test_data/selector.sol @@ -0,0 +1,47 @@ +interface I{ + function testFunction(uint a) external ; +} + +contract A{ + function testFunction() public{} +} + +contract Test{ + event TestEvent(); + struct St{ + uint a; + } + error TestError(); + + function testFunction(uint a) public {} + + + function testFunctionStructure(St memory s) public {} + + function returnEvent() public returns (bytes32){ + return TestEvent.selector; + } + + function returnError() public returns (bytes4){ + return TestError.selector; + } + + + function returnFunctionFromContract() public returns (bytes4){ + return I.testFunction.selector; + } + + + function returnFunction() public returns (bytes4){ + return this.testFunction.selector; + } + + function returnFunctionWithStructure() public returns (bytes4){ + return this.testFunctionStructure.selector; + } + + function returnFunctionThroughLocaLVar() public returns(bytes4){ + A a; + return a.testFunction.selector; + } +} \ No newline at end of file diff --git a/tests/unit/slithir/test_selector.py b/tests/unit/slithir/test_selector.py new file mode 100644 index 000000000..34643b58d --- /dev/null +++ b/tests/unit/slithir/test_selector.py @@ -0,0 +1,32 @@ +from pathlib import Path +from slither import Slither +from slither.slithir.operations import Assignment +from slither.slithir.variables import Constant + +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" + + +func_to_results = { + "returnEvent()": "16700440330922901039223184000601971290390760458944929668086539975128325467771", + "returnError()": "224292994", + "returnFunctionFromContract()": "890000139", + "returnFunction()": "890000139", + "returnFunctionWithStructure()": "1430834845", + "returnFunctionThroughLocaLVar()": "3781905051", +} + + +def test_enum_max_min(solc_binary_path) -> None: + solc_path = solc_binary_path("0.8.19") + slither = Slither(Path(TEST_DATA_DIR, "selector.sol").as_posix(), solc=solc_path) + + contract = slither.get_contract_from_name("Test")[0] + + for func_name, value in func_to_results.items(): + f = contract.get_function_from_signature(func_name) + assignment = f.slithir_operations[0] + assert ( + isinstance(assignment, Assignment) + and isinstance(assignment.rvalue, Constant) + and assignment.rvalue.value == value + )