diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index bb8344d8fb..2157570101 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -12,16 +12,18 @@ MappingType, ArrayType, ElementaryType, + TypeAlias, ) -from slither.core.declarations import Structure, Enum, Contract +from slither.core.declarations import Structure, StructureContract, Enum, Contract if TYPE_CHECKING: from slither.core.declarations import FunctionContract, CustomErrorContract from slither.core.variables.state_variable import StateVariable from slither.core.variables.local_variable import LocalVariable + from slither.core.variables.structure_variable import StructureVariable -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments,too-many-locals,too-many-branches def generate_interface( contract: "Contract", unroll_structs: bool = True, @@ -56,12 +58,47 @@ def generate_interface( for enum in contract.enums: interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n" if include_structs: - for struct in contract.structures: + # Include structures defined in this contract and at the top level + structs = contract.structures + contract.compilation_unit.structures_top_level + # Function signatures may reference other structures as well + # Include structures defined in libraries used for them + for _for in contract.using_for.keys(): + if ( + isinstance(_for, UserDefinedType) + and isinstance(_for.type, StructureContract) + and _for.type not in structs + ): + structs.append(_for.type) + # Include any other structures used as function arguments/returns + for func in contract.functions_entry_points: + for arg in func.parameters + func.returns: + _type = arg.type + if isinstance(_type, ArrayType): + _type = _type.type + while isinstance(_type, MappingType): + _type = _type.type_to + if isinstance(_type, UserDefinedType): + _type = _type.type + if isinstance(_type, Structure) and _type not in structs: + structs.append(_type) + for struct in structs: interface += generate_struct_interface_str(struct, indent=4) + for elem in struct.elems_ordered: + if ( + isinstance(elem.type, UserDefinedType) + and isinstance(elem.type.type, StructureContract) + and elem.type.type not in structs + ): + structs.append(elem.type.type) for var in contract.state_variables_entry_points: - interface += f" function {generate_interface_variable_signature(var, unroll_structs)};\n" + # if any(func.name == var.name for func in contract.functions_entry_points): + # # ignore public variables that override a public function + # continue + var_sig = generate_interface_variable_signature(var, unroll_structs) + if var_sig is not None and var_sig != "": + interface += f" function {var_sig};\n" for func in contract.functions_entry_points: - if func.is_constructor or func.is_fallback or func.is_receive: + if func.is_constructor or func.is_fallback or func.is_receive or not func.is_implemented: continue interface += ( f" function {generate_interface_function_signature(func, unroll_structs)};\n" @@ -75,6 +112,10 @@ def generate_interface_variable_signature( ) -> Optional[str]: if var.visibility in ["private", "internal"]: return None + if isinstance(var.type, UserDefinedType) and isinstance(var.type.type, Structure): + for elem in var.type.type.elems_ordered: + if isinstance(elem.type, MappingType): + return "" if unroll_structs: params = [ convert_type_for_solidity_signature_to_string(x).replace("(", "").replace(")", "") @@ -93,6 +134,11 @@ def generate_interface_variable_signature( _type = _type.type_to while isinstance(_type, (ArrayType, UserDefinedType)): _type = _type.type + if isinstance(_type, TypeAlias): + _type = _type.type + if isinstance(_type, Structure): + if any(isinstance(elem.type, MappingType) for elem in _type.elems_ordered): + return "" ret = str(_type) if isinstance(_type, Structure) or (isinstance(_type, Type) and _type.is_dynamic): ret += " memory" @@ -125,6 +171,8 @@ def format_var(var: "LocalVariable", unroll: bool) -> str: .replace("(", "") .replace(")", "") ) + if var.type.is_dynamic: + return f"{_handle_dynamic_struct_elem(var.type)} {var.location}" if isinstance(var.type, ArrayType) and isinstance( var.type.type, (UserDefinedType, ElementaryType) ): @@ -135,12 +183,14 @@ def format_var(var: "LocalVariable", unroll: bool) -> str: + f" {var.location}" ) if isinstance(var.type, UserDefinedType): - if isinstance(var.type.type, (Structure, Enum)): + if isinstance(var.type.type, Structure): return f"{str(var.type.type)} memory" + if isinstance(var.type.type, Enum): + return str(var.type.type) if isinstance(var.type.type, Contract): return "address" - if var.type.is_dynamic: - return f"{var.type} {var.location}" + if isinstance(var.type, TypeAlias): + return str(var.type.type) return str(var.type) name, _, _ = func.signature @@ -154,6 +204,12 @@ def format_var(var: "LocalVariable", unroll: bool) -> str: view = " view" if func.view and not func.pure else "" pure = " pure" if func.pure else "" payable = " payable" if func.payable else "" + # Make sure the function doesn't return a struct with nested mappings + for ret in func.returns: + if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, Structure): + for elem in ret.type.type.elems_ordered: + if isinstance(elem.type, MappingType): + return "" returns = [format_var(ret, unroll_structs) for ret in func.returns] parameters = [format_var(param, unroll_structs) for param in func.parameters] _interface_signature_str = ( @@ -184,17 +240,49 @@ def generate_struct_interface_str(struct: "Structure", indent: int = 0) -> str: spaces += " " definition = f"{spaces}struct {struct.name} {{\n" for elem in struct.elems_ordered: - if isinstance(elem.type, UserDefinedType): - if isinstance(elem.type.type, (Structure, Enum)): + if elem.type.is_dynamic: + definition += f"{spaces} {_handle_dynamic_struct_elem(elem.type)} {elem.name};\n" + elif isinstance(elem.type, UserDefinedType): + if isinstance(elem.type.type, Structure): definition += f"{spaces} {elem.type.type} {elem.name};\n" - elif isinstance(elem.type.type, Contract): - definition += f"{spaces} address {elem.name};\n" + else: + definition += f"{spaces} {convert_type_for_solidity_signature_to_string(elem.type)} {elem.name};\n" + elif isinstance(elem.type, TypeAlias): + definition += f"{spaces} {elem.type.type} {elem.name};\n" else: definition += f"{spaces} {elem.type} {elem.name};\n" definition += f"{spaces}}}\n" return definition +def _handle_dynamic_struct_elem(elem_type: Type) -> str: + assert elem_type.is_dynamic + if isinstance(elem_type, ElementaryType): + return f"{elem_type}" + if isinstance(elem_type, ArrayType): + base_type = elem_type.type + if isinstance(base_type, UserDefinedType): + if isinstance(base_type.type, Contract): + return "address[]" + if isinstance(base_type.type, Enum): + return convert_type_for_solidity_signature_to_string(elem_type) + return f"{base_type.type.name}[]" + return f"{base_type}[]" + if isinstance(elem_type, MappingType): + type_to = elem_type.type_to + type_from = elem_type.type_from + if isinstance(type_from, UserDefinedType) and isinstance(type_from.type, Contract): + type_from = ElementaryType("address") + if isinstance(type_to, MappingType): + return f"mapping({type_from} => {_handle_dynamic_struct_elem(type_to)})" + if isinstance(type_to, UserDefinedType): + if isinstance(type_to.type, Contract): + return f"mapping({type_from} => address)" + return f"mapping({type_from} => {type_to.type.name})" + return f"{elem_type}" + return "" + + def generate_custom_error_interface( error: "CustomErrorContract", unroll_structs: bool = True ) -> str: diff --git a/tests/unit/utils/test_data/code_generation/CodeGeneration.sol b/tests/unit/utils/test_data/code_generation/CodeGeneration.sol index 6f1f63c72f..292a4b43f7 100644 --- a/tests/unit/utils/test_data/code_generation/CodeGeneration.sol +++ b/tests/unit/utils/test_data/code_generation/CodeGeneration.sol @@ -1,7 +1,10 @@ pragma solidity ^0.8.4; + +import "./IFee.sol"; + interface I { - enum SomeEnum { ONE, TWO, THREE } - error ErrorWithEnum(SomeEnum e); + enum SomeEnum { ONE, TWO, THREE } + error ErrorWithEnum(SomeEnum e); } contract TestContract is I { @@ -62,4 +65,10 @@ contract TestContract is I { function setOtherI(I _i) public { otherI = _i; } + + function newFee(uint128 fee) public returns (IFee.Fee memory) { + IFee.Fee memory _fee; + _fee.fee = fee; + return _fee; + } } \ No newline at end of file diff --git a/tests/unit/utils/test_data/code_generation/IFee.sol b/tests/unit/utils/test_data/code_generation/IFee.sol new file mode 100644 index 0000000000..17560f60ce --- /dev/null +++ b/tests/unit/utils/test_data/code_generation/IFee.sol @@ -0,0 +1,5 @@ +interface IFee { + struct Fee { + uint128 fee; + } +} diff --git a/tests/unit/utils/test_data/code_generation/TEST_generated_code.sol b/tests/unit/utils/test_data/code_generation/TEST_generated_code.sol index 373fba9ca2..5240e10633 100644 --- a/tests/unit/utils/test_data/code_generation/TEST_generated_code.sol +++ b/tests/unit/utils/test_data/code_generation/TEST_generated_code.sol @@ -14,6 +14,9 @@ interface ITestContract { struct Nested { St st; } + struct Fee { + uint128 fee; + } function stateA() external returns (uint256); function owner() external returns (address); function structsMap(address,uint256) external returns (uint256); @@ -26,5 +29,6 @@ interface ITestContract { function getSt(uint256) external view returns (uint256); function removeSt(uint256) external; function setOtherI(address) external; + function newFee(uint128) external returns (uint128); } diff --git a/tests/unit/utils/test_data/code_generation/TEST_generated_code_not_unrolled.sol b/tests/unit/utils/test_data/code_generation/TEST_generated_code_not_unrolled.sol index 0cc4dc0404..1154ec4cc4 100644 --- a/tests/unit/utils/test_data/code_generation/TEST_generated_code_not_unrolled.sol +++ b/tests/unit/utils/test_data/code_generation/TEST_generated_code_not_unrolled.sol @@ -14,6 +14,9 @@ interface ITestContract { struct Nested { St st; } + struct Fee { + uint128 fee; + } function stateA() external returns (uint256); function owner() external returns (address); function structsMap(address,uint256) external returns (St memory); @@ -26,5 +29,6 @@ interface ITestContract { function getSt(uint256) external view returns (St memory); function removeSt(St memory) external; function setOtherI(address) external; + function newFee(uint128) external returns (Fee memory); }