Skip to content
This repository has been archived by the owner on Dec 23, 2024. It is now read-only.

Commit

Permalink
Merge pull request #49 from zama-ai/features/add-cmux
Browse files Browse the repository at this point in the history
feat() add ifThenElse to the stack
  • Loading branch information
immortal-tofu authored Jan 4, 2024
2 parents 8dd0820 + 062abc6 commit 43e4919
Show file tree
Hide file tree
Showing 6 changed files with 396 additions and 0 deletions.
114 changes: 114 additions & 0 deletions fhevm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ func toPrecompileInput(isScalar bool, hashes ...common.Hash) []byte {
return ret
}

func toPrecompileInputNoScalar(isScalar bool, hashes ...common.Hash) []byte {
ret := make([]byte, 0)
for _, hash := range hashes {
ret = append(ret, hash.Bytes()...)
}
return ret
}

var scalarBytePadding = make([]byte, 31)

func toLibPrecompileInput(method string, isScalar bool, hashes ...common.Hash) []byte {
Expand Down Expand Up @@ -1249,6 +1257,44 @@ func FheLibRandBounded(t *testing.T, fheUintType FheUintType, upperBound64 uint6
}
}

func FheLibIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) {
var second, third uint64
switch fheUintType {
case FheUint8:
second = 2
third = 1
case FheUint16:
second = 4283
third = 1337
case FheUint32:
second = 1333337
third = 133337
}
signature := "fheIfThenElse(uint256,uint256,uint256)"
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth
addr := common.Address{}
readOnly := false
firstHash := verifyCiphertextInTestMemory(environment, condition, depth, FheUint8).getHash()
secondHash := verifyCiphertextInTestMemory(environment, second, depth, fheUintType).getHash()
thirdHash := verifyCiphertextInTestMemory(environment, third, depth, fheUintType).getHash()
input := toLibPrecompileInputNoScalar(signature, firstHash, secondHash, thirdHash)
out, err := FheLibRun(environment, addr, addr, input, readOnly)
if err != nil {
t.Fatalf("VALUE %v", len(input))
// t.Fatalf(err.Error())
}
res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out))
if res == nil {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted, err := res.ciphertext.decrypt()
if err != nil || condition == 1 && decrypted.Uint64() != second || condition == 0 && decrypted.Uint64() != third {
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1)
}
}

func LibTrivialEncrypt(t *testing.T, fheUintType FheUintType) {
var value big.Int
switch fheUintType {
Expand Down Expand Up @@ -2352,6 +2398,44 @@ func FheNot(t *testing.T, fheUintType FheUintType, scalar bool) {
}
}


func FheIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) {
var lhs, rhs uint64
switch fheUintType {
case FheUint8:
lhs = 2
rhs = 1
case FheUint16:
lhs = 4283
rhs = 1337
case FheUint32:
lhs = 1333337
rhs = 133337
}
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth
addr := common.Address{}
readOnly := false
conditionHash := verifyCiphertextInTestMemory(environment, condition, depth, fheUintType).getHash()
lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).getHash()
rhsHash := verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).getHash()

input1 := toPrecompileInputNoScalar(false, conditionHash, lhsHash, rhsHash)
out, err := fheIfThenElseRun(environment, addr, addr, input1, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out))
if res == nil {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted, err := res.ciphertext.decrypt()
if err != nil || condition == 1 && decrypted.Uint64() != lhs || condition == 0 && decrypted.Uint64() != rhs {
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0)
}
}

func Decrypt(t *testing.T, fheUintType FheUintType) {
var value uint64
switch fheUintType {
Expand Down Expand Up @@ -2627,6 +2711,21 @@ func TestFheLibRandBounded32(t *testing.T) {
FheLibRandBounded(t, FheUint32, 32)
}

func TestFheLibIfThenElse8(t *testing.T) {
FheLibIfThenElse(t, FheUint8, 1)
FheLibIfThenElse(t, FheUint8, 0)
}

func TestFheLibIfThenElse16(t *testing.T) {
FheLibIfThenElse(t, FheUint16, 1)
FheLibIfThenElse(t, FheUint16, 0)
}

func TestFheLibIfThenElse32(t *testing.T) {
FheLibIfThenElse(t, FheUint32, 1)
FheLibIfThenElse(t, FheUint32, 0)
}

func TestFheLibTrivialEncrypt8(t *testing.T) {
LibTrivialEncrypt(t, FheUint8)
}
Expand Down Expand Up @@ -3079,6 +3178,21 @@ func TestFheNot32(t *testing.T) {
FheNot(t, FheUint32, false)
}

func TestFheIfThenElse8(t *testing.T) {
FheIfThenElse(t, FheUint8, 1)
FheIfThenElse(t, FheUint8, 0)
}

func TestFheIfThenElse16(t *testing.T) {
FheIfThenElse(t, FheUint16, 1)
FheIfThenElse(t, FheUint16, 0)
}

func TestFheIfThenElse32(t *testing.T) {
FheIfThenElse(t, FheUint32, 1)
FheIfThenElse(t, FheUint32, 0)
}

func TestFheScalarMax8(t *testing.T) {
FheMax(t, FheUint8, true)
}
Expand Down
20 changes: 20 additions & 0 deletions fhevm/evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,26 @@ func get2VerifiedOperands(environment EVMEnvironment, input []byte) (lhs *verifi
return
}

func get3VerifiedOperands(environment EVMEnvironment, input []byte) (first *verifiedCiphertext, second *verifiedCiphertext, third *verifiedCiphertext, err error) {
if len(input) != 96 {
return nil, nil, nil, errors.New("input needs to contain three 256-bit sized values")
}
first = getVerifiedCiphertext(environment, common.BytesToHash(input[0:32]))
if first == nil {
return nil, nil, nil, errors.New("unverified ciphertext handle")
}
second = getVerifiedCiphertext(environment, common.BytesToHash(input[32:64]))
if second == nil {
return nil, nil, nil, errors.New("unverified ciphertext handle")
}
third = getVerifiedCiphertext(environment, common.BytesToHash(input[64:96]))
if third == nil {
return nil, nil, nil, errors.New("unverified ciphertext handle")
}
err = nil
return
}

func getScalarOperands(environment EVMEnvironment, input []byte) (lhs *verifiedCiphertext, rhs *big.Int, err error) {
if len(input) != 65 {
return nil, nil, errors.New("input needs to contain two 256-bit sized values and 1 8-bit value")
Expand Down
6 changes: 6 additions & 0 deletions fhevm/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type GasCosts struct {
FheReencrypt map[FheUintType]uint64
FheTrivialEncrypt map[FheUintType]uint64
FheRand map[FheUintType]uint64
FheIfThenElse map[FheUintType]uint64
FheVerify map[FheUintType]uint64
FheOptRequire map[FheUintType]uint64
FheOptRequireBitAnd map[FheUintType]uint64
Expand Down Expand Up @@ -150,6 +151,11 @@ func DefaultGasCosts() GasCosts {
FheUint16: EvmNetSstoreInitGas + 2000,
FheUint32: EvmNetSstoreInitGas + 3000,
},
FheIfThenElse: map[FheUintType]uint64{
FheUint8: 61000,
FheUint16: 83000,
FheUint32: 109000,
},
// TODO: As of now, only support FheUint8. All optimistic require predicates are
// downcast to FheUint8 at the solidity level. Eventually move to ebool.
// If there is at least one optimistic require, we need to decrypt it as it was a normal FHE require.
Expand Down
57 changes: 57 additions & 0 deletions fhevm/precompiles.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ var signatureFheBitOr = makeKeccakSignature("fheBitOr(uint256,uint256,bytes1)")
var signatureFheBitXor = makeKeccakSignature("fheBitXor(uint256,uint256,bytes1)")
var signatureFheRand = makeKeccakSignature("fheRand(bytes1)")
var signatureFheRandBounded = makeKeccakSignature("fheRandBounded(uint256,bytes1)")
var signatureFheIfThenElse = makeKeccakSignature("fheIfThenElse(uint256,uint256,uint256)")
var signatureVerifyCiphertext = makeKeccakSignature("verifyCiphertext(bytes)")
var signatureReencrypt = makeKeccakSignature("reencrypt(uint256,uint256)")
var signatureOptimisticRequire = makeKeccakSignature("optimisticRequire(uint256)")
Expand Down Expand Up @@ -149,6 +150,9 @@ func FheLibRequiredGas(environment EVMEnvironment, input []byte) uint64 {
case signatureFheRandBounded:
bwCompatBytes := input[4:minInt(37, len(input))]
return fheRandBoundedRequiredGas(environment, bwCompatBytes)
case signatureFheIfThenElse:
bwCompatBytes := input[4:minInt(100, len(input))]
return fheIfThenElseRequiredGas(environment, bwCompatBytes)
case signatureVerifyCiphertext:
bwCompatBytes := input[4:]
return verifyCiphertextRequiredGas(environment, bwCompatBytes)
Expand Down Expand Up @@ -261,6 +265,9 @@ func FheLibRun(environment EVMEnvironment, caller common.Address, addr common.Ad
case signatureFheRandBounded:
bwCompatBytes := input[4:minInt(37, len(input))]
return fheRandBoundedRun(environment, caller, addr, bwCompatBytes, readOnly)
case signatureFheIfThenElse:
bwCompatBytes := input[4:minInt(100, len(input))]
return fheIfThenElseRun(environment, caller, addr, bwCompatBytes, readOnly)
case signatureVerifyCiphertext:
// first 32 bytes of the payload is offset, then 32 bytes are size of byte array
if len(input) <= 68 {
Expand Down Expand Up @@ -613,6 +620,24 @@ func fheRandBoundedRequiredGas(environment EVMEnvironment, input []byte) uint64
return environment.FhevmParams().GasCosts.FheRand[randType]
}

func fheIfThenElseRequiredGas(environment EVMEnvironment, input []byte) uint64 {
logger := environment.GetLogger()
first, second, third, err := get3VerifiedOperands(environment, input)
if err != nil {
logger.Error("IfThenElse op RequiredGas() inputs not verified", "err", err, "input", hex.EncodeToString(input))
return 0
}
if first.ciphertext.fheUintType != FheUint8 {
logger.Error("IfThenElse op RequiredGas() invalid type for condition", "first", first.ciphertext.fheUintType)
return 0
}
if second.ciphertext.fheUintType != third.ciphertext.fheUintType {
logger.Error("IfThenElse op RequiredGas() operand type mismatch", "second", second.ciphertext.fheUintType, "third", third.ciphertext.fheUintType)
return 0
}
return environment.FhevmParams().GasCosts.FheIfThenElse[second.ciphertext.fheUintType]
}

func verifyCiphertextRequiredGas(environment EVMEnvironment, input []byte) uint64 {
if len(input) <= 1 {
environment.GetLogger().Error(
Expand Down Expand Up @@ -1930,6 +1955,38 @@ func fheRandBoundedRun(environment EVMEnvironment, caller common.Address, addr c
return generateRandom(environment, caller, randType, &bound64)
}


func fheIfThenElseRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
logger := environment.GetLogger()
first, second, third, err := get3VerifiedOperands(environment, input)
if err != nil {
logger.Error("fheIfThenElse inputs not verified", "err", err, "input", hex.EncodeToString(input))
return nil, err
}

if second.ciphertext.fheUintType != third.ciphertext.fheUintType {
msg := "fheIfThenElse operand type mismatch"
logger.Error(msg, "second", second.ciphertext.fheUintType, "third", third.ciphertext.fheUintType)
return nil, errors.New(msg)
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !environment.IsCommitting() && !environment.IsEthCall() {
return importRandomCiphertext(environment, second.ciphertext.fheUintType), nil
}

result, err := first.ciphertext.ifThenElse(second.ciphertext, third.ciphertext)
if err != nil {
logger.Error("fheIfThenElse failed", "err", err)
return nil, err
}
importCiphertext(environment, result)

resultHash := result.getHash()
logger.Info("fheIfThenElse success", "first", first.ciphertext.getHash().Hex(), "second", second.ciphertext.getHash().Hex(), "third", third.ciphertext.getHash().Hex(), "result", resultHash.Hex())
return resultHash[:], nil
}

func verifyCiphertextRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
logger := environment.GetLogger()
if len(input) <= 1 {
Expand Down
Loading

0 comments on commit 43e4919

Please sign in to comment.