From de4413a0f5d82d5d8836359fff462b6283f82621 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Wed, 3 Apr 2024 17:54:24 +0200 Subject: [PATCH] feat: add rotate left and rotate right --- fhevm/contracts_test.go | 96 +++++++++++++++ fhevm/fhelib.go | 12 ++ fhevm/operators_bit.go | 133 ++++++++++++++++++++ fhevm/operators_bit_gas.go | 10 ++ fhevm/tfhe/tfhe_ciphertext.go | 86 +++++++++++++ fhevm/tfhe/tfhe_test.go | 219 +++++++++++++++++++++++++++++++++ fhevm/tfhe/tfhe_wrappers.c | 220 ++++++++++++++++++++++++++++++++++ fhevm/tfhe/tfhe_wrappers.h | 40 +++++++ 8 files changed, 816 insertions(+) diff --git a/fhevm/contracts_test.go b/fhevm/contracts_test.go index 38f866b..28b602b 100644 --- a/fhevm/contracts_test.go +++ b/fhevm/contracts_test.go @@ -824,6 +824,102 @@ func FheLibShr(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) { } } +func FheLibRotl(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case tfhe.FheUint4: + lhs = 2 + rhs = 1 + case tfhe.FheUint8: + lhs = 2 + rhs = 1 + case tfhe.FheUint16: + lhs = 4283 + rhs = 2 + case tfhe.FheUint32: + lhs = 1333337 + rhs = 3 + case tfhe.FheUint64: + lhs = 13333377777777777 + lhs = 34 + } + expected := lhs << rhs + signature := "fheRotl(uint256,uint256,bytes1)" + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).GetHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).GetHash() + } + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := FheLibRun(environment, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.Decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheLibRotr(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case tfhe.FheUint4: + lhs = 2 + rhs = 1 + case tfhe.FheUint8: + lhs = 2 + rhs = 1 + case tfhe.FheUint16: + lhs = 4283 + rhs = 3 + case tfhe.FheUint32: + lhs = 1333337 + rhs = 3 + case tfhe.FheUint64: + lhs = 13333377777777777 + lhs = 34 + } + expected := lhs >> rhs + signature := "fheRotr(uint256,uint256,bytes1)" + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).GetHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).GetHash() + } + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := FheLibRun(environment, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.Decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + func FheLibNe(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { diff --git a/fhevm/fhelib.go b/fhevm/fhelib.go index 44f7c6b..7139895 100644 --- a/fhevm/fhelib.go +++ b/fhevm/fhelib.go @@ -154,6 +154,18 @@ var fhelibMethods = []*FheLibMethod{ requiredGasFunction: fheShrRequiredGas, runFunction: fheShrRun, }, + { + name: "fheRotl", + argTypes: "(uint256,uint256,bytes1)", + requiredGasFunction: fheRotlRequiredGas, + runFunction: fheRotlRun, + }, + { + name: "fheRotr", + argTypes: "(uint256,uint256,bytes1)", + requiredGasFunction: fheRotrRequiredGas, + runFunction: fheRotrRun, + }, { name: "fheNe", argTypes: "(uint256,uint256,bytes1)", diff --git a/fhevm/operators_bit.go b/fhevm/operators_bit.go index 62b3e42..cd77065 100644 --- a/fhevm/operators_bit.go +++ b/fhevm/operators_bit.go @@ -140,6 +140,139 @@ func fheShrRun(environment EVMEnvironment, caller common.Address, addr common.Ad } } + +func fheRotlRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) { + input = input[:minInt(65, len(input))] + + logger := environment.GetLogger() + + isScalar, err := isScalarOp(input) + if err != nil { + logger.Error("fheShl can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(environment, input) + otelDescribeOperands(runSpan, encryptedOperand(*lhs), encryptedOperand(*rhs)) + if err != nil { + logger.Error("fheShl inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.fheUintType() != rhs.fheUintType() { + msg := "fheShl operand type mismatch" + logger.Error(msg, "lhs", lhs.fheUintType(), "rhs", rhs.fheUintType()) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !environment.IsCommitting() && !environment.IsEthCall() { + return importRandomCiphertext(environment, lhs.fheUintType()), nil + } + + result, err := lhs.ciphertext.Rotl(rhs.ciphertext) + if err != nil { + logger.Error("fheRotl failed", "err", err) + return nil, err + } + importCiphertext(environment, result) + + resultHash := result.GetHash() + logger.Info("fheRotl success", "lhs", lhs.hash().Hex(), "rhs", rhs.hash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(environment, input) + otelDescribeOperands(runSpan, encryptedOperand(*lhs), plainOperand(*rhs)) + if err != nil { + logger.Error("fheRotl scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !environment.IsCommitting() && !environment.IsEthCall() { + return importRandomCiphertext(environment, lhs.fheUintType()), nil + } + + result, err := lhs.ciphertext.ScalarRotl(rhs) + if err != nil { + logger.Error("fheRotl failed", "err", err) + return nil, err + } + importCiphertext(environment, result) + + resultHash := result.GetHash() + logger.Info("fheRotl scalar success", "lhs", lhs.hash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } +} + +func fheRotrRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) { + input = input[:minInt(65, len(input))] + + logger := environment.GetLogger() + + isScalar, err := isScalarOp(input) + if err != nil { + logger.Error("fheRotr can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(environment, input) + otelDescribeOperands(runSpan, encryptedOperand(*lhs), encryptedOperand(*rhs)) + if err != nil { + logger.Error("fheRotr inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.fheUintType() != rhs.fheUintType() { + msg := "fheRotr operand type mismatch" + logger.Error(msg, "lhs", lhs.fheUintType(), "rhs", rhs.fheUintType()) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !environment.IsCommitting() && !environment.IsEthCall() { + return importRandomCiphertext(environment, lhs.fheUintType()), nil + } + + result, err := lhs.ciphertext.Rotr(rhs.ciphertext) + if err != nil { + logger.Error("fheRotr failed", "err", err) + return nil, err + } + importCiphertext(environment, result) + + resultHash := result.GetHash() + logger.Info("fheRotr success", "lhs", lhs.hash().Hex(), "rhs", rhs.hash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(environment, input) + otelDescribeOperands(runSpan, encryptedOperand(*lhs), plainOperand(*rhs)) + if err != nil { + logger.Error("fheRotr scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !environment.IsCommitting() && !environment.IsEthCall() { + return importRandomCiphertext(environment, lhs.fheUintType()), nil + } + + result, err := lhs.ciphertext.ScalarRotr(rhs) + if err != nil { + logger.Error("fheRotr failed", "err", err) + return nil, err + } + importCiphertext(environment, result) + + resultHash := result.GetHash() + logger.Info("fheRotr scalar success", "lhs", lhs.hash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } +} + func fheNegRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) { input = input[:minInt(32, len(input))] diff --git a/fhevm/operators_bit_gas.go b/fhevm/operators_bit_gas.go index bf9ea00..60e10a5 100644 --- a/fhevm/operators_bit_gas.go +++ b/fhevm/operators_bit_gas.go @@ -42,6 +42,16 @@ func fheShrRequiredGas(environment EVMEnvironment, input []byte) uint64 { return fheShlRequiredGas(environment, input) } +func fheRotrRequiredGas(environment EVMEnvironment, input []byte) uint64 { + // Implement in terms of shl, because comparison costs are currently the same. + return fheShlRequiredGas(environment, input) +} + +func fheRotlRequiredGas(environment EVMEnvironment, input []byte) uint64 { + // Implement in terms of shl, because comparison costs are currently the same. + return fheShlRequiredGas(environment, input) +} + func fheNegRequiredGas(environment EVMEnvironment, input []byte) uint64 { input = input[:minInt(32, len(input))] diff --git a/fhevm/tfhe/tfhe_ciphertext.go b/fhevm/tfhe/tfhe_ciphertext.go index d2ecc4f..31f0452 100644 --- a/fhevm/tfhe/tfhe_ciphertext.go +++ b/fhevm/tfhe/tfhe_ciphertext.go @@ -1485,6 +1485,92 @@ func (lhs *TfheCiphertext) ScalarShr(rhs *big.Int) (*TfheCiphertext, error) { fheUint160BinaryScalarNotSupportedOp, false) } + +func (lhs *TfheCiphertext) Rotl(rhs *TfheCiphertext) (*TfheCiphertext, error) { + return lhs.executeBinaryCiphertextOperation(rhs, + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.rotl_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.rotl_fhe_uint8(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.rotl_fhe_uint16(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.rotl_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.rotl_fhe_uint64(lhs, rhs, sks), nil + }, + fheUint160BinaryNotSupportedOp, false) +} + +func (lhs *TfheCiphertext) ScalarRotl(rhs *big.Int) (*TfheCiphertext, error) { + return lhs.executeBinaryScalarOperation(rhs, + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_rotl_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_rotl_fhe_uint8(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_rotl_fhe_uint16(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_rotl_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_rotl_fhe_uint64(lhs, rhs, sks), nil + }, + fheUint160BinaryScalarNotSupportedOp, false) +} + +func (lhs *TfheCiphertext) Rotr(rhs *TfheCiphertext) (*TfheCiphertext, error) { + return lhs.executeBinaryCiphertextOperation(rhs, + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.rotr_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.rotr_fhe_uint8(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.rotr_fhe_uint16(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.rotr_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.rotr_fhe_uint64(lhs, rhs, sks), nil + }, + fheUint160BinaryNotSupportedOp, + false) +} + +func (lhs *TfheCiphertext) ScalarRotr(rhs *big.Int) (*TfheCiphertext, error) { + return lhs.executeBinaryScalarOperation(rhs, + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_rotr_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_rotr_fhe_uint8(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_rotr_fhe_uint16(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_rotr_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_rotr_fhe_uint64(lhs, rhs, sks), nil + }, + fheUint160BinaryScalarNotSupportedOp, false) +} + func (lhs *TfheCiphertext) Eq(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, boolBinaryNotSupportedOp, diff --git a/fhevm/tfhe/tfhe_test.go b/fhevm/tfhe/tfhe_test.go index 221b083..de63ed5 100644 --- a/fhevm/tfhe/tfhe_test.go +++ b/fhevm/tfhe/tfhe_test.go @@ -7,6 +7,7 @@ import ( "log" "math" "math/big" + "math/bits" "os" "testing" ) @@ -715,6 +716,147 @@ func TfheScalarShr(t *testing.T, fheUintType FheUintType) { } } +func TfheRotl(t *testing.T, fheUintType FheUintType) { + var a, b big.Int + var expected uint64 + switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) + expected = uint64(bits.RotateLeft8(uint8(a.Uint64()), int(b.Uint64()))) + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + expected = uint64(bits.RotateLeft8(uint8(a.Uint64()), int(b.Uint64()))) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + expected = uint64(bits.RotateLeft16(uint16(a.Uint64()), int(b.Uint64()))) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + expected = uint64(bits.RotateLeft32(uint32(a.Uint64()), int(b.Uint64()))) + case FheUint64: + a.SetUint64(13371337) + b.SetUint64(45) + expected = bits.RotateLeft64(a.Uint64(), int(b.Uint64())) + } + ctA := new(TfheCiphertext) + ctA.Encrypt(a, fheUintType) + ctB := new(TfheCiphertext) + ctB.Encrypt(b, fheUintType) + ctRes, _ := ctA.Rotl(ctB) + res, err := ctRes.Decrypt() + if err != nil || res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + + +func TfheScalarRotl(t *testing.T, fheUintType FheUintType) { + var a, b big.Int + var expected uint64 + switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) + expected = uint64(bits.RotateLeft8(uint8(a.Uint64()), int(b.Uint64()))) + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + expected = uint64(bits.RotateLeft8(uint8(a.Uint64()), int(b.Uint64()))) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + expected = uint64(bits.RotateLeft16(uint16(a.Uint64()), int(b.Uint64()))) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + expected = uint64(bits.RotateLeft32(uint32(a.Uint64()), int(b.Uint64()))) + case FheUint64: + a.SetUint64(13371337) + b.SetUint64(45) + expected = uint64(bits.RotateLeft64(a.Uint64(), int(b.Uint64()))) + } + ctA := new(TfheCiphertext) + ctA.Encrypt(a, fheUintType) + ctRes, _ := ctA.ScalarRotl(&b) + res, err := ctRes.Decrypt() + if err != nil || res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheRotr(t *testing.T, fheUintType FheUintType) { + var a, b big.Int + var expected uint64 + switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) + expected = uint64(bits.RotateLeft8(uint8(a.Uint64()), -int(b.Uint64()))) + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + expected = uint64(bits.RotateLeft8(uint8(a.Uint64()), -int(b.Uint64()))) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + expected = uint64(bits.RotateLeft16(uint16(a.Uint64()), -int(b.Uint64()))) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + expected = uint64(bits.RotateLeft32(uint32(a.Uint64()), -int(b.Uint64()))) + case FheUint64: + a.SetUint64(13371337) + b.SetUint64(1337) + expected = uint64(bits.RotateLeft64(a.Uint64(), -int(b.Uint64()))) + } + ctA := new(TfheCiphertext) + ctA.Encrypt(a, fheUintType) + ctB := new(TfheCiphertext) + ctB.Encrypt(b, fheUintType) + ctRes, _ := ctA.Rotr(ctB) + res, err := ctRes.Decrypt() + if err != nil || res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheScalarRotr(t *testing.T, fheUintType FheUintType) { + var a, b big.Int + var expected uint64 + switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) + expected = uint64(bits.RotateLeft8(uint8(a.Uint64()), -int(b.Uint64()))) + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + expected = uint64(bits.RotateLeft8(uint8(a.Uint64()), -int(b.Uint64()))) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + expected = uint64(bits.RotateLeft16(uint16(a.Uint64()), -int(b.Uint64()))) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + expected = uint64(bits.RotateLeft32(uint32(a.Uint64()), -int(b.Uint64()))) + case FheUint64: + a.SetUint64(13371337) + b.SetUint64(1337) + expected = uint64(bits.RotateLeft64(a.Uint64(), -int(b.Uint64()))) + } + ctA := new(TfheCiphertext) + ctA.Encrypt(a, fheUintType) + ctRes, _ := ctA.ScalarRotr(&b) + res, err := ctRes.Decrypt() + if err != nil || res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + func TfheEq(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { @@ -1926,6 +2068,83 @@ func TestTfheScalarShr64(t *testing.T) { TfheScalarShr(t, FheUint64) } + +func TestTfheRotl4(t *testing.T) { + TfheRotl(t, FheUint4) +} + +func TestTfheRotl8(t *testing.T) { + TfheRotl(t, FheUint8) +} + +func TestTfheRotl16(t *testing.T) { + TfheRotl(t, FheUint16) +} + +func TestTfheRotl32(t *testing.T) { + TfheRotl(t, FheUint32) +} + +func TestTfheRotl64(t *testing.T) { + TfheRotl(t, FheUint64) +} + +func TestTfheScalarRotl4(t *testing.T) { + TfheScalarRotl(t, FheUint4) +} + +func TestTfheScalarRotl8(t *testing.T) { + TfheScalarRotl(t, FheUint8) +} + +func TestTfheScalarRotl16(t *testing.T) { + TfheScalarRotl(t, FheUint16) +} + +func TestTfheScalarRotl32(t *testing.T) { + TfheScalarRotl(t, FheUint32) +} + +func TestTfheScalarRotl64(t *testing.T) { + TfheScalarRotl(t, FheUint64) +} + +func TestTfheRotr4(t *testing.T) { + TfheRotr(t, FheUint4) +} + +func TestTfheRotr8(t *testing.T) { + TfheRotr(t, FheUint8) +} + +func TestTfheRotr16(t *testing.T) { + TfheRotr(t, FheUint16) +} + +func TestTfheRotr32(t *testing.T) { + TfheRotr(t, FheUint32) +} + +func TestTfheRotr64(t *testing.T) { + TfheRotr(t, FheUint64) +} + +func TestTfheScalarRotr8(t *testing.T) { + TfheScalarRotr(t, FheUint8) +} + +func TestTfheScalarRotr16(t *testing.T) { + TfheScalarRotr(t, FheUint16) +} + +func TestTfheScalarRotr32(t *testing.T) { + TfheScalarRotr(t, FheUint32) +} + +func TestTfheScalarRotr64(t *testing.T) { + TfheScalarRotr(t, FheUint64) +} + func TestTfheEq4(t *testing.T) { TfheEq(t, FheUint4) } diff --git a/fhevm/tfhe/tfhe_wrappers.c b/fhevm/tfhe/tfhe_wrappers.c index 9a704a2..1f559e5 100644 --- a/fhevm/tfhe/tfhe_wrappers.c +++ b/fhevm/tfhe/tfhe_wrappers.c @@ -1284,6 +1284,226 @@ void* scalar_shr_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* rotl_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_rotate_left(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* rotl_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_rotate_left(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* rotl_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_rotate_left(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* rotl_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_rotate_left(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* rotl_fhe_uint64(void* ct1, void* ct2, void* sks) +{ + FheUint64* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint64_rotate_left(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* scalar_rotl_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_rotate_left(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + +void* scalar_rotl_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_scalar_rotate_left(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + +void* scalar_rotl_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_scalar_rotate_left(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + +void* scalar_rotl_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_scalar_rotate_left(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + +void* scalar_rotl_fhe_uint64(void* ct, uint64_t pt, void* sks) +{ + FheUint64* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint64_scalar_rotate_left(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + +void* rotr_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_rotate_right(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* rotr_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_rotate_right(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* rotr_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_rotate_right(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* rotr_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_rotate_right(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* rotr_fhe_uint64(void* ct1, void* ct2, void* sks) +{ + FheUint64* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint64_rotate_right(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* scalar_rotr_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_rotate_right(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + +void* scalar_rotr_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_scalar_rotate_right(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + +void* scalar_rotr_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_scalar_rotate_right(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + +void* scalar_rotr_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_scalar_rotate_right(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + +void* scalar_rotr_fhe_uint64(void* ct, uint64_t pt, void* sks) +{ + FheUint64* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint64_scalar_rotate_right(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* eq_fhe_uint4(void* ct1, void* ct2, void* sks) { FheBool* result = NULL; diff --git a/fhevm/tfhe/tfhe_wrappers.h b/fhevm/tfhe/tfhe_wrappers.h index 83d7882..1a4f1ce 100644 --- a/fhevm/tfhe/tfhe_wrappers.h +++ b/fhevm/tfhe/tfhe_wrappers.h @@ -231,6 +231,46 @@ void* scalar_shr_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_shr_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* rotl_fhe_uint4(void* ct1, void* ct2, void* sks); + +void* rotl_fhe_uint8(void* ct1, void* ct2, void* sks); + +void* rotl_fhe_uint16(void* ct1, void* ct2, void* sks); + +void* rotl_fhe_uint32(void* ct1, void* ct2, void* sks); + +void* rotl_fhe_uint64(void* ct1, void* ct2, void* sks); + +void* scalar_rotl_fhe_uint4(void* ct, uint8_t pt, void* sks); + +void* scalar_rotl_fhe_uint8(void* ct, uint8_t pt, void* sks); + +void* scalar_rotl_fhe_uint16(void* ct, uint16_t pt, void* sks); + +void* scalar_rotl_fhe_uint32(void* ct, uint32_t pt, void* sks); + +void* scalar_rotl_fhe_uint64(void* ct, uint64_t pt, void* sks); + +void* rotr_fhe_uint4(void* ct1, void* ct2, void* sks); + +void* rotr_fhe_uint8(void* ct1, void* ct2, void* sks); + +void* rotr_fhe_uint16(void* ct1, void* ct2, void* sks); + +void* rotr_fhe_uint32(void* ct1, void* ct2, void* sks); + +void* rotr_fhe_uint64(void* ct1, void* ct2, void* sks); + +void* scalar_rotr_fhe_uint4(void* ct, uint8_t pt, void* sks); + +void* scalar_rotr_fhe_uint8(void* ct, uint8_t pt, void* sks); + +void* scalar_rotr_fhe_uint16(void* ct, uint16_t pt, void* sks); + +void* scalar_rotr_fhe_uint32(void* ct, uint32_t pt, void* sks); + +void* scalar_rotr_fhe_uint64(void* ct, uint64_t pt, void* sks); + void* eq_fhe_uint4(void* ct1, void* ct2, void* sks); void* eq_fhe_uint8(void* ct1, void* ct2, void* sks);