From 60b8b9e56408705768d7a5a5b81b9b1d531b3128 Mon Sep 17 00:00:00 2001 From: David Date: Thu, 9 Nov 2023 14:35:59 +0800 Subject: [PATCH] Remove most geth dependencies --- crypto/crypto.go | 12 ---- fhevm/contract.go | 4 +- fhevm/evm.go | 37 +++++------ fhevm/instructions.go | 45 ++++++------- fhevm/interface.go | 21 +++--- fhevm/interpreter.go | 6 +- fhevm/precompiles.go | 93 +++++++++++++------------- fhevm/tfhe.go | 13 ++-- fhevm/utils.go | 149 ++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 250 insertions(+), 130 deletions(-) delete mode 100644 crypto/crypto.go create mode 100644 fhevm/utils.go diff --git a/crypto/crypto.go b/crypto/crypto.go deleted file mode 100644 index 148f86d..0000000 --- a/crypto/crypto.go +++ /dev/null @@ -1,12 +0,0 @@ -package crypto - -import ( - "github.com/ethereum/go-ethereum/common" - evm "github.com/ethereum/go-ethereum/crypto" -) - -// CreateProtectedStorageAddress creates an ethereum contract address for protected storage -// given the corresponding contract address -func CreateProtectedStorageContractAddress(b common.Address) common.Address { - return evm.CreateAddress(b, 0) -} diff --git a/fhevm/contract.go b/fhevm/contract.go index b424cef..7d8802e 100644 --- a/fhevm/contract.go +++ b/fhevm/contract.go @@ -1,7 +1,5 @@ package fhevm -import "github.com/ethereum/go-ethereum/common" - type Contract interface { - Address() common.Address + Address() Address } diff --git a/fhevm/evm.go b/fhevm/evm.go index 12822fd..43e8674 100644 --- a/fhevm/evm.go +++ b/fhevm/evm.go @@ -6,10 +6,7 @@ import ( "fmt" "math/big" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" "github.com/holiman/uint256" - fhevm_crypto "github.com/zama-ai/fhevm-go/crypto" ) // A Logger interface for the EVM. @@ -42,7 +39,7 @@ func (*DefaultLogger) Error(msg string, keyvals ...interface{}) { } func makeKeccakSignature(input string) uint32 { - return binary.BigEndian.Uint32(crypto.Keccak256([]byte(input))[0:4]) + return binary.BigEndian.Uint32(Keccak256([]byte(input))[0:4]) } func isScalarOp(input []byte) (bool, error) { @@ -53,7 +50,7 @@ func isScalarOp(input []byte) (bool, error) { return isScalar, nil } -func getVerifiedCiphertext(environment EVMEnvironment, ciphertextHash common.Hash) *verifiedCiphertext { +func getVerifiedCiphertext(environment EVMEnvironment, ciphertextHash Hash) *verifiedCiphertext { return getVerifiedCiphertextFromEVM(environment, ciphertextHash) } @@ -61,11 +58,11 @@ func get2VerifiedOperands(environment EVMEnvironment, input []byte) (lhs *verifi if len(input) != 65 { return nil, nil, errors.New("input needs to contain two 256-bit sized values and 1 8-bit value") } - lhs = getVerifiedCiphertext(environment, common.BytesToHash(input[0:32])) + lhs = getVerifiedCiphertext(environment, BytesToHash(input[0:32])) if lhs == nil { return nil, nil, errors.New("unverified ciphertext handle") } - rhs = getVerifiedCiphertext(environment, common.BytesToHash(input[32:64])) + rhs = getVerifiedCiphertext(environment, BytesToHash(input[32:64])) if rhs == nil { return nil, nil, errors.New("unverified ciphertext handle") } @@ -77,7 +74,7 @@ func getScalarOperands(environment EVMEnvironment, input []byte) (lhs *verifiedC if len(input) != 65 { return nil, nil, errors.New("input needs to contain two 256-bit sized values and 1 8-bit value") } - lhs = getVerifiedCiphertext(environment, common.BytesToHash(input[0:32])) + lhs = getVerifiedCiphertext(environment, BytesToHash(input[0:32])) if lhs == nil { return nil, nil, errors.New("unverified ciphertext handle") } @@ -113,8 +110,8 @@ func importCiphertext(environment EVMEnvironment, ct *tfheCiphertext) *verifiedC func importRandomCiphertext(environment EVMEnvironment, t FheUintType) []byte { nextCtHash := &environment.FhevmData().nextCiphertextHashOnGasEst - ctHashBytes := crypto.Keccak256(nextCtHash.Bytes()) - handle := common.BytesToHash(ctHashBytes) + ctHashBytes := Keccak256(nextCtHash.Bytes()) + handle := BytesToHash(ctHashBytes) ct := new(tfheCiphertext) ct.fheUintType = t ct.hash = &handle @@ -168,27 +165,27 @@ func padArrayTo32Multiple(input []byte) []byte { return input } -func Create(evm EVMEnvironment, caller common.Address, code []byte, gas uint64, value *big.Int) (ret []byte, contractAddr common.Address, leftOverGas uint64, err error) { - contractAddr = crypto.CreateAddress(caller, evm.GetNonce(caller)) - protectedStorageAddr := fhevm_crypto.CreateProtectedStorageContractAddress(contractAddr) +func Create(evm EVMEnvironment, caller Address, code []byte, gas uint64, value *big.Int) (ret []byte, contractAddr Address, leftOverGas uint64, err error) { + contractAddr = CreateAddress(caller, evm.GetNonce(caller)) + protectedStorageAddr := CreateProtectedStorageContractAddress(contractAddr) _, _, leftOverGas, err = evm.CreateContract(caller, nil, gas, big.NewInt(0), protectedStorageAddr) if err != nil { ret = nil - contractAddr = common.Address{} + contractAddr = Address{} return } // TODO: consider reverting changes to `protectedStorageAddr` if actual contract creation fails. return evm.CreateContract(caller, code, leftOverGas, value, contractAddr) } -func Create2(evm EVMEnvironment, caller common.Address, code []byte, gas uint64, endowment *big.Int, salt *uint256.Int) (ret []byte, contractAddr common.Address, leftOverGas uint64, err error) { - codeHash := crypto.Keccak256Hash(code) - contractAddr = crypto.CreateAddress2(caller, salt.Bytes32(), codeHash.Bytes()) - protectedStorageAddr := fhevm_crypto.CreateProtectedStorageContractAddress(contractAddr) - _, _, leftOverGas, err = evm.CreateContract2(caller, nil, common.Hash{}, gas, big.NewInt(0), protectedStorageAddr) +func Create2(evm EVMEnvironment, caller Address, code []byte, gas uint64, endowment *big.Int, salt *uint256.Int) (ret []byte, contractAddr Address, leftOverGas uint64, err error) { + codeHash := Keccak256Hash(code) + contractAddr = CreateAddress2(caller, salt.Bytes32(), codeHash.Bytes()) + protectedStorageAddr := CreateProtectedStorageContractAddress(contractAddr) + _, _, leftOverGas, err = evm.CreateContract2(caller, nil, Hash{}, gas, big.NewInt(0), protectedStorageAddr) if err != nil { ret = nil - contractAddr = common.Address{} + contractAddr = Address{} return } // TODO: consider reverting changes to `protectedStorageAddr` if actual contract creation fails. diff --git a/fhevm/instructions.go b/fhevm/instructions.go index 643c104..dd43814 100644 --- a/fhevm/instructions.go +++ b/fhevm/instructions.go @@ -7,10 +7,7 @@ import ( "math/big" "strings" - "github.com/ethereum/go-ethereum/common" - crypto "github.com/ethereum/go-ethereum/crypto" "github.com/holiman/uint256" - fhevm_crypto "github.com/zama-ai/fhevm-go/crypto" ) var zero = uint256.NewInt(0).Bytes32() @@ -62,11 +59,11 @@ func minUint64(a, b uint64) uint64 { } // If references are still left, reduce refCount by 1. Otherwise, zero out the metadata and the ciphertext slots. -func garbageCollectProtectedStorage(flagHandleLocation common.Hash, handle common.Hash, protectedStorage common.Address, env EVMEnvironment) { +func garbageCollectProtectedStorage(flagHandleLocation Hash, handle Hash, protectedStorage Address, env EVMEnvironment) { // The location of ciphertext metadata is at Keccak256(handle). Doing so avoids attacks from users trying to garbage // collect arbitrary locations in protected storage. Hashing the handle makes it hard to find a preimage such that // it ends up in arbitrary non-zero places in protected stroage. - metadataKey := crypto.Keccak256Hash(handle.Bytes()) + metadataKey := Keccak256Hash(handle.Bytes()) existingMetadataHash := env.GetState(protectedStorage, metadataKey) existingMetadataInt := newInt(existingMetadataHash.Bytes()) @@ -132,7 +129,7 @@ func isVerifiedAtCurrentDepth(environment EVMEnvironment, ct *verifiedCiphertext // Returns a pointer to the ciphertext if the given hash points to a verified ciphertext. // Else, it returns nil. -func getVerifiedCiphertextFromEVM(environment EVMEnvironment, ciphertextHash common.Hash) *verifiedCiphertext { +func getVerifiedCiphertextFromEVM(environment EVMEnvironment, ciphertextHash Hash) *verifiedCiphertext { ct, ok := environment.FhevmData().verifiedCiphertexts[ciphertextHash] if ok && isVerifiedAtCurrentDepth(environment, ct) { return ct @@ -140,7 +137,7 @@ func getVerifiedCiphertextFromEVM(environment EVMEnvironment, ciphertextHash com return nil } -func verifyIfCiphertextHandle(handle common.Hash, env EVMEnvironment, contractAddress common.Address) error { +func verifyIfCiphertextHandle(handle Hash, env EVMEnvironment, contractAddress Address) error { ct, ok := env.FhevmData().verifiedCiphertexts[handle] if ok { // If already existing in memory, skip storage and import the same ciphertext at the current depth. @@ -152,8 +149,8 @@ func verifyIfCiphertextHandle(handle common.Hash, env EVMEnvironment, contractAd return nil } - metadataKey := crypto.Keccak256Hash(handle.Bytes()) - protectedStorage := fhevm_crypto.CreateProtectedStorageContractAddress(contractAddress) + metadataKey := Keccak256Hash(handle.Bytes()) + protectedStorage := CreateProtectedStorageContractAddress(contractAddress) metadataInt := newInt(env.GetState(protectedStorage, metadataKey).Bytes()) if !metadataInt.IsZero() { metadata := newCiphertextMetadata(metadataInt.Bytes32()) @@ -186,7 +183,7 @@ func verifyIfCiphertextHandle(handle common.Hash, env EVMEnvironment, contractAd func OpSload(pc *uint64, env EVMEnvironment, scope ScopeContext) ([]byte, error) { loc := scope.GetStack().Peek() - hash := common.Hash(loc.Bytes32()) + hash := Hash(loc.Bytes32()) val := env.GetState(scope.GetContract().Address(), hash) if err := verifyIfCiphertextHandle(val, env, scope.GetContract().Address()); err != nil { return nil, err @@ -196,12 +193,12 @@ func OpSload(pc *uint64, env EVMEnvironment, scope ScopeContext) ([]byte, error) } // An arbitrary constant value to flag locations in protected storage. -var flag = common.HexToHash("0xa145ffde0100a145ffde0100a145ffde0100a145ffde0100a145ffde0100fab3") +var flag = HexToHash("0xa145ffde0100a145ffde0100a145ffde0100a145ffde0100a145ffde0100fab3") // If a verified ciphertext: // * if the ciphertext does not exist in protected storage, persist it with a refCount = 1 // * if the ciphertexts exists in protected, bump its refCount by 1 -func persistIfVerifiedCiphertext(flagHandleLocation common.Hash, handle common.Hash, protectedStorage common.Address, env EVMEnvironment) { +func persistIfVerifiedCiphertext(flagHandleLocation Hash, handle Hash, protectedStorage Address, env EVMEnvironment) { verifiedCiphertext := getVerifiedCiphertextFromEVM(env, handle) if verifiedCiphertext == nil { return @@ -209,7 +206,7 @@ func persistIfVerifiedCiphertext(flagHandleLocation common.Hash, handle common.H logger := env.GetLogger() // Try to read ciphertext metadata from protected storage. - metadataKey := crypto.Keccak256Hash(handle.Bytes()) + metadataKey := Keccak256Hash(handle.Bytes()) metadataInt := newInt(env.GetState(protectedStorage, metadataKey).Bytes()) metadata := ciphertextMetadata{} @@ -236,7 +233,7 @@ func persistIfVerifiedCiphertext(flagHandleLocation common.Hash, handle common.H ctBytes := verifiedCiphertext.ciphertext.serialize() for i, b := range ctBytes { if i%32 == 0 && i != 0 { - env.SetState(protectedStorage, ciphertextSlot.Bytes32(), common.BytesToHash(ctPart32)) + env.SetState(protectedStorage, ciphertextSlot.Bytes32(), BytesToHash(ctPart32)) ciphertextSlot.AddUint64(ciphertextSlot, 1) ctPart32 = make([]byte, 32) partIdx = 0 @@ -245,7 +242,7 @@ func persistIfVerifiedCiphertext(flagHandleLocation common.Hash, handle common.H partIdx++ } if len(ctPart32) != 0 { - env.SetState(protectedStorage, ciphertextSlot.Bytes32(), common.BytesToHash(ctPart32)) + env.SetState(protectedStorage, ciphertextSlot.Bytes32(), BytesToHash(ctPart32)) } } else { // If metadata exists, bump the refcount by 1. @@ -269,20 +266,20 @@ func OpSstore(pc *uint64, env EVMEnvironment, scope ScopeContext) ([]byte, error return nil, ErrWriteProtection } loc := scope.GetStack().Pop() - locHash := common.BytesToHash(loc.Bytes()) + locHash := BytesToHash(loc.Bytes()) newVal := scope.GetStack().Pop() - newValHash := common.BytesToHash(newVal.Bytes()) - oldValHash := env.GetState(scope.GetContract().Address(), common.Hash(loc.Bytes32())) + newValHash := BytesToHash(newVal.Bytes()) + oldValHash := env.GetState(scope.GetContract().Address(), Hash(loc.Bytes32())) // If the value is the same or if we are not going to commit, don't do anything to protected storage. if newValHash != oldValHash && env.IsCommitting() { - protectedStorage := fhevm_crypto.CreateProtectedStorageContractAddress(scope.GetContract().Address()) + protectedStorage := CreateProtectedStorageContractAddress(scope.GetContract().Address()) // Define flag location as keccak256(keccak256(loc)) in protected storage. Used to mark the location as containing a handle. // Note: We apply the hash function twice to make sure a flag location in protected storage cannot clash with a ciphertext // metadata location that is keccak256(keccak256(ciphertext)). Since a location is 32 bytes, it cannot clash with a well-formed // ciphertext. Therefore, there needs to be a hash collistion for a clash to happen. If hash is applied only once, there could // be a collision, since malicous users could store at loc = keccak256(ciphertext), making the flag clash with metadata. - flagHandleLocation := crypto.Keccak256Hash(crypto.Keccak256Hash(locHash[:]).Bytes()) + flagHandleLocation := Keccak256Hash(Keccak256Hash(locHash[:]).Bytes()) // Since the old value is no longer stored in actual contract storage, run garbage collection on protected storage. garbageCollectProtectedStorage(flagHandleLocation, oldValHash, protectedStorage, env) @@ -297,8 +294,8 @@ func OpSstore(pc *uint64, env EVMEnvironment, scope ScopeContext) ([]byte, error // If there are ciphertext handles in the arguments to a call, delegate them to the callee. // Return a map from ciphertext hash -> depthSet before delegation. -func DelegateCiphertextHandlesInArgs(env EVMEnvironment, args []byte) (verified map[common.Hash]*depthSet) { - verified = make(map[common.Hash]*depthSet) +func DelegateCiphertextHandlesInArgs(env EVMEnvironment, args []byte) (verified map[Hash]*depthSet) { + verified = make(map[Hash]*depthSet) for key, verifiedCiphertext := range env.FhevmData().verifiedCiphertexts { if contains(args, key.Bytes()) && isVerifiedAtCurrentDepth(env, verifiedCiphertext) { if env.IsCommitting() { @@ -314,7 +311,7 @@ func DelegateCiphertextHandlesInArgs(env EVMEnvironment, args []byte) (verified return } -func RestoreVerifiedDepths(env EVMEnvironment, verified map[common.Hash]*depthSet) { +func RestoreVerifiedDepths(env EVMEnvironment, verified map[Hash]*depthSet) { for k, v := range verified { env.FhevmData().verifiedCiphertexts[k].verifiedDepths = v } @@ -356,7 +353,7 @@ func OpReturn(pc *uint64, env EVMEnvironment, scope ScopeContext) []byte { func OpSelfdestruct(pc *uint64, env EVMEnvironment, scope ScopeContext) (beneficiary uint256.Int, balance *big.Int) { beneficiary = scope.GetStack().Pop() - protectedStorage := fhevm_crypto.CreateProtectedStorageContractAddress(scope.GetContract().Address()) + protectedStorage := CreateProtectedStorageContractAddress(scope.GetContract().Address()) balance = env.GetBalance(scope.GetContract().Address()) balance.Add(balance, env.GetBalance(protectedStorage)) env.AddBalance(beneficiary.Bytes20(), balance) diff --git a/fhevm/interface.go b/fhevm/interface.go index dd814ac..cd78312 100644 --- a/fhevm/interface.go +++ b/fhevm/interface.go @@ -3,19 +3,18 @@ package fhevm import ( "math/big" - "github.com/ethereum/go-ethereum/common" "github.com/holiman/uint256" ) type EVMEnvironment interface { // StateDB related functions - GetState(common.Address, common.Hash) common.Hash - SetState(common.Address, common.Hash, common.Hash) - GetNonce(common.Address) uint64 - AddBalance(common.Address, *big.Int) - GetBalance(common.Address) *big.Int + GetState(Address, Hash) Hash + SetState(Address, Hash, Hash) + GetNonce(Address) uint64 + AddBalance(Address, *big.Int) + GetBalance(Address) *big.Int - Suicide(common.Address) bool + Suicide(Address) bool // EVM call stack depth GetDepth() int @@ -28,8 +27,8 @@ type EVMEnvironment interface { IsEthCall() bool IsReadOnly() bool - CreateContract(caller common.Address, code []byte, gas uint64, value *big.Int, address common.Address) ([]byte, common.Address, uint64, error) - CreateContract2(caller common.Address, code []byte, codeHash common.Hash, gas uint64, value *big.Int, address common.Address) ([]byte, common.Address, uint64, error) + CreateContract(caller Address, code []byte, gas uint64, value *big.Int, address Address) ([]byte, Address, uint64, error) + CreateContract2(caller Address, code []byte, codeHash Hash, gas uint64, value *big.Int, address Address) ([]byte, Address, uint64, error) FhevmData() *FhevmData FhevmParams() *FhevmParams @@ -37,7 +36,7 @@ type EVMEnvironment interface { type FhevmData struct { // A map from a ciphertext hash to itself and stack depth at which it is verified - verifiedCiphertexts map[common.Hash]*verifiedCiphertext + verifiedCiphertexts map[Hash]*verifiedCiphertext // All optimistic requires encountered up to that point in the txn execution optimisticRequires []*tfheCiphertext @@ -47,7 +46,7 @@ type FhevmData struct { func NewFhevmData() FhevmData { return FhevmData{ - verifiedCiphertexts: make(map[common.Hash]*verifiedCiphertext), + verifiedCiphertexts: make(map[Hash]*verifiedCiphertext), optimisticRequires: make([]*tfheCiphertext, 0), } } diff --git a/fhevm/interpreter.go b/fhevm/interpreter.go index 6f5f7bb..252422c 100644 --- a/fhevm/interpreter.go +++ b/fhevm/interpreter.go @@ -1,7 +1,5 @@ package fhevm -import "github.com/ethereum/go-ethereum/common" - type ScopeContext interface { GetMemory() Memory GetStack() Stack @@ -50,13 +48,13 @@ type verifiedCiphertext struct { type PrivilegedMemory struct { // A map from a ciphertext hash to itself and stack depths at which it is verified - VerifiedCiphertexts map[common.Hash]*verifiedCiphertext + VerifiedCiphertexts map[Hash]*verifiedCiphertext // All optimistic requires encountered up to that point in the txn execution OptimisticRequires []*tfheCiphertext } var PrivilegedMempory *PrivilegedMemory = &PrivilegedMemory{ - make(map[common.Hash]*verifiedCiphertext), + make(map[Hash]*verifiedCiphertext), make([]*tfheCiphertext, 0), } diff --git a/fhevm/precompiles.go b/fhevm/precompiles.go index 7a9535a..37242f7 100644 --- a/fhevm/precompiles.go +++ b/fhevm/precompiles.go @@ -8,10 +8,7 @@ import ( "errors" "math/big" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" "github.com/holiman/uint256" - fhevm_crypto "github.com/zama-ai/fhevm-go/crypto" "golang.org/x/crypto/chacha20" "golang.org/x/crypto/nacl/box" ) @@ -21,7 +18,7 @@ import ( // contract. type PrecompiledContract interface { RequiredGas(environment *EVMEnvironment, input []byte) uint64 // RequiredGas calculates the contract gas use - Run(environment *EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) (ret []byte, err error) + Run(environment *EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) (ret []byte, err error) } var signatureFheAdd = makeKeccakSignature("fheAdd(uint256,uint256,bytes1)") @@ -154,7 +151,7 @@ func FheLibRequiredGas(environment EVMEnvironment, input []byte) uint64 { } } -func FheLibRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func FheLibRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() if len(input) < 4 { err := errors.New("input must contain at least 4 bytes for method signature") @@ -468,7 +465,7 @@ func fheNegRequiredGas(environment EVMEnvironment, input []byte) uint64 { logger.Error("fheNeg input needs to contain one 256-bit sized value", "input", hex.EncodeToString(input)) return 0 } - ct := getVerifiedCiphertext(environment, common.BytesToHash(input[0:32])) + ct := getVerifiedCiphertext(environment, BytesToHash(input[0:32])) if ct == nil { logger.Error("fheNeg input not verified", "input", hex.EncodeToString(input)) return 0 @@ -587,7 +584,7 @@ func reencryptRequiredGas(environment EVMEnvironment, input []byte) uint64 { logger.Error("reencrypt RequiredGas() input len must be 64 bytes", "input", hex.EncodeToString(input), "len", len(input)) return 0 } - ct := getVerifiedCiphertext(environment, common.BytesToHash(input[0:32])) + ct := getVerifiedCiphertext(environment, BytesToHash(input[0:32])) if ct == nil { logger.Error("reencrypt RequiredGas() input doesn't point to verified ciphertext", "input", hex.EncodeToString(input)) return 0 @@ -602,7 +599,7 @@ func optimisticRequireRequiredGas(environment EVMEnvironment, input []byte) uint "input", hex.EncodeToString(input), "len", len(input)) return 0 } - ct := getVerifiedCiphertext(environment, common.BytesToHash(input)) + ct := getVerifiedCiphertext(environment, BytesToHash(input)) if ct == nil { logger.Error("optimisticRequire RequiredGas() input doesn't point to verified ciphertext", "input", hex.EncodeToString(input)) @@ -635,7 +632,7 @@ func decryptRequiredGas(environment EVMEnvironment, input []byte) uint64 { logger.Error("decrypt RequiredGas() input len must be 32 bytes", "input", hex.EncodeToString(input), "len", len(input)) return 0 } - ct := getVerifiedCiphertext(environment, common.BytesToHash(input)) + ct := getVerifiedCiphertext(environment, BytesToHash(input)) if ct == nil { logger.Error("decrypt RequiredGas() input doesn't point to verified ciphertext", "input", hex.EncodeToString(input)) return 0 @@ -658,7 +655,7 @@ func trivialEncryptRequiredGas(environment EVMEnvironment, input []byte) uint64 } // Implementations -func fheAddRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheAddRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -720,7 +717,7 @@ func fheAddRun(environment EVMEnvironment, caller common.Address, addr common.Ad } } -func fheSubRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheSubRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -782,7 +779,7 @@ func fheSubRun(environment EVMEnvironment, caller common.Address, addr common.Ad } } -func fheMulRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheMulRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -844,7 +841,7 @@ func fheMulRun(environment EVMEnvironment, caller common.Address, addr common.Ad } } -func fheLeRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheLeRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -906,7 +903,7 @@ func fheLeRun(environment EVMEnvironment, caller common.Address, addr common.Add } } -func fheLtRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheLtRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -968,7 +965,7 @@ func fheLtRun(environment EVMEnvironment, caller common.Address, addr common.Add } } -func fheEqRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheEqRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1030,7 +1027,7 @@ func fheEqRun(environment EVMEnvironment, caller common.Address, addr common.Add } } -func fheGeRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheGeRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1092,7 +1089,7 @@ func fheGeRun(environment EVMEnvironment, caller common.Address, addr common.Add } } -func fheGtRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheGtRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1154,7 +1151,7 @@ func fheGtRun(environment EVMEnvironment, caller common.Address, addr common.Add } } -func fheShlRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheShlRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1216,7 +1213,7 @@ func fheShlRun(environment EVMEnvironment, caller common.Address, addr common.Ad } } -func fheShrRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheShrRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1278,7 +1275,7 @@ func fheShrRun(environment EVMEnvironment, caller common.Address, addr common.Ad } } -func fheNeRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheNeRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1340,7 +1337,7 @@ func fheNeRun(environment EVMEnvironment, caller common.Address, addr common.Add } } -func fheMinRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheMinRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1402,7 +1399,7 @@ func fheMinRun(environment EVMEnvironment, caller common.Address, addr common.Ad } } -func fheMaxRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheMaxRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1464,7 +1461,7 @@ func fheMaxRun(environment EVMEnvironment, caller common.Address, addr common.Ad } } -func fheNegRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheNegRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() if len(input) != 32 { @@ -1474,7 +1471,7 @@ func fheNegRun(environment EVMEnvironment, caller common.Address, addr common.Ad } - ct := getVerifiedCiphertext(environment, common.BytesToHash(input[0:32])) + ct := getVerifiedCiphertext(environment, BytesToHash(input[0:32])) if ct == nil { msg := "fheNeg input not verified" logger.Error(msg, msg, "input", hex.EncodeToString(input)) @@ -1498,7 +1495,7 @@ func fheNegRun(environment EVMEnvironment, caller common.Address, addr common.Ad return resultHash[:], nil } -func fheNotRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheNotRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() if len(input) != 32 { @@ -1508,7 +1505,7 @@ func fheNotRun(environment EVMEnvironment, caller common.Address, addr common.Ad } - ct := getVerifiedCiphertext(environment, common.BytesToHash(input[0:32])) + ct := getVerifiedCiphertext(environment, BytesToHash(input[0:32])) if ct == nil { msg := "fheNot input not verified" logger.Error(msg, msg, "input", hex.EncodeToString(input)) @@ -1532,7 +1529,7 @@ func fheNotRun(environment EVMEnvironment, caller common.Address, addr common.Ad return resultHash[:], nil } -func fheDivRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheDivRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1570,7 +1567,7 @@ func fheDivRun(environment EVMEnvironment, caller common.Address, addr common.Ad } } -func fheRemRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheRemRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1608,7 +1605,7 @@ func fheRemRun(environment EVMEnvironment, caller common.Address, addr common.Ad } } -func fheBitAndRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheBitAndRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1652,7 +1649,7 @@ func fheBitAndRun(environment EVMEnvironment, caller common.Address, addr common return resultHash[:], nil } -func fheBitOrRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheBitOrRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1696,7 +1693,7 @@ func fheBitOrRun(environment EVMEnvironment, caller common.Address, addr common. return resultHash[:], nil } -func fheBitXorRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheBitXorRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() isScalar, err := isScalarOp(input) @@ -1758,7 +1755,7 @@ func init() { } } -func fheRandRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fheRandRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() if environment.IsEthCall() { msg := "fheRand cannot be called via EthCall, because it needs to mutate internal state" @@ -1778,7 +1775,7 @@ func fheRandRun(environment EVMEnvironment, caller common.Address, addr common.A } // Get the RNG nonce. - protectedStorage := fhevm_crypto.CreateProtectedStorageContractAddress(caller) + protectedStorage := CreateProtectedStorageContractAddress(caller) currentRngNonceBytes := environment.GetState(protectedStorage, rngNonceKey).Bytes() // Increment the RNG nonce by 1. @@ -1787,10 +1784,10 @@ func fheRandRun(environment EVMEnvironment, caller common.Address, addr common.A environment.SetState(protectedStorage, rngNonceKey, nextRngNonce.Bytes32()) // Compute the seed and use it to create a new cipher. - hasher := crypto.NewKeccakState() + hasher := NewKeccakState() hasher.Write(globalRngSeed) hasher.Write(caller.Bytes()) - seed := common.Hash{} + seed := Hash{} _, err := hasher.Read(seed[:]) if err != nil { return nil, err @@ -1823,7 +1820,7 @@ func fheRandRun(environment EVMEnvironment, caller common.Address, addr common.A return ctHash[:], nil } -func verifyCiphertextRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func verifyCiphertextRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() if len(input) <= 1 { msg := "verifyCiphertext Run() input needs to contain a ciphertext and one byte for its type" @@ -1888,7 +1885,7 @@ func encryptToUserKey(value *big.Int, pubKey []byte) ([]byte, error) { return ct, nil } -func reencryptRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func reencryptRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() if !environment.IsEthCall() { msg := "reencrypt only supported on EthCall" @@ -1900,7 +1897,7 @@ func reencryptRun(environment EVMEnvironment, caller common.Address, addr common logger.Error(msg, "input", hex.EncodeToString(input), "len", len(input)) return nil, errors.New(msg) } - ct := getVerifiedCiphertext(environment, common.BytesToHash(input[0:32])) + ct := getVerifiedCiphertext(environment, BytesToHash(input[0:32])) if ct != nil { // Make sure we don't decrypt before any optimistic requires are checked. optReqResult, optReqErr := evaluateRemainingOptimisticRequires(environment) @@ -1928,14 +1925,14 @@ func reencryptRun(environment EVMEnvironment, caller common.Address, addr common return nil, errors.New(msg) } -func optimisticRequireRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func optimisticRequireRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() if len(input) != 32 { msg := "optimisticRequire input len must be 32 bytes" logger.Error(msg, "input", hex.EncodeToString(input), "len", len(input)) return nil, errors.New(msg) } - ct := getVerifiedCiphertext(environment, common.BytesToHash(input)) + ct := getVerifiedCiphertext(environment, BytesToHash(input)) if ct == nil { msg := "optimisticRequire unverified handle" logger.Error(msg, "input", hex.EncodeToString(input)) @@ -1954,14 +1951,14 @@ func optimisticRequireRun(environment EVMEnvironment, caller common.Address, add return nil, nil } -func decryptRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func decryptRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() if len(input) != 32 { msg := "decrypt input len must be 32 bytes" logger.Error(msg, "input", hex.EncodeToString(input), "len", len(input)) return nil, errors.New(msg) } - ct := getVerifiedCiphertext(environment, common.BytesToHash(input)) + ct := getVerifiedCiphertext(environment, BytesToHash(input)) if ct == nil { msg := "decrypt unverified handle" logger.Error(msg, "input", hex.EncodeToString(input)) @@ -2020,7 +2017,7 @@ func evaluateRemainingOptimisticRequires(environment EVMEnvironment) (bool, erro return true, nil } -func castRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func castRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() if len(input) != 33 { msg := "cast Run() input needs to contain a ciphertext and one byte for its type" @@ -2028,7 +2025,7 @@ func castRun(environment EVMEnvironment, caller common.Address, addr common.Addr return nil, errors.New(msg) } - ct := getVerifiedCiphertext(environment, common.BytesToHash(input[0:32])) + ct := getVerifiedCiphertext(environment, BytesToHash(input[0:32])) if ct == nil { logger.Error("cast input not verified") return nil, errors.New("unverified ciphertext handle") @@ -2064,10 +2061,10 @@ func castRun(environment EVMEnvironment, caller common.Address, addr common.Addr return resHash.Bytes(), nil } -var fhePubKeyHashPrecompile = common.BytesToAddress([]byte{93}) -var fhePubKeyHashSlot = common.Hash{} +var fhePubKeyHashPrecompile = BytesToAddress([]byte{93}) +var fhePubKeyHashSlot = Hash{} -func fhePubKeyRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func fhePubKeyRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { existing := environment.GetState(fhePubKeyHashPrecompile, fhePubKeyHashSlot) if existing != pksHash { msg := "fhePubKey FHE public key hash doesn't match one stored in state" @@ -2087,7 +2084,7 @@ func fhePubKeyRun(environment EVMEnvironment, caller common.Address, addr common } } -func trivialEncryptRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func trivialEncryptRun(environment EVMEnvironment, caller Address, addr Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() if len(input) != 33 { msg := "trivialEncrypt input len must be 33 bytes" diff --git a/fhevm/tfhe.go b/fhevm/tfhe.go index dd08e05..6f827f3 100644 --- a/fhevm/tfhe.go +++ b/fhevm/tfhe.go @@ -1505,9 +1505,6 @@ import ( "os" "path" "unsafe" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" ) func toBufferView(in []byte) C.BufferView { @@ -1531,7 +1528,7 @@ var cks unsafe.Pointer // public key var pks unsafe.Pointer -var pksHash common.Hash +var pksHash Hash // Generate keys for the fhevm (sks, cks, psk) func generateFhevmKeys() (unsafe.Pointer, unsafe.Pointer, unsafe.Pointer) { @@ -1584,7 +1581,7 @@ func InitGlobalKeysFromFiles(keysDir string) error { sks = C.deserialize_server_key(toBufferView(sksBytes)) - pksHash = crypto.Keccak256Hash(pksBytes) + pksHash = Keccak256Hash(pksBytes) pks = C.deserialize_compact_public_key(toBufferView(pksBytes)) cks = C.deserialize_client_key(toBufferView(cksBytes)) @@ -1655,7 +1652,7 @@ const ( // Represents an expanded TFHE ciphertext. type tfheCiphertext struct { serialization []byte - hash *common.Hash + hash *Hash fheUintType FheUintType } @@ -2626,11 +2623,11 @@ func (ct *tfheCiphertext) decrypt() (big.Int, error) { } func (ct *tfheCiphertext) computeHash() { - hash := common.BytesToHash(crypto.Keccak256(ct.serialization)) + hash := BytesToHash(Keccak256(ct.serialization)) ct.hash = &hash } -func (ct *tfheCiphertext) getHash() common.Hash { +func (ct *tfheCiphertext) getHash() Hash { if ct.hash != nil { return *ct.hash } diff --git a/fhevm/utils.go b/fhevm/utils.go new file mode 100644 index 0000000..a1f0eb0 --- /dev/null +++ b/fhevm/utils.go @@ -0,0 +1,149 @@ +package fhevm + +import ( + "encoding/hex" + "hash" + + "github.com/ethereum/go-ethereum/rlp" + "golang.org/x/crypto/sha3" +) + +// Lengths of hashes and addresses in bytes. +const ( + // HashLength is the expected length of the hash + HashLength = 32 + // AddressLength is the expected length of the address + AddressLength = 20 +) + +// Address represents the 20 byte address of an Ethereum account. +type Address [AddressLength]byte + +// Hash represents the 32 byte Keccak256 hash of arbitrary data. +type Hash [HashLength]byte + +// KeccakState wraps sha3.state. In addition to the usual hash methods, it also supports +// Read to get a variable amount of data from the hash state. Read is faster than Sum +// because it doesn't copy the internal state, but also modifies the internal state. +type KeccakState interface { + hash.Hash + Read([]byte) (int, error) +} + +// Bytes gets the byte representation of the underlying hash. +func (h Hash) Bytes() []byte { return h[:] } + +// Hex converts a hash to a hex string. +func (h Hash) Hex() string { return HexEncode(h[:]) } + +// HexEncode encodes b as a hex string with 0x prefix. +func HexEncode(b []byte) string { + enc := make([]byte, len(b)*2+2) + copy(enc, "0x") + hex.Encode(enc[2:], b) + return string(enc) +} + +// SetBytes sets the hash to the value of b. +// If b is larger than len(h), b will be cropped from the left. +func (h *Hash) SetBytes(b []byte) { + if len(b) > len(h) { + b = b[len(b)-HashLength:] + } + + copy(h[HashLength-len(b):], b) +} + +func BytesToAddress(b []byte) Address { + var a Address + a.SetBytes(b) + return a +} + +// BytesToHash sets b to hash. +// If b is larger than len(h), b will be cropped from the left. +func BytesToHash(b []byte) Hash { + var h Hash + h.SetBytes(b) + return h +} + +func (a *Address) SetBytes(b []byte) { + if len(b) > len(a) { + b = b[len(b)-AddressLength:] + } + copy(a[AddressLength-len(b):], b) +} + +func (a Address) Bytes() []byte { return a[:] } + +func CreateAddress(b Address, nonce uint64) Address { + data, _ := rlp.EncodeToBytes([]interface{}{b, nonce}) + return BytesToAddress(Keccak256(data)[12:]) +} + +// CreateAddress2 creates an ethereum address given the address bytes, initial +// contract code hash and a salt. +func CreateAddress2(b Address, salt [32]byte, inithash []byte) Address { + return BytesToAddress(Keccak256([]byte{0xff}, b.Bytes(), salt[:], inithash)[12:]) +} + +// CreateProtectedStorageAddress creates an ethereum contract address for protected storage +// given the corresponding contract address +func CreateProtectedStorageContractAddress(b Address) Address { + return CreateAddress(b, 0) +} + +// NewKeccakState creates a new KeccakState +func NewKeccakState() KeccakState { + return sha3.NewLegacyKeccak256().(KeccakState) +} + +// Keccak256 calculates and returns the Keccak256 hash of the input data. +func Keccak256(data ...[]byte) []byte { + b := make([]byte, 32) + d := NewKeccakState() + for _, b := range data { + d.Write(b) + } + d.Read(b) + return b +} + +// Keccak256Hash calculates and returns the Keccak256 hash of the input data, +// converting it to an internal Hash data structure. +func Keccak256Hash(data ...[]byte) (h Hash) { + d := NewKeccakState() + for _, b := range data { + d.Write(b) + } + d.Read(h[:]) + return h +} + +// HexToHash sets byte representation of s to hash. +// If b is larger than len(h), b will be cropped from the left. +func HexToHash(s string) Hash { return BytesToHash(FromHex(s)) } + +// FromHex returns the bytes represented by the hexadecimal string s. +// s may be prefixed with "0x". +func FromHex(s string) []byte { + if has0xPrefix(s) { + s = s[2:] + } + if len(s)%2 == 1 { + s = "0" + s + } + return Hex2Bytes(s) +} + +// has0xPrefix validates str begins with '0x' or '0X'. +func has0xPrefix(str string) bool { + return len(str) >= 2 && str[0] == '0' && (str[1] == 'x' || str[1] == 'X') +} + +// Hex2Bytes returns the bytes represented by the hexadecimal string str. +func Hex2Bytes(str string) []byte { + h, _ := hex.DecodeString(str) + return h +}