diff --git a/fhevm/evm.go b/fhevm/evm.go index c6edf3f..6d09c53 100644 --- a/fhevm/evm.go +++ b/fhevm/evm.go @@ -191,26 +191,36 @@ func padArrayTo32Multiple(input []byte) []byte { func Create(evm EVMEnvironment, caller common.Address, code []byte, gas uint64, value *big.Int) (ret []byte, contractAddr common.Address, leftOverGas uint64, err error) { contractAddr = crypto.CreateAddress(caller, evm.GetNonce(caller)) protectedStorageAddr := fhevm_crypto.CreateProtectedStorageContractAddress(contractAddr) + ret, contractAddr, leftOverGas, err = evm.CreateContract(caller, code, leftOverGas, value, contractAddr) + if err != nil { + ret = nil + contractAddr = common.Address{} + return + } _, _, leftOverGas, err = evm.CreateContract(caller, nil, gas, big.NewInt(0), protectedStorageAddr) if err != nil { ret = nil contractAddr = common.Address{} return } - // TODO: consider reverting changes to `protectedStorageAddr` if actual contract creation fails. - return evm.CreateContract(caller, code, leftOverGas, value, contractAddr) + return } func Create2(evm EVMEnvironment, caller common.Address, code []byte, gas uint64, endowment *big.Int, salt *uint256.Int) (ret []byte, contractAddr common.Address, leftOverGas uint64, err error) { codeHash := crypto.Keccak256Hash(code) contractAddr = crypto.CreateAddress2(caller, salt.Bytes32(), codeHash.Bytes()) protectedStorageAddr := fhevm_crypto.CreateProtectedStorageContractAddress(contractAddr) - _, _, leftOverGas, err = evm.CreateContract2(caller, nil, common.Hash{}, gas, big.NewInt(0), protectedStorageAddr) + ret, contractAddr, leftOverGas, err = evm.CreateContract2(caller, code, codeHash, gas, endowment, contractAddr) if err != nil { ret = nil contractAddr = common.Address{} return } - // TODO: consider reverting changes to `protectedStorageAddr` if actual contract creation fails. - return evm.CreateContract2(caller, code, codeHash, leftOverGas, endowment, contractAddr) + _, _, leftOverGas, err = evm.CreateContract2(caller, nil, common.Hash{}, leftOverGas, big.NewInt(0), protectedStorageAddr) + if err != nil { + ret = nil + contractAddr = common.Address{} + return + } + return }