From 1ca06f336b983f60307412f808e59e1ce12009c5 Mon Sep 17 00:00:00 2001 From: Karl Floersch Date: Wed, 3 May 2017 22:48:46 -0400 Subject: [PATCH] Add call() & raise TransactionFailed() --- ethereum/tools/tester2.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/ethereum/tools/tester2.py b/ethereum/tools/tester2.py index 1d14c942d..e226fa666 100644 --- a/ethereum/tools/tester2.py +++ b/ethereum/tools/tester2.py @@ -57,6 +57,10 @@ config_string = ':info' # configure_logging(config_string=config_string) +class TransactionFailed(Exception): + pass + + class ABIContract(object): # pylint: disable=too-few-public-methods def __init__(self, _chain, _abi, address): @@ -70,12 +74,15 @@ def __init__(self, _chain, _abi, address): self.translator = abi_translator for function_name in self.translator.function_data: - function = self.method_factory(_chain, function_name) + if self.translator.function_data[function_name]['is_constant']: + function = self.method_factory(_chain.call, function_name) + else: + function = self.method_factory(_chain.tx, function_name) method = types.MethodType(function, self) setattr(self, function_name, method) @staticmethod - def method_factory(test_chain, function_name): + def method_factory(tx_or_call, function_name): """ Return a proxy for calling a contract method with automatic encoding of argument and decoding of results. """ @@ -83,7 +90,7 @@ def method_factory(test_chain, function_name): def kall(self, *args, **kwargs): key = kwargs.get('sender', k0) - result = test_chain.tx( # pylint: disable=protected-access + result = tx_or_call( # pylint: disable=protected-access sender=key, to=self.address, value=kwargs.get('value', 0), @@ -121,9 +128,19 @@ def tx(self, sender=k0, to=b'\x00' * 20, value=0, data=b'', startgas=STARTGAS, g success, output = apply_transaction(self.head_state, transaction) self.block.transactions.append(transaction) if not success: - return False + raise TransactionFailed() return output + def call(self, sender=k0, to=b'\x00' * 20, value=0, data=b'', startgas=STARTGAS, gasprice=GASPRICE): + snapshot = self.snapshot() + try: + output = self.tx(sender, to, value, data, startgas, gasprice) + self.revert(snapshot) + return output + except Exception as e: + self.revert(snapshot) + raise e + def contract(self, sourcecode, args=[], sender=k0, value=0, language='evm', startgas=STARTGAS, gasprice=GASPRICE): if language == 'evm': assert len(args) == 0