diff --git a/HISTORY.rst b/HISTORY.rst index 48e4cac..340468c 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -2,6 +2,10 @@ History ------- + +0.8.151 [2024-11-18] +* fix: web3 batch for different batch sizes + 0.8.150 [2024-11-12] * fix: replace deprecated eth testnets with sepolia diff --git a/credmark/cmf/engine/web3/batch.py b/credmark/cmf/engine/web3/batch.py index 211b23b..4c4eb17 100644 --- a/credmark/cmf/engine/web3/batch.py +++ b/credmark/cmf/engine/web3/batch.py @@ -1,5 +1,16 @@ from abc import ABC, abstractmethod -from typing import Any, Iterator, Literal, NotRequired, Sequence, TypedDict, TypeVar, cast, overload +from typing import ( + Any, + Iterable, + Iterator, + Literal, + NotRequired, + Sequence, + TypedDict, + TypeVar, + cast, + overload, +) from eth_abi.exceptions import DecodingError from eth_typing import ChecksumAddress, HexStr @@ -17,37 +28,47 @@ U = TypeVar("U", bound=tuple) -Payload = TypedDict('Payload', { - 'data': bytes | HexStr | None, - 'output_type': Sequence[str], - 'to': ChecksumAddress, - 'fn_name': str | None, - 'from': NotRequired[ChecksumAddress] -}) +Payload = TypedDict( + "Payload", + { + "data": bytes | HexStr | None, + "output_type": Sequence[str], + "to": ChecksumAddress, + "fn_name": str | None, + "from": NotRequired[ChecksumAddress], + }, +) class Web3Batch(ABC): - def _divide_chunks(self, - big_list: Sequence[T], - chunk_size: int) -> Iterator[Sequence[T]]: - for i in range(0, len(big_list), chunk_size): - yield big_list[i:i + chunk_size] - - def _build_payload(self, - contract_functions: Sequence[ContractFunction], - *, - from_address: ChecksumAddress | None = None): + @abstractmethod + def batch_size(self) -> int: ... + + def _divide_chunks( + self, big_list: Sequence[T], batch_size: int | None = None + ) -> Iterator[Sequence[T]]: + if batch_size is None: + batch_size = self.batch_size() + + for i in range(0, len(big_list), batch_size): + yield big_list[i : i + batch_size] + + def _build_payload( + self, + contract_functions: Sequence[ContractFunction], + *, + from_address: ChecksumAddress | None = None, + ): payloads: list[Payload] = [] params: TxParams = {"gas": Wei(0), "gasPrice": Wei(0)} for contract_function in contract_functions: if not contract_function.address: - raise ValueError( - f"Missing address for batch_call in `{contract_function.fn_name}`" - ) + raise ValueError(f"Missing address for batch_call in `{contract_function.fn_name}`") data = contract_function.build_transaction(params).get("data") outputs: Sequence[ABIFunctionParams] = ( - contract_function.abi["outputs"] if "outputs" in contract_function.abi else []) + contract_function.abi["outputs"] if "outputs" in contract_function.abi else [] + ) output_types = [output["type"] if "type" in output else "" for output in outputs] payload: Payload = { "to": contract_function.address, @@ -61,15 +82,18 @@ def _build_payload(self, return payloads - def _build_payload_same_function(self, - contract_function: ContractFunction, - contract_addresses: Sequence[ChecksumAddress], - *, - from_address: ChecksumAddress | None = None): + def _build_payload_same_function( + self, + contract_function: ContractFunction, + contract_addresses: Sequence[ChecksumAddress], + *, + from_address: ChecksumAddress | None = None, + ): params: TxParams = {"gas": Wei(0), "gasPrice": Wei(0)} data = contract_function.build_transaction(params).get("data") - outputs: Sequence[ABIFunctionParams] = (contract_function.abi["outputs"] - if "outputs" in contract_function.abi else []) + outputs: Sequence[ABIFunctionParams] = ( + contract_function.abi["outputs"] if "outputs" in contract_function.abi else [] + ) output_types = [output["type"] if "type" in output else "" for output in outputs] fn_name = contract_function.fn_name @@ -87,10 +111,7 @@ def _build_payload_same_function(self, return payloads - def _decode_data(self, - fn_name: str | None, - output_types: Sequence[str], - return_data: bytes): + def _decode_data(self, fn_name: str | None, output_types: Sequence[str], return_data: bytes): w3 = credmark.cmf.model.ModelContext.current_context().web3 try: output_data = w3.codec.decode(output_types, return_data) @@ -110,20 +131,20 @@ def _decode_data(self, raise BadFunctionCallOutput(msg) from e - normalized_data = map_abi_data( - BASE_RETURN_NORMALIZERS, output_types, output_data) + normalized_data = map_abi_data(BASE_RETURN_NORMALIZERS, output_types, output_data) if len(normalized_data) == 1: return normalized_data[0] else: return normalized_data - def _decode_results(self, - payloads: list[Payload], - encoded_results: list[MulticallResult], - *, - require_success: bool = False, - ): + def _decode_results( + self, + payloads: list[Payload], + encoded_results: list[MulticallResult], + *, + require_success: bool = False, + ): results: list[MulticallDecodedResult] = [] for payload, multicall_result in zip(payloads, encoded_results): success = multicall_result.success @@ -135,62 +156,68 @@ def _decode_results(self, raise ModelEngineError("Multicall function failed") try: - data = self._decode_data(payload['fn_name'], - payload['output_type'], - encoded_data) \ - if multicall_result.success else encoded_data + data = ( + self._decode_data(payload["fn_name"], payload["output_type"], encoded_data) + if multicall_result.success + else encoded_data + ) except BadFunctionCallOutput: success = False data = multicall_result.return_data - results.append(MulticallDecodedResult( - success, - data, - )) + results.append( + MulticallDecodedResult( + success, + data, + ) + ) return results @abstractmethod - def _process_payloads(self, payloads: list[Payload]) -> list[MulticallResult]: - ... + def _process_payloads(self, payloads: list[Payload]) -> list[MulticallResult]: ... + + def _chunk_process_payloads(self, payloads: list[Payload]) -> list[MulticallResult]: + results: list[MulticallResult] = [] + for chunk in self._divide_chunks(payloads): + results.extend(self._process_payloads(list(chunk))) + return results @overload def call( - self, - contract_functions: Sequence[ContractFunction], - *, - require_success: bool = False, - batch_size: int = 100, - unwrap: Literal[True], - unwrap_default: Any = None, - return_type: type[U] = tuple[Any, ...]) -> U: - ... + self, + contract_functions: Sequence[ContractFunction], + *, + require_success: bool = False, + unwrap: Literal[True], + unwrap_default: Any = None, + return_type: type[U] = tuple[Any, ...], + ) -> U: ... @overload def call( - self, - contract_functions: Sequence[ContractFunction], - *, - require_success: bool = False, - batch_size: int = 100, - unwrap: Literal[False] = ...,) -> tuple[MulticallDecodedResult[Any], ...]: - ... + self, + contract_functions: Sequence[ContractFunction], + *, + require_success: bool = False, + unwrap: Literal[False] = ..., + ) -> tuple[MulticallDecodedResult[Any], ...]: ... def call( - self, - contract_functions: Sequence[ContractFunction], - *, - require_success: bool = False, - batch_size: int = 100, - unwrap: bool = False, - unwrap_default: Any = None, - return_type: type[U] = tuple[Any]) -> U: # pylint: disable=unused-argument + self, + contract_functions: Sequence[ContractFunction], + *, + require_success: bool = False, + unwrap: bool = False, + unwrap_default: Any = None, + return_type: type[U] = tuple[Any, ...], # pylint: disable=unused-argument + ) -> U: results: list[MulticallDecodedResult] = [] - for chunk in self._divide_chunks(contract_functions, batch_size): + for chunk in self._divide_chunks(contract_functions): payloads = self._build_payload(chunk) encoded_results = self._process_payloads(payloads) - results.extend(self._decode_results(payloads, - encoded_results, - require_success=require_success)) + results.extend( + self._decode_results(payloads, encoded_results, require_success=require_success) + ) if unwrap: # pylint disable=consider-using-generator @@ -200,87 +227,89 @@ def call( @overload def call_same_function( - self, - contract_function: ContractFunction, - contract_addresses: Sequence[ChecksumAddress], - *, - require_success: bool = False, - batch_size: int = 100, - fallback_functions: Sequence[ContractFunction] | None = None, - unwrap: Literal[True]) -> list[Any]: - ... + self, + contract_function: ContractFunction, + contract_addresses: Sequence[ChecksumAddress], + *, + require_success: bool = False, + fallback_functions: Sequence[ContractFunction] | None = None, + unwrap: Literal[True], + ) -> list[Any]: ... @overload def call_same_function( - self, - contract_function: ContractFunction, - contract_addresses: Sequence[ChecksumAddress], - *, - require_success: bool = False, - batch_size: int = 100, - fallback_functions: Sequence[ContractFunction] | None = None, - unwrap: Literal[False]) -> list[MulticallDecodedResult[Any]]: - ... + self, + contract_function: ContractFunction, + contract_addresses: Sequence[ChecksumAddress], + *, + require_success: bool = False, + fallback_functions: Sequence[ContractFunction] | None = None, + unwrap: Literal[False], + ) -> list[MulticallDecodedResult[Any]]: ... # pylint: disable=too-many-arguments @overload def call_same_function( - self, - contract_function: ContractFunction, - contract_addresses: Sequence[ChecksumAddress], - *, - require_success: bool = False, - batch_size: int = 100, - fallback_functions: Sequence[ContractFunction] | None = None, - unwrap: Literal[True], - unwrap_default: T = None, - return_type: type[T] | Any = Any) -> list[T]: - ... + self, + contract_function: ContractFunction, + contract_addresses: Sequence[ChecksumAddress], + *, + require_success: bool = False, + fallback_functions: Sequence[ContractFunction] | None = None, + unwrap: Literal[True], + unwrap_default: T = None, + return_type: type[T] | Any = Any, + ) -> list[T]: ... @overload def call_same_function( - self, - contract_function: ContractFunction, - contract_addresses: Sequence[ChecksumAddress], - *, - require_success: bool = False, - batch_size: int = 100, - fallback_functions: Sequence[ContractFunction] | None = None, - unwrap: Literal[False] = ..., - return_type: type[T] | Any = Any) -> list[MulticallDecodedResult[T]]: - ... + self, + contract_function: ContractFunction, + contract_addresses: Sequence[ChecksumAddress], + *, + require_success: bool = False, + fallback_functions: Sequence[ContractFunction] | None = None, + unwrap: Literal[False] = ..., + return_type: type[T] | Any = Any, + ) -> list[MulticallDecodedResult[T]]: ... # pylint: disable=too-many-arguments def call_same_function( # pylint: disable=too-many-locals - self, - contract_function: ContractFunction, - contract_addresses: Sequence[ChecksumAddress], - *, - require_success: bool = False, - batch_size: int = 100, - fallback_functions: Sequence[ContractFunction] | None = None, - unwrap: bool = False, - unwrap_default: Any = None, - return_type: type[T] | Any = Any) -> list[T | None] | list[MulticallDecodedResult[T]]: # pylint: disable=unused-argument + self, + contract_function: ContractFunction, + contract_addresses: Sequence[ChecksumAddress], + *, + require_success: bool = False, + fallback_functions: Sequence[ContractFunction] | None = None, + unwrap: bool = False, + unwrap_default: Any = None, + return_type: type[T] | Any = Any, + ) -> list[T | None] | list[MulticallDecodedResult[T]]: # pylint: disable=unused-argument results: list[MulticallDecodedResult[T]] = [] - for chunk in self._divide_chunks(contract_addresses, batch_size): + for chunk in self._divide_chunks(contract_addresses): payloads = self._build_payload_same_function(contract_function, chunk) encoded_results = self._process_payloads(payloads) - results.extend(self._decode_results( - payloads, - encoded_results, - require_success=require_success and not fallback_functions)) + results.extend( + self._decode_results( + payloads, + encoded_results, + require_success=require_success and not fallback_functions, + ) + ) - failed = [(idx, contract_addresses[idx]) - for idx, decoded_result in enumerate(results) if not decoded_result.success] + failed = [ + (idx, contract_addresses[idx]) + for idx, decoded_result in enumerate(results) + if not decoded_result.success + ] if failed and fallback_functions: fallback_results = self.call_same_function( fallback_functions[0], [address for (_, address) in failed], require_success=require_success, - batch_size=batch_size, - fallback_functions=fallback_functions[1:]) + fallback_functions=fallback_functions[1:], + ) for idx, fallback_result in enumerate(fallback_results): results[failed[idx][0]] = fallback_result diff --git a/credmark/cmf/engine/web3/batch_fallback.py b/credmark/cmf/engine/web3/batch_fallback.py index aee9a94..31407bd 100644 --- a/credmark/cmf/engine/web3/batch_fallback.py +++ b/credmark/cmf/engine/web3/batch_fallback.py @@ -5,10 +5,15 @@ class Web3BatchFallback(Web3Batch): - def __init__(self,): + def __init__( + self, + ): self._multicall = None self._rpc = None + def batch_size(self): + return self.multicall.batch_size() + @property def multicall(self) -> Web3BatchMulticall: """ @@ -33,4 +38,4 @@ def _process_payloads(self, payloads: list[Payload]) -> list[MulticallResult]: try: return self.multicall._process_payloads(payloads) # pylint: disable=protected-access except Exception: - return self.rpc._process_payloads(payloads) # pylint: disable=protected-access + return self.rpc._chunk_process_payloads(payloads) # pylint: disable=protected-access diff --git a/credmark/cmf/engine/web3/batch_multicall.py b/credmark/cmf/engine/web3/batch_multicall.py index 350c074..c5b1c7d 100644 --- a/credmark/cmf/engine/web3/batch_multicall.py +++ b/credmark/cmf/engine/web3/batch_multicall.py @@ -31,10 +31,12 @@ Network.Optimism: 4286263, Network.BSC: 15921452, Network.Polygon: 25770160, - Network.ArbitrumOne: 7654707, Network.Fantom: 33001987, - Network.Avalanche: 11907934, Network.Base: 11907934, + Network.ArbitrumOne: 7654707, + Network.Avalanche: 11907934, + Network.Linea: 42, + Network.Sepolia: 751532, }, ) @@ -42,6 +44,9 @@ class Web3BatchMulticall(Web3Batch): _contract: Contract | None = None + def batch_size(self) -> int: + return 100 + @property def contract(self) -> Contract: context = credmark.cmf.model.ModelContext.current_context() @@ -72,7 +77,6 @@ def _process_payloads(self, payloads: list[Payload]) -> list[MulticallResult]: ] try: result = self.contract.functions.tryAggregate(False, aggregate_parameter).call() - return [MulticallResult(success, data) for success, data in result] except (ContractLogicError, OverflowError, ValueError) as err: raise ModelEngineError("Multicall function failed") from err diff --git a/credmark/cmf/engine/web3/batch_rpc.py b/credmark/cmf/engine/web3/batch_rpc.py index a50c309..feec769 100644 --- a/credmark/cmf/engine/web3/batch_rpc.py +++ b/credmark/cmf/engine/web3/batch_rpc.py @@ -1,16 +1,19 @@ from typing import Any from eth_typing import URI - from hexbytes import HexBytes from web3._utils.request import get_response_from_post_request -from credmark.cmf.engine.web3.helper import MulticallResult + import credmark.cmf.model -from credmark.cmf.model.errors import ModelEngineError from credmark.cmf.engine.web3.batch import Payload, Web3Batch +from credmark.cmf.engine.web3.helper import MulticallResult +from credmark.cmf.model.errors import ModelEngineError class Web3BatchRpc(Web3Batch): + def batch_size(self) -> int: + return 10 + def _process_payloads(self, payloads: list[Payload]) -> list[MulticallResult]: if not payloads: return [] @@ -40,12 +43,11 @@ def _process_payloads(self, payloads: list[Payload]) -> list[MulticallResult]: ) provider_url = context._web3_registry.provider_url_for_chain_id( # pylint: disable=protected-access - context.chain_id) + context.chain_id + ) response = get_response_from_post_request(URI(provider_url), json=queries) if not response.ok: - raise ModelEngineError( - f"Error connecting to {provider_url}: {response.text}" - ) + raise ModelEngineError(f"Error connecting to {provider_url}: {response.text}") results = response.json() @@ -54,9 +56,7 @@ def _process_payloads(self, payloads: list[Payload]) -> list[MulticallResult]: return_values: list[MulticallResult] = [] errors = [] - for payload, result in zip( - payloads, sorted(results, key=lambda x: x["id"]) - ): + for payload, result in zip(payloads, sorted(results, key=lambda x: x["id"])): if "error" in result: fn_name = payload.get("fn_name", HexBytes(payload["data"] or "").hex()) errors.append(f'`{fn_name}`: {result["error"]}')