From 4beb114eba6cdb6ecacf14e36ffd8b0df53003f2 Mon Sep 17 00:00:00 2001 From: Levent DEMIR Date: Mon, 11 Mar 2024 14:46:53 +0100 Subject: [PATCH 1/3] feat: add 160 bit enc/dec --- README.md | 2 +- fhevm/kms/kms.pb.go | 90 +++++---- fhevm/kms/kms_grpc.pb.go | 2 +- fhevm/operators_arithmetic.go | 10 +- fhevm/operators_bit.go | 4 +- fhevm/operators_comparison.go | 16 +- fhevm/operators_crypto.go | 67 ++++++- fhevm/tfhe/tfhe_ciphertext.go | 306 ++++++++++++++++++++++++------ fhevm/tfhe/tfhe_key_management.go | 1 + fhevm/tfhe/tfhe_test.go | 179 ++++++++++++++--- fhevm/tfhe/tfhe_wrappers.c | 134 +++++++++++++ fhevm/tfhe/tfhe_wrappers.go | 78 ++++++++ fhevm/tfhe/tfhe_wrappers.h | 24 +++ proto/kms.proto | 4 +- 14 files changed, 759 insertions(+), 158 deletions(-) diff --git a/README.md b/README.md index 4cdf2a1..0dba691 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ replace( ## Regenerate protobuff files To re-generate these files, install `protoc`, `protoc-gen-go` and `protoc-gen-go-grpc` and run protoc -`cd proto && protoc --go_out=../kms --go_opt=paths=source_relative --go-grpc_out=../kms --go-grpc_opt=paths=source_relative kms.proto && cd ..`. +`cd proto && protoc --go_out=../fhevm/kms --go_opt=paths=source_relative --go-grpc_out=../fhevm/kms --go-grpc_opt=paths=source_relative kms.proto && cd ..`. ## Documentation diff --git a/fhevm/kms/kms.pb.go b/fhevm/kms/kms.pb.go index 2fa8c80..6180a86 100644 --- a/fhevm/kms/kms.pb.go +++ b/fhevm/kms/kms.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.32.0 -// protoc v4.25.2 +// protoc-gen-go v1.33.0 +// protoc v3.19.6 // source: kms.proto package kms @@ -23,12 +23,14 @@ const ( type FheType int32 const ( - FheType_Bool FheType = 0 - FheType_Euint4 FheType = 1 - FheType_Euint8 FheType = 2 - FheType_Euint16 FheType = 3 - FheType_Euint32 FheType = 4 - FheType_Euint64 FheType = 5 + FheType_Bool FheType = 0 + FheType_Euint4 FheType = 1 + FheType_Euint8 FheType = 2 + FheType_Euint16 FheType = 3 + FheType_Euint32 FheType = 4 + FheType_Euint64 FheType = 5 + FheType_Euint128 FheType = 6 + FheType_Euint160 FheType = 7 ) // Enum value maps for FheType. @@ -40,14 +42,18 @@ var ( 3: "Euint16", 4: "Euint32", 5: "Euint64", + 6: "Euint128", + 7: "Euint160", } FheType_value = map[string]int32{ - "Bool": 0, - "Euint4": 1, - "Euint8": 2, - "Euint16": 3, - "Euint32": 4, - "Euint64": 5, + "Bool": 0, + "Euint4": 1, + "Euint8": 2, + "Euint16": 3, + "Euint32": 4, + "Euint64": 5, + "Euint128": 6, + "Euint160": 7, } ) @@ -211,7 +217,7 @@ type DecryptionResponse struct { Signature []byte `protobuf:"bytes,1,opt,name=signature,proto3" json:"signature,omitempty"` FheType FheType `protobuf:"varint,2,opt,name=fhe_type,json=fheType,proto3,enum=kms.FheType" json:"fhe_type,omitempty"` - Plaintext uint64 `protobuf:"varint,3,opt,name=plaintext,proto3" json:"plaintext,omitempty"` + Plaintext []byte `protobuf:"bytes,3,opt,name=plaintext,proto3" json:"plaintext,omitempty"` } func (x *DecryptionResponse) Reset() { @@ -260,11 +266,11 @@ func (x *DecryptionResponse) GetFheType() FheType { return FheType_Bool } -func (x *DecryptionResponse) GetPlaintext() uint64 { +func (x *DecryptionResponse) GetPlaintext() []byte { if x != nil { return x.Plaintext } - return 0 + return nil } type ReencryptionRequest struct { @@ -418,7 +424,7 @@ var file_kms_proto_rawDesc = []byte{ 0x74, 0x75, 0x72, 0x65, 0x12, 0x27, 0x0a, 0x08, 0x66, 0x68, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0c, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x46, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x07, 0x66, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1c, 0x0a, - 0x09, 0x70, 0x6c, 0x61, 0x69, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x04, + 0x09, 0x70, 0x6c, 0x61, 0x69, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x70, 0x6c, 0x61, 0x69, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x22, 0x9a, 0x01, 0x0a, 0x13, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x08, 0x66, 0x68, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, @@ -437,33 +443,35 @@ var file_kms_proto_rawDesc = []byte{ 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x12, 0x27, 0x0a, 0x08, 0x66, 0x68, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0c, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x46, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x07, 0x66, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, - 0x2a, 0x52, 0x0a, 0x07, 0x46, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x08, 0x0a, 0x04, 0x42, + 0x2a, 0x6e, 0x0a, 0x07, 0x46, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x08, 0x0a, 0x04, 0x42, 0x6f, 0x6f, 0x6c, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x34, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x38, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x31, 0x36, 0x10, 0x03, 0x12, 0x0b, 0x0a, 0x07, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x10, 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x45, 0x75, 0x69, 0x6e, 0x74, - 0x36, 0x34, 0x10, 0x05, 0x32, 0xa3, 0x02, 0x0a, 0x0b, 0x4b, 0x6d, 0x73, 0x45, 0x6e, 0x64, 0x70, - 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x47, 0x0a, 0x14, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, - 0x5f, 0x61, 0x6e, 0x64, 0x5f, 0x64, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x16, 0x2e, 0x6b, - 0x6d, 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4d, 0x0a, - 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x5f, 0x61, 0x6e, 0x64, 0x5f, 0x72, 0x65, - 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x18, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, - 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x19, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3a, 0x0a, 0x07, - 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x16, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, 0x65, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x17, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x40, 0x0a, 0x09, 0x52, 0x65, 0x65, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x18, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x19, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x21, 0x5a, 0x1f, 0x67, 0x69, - 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x7a, 0x61, 0x6d, 0x61, 0x2d, 0x61, 0x69, - 0x2f, 0x66, 0x68, 0x65, 0x76, 0x6d, 0x2d, 0x67, 0x6f, 0x2f, 0x6b, 0x6d, 0x73, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x36, 0x34, 0x10, 0x05, 0x12, 0x0c, 0x0a, 0x08, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x31, 0x32, 0x38, + 0x10, 0x06, 0x12, 0x0c, 0x0a, 0x08, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x31, 0x36, 0x30, 0x10, 0x07, + 0x32, 0xa3, 0x02, 0x0a, 0x0b, 0x4b, 0x6d, 0x73, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x12, 0x47, 0x0a, 0x14, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x5f, 0x61, 0x6e, 0x64, + 0x5f, 0x64, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x16, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, + 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x17, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4d, 0x0a, 0x16, 0x56, 0x61, 0x6c, + 0x69, 0x64, 0x61, 0x74, 0x65, 0x5f, 0x61, 0x6e, 0x64, 0x5f, 0x72, 0x65, 0x65, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x12, 0x18, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, + 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3a, 0x0a, 0x07, 0x44, 0x65, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x12, 0x16, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x6b, 0x6d, + 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x40, 0x0a, 0x09, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x12, 0x18, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x6b, 0x6d, + 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x27, 0x5a, 0x25, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x7a, 0x61, 0x6d, 0x61, 0x2d, 0x61, 0x69, 0x2f, 0x66, 0x68, 0x65, + 0x76, 0x6d, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x68, 0x65, 0x76, 0x6d, 0x2f, 0x6b, 0x6d, 0x73, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/fhevm/kms/kms_grpc.pb.go b/fhevm/kms/kms_grpc.pb.go index fa5c1f5..c6fc9c3 100644 --- a/fhevm/kms/kms_grpc.pb.go +++ b/fhevm/kms/kms_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.3.0 -// - protoc v4.25.2 +// - protoc v3.19.6 // source: kms.proto package kms diff --git a/fhevm/operators_arithmetic.go b/fhevm/operators_arithmetic.go index 87c99e9..d57de27 100644 --- a/fhevm/operators_arithmetic.go +++ b/fhevm/operators_arithmetic.go @@ -61,7 +61,7 @@ func fheAddRun(environment EVMEnvironment, caller common.Address, addr common.Ad return importRandomCiphertext(environment, lhs.fheUintType()), nil } - result, err := lhs.ciphertext.ScalarAdd(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarAdd(rhs) if err != nil { logger.Error("fheAdd failed", "err", err) return nil, err @@ -127,7 +127,7 @@ func fheSubRun(environment EVMEnvironment, caller common.Address, addr common.Ad return importRandomCiphertext(environment, lhs.fheUintType()), nil } - result, err := lhs.ciphertext.ScalarSub(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarSub(rhs) if err != nil { logger.Error("fheSub failed", "err", err) return nil, err @@ -193,7 +193,7 @@ func fheMulRun(environment EVMEnvironment, caller common.Address, addr common.Ad return importRandomCiphertext(environment, lhs.fheUintType()), nil } - result, err := lhs.ciphertext.ScalarMul(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarMul(rhs) if err != nil { logger.Error("fheMul failed", "err", err) return nil, err @@ -234,7 +234,7 @@ func fheDivRun(environment EVMEnvironment, caller common.Address, addr common.Ad return importRandomCiphertext(environment, lhs.fheUintType()), nil } - result, err := lhs.ciphertext.ScalarDiv(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarDiv(rhs) if err != nil { logger.Error("fheDiv failed", "err", err) return nil, err @@ -275,7 +275,7 @@ func fheRemRun(environment EVMEnvironment, caller common.Address, addr common.Ad return importRandomCiphertext(environment, lhs.fheUintType()), nil } - result, err := lhs.ciphertext.ScalarRem(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarRem(rhs) if err != nil { logger.Error("fheRem failed", "err", err) return nil, err diff --git a/fhevm/operators_bit.go b/fhevm/operators_bit.go index 502f3e7..62b3e42 100644 --- a/fhevm/operators_bit.go +++ b/fhevm/operators_bit.go @@ -61,7 +61,7 @@ func fheShlRun(environment EVMEnvironment, caller common.Address, addr common.Ad return importRandomCiphertext(environment, lhs.fheUintType()), nil } - result, err := lhs.ciphertext.ScalarShl(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarShl(rhs) if err != nil { logger.Error("fheShl failed", "err", err) return nil, err @@ -127,7 +127,7 @@ func fheShrRun(environment EVMEnvironment, caller common.Address, addr common.Ad return importRandomCiphertext(environment, lhs.fheUintType()), nil } - result, err := lhs.ciphertext.ScalarShr(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarShr(rhs) if err != nil { logger.Error("fheShr failed", "err", err) return nil, err diff --git a/fhevm/operators_comparison.go b/fhevm/operators_comparison.go index 713f759..3df181d 100644 --- a/fhevm/operators_comparison.go +++ b/fhevm/operators_comparison.go @@ -62,7 +62,7 @@ func fheLeRun(environment EVMEnvironment, caller common.Address, addr common.Add return importRandomCiphertext(environment, tfhe.FheBool), nil } - result, err := lhs.ciphertext.ScalarLe(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarLe(rhs) if err != nil { logger.Error("fheLe failed", "err", err) return nil, err @@ -128,7 +128,7 @@ func fheLtRun(environment EVMEnvironment, caller common.Address, addr common.Add return importRandomCiphertext(environment, tfhe.FheBool), nil } - result, err := lhs.ciphertext.ScalarLt(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarLt(rhs) if err != nil { logger.Error("fheLt failed", "err", err) return nil, err @@ -194,7 +194,7 @@ func fheEqRun(environment EVMEnvironment, caller common.Address, addr common.Add return importRandomCiphertext(environment, tfhe.FheBool), nil } - result, err := lhs.ciphertext.ScalarEq(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarEq(rhs) if err != nil { logger.Error("fheEq failed", "err", err) return nil, err @@ -260,7 +260,7 @@ func fheGeRun(environment EVMEnvironment, caller common.Address, addr common.Add return importRandomCiphertext(environment, tfhe.FheBool), nil } - result, err := lhs.ciphertext.ScalarGe(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarGe(rhs) if err != nil { logger.Error("fheGe failed", "err", err) return nil, err @@ -326,7 +326,7 @@ func fheGtRun(environment EVMEnvironment, caller common.Address, addr common.Add return importRandomCiphertext(environment, tfhe.FheBool), nil } - result, err := lhs.ciphertext.ScalarGt(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarGt(rhs) if err != nil { logger.Error("fheGt failed", "err", err) return nil, err @@ -392,7 +392,7 @@ func fheNeRun(environment EVMEnvironment, caller common.Address, addr common.Add return importRandomCiphertext(environment, tfhe.FheBool), nil } - result, err := lhs.ciphertext.ScalarNe(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarNe(rhs) if err != nil { logger.Error("fheNe failed", "err", err) return nil, err @@ -458,7 +458,7 @@ func fheMinRun(environment EVMEnvironment, caller common.Address, addr common.Ad return importRandomCiphertext(environment, lhs.fheUintType()), nil } - result, err := lhs.ciphertext.ScalarMin(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarMin(rhs) if err != nil { logger.Error("fheMin failed", "err", err) return nil, err @@ -524,7 +524,7 @@ func fheMaxRun(environment EVMEnvironment, caller common.Address, addr common.Ad return importRandomCiphertext(environment, lhs.fheUintType()), nil } - result, err := lhs.ciphertext.ScalarMax(rhs.Uint64()) + result, err := lhs.ciphertext.ScalarMax(rhs) if err != nil { logger.Error("fheMax failed", "err", err) return nil, err diff --git a/fhevm/operators_crypto.go b/fhevm/operators_crypto.go index e3db9b7..bea796d 100644 --- a/fhevm/operators_crypto.go +++ b/fhevm/operators_crypto.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "encoding/hex" "errors" + "fmt" "math/big" "time" @@ -115,6 +116,8 @@ func reencryptRun(environment EVMEnvironment, caller common.Address, addr common fheType = kms.FheType_Euint32 case tfhe.FheUint64: fheType = kms.FheType_Euint64 + case tfhe.FheUint160: + fheType = kms.FheType_Euint160 } pubKey := input[32:64] @@ -165,7 +168,6 @@ func reencryptRun(environment EVMEnvironment, caller common.Address, addr common return nil, errors.New(msg) } - func decryptRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) { input = input[:minInt(32, len(input))] @@ -205,9 +207,7 @@ func decryptRun(environment EVMEnvironment, caller common.Address, addr common.A // Always return a 32-byte big-endian integer. ret := make([]byte, 32) - bigIntValue := big.NewInt(0) - bigIntValue.SetUint64(plaintext) - bigIntValue.FillBytes(ret) + plaintext.FillBytes(ret) return ret, nil } @@ -237,7 +237,7 @@ func getCiphertextRun(environment EVMEnvironment, caller common.Address, addr co return ciphertext.bytes, nil } -func decryptValue(environment EVMEnvironment, ct *tfhe.TfheCiphertext) (uint64, error) { +func decryptValue(environment EVMEnvironment, ct *tfhe.TfheCiphertext) (*big.Int, error) { logger := environment.GetLogger() var fheType kms.FheType @@ -254,6 +254,8 @@ func decryptValue(environment EVMEnvironment, ct *tfhe.TfheCiphertext) (uint64, fheType = kms.FheType_Euint32 case tfhe.FheUint64: fheType = kms.FheType_Euint64 + case tfhe.FheUint160: + fheType = kms.FheType_Euint160 } // TODO: generate merkle proof for some data @@ -271,7 +273,7 @@ func decryptValue(environment EVMEnvironment, ct *tfhe.TfheCiphertext) (uint64, conn, err := grpc.Dial(kms.KmsEndpointAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - return 0, errors.New("kms unreachable") + return nil, errors.New("kms unreachable") } defer conn.Close() @@ -283,10 +285,59 @@ func decryptValue(environment EVMEnvironment, ct *tfhe.TfheCiphertext) (uint64, res, err := ep.Decrypt(ctx, decryptionRequest) if err != nil { logger.Error("decrypt failed", "err", err) - return 0, err + return nil, err + } + + // plaintext is a byte slice + plaintextBytes := res.Plaintext + + // Variable to hold the resulting big.Int + var plaintextBigInt *big.Int + + switch fheType { + case kms.FheType_Bool, kms.FheType_Euint4, kms.FheType_Euint8: + + if len(plaintextBytes) > 0 { + plaintextBigInt = big.NewInt(int64(plaintextBytes[0])) + } else { + return nil, errors.New("decryption resulted in empty plaintext for a single-byte FheType") + } + case kms.FheType_Euint16: + // For Euint16, ensure plaintextBytes has at least 2 bytes. + if len(plaintextBytes) >= 2 { + // Use binary.BigEndian.Uint16 to convert bytes to uint16, then to big.Int. + uintVal := binary.BigEndian.Uint16(plaintextBytes) + plaintextBigInt = big.NewInt(int64(uintVal)) + } else { + return nil, errors.New("decryption resulted in insufficient bytes for FheType_Euint16") + } + case kms.FheType_Euint32: + // Similar to Euint16, but with 4 bytes to uint32. + if len(plaintextBytes) >= 4 { + uintVal := binary.BigEndian.Uint32(plaintextBytes) + plaintextBigInt = big.NewInt(int64(uintVal)) + } else { + return nil, errors.New("decryption resulted in insufficient bytes for FheType_Euint32") + } + case kms.FheType_Euint64: + // For Euint64, ensure there are 8 bytes to work with. + if len(plaintextBytes) >= 8 { + uintVal := binary.BigEndian.Uint64(plaintextBytes) + plaintextBigInt = new(big.Int).SetUint64(uintVal) + } else { + return nil, errors.New("decryption resulted in insufficient bytes for FheType_Euint64") + } + case kms.FheType_Euint160: + logger.Info("decrypt success", "plaintextBytes", plaintextBytes) + logger.Info("decrypt success", "plaintextBytes", fmt.Sprintf("%v", plaintextBytes)) + // Special handling for FheUint160, already covered. + plaintextBigInt, err = tfhe.U256BytesToBigInt(plaintextBytes) + default: + return nil, fmt.Errorf("unsupported FheType: %v", fheType) } - return uint64(res.Plaintext), err + return plaintextBigInt, nil + } func castRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) { diff --git a/fhevm/tfhe/tfhe_ciphertext.go b/fhevm/tfhe/tfhe_ciphertext.go index ad39c0b..d2ecc4f 100644 --- a/fhevm/tfhe/tfhe_ciphertext.go +++ b/fhevm/tfhe/tfhe_ciphertext.go @@ -17,12 +17,14 @@ import ( type FheUintType uint8 const ( - FheBool FheUintType = 0 - FheUint4 FheUintType = 1 - FheUint8 FheUintType = 2 - FheUint16 FheUintType = 3 - FheUint32 FheUintType = 4 - FheUint64 FheUintType = 5 + FheBool FheUintType = 0 + FheUint4 FheUintType = 1 + FheUint8 FheUintType = 2 + FheUint16 FheUintType = 3 + FheUint32 FheUintType = 4 + FheUint64 FheUintType = 5 + FheUint128 FheUintType = 6 + FheUint160 FheUintType = 7 ) func (t FheUintType) String() string { @@ -39,13 +41,17 @@ func (t FheUintType) String() string { return "fheUint32" case FheUint64: return "fheUint64" + case FheUint128: + return "fheUint128" + case FheUint160: + return "fheUint160" default: return "unknownFheUintType" } } func IsValidFheType(t byte) bool { - if uint8(t) < uint8(FheBool) || uint8(t) > uint8(FheUint64) { + if uint8(t) < uint8(FheBool) || uint8(t) > uint8(FheUint160) { return false } return true @@ -69,6 +75,14 @@ func boolBinaryScalarNotSupportedOp(lhs unsafe.Pointer, rhs C.bool) (unsafe.Poin return nil, errors.New("Bool is not supported") } +func fheUint160BinaryNotSupportedOp(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return nil, errors.New("FHEUint160 is not supported") +} + +func fheUint160BinaryScalarNotSupportedOp(lhs unsafe.Pointer, rhs C.U256) (unsafe.Pointer, error) { + return nil, errors.New("FHEUint160 is not supported") +} + func boolUnaryNotSupportedOp(lhs unsafe.Pointer) (unsafe.Pointer, error) { return nil, errors.New("Bool is not supported") } @@ -112,6 +126,12 @@ func (ct *TfheCiphertext) Deserialize(in []byte, t FheUintType) error { return errors.New("FheUint64 ciphertext deserialization failed") } C.destroy_fhe_uint64(ptr) + case FheUint160: + ptr := C.deserialize_fhe_uint160(toDynamicBufferView((in))) + if ptr == nil { + return errors.New("FheUint160 ciphertext deserialization failed") + } + C.destroy_fhe_uint160(ptr) default: panic("deserialize: unexpected ciphertext type") } @@ -192,6 +212,17 @@ func (ct *TfheCiphertext) DeserializeCompact(in []byte, t FheUintType) error { if err != nil { return err } + case FheUint160: + ptr := C.deserialize_compact_fhe_uint160(toDynamicBufferView((in))) + if ptr == nil { + return errors.New("compact FheUint160 ciphertext deserialization failed") + } + var err error + ct.Serialization, err = serialize(ptr, t) + C.destroy_fhe_uint160(ptr) + if err != nil { + return err + } default: panic("deserializeCompact: unexpected ciphertext type") } @@ -252,6 +283,17 @@ func (ct *TfheCiphertext) Encrypt(value big.Int, t FheUintType) *TfheCiphertext if err != nil { panic(err) } + case FheUint160: + input, err := bigIntToU256(&value) + if err != nil { + panic(err) + } + ptr = C.public_key_encrypt_fhe_uint160(pks, input) + ct.Serialization, err = serialize(ptr, t) + C.destroy_fhe_uint160(ptr) + if err != nil { + panic(err) + } default: panic("encrypt: unexpected ciphertext type") } @@ -310,6 +352,17 @@ func (ct *TfheCiphertext) TrivialEncrypt(value big.Int, t FheUintType) *TfheCiph if err != nil { panic(err) } + case FheUint160: + input, err := bigIntToU256(&value) + if err != nil { + panic(err) + } + ptr = C.trivial_encrypt_fhe_uint160(sks, *input) + ct.Serialization, err = serialize(ptr, t) + C.destroy_fhe_uint160(ptr) + if err != nil { + panic(err) + } default: panic("trivialEncrypt: unexpected ciphertext type") } @@ -468,6 +521,7 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, op16 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), op32 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), op64 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), + op160 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), returnBool bool) (*TfheCiphertext, error) { if lhs.FheUintType != rhs.FheUintType { return nil, errors.New("binary operations are only well-defined for identical types") @@ -678,6 +732,40 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, } res.Serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) + case FheUint160: + lhs_ptr := C.deserialize_fhe_uint160(toDynamicBufferView((lhs.Serialization))) + if lhs_ptr == nil { + return nil, errors.New("160 bit binary op deserialization failed") + } + rhs_ptr := C.deserialize_fhe_uint160(toDynamicBufferView((rhs.Serialization))) + if rhs_ptr == nil { + C.destroy_fhe_uint160(lhs_ptr) + return nil, errors.New("160 bit binary op deserialization failed") + } + res_ptr, err := op160(lhs_ptr, rhs_ptr) + if err != nil { + return nil, err + } + C.destroy_fhe_uint160(lhs_ptr) + C.destroy_fhe_uint160(rhs_ptr) + if res_ptr == nil { + return nil, errors.New("160 bit binary op failed") + } + if returnBool { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("bool binary op serialization failed") + } + } else { + ret := C.serialize_fhe_uint160(res_ptr, res_ser) + C.destroy_fhe_uint160(res_ptr) + if ret != 0 { + return nil, errors.New("160 bit binary op serialization failed") + } + } + res.Serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_dynamic_buffer(res_ser) default: panic("binary op unexpected ciphertext type") } @@ -851,13 +939,15 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte return res, nil } -func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, +// Update: Switched 'rhs' from uint64 to *big.Int to enable 160-bit operations (eq,ne). +func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs *big.Int, opBool func(lhs unsafe.Pointer, rhs C.bool) (unsafe.Pointer, error), op4 func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error), op8 func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error), op16 func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error), op32 func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error), op64 func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error), + op160 func(lhs unsafe.Pointer, rhs C.U256) (unsafe.Pointer, error), returnBool bool) (*TfheCiphertext, error) { res := new(TfheCiphertext) if returnBool { @@ -865,6 +955,7 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, } else { res.FheUintType = lhs.FheUintType } + rhs_uint64 := rhs.Uint64() res_ser := &C.DynamicBuffer{} switch lhs.FheUintType { case FheBool: @@ -872,7 +963,7 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, if lhs_ptr == nil { return nil, errors.New("Bool scalar op deserialization failed") } - scalar := C.bool(rhs == 1) + scalar := C.bool(rhs_uint64 == 1) res_ptr, err := opBool(lhs_ptr, scalar) if err != nil { return nil, err @@ -893,7 +984,7 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, if lhs_ptr == nil { return nil, errors.New("4 bit scalar op deserialization failed") } - scalar := C.uint8_t(rhs) + scalar := C.uint8_t(rhs_uint64) res_ptr, err := op4(lhs_ptr, scalar) if err != nil { return nil, err @@ -922,7 +1013,7 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, if lhs_ptr == nil { return nil, errors.New("8 bit scalar op deserialization failed") } - scalar := C.uint8_t(rhs) + scalar := C.uint8_t(rhs_uint64) res_ptr, err := op8(lhs_ptr, scalar) if err != nil { return nil, err @@ -951,7 +1042,7 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, if lhs_ptr == nil { return nil, errors.New("16 bit scalar op deserialization failed") } - scalar := C.uint16_t(rhs) + scalar := C.uint16_t(rhs_uint64) res_ptr, err := op16(lhs_ptr, scalar) if err != nil { return nil, err @@ -980,7 +1071,7 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, if lhs_ptr == nil { return nil, errors.New("32 bit scalar op deserialization failed") } - scalar := C.uint32_t(rhs) + scalar := C.uint32_t(rhs_uint64) res_ptr, err := op32(lhs_ptr, scalar) if err != nil { return nil, err @@ -1009,7 +1100,7 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, if lhs_ptr == nil { return nil, errors.New("64 bit scalar op deserialization failed") } - scalar := C.uint64_t(rhs) + scalar := C.uint64_t(rhs_uint64) res_ptr, err := op64(lhs_ptr, scalar) if err != nil { return nil, err @@ -1033,6 +1124,38 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, } res.Serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) + case FheUint160: + lhs_ptr := C.deserialize_fhe_uint160(toDynamicBufferView((lhs.Serialization))) + if lhs_ptr == nil { + return nil, errors.New("160 bit scalar op deserialization failed") + } + + scalar, err := bigIntToU256(rhs) + + res_ptr, err := op160(lhs_ptr, *scalar) + if err != nil { + return nil, err + } + C.destroy_fhe_uint160(lhs_ptr) + if res_ptr == nil { + return nil, errors.New("160 bit scalar op failed") + } + if returnBool { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("Bool scalar op serialization failed") + } + } else { + ret := C.serialize_fhe_uint160(res_ptr, res_ser) + C.destroy_fhe_uint160(res_ptr) + if ret != 0 { + return nil, errors.New("160 bit scalar op serialization failed") + } + } + res.Serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_dynamic_buffer(res_ser) + default: panic("scalar op unexpected ciphertext type") } @@ -1057,10 +1180,11 @@ func (lhs *TfheCiphertext) Add(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.add_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryNotSupportedOp, false) } -func (lhs *TfheCiphertext) ScalarAdd(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarAdd(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1077,7 +1201,8 @@ func (lhs *TfheCiphertext) ScalarAdd(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_add_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryScalarNotSupportedOp, false) } func (lhs *TfheCiphertext) Sub(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1097,10 +1222,11 @@ func (lhs *TfheCiphertext) Sub(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.sub_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryNotSupportedOp, false) } -func (lhs *TfheCiphertext) ScalarSub(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarSub(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1117,7 +1243,8 @@ func (lhs *TfheCiphertext) ScalarSub(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_sub_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryScalarNotSupportedOp, false) } func (lhs *TfheCiphertext) Mul(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1137,10 +1264,11 @@ func (lhs *TfheCiphertext) Mul(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.mul_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryNotSupportedOp, false) } -func (lhs *TfheCiphertext) ScalarMul(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarMul(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1157,10 +1285,11 @@ func (lhs *TfheCiphertext) ScalarMul(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_mul_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryScalarNotSupportedOp, false) } -func (lhs *TfheCiphertext) ScalarDiv(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarDiv(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1177,10 +1306,11 @@ func (lhs *TfheCiphertext) ScalarDiv(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_div_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryScalarNotSupportedOp, false) } -func (lhs *TfheCiphertext) ScalarRem(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarRem(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1197,7 +1327,8 @@ func (lhs *TfheCiphertext) ScalarRem(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_rem_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryScalarNotSupportedOp, false) } func (lhs *TfheCiphertext) Bitand(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1219,7 +1350,8 @@ func (lhs *TfheCiphertext) Bitand(rhs *TfheCiphertext) (*TfheCiphertext, error) }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.bitand_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryNotSupportedOp, false) } func (lhs *TfheCiphertext) Bitor(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1241,7 +1373,8 @@ func (lhs *TfheCiphertext) Bitor(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.bitor_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryNotSupportedOp, false) } func (lhs *TfheCiphertext) Bitxor(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1263,7 +1396,8 @@ func (lhs *TfheCiphertext) Bitxor(rhs *TfheCiphertext) (*TfheCiphertext, error) }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.bitxor_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryNotSupportedOp, false) } func (lhs *TfheCiphertext) Shl(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1283,10 +1417,11 @@ func (lhs *TfheCiphertext) Shl(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.shl_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryNotSupportedOp, false) } -func (lhs *TfheCiphertext) ScalarShl(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarShl(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1303,7 +1438,8 @@ func (lhs *TfheCiphertext) ScalarShl(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_shl_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryScalarNotSupportedOp, false) } func (lhs *TfheCiphertext) Shr(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1323,10 +1459,12 @@ func (lhs *TfheCiphertext) Shr(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.shr_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryNotSupportedOp, + false) } -func (lhs *TfheCiphertext) ScalarShr(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarShr(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1343,7 +1481,8 @@ func (lhs *TfheCiphertext) ScalarShr(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_shr_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryScalarNotSupportedOp, false) } func (lhs *TfheCiphertext) Eq(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1363,10 +1502,14 @@ func (lhs *TfheCiphertext) Eq(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.eq_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.eq_fhe_uint160(lhs, rhs, sks), nil + }, + true) } -func (lhs *TfheCiphertext) ScalarEq(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarEq(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1383,7 +1526,11 @@ func (lhs *TfheCiphertext) ScalarEq(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_eq_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, + func(lhs unsafe.Pointer, rhs C.U256) (unsafe.Pointer, error) { + return C.scalar_eq_fhe_uint160(lhs, rhs, sks), nil + }, + true) } func (lhs *TfheCiphertext) Ne(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1403,10 +1550,14 @@ func (lhs *TfheCiphertext) Ne(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.ne_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.ne_fhe_uint160(lhs, rhs, sks), nil + }, + true) } -func (lhs *TfheCiphertext) ScalarNe(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarNe(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1423,7 +1574,11 @@ func (lhs *TfheCiphertext) ScalarNe(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_ne_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, + func(lhs unsafe.Pointer, rhs C.U256) (unsafe.Pointer, error) { + return C.scalar_ne_fhe_uint160(lhs, rhs, sks), nil + }, + true) } func (lhs *TfheCiphertext) Ge(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1443,10 +1598,12 @@ func (lhs *TfheCiphertext) Ge(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.ge_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, + fheUint160BinaryNotSupportedOp, + true) } -func (lhs *TfheCiphertext) ScalarGe(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarGe(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1463,7 +1620,8 @@ func (lhs *TfheCiphertext) ScalarGe(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_ge_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, fheUint160BinaryScalarNotSupportedOp, + true) } func (lhs *TfheCiphertext) Gt(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1483,10 +1641,12 @@ func (lhs *TfheCiphertext) Gt(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.gt_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, + fheUint160BinaryNotSupportedOp, + true) } -func (lhs *TfheCiphertext) ScalarGt(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarGt(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1503,7 +1663,8 @@ func (lhs *TfheCiphertext) ScalarGt(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_gt_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, fheUint160BinaryScalarNotSupportedOp, + true) } func (lhs *TfheCiphertext) Le(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1523,10 +1684,12 @@ func (lhs *TfheCiphertext) Le(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.le_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, + fheUint160BinaryNotSupportedOp, + true) } -func (lhs *TfheCiphertext) ScalarLe(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarLe(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1544,7 +1707,9 @@ func (lhs *TfheCiphertext) ScalarLe(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_le_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, + fheUint160BinaryScalarNotSupportedOp, + true) } func (lhs *TfheCiphertext) Lt(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1564,10 +1729,12 @@ func (lhs *TfheCiphertext) Lt(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.lt_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, + fheUint160BinaryNotSupportedOp, + true) } -func (lhs *TfheCiphertext) ScalarLt(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarLt(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1584,7 +1751,9 @@ func (lhs *TfheCiphertext) ScalarLt(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_lt_fhe_uint64(lhs, rhs, sks), nil - }, true) + }, + fheUint160BinaryScalarNotSupportedOp, + true) } func (lhs *TfheCiphertext) Min(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1604,10 +1773,11 @@ func (lhs *TfheCiphertext) Min(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.min_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryNotSupportedOp, false) } -func (lhs *TfheCiphertext) ScalarMin(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarMin(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1624,7 +1794,8 @@ func (lhs *TfheCiphertext) ScalarMin(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_min_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryScalarNotSupportedOp, false) } func (lhs *TfheCiphertext) Max(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1644,10 +1815,11 @@ func (lhs *TfheCiphertext) Max(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.max_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryNotSupportedOp, false) } -func (lhs *TfheCiphertext) ScalarMax(rhs uint64) (*TfheCiphertext, error) { +func (lhs *TfheCiphertext) ScalarMax(rhs *big.Int) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, boolBinaryScalarNotSupportedOp, func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { @@ -1664,7 +1836,8 @@ func (lhs *TfheCiphertext) ScalarMax(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_max_fhe_uint64(lhs, rhs, sks), nil - }, false) + }, + fheUint160BinaryScalarNotSupportedOp, false) } func (lhs *TfheCiphertext) Neg() (*TfheCiphertext, error) { @@ -2237,6 +2410,19 @@ func (ct *TfheCiphertext) Decrypt() (big.Int, error) { ret = C.decrypt_fhe_uint64(cks, ptr, &result) C.destroy_fhe_uint64(ptr) value = uint64(result) + case FheUint160: + ptr := C.deserialize_fhe_uint160(toDynamicBufferView(ct.Serialization)) + if ptr == nil { + return *new(big.Int).SetUint64(0), errors.New("failed to deserialize FheUint160") + } + var result C.U256 + ret = C.decrypt_fhe_uint160(cks, ptr, &result) + if ret != 0 { + return *new(big.Int).SetUint64(0), errors.New("failed to decrypt FheUint160") + } + C.destroy_fhe_uint160(ptr) + resultBigInt := *u256ToBigInt(result) + return resultBigInt, nil default: panic("decrypt: unexpected ciphertext type") } diff --git a/fhevm/tfhe/tfhe_key_management.go b/fhevm/tfhe/tfhe_key_management.go index 9be572a..4daf6d1 100644 --- a/fhevm/tfhe/tfhe_key_management.go +++ b/fhevm/tfhe/tfhe_key_management.go @@ -79,6 +79,7 @@ func initCiphertextSizes() { compactFheCiphertextSize[FheUint16] = uint(len(EncryptAndSerializeCompact(0, FheUint16))) compactFheCiphertextSize[FheUint32] = uint(len(EncryptAndSerializeCompact(0, FheUint32))) compactFheCiphertextSize[FheUint64] = uint(len(EncryptAndSerializeCompact(0, FheUint64))) + compactFheCiphertextSize[FheUint160] = uint(len(EncryptAndSerializeCompact(0, FheUint160))) } func InitGlobalKeysFromFiles(keysDir string) error { diff --git a/fhevm/tfhe/tfhe_test.go b/fhevm/tfhe/tfhe_test.go index 167440c..221b083 100644 --- a/fhevm/tfhe/tfhe_test.go +++ b/fhevm/tfhe/tfhe_test.go @@ -2,7 +2,9 @@ package tfhe import ( "bytes" + "encoding/hex" "fmt" + "log" "math" "math/big" "os" @@ -37,12 +39,25 @@ func TfheEncryptDecrypt(t *testing.T, fheUintType FheUintType) { val.SetUint64(1333337) case FheUint64: val.SetUint64(13333377777777777) + + case FheUint160: + hexValue := "12345676876661323221435343" + byteValue, err := hex.DecodeString(hexValue) + if err != nil { + log.Fatalf("Failed to decode hex string: %v", err) + } + val.SetBytes(byteValue) } ct := new(TfheCiphertext) ct.Encrypt(val, fheUintType) res, err := ct.Decrypt() - if err != nil || res.Uint64() != val.Uint64() { - t.Fatalf("%d != %d", val.Uint64(), res.Uint64()) + + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + if res.Cmp(&val) != 0 { + t.Fatalf("Decryption result does not match the original value. Expected %s, got %s", val.Text(10), res.Text(10)) } } @@ -61,12 +76,23 @@ func TfheTrivialEncryptDecrypt(t *testing.T, fheUintType FheUintType) { val.SetUint64(1333337) case FheUint64: val.SetUint64(13333377777777777) + case FheUint160: + hexValue := "12345676876661323221435343" + byteValue, err := hex.DecodeString(hexValue) + if err != nil { + log.Fatalf("Failed to decode hex string: %v", err) + } + val.SetBytes(byteValue) } ct := new(TfheCiphertext) ct.TrivialEncrypt(val, fheUintType) res, err := ct.Decrypt() - if err != nil || res.Uint64() != val.Uint64() { - t.Fatalf("%d != %d", val.Uint64(), res.Uint64()) + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + if res.Cmp(&val) != 0 { + t.Fatalf("Decryption result does not match the original value. Expected %s, got %s", val.Text(10), res.Text(10)) } } @@ -85,6 +111,13 @@ func TfheSerializeDeserialize(t *testing.T, fheUintType FheUintType) { val = *big.NewInt(1333337) case FheUint64: val = *big.NewInt(13333377777777777) + case FheUint160: + hexValue := "12345676876661323221435343" + byteValue, err := hex.DecodeString(hexValue) + if err != nil { + log.Fatalf("Failed to decode hex string: %v", err) + } + val.SetBytes(byteValue) } ct1 := new(TfheCiphertext) ct1.Encrypt(val, fheUintType) @@ -115,6 +148,8 @@ func TfheSerializeDeserializeCompact(t *testing.T, fheUintType FheUintType) { val = 1333337 case FheUint64: val = 13333377777777777 + case FheUint160: + val = 13333377777777777 } ser := EncryptAndSerializeCompact(val, fheUintType) @@ -157,6 +192,13 @@ func TfheTrivialSerializeDeserialize(t *testing.T, fheUintType FheUintType) { val = *big.NewInt(1333337) case FheUint64: val = *big.NewInt(13333377777777777) + case FheUint160: + hexValue := "12345676876661323221435343" + byteValue, err := hex.DecodeString(hexValue) + if err != nil { + log.Fatalf("Failed to decode hex string: %v", err) + } + val.SetBytes(byteValue) } ct1 := new(TfheCiphertext) ct1.TrivialEncrypt(val, fheUintType) @@ -198,6 +240,7 @@ func TfheDeserializeCompact(t *testing.T, fheUintType FheUintType) { case FheUint64: val = 13333377777777777 } + ser := EncryptAndSerializeCompact(val, fheUintType) ct := new(TfheCiphertext) err := ct.DeserializeCompact(ser, fheUintType) @@ -271,7 +314,7 @@ func TfheScalarAdd(t *testing.T, fheUintType FheUintType) { expected := new(big.Int).Add(&a, &b) ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes, _ := ctA.ScalarAdd(b.Uint64()) + ctRes, _ := ctA.ScalarAdd(&b) res, err := ctRes.Decrypt() if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) @@ -331,7 +374,7 @@ func TfheScalarSub(t *testing.T, fheUintType FheUintType) { expected := new(big.Int).Sub(&a, &b) ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes, _ := ctA.ScalarSub(b.Uint64()) + ctRes, _ := ctA.ScalarSub(&b) res, err := ctRes.Decrypt() if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) @@ -391,7 +434,7 @@ func TfheScalarMul(t *testing.T, fheUintType FheUintType) { expected := new(big.Int).Mul(&a, &b) ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes, _ := ctA.ScalarMul(b.Uint64()) + ctRes, _ := ctA.ScalarMul(&b) res, err := ctRes.Decrypt() if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) @@ -420,7 +463,7 @@ func TfheScalarDiv(t *testing.T, fheUintType FheUintType) { expected := new(big.Int).Div(&a, &b) ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes, _ := ctA.ScalarDiv(b.Uint64()) + ctRes, _ := ctA.ScalarDiv(&b) res, err := ctRes.Decrypt() if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) @@ -449,7 +492,7 @@ func TfheScalarRem(t *testing.T, fheUintType FheUintType) { expected := new(big.Int).Rem(&a, &b) ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes, _ := ctA.ScalarRem(b.Uint64()) + ctRes, _ := ctA.ScalarRem(&b) res, err := ctRes.Decrypt() if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) @@ -605,7 +648,7 @@ func TfheScalarShl(t *testing.T, fheUintType FheUintType) { expected := new(big.Int).Lsh(&a, uint(b.Uint64())) ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes, _ := ctA.ScalarShl(b.Uint64()) + ctRes, _ := ctA.ScalarShl(&b) res, err := ctRes.Decrypt() if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) @@ -665,7 +708,7 @@ func TfheScalarShr(t *testing.T, fheUintType FheUintType) { expected := new(big.Int).Rsh(&a, uint(b.Uint64())) ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes, _ := ctA.ScalarShr(b.Uint64()) + ctRes, _ := ctA.ScalarShr(&b) res, err := ctRes.Decrypt() if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) @@ -690,14 +733,24 @@ func TfheEq(t *testing.T, fheUintType FheUintType) { case FheUint64: a.SetUint64(1337) b.SetUint64(1337) + case FheUint160: + hexValue := "12345676876661323221435343" + byteValue, err := hex.DecodeString(hexValue) + if err != nil { + log.Fatalf("Failed to decode hex string: %v", err) + } + a.SetBytes(byteValue) + b.SetBytes(byteValue) } + var expected uint64 - expectedBool := a.Uint64() == b.Uint64() - if expectedBool { + expectedPlain := a.Cmp(&b) + if expectedPlain == 0 { expected = 1 } else { expected = 0 } + ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) ctB := new(TfheCiphertext) @@ -707,6 +760,7 @@ func TfheEq(t *testing.T, fheUintType FheUintType) { if err != nil || res.Uint64() != expected { t.Fatalf("%d != %d", expected, res.Uint64()) } + } func TfheScalarEq(t *testing.T, fheUintType FheUintType) { @@ -727,17 +781,25 @@ func TfheScalarEq(t *testing.T, fheUintType FheUintType) { case FheUint64: a.SetUint64(13371337) b.SetUint64(1337) + case FheUint160: + hexValue := "12345676876661323221435343" + byteValue, err := hex.DecodeString(hexValue) + if err != nil { + log.Fatalf("Failed to decode hex string: %v", err) + } + a.SetBytes(byteValue) + b.SetBytes(byteValue) } var expected uint64 - expectedBool := a.Uint64() == b.Uint64() - if expectedBool { + expectedPlain := a.Cmp(&b) + if expectedPlain == 0 { expected = 1 } else { expected = 0 } ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes, _ := ctA.ScalarEq(b.Uint64()) + ctRes, _ := ctA.ScalarEq(&b) res, err := ctRes.Decrypt() if err != nil || res.Uint64() != expected { t.Fatalf("%d != %d", expected, res.Uint64()) @@ -762,13 +824,22 @@ func TfheNe(t *testing.T, fheUintType FheUintType) { case FheUint64: a.SetUint64(1337) b.SetUint64(1337) + case FheUint160: + hexValue := "12345676876661323221435343" + byteValue, err := hex.DecodeString(hexValue) + if err != nil { + log.Fatalf("Failed to decode hex string: %v", err) + } + a.SetBytes(byteValue) + b.SetUint64(8888) } + var expected uint64 - expectedBool := a.Uint64() != b.Uint64() - if expectedBool { - expected = 1 - } else { + expectedPlain := a.Cmp(&b) + if expectedPlain == 0 { expected = 0 + } else { + expected = 1 } ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) @@ -799,17 +870,27 @@ func TfheScalarNe(t *testing.T, fheUintType FheUintType) { case FheUint64: a.SetUint64(13371337) b.SetUint64(1337) + case FheUint160: + hexValue := "12345676876661323221435343" + byteValue, err := hex.DecodeString(hexValue) + if err != nil { + log.Fatalf("Failed to decode hex string: %v", err) + } + a.SetBytes(byteValue) + b.SetUint64(8888) } + var expected uint64 - expectedBool := a.Uint64() != b.Uint64() - if expectedBool { - expected = 1 - } else { + // No != for big.Int + expectedPlain := a.Cmp(&b) + if expectedPlain == 0 { expected = 0 + } else { + expected = 1 } ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes, _ := ctA.ScalarNe(b.Uint64()) + ctRes, _ := ctA.ScalarNe(&b) res, err := ctRes.Decrypt() if err != nil || res.Uint64() != expected { t.Fatalf("%d != %d", expected, res.Uint64()) @@ -872,7 +953,7 @@ func TfheScalarGe(t *testing.T, fheUintType FheUintType) { } ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes1, _ := ctA.ScalarGe(b.Uint64()) + ctRes1, _ := ctA.ScalarGe(&b) res1, err := ctRes1.Decrypt() if err != nil || res1.Uint64() != 1 { t.Fatalf("%d != %d", 0, res1.Uint64()) @@ -935,7 +1016,7 @@ func TfheScalarGt(t *testing.T, fheUintType FheUintType) { } ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes1, _ := ctA.ScalarGt(b.Uint64()) + ctRes1, _ := ctA.ScalarGt(&b) res1, err := ctRes1.Decrypt() if err != nil || res1.Uint64() != 1 { t.Fatalf("%d != %d", 0, res1.Uint64()) @@ -998,7 +1079,7 @@ func TfheScalarLe(t *testing.T, fheUintType FheUintType) { } ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes1, _ := ctA.ScalarLe(b.Uint64()) + ctRes1, _ := ctA.ScalarLe(&b) res1, err := ctRes1.Decrypt() if err != nil || res1.Uint64() != 0 { t.Fatalf("%d != %d", 0, res1.Uint64()) @@ -1061,7 +1142,7 @@ func TfheScalarLt(t *testing.T, fheUintType FheUintType) { } ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes1, _ := ctA.ScalarLt(b.Uint64()) + ctRes1, _ := ctA.ScalarLt(&b) res1, err := ctRes1.Decrypt() if err != nil || res1.Uint64() != 0 { t.Fatalf("%d != %d", 0, res1.Uint64()) @@ -1124,7 +1205,7 @@ func TfheScalarMin(t *testing.T, fheUintType FheUintType) { } ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes1, _ := ctA.ScalarMin(b.Uint64()) + ctRes1, _ := ctA.ScalarMin(&b) res1, err1 := ctRes1.Decrypt() if err1 != nil || res1.Uint64() != b.Uint64() { t.Fatalf("%d != %d", 0, res1.Uint64()) @@ -1187,7 +1268,7 @@ func TfheScalarMax(t *testing.T, fheUintType FheUintType) { } ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) - ctRes1, _ := ctA.ScalarMax(b.Uint64()) + ctRes1, _ := ctA.ScalarMax(&b) res1, err1 := ctRes1.Decrypt() if err1 != nil || res1.Uint64() != a.Uint64() { t.Fatalf("%d != %d", 0, res1.Uint64()) @@ -1365,6 +1446,10 @@ func TestTfheEncryptDecrypt64(t *testing.T) { TfheEncryptDecrypt(t, FheUint64) } +func TestTfheEncryptDecrypt160(t *testing.T) { + TfheEncryptDecrypt(t, FheUint160) +} + func TestTfheTrivialEncryptDecryptBool(t *testing.T) { TfheTrivialEncryptDecrypt(t, FheBool) } @@ -1389,6 +1474,10 @@ func TestTfheTrivialEncryptDecrypt64(t *testing.T) { TfheTrivialEncryptDecrypt(t, FheUint64) } +func TestTfheTrivialEncryptDecrypt160(t *testing.T) { + TfheTrivialEncryptDecrypt(t, FheUint160) +} + func TestTfheSerializeDeserializeBool(t *testing.T) { TfheSerializeDeserialize(t, FheBool) } @@ -1413,6 +1502,10 @@ func TestTfheSerializeDeserialize64(t *testing.T) { TfheSerializeDeserialize(t, FheUint64) } +func TestTfheSerializeDeserialize160(t *testing.T) { + TfheSerializeDeserialize(t, FheUint160) +} + func TestTfheSerializeDeserializeCompactBool(t *testing.T) { TfheSerializeDeserializeCompact(t, FheBool) } @@ -1433,6 +1526,10 @@ func TestTfheSerializeDeserializeCompact64(t *testing.T) { TfheSerializeDeserializeCompact(t, FheUint64) } +func TestTfheSerializeDeserializeCompact160(t *testing.T) { + TfheSerializeDeserializeCompact(t, FheUint160) +} + func TestTfheTrivialSerializeDeserializeBool(t *testing.T) { TfheTrivialSerializeDeserialize(t, FheBool) } @@ -1457,6 +1554,10 @@ func TestTfheTrivialSerializeDeserialize64(t *testing.T) { TfheTrivialSerializeDeserialize(t, FheUint64) } +func TestTfheTrivialSerializeDeserialize160(t *testing.T) { + TfheTrivialSerializeDeserialize(t, FheUint160) +} + func TestTfheDeserializeFailureBool(t *testing.T) { TfheDeserializeFailure(t, FheBool) } @@ -1845,6 +1946,10 @@ func TestTfheEq64(t *testing.T) { TfheEq(t, FheUint64) } +func TestTfheEq160(t *testing.T) { + TfheEq(t, FheUint160) +} + func TestTfheScalarEq4(t *testing.T) { TfheScalarEq(t, FheUint4) } @@ -1865,6 +1970,10 @@ func TestTfheScalarEq64(t *testing.T) { TfheScalarEq(t, FheUint64) } +func TestTfheScalarEq160(t *testing.T) { + TfheScalarEq(t, FheUint160) +} + func TestTfheNe4(t *testing.T) { TfheNe(t, FheUint8) } @@ -1885,6 +1994,10 @@ func TestTfheNe64(t *testing.T) { TfheNe(t, FheUint64) } +func TestTfheNe160(t *testing.T) { + TfheNe(t, FheUint160) +} + func TestTfheScalarNe4(t *testing.T) { TfheScalarNe(t, FheUint4) } @@ -1905,6 +2018,10 @@ func TestTfheScalarNe64(t *testing.T) { TfheScalarNe(t, FheUint64) } +func TestTfheScalarNe160(t *testing.T) { + TfheScalarNe(t, FheUint160) +} + func TestTfheGe4(t *testing.T) { TfheGe(t, FheUint4) } diff --git a/fhevm/tfhe/tfhe_wrappers.c b/fhevm/tfhe/tfhe_wrappers.c index 31efeec..9a704a2 100644 --- a/fhevm/tfhe/tfhe_wrappers.c +++ b/fhevm/tfhe/tfhe_wrappers.c @@ -351,6 +351,46 @@ void* deserialize_compact_fhe_uint64(DynamicBufferView in) { return ct; } + +int serialize_fhe_uint160(void *ct, DynamicBuffer* out) { + return fhe_uint160_serialize(ct, out); +} + +void* deserialize_fhe_uint160(DynamicBufferView in) { + FheUint160* ct = NULL; + const int r = fhe_uint160_deserialize(in, &ct); + if(r != 0) { + return NULL; + } + return ct; +} + + +void* deserialize_compact_fhe_uint160(DynamicBufferView in) { + CompactFheUint160List* list = NULL; + FheUint160* ct = NULL; + + int r = compact_fhe_uint160_list_deserialize(in, &list); + if(r != 0) { + return NULL; + } + size_t len = 0; + r = compact_fhe_uint160_list_len(list, &len); + // Expect only 1 ciphertext in the list. + if(r != 0 || len != 1) { + r = compact_fhe_uint160_list_destroy(list); + assert(r == 0); + return NULL; + } + r = compact_fhe_uint160_list_expand(list, &ct, 1); + if(r != 0) { + ct = NULL; + } + r = compact_fhe_uint160_list_destroy(list); + assert(r == 0); + return ct; +} + void destroy_fhe_bool(void* ct) { const int r = fhe_bool_destroy(ct); assert(r == 0); @@ -381,6 +421,11 @@ void destroy_fhe_uint64(void* ct) { assert(r == 0); } +void destroy_fhe_uint160(void* ct) { + const int r = fhe_uint160_destroy(ct); + assert(r == 0); +} + void* add_fhe_uint4(void* ct1, void* ct2, void* sks) { FheUint4* result = NULL; @@ -1294,6 +1339,17 @@ void* eq_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* eq_fhe_uint160(void* ct1, void* ct2, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint160_eq(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* scalar_eq_fhe_uint4(void* ct, uint8_t pt, void* sks) { FheBool* result = NULL; @@ -1349,6 +1405,17 @@ void* scalar_eq_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* scalar_eq_fhe_uint160(void* ct, struct U256 pt, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint160_scalar_eq(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* ne_fhe_uint4(void* ct1, void* ct2, void* sks) { FheBool* result = NULL; @@ -1404,6 +1471,17 @@ void* ne_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* ne_fhe_uint160(void* ct1, void* ct2, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint160_ne(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* scalar_ne_fhe_uint4(void* ct, uint8_t pt, void* sks) { FheBool* result = NULL; @@ -1459,6 +1537,17 @@ void* scalar_ne_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* scalar_ne_fhe_uint160(void* ct, struct U256 pt, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint160_scalar_ne(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* ge_fhe_uint4(void* ct1, void* ct2, void* sks) { FheBool* result = NULL; @@ -2320,6 +2409,11 @@ int decrypt_fhe_uint64(void* cks, void* ct, uint64_t* res) return fhe_uint64_decrypt(ct, cks, res); } +int decrypt_fhe_uint160(void* cks, void* ct, struct U256 *res) +{ + return fhe_uint160_decrypt(ct, cks, res); +} + void* public_key_encrypt_fhe_bool(void* pks, bool value) { CompactFheBoolList* list = NULL; FheBool* ct = NULL; @@ -2416,6 +2510,22 @@ void* public_key_encrypt_fhe_uint64(void* pks, uint64_t value) { return ct; } +void* public_key_encrypt_fhe_uint160(void* pks, struct U256 *value) { + CompactFheUint160List* list = NULL; + FheUint160* ct = NULL; + + int r = compact_fhe_uint160_list_try_encrypt_with_compact_public_key_u256(value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint160_list_expand(list, &ct, 1); + assert(r == 0); + + r = compact_fhe_uint160_list_destroy(list); + assert(r == 0); + + return ct; +} + void* trivial_encrypt_fhe_bool(void* sks, bool value) { FheBool* ct = NULL; @@ -2482,6 +2592,17 @@ void* trivial_encrypt_fhe_uint64(void* sks, uint64_t value) { return ct; } +void* trivial_encrypt_fhe_uint160(void* sks, struct U256 value) { + FheUint160* ct = NULL; + + checked_set_server_key(sks); + + int r = fhe_uint160_try_encrypt_trivial_u256(value, &ct); + assert(r == 0); + + return ct; +} + void public_key_encrypt_and_serialize_fhe_bool_list(void* pks, bool value, DynamicBuffer* out) { CompactFheBoolList* list = NULL; @@ -2560,6 +2681,19 @@ void public_key_encrypt_and_serialize_fhe_uint64_list(void* pks, uint64_t value, assert(r == 0); } +void public_key_encrypt_and_serialize_fhe_uint160_list(void* pks, struct U256 *value, DynamicBuffer* out) { + CompactFheUint160List* list = NULL; + FheUint160* ct = NULL; + + int r = compact_fhe_uint160_list_try_encrypt_with_compact_public_key_u256(value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint160_list_serialize(list, out); + assert(r == 0); + + r = compact_fhe_uint160_list_destroy(list); + assert(r == 0); +} void* cast_4_8(void* ct, void* sks) { FheUint8* result = NULL; diff --git a/fhevm/tfhe/tfhe_wrappers.go b/fhevm/tfhe/tfhe_wrappers.go index a017be5..fc5823e 100644 --- a/fhevm/tfhe/tfhe_wrappers.go +++ b/fhevm/tfhe/tfhe_wrappers.go @@ -13,7 +13,10 @@ import "C" import ( _ "embed" + "encoding/binary" "errors" + "fmt" + "math/big" "unsafe" ) @@ -40,6 +43,8 @@ func serialize(ptr unsafe.Pointer, t FheUintType) ([]byte, error) { ret = C.serialize_fhe_uint32(ptr, out) case FheUint64: ret = C.serialize_fhe_uint64(ptr, out) + case FheUint160: + ret = C.serialize_fhe_uint160(ptr, out) default: panic("serialize: unexpected ciphertext type") } @@ -84,9 +89,82 @@ func EncryptAndSerializeCompact(value uint64, fheUintType FheUintType) []byte { C.public_key_encrypt_and_serialize_fhe_uint32_list(pks, C.uint32_t(value), out) case FheUint64: C.public_key_encrypt_and_serialize_fhe_uint64_list(pks, C.uint64_t(value), out) + case FheUint160: + // TODO + // This function is used to compute ciphertext size, the given value is generally 0, + value_big := new(big.Int).SetUint64(value) + input, err := bigIntToU256(value_big) + if err != nil { + panic(err) + } + C.public_key_encrypt_and_serialize_fhe_uint160_list(pks, input, out) } ser := C.GoBytes(unsafe.Pointer(out.pointer), C.int(out.length)) C.destroy_dynamic_buffer(out) return ser } + +// bigIntToU256 uses u256_from_big_endian_bytes to convert big.Int to U256 +func bigIntToU256(value *big.Int) (*C.U256, error) { + // Convert big.Int to 32-byte big-endian slice + bytes := value.Bytes() + if len(bytes) > 32 { + return nil, fmt.Errorf("big.Int too large for U256") + } + paddedBytes := make([]byte, 32-len(bytes)) // Padding + paddedBytes = append(paddedBytes, bytes...) + + var result C.U256 + + _, err := C.u256_from_big_endian_bytes((*C.uint8_t)(unsafe.Pointer(&paddedBytes[0])), C.size_t(32), &result) + if err != nil { + return nil, fmt.Errorf("failed to convert big.Int to U256: %v", err) + } + + return &result, nil +} + +// u256ToBigInt converts a U256 to a *big.Int. +func u256ToBigInt(u256 C.U256) *big.Int { + // Allocate a byte slice with enough space (32 bytes for U256) + buf := make([]byte, 32) + + // Call the C function to fill the buffer with the big-endian bytes of U256 + C.u256_big_endian_bytes(u256, (*C.uint8_t)(unsafe.Pointer(&buf[0])), C.size_t(len(buf))) + + return new(big.Int).SetBytes(buf) +} + +// U256BytesToBigInt takes a 32-byte big-endian slice and returns a big.Int. +func U256BytesToBigInt(plaintextBytes []byte) (*big.Int, error) { + if len(plaintextBytes) != 32 { + return nil, fmt.Errorf("byte slice is not the correct length for U256: got %d bytes, want 32", len(plaintextBytes)) + } + + // Split the byte slice into four u64 parts considering big-endian encoding + w0 := binary.BigEndian.Uint64(plaintextBytes[0:8]) + w1 := binary.BigEndian.Uint64(plaintextBytes[8:16]) + w2 := binary.BigEndian.Uint64(plaintextBytes[16:24]) + w3 := binary.BigEndian.Uint64(plaintextBytes[24:32]) + + // Print the u64 parts for verification + // fmt.Printf("U256\n") + // fmt.Printf("w0: %d\n", w0) + // fmt.Printf("w1: %d\n", w1) + // fmt.Printf("w2: %d\n", w2) + // fmt.Printf("w3: %d\n", w3) + + // Combine the u64 parts into low and high u128 parts to construct the big.Int + low := new(big.Int).SetUint64(w0) + low.Or(low, new(big.Int).Lsh(new(big.Int).SetUint64(w1), 64)) + + high := new(big.Int).SetUint64(w2) + high.Or(high, new(big.Int).Lsh(new(big.Int).SetUint64(w3), 64)) + + // Shift the high part by 128 bits to the left and add it to the low part + bigIntValue := new(big.Int).Lsh(high, 128) + bigIntValue.Add(bigIntValue, low) + + return bigIntValue, nil +} diff --git a/fhevm/tfhe/tfhe_wrappers.h b/fhevm/tfhe/tfhe_wrappers.h index 3fa8d16..83d7882 100644 --- a/fhevm/tfhe/tfhe_wrappers.h +++ b/fhevm/tfhe/tfhe_wrappers.h @@ -55,6 +55,12 @@ void* deserialize_fhe_uint64(DynamicBufferView in); void* deserialize_compact_fhe_uint64(DynamicBufferView in); +int serialize_fhe_uint160(void *ct, DynamicBuffer* out); + +void* deserialize_fhe_uint160(DynamicBufferView in); + +void* deserialize_compact_fhe_uint160(DynamicBufferView in); + void destroy_fhe_bool(void* ct); void destroy_fhe_uint4(void* ct); @@ -67,6 +73,8 @@ void destroy_fhe_uint32(void* ct); void destroy_fhe_uint64(void* ct); +void destroy_fhe_uint160(void* ct); + void* add_fhe_uint4(void* ct1, void* ct2, void* sks); void* add_fhe_uint8(void* ct1, void* ct2, void* sks); @@ -233,6 +241,8 @@ void* eq_fhe_uint32(void* ct1, void* ct2, void* sks); void* eq_fhe_uint64(void* ct1, void* ct2, void* sks); +void* eq_fhe_uint160(void* ct1, void* ct2, void* sks); + void* scalar_eq_fhe_uint4(void* ct, uint8_t pt, void* sks); void* scalar_eq_fhe_uint8(void* ct, uint8_t pt, void* sks); @@ -243,6 +253,8 @@ void* scalar_eq_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_eq_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* scalar_eq_fhe_uint160(void* ct, struct U256 pt, void* sks); + void* ne_fhe_uint4(void* ct1, void* ct2, void* sks); void* ne_fhe_uint8(void* ct1, void* ct2, void* sks); @@ -253,6 +265,8 @@ void* ne_fhe_uint32(void* ct1, void* ct2, void* sks); void* ne_fhe_uint64(void* ct1, void* ct2, void* sks); +void* ne_fhe_uint160(void* ct1, void* ct2, void* sks); + void* scalar_ne_fhe_uint4(void* ct, uint8_t pt, void* sks); void* scalar_ne_fhe_uint8(void* ct, uint8_t pt, void* sks); @@ -263,6 +277,8 @@ void* scalar_ne_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_ne_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* scalar_ne_fhe_uint160(void* ct, struct U256 pt, void* sks); + void* ge_fhe_uint4(void* ct1, void* ct2, void* sks); void* ge_fhe_uint8(void* ct1, void* ct2, void* sks); @@ -427,6 +443,8 @@ int decrypt_fhe_uint32(void* cks, void* ct, uint32_t* res); int decrypt_fhe_uint64(void* cks, void* ct, uint64_t* res); +int decrypt_fhe_uint160(void* cks, void* ct, struct U256 *res); + void* public_key_encrypt_fhe_bool(void* pks, bool value); void* public_key_encrypt_fhe_uint4(void* pks, uint8_t value); @@ -439,6 +457,8 @@ void* public_key_encrypt_fhe_uint32(void* pks, uint32_t value); void* public_key_encrypt_fhe_uint64(void* pks, uint64_t value); +void* public_key_encrypt_fhe_uint160(void* pks, struct U256 *value); + void* trivial_encrypt_fhe_bool(void* sks, bool value); void* trivial_encrypt_fhe_uint4(void* sks, uint8_t value); @@ -451,6 +471,8 @@ void* trivial_encrypt_fhe_uint32(void* sks, uint32_t value); void* trivial_encrypt_fhe_uint64(void* sks, uint64_t value); +void* trivial_encrypt_fhe_uint160(void* sks, struct U256 value); + void public_key_encrypt_and_serialize_fhe_bool_list(void* pks, bool value, DynamicBuffer* out); void public_key_encrypt_and_serialize_fhe_uint4_list(void* pks, uint8_t value, DynamicBuffer* out); @@ -463,6 +485,8 @@ void public_key_encrypt_and_serialize_fhe_uint32_list(void* pks, uint32_t value, void public_key_encrypt_and_serialize_fhe_uint64_list(void* pks, uint64_t value, DynamicBuffer* out); +void public_key_encrypt_and_serialize_fhe_uint160_list(void* pks, struct U256 *value, DynamicBuffer* out); + void* cast_bool_4(void* ct, void* sks); void* cast_bool_8(void* ct, void* sks); diff --git a/proto/kms.proto b/proto/kms.proto index 7a709ba..82422c0 100644 --- a/proto/kms.proto +++ b/proto/kms.proto @@ -17,6 +17,8 @@ enum FheType { Euint16 = 3; Euint32 = 4; Euint64 = 5; + Euint128 = 6; + Euint160 = 7; } message Proof { @@ -34,7 +36,7 @@ message DecryptionRequest { message DecryptionResponse { bytes signature = 1; FheType fhe_type = 2; - uint64 plaintext = 3; + bytes plaintext = 3; } message ReencryptionRequest { From c3abd74a359f03eef288166eed7cbde361767526 Mon Sep 17 00:00:00 2001 From: Levent DEMIR Date: Mon, 18 Mar 2024 01:44:06 +0100 Subject: [PATCH 2/3] chore: refactor deserialization of decrypted values --- fhevm/operators_crypto.go | 45 +------------------------------------ fhevm/tfhe/tfhe_wrappers.go | 34 ---------------------------- 2 files changed, 1 insertion(+), 78 deletions(-) diff --git a/fhevm/operators_crypto.go b/fhevm/operators_crypto.go index bea796d..6e1e41f 100644 --- a/fhevm/operators_crypto.go +++ b/fhevm/operators_crypto.go @@ -6,7 +6,6 @@ import ( "encoding/binary" "encoding/hex" "errors" - "fmt" "math/big" "time" @@ -292,49 +291,7 @@ func decryptValue(environment EVMEnvironment, ct *tfhe.TfheCiphertext) (*big.Int plaintextBytes := res.Plaintext // Variable to hold the resulting big.Int - var plaintextBigInt *big.Int - - switch fheType { - case kms.FheType_Bool, kms.FheType_Euint4, kms.FheType_Euint8: - - if len(plaintextBytes) > 0 { - plaintextBigInt = big.NewInt(int64(plaintextBytes[0])) - } else { - return nil, errors.New("decryption resulted in empty plaintext for a single-byte FheType") - } - case kms.FheType_Euint16: - // For Euint16, ensure plaintextBytes has at least 2 bytes. - if len(plaintextBytes) >= 2 { - // Use binary.BigEndian.Uint16 to convert bytes to uint16, then to big.Int. - uintVal := binary.BigEndian.Uint16(plaintextBytes) - plaintextBigInt = big.NewInt(int64(uintVal)) - } else { - return nil, errors.New("decryption resulted in insufficient bytes for FheType_Euint16") - } - case kms.FheType_Euint32: - // Similar to Euint16, but with 4 bytes to uint32. - if len(plaintextBytes) >= 4 { - uintVal := binary.BigEndian.Uint32(plaintextBytes) - plaintextBigInt = big.NewInt(int64(uintVal)) - } else { - return nil, errors.New("decryption resulted in insufficient bytes for FheType_Euint32") - } - case kms.FheType_Euint64: - // For Euint64, ensure there are 8 bytes to work with. - if len(plaintextBytes) >= 8 { - uintVal := binary.BigEndian.Uint64(plaintextBytes) - plaintextBigInt = new(big.Int).SetUint64(uintVal) - } else { - return nil, errors.New("decryption resulted in insufficient bytes for FheType_Euint64") - } - case kms.FheType_Euint160: - logger.Info("decrypt success", "plaintextBytes", plaintextBytes) - logger.Info("decrypt success", "plaintextBytes", fmt.Sprintf("%v", plaintextBytes)) - // Special handling for FheUint160, already covered. - plaintextBigInt, err = tfhe.U256BytesToBigInt(plaintextBytes) - default: - return nil, fmt.Errorf("unsupported FheType: %v", fheType) - } + plaintextBigInt := new(big.Int).SetBytes(plaintextBytes) return plaintextBigInt, nil diff --git a/fhevm/tfhe/tfhe_wrappers.go b/fhevm/tfhe/tfhe_wrappers.go index fc5823e..2fc2c8c 100644 --- a/fhevm/tfhe/tfhe_wrappers.go +++ b/fhevm/tfhe/tfhe_wrappers.go @@ -13,7 +13,6 @@ import "C" import ( _ "embed" - "encoding/binary" "errors" "fmt" "math/big" @@ -135,36 +134,3 @@ func u256ToBigInt(u256 C.U256) *big.Int { return new(big.Int).SetBytes(buf) } - -// U256BytesToBigInt takes a 32-byte big-endian slice and returns a big.Int. -func U256BytesToBigInt(plaintextBytes []byte) (*big.Int, error) { - if len(plaintextBytes) != 32 { - return nil, fmt.Errorf("byte slice is not the correct length for U256: got %d bytes, want 32", len(plaintextBytes)) - } - - // Split the byte slice into four u64 parts considering big-endian encoding - w0 := binary.BigEndian.Uint64(plaintextBytes[0:8]) - w1 := binary.BigEndian.Uint64(plaintextBytes[8:16]) - w2 := binary.BigEndian.Uint64(plaintextBytes[16:24]) - w3 := binary.BigEndian.Uint64(plaintextBytes[24:32]) - - // Print the u64 parts for verification - // fmt.Printf("U256\n") - // fmt.Printf("w0: %d\n", w0) - // fmt.Printf("w1: %d\n", w1) - // fmt.Printf("w2: %d\n", w2) - // fmt.Printf("w3: %d\n", w3) - - // Combine the u64 parts into low and high u128 parts to construct the big.Int - low := new(big.Int).SetUint64(w0) - low.Or(low, new(big.Int).Lsh(new(big.Int).SetUint64(w1), 64)) - - high := new(big.Int).SetUint64(w2) - high.Or(high, new(big.Int).Lsh(new(big.Int).SetUint64(w3), 64)) - - // Shift the high part by 128 bits to the left and add it to the low part - bigIntValue := new(big.Int).Lsh(high, 128) - bigIntValue.Add(bigIntValue, low) - - return bigIntValue, nil -} From 70acdcc5b31e332b710c59f7a4996a0529931915 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Mon, 18 Mar 2024 13:32:04 +0100 Subject: [PATCH 3/3] feat: add correct gas price for FheUint160 --- fhevm/params.go | 1 + 1 file changed, 1 insertion(+) diff --git a/fhevm/params.go b/fhevm/params.go index e98d0ab..b9c4a36 100644 --- a/fhevm/params.go +++ b/fhevm/params.go @@ -137,6 +137,7 @@ func DefaultGasCosts() GasCosts { tfhe.FheUint16: 44000 + AdjustFHEGas, tfhe.FheUint32: 72000 + AdjustFHEGas, tfhe.FheUint64: 76000 + AdjustFHEGas, + tfhe.FheUint160: 80000 + AdjustFHEGas, }, FheLe: map[tfhe.FheUintType]uint64{ tfhe.FheUint4: 60000 + AdjustFHEGas,