From 06e9cdfbe528af01d254398a6d7934eec3654146 Mon Sep 17 00:00:00 2001 From: Petar Ivanov <29689712+dartdart26@users.noreply.github.com> Date: Thu, 23 May 2024 11:17:28 +0200 Subject: [PATCH] feat: array equality as a precompile MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Uses geth's abi package to parse uint256[] Solidity inputs. Add missing gas costs for FheCast and FhePubKey. Co-authored-by: Clément Danjou --- fhevm/contracts_test.go | 769 ++++++++++++++++++++++++++++++ fhevm/fhelib.go | 6 + fhevm/operators_comparison.go | 109 +++++ fhevm/operators_comparison_gas.go | 65 +++ fhevm/params.go | 4 + fhevm/tfhe/tfhe_ciphertext.go | 182 +++++-- fhevm/tfhe/tfhe_test.go | 238 +++++++++ fhevm/tfhe/tfhe_wrappers.c | 55 +++ fhevm/tfhe/tfhe_wrappers.h | 10 + tfhe-rs | 2 +- 10 files changed, 1407 insertions(+), 33 deletions(-) diff --git a/fhevm/contracts_test.go b/fhevm/contracts_test.go index 28b602b..60f5f39 100644 --- a/fhevm/contracts_test.go +++ b/fhevm/contracts_test.go @@ -2942,6 +2942,555 @@ func FheRand(t *testing.T, fheUintType tfhe.FheUintType) { } } +func FheArrayEq(t *testing.T, fheUintType tfhe.FheUintType) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + out, err := fheArrayEqRun(environment, addr, addr, input, readOnly, nil) + if err != nil { + t.Fatalf(err.Error()) + } else if len(out) != 32 { + t.Fatalf("fheArrayEq expected output len of 32, got %v", len(out)) + } + + if len(environment.FhevmData().verifiedCiphertexts) != 7 { + t.Fatalf("fheArrayEq expected 7 verified ciphertext") + } + + hash := common.BytesToHash(out) + decrypted, err := environment.FhevmData().verifiedCiphertexts[hash].ciphertext.Decrypt() + if err != nil { + t.Fatalf(err.Error()) + } + if !decrypted.IsUint64() || decrypted.Uint64() != 1 { + t.Fatalf("fheArrayEq expected result of 1, got: %s", decrypted.String()) + } +} + +func FheArrayEqGas(t *testing.T, fheUintType tfhe.FheUintType) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + gas := fheArrayEqRequiredGas(environment, input) + numBits := 3 * fheUintType.NumBits() + if numBits <= 4 && gas != environment.fhevmParams.GasCosts.FheEq[tfhe.FheUint4] { + t.Fatalf("fheArrayEq unexpected gas value") + } else if numBits > 4 && numBits <= 8 && gas != environment.fhevmParams.GasCosts.FheEq[tfhe.FheUint8] { + t.Fatalf("fheArrayEq unexpected gas value") + } else if numBits > 8 && numBits <= 16 && gas != environment.fhevmParams.GasCosts.FheEq[tfhe.FheUint16] { + t.Fatalf("fheArrayEq unexpected gas value") + } else if numBits > 16 && numBits <= 32 && gas != environment.fhevmParams.GasCosts.FheEq[tfhe.FheUint32] { + t.Fatalf("fheArrayEq unexpected gas value") + } else if numBits > 32 && numBits <= 64 && gas != environment.fhevmParams.GasCosts.FheEq[tfhe.FheUint64] { + t.Fatalf("fheArrayEq unexpected gas value") + } else if numBits > 64 && numBits <= 160 && gas != environment.fhevmParams.GasCosts.FheEq[tfhe.FheUint160] { + t.Fatalf("fheArrayEq unexpected gas value") + } else if numBits > 160 && gas <= environment.fhevmParams.GasCosts.FheEq[tfhe.FheUint160] { + t.Fatalf("fheArrayEq unexpected gas value") + } +} + +func FheArrayEqSameLenNotEqual(t *testing.T, fheUintType tfhe.FheUintType) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 6, depth, fheUintType).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + out, err := fheArrayEqRun(environment, addr, addr, input, readOnly, nil) + if err != nil { + t.Fatalf(err.Error()) + } else if len(out) != 32 { + t.Fatalf("fheArrayEq expected output len of 32, got %v", len(out)) + } + + if len(environment.FhevmData().verifiedCiphertexts) != 7 { + t.Fatalf("fheArrayEq expected 7 verified ciphertext") + } + + hash := common.BytesToHash(out) + decrypted, err := environment.FhevmData().verifiedCiphertexts[hash].ciphertext.Decrypt() + if err != nil { + t.Fatalf(err.Error()) + } + if !decrypted.IsUint64() || decrypted.Uint64() != 0 { + t.Fatalf("fheArrayEq expected result of 0, got: %s", decrypted.String()) + } +} + +func FheArrayEqDifferentLen(t *testing.T, fheUintType tfhe.FheUintType) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + lhs := make([]*big.Int, 2) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 6, depth, fheUintType).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + out, err := fheArrayEqRun(environment, addr, addr, input, readOnly, nil) + if err != nil { + t.Fatalf(err.Error()) + } else if len(out) != 32 { + t.Fatalf("fheArrayEq expected output len of 32, got %v", len(out)) + } + + if len(environment.FhevmData().verifiedCiphertexts) != 6 { + t.Fatalf("fheArrayEq expected 6 verified ciphertext") + } + + hash := common.BytesToHash(out) + decrypted, err := environment.FhevmData().verifiedCiphertexts[hash].ciphertext.Decrypt() + if err != nil { + t.Fatalf(err.Error()) + } + if !decrypted.IsUint64() || decrypted.Uint64() != 0 { + t.Fatalf("fheArrayEq expected result of 0, got: %s", decrypted.String()) + } +} + +func FheArrayEqDifferentLenGas(t *testing.T, fheUintType tfhe.FheUintType) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + + lhs := make([]*big.Int, 2) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + gas := fheArrayEqRequiredGas(environment, input) + if gas != environment.fhevmParams.GasCosts.FheTrivialEncrypt[tfhe.FheBool] { + t.Fatalf("fheArrayEq expected trivial encryption of bool gas value") + } +} + +func FheArrayEqBothEmpty(t *testing.T, fheUintType tfhe.FheUintType) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + lhs := make([]*big.Int, 0) + rhs := make([]*big.Int, 0) + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + out, err := fheArrayEqRun(environment, addr, addr, input, readOnly, nil) + if err != nil { + t.Fatalf(err.Error()) + } else if len(out) != 32 { + t.Fatalf("fheArrayEq expected output len of 32, got %v", len(out)) + } + + if len(environment.FhevmData().verifiedCiphertexts) != 1 { + t.Fatalf("fheArrayEq expected 1 verified ciphertext") + } + + hash := common.BytesToHash(out) + decrypted, err := environment.FhevmData().verifiedCiphertexts[hash].ciphertext.Decrypt() + if err != nil { + t.Fatalf(err.Error()) + } + if !decrypted.IsUint64() || decrypted.Uint64() != 1 { + t.Fatalf("fheArrayEq expected result of 1, got: %s", decrypted.String()) + } +} + +func FheArrayEqBothEmptyGas(t *testing.T, fheUintType tfhe.FheUintType) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + + lhs := make([]*big.Int, 0) + rhs := make([]*big.Int, 0) + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + gas := fheArrayEqRequiredGas(environment, input) + if gas != environment.fhevmParams.GasCosts.FheTrivialEncrypt[tfhe.FheBool] { + t.Fatalf("fheArrayEq expected trivial encryption of bool gas value") + } +} + +func TestFheArrayEqDifferentTypesInLhs(t *testing.T) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint16).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint16).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint16).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint16).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint16).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + _, err := fheArrayEqRun(environment, addr, addr, input, readOnly, nil) + if err == nil { + t.Fatalf("fheArrayEq expected an error") + } + + if len(environment.FhevmData().verifiedCiphertexts) != 6 { + t.Fatalf("fheArrayEq expected 6 verified ciphertext") + } +} + +func TestFheArrayEqDifferentTypesInLhsGas(t *testing.T) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint16).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint16).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint16).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint16).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint16).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + gas := fheArrayEqRequiredGas(environment, input) + if gas != 0 { + t.Fatalf("fheArrayEq expected 0 gas value") + } +} + +func TestFheArrayEqDifferentTypesInRhs(t *testing.T) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint16).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint16).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint16).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint16).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint16).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + _, err := fheArrayEqRun(environment, addr, addr, input, readOnly, nil) + if err == nil { + t.Fatalf("fheArrayEq expected an error") + } + + if len(environment.FhevmData().verifiedCiphertexts) != 6 { + t.Fatalf("fheArrayEq expected 6 verified ciphertext") + } +} + +func TestFheArrayEqDifferentTypesInRhsGas(t *testing.T) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint16).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint16).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint16).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint16).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint16).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + gas := fheArrayEqRequiredGas(environment, input) + if gas != 0 { + t.Fatalf("fheArrayEq expected 0 gas value") + } +} + +func TestFheArrayEqUnsupportedType(t *testing.T) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheBool).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheBool).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 0, depth, tfhe.FheBool).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheBool).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheBool).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 0, depth, tfhe.FheBool).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + _, err := fheArrayEqRun(environment, addr, addr, input, readOnly, nil) + if err == nil { + t.Fatalf("fheArrayEq expected an error") + } + + if len(environment.FhevmData().verifiedCiphertexts) != 6 { + t.Fatalf("fheArrayEq expected 6 verified ciphertext") + } +} + +func TestFheArrayEqUnsupportedTypeGas(t *testing.T) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheBool).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheBool).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 0, depth, tfhe.FheBool).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheBool).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheBool).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 0, depth, tfhe.FheBool).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + gas := fheArrayEqRequiredGas(environment, input) + if gas != 0 { + t.Fatalf("fheArrayEq expected 0 gas value") + } +} + +func FheArrayEqInsufficientBytes(t *testing.T, fheUintType tfhe.FheUintType) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + input = input[:37] + + _, err := fheArrayEqRun(environment, addr, addr, input, readOnly, nil) + if err == nil { + t.Fatalf("fheArrayEq expected an error") + } + + if len(environment.FhevmData().verifiedCiphertexts) != 6 { + t.Fatalf("fheArrayEq expected 6 verified ciphertext") + } +} + +func FheArrayEqInsufficientBytesGas(t *testing.T, fheUintType tfhe.FheUintType) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + input = input[:37] + + gas := fheArrayEqRequiredGas(environment, input) + if gas != 0 { + t.Fatalf("fheArrayEq expected 0 gas value") + } +} + +func FheArrayEqNoRhs(t *testing.T, fheUintType tfhe.FheUintType) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs) + + _, err := fheArrayEqRun(environment, addr, addr, input, readOnly, nil) + if err == nil { + t.Fatalf("fheArrayEq expected an error") + } + + if len(environment.FhevmData().verifiedCiphertexts) != 3 { + t.Fatalf("fheArrayEq expected 3 verified ciphertext") + } +} + +func FheArrayEqNoRhsGas(t *testing.T, fheUintType tfhe.FheUintType) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs) + + gas := fheArrayEqRequiredGas(environment, input) + if gas != 0 { + t.Fatalf("fheArrayEq expected 0 gas value") + } +} + +func TestFheArrayEqUnverifiedCtInLhs(t *testing.T) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big() + lhs[0].Add(lhs[0], big.NewInt(1)) + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + _, err := fheArrayEqRun(environment, addr, addr, input, readOnly, nil) + if err == nil { + t.Fatalf("fheArrayEq expected an error") + } + + if len(environment.FhevmData().verifiedCiphertexts) != 6 { + t.Fatalf("fheArrayEq expected 6 verified ciphertext") + } +} + +func TestFheArrayEqUnverifiedCtInLhsGas(t *testing.T) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big() + lhs[0].Add(lhs[0], big.NewInt(1)) + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big() + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + gas := fheArrayEqRequiredGas(environment, input) + if gas != 0 { + t.Fatalf("fheArrayEq expected 0 gas value") + } +} + +func TestFheArrayEqUnverifiedCtInRhs(t *testing.T) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big() + + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big() + rhs[1].Add(rhs[1], big.NewInt(1)) + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + _, err := fheArrayEqRun(environment, addr, addr, input, readOnly, nil) + if err == nil { + t.Fatalf("fheArrayEq expected an error") + } + + if len(environment.FhevmData().verifiedCiphertexts) != 6 { + t.Fatalf("fheArrayEq expected 6 verified ciphertext") + } +} + +func TestFheArrayEqUnverifiedCtInRhsGas(t *testing.T) { + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + + lhs := make([]*big.Int, 3) + lhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big() + lhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big() + lhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big() + rhs := make([]*big.Int, 3) + rhs[0] = verifyCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big() + rhs[1] = verifyCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big() + rhs[1].Add(lhs[0], big.NewInt(1)) + rhs[2] = verifyCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big() + input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs) + + gas := fheArrayEqRequiredGas(environment, input) + if gas != 0 { + t.Fatalf("fheArrayEq expected 0 gas value") + } +} + func TestVerifyCiphertextInvalidType(t *testing.T) { depth := 1 environment := newTestEVMEnvironment() @@ -4496,3 +5045,223 @@ func TestFheLibGetCiphertext32(t *testing.T) { func TestFheLibGetCiphertext64(t *testing.T) { FheLibGetCiphertext(t, tfhe.FheUint64) } + +func TestFheArrayEq4(t *testing.T) { + FheArrayEq(t, tfhe.FheUint4) +} + +func TestFheArrayEq8(t *testing.T) { + FheArrayEq(t, tfhe.FheUint8) +} + +func TestFheArrayEq16(t *testing.T) { + FheArrayEq(t, tfhe.FheUint16) +} + +func TestFheArrayEq32(t *testing.T) { + FheArrayEq(t, tfhe.FheUint32) +} + +func TestFheArrayEq64(t *testing.T) { + FheArrayEq(t, tfhe.FheUint64) +} + +func TestFheArrayEqGas4(t *testing.T) { + FheArrayEqGas(t, tfhe.FheUint4) +} + +func TestFheArrayEqGas8(t *testing.T) { + FheArrayEqGas(t, tfhe.FheUint8) +} + +func TestFheArrayEqGas16(t *testing.T) { + FheArrayEqGas(t, tfhe.FheUint16) +} + +func TestFheArrayEqGas32(t *testing.T) { + FheArrayEqGas(t, tfhe.FheUint32) +} + +func TestFheArrayEqGas64(t *testing.T) { + FheArrayEqGas(t, tfhe.FheUint64) +} + +func TestFheArrayEqSameLenNotEqual4(t *testing.T) { + FheArrayEqSameLenNotEqual(t, tfhe.FheUint4) +} + +func TestFheArrayEqSameLenNotEqual8(t *testing.T) { + FheArrayEqSameLenNotEqual(t, tfhe.FheUint8) +} + +func TestFheArrayEqSameLenNotEqual16(t *testing.T) { + FheArrayEqSameLenNotEqual(t, tfhe.FheUint16) +} + +func TestFheArrayEqSameLenNotEqual32(t *testing.T) { + FheArrayEqSameLenNotEqual(t, tfhe.FheUint32) +} + +func TestFheArrayEqSameLenNotEqual64(t *testing.T) { + FheArrayEqSameLenNotEqual(t, tfhe.FheUint64) +} + +func TestFheArrayNotEqDifferentLen4(t *testing.T) { + FheArrayEqDifferentLen(t, tfhe.FheUint4) +} + +func TestFheArrayEqDifferentLen8(t *testing.T) { + FheArrayEqDifferentLen(t, tfhe.FheUint8) +} + +func TestFheArrayEqDifferentLen16(t *testing.T) { + FheArrayEqDifferentLen(t, tfhe.FheUint16) +} + +func TesFheArrayEqDifferentLen32(t *testing.T) { + FheArrayEqDifferentLen(t, tfhe.FheUint32) +} + +func TestFheArrayEqDifferentLen64(t *testing.T) { + FheArrayEqDifferentLen(t, tfhe.FheUint64) +} + +func TestFheArrayNotEqDifferentLenGas4(t *testing.T) { + FheArrayEqDifferentLenGas(t, tfhe.FheUint4) +} + +func TestFheArrayEqDifferentLenGas8(t *testing.T) { + FheArrayEqDifferentLenGas(t, tfhe.FheUint8) +} + +func TestFheArrayEqDifferentLenGas16(t *testing.T) { + FheArrayEqDifferentLenGas(t, tfhe.FheUint16) +} + +func TesFheArrayEqDifferentLenGas32(t *testing.T) { + FheArrayEqDifferentLenGas(t, tfhe.FheUint32) +} + +func TestFheArrayEqDifferentLenGas64(t *testing.T) { + FheArrayEqDifferentLenGas(t, tfhe.FheUint64) +} + +func TestFheArrayEqBothEmpty4(t *testing.T) { + FheArrayEqBothEmpty(t, tfhe.FheUint4) +} + +func TestFheArrayEqBothEmpty8(t *testing.T) { + FheArrayEqBothEmpty(t, tfhe.FheUint8) +} + +func TestFheArrayEqBothEmpty16(t *testing.T) { + FheArrayEqBothEmpty(t, tfhe.FheUint16) +} + +func TesFheArrayEqBothEmpty32(t *testing.T) { + FheArrayEqBothEmpty(t, tfhe.FheUint32) +} + +func TestFheArrayEqBothEmpty64(t *testing.T) { + FheArrayEqBothEmpty(t, tfhe.FheUint64) +} + +func TestFheArrayEqBothEmptyGas4(t *testing.T) { + FheArrayEqBothEmptyGas(t, tfhe.FheUint4) +} + +func TestFheArrayEqBothEmptyGas8(t *testing.T) { + FheArrayEqBothEmptyGas(t, tfhe.FheUint8) +} + +func TestFheArrayEqBothEmptyGas16(t *testing.T) { + FheArrayEqBothEmptyGas(t, tfhe.FheUint16) +} + +func TesFheArrayEqBothEmptyGas32(t *testing.T) { + FheArrayEqBothEmptyGas(t, tfhe.FheUint32) +} + +func TestFheArrayEqBothEmptyGas64(t *testing.T) { + FheArrayEqBothEmptyGas(t, tfhe.FheUint64) +} + +func TestFheArrayEqInsufficientBytes4(t *testing.T) { + FheArrayEqInsufficientBytes(t, tfhe.FheUint4) +} + +func TestFheArrayEqInsufficientBytes8(t *testing.T) { + FheArrayEqInsufficientBytes(t, tfhe.FheUint8) +} + +func TestFheArrayEqInsufficientBytes16(t *testing.T) { + FheArrayEqInsufficientBytes(t, tfhe.FheUint16) +} + +func TesFheArrayEqInsufficientBytes32(t *testing.T) { + FheArrayEqInsufficientBytes(t, tfhe.FheUint32) +} + +func TestFheArrayEqInsufficientBytes64(t *testing.T) { + FheArrayEqInsufficientBytes(t, tfhe.FheUint64) +} + +func TestFheArrayEqInsufficientBytesGas4(t *testing.T) { + FheArrayEqInsufficientBytesGas(t, tfhe.FheUint4) +} + +func TestFheArrayEqInsufficientBytesGas8(t *testing.T) { + FheArrayEqInsufficientBytesGas(t, tfhe.FheUint8) +} + +func TestFheArrayEqInsufficientBytesGas16(t *testing.T) { + FheArrayEqInsufficientBytesGas(t, tfhe.FheUint16) +} + +func TesFheArrayEqInsufficientBytesGas32(t *testing.T) { + FheArrayEqInsufficientBytes(t, tfhe.FheUint32) +} + +func TestFheArrayEqInsufficientBytesGas64(t *testing.T) { + FheArrayEqInsufficientBytes(t, tfhe.FheUint64) +} + +func TestFheArrayEqNoRhs4(t *testing.T) { + FheArrayEqNoRhs(t, tfhe.FheUint4) +} + +func TestFheArrayEqNoRhs8(t *testing.T) { + FheArrayEqNoRhs(t, tfhe.FheUint8) +} + +func TestFheArrayEqNoRhs16(t *testing.T) { + FheArrayEqNoRhs(t, tfhe.FheUint16) +} + +func TesFheArrayEqNoRhs32(t *testing.T) { + FheArrayEqNoRhs(t, tfhe.FheUint32) +} + +func TestFheArrayEqNoRhs64(t *testing.T) { + FheArrayEqNoRhs(t, tfhe.FheUint64) +} + +func TestFheArrayEqNoRhsGas4(t *testing.T) { + FheArrayEqNoRhsGas(t, tfhe.FheUint4) +} + +func TestFheArrayEqNoRhsGas8(t *testing.T) { + FheArrayEqNoRhsGas(t, tfhe.FheUint8) +} + +func TestFheArrayEqNoRhsGas16(t *testing.T) { + FheArrayEqNoRhsGas(t, tfhe.FheUint16) +} + +func TesFheArrayEqNoRhsGas32(t *testing.T) { + FheArrayEqNoRhsGas(t, tfhe.FheUint32) +} + +func TestFheArrayEqNoRhsGas64(t *testing.T) { + FheArrayEqNoRhsGas(t, tfhe.FheUint64) +} diff --git a/fhevm/fhelib.go b/fhevm/fhelib.go index 7139895..b3e955e 100644 --- a/fhevm/fhelib.go +++ b/fhevm/fhelib.go @@ -208,6 +208,12 @@ var fhelibMethods = []*FheLibMethod{ requiredGasFunction: fheIfThenElseRequiredGas, runFunction: fheIfThenElseRun, }, + { + name: "fheArrayEq", + argTypes: "(uint256[],uint256[])", + requiredGasFunction: fheArrayEqRequiredGas, + runFunction: fheArrayEqRun, + }, { name: "fhePubKey", argTypes: "(bytes1)", diff --git a/fhevm/operators_comparison.go b/fhevm/operators_comparison.go index 3df181d..6c68faa 100644 --- a/fhevm/operators_comparison.go +++ b/fhevm/operators_comparison.go @@ -3,7 +3,11 @@ package fhevm import ( "encoding/hex" "errors" + "fmt" + "math/big" + "strings" + "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" "github.com/zama-ai/fhevm-go/fhevm/tfhe" "go.opentelemetry.io/otel/trace" @@ -570,3 +574,108 @@ func fheIfThenElseRun(environment EVMEnvironment, caller common.Address, addr co logger.Info("fheIfThenElse success", "first", first.hash().Hex(), "second", second.hash().Hex(), "third", third.hash().Hex(), "result", resultHash.Hex()) return resultHash[:], nil } + +// TODO: implement as part of fhelibMethods. +const fheArrayEqAbiJson = ` + [ + { + "name": "fheArrayEq", + "type": "function", + "inputs": [ + { + "name": "lhs", + "type": "uint256[]" + }, + { + "name": "rhs", + "type": "uint256[]" + } + ], + "outputs": [ + { + "name": "", + "type": "uint256" + } + ] + } + ] +` + +var arrayEqMethod abi.Method + +func init() { + reader := strings.NewReader(fheArrayEqAbiJson) + arrayEqAbi, err := abi.JSON(reader) + if err != nil { + panic(err) + } + + var ok bool + arrayEqMethod, ok = arrayEqAbi.Methods["fheArrayEq"] + if !ok { + panic("couldn't find the fheArrayEq method") + } +} + +func getVerifiedCiphertexts(environment EVMEnvironment, unpacked interface{}) ([]*tfhe.TfheCiphertext, error) { + big, ok := unpacked.([]*big.Int) + if !ok { + return nil, fmt.Errorf("fheArrayEq failed to cast to []*big.Int") + } + ret := make([]*tfhe.TfheCiphertext, 0, len(big)) + for _, b := range big { + ct := getVerifiedCiphertext(environment, common.BigToHash(b)) + if ct == nil { + return nil, fmt.Errorf("fheArrayEq unverified ciphertext") + } + ret = append(ret, ct.ciphertext) + } + return ret, nil +} + +func fheArrayEqRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) { + logger := environment.GetLogger() + + unpacked, err := arrayEqMethod.Inputs.UnpackValues(input) + if err != nil { + msg := "fheArrayEqRun failed to unpack input" + logger.Error(msg, "err", err) + return nil, err + } + + if len(unpacked) != 2 { + err := fmt.Errorf("fheArrayEqRun unexpected unpacked len: %d", len(unpacked)) + logger.Error(err.Error()) + return nil, err + } + + lhs, err := getVerifiedCiphertexts(environment, unpacked[0]) + if err != nil { + msg := "fheArrayEqRun failed to get lhs to verified ciphertexts" + logger.Error(msg, "err", err) + return nil, err + } + + rhs, err := getVerifiedCiphertexts(environment, unpacked[1]) + if err != nil { + msg := "fheArrayEqRun failed to get rhs to verified ciphertexts" + logger.Error(msg, "err", err) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !environment.IsCommitting() && !environment.IsEthCall() { + return importRandomCiphertext(environment, tfhe.FheBool), nil + } + + result, err := tfhe.EqArray(lhs, rhs) + if err != nil { + msg := "fheArrayEqRun failed to execute" + logger.Error(msg, "err", err) + return nil, err + } + importCiphertext(environment, result) + resultHash := result.GetHash() + logger.Info("fheArrayEqRun success", "result", resultHash.Hex()) + return resultHash[:], nil +} diff --git a/fhevm/operators_comparison_gas.go b/fhevm/operators_comparison_gas.go index 563bfad..7417509 100644 --- a/fhevm/operators_comparison_gas.go +++ b/fhevm/operators_comparison_gas.go @@ -2,6 +2,7 @@ package fhevm import ( "encoding/hex" + "fmt" "github.com/zama-ai/fhevm-go/fhevm/tfhe" ) @@ -141,3 +142,67 @@ func fheIfThenElseRequiredGas(environment EVMEnvironment, input []byte) uint64 { } return environment.FhevmParams().GasCosts.FheIfThenElse[second.fheUintType()] } + +func fheArrayEqRequiredGas(environment EVMEnvironment, input []byte) uint64 { + logger := environment.GetLogger() + + unpacked, err := arrayEqMethod.Inputs.UnpackValues(input) + if err != nil { + msg := "fheArrayEqRun RequiredGas() failed to unpack input" + logger.Error(msg, "err", err) + return 0 + } + + if len(unpacked) != 2 { + err := fmt.Errorf("fheArrayEqRun RequiredGas() unexpected unpacked len: %d", len(unpacked)) + logger.Error(err.Error()) + return 0 + } + + lhs, err := getVerifiedCiphertexts(environment, unpacked[0]) + if err != nil { + msg := "fheArrayEqRun RequiredGas() failed to get lhs to verified ciphertexts" + logger.Error(msg, "err", err) + return 0 + } + + rhs, err := getVerifiedCiphertexts(environment, unpacked[1]) + if err != nil { + msg := "fheArrayEqRun RequiredGas() failed to get rhs to verified ciphertexts" + logger.Error(msg, "err", err) + return 0 + } + + if len(lhs) != len(rhs) || (len(lhs) == 0 && len(rhs) == 0) { + return environment.FhevmParams().GasCosts.FheTrivialEncrypt[tfhe.FheBool] + } + + numElements := len(lhs) + elementType := lhs[0].Type() + // TODO: tie to supported types in tfhe.TfheCiphertext.EqArray() + if elementType != tfhe.FheUint4 && elementType != tfhe.FheUint8 && elementType != tfhe.FheUint16 && elementType != tfhe.FheUint32 && elementType != tfhe.FheUint64 { + return 0 + } + for i := range lhs { + if lhs[i].Type() != elementType || rhs[i].Type() != elementType { + return 0 + } + } + + numBits := elementType.NumBits() * uint(numElements) + if numBits <= 4 { + return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint4] + } else if numBits <= 8 { + return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint8] + } else if numBits <= 16 { + return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint16] + } else if numBits <= 32 { + return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint32] + } else if numBits <= 64 { + return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint64] + } else if numBits <= 160 { + return environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint160] + } else { + return (environment.FhevmParams().GasCosts.FheEq[tfhe.FheUint160] + environment.FhevmParams().GasCosts.FheArrayEqBigArrayFactor) * (uint64(numBits) / 160) + } +} diff --git a/fhevm/params.go b/fhevm/params.go index 2353c99..02a5be0 100644 --- a/fhevm/params.go +++ b/fhevm/params.go @@ -43,6 +43,7 @@ type GasCosts struct { FheShift map[tfhe.FheUintType]uint64 FheScalarShift map[tfhe.FheUintType]uint64 FheEq map[tfhe.FheUintType]uint64 + FheArrayEqBigArrayFactor uint64 // TODO: either rename or come up with a better solution FheLe map[tfhe.FheUintType]uint64 FheMinMax map[tfhe.FheUintType]uint64 FheScalarMinMax map[tfhe.FheUintType]uint64 @@ -60,6 +61,8 @@ type GasCosts struct { func DefaultGasCosts() GasCosts { return GasCosts{ + FheCast: 200, + FhePubKey: 50, FheAddSub: map[tfhe.FheUintType]uint64{ tfhe.FheUint4: 55000 + AdjustFHEGas, tfhe.FheUint8: 84000 + AdjustFHEGas, @@ -132,6 +135,7 @@ func DefaultGasCosts() GasCosts { tfhe.FheUint64: 76000 + AdjustFHEGas, tfhe.FheUint160: 80000 + AdjustFHEGas, }, + FheArrayEqBigArrayFactor: 1000, FheLe: map[tfhe.FheUintType]uint64{ tfhe.FheUint4: 60000 + AdjustFHEGas, tfhe.FheUint8: 72000 + AdjustFHEGas, diff --git a/fhevm/tfhe/tfhe_ciphertext.go b/fhevm/tfhe/tfhe_ciphertext.go index cf306e6..529f7a2 100644 --- a/fhevm/tfhe/tfhe_ciphertext.go +++ b/fhevm/tfhe/tfhe_ciphertext.go @@ -6,6 +6,7 @@ package tfhe import "C" import ( "errors" + "fmt" "math/big" "unsafe" @@ -46,7 +47,30 @@ func (t FheUintType) String() string { case FheUint160: return "fheUint160" default: - return "unknownFheUintType" + return "unknown FheUintType" + } +} + +func (t FheUintType) NumBits() uint { + switch t { + case FheBool: + return 1 + case FheUint4: + return 4 + case FheUint8: + return 8 + case FheUint16: + return 16 + case FheUint32: + return 32 + case FheUint64: + return 64 + case FheUint128: + return 128 + case FheUint160: + return 160 + default: + panic("unknown FheUintType") } } @@ -87,54 +111,63 @@ func boolUnaryNotSupportedOp(lhs unsafe.Pointer) (unsafe.Pointer, error) { return nil, errors.New("Bool is not supported") } -// Deserializes a TFHE ciphertext. -func (ct *TfheCiphertext) Deserialize(in []byte, t FheUintType) error { +// Deserializes `in` and returns a C pointer to the ciphertext. +// Expects that the caller will destroy the returned ciphertext via destroyCiphertext(). +func Deserialize(in []byte, t FheUintType) unsafe.Pointer { + switch t { + case FheBool: + return C.deserialize_fhe_bool(toDynamicBufferView(in)) + case FheUint4: + return C.deserialize_fhe_uint4(toDynamicBufferView(in)) + case FheUint8: + return C.deserialize_fhe_uint8(toDynamicBufferView(in)) + case FheUint16: + return C.deserialize_fhe_uint16(toDynamicBufferView(in)) + case FheUint32: + return C.deserialize_fhe_uint32(toDynamicBufferView(in)) + case FheUint64: + return C.deserialize_fhe_uint64(toDynamicBufferView(in)) + case FheUint160: + return C.deserialize_fhe_uint160(toDynamicBufferView(in)) + default: + panic("Deserialize: unexpected ciphertext type") + } +} + +// Destroys the ciphertext that is pointed to by `ptr`. +func destroyCiphertext(ptr unsafe.Pointer, t FheUintType) { switch t { case FheBool: - ptr := C.deserialize_fhe_bool(toDynamicBufferView((in))) - if ptr == nil { - return errors.New("FheBool ciphertext deserialization failed") - } C.destroy_fhe_bool(ptr) case FheUint4: - ptr := C.deserialize_fhe_uint4(toDynamicBufferView((in))) - if ptr == nil { - return errors.New("FheUint4 ciphertext deserialization failed") - } C.destroy_fhe_uint4(ptr) case FheUint8: - ptr := C.deserialize_fhe_uint8(toDynamicBufferView((in))) - if ptr == nil { - return errors.New("FheUint8 ciphertext deserialization failed") - } C.destroy_fhe_uint8(ptr) case FheUint16: - ptr := C.deserialize_fhe_uint16(toDynamicBufferView((in))) - if ptr == nil { - return errors.New("FheUint16 ciphertext deserialization failed") - } C.destroy_fhe_uint16(ptr) case FheUint32: - ptr := C.deserialize_fhe_uint32(toDynamicBufferView((in))) - if ptr == nil { - return errors.New("FheUint32 ciphertext deserialization failed") - } C.destroy_fhe_uint32(ptr) case FheUint64: - ptr := C.deserialize_fhe_uint64(toDynamicBufferView((in))) - if ptr == nil { - return errors.New("FheUint64 ciphertext deserialization failed") - } C.destroy_fhe_uint64(ptr) case FheUint160: - ptr := C.deserialize_fhe_uint160(toDynamicBufferView((in))) - if ptr == nil { - return errors.New("FheUint160 ciphertext deserialization failed") - } C.destroy_fhe_uint160(ptr) default: - panic("deserialize: unexpected ciphertext type") + panic("destroyCiphertext: unexpected ciphertext type") + } +} + +// Expects that the caller will destroy the pointer via destroyCiphertext(). +func (ct *TfheCiphertext) DeserializeToPtr() unsafe.Pointer { + return Deserialize(ct.Serialize(), ct.FheUintType) +} + +// Deserializes a TFHE ciphertext. +func (ct *TfheCiphertext) Deserialize(in []byte, t FheUintType) error { + ptr := Deserialize(in, t) + if ptr == nil { + return fmt.Errorf("%s ciphertext deserialization failed", t.String()) } + destroyCiphertext(ptr, t) ct.FheUintType = t ct.Serialization = in ct.computeHash() @@ -2562,3 +2595,88 @@ func (ct *TfheCiphertext) GetHash() common.Hash { ct.computeHash() return *ct.Hash } + +// Caller is responsible for freeing the returned pointers. +func arrayToCiphertextPointerArray(arr []*TfheCiphertext, expectedType FheUintType) []unsafe.Pointer { + ret := make([]unsafe.Pointer, 0, len(arr)) + for _, ct := range arr { + if ct.Type() != expectedType { + return ret + } + ptr := ct.DeserializeToPtr() + if ptr == nil { + return ret + } + ret = append(ret, ptr) + } + return ret +} + +func destroyCiphertextPointerArray(arr []unsafe.Pointer, t FheUintType) { + for _, p := range arr { + destroyCiphertext(p, t) + } +} + +func EqArray(lhs []*TfheCiphertext, rhs []*TfheCiphertext) (*TfheCiphertext, error) { + result := new(TfheCiphertext) + if len(lhs) == 0 && len(rhs) == 0 { + // If both lhs and rhs are empty, return a trivial encryption of true. + result.TrivialEncrypt(*big.NewInt(1), FheBool) + } else if len(lhs) != len(rhs) { + // If lengths are different, return a trivial encryption of false. + result.TrivialEncrypt(*big.NewInt(0), FheBool) + } else { + // Make sure types are the same. + lhsType := lhs[0].Type() + rhsType := rhs[0].Type() + if lhsType != rhsType { + msg := fmt.Sprintf("EqArray: lhs type %d is different from rhs type %d", lhsType, rhsType) + return nil, errors.New(msg) + } + numOfElements := len(lhs) + elementsType := lhsType + + // Convert to C pointers. + lhsPtrs := arrayToCiphertextPointerArray(lhs, elementsType) + defer destroyCiphertextPointerArray(lhsPtrs, elementsType) + rhsPtrs := arrayToCiphertextPointerArray(rhs, elementsType) + defer destroyCiphertextPointerArray(rhsPtrs, elementsType) + + // Make sure all are of the same type. + if len(lhsPtrs) != numOfElements || len(rhsPtrs) != numOfElements { + return nil, errors.New("EqArray: elements are of different types") + } + + // Do the FHE computation. + var resultPtr unsafe.Pointer + switch elementsType { + case FheUint4: + resultPtr = C.eq_fhe_array_uint4(unsafe.Pointer(&lhsPtrs[0]), (C.size_t)(numOfElements), unsafe.Pointer(&rhsPtrs[0]), (C.size_t)(numOfElements), sks) + case FheUint8: + resultPtr = C.eq_fhe_array_uint8(unsafe.Pointer(&lhsPtrs[0]), (C.size_t)(numOfElements), unsafe.Pointer(&rhsPtrs[0]), (C.size_t)(numOfElements), sks) + case FheUint16: + resultPtr = C.eq_fhe_array_uint16(unsafe.Pointer(&lhsPtrs[0]), (C.size_t)(numOfElements), unsafe.Pointer(&rhsPtrs[0]), (C.size_t)(numOfElements), sks) + case FheUint32: + resultPtr = C.eq_fhe_array_uint32(unsafe.Pointer(&lhsPtrs[0]), (C.size_t)(numOfElements), unsafe.Pointer(&rhsPtrs[0]), (C.size_t)(numOfElements), sks) + case FheUint64: + resultPtr = C.eq_fhe_array_uint64(unsafe.Pointer(&lhsPtrs[0]), (C.size_t)(numOfElements), unsafe.Pointer(&rhsPtrs[0]), (C.size_t)(numOfElements), sks) + default: + return nil, fmt.Errorf("EqArray: unsupported ciphertext type %d", elementsType) + } + if resultPtr == nil { + return nil, errors.New("EqArray: FHE computation failed") + } + defer C.destroy_fhe_bool(resultPtr) + ser := &C.DynamicBuffer{} + ret := C.serialize_fhe_bool(resultPtr, ser) + if ret != 0 { + return nil, errors.New("EqArray: bool serialization failed") + } + defer C.destroy_dynamic_buffer(ser) + result.Serialization = C.GoBytes(unsafe.Pointer(ser.pointer), C.int(ser.length)) + result.FheUintType = FheBool + result.computeHash() + } + return result, nil +} diff --git a/fhevm/tfhe/tfhe_test.go b/fhevm/tfhe/tfhe_test.go index d63d5d1..ed0fa8f 100644 --- a/fhevm/tfhe/tfhe_test.go +++ b/fhevm/tfhe/tfhe_test.go @@ -1576,6 +1576,164 @@ func TfheCast(t *testing.T, fheUintTypeFrom FheUintType, fheUintTypeTo FheUintTy } } +func TfheEqArrayEqual(t *testing.T, fheUintType FheUintType) { + lhs := make([]*TfheCiphertext, 0) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), fheUintType)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(7), fheUintType)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), fheUintType)) + + rhs := make([]*TfheCiphertext, 0) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), fheUintType)) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(7), fheUintType)) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), fheUintType)) + + result, err := EqArray(lhs, rhs) + if err != nil { + t.Fatalf("EqArray failed: %v", err) + } + decrypted, err := result.Decrypt() + if err != nil { + t.Fatalf("EqArray decrypt failed: %v", err) + } + if !decrypted.IsUint64() || decrypted.Uint64() != 1 { + t.Fatalf("EqArray expected result of 1, got: %s", decrypted.String()) + } +} + +func TfheEqArrayCompareToSelf(t *testing.T, fheUintType FheUintType) { + lhs := make([]*TfheCiphertext, 0) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), fheUintType)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(7), fheUintType)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), fheUintType)) + rhs := lhs + + result, err := EqArray(lhs, rhs) + if err != nil { + t.Fatalf("EqArray failed: %v", err) + } + decrypted, err := result.Decrypt() + if err != nil { + t.Fatalf("EqArray decrypt failed: %v", err) + } + if !decrypted.IsUint64() || decrypted.Uint64() != 1 { + t.Fatalf("EqArray expected result of 1, got: %s", decrypted.String()) + } +} + +func TfheEqArrayNotEqualSameLen(t *testing.T, fheUintType FheUintType) { + lhs := make([]*TfheCiphertext, 0) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), fheUintType)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(7), fheUintType)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), fheUintType)) + + rhs := make([]*TfheCiphertext, 0) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), fheUintType)) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(6), fheUintType)) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), fheUintType)) + + result, err := EqArray(lhs, rhs) + if err != nil { + t.Fatalf("EqArray failed: %v", err) + } + decrypted, err := result.Decrypt() + if err != nil { + t.Fatalf("EqArray decrypt failed: %v", err) + } + if !decrypted.IsUint64() || decrypted.Uint64() != 0 { + t.Fatalf("EqArray expected result of 0, got: %s", decrypted.String()) + } +} + +func TfheEqArrayNotEqualDifferentLen(t *testing.T, fheUintType FheUintType) { + lhs := make([]*TfheCiphertext, 0) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), fheUintType)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(7), fheUintType)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), fheUintType)) + + rhs := make([]*TfheCiphertext, 0) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), fheUintType)) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(6), fheUintType)) + + result, err := EqArray(lhs, rhs) + if err != nil { + t.Fatalf("EqArray failed: %v", err) + } + decrypted, err := result.Decrypt() + if err != nil { + t.Fatalf("EqArray decrypt failed: %v", err) + } + if !decrypted.IsUint64() || decrypted.Uint64() != 0 { + t.Fatalf("EqArray expected result of 0, got: %s", decrypted.String()) + } +} + +func TestTfheEqArrayEqualBothEmpty(t *testing.T) { + lhs := make([]*TfheCiphertext, 0) + rhs := make([]*TfheCiphertext, 0) + result, err := EqArray(lhs, rhs) + if err != nil { + t.Fatalf("EqArray failed: %v", err) + } + decrypted, err := result.Decrypt() + if err != nil { + t.Fatalf("EqArray decrypt failed: %v", err) + } + if !decrypted.IsUint64() || decrypted.Uint64() != 1 { + t.Fatalf("EqArray expected result of 1, got: %s", decrypted.String()) + } +} + +func TestTfheEqArrayDifferentTypesInLhs(t *testing.T) { + lhs := make([]*TfheCiphertext, 0) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), FheUint32)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(7), FheUint32)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), FheUint64)) + + rhs := make([]*TfheCiphertext, 0) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), FheUint32)) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(6), FheUint32)) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), FheUint32)) + + _, err := EqArray(lhs, rhs) + if err == nil { + t.Fatalf("EqArray expected error") + } +} + +func TestTfheEqArrayDifferentTypesInRhs(t *testing.T) { + lhs := make([]*TfheCiphertext, 0) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), FheUint32)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(7), FheUint32)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), FheUint32)) + + rhs := make([]*TfheCiphertext, 0) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), FheUint32)) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(6), FheUint16)) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), FheUint32)) + + _, err := EqArray(lhs, rhs) + if err == nil { + t.Fatalf("EqArray expected error") + } +} + +func TestTfheEqArrayUnsupportedType(t *testing.T) { + lhs := make([]*TfheCiphertext, 0) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), FheBool)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(7), FheBool)) + lhs = append(lhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), FheBool)) + + rhs := make([]*TfheCiphertext, 0) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(4), FheBool)) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(6), FheBool)) + rhs = append(rhs, new(TfheCiphertext).Encrypt(*big.NewInt(10), FheBool)) + + _, err := EqArray(lhs, rhs) + if err == nil { + t.Fatalf("EqArray expected error") + } +} + func TestTfheEncryptDecryptBool(t *testing.T) { TfheEncryptDecrypt(t, FheBool) } @@ -2623,3 +2781,83 @@ func TestTfhe64Cast16(t *testing.T) { func TestTfhe64Cast32(t *testing.T) { TfheCast(t, FheUint64, FheUint32) } + +func TestTfheEqArrayEqual4(t *testing.T) { + TfheEqArrayEqual(t, FheUint4) +} + +func TestTfheEqArrayEqual8(t *testing.T) { + TfheEqArrayEqual(t, FheUint8) +} + +func TestTfheEqArrayEqual16(t *testing.T) { + TfheEqArrayEqual(t, FheUint16) +} + +func TestTfheEqArrayEqual32(t *testing.T) { + TfheEqArrayEqual(t, FheUint32) +} + +func TestTfheEqArrayEqual64(t *testing.T) { + TfheEqArrayEqual(t, FheUint64) +} + +func TestTfheEqArrayCompareToSelf4(t *testing.T) { + TfheEqArrayCompareToSelf(t, FheUint4) +} + +func TestTfheEqArrayCompareToSelf8(t *testing.T) { + TfheEqArrayCompareToSelf(t, FheUint8) +} + +func TestTfheEqArrayCompareToSelf16(t *testing.T) { + TfheEqArrayCompareToSelf(t, FheUint16) +} + +func TestTfheEqArrayCompareToSelf32(t *testing.T) { + TfheEqArrayCompareToSelf(t, FheUint32) +} + +func TestTfheEqArrayCompareToSelf64(t *testing.T) { + TfheEqArrayCompareToSelf(t, FheUint64) +} + +func TestTfheEqArrayNotEqualSameLen4(t *testing.T) { + TfheEqArrayNotEqualSameLen(t, FheUint4) +} + +func TestTfheEqArrayNotEqualSameLen8(t *testing.T) { + TfheEqArrayNotEqualSameLen(t, FheUint8) +} + +func TestTfheEqArrayNotEqualSameLen16(t *testing.T) { + TfheEqArrayNotEqualSameLen(t, FheUint16) +} + +func TestTfheEqArrayNotEqualSameLen32(t *testing.T) { + TfheEqArrayNotEqualSameLen(t, FheUint32) +} + +func TestTfheEqArrayNotEqualSameLen64(t *testing.T) { + TfheEqArrayNotEqualSameLen(t, FheUint64) +} + +func TestTfheEqArrayNotEqualDifferentLen4(t *testing.T) { + TfheEqArrayNotEqualSameLen(t, FheUint4) +} + +func TestTfheEqArrayNotEqualDifferentLen8(t *testing.T) { + TfheEqArrayNotEqualSameLen(t, FheUint8) +} + +func TestTfheEqArrayNotEqualDifferentLen16(t *testing.T) { + TfheEqArrayNotEqualSameLen(t, FheUint16) +} + +func TestTfheEqArrayNotEqualDifferentLen32(t *testing.T) { + TfheEqArrayNotEqualSameLen(t, FheUint32) +} + +func TestTfheEqArrayNotEqualDifferentLen64(t *testing.T) { + TfheEqArrayNotEqualSameLen(t, FheUint64) +} diff --git a/fhevm/tfhe/tfhe_wrappers.c b/fhevm/tfhe/tfhe_wrappers.c index 51ec46c..823f413 100644 --- a/fhevm/tfhe/tfhe_wrappers.c +++ b/fhevm/tfhe/tfhe_wrappers.c @@ -1636,6 +1636,61 @@ void* scalar_eq_fhe_uint160(void* ct, struct U256 pt, void* sks) return result; } +void* eq_fhe_array_uint4(void* ct1, size_t ct1_len, void* ct2, size_t ct2_len, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_array_eq(ct1, ct1_len, ct2, ct2_len, &result); + if(r != 0) return NULL; + return result; +} + +void* eq_fhe_array_uint8(void* ct1, size_t ct1_len, void* ct2, size_t ct2_len, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_array_eq(ct1, ct1_len, ct2, ct2_len, &result); + if(r != 0) return NULL; + return result; +} + +void* eq_fhe_array_uint16(void* ct1, size_t ct1_len, void* ct2, size_t ct2_len, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_array_eq(ct1, ct1_len, ct2, ct2_len, &result); + if(r != 0) return NULL; + return result; +} + +void* eq_fhe_array_uint32(void* ct1, size_t ct1_len, void* ct2, size_t ct2_len, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_array_eq(ct1, ct1_len, ct2, ct2_len, &result); + if(r != 0) return NULL; + return result; +} + +void* eq_fhe_array_uint64(void* ct1, size_t ct1_len, void* ct2, size_t ct2_len, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint64_array_eq(ct1, ct1_len, ct2, ct2_len, &result); + if(r != 0) return NULL; + return result; +} + void* ne_fhe_uint4(void* ct1, void* ct2, void* sks) { FheBool* result = NULL; diff --git a/fhevm/tfhe/tfhe_wrappers.h b/fhevm/tfhe/tfhe_wrappers.h index 1a1035b..6597255 100644 --- a/fhevm/tfhe/tfhe_wrappers.h +++ b/fhevm/tfhe/tfhe_wrappers.h @@ -295,6 +295,16 @@ void* scalar_eq_fhe_uint64(void* ct, uint64_t pt, void* sks); void* scalar_eq_fhe_uint160(void* ct, struct U256 pt, void* sks); +void* eq_fhe_array_uint4(void* ct1, size_t ct1_len, void* ct2, size_t ct2_len, void* sks); + +void* eq_fhe_array_uint8(void* ct1, size_t ct1_len, void* ct2, size_t ct2_len, void* sks); + +void* eq_fhe_array_uint16(void* ct1, size_t ct1_len, void* ct2, size_t ct2_len, void* sks); + +void* eq_fhe_array_uint32(void* ct1, size_t ct1_len, void* ct2, size_t ct2_len, void* sks); + +void* eq_fhe_array_uint64(void* ct1, size_t ct1_len, void* ct2, size_t ct2_len, void* sks); + void* ne_fhe_uint4(void* ct1, void* ct2, void* sks); void* ne_fhe_uint8(void* ct1, void* ct2, void* sks); diff --git a/tfhe-rs b/tfhe-rs index 5b65386..0d7a88e 160000 --- a/tfhe-rs +++ b/tfhe-rs @@ -1 +1 @@ -Subproject commit 5b653864b7e2c865c85c9d83efb2b85f08f53045 +Subproject commit 0d7a88e640a98612c0743d01062bb81d2a7e4a23