Skip to content
This repository has been archived by the owner on Dec 23, 2024. It is now read-only.

feat: add rotate left and rotate right #100

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions fhevm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,102 @@ func FheLibShr(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) {
}
}

func FheLibRotl(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) {
var lhs, rhs uint64
switch fheUintType {
case tfhe.FheUint4:
lhs = 2
rhs = 1
case tfhe.FheUint8:
lhs = 2
rhs = 1
case tfhe.FheUint16:
lhs = 4283
rhs = 2
case tfhe.FheUint32:
lhs = 1333337
rhs = 3
case tfhe.FheUint64:
lhs = 13333377777777777
lhs = 34
}
expected := lhs << rhs
signature := "fheRotl(uint256,uint256,bytes1)"
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth
addr := common.Address{}
readOnly := false
lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).GetHash()
var rhsHash common.Hash
if scalar {
rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes())
} else {
rhsHash = verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).GetHash()
}
input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash)
out, err := FheLibRun(environment, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out))
if res == nil {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted, err := res.ciphertext.Decrypt()
if err != nil || decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
}

func FheLibRotr(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) {
var lhs, rhs uint64
switch fheUintType {
case tfhe.FheUint4:
lhs = 2
rhs = 1
case tfhe.FheUint8:
lhs = 2
rhs = 1
case tfhe.FheUint16:
lhs = 4283
rhs = 3
case tfhe.FheUint32:
lhs = 1333337
rhs = 3
case tfhe.FheUint64:
lhs = 13333377777777777
lhs = 34
}
expected := lhs >> rhs
signature := "fheRotr(uint256,uint256,bytes1)"
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth
addr := common.Address{}
readOnly := false
lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).GetHash()
var rhsHash common.Hash
if scalar {
rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes())
} else {
rhsHash = verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).GetHash()
}
input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash)
out, err := FheLibRun(environment, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out))
if res == nil {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted, err := res.ciphertext.Decrypt()
if err != nil || decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
}

func FheLibNe(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) {
var lhs, rhs uint64
switch fheUintType {
Expand Down
12 changes: 12 additions & 0 deletions fhevm/fhelib.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ var fhelibMethods = []*FheLibMethod{
requiredGasFunction: fheShrRequiredGas,
runFunction: fheShrRun,
},
{
name: "fheRotl",
argTypes: "(uint256,uint256,bytes1)",
requiredGasFunction: fheRotlRequiredGas,
runFunction: fheRotlRun,
},
{
name: "fheRotr",
argTypes: "(uint256,uint256,bytes1)",
requiredGasFunction: fheRotrRequiredGas,
runFunction: fheRotrRun,
},
{
name: "fheNe",
argTypes: "(uint256,uint256,bytes1)",
Expand Down
133 changes: 133 additions & 0 deletions fhevm/operators_bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,139 @@ func fheShrRun(environment EVMEnvironment, caller common.Address, addr common.Ad
}
}


func fheRotlRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) {
input = input[:minInt(65, len(input))]

logger := environment.GetLogger()

isScalar, err := isScalarOp(input)
if err != nil {
logger.Error("fheShl can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return nil, err
}

if !isScalar {
lhs, rhs, err := get2VerifiedOperands(environment, input)
otelDescribeOperands(runSpan, encryptedOperand(*lhs), encryptedOperand(*rhs))
if err != nil {
logger.Error("fheShl inputs not verified", "err", err, "input", hex.EncodeToString(input))
return nil, err
}
if lhs.fheUintType() != rhs.fheUintType() {
msg := "fheShl operand type mismatch"
logger.Error(msg, "lhs", lhs.fheUintType(), "rhs", rhs.fheUintType())
return nil, errors.New(msg)
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !environment.IsCommitting() && !environment.IsEthCall() {
return importRandomCiphertext(environment, lhs.fheUintType()), nil
}

result, err := lhs.ciphertext.Rotl(rhs.ciphertext)
if err != nil {
logger.Error("fheRotl failed", "err", err)
return nil, err
}
importCiphertext(environment, result)

resultHash := result.GetHash()
logger.Info("fheRotl success", "lhs", lhs.hash().Hex(), "rhs", rhs.hash().Hex(), "result", resultHash.Hex())
return resultHash[:], nil

} else {
lhs, rhs, err := getScalarOperands(environment, input)
otelDescribeOperands(runSpan, encryptedOperand(*lhs), plainOperand(*rhs))
if err != nil {
logger.Error("fheRotl scalar inputs not verified", "err", err, "input", hex.EncodeToString(input))
return nil, err
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !environment.IsCommitting() && !environment.IsEthCall() {
return importRandomCiphertext(environment, lhs.fheUintType()), nil
}

result, err := lhs.ciphertext.ScalarRotl(rhs)
if err != nil {
logger.Error("fheRotl failed", "err", err)
return nil, err
}
importCiphertext(environment, result)

resultHash := result.GetHash()
logger.Info("fheRotl scalar success", "lhs", lhs.hash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex())
return resultHash[:], nil
}
}

func fheRotrRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) {
input = input[:minInt(65, len(input))]

logger := environment.GetLogger()

isScalar, err := isScalarOp(input)
if err != nil {
logger.Error("fheRotr can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return nil, err
}

if !isScalar {
lhs, rhs, err := get2VerifiedOperands(environment, input)
otelDescribeOperands(runSpan, encryptedOperand(*lhs), encryptedOperand(*rhs))
if err != nil {
logger.Error("fheRotr inputs not verified", "err", err, "input", hex.EncodeToString(input))
return nil, err
}
if lhs.fheUintType() != rhs.fheUintType() {
msg := "fheRotr operand type mismatch"
logger.Error(msg, "lhs", lhs.fheUintType(), "rhs", rhs.fheUintType())
return nil, errors.New(msg)
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !environment.IsCommitting() && !environment.IsEthCall() {
return importRandomCiphertext(environment, lhs.fheUintType()), nil
}

result, err := lhs.ciphertext.Rotr(rhs.ciphertext)
if err != nil {
logger.Error("fheRotr failed", "err", err)
return nil, err
}
importCiphertext(environment, result)

resultHash := result.GetHash()
logger.Info("fheRotr success", "lhs", lhs.hash().Hex(), "rhs", rhs.hash().Hex(), "result", resultHash.Hex())
return resultHash[:], nil

} else {
lhs, rhs, err := getScalarOperands(environment, input)
otelDescribeOperands(runSpan, encryptedOperand(*lhs), plainOperand(*rhs))
if err != nil {
logger.Error("fheRotr scalar inputs not verified", "err", err, "input", hex.EncodeToString(input))
return nil, err
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !environment.IsCommitting() && !environment.IsEthCall() {
return importRandomCiphertext(environment, lhs.fheUintType()), nil
}

result, err := lhs.ciphertext.ScalarRotr(rhs)
if err != nil {
logger.Error("fheRotr failed", "err", err)
return nil, err
}
importCiphertext(environment, result)

resultHash := result.GetHash()
logger.Info("fheRotr scalar success", "lhs", lhs.hash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex())
return resultHash[:], nil
}
}

func fheNegRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) {
input = input[:minInt(32, len(input))]

Expand Down
10 changes: 10 additions & 0 deletions fhevm/operators_bit_gas.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ func fheShrRequiredGas(environment EVMEnvironment, input []byte) uint64 {
return fheShlRequiredGas(environment, input)
}

func fheRotrRequiredGas(environment EVMEnvironment, input []byte) uint64 {
// Implement in terms of shl, because comparison costs are currently the same.
return fheShlRequiredGas(environment, input)
}

func fheRotlRequiredGas(environment EVMEnvironment, input []byte) uint64 {
// Implement in terms of shl, because comparison costs are currently the same.
return fheShlRequiredGas(environment, input)
}

func fheNegRequiredGas(environment EVMEnvironment, input []byte) uint64 {
input = input[:minInt(32, len(input))]

Expand Down
86 changes: 86 additions & 0 deletions fhevm/tfhe/tfhe_ciphertext.go
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,92 @@ func (lhs *TfheCiphertext) ScalarShr(rhs *big.Int) (*TfheCiphertext, error) {
fheUint160BinaryScalarNotSupportedOp, false)
}


func (lhs *TfheCiphertext) Rotl(rhs *TfheCiphertext) (*TfheCiphertext, error) {
return lhs.executeBinaryCiphertextOperation(rhs,
boolBinaryNotSupportedOp,
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotl_fhe_uint4(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotl_fhe_uint8(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotl_fhe_uint16(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotl_fhe_uint32(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotl_fhe_uint64(lhs, rhs, sks), nil
},
fheUint160BinaryNotSupportedOp, false)
}

func (lhs *TfheCiphertext) ScalarRotl(rhs *big.Int) (*TfheCiphertext, error) {
return lhs.executeBinaryScalarOperation(rhs,
boolBinaryScalarNotSupportedOp,
func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) {
return C.scalar_rotl_fhe_uint4(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) {
return C.scalar_rotl_fhe_uint8(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) {
return C.scalar_rotl_fhe_uint16(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) {
return C.scalar_rotl_fhe_uint32(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) {
return C.scalar_rotl_fhe_uint64(lhs, rhs, sks), nil
},
fheUint160BinaryScalarNotSupportedOp, false)
}

func (lhs *TfheCiphertext) Rotr(rhs *TfheCiphertext) (*TfheCiphertext, error) {
return lhs.executeBinaryCiphertextOperation(rhs,
boolBinaryNotSupportedOp,
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotr_fhe_uint4(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotr_fhe_uint8(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotr_fhe_uint16(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotr_fhe_uint32(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotr_fhe_uint64(lhs, rhs, sks), nil
},
fheUint160BinaryNotSupportedOp,
false)
}

func (lhs *TfheCiphertext) ScalarRotr(rhs *big.Int) (*TfheCiphertext, error) {
return lhs.executeBinaryScalarOperation(rhs,
boolBinaryScalarNotSupportedOp,
func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) {
return C.scalar_rotr_fhe_uint4(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) {
return C.scalar_rotr_fhe_uint8(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) {
return C.scalar_rotr_fhe_uint16(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) {
return C.scalar_rotr_fhe_uint32(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) {
return C.scalar_rotr_fhe_uint64(lhs, rhs, sks), nil
},
fheUint160BinaryScalarNotSupportedOp, false)
}

func (lhs *TfheCiphertext) Eq(rhs *TfheCiphertext) (*TfheCiphertext, error) {
return lhs.executeBinaryCiphertextOperation(rhs,
boolBinaryNotSupportedOp,
Expand Down
Loading
Loading