diff --git a/app/services/contract_metadata_service.py b/app/services/contract_metadata_service.py index 3949eb9..b5dd611 100644 --- a/app/services/contract_metadata_service.py +++ b/app/services/contract_metadata_service.py @@ -17,6 +17,7 @@ SourcifyClientConfigurationProblem, ) from safe_eth.eth.clients.etherscan_client_v2 import AsyncEtherscanClientV2 +from safe_eth.eth.utils import fast_to_checksum_address from sqlmodel.ext.asyncio.session import AsyncSession from app.config import settings @@ -191,6 +192,10 @@ async def process_contract_metadata( ) contract.abi_id = abi.id contract.name = contract_metadata.metadata.name + if contract_metadata.metadata.implementation: + contract.implementation = HexBytes( + contract_metadata.metadata.implementation + ) with_metadata = True else: with_metadata = False @@ -199,6 +204,14 @@ async def process_contract_metadata( await contract.update(session=session) return with_metadata + @staticmethod + def get_proxy_implementation_address( + contract_metadata: EnhancedContractMetadata, + ) -> ChecksumAddress | None: + if contract_metadata.metadata and contract_metadata.metadata.implementation: + return fast_to_checksum_address(contract_metadata.metadata.implementation) + return None + @staticmethod async def should_attempt_download( session: AsyncSession, diff --git a/app/tests/mocks/contract_metadata_mocks.py b/app/tests/mocks/contract_metadata_mocks.py index a9df5e1..e400290 100644 --- a/app/tests/mocks/contract_metadata_mocks.py +++ b/app/tests/mocks/contract_metadata_mocks.py @@ -45,6 +45,54 @@ ], False, ) + +etherscan_proxy_metadata_mock = ContractMetadata( + "Etherscan Uxio Proxy Contract", + [ + { + "anonymous": False, + "inputs": [ + { + "indexed": False, + "internalType": "address", + "name": "etherscanParam", + "type": "address", + } + ], + "name": "AddedOwner", + "type": "event", + }, + { + "constant": False, + "inputs": [ + { + "internalType": "address", + "name": "_masterCopy", + "type": "address", + } + ], + "name": "changeMasterCopy", + "outputs": [], + "payable": False, + "stateMutability": "nonpayable", + "type": "function", + }, + { + "constant": False, + "inputs": [ + {"internalType": "uint256", "name": "_threshold", "type": "uint256"} + ], + "name": "changeThreshold", + "outputs": [], + "payable": False, + "stateMutability": "nonpayable", + "type": "function", + }, + ], + False, + "0x43506849D7C04F9138D1A2050bbF3A0c054402dd", +) + sourcify_metadata_mock = ContractMetadata( "Sourcify Uxio Contract", [ diff --git a/app/tests/services/test_contract_metadata.py b/app/tests/services/test_contract_metadata.py index fd313d2..dd15905 100644 --- a/app/tests/services/test_contract_metadata.py +++ b/app/tests/services/test_contract_metadata.py @@ -26,6 +26,7 @@ from ..mocks.contract_metadata_mocks import ( blockscout_metadata_mock, etherscan_metadata_mock, + etherscan_proxy_metadata_mock, sourcify_metadata_mock, ) @@ -174,9 +175,30 @@ async def test_process_contract_metadata(self, session: AsyncSession): self.assertIsNotNone(contract) self.assertEqual(HexBytes(contract.address), HexBytes(random_address)) self.assertEqual(contract.name, etherscan_metadata_mock.name) + self.assertIsNone(contract.implementation) self.assertEqual(contract.abi.abi_json, etherscan_metadata_mock.abi) self.assertEqual(contract.chain_id, 1) self.assertEqual(contract.fetch_retries, 1) + + # New proxy contract + proxy_contract_data = EnhancedContractMetadata( + address=random_address, + metadata=etherscan_proxy_metadata_mock, + source=ClientSource.ETHERSCAN, + chain_id=1, + ) + await ContractMetadataService.process_contract_metadata( + session, proxy_contract_data + ) + proxy_contract = await Contract.get_contract( + session, address=HexBytes(random_address), chain_id=1 + ) + self.assertIsNotNone(proxy_contract) + self.assertEqual( + proxy_contract.implementation, + HexBytes("0x43506849d7c04f9138d1a2050bbf3a0c054402dd"), + ) + # Same contract shouldn't be updated without abi contract_data.metadata = None await ContractMetadataService.process_contract_metadata(session, contract_data) @@ -260,3 +282,30 @@ async def test_should_attempt_download(self, session: AsyncSession): session, fast_to_checksum_address(random_address), 100, 0 ) ) + + def test_get_proxy_implementation_address(self): + random_address = Account.create().address + proxy_contract_data = EnhancedContractMetadata( + address=random_address, + metadata=etherscan_proxy_metadata_mock, + source=ClientSource.ETHERSCAN, + chain_id=1, + ) + proxy_implementation_address = ( + ContractMetadataService.get_proxy_implementation_address( + proxy_contract_data + ) + ) + self.assertEqual( + proxy_implementation_address, "0x43506849D7C04F9138D1A2050bbF3A0c054402dd" + ) + + contract_data = EnhancedContractMetadata( + address=random_address, + metadata=etherscan_metadata_mock, + source=ClientSource.ETHERSCAN, + chain_id=1, + ) + self.assertIsNone( + ContractMetadataService.get_proxy_implementation_address(contract_data) + ) diff --git a/app/tests/workers/test_tasks.py b/app/tests/workers/test_tasks.py index 850b609..d9f7cf8 100644 --- a/app/tests/workers/test_tasks.py +++ b/app/tests/workers/test_tasks.py @@ -1,9 +1,13 @@ import json import unittest from typing import Any, Awaitable +from unittest import mock +from unittest.mock import MagicMock from dramatiq.worker import Worker +from eth_account import Account from hexbytes import HexBytes +from safe_eth.eth.clients import AsyncEtherscanClientV2 from sqlmodel.ext.asyncio.session import AsyncSession from app.datasources.db.database import database_session @@ -11,6 +15,10 @@ from app.workers.tasks import get_contract_metadata_task, redis_broker, test_task from ..datasources.db.db_async_conn import DbAsyncConn +from ..mocks.contract_metadata_mocks import ( + etherscan_metadata_mock, + etherscan_proxy_metadata_mock, +) class TestTasks(unittest.TestCase): @@ -62,8 +70,19 @@ async def asyncTearDown(self): await super().asyncTearDown() self.worker.stop() + def _wait_tasks_execution(self): + redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) + while len(redis_tasks) > 0: + redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) + + @mock.patch.object( + AsyncEtherscanClientV2, "async_get_contract_metadata", autospec=True + ) @database_session - async def test_get_contract_metadata_task(self, session: AsyncSession): + async def test_get_contract_metadata_task( + self, etherscan_get_contract_metadata_mock: MagicMock, session: AsyncSession + ): + etherscan_get_contract_metadata_mock.return_value = etherscan_metadata_mock contract_address = "0xd9Db270c1B5E3Bd161E8c8503c55cEABeE709552" chain_id = 100 get_contract_metadata_task.fn(contract_address, chain_id) @@ -71,3 +90,35 @@ async def test_get_contract_metadata_task(self, session: AsyncSession): session, HexBytes(contract_address), chain_id ) self.assertIsNotNone(contract) + self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 1) + + @mock.patch.object( + AsyncEtherscanClientV2, "async_get_contract_metadata", autospec=True + ) + @database_session + async def test_get_contract_metadata_task_with_proxy( + self, etherscan_get_contract_metadata_mock: MagicMock, session: AsyncSession + ): + etherscan_get_contract_metadata_mock.side_effect = [ + etherscan_proxy_metadata_mock, + etherscan_metadata_mock, + ] + contract_address = Account.create().address + proxy_implementation_address = "0x43506849D7C04F9138D1A2050bbF3A0c054402dd" + chain_id = 1 + + get_contract_metadata_task.fn(address=contract_address, chain_id=chain_id) + + self._wait_tasks_execution() + + contract = await Contract.get_contract( + session, HexBytes(contract_address), chain_id + ) + self.assertIsNotNone(contract) + + proxy_implementation = await Contract.get_contract( + session, HexBytes(proxy_implementation_address), chain_id + ) + self.assertIsNotNone(proxy_implementation) + + self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2) diff --git a/app/workers/tasks.py b/app/workers/tasks.py index f255ded..54b9b73 100644 --- a/app/workers/tasks.py +++ b/app/workers/tasks.py @@ -7,9 +7,11 @@ from safe_eth.eth.utils import fast_to_checksum_address from sqlmodel.ext.asyncio.session import AsyncSession -from app.config import settings -from app.datasources.db.database import database_session -from app.services.contract_metadata_service import get_contract_metadata_service +from ..config import settings +from ..datasources.db.database import database_session +from ..services.contract_metadata_service import get_contract_metadata_service + +logger = logging.getLogger(__name__) redis_broker = RedisBroker(url=settings.REDIS_URL) redis_broker.add_middleware(PeriodiqMiddleware(skip_delay=60)) @@ -28,7 +30,7 @@ def test_task(message: str) -> None: async def test_task(message: str) -> None: """ - logging.info(f"Message processed! -> {message}") + logger.info(f"Message processed! -> {message}") return @@ -42,8 +44,10 @@ async def get_contract_metadata_task( if await contract_metadata_service.should_attempt_download( session, address, chain_id, 0 ): - logging.info( - f"Downloading contract metadata for {address} and chain {chain_id}" + logger.info( + "Downloading contract metadata for contract=%s and chain=%s", + address, + chain_id, ) contract_metadata = await contract_metadata_service.get_contract_metadata( fast_to_checksum_address(address), chain_id @@ -52,12 +56,27 @@ async def get_contract_metadata_task( session, contract_metadata ) if result: - logging.info( - f"Success download contract metadata for {address} and chain {chain_id}" + logger.info( + "Success download contract metadata for contract=%s and chain=%s", + address, + chain_id, ) else: - logging.info( - f"Failed to download contract metadata for {address} and chain {chain_id}" + logger.info( + "Failed to download contract metadata for contract=%s and chain=%s", + address, + chain_id, + ) + + if proxy_implementation_address := contract_metadata_service.get_proxy_implementation_address( + contract_metadata + ): + logger.info( + "Downloading proxy implementation metadata from address=%s for contract=%s and chain=%s", + proxy_implementation_address, + address, + chain_id, ) + get_contract_metadata_task.send(proxy_implementation_address, chain_id) else: - logging.debug(f"Skipping contract with address {address} and chain {chain_id}") + logger.debug("Skipping contract=%s and chain=%s", address, chain_id)