diff --git a/.env.sample b/.env.sample new file mode 100644 index 0000000..dd89db4 --- /dev/null +++ b/.env.sample @@ -0,0 +1,12 @@ +# .env.sample + +# URLs for DB connection +ETHEREUM_DB_URL= +GNOSIS_DB_URL= + +# URLs for Node provider connection +ETHEREUM_NODE_URL= +GNOSIS_NODE_URL= + +# optional +INFURA_KEY=infura_key_here diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c9bca91 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.env +__pycache__ +.pytest_cache +src/tempCodeRunnerFile.py \ No newline at end of file diff --git a/README.md b/README.md index 07a5c96..08479bb 100644 --- a/README.md +++ b/README.md @@ -1 +1,17 @@ -# token-imbalances \ No newline at end of file +# token-imbalances + +This script is to calculate the raw token imbalances before and after a settlement. + + +**Install requirements from root directory:** +```bash +pip install -r requirements.txt +``` + +**Environment Variables**: Make sure the `.env` file is correctly set up locally. You can use the `.env.sample` file as reference. + +**From the root directory, run:** + +```bash +python -m src.daemon +``` diff --git a/contracts/__init__.py b/contracts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/contracts/erc20_abi.py b/contracts/erc20_abi.py new file mode 100644 index 0000000..40e7d65 --- /dev/null +++ b/contracts/erc20_abi.py @@ -0,0 +1,222 @@ +erc20_abi = [ + { + "constant": True, + "inputs": [], + "name": "name", + "outputs": [ + { + "name": "", + "type": "string" + } + ], + "payable": False, + "stateMutability": "view", + "type": "function" + }, + { + "constant": False, + "inputs": [ + { + "name": "_spender", + "type": "address" + }, + { + "name": "_value", + "type": "uint256" + } + ], + "name": "approve", + "outputs": [ + { + "name": "", + "type": "bool" + } + ], + "payable": False, + "stateMutability": "nonpayable", + "type": "function" + }, + { + "constant": True, + "inputs": [], + "name": "totalSupply", + "outputs": [ + { + "name": "", + "type": "uint256" + } + ], + "payable": False, + "stateMutability": "view", + "type": "function" + }, + { + "constant": False, + "inputs": [ + { + "name": "_from", + "type": "address" + }, + { + "name": "_to", + "type": "address" + }, + { + "name": "_value", + "type": "uint256" + } + ], + "name": "transferFrom", + "outputs": [ + { + "name": "", + "type": "bool" + } + ], + "payable": False, + "stateMutability": "nonpayable", + "type": "function" + }, + { + "constant": True, + "inputs": [], + "name": "decimals", + "outputs": [ + { + "name": "", + "type": "uint8" + } + ], + "payable": False, + "stateMutability": "view", + "type": "function" + }, + { + "constant": True, + "inputs": [ + { + "name": "_owner", + "type": "address" + } + ], + "name": "balanceOf", + "outputs": [ + { + "name": "balance", + "type": "uint256" + } + ], + "payable": False, + "stateMutability": "view", + "type": "function" + }, + { + "constant": True, + "inputs": [], + "name": "symbol", + "outputs": [ + { + "name": "", + "type": "string" + } + ], + "payable": False, + "stateMutability": "view", + "type": "function" + }, + { + "constant": False, + "inputs": [ + { + "name": "_to", + "type": "address" + }, + { + "name": "_value", + "type": "uint256" + } + ], + "name": "transfer", + "outputs": [ + { + "name": "", + "type": "bool" + } + ], + "payable": False, + "stateMutability": "nonpayable", + "type": "function" + }, + { + "constant": True, + "inputs": [ + { + "name": "_owner", + "type": "address" + }, + { + "name": "_spender", + "type": "address" + } + ], + "name": "allowance", + "outputs": [ + { + "name": "", + "type": "uint256" + } + ], + "payable": False, + "stateMutability": "view", + "type": "function" + }, + { + "payable": True, + "stateMutability": "payable", + "type": "fallback" + }, + { + "anonymous": False, + "inputs": [ + { + "indexed": True, + "name": "owner", + "type": "address" + }, + { + "indexed": True, + "name": "spender", + "type": "address" + }, + { + "indexed": False, + "name": "value", + "type": "uint256" + } + ], + "name": "Approval", + "type": "event" + }, + { + "anonymous": False, + "inputs": [ + { + "indexed": True, + "name": "from", + "type": "address" + }, + { + "indexed": True, + "name": "to", + "type": "address" + }, + { + "indexed": False, + "name": "value", + "type": "uint256" + } + ], + "name": "Transfer", + "type": "event" + } +] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9d0c28b..01f581d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ python-dotenv==1.0.0 black==23.3.0 mypy==1.4.1 pylint==2.17.4 -pytest==7.4.0 \ No newline at end of file +pytest==7.4.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py new file mode 100644 index 0000000..617fc85 --- /dev/null +++ b/src/balanceof_imbalances.py @@ -0,0 +1,103 @@ +import sys +import os +# for debugging purposes +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from web3 import Web3 +from typing import Dict, Optional, Set +from src.config import ETHEREUM_NODE_URL +from src.constants import SETTLEMENT_CONTRACT_ADDRESS, NATIVE_ETH_TOKEN_ADDRESS +from contracts.erc20_abi import erc20_abi + +# conducting sanity test only for ethereum mainnet transactions + +class BalanceOfImbalances: + def __init__(self, ETHEREUM_NODE_URL: str): + self.web3 = Web3(Web3.HTTPProvider(ETHEREUM_NODE_URL)) + + def get_token_balance(self, token_address: str, account: str, block_identifier: int) -> Optional[int]: + """ Retrieve the ERC-20 token balance of an account at a given block. """ + token_contract = self.web3.eth.contract(address=token_address, abi=erc20_abi) + try: + return token_contract.functions.balanceOf(account).call(block_identifier=block_identifier) + except Exception as e: + print(f"Error fetching balance for token {token_address}: {e}") + return None + + def get_eth_balance(self, account: str, block_identifier: int) -> Optional[int]: + """Get the ETH balance for a given account and block number.""" + try: + return self.web3.eth.get_balance(account, block_identifier=block_identifier) + except Exception as e: + print(f"Error fetching ETH balance: {e}") + return None + + def extract_token_addresses(self, tx_receipt: Dict) -> Set[str]: + """Extract unique token addresses from 'Transfer' events in a transaction receipt.""" + token_addresses = set() + transfer_topics = { + self.web3.keccak(text="Transfer(address,address,uint256)").hex(), + self.web3.keccak(text="ERC20Transfer(address,address,uint256)").hex(), + self.web3.keccak(text="Withdrawal(address,uint256)").hex() + } + for log in tx_receipt['logs']: + if log['topics'][0].hex() in transfer_topics: + token_addresses.add(log['address']) + return token_addresses + + def get_transaction_receipt(self, tx_hash: str) -> Optional[Dict]: + """Fetch the transaction receipt for the given hash.""" + try: + return self.web3.eth.get_transaction_receipt(tx_hash) + except Exception as e: + print(f"Error fetching transaction receipt for hash {tx_hash}: {e}") + return None + + def get_balances(self, token_addresses: Set[str], block_number: int) -> Dict[str, Optional[int]]: + """Get balances for all tokens at the given block number.""" + balances = {} + balances[NATIVE_ETH_TOKEN_ADDRESS] = self.get_eth_balance(SETTLEMENT_CONTRACT_ADDRESS, block_number) + + for token_address in token_addresses: + balances[token_address] = self.get_token_balance(token_address, SETTLEMENT_CONTRACT_ADDRESS, block_number) + + return balances + + def calculate_imbalances(self, prev_balances: Dict[str, Optional[int]], final_balances: Dict[str, Optional[int]]) -> Dict[str, int]: + """Calculate imbalances between previous and final balances.""" + imbalances = {} + for token_address in prev_balances: + if prev_balances[token_address] is not None and final_balances[token_address] is not None: + imbalance = final_balances[token_address] - prev_balances[token_address] + imbalances[token_address] = imbalance + return imbalances + + def compute_imbalances(self, tx_hash: str) -> Dict[str, int]: + """Compute token imbalances before and after a transaction.""" + tx_receipt = self.get_transaction_receipt(tx_hash) + if tx_receipt is None: + return {} + + token_addresses = self.extract_token_addresses(tx_receipt) + if not token_addresses: + print("No tokens involved in this transaction.") + return {} + + prev_block = tx_receipt['blockNumber'] - 1 + final_block = tx_receipt['blockNumber'] + + prev_balances = self.get_balances(token_addresses, prev_block) + final_balances = self.get_balances(token_addresses, final_block) + + return self.calculate_imbalances(prev_balances, final_balances) + +def main(): + tx_hash = input("Enter transaction hash: ") + bo = BalanceOfImbalances(ETHEREUM_NODE_URL) + imbalances = bo.compute_imbalances(tx_hash) + print("Token Imbalances:") + for token_address, imbalance in imbalances.items(): + print(f"Token: {token_address}, Imbalance: {imbalance}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..d668d68 --- /dev/null +++ b/src/config.py @@ -0,0 +1,16 @@ +import os +from dotenv import load_dotenv + +load_dotenv() +ETHEREUM_NODE_URL = os.getenv('ETHEREUM_NODE_URL') +GNOSIS_NODE_URL = os.getenv('GNOSIS_NODE_URL') + +CHAIN_RPC_ENDPOINTS = { + 'Ethereum': ETHEREUM_NODE_URL, + 'Gnosis': GNOSIS_NODE_URL +} + +CHAIN_SLEEP_TIMES = { + 'Ethereum': 60, + 'Gnosis': 120 +} \ No newline at end of file diff --git a/src/constants.py b/src/constants.py new file mode 100644 index 0000000..5a0e951 --- /dev/null +++ b/src/constants.py @@ -0,0 +1,6 @@ +from web3 import Web3 + +SETTLEMENT_CONTRACT_ADDRESS = Web3.to_checksum_address('0x9008D19f58AAbD9eD0D60971565AA8510560ab41') +NATIVE_ETH_TOKEN_ADDRESS = Web3.to_checksum_address('0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee') +WETH_TOKEN_ADDRESS = Web3.to_checksum_address('0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2') +SDAI_TOKEN_ADDRESS = Web3.to_checksum_address('0x83F20F44975D03b1b09e64809B757c47f942BEeA') diff --git a/src/daemon.py b/src/daemon.py new file mode 100644 index 0000000..ad7333a --- /dev/null +++ b/src/daemon.py @@ -0,0 +1,93 @@ +import os +import time +import pandas as pd +from web3 import Web3 +from typing import List +from threading import Thread +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from src.imbalances_script import RawTokenImbalances +from src.config import CHAIN_RPC_ENDPOINTS, CHAIN_SLEEP_TIMES + +def get_web3_instance(chain_name: str) -> Web3: + return Web3(Web3.HTTPProvider(CHAIN_RPC_ENDPOINTS[chain_name])) + +def get_finalized_block_number(web3: Web3) -> int: + return web3.eth.block_number - 64 + +def create_db_connection(chain_name: str): + """function that creates a connection to the CoW db.""" + if chain_name == 'Ethereum': + db_url = os.getenv("ETHEREUM_DB_URL") + elif chain_name == 'Gnosis': + db_url = os.getenv("GNOSIS_DB_URL") + + return create_engine(f"postgresql+psycopg2://{db_url}") + +def fetch_transaction_hashes(db_connection: Engine, start_block: int, end_block: int) -> List[str]: + """Fetch transaction hashes beginning start_block.""" + query = f""" + SELECT tx_hash + FROM settlements + WHERE block_number >= {start_block} + AND block_number <= {end_block} + """ + + db_hashes = pd.read_sql(query, db_connection) + # converts hashes at memory location to hex + db_hashes['tx_hash'] = db_hashes['tx_hash'].apply(lambda x: f"0x{x.hex()}") + + return db_hashes['tx_hash'].tolist() + +def process_transactions(chain_name: str) -> None: + web3 = get_web3_instance(chain_name) + rt = RawTokenImbalances(web3, chain_name) + sleep_time = CHAIN_SLEEP_TIMES.get(chain_name) + db_connection = create_db_connection(chain_name) + + previous_block = get_finalized_block_number(web3) + unprocessed_txs = [] + + print(f"{chain_name} Daemon started.") + + while True: + try: + latest_block = get_finalized_block_number(web3) + new_txs = fetch_transaction_hashes(db_connection, previous_block, latest_block) + # add any unprocessed hashes for processing, then clear list of unprocessed + all_txs = new_txs + unprocessed_txs + unprocessed_txs.clear() + + for tx in all_txs: + print(f'Processing transaction on {chain_name}: {tx}') + try: + imbalances = rt.compute_imbalances(tx) + print(f"Token Imbalances on {chain_name}:") + for token_address, imbalance in imbalances.items(): + print(f"Token: {token_address}, Imbalance: {imbalance}") + except ValueError as e: + print(e) + unprocessed_txs.append(tx) + + print("Done checks..") + previous_block = latest_block + 1 + except ConnectionError as e: + print(f"Connection error processing transactions on {chain_name}: {e}") + except Exception as e: + print(f"Error processing transactions on {chain_name}: {e}") + + time.sleep(sleep_time) + +def main() -> None: + threads = [] + + for chain_name in CHAIN_RPC_ENDPOINTS.keys(): + thread = Thread(target=process_transactions, args=(chain_name,), daemon=True) + thread.start() + threads.append(thread) + + for thread in threads: + thread.join() + +if __name__ == "__main__": + main() diff --git a/src/imbalances_script.py b/src/imbalances_script.py new file mode 100644 index 0000000..6d2d3ae --- /dev/null +++ b/src/imbalances_script.py @@ -0,0 +1,277 @@ +""" +Steps for computing token imbalances: + +1. Get transaction receipt via tx hash -> get_transaction_receipt() +2. Obtain the transaction trace and extract actions from trace to identify actions + related to native ETH transfers -> get_transaction_trace() and extract_actions() +3. Calculate ETH imbalance via actions by identifying transfers in + and out of a contract address -> calculate_native_eth_imbalance() +4. Extract and categorize relevant events (such as ERC20 transfers, WETH withdrawals, + and sDAI transactions) from the transaction receipt. -> extract_events() +5. Process each event by first decoding it to retrieve event details, i.e. to_address, from_address + and transfer value -> decode_event() +6. If to_address or from_address match the contract address parameter, update inflows/outflows by + adding the transfer value to existing inflow/outflow for the token addresses. +7. Returning to calculate_imbalances(), which finds the imbalance for all token addresses using + inflow-outflow. +8. If actions are not None, it denotes an ETH transfer event, which involves reducing WETH withdrawal + amount- > update_weth_imbalance(). The ETH imbalance is also calculated via -> update_native_eth_imbalance(). +9. update_sdai_imbalance() is called in each iteration and only completes if there is an SDAI transfer + involved which has special handling for its events. +""" + +from web3.datastructures import AttributeDict +from typing import Dict, List, Optional, Tuple +from web3 import Web3 +from src.config import CHAIN_RPC_ENDPOINTS +from src.constants import (SETTLEMENT_CONTRACT_ADDRESS, NATIVE_ETH_TOKEN_ADDRESS, + WETH_TOKEN_ADDRESS, SDAI_TOKEN_ADDRESS) + +EVENT_TOPICS = { + 'Transfer': 'Transfer(address,address,uint256)', + 'ERC20Transfer': 'ERC20Transfer(address,address,uint256)', + 'WithdrawalWETH': 'Withdrawal(address,uint256)', + 'DepositSDAI': 'Deposit(address,address,uint256,uint256)', + 'WithdrawSDAI': 'Withdraw(address,address,address,uint256,uint256)', +} + +def compute_event_topics(web3: Web3) -> Dict[str, str]: + """Compute the event topics for all relevant events.""" + return {name: web3.keccak(text=text).hex() for name, text in EVENT_TOPICS.items()} + +def find_chain_with_tx(tx_hash: str) -> Tuple[str, Web3]: + """ + Find the chain where the transaction is present. + Returns the chain name and the web3 instance. Used for checking single tx hashes. + """ + for chain_name, url in CHAIN_RPC_ENDPOINTS.items(): + web3 = Web3(Web3.HTTPProvider(url)) + if not web3.is_connected(): + print(f"Could not connect to {chain_name}.") + continue + try: + web3.eth.get_transaction_receipt(tx_hash) + print(f"Transaction found on {chain_name}.") + return chain_name, web3 + except Exception as e: + print(f"Transaction not found on {chain_name}: {e}") + raise ValueError(f"Transaction hash {tx_hash} not found on any chain.") + +def _to_int(value: str | int) -> int: + """Convert hex string or integer to integer.""" + try: + return int(value, 16) if isinstance(value, str) and value.startswith('0x') else int(value) + except ValueError: + print(f"Error converting value {value} to integer.") + +class RawTokenImbalances: + def __init__(self, web3: Web3, chain_name: str): + self.web3 = web3 + self.chain_name = chain_name + + def get_transaction_receipt(self, tx_hash: str) -> Optional[Dict]: + """ + Get the transaction receipt from the provided web3 instance. + """ + try: + return self.web3.eth.get_transaction_receipt(tx_hash) + except Exception as e: + print(f"Error getting transaction receipt: {e}") + return None + + def get_transaction_trace(self, tx_hash: str) -> Optional[List[Dict]]: + """ Function used for retreiving trace to identify ETH transfers. """ + try: + res = self.web3.tracing.trace_transaction(tx_hash) + return res + except Exception as err: + print(f"Error occurred while fetching transaction trace: {err}") + return None + + def extract_actions(self, traces: List[AttributeDict], address: str) -> List[Dict]: + """ Identify transfer events in trace involving the specified contract. """ + normalized_address = Web3.to_checksum_address(address) + actions = [] + # input_field = '0x' denotes a native ETH transfer event, which we want to filter for + input_field: str = '0x' + for trace in traces: + if isinstance(trace, AttributeDict): + action = trace.get('action', {}) + input_value = action.get('input', b"").hex() + # filter out action if involved in an ETH transfer event + if input_value == input_field and ( + Web3.to_checksum_address(action.get('from', '')) == normalized_address or + Web3.to_checksum_address(action.get('to', '')) == normalized_address + ): + actions.append(dict(action)) + return actions + + def calculate_native_eth_imbalance(self, actions: List[Dict], address: str) -> int: + """Extract ETH imbalance from transfer actions.""" + # inflow is the total value transferred to address param + inflow = sum( + _to_int(action['value']) + for action in actions + if Web3.to_checksum_address(action.get('to', '')) == address + ) + # outflow is the total value transferred out of address param + outflow = sum( + _to_int(action['value']) + for action in actions + if Web3.to_checksum_address(action.get('from', '')) == address + ) + return inflow - outflow + + def extract_events(self, tx_receipt: Dict) -> Dict[str, List[Dict]]: + """Extract relevant events from the transaction receipt.""" + event_topics = compute_event_topics(self.web3) + # transfer_topics are filtered to find imbalances for most ERC-20 tokens + transfer_topics = {k: v for k, v in event_topics.items() if k in ['Transfer', 'ERC20Transfer']} + # other_topics is used to find imbalances for SDAI, ETH txss + other_topics = {k: v for k, v in event_topics.items() if k not in transfer_topics} + + events = {name: [] for name in EVENT_TOPICS} + for log in tx_receipt['logs']: + log_topic = log['topics'][0].hex() + if log_topic in transfer_topics.values(): + events['Transfer'].append(log) + else: + for event_name, topic in other_topics.items(): + if log_topic == topic: + events[event_name].append(log) + break + return events + + def decode_event(self, event: Dict) -> Tuple[Optional[str], Optional[str], Optional[int]]: + """ + Decode transfer and withdrawal events. + Returns from_address, to_address (for transfer), and value. + """ + try: + from_address = Web3.to_checksum_address("0x" + event['topics'][1].hex()[-40:]) + value_hex = event['data'] + + if isinstance(value_hex, bytes): + value = int.from_bytes(value_hex, byteorder='big') + else: + value = int(value_hex, 16) + + if len(event['topics']) > 2: # Transfer event + to_address = Web3.to_checksum_address("0x" + event['topics'][2].hex()[-40:]) + return from_address, to_address, value + else: # Withdrawal event + return from_address, None, value + except Exception as e: + print(f"Error decoding event: {str(e)}") + return None, None, None + + def process_event(self, event: Dict, inflows: Dict[str, int], outflows: Dict[str, int], address: str) -> None: + """Process a single event to update inflows and outflows.""" + from_address, to_address, value = self.decode_event(event) + if from_address is None or to_address is None: + return + if to_address == address: + inflows[event['address']] = inflows.get(event['address'], 0) + value + if from_address == address: + outflows[event['address']] = outflows.get(event['address'], 0) + value + + def calculate_imbalances(self, events: Dict[str, List[Dict]], address: str) -> Dict[str, int]: + """Calculate token imbalances from events.""" + inflows, outflows = {}, {} + for event in events['Transfer']: + self.process_event(event, inflows, outflows, address) + + imbalances = { + token_address: inflows.get(token_address, 0) - outflows.get(token_address, 0) + for token_address in set(inflows.keys()).union(outflows.keys()) + } + return imbalances + + def update_weth_imbalance(self, events: Dict[str, List[Dict]], actions: List[Dict], imbalances: Dict[str, int], address: str) -> None: + """Update the WETH imbalance in imbalances.""" + weth_inflow = imbalances.get(WETH_TOKEN_ADDRESS, 0) + weth_outflow = 0 + weth_withdrawals = 0 + for event in events['WithdrawalWETH']: + from_address, _, value = self.decode_event(event) + if from_address == address: + weth_withdrawals += value + imbalances[WETH_TOKEN_ADDRESS] = weth_inflow - weth_outflow - weth_withdrawals + + def update_native_eth_imbalance(self, imbalances: Dict[str, int], native_eth_imbalance: Optional[int]) -> None: + """Update the native ETH imbalance in imbalances.""" + if native_eth_imbalance is not None: + imbalances[NATIVE_ETH_TOKEN_ADDRESS] = native_eth_imbalance + + def decode_sdai_event(self, event: Dict) -> int | None: + """Decode sDAI event.""" + try: + # SDAI event has hex value at the end, which needs to be extracted + value_hex = event['data'][-30:] + if isinstance(value_hex, bytes): + value = int.from_bytes(value_hex, byteorder='big') + else: + value = int(value_hex, 16) + return value + except Exception as e: + print(f"Error decoding sDAI event: {str(e)}") + return None + + def process_sdai_event(self, event: Dict, imbalances: Dict[str, int], is_deposit: bool) -> None: + """Process an sDAI deposit or withdrawal event to update imbalances.""" + decoded_event_value = self.decode_sdai_event(event) + if decoded_event_value is None: + return + if is_deposit: + imbalances[SDAI_TOKEN_ADDRESS] = imbalances.get(SDAI_TOKEN_ADDRESS, 0) + decoded_event_value + else: + imbalances[SDAI_TOKEN_ADDRESS] = imbalances.get(SDAI_TOKEN_ADDRESS, 0) - decoded_event_value + + def update_sdai_imbalance(self, events: Dict[str, List[Dict]], imbalances: Dict[str, int]) -> None: + """Update the sDAI imbalance in imbalances.""" + for event in events['DepositSDAI']: + if event['address'] == SDAI_TOKEN_ADDRESS: + self.process_sdai_event(event, imbalances, is_deposit=True) + for event in events['WithdrawSDAI']: + if event['address'] == SDAI_TOKEN_ADDRESS: + self.process_sdai_event(event, imbalances, is_deposit=False) + + def compute_imbalances(self, tx_hash: str) -> Dict[str, int]: + """Compute token imbalances for a given transaction hash.""" + tx_receipt = self.get_transaction_receipt(tx_hash) + if tx_receipt is None: + raise ValueError(f"Transaction hash {tx_hash} not found on chain {self.chain_name}.") + # find trace and actions from trace to track native ETH events + traces = self.get_transaction_trace(tx_hash) + native_eth_imbalance = None + actions = [] + if traces is not None: + actions = self.extract_actions(traces, SETTLEMENT_CONTRACT_ADDRESS) + native_eth_imbalance = self.calculate_native_eth_imbalance(actions, SETTLEMENT_CONTRACT_ADDRESS) + + events = self.extract_events(tx_receipt) + imbalances = self.calculate_imbalances(events, SETTLEMENT_CONTRACT_ADDRESS) + + if actions: + self.update_weth_imbalance(events, actions, imbalances, SETTLEMENT_CONTRACT_ADDRESS) + self.update_native_eth_imbalance(imbalances, native_eth_imbalance) + + self.update_sdai_imbalance(events, imbalances) + + return imbalances + +# main method for finding imbalance for a single tx hash +def main() -> None: + tx_hash = input("Enter transaction hash: ") + chain_name, web3 = find_chain_with_tx(tx_hash) + rt = RawTokenImbalances(web3, chain_name) + try: + imbalances = rt.compute_imbalances(tx_hash) + print(f"Token Imbalances on {chain_name}:") + for token_address, imbalance in imbalances.items(): + print(f"Token: {token_address}, Imbalance: {imbalance}") + except ValueError as e: + print(e) + +if __name__ == "__main__": + main() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/basic_test.py b/tests/basic_test.py new file mode 100644 index 0000000..378756c --- /dev/null +++ b/tests/basic_test.py @@ -0,0 +1,30 @@ +import pytest +from src.imbalances_script import RawTokenImbalances + +@pytest.mark.parametrize("tx_hash, expected_imbalances", [ + # Native ETH buy + ("0x749b557872d7d1f857719f619300df9621631f87338caa706154a3d7040fac9f", + { + "0x6B175474E89094C44Da98b954EedeAC495271d0F": 6286775129763176601, + "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": 12147750061816, + "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE": 221116798827683 + }), + # SDAI sell + ("0xdae82500c69c66db4e4a8c64e1d6a95f3cdc5cb81a5a00228ce6f247b9b8cefd", + { + "0x83F20F44975D03b1b09e64809B757c47f942BEeA": 90419674604117409792, + "0x6B175474E89094C44Da98b954EedeAC495271d0F": 360948092321672598, + }), + # ERC404 Token Buy + ("0xfcb1d20df8a90f5b4646a5d1818da407b3a78cfcb8291f477291f5c01115ca7a", + { + "0x9E9FbDE7C7a83c43913BddC8779158F1368F0413": -11207351687745217, + "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": 64641750602289665, + }), +]) + +def test_imbalances(tx_hash, expected_imbalances): + rt = RawTokenImbalances() + imbalances, _ = rt.compute_imbalances(tx_hash) + for token_address, expected_imbalance in expected_imbalances.items(): + assert imbalances.get(token_address) == expected_imbalance diff --git a/tests/compare_imbalances.py b/tests/compare_imbalances.py new file mode 100644 index 0000000..43900d3 --- /dev/null +++ b/tests/compare_imbalances.py @@ -0,0 +1,53 @@ +""" +Script can be used as a sanity test to compare raw imbalances via RawTokenImbalances class +and the BalanceOfImbalances class. +""" +import time +from web3 import Web3 +from src.config import ETHEREUM_NODE_URL +from src.imbalances_script import RawTokenImbalances +from src.balanceof_imbalances import BalanceOfImbalances +from src.daemon import get_web3_instance, create_db_connection, fetch_transaction_hashes + +RED_COLOR = "\033[91m" +RESET_COLOR = "\033[0m" + +def remove_zero_balances(balances: dict) -> dict: + """Remove entries with zero balance for all tokens.""" + return {token: balance for token, balance in balances.items() if balance != 0} + +def compare_imbalances(tx_hash: str, web3: Web3) -> None: + """Compare imbalances computed by RawTokenImbalances and BalanceOfImbalances.""" + raw_imbalances = RawTokenImbalances(web3, 'Ethereum') + balanceof_imbalances = BalanceOfImbalances(ETHEREUM_NODE_URL) + + raw_result = raw_imbalances.compute_imbalances(tx_hash) + balanceof_result = balanceof_imbalances.compute_imbalances(tx_hash) + + # Remove entries for native ETH with balance 0 + raw_result = remove_zero_balances(raw_result) + balanceof_result = remove_zero_balances(balanceof_result) + + if raw_result != balanceof_result: + print(f"{RED_COLOR}Imbalances do not match for tx: {tx_hash}.\nRaw: {raw_result}\nBalanceOf: {balanceof_result}{RESET_COLOR}") + else: + print(f"Imbalances match for transaction {tx_hash}.") + +def main() -> None: + start_block = int(input("Enter start block number: ")) + end_block = int(input("Enter end block number: ")) + + web3 = get_web3_instance("Ethereum") + db_connection = create_db_connection("Ethereum") + tx_hashes = fetch_transaction_hashes(db_connection, start_block, end_block) + + for tx_hash in tx_hashes: + try: + compare_imbalances(tx_hash, web3) + time.sleep(1) # Delay to avoid rate limits + + except Exception as e: + print(f"Error comparing imbalances for tx {tx_hash}: {e}") + +if __name__ == "__main__": + main()