Skip to content
This repository has been archived by the owner on Dec 23, 2024. It is now read-only.

feat: array equality as a precompile #113

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
769 changes: 769 additions & 0 deletions fhevm/contracts_test.go

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions fhevm/fhelib.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ var fhelibMethods = []*FheLibMethod{
requiredGasFunction: fheIfThenElseRequiredGas,
runFunction: fheIfThenElseRun,
},
{
name: "fheArrayEq",
argTypes: "(uint256[],uint256[])",
requiredGasFunction: fheArrayEqRequiredGas,
runFunction: fheArrayEqRun,
},
{
name: "fhePubKey",
argTypes: "(bytes1)",
Expand Down
109 changes: 109 additions & 0 deletions fhevm/operators_comparison.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ package fhevm
import (
"encoding/hex"
"errors"
"fmt"
"math/big"
"strings"

"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"github.com/zama-ai/fhevm-go/fhevm/tfhe"
"go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -570,3 +574,108 @@ func fheIfThenElseRun(environment EVMEnvironment, caller common.Address, addr co
logger.Info("fheIfThenElse success", "first", first.hash().Hex(), "second", second.hash().Hex(), "third", third.hash().Hex(), "result", resultHash.Hex())
return resultHash[:], nil
}

// TODO: implement as part of fhelibMethods.
const fheArrayEqAbiJson = `
[
{
"name": "fheArrayEq",
"type": "function",
"inputs": [
{
"name": "lhs",
"type": "uint256[]"
},
{
"name": "rhs",
"type": "uint256[]"
}
],
"outputs": [
{
"name": "",
"type": "uint256"
}
]
}
]
`

var arrayEqMethod abi.Method

func init() {
reader := strings.NewReader(fheArrayEqAbiJson)
arrayEqAbi, err := abi.JSON(reader)
if err != nil {
panic(err)
}

var ok bool
arrayEqMethod, ok = arrayEqAbi.Methods["fheArrayEq"]
if !ok {
panic("couldn't find the fheArrayEq method")
}
}

func getVerifiedCiphertexts(environment EVMEnvironment, unpacked interface{}) ([]*tfhe.TfheCiphertext, error) {
big, ok := unpacked.([]*big.Int)
if !ok {
return nil, fmt.Errorf("fheArrayEq failed to cast to []*big.Int")
}
ret := make([]*tfhe.TfheCiphertext, 0, len(big))
for _, b := range big {
ct := getVerifiedCiphertext(environment, common.BigToHash(b))
if ct == nil {
return nil, fmt.Errorf("fheArrayEq unverified ciphertext")
}
ret = append(ret, ct.ciphertext)
}
return ret, nil
}

func fheArrayEqRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) {
logger := environment.GetLogger()

unpacked, err := arrayEqMethod.Inputs.UnpackValues(input)
if err != nil {
msg := "fheArrayEqRun failed to unpack input"
logger.Error(msg, "err", err)
return nil, err
}

if len(unpacked) != 2 {
err := fmt.Errorf("fheArrayEqRun unexpected unpacked len: %d", len(unpacked))
logger.Error(err.Error())
return nil, err
}

lhs, err := getVerifiedCiphertexts(environment, unpacked[0])
if err != nil {
msg := "fheArrayEqRun failed to get lhs to verified ciphertexts"
logger.Error(msg, "err", err)
return nil, err
}

rhs, err := getVerifiedCiphertexts(environment, unpacked[1])
if err != nil {
msg := "fheArrayEqRun failed to get rhs to verified ciphertexts"
logger.Error(msg, "err", err)
return nil, err
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !environment.IsCommitting() && !environment.IsEthCall() {
return importRandomCiphertext(environment, tfhe.FheBool), nil
}

result, err := tfhe.EqArray(lhs, rhs)
if err != nil {
msg := "fheArrayEqRun failed to execute"
logger.Error(msg, "err", err)
return nil, err
}
importCiphertext(environment, result)
resultHash := result.GetHash()
logger.Info("fheArrayEqRun success", "result", resultHash.Hex())
return resultHash[:], nil
}
65 changes: 65 additions & 0 deletions fhevm/operators_comparison_gas.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package fhevm

import (
"encoding/hex"
"fmt"

"github.com/zama-ai/fhevm-go/fhevm/tfhe"
)
Expand Down Expand Up @@ -141,3 +142,67 @@ func fheIfThenElseRequiredGas(environment EVMEnvironment, input []byte) uint64 {
}
return environment.FhevmParams().GasCosts.FheIfThenElse[second.fheUintType()]
}

func fheArrayEqRequiredGas(environment EVMEnvironment, input []byte) uint64 {
logger := environment.GetLogger()

unpacked, err := arrayEqMethod.Inputs.UnpackValues(input)
if err != nil {
msg := "fheArrayEqRun RequiredGas() failed to unpack input"
logger.Error(msg, "err", err)
return 0
}

if len(unpacked) != 2 {
err := fmt.Errorf("fheArrayEqRun RequiredGas() unexpected unpacked len: %d", len(unpacked))
logger.Error(err.Error())
return 0
}

lhs, err := getVerifiedCiphertexts(environment, unpacked[0])
if err != nil {
msg := "fheArrayEqRun RequiredGas() failed to get lhs to verified ciphertexts"
logger.Error(msg, "err", err)
return 0
}

rhs, err := getVerifiedCiphertexts(environment, unpacked[1])
if err != nil {
msg := "fheArrayEqRun RequiredGas() failed to get rhs to verified ciphertexts"
logger.Error(msg, "err", err)
return 0
}

if len(lhs) != len(rhs) || (len(lhs) == 0 && len(rhs) == 0) {
return environment.FhevmParams().GasCosts.FheTrivialEncrypt[tfhe.FheBool]
}

numElements := len(lhs)
elementType := lhs[0].Type()
// TODO: tie to supported types in tfhe.TfheCiphertext.EqArray()
if elementType != tfhe.FheUint4 && elementType != tfhe.FheUint8 && elementType != tfhe.FheUint16 && elementType != tfhe.FheUint32 && elementType != tfhe.FheUint64 {
return 0
}
for i := range lhs {
if lhs[i].Type() != elementType || rhs[i].Type() != elementType {
return 0
}
}

numBits := elementType.NumBits() * uint(numElements)
if numBits <= 4 {
return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint4]
} else if numBits <= 8 {
return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint8]
} else if numBits <= 16 {
return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint16]
} else if numBits <= 32 {
return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint32]
} else if numBits <= 64 {
return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint64]
} else if numBits <= 160 {
immortal-tofu marked this conversation as resolved.
Show resolved Hide resolved
return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint160]
} else {
return (environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint160] + environment.FhevmParams().GasCosts.FheArrayEqBigArrayFactor) * (uint64(numBits) / 160)
}
}
4 changes: 4 additions & 0 deletions fhevm/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type GasCosts struct {
FheShift map[tfhe.FheUintType]uint64
FheScalarShift map[tfhe.FheUintType]uint64
FheEq map[tfhe.FheUintType]uint64
FheArrayEqBigArrayFactor uint64 // TODO: either rename or come up with a better solution
FheLe map[tfhe.FheUintType]uint64
FheMinMax map[tfhe.FheUintType]uint64
FheScalarMinMax map[tfhe.FheUintType]uint64
Expand All @@ -60,6 +61,8 @@ type GasCosts struct {

func DefaultGasCosts() GasCosts {
return GasCosts{
FheCast: 200,
FhePubKey: 50,
FheAddSub: map[tfhe.FheUintType]uint64{
tfhe.FheUint4: 55000 + AdjustFHEGas,
tfhe.FheUint8: 84000 + AdjustFHEGas,
Expand Down Expand Up @@ -132,6 +135,7 @@ func DefaultGasCosts() GasCosts {
tfhe.FheUint64: 76000 + AdjustFHEGas,
tfhe.FheUint160: 80000 + AdjustFHEGas,
},
FheArrayEqBigArrayFactor: 1000,
FheLe: map[tfhe.FheUintType]uint64{
tfhe.FheUint4: 60000 + AdjustFHEGas,
tfhe.FheUint8: 72000 + AdjustFHEGas,
Expand Down
Loading
Loading