diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..a6b413a --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 SmartContract + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..4fcf953 --- /dev/null +++ b/README.md @@ -0,0 +1,4 @@ +# tdh2 +A Go & JavaScript implementation of the TDH2 protocol from [Securing Threshold Cryptosystems against Chosen Ciphertext Attack](https://www.shoup.net/papers/thresh1.pdf) by Shoup & Gennaro. + +**This is untested alpha code. Do not rely on it for any non-experimental use cases!** diff --git a/go/go.work b/go/go.work new file mode 100644 index 0000000..ae05aef --- /dev/null +++ b/go/go.work @@ -0,0 +1,6 @@ +go 1.19 + +use ( + ./ocr2/decryptionplugin + ./tdh2 +) diff --git a/go/go.work.sum b/go/go.work.sum new file mode 100644 index 0000000..8e5ea79 --- /dev/null +++ b/go/go.work.sum @@ -0,0 +1,13 @@ +filippo.io/edwards25519 v1.0.0-rc.1/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= +github.com/gtank/ristretto255 v0.1.3-0.20210930101514-6bb39798585c/go.mod h1:tDPFhGdt3hJWqtKwx57i9baiB1Cj0yAg22VOPUqm5vY= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/go/ocr2/decryptionplugin/config/config.go b/go/ocr2/decryptionplugin/config/config.go new file mode 100644 index 0000000..aa53ea8 --- /dev/null +++ b/go/ocr2/decryptionplugin/config/config.go @@ -0,0 +1,41 @@ +package config + +import ( + "fmt" + "math" + + "github.com/goplugin/plugin-libocr/commontypes" + "google.golang.org/protobuf/proto" +) + +// This config is stored in the Oracle contract (set via SetConfig()). +// Every SetConfig() call reloads the reporting plugin (DirectRequestReportingPluginFactory.NewReportingPlugin()) +type ReportingPluginConfigWrapper struct { + Config *ReportingPluginConfig +} + +func DecodeReportingPluginConfig(raw []byte) (*ReportingPluginConfigWrapper, error) { + configProto := &ReportingPluginConfig{} + if err := proto.Unmarshal(raw, configProto); err != nil { + return nil, err + } + return &ReportingPluginConfigWrapper{Config: configProto}, nil +} + +func EncodeReportingPluginConfig(rpConfig *ReportingPluginConfigWrapper) ([]byte, error) { + return proto.Marshal(rpConfig.Config) +} + +func EncodeOracleIdtoKeyShareIndex(oracleID commontypes.OracleID, keyShareIndex int) *OracleIDtoKeyShareIndex { + return &OracleIDtoKeyShareIndex{ + OracleId: uint32(oracleID), + KeyShareIndex: uint32(keyShareIndex), + } +} + +func DecodeOracleIdtoKeyShareIndex(oracleIDtoKeyShareIndex *OracleIDtoKeyShareIndex) (commontypes.OracleID, int, error) { + if oracleIDtoKeyShareIndex.OracleId > math.MaxUint8 { + return 0, 0, fmt.Errorf("oracleID is larger than MAX_UINT8") + } + return commontypes.OracleID(oracleIDtoKeyShareIndex.OracleId), int(oracleIDtoKeyShareIndex.KeyShareIndex), nil +} diff --git a/go/ocr2/decryptionplugin/config/config_types.pb.go b/go/ocr2/decryptionplugin/config/config_types.pb.go new file mode 100644 index 0000000..282a6a5 --- /dev/null +++ b/go/ocr2/decryptionplugin/config/config_types.pb.go @@ -0,0 +1,312 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc v3.20.0 +// source: config/config_types.proto + +package config + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type OracleIDtoKeyShareIndex struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + OracleId uint32 `protobuf:"varint,1,opt,name=oracle_id,json=oracleId,proto3" json:"oracle_id,omitempty"` + KeyShareIndex uint32 `protobuf:"varint,2,opt,name=key_share_index,json=keyShareIndex,proto3" json:"key_share_index,omitempty"` +} + +func (x *OracleIDtoKeyShareIndex) Reset() { + *x = OracleIDtoKeyShareIndex{} + if protoimpl.UnsafeEnabled { + mi := &file_config_config_types_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *OracleIDtoKeyShareIndex) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OracleIDtoKeyShareIndex) ProtoMessage() {} + +func (x *OracleIDtoKeyShareIndex) ProtoReflect() protoreflect.Message { + mi := &file_config_config_types_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OracleIDtoKeyShareIndex.ProtoReflect.Descriptor instead. +func (*OracleIDtoKeyShareIndex) Descriptor() ([]byte, []int) { + return file_config_config_types_proto_rawDescGZIP(), []int{0} +} + +func (x *OracleIDtoKeyShareIndex) GetOracleId() uint32 { + if x != nil { + return x.OracleId + } + return 0 +} + +func (x *OracleIDtoKeyShareIndex) GetKeyShareIndex() uint32 { + if x != nil { + return x.KeyShareIndex + } + return 0 +} + +type ReportingPluginConfig struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + MaxQueryLengthBytes uint32 `protobuf:"varint,1,opt,name=max_query_length_bytes,json=maxQueryLengthBytes,proto3" json:"max_query_length_bytes,omitempty"` + MaxObservationLengthBytes uint32 `protobuf:"varint,2,opt,name=max_observation_length_bytes,json=maxObservationLengthBytes,proto3" json:"max_observation_length_bytes,omitempty"` + MaxReportLengthBytes uint32 `protobuf:"varint,3,opt,name=max_report_length_bytes,json=maxReportLengthBytes,proto3" json:"max_report_length_bytes,omitempty"` + RequestCountLimit uint32 `protobuf:"varint,4,opt,name=request_count_limit,json=requestCountLimit,proto3" json:"request_count_limit,omitempty"` + RequestTotalBytesLimit uint32 `protobuf:"varint,5,opt,name=request_total_bytes_limit,json=requestTotalBytesLimit,proto3" json:"request_total_bytes_limit,omitempty"` + RequireLocalRequestCheck bool `protobuf:"varint,6,opt,name=require_local_request_check,json=requireLocalRequestCheck,proto3" json:"require_local_request_check,omitempty"` + PublicKey []byte `protobuf:"bytes,7,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"` + PrivKeyShare []byte `protobuf:"bytes,8,opt,name=priv_key_share,json=privKeyShare,proto3" json:"priv_key_share,omitempty"` + OracleIdToKeyIndex []*OracleIDtoKeyShareIndex `protobuf:"bytes,9,rep,name=oracle_id_to_key_index,json=oracleIdToKeyIndex,proto3" json:"oracle_id_to_key_index,omitempty"` +} + +func (x *ReportingPluginConfig) Reset() { + *x = ReportingPluginConfig{} + if protoimpl.UnsafeEnabled { + mi := &file_config_config_types_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ReportingPluginConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReportingPluginConfig) ProtoMessage() {} + +func (x *ReportingPluginConfig) ProtoReflect() protoreflect.Message { + mi := &file_config_config_types_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReportingPluginConfig.ProtoReflect.Descriptor instead. +func (*ReportingPluginConfig) Descriptor() ([]byte, []int) { + return file_config_config_types_proto_rawDescGZIP(), []int{1} +} + +func (x *ReportingPluginConfig) GetMaxQueryLengthBytes() uint32 { + if x != nil { + return x.MaxQueryLengthBytes + } + return 0 +} + +func (x *ReportingPluginConfig) GetMaxObservationLengthBytes() uint32 { + if x != nil { + return x.MaxObservationLengthBytes + } + return 0 +} + +func (x *ReportingPluginConfig) GetMaxReportLengthBytes() uint32 { + if x != nil { + return x.MaxReportLengthBytes + } + return 0 +} + +func (x *ReportingPluginConfig) GetRequestCountLimit() uint32 { + if x != nil { + return x.RequestCountLimit + } + return 0 +} + +func (x *ReportingPluginConfig) GetRequestTotalBytesLimit() uint32 { + if x != nil { + return x.RequestTotalBytesLimit + } + return 0 +} + +func (x *ReportingPluginConfig) GetRequireLocalRequestCheck() bool { + if x != nil { + return x.RequireLocalRequestCheck + } + return false +} + +func (x *ReportingPluginConfig) GetPublicKey() []byte { + if x != nil { + return x.PublicKey + } + return nil +} + +func (x *ReportingPluginConfig) GetPrivKeyShare() []byte { + if x != nil { + return x.PrivKeyShare + } + return nil +} + +func (x *ReportingPluginConfig) GetOracleIdToKeyIndex() []*OracleIDtoKeyShareIndex { + if x != nil { + return x.OracleIdToKeyIndex + } + return nil +} + +var File_config_config_types_proto protoreflect.FileDescriptor + +var file_config_config_types_proto_rawDesc = []byte{ + 0x0a, 0x19, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, + 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x63, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x22, 0x5e, 0x0a, 0x17, 0x4f, 0x72, 0x61, + 0x63, 0x6c, 0x65, 0x49, 0x44, 0x74, 0x6f, 0x4b, 0x65, 0x79, 0x53, 0x68, 0x61, 0x72, 0x65, 0x49, + 0x6e, 0x64, 0x65, 0x78, 0x12, 0x1b, 0x0a, 0x09, 0x6f, 0x72, 0x61, 0x63, 0x6c, 0x65, 0x5f, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x6f, 0x72, 0x61, 0x63, 0x6c, 0x65, 0x49, + 0x64, 0x12, 0x26, 0x0a, 0x0f, 0x6b, 0x65, 0x79, 0x5f, 0x73, 0x68, 0x61, 0x72, 0x65, 0x5f, 0x69, + 0x6e, 0x64, 0x65, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0d, 0x6b, 0x65, 0x79, 0x53, + 0x68, 0x61, 0x72, 0x65, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x22, 0x8e, 0x04, 0x0a, 0x15, 0x52, 0x65, + 0x70, 0x6f, 0x72, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x12, 0x33, 0x0a, 0x16, 0x6d, 0x61, 0x78, 0x5f, 0x71, 0x75, 0x65, 0x72, 0x79, + 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x13, 0x6d, 0x61, 0x78, 0x51, 0x75, 0x65, 0x72, 0x79, 0x4c, 0x65, 0x6e, + 0x67, 0x74, 0x68, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x3f, 0x0a, 0x1c, 0x6d, 0x61, 0x78, 0x5f, + 0x6f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6c, 0x65, 0x6e, 0x67, + 0x74, 0x68, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x19, + 0x6d, 0x61, 0x78, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x65, + 0x6e, 0x67, 0x74, 0x68, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x35, 0x0a, 0x17, 0x6d, 0x61, 0x78, + 0x5f, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x5f, 0x62, + 0x79, 0x74, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x14, 0x6d, 0x61, 0x78, 0x52, + 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x42, 0x79, 0x74, 0x65, 0x73, + 0x12, 0x2e, 0x0a, 0x13, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x63, 0x6f, 0x75, 0x6e, + 0x74, 0x5f, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x11, 0x72, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x4c, 0x69, 0x6d, 0x69, 0x74, + 0x12, 0x39, 0x0a, 0x19, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x74, 0x6f, 0x74, 0x61, + 0x6c, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x16, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x54, 0x6f, 0x74, 0x61, + 0x6c, 0x42, 0x79, 0x74, 0x65, 0x73, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x12, 0x3d, 0x0a, 0x1b, 0x72, + 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x5f, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x5f, 0x72, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x5f, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x18, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x75, + 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, + 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x0e, 0x70, 0x72, 0x69, + 0x76, 0x5f, 0x6b, 0x65, 0x79, 0x5f, 0x73, 0x68, 0x61, 0x72, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x0c, 0x70, 0x72, 0x69, 0x76, 0x4b, 0x65, 0x79, 0x53, 0x68, 0x61, 0x72, 0x65, 0x12, + 0x59, 0x0a, 0x16, 0x6f, 0x72, 0x61, 0x63, 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x5f, 0x74, 0x6f, 0x5f, + 0x6b, 0x65, 0x79, 0x5f, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x18, 0x09, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x25, 0x2e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4f, + 0x72, 0x61, 0x63, 0x6c, 0x65, 0x49, 0x44, 0x74, 0x6f, 0x4b, 0x65, 0x79, 0x53, 0x68, 0x61, 0x72, + 0x65, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x52, 0x12, 0x6f, 0x72, 0x61, 0x63, 0x6c, 0x65, 0x49, 0x64, + 0x54, 0x6f, 0x4b, 0x65, 0x79, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x42, 0x0b, 0x5a, 0x09, 0x2e, 0x2f, + 0x3b, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_config_config_types_proto_rawDescOnce sync.Once + file_config_config_types_proto_rawDescData = file_config_config_types_proto_rawDesc +) + +func file_config_config_types_proto_rawDescGZIP() []byte { + file_config_config_types_proto_rawDescOnce.Do(func() { + file_config_config_types_proto_rawDescData = protoimpl.X.CompressGZIP(file_config_config_types_proto_rawDescData) + }) + return file_config_config_types_proto_rawDescData +} + +var file_config_config_types_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_config_config_types_proto_goTypes = []interface{}{ + (*OracleIDtoKeyShareIndex)(nil), // 0: config_types.OracleIDtoKeyShareIndex + (*ReportingPluginConfig)(nil), // 1: config_types.ReportingPluginConfig +} +var file_config_config_types_proto_depIdxs = []int32{ + 0, // 0: config_types.ReportingPluginConfig.oracle_id_to_key_index:type_name -> config_types.OracleIDtoKeyShareIndex + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_config_config_types_proto_init() } +func file_config_config_types_proto_init() { + if File_config_config_types_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_config_config_types_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*OracleIDtoKeyShareIndex); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_config_config_types_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ReportingPluginConfig); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_config_config_types_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_config_config_types_proto_goTypes, + DependencyIndexes: file_config_config_types_proto_depIdxs, + MessageInfos: file_config_config_types_proto_msgTypes, + }.Build() + File_config_config_types_proto = out.File + file_config_config_types_proto_rawDesc = nil + file_config_config_types_proto_goTypes = nil + file_config_config_types_proto_depIdxs = nil +} diff --git a/go/ocr2/decryptionplugin/config/config_types.proto b/go/ocr2/decryptionplugin/config/config_types.proto new file mode 100644 index 0000000..40153a4 --- /dev/null +++ b/go/ocr2/decryptionplugin/config/config_types.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +option go_package = "./;config"; + +package config_types; + +message OracleIDtoKeyShareIndex { + uint32 oracle_id = 1; + uint32 key_share_index = 2; +} + +message ReportingPluginConfig { + uint32 max_query_length_bytes = 1; + uint32 max_observation_length_bytes = 2; + uint32 max_report_length_bytes = 3; + uint32 request_count_limit = 4; + uint32 request_total_bytes_limit = 5; + bool require_local_request_check = 6; + bytes public_key = 7; + bytes priv_key_share = 8; + repeated OracleIDtoKeyShareIndex oracle_id_to_key_index = 9; +} \ No newline at end of file diff --git a/go/ocr2/decryptionplugin/decryption.go b/go/ocr2/decryptionplugin/decryption.go new file mode 100644 index 0000000..8b44cf5 --- /dev/null +++ b/go/ocr2/decryptionplugin/decryption.go @@ -0,0 +1,383 @@ +package decryptionplugin + +import ( + "bytes" + "context" + "errors" + "fmt" + + "github.com/goplugin/plugin-libocr/commontypes" + "github.com/goplugin/plugin-libocr/offchainreporting2/types" + "github.com/goplugin/tdh2/go/ocr2/decryptionplugin/config" + "github.com/goplugin/tdh2/go/tdh2easy" + "google.golang.org/protobuf/proto" +) + +type DecryptionReportingPluginFactory struct { + DecryptionQueue DecryptionQueuingService + Logger commontypes.Logger +} + +type decryptionPlugin struct { + logger commontypes.Logger + decryptionQueue DecryptionQueuingService + publicKey *tdh2easy.PublicKey + privKeyShare *tdh2easy.PrivateShare + oracleToKeyShare map[commontypes.OracleID]int + genericConfig *types.ReportingPluginConfig + specificConfig *config.ReportingPluginConfigWrapper +} + +// NewReportingPlugin complies with ReportingPluginFactory. +func (f DecryptionReportingPluginFactory) NewReportingPlugin(rpConfig types.ReportingPluginConfig) (types.ReportingPlugin, types.ReportingPluginInfo, error) { + pluginConfig, err := config.DecodeReportingPluginConfig(rpConfig.OffchainConfig) + if err != nil { + f.Logger.Error("unable to decode reporting plugin config", commontypes.LogFields{ + "configDigest": rpConfig.ConfigDigest.String(), + }) + return nil, types.ReportingPluginInfo{}, fmt.Errorf("unalbe to decode reporting plugin config: %w", err) + } + + info := types.ReportingPluginInfo{ + Name: "ThresholdDecryption", + UniqueReports: false, // Aggregating any f+1 valid decryption shares result in the same plaintext. Must match setting in OCR2Base.sol. + // TODO calculate limits based on the maximum size of the plaintext and ciphertextID + Limits: types.ReportingPluginLimits{ + MaxQueryLength: int(pluginConfig.Config.GetMaxQueryLengthBytes()), + MaxObservationLength: int(pluginConfig.Config.GetMaxObservationLengthBytes()), + MaxReportLength: int(pluginConfig.Config.GetMaxReportLengthBytes()), + }, + } + + oracleToKeyShare := make(map[commontypes.OracleID]int) + for _, entry := range pluginConfig.Config.OracleIdToKeyIndex { + oID, ksID, err := config.DecodeOracleIdtoKeyShareIndex(entry) + if err != nil { + return nil, types.ReportingPluginInfo{}, fmt.Errorf("unalbe to decode reporting plugin oracle id to key Share index mapping: %w", err) + } + oracleToKeyShare[oID] = ksID + } + + plugin := decryptionPlugin{ + f.Logger, + f.DecryptionQueue, + &tdh2easy.PublicKey{}, + &tdh2easy.PrivateShare{}, + oracleToKeyShare, + &rpConfig, + pluginConfig, + } + + if err = plugin.publicKey.Unmarshal(pluginConfig.Config.PublicKey); err != nil { + return nil, info, fmt.Errorf("cannot unmarshal public key: %w", err) + } + + if err = plugin.privKeyShare.Unmarshal(pluginConfig.Config.PrivKeyShare); err != nil { + return nil, info, fmt.Errorf("cannot unmarshal private key share: %w", err) + } + + return &plugin, info, nil +} + +// Query creates a query with the oldest pending decryption requests. +func (dp *decryptionPlugin) Query(ctx context.Context, ts types.ReportTimestamp) (types.Query, error) { + dp.logger.Debug("DecryptionReporting Query: start", commontypes.LogFields{ + "epoch": ts.Epoch, + "round": ts.Round, + }) + + decryptionRequests := dp.decryptionQueue.GetRequests( + int(dp.specificConfig.Config.RequestCountLimit), + int(dp.specificConfig.Config.RequestTotalBytesLimit), + ) + + queryProto := Query{} + for _, request := range decryptionRequests { + ciphertext := &tdh2easy.Ciphertext{} + if err := ciphertext.UnmarshalVerify(request.Ciphertext, dp.publicKey); err != nil { + dp.logger.Error("DecryptionReporting Query: cannot unmarshal the ciphertext, skipping it", commontypes.LogFields{ + "error": err, + "ciphertextID": request.CiphertextId, + }) + continue + } + queryProto.DecryptionRequests = append(queryProto.GetDecryptionRequests(), &CiphertextWithID{ + CiphertextId: request.CiphertextId, + Ciphertext: request.Ciphertext, + }) + } + + dp.logger.Debug("DecryptionReporting Query: end", commontypes.LogFields{ + "epoch": ts.Epoch, + "round": ts.Round, + "queryLen": len(queryProto.DecryptionRequests), + }) + queryProtoBytes, err := proto.Marshal(&queryProto) + if err != nil { + return nil, fmt.Errorf("cannot marshal query: %w", err) + } + return queryProtoBytes, nil +} + +// Observation creates a decryption share for each request in the query. +// If dp.specificConfig.Config.LocalRequest is true, then the oracle +// only creates a decryption share for the decryption requests which it has locally. +func (dp *decryptionPlugin) Observation(ctx context.Context, ts types.ReportTimestamp, query types.Query) (types.Observation, error) { + dp.logger.Debug("DecryptionReporting Observation: start", commontypes.LogFields{ + "epoch": ts.Epoch, + "round": ts.Round, + }) + + queryProto := &Query{} + if err := proto.Unmarshal(query, queryProto); err != nil { + return nil, fmt.Errorf("cannot unmarshal query: %w", err) + } + + observationProto := Observation{} + for _, request := range queryProto.DecryptionRequests { + ciphertext := &tdh2easy.Ciphertext{} + ciphertextBytes := request.Ciphertext + if err := ciphertext.UnmarshalVerify(ciphertextBytes, dp.publicKey); err != nil { + dp.logger.Error("DecryptionReporting Observation: cannot unmarshal and verify the ciphertext, the leader is faulty", commontypes.LogFields{ + "error": err, + "ciphertextID": request.CiphertextId, + }) + return nil, fmt.Errorf("cannot unmarshal and verify the ciphertext: %w", err) + } + if dp.specificConfig.Config.RequireLocalRequestCheck { + queueCiphertextBytes, err := dp.decryptionQueue.GetCiphertext(request.CiphertextId) + if err != nil && errors.Is(err, ErrNotFound) { + dp.logger.Warn("DecryptionReporting Observation: cannot find ciphertext locally, skipping it", commontypes.LogFields{ + "error": err, + "ciphertextID": request.CiphertextId, + }) + continue + } else if err != nil { + dp.logger.Error("DecryptionReporting Observation: failed when looking for ciphertext locally, skipping it", commontypes.LogFields{ + "error": err, + "ciphertextID": request.CiphertextId, + }) + continue + } + if !bytes.Equal(queueCiphertextBytes, ciphertextBytes) { + dp.logger.Error("DecryptionReporting Observation: local ciphertext does not match the query ciphertext, skipping it", commontypes.LogFields{ + "ciphertextID": request.CiphertextId, + }) + continue + } + } + + decryptionShare, err := tdh2easy.Decrypt(ciphertext, dp.privKeyShare) + if err != nil { + dp.logger.Error("DecryptionReporting Observation: cannot decrypt the ciphertext", commontypes.LogFields{ + "error": err, + "ciphertextID": request.CiphertextId, + }) + continue + } + decryptionShareBytes, err := decryptionShare.Marshal() + if err != nil { + dp.logger.Error("DecryptionReporting Observation: cannot marshal the decryption share, skipping it", commontypes.LogFields{ + "error": err, + "ciphertextID": request.CiphertextId, + }) + continue + } + observationProto.DecryptionShares = append(observationProto.DecryptionShares, &DecryptionShareWithID{ + CiphertextId: request.CiphertextId, + DecryptionShare: decryptionShareBytes, + }) + } + + dp.logger.Debug("DecryptionReporting Observation: end", commontypes.LogFields{ + "epoch": ts.Epoch, + "round": ts.Round, + "decryptedRequests": len(observationProto.DecryptionShares), + "totalRequests": len(queryProto.DecryptionRequests), + }) + observationProtoBytes, err := proto.Marshal(&observationProto) + if err != nil { + return nil, fmt.Errorf("cannot marshal observation: %w", err) + } + return observationProtoBytes, nil +} + +// Report aggregates decryption shares from Observations to derive the plaintext. +func (dp *decryptionPlugin) Report(ctx context.Context, ts types.ReportTimestamp, query types.Query, obs []types.AttributedObservation) (bool, types.Report, error) { + dp.logger.Debug("DecryptionReporting Report: start", commontypes.LogFields{ + "epoch": ts.Epoch, + "round": ts.Round, + "nObservations": len(obs), + }) + + queryProto := &Query{} + if err := proto.Unmarshal(query, queryProto); err != nil { + return false, nil, fmt.Errorf("cannot unmarshal query: %w ", err) + } + ciphertexts := make(map[string]*tdh2easy.Ciphertext) + for _, request := range queryProto.DecryptionRequests { + ciphertext := &tdh2easy.Ciphertext{} + if err := ciphertext.UnmarshalVerify(request.Ciphertext, dp.publicKey); err != nil { + dp.logger.Error("DecryptionReporting Report: cannot unmarshal and verify the ciphertext, the leader is faulty", commontypes.LogFields{ + "error": err, + "ciphertextID": request.CiphertextId, + }) + return false, nil, fmt.Errorf("cannot unmarshal and verify the ciphertext: %w", err) + } + ciphertexts[string(request.CiphertextId)] = ciphertext + } + + fPlusOne := dp.genericConfig.F + 1 + validDecryptionShares := make(map[string][]*tdh2easy.DecryptionShare) + for _, ob := range obs { + observationProto := &Observation{} + if err := proto.Unmarshal(ob.Observation, observationProto); err != nil { + dp.logger.Error("DecryptionReporting Report: cannot unmarshal observation, skipping it", commontypes.LogFields{ + "error": err, + "observer": ob.Observer, + }) + continue + } + + for _, decryptionShareWithID := range observationProto.DecryptionShares { + ciphertextID := string(decryptionShareWithID.CiphertextId) + ciphertext, ok := ciphertexts[ciphertextID] + if !ok { + dp.logger.Error("DecryptionReporting Report: there is not ciphertext in the query with matching id", commontypes.LogFields{ + "ciphertextID": ciphertextID, + "observer": ob.Observer, + }) + continue + } + + validDecryptionShare, err := dp.getValidDecryptionShare(ob.Observer, + ciphertext, decryptionShareWithID.DecryptionShare) + if err != nil { + dp.logger.Error("DecryptionReporting Report: invalid decryption share", commontypes.LogFields{ + "error": err, + "ciphertextID": ciphertextID, + "observer": ob.Observer, + }) + continue + } + + validDecryptionShares[ciphertextID] = append(validDecryptionShares[ciphertextID], validDecryptionShare) + if len(validDecryptionShares[ciphertextID]) >= fPlusOne { + dp.logger.Trace("DecryptionReporting Report: we have already f+1 valid decryption shares", commontypes.LogFields{ + "ciphertextID": ciphertextID, + "observer": ob.Observer, + }) + break + } + } + } + + reportProto := Report{} + for id, decrShares := range validDecryptionShares { + ciphertext, ok := ciphertexts[id] + if !ok { + dp.logger.Error("DecryptionReporting Report: there is not ciphertext in the query with matching id, skipping aggregation of decryption shares", commontypes.LogFields{ + "ciphertextID": id, + }) + continue + } + + // OCR2.0 guaranties 2f+1 observations are from distinct oracles + // which guaranties f+1 valid observations and, hence, f+1 valid decryption shares. + // Therefore, here it is guaranteed that len(decrShares) > f. + plaintext, err := tdh2easy.Aggregate(ciphertext, decrShares, dp.genericConfig.N) + if err != nil { + dp.logger.Error("DecryptionReporting Report: cannot aggregate decryption shares", commontypes.LogFields{ + "error": err, + "ciphertextID": id, + }) + continue + } + + dp.logger.Debug("DecryptionReporting Report: plaintext aggregated successfully", commontypes.LogFields{ + "epoch": ts.Epoch, + "round": ts.Round, + "ciphertextID": id, + }) + reportProto.ProcessedDecryptedRequests = append(reportProto.ProcessedDecryptedRequests, &ProcessedDecryptionRequest{ + CiphertextId: []byte(id), + Plaintext: plaintext, + }) + } + + dp.logger.Debug("DecryptionReporting Report: end", commontypes.LogFields{ + "epoch": ts.Epoch, + "round": ts.Round, + "aggregatedDecryptionShares": len(reportProto.ProcessedDecryptedRequests), + "reporting": len(reportProto.ProcessedDecryptedRequests) > 0, + }) + + if len(reportProto.ProcessedDecryptedRequests) == 0 { + return false, nil, nil + } + + reportBytes, err := proto.Marshal(&reportProto) + if err != nil { + return false, nil, fmt.Errorf("cannot marshal report: %w", err) + } + return true, reportBytes, nil +} + +func (dp *decryptionPlugin) getValidDecryptionShare(observer commontypes.OracleID, + ciphertext *tdh2easy.Ciphertext, decryptionShareBytes []byte) (*tdh2easy.DecryptionShare, error) { + decryptionShare := &tdh2easy.DecryptionShare{} + if err := decryptionShare.Unmarshal(decryptionShareBytes); err != nil { + return nil, fmt.Errorf("cannot unmarshal decryption share: %w", err) + } + + expectedKeyShareIndex, ok := dp.oracleToKeyShare[observer] + if !ok { + return nil, fmt.Errorf("invalid observer ID") + } + + if expectedKeyShareIndex != decryptionShare.Index() { + return nil, fmt.Errorf("invalid decryption share index: expected %d and got %d", expectedKeyShareIndex, decryptionShare.Index()) + } + + if err := tdh2easy.VerifyShare(ciphertext, dp.publicKey, decryptionShare); err != nil { + return nil, fmt.Errorf("decryption share verification failed: %w", err) + } + return decryptionShare, nil +} + +// ShouldAcceptFinalizedReport updates the decryption queue. +// Returns always false as the report will not be transmitted on-chain. +func (dp *decryptionPlugin) ShouldAcceptFinalizedReport(ctx context.Context, ts types.ReportTimestamp, report types.Report) (bool, error) { + dp.logger.Debug("DecryptionReporting ShouldAcceptFinalizedReport: start", commontypes.LogFields{ + "epoch": ts.Epoch, + "round": ts.Round, + }) + + reportProto := &Report{} + if err := proto.Unmarshal(report, reportProto); err != nil { + return false, fmt.Errorf("cannot unmarshal report: %w", err) + } + + for _, item := range reportProto.ProcessedDecryptedRequests { + dp.decryptionQueue.SetResult(item.CiphertextId, item.Plaintext) + } + + dp.logger.Debug("DecryptionReporting ShouldAcceptFinalizedReport: end", commontypes.LogFields{ + "epoch": ts.Epoch, + "round": ts.Round, + "accepting": false, + }) + + return false, nil +} + +// ShouldTransmitAcceptedReport is a no-op +func (dp *decryptionPlugin) ShouldTransmitAcceptedReport(ctx context.Context, ts types.ReportTimestamp, report types.Report) (bool, error) { + return false, nil +} + +// Close complies with ReportingPlugin +func (dp *decryptionPlugin) Close() error { + dp.logger.Debug("DecryptionReporting Close", nil) + return nil +} diff --git a/go/ocr2/decryptionplugin/decryption_test.go b/go/ocr2/decryptionplugin/decryption_test.go new file mode 100644 index 0000000..be631a8 --- /dev/null +++ b/go/ocr2/decryptionplugin/decryption_test.go @@ -0,0 +1,3 @@ +package decryptionplugin + +// TODO diff --git a/go/ocr2/decryptionplugin/go.mod b/go/ocr2/decryptionplugin/go.mod new file mode 100644 index 0000000..d85811a --- /dev/null +++ b/go/ocr2/decryptionplugin/go.mod @@ -0,0 +1,20 @@ +module github.com/goplugin/tdh2/go/ocr2/decryptionplugin + +go 1.20 + +require ( + //github.com/goplugin/plugin-libocr v0.0.0-20230503222226-29f534b2de1a + github.com/goplugin/plugin-libocr v0.1.1-beta //plugin update changes + //github.com/goplugin/tdh2 v0.0.0-20230523083904-ccb0d2ebd7d4 + github.com/goplugin/tdh2 v0.0.1 //plugin update changes + google.golang.org/protobuf v1.30.0 +) + +require ( + github.com/mr-tron/base58 v1.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + go.dedis.ch/fixbuf v1.0.3 // indirect + go.dedis.ch/kyber/v3 v3.1.0 // indirect + golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a // indirect + golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912 // indirect +) diff --git a/go/ocr2/decryptionplugin/go.sum b/go/ocr2/decryptionplugin/go.sum new file mode 100644 index 0000000..50f11d9 --- /dev/null +++ b/go/ocr2/decryptionplugin/go.sum @@ -0,0 +1,37 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= +github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/goplugin/plugin-libocr v0.0.0-20230503222226-29f534b2de1a h1:NFTlIAjSwx9vBYZeUerd0DDsvg+zZ5TSESw4dPNdwJk= +github.com/goplugin/plugin-libocr v0.0.0-20230503222226-29f534b2de1a/go.mod h1:5JnCHuYgmIP9ZyXzgAfI5Iwu0WxBtBKp+ApeT5o1Cjw= +github.com/goplugin/tdh2 v0.0.0-20230523083904-ccb0d2ebd7d4 h1:dFozHOgWYCh3qVrv9274+d14GvgKmVyGbY18SS93ZSU= +github.com/goplugin/tdh2 v0.0.0-20230523083904-ccb0d2ebd7d4/go.mod h1:37DpReCODOYar/xLkieKH/DKeI3ubU3QKSDOjQX1QnY= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +go.dedis.ch/fixbuf v1.0.3 h1:hGcV9Cd/znUxlusJ64eAlExS+5cJDIyTyEG+otu5wQs= +go.dedis.ch/fixbuf v1.0.3/go.mod h1:yzJMt34Wa5xD37V5RTdmp38cz3QhMagdGoem9anUalw= +go.dedis.ch/kyber/v3 v3.0.4/go.mod h1:OzvaEnPvKlyrWyp3kGXlFdp7ap1VC6RkZDTaPikqhsQ= +go.dedis.ch/kyber/v3 v3.0.9/go.mod h1:rhNjUUg6ahf8HEg5HUvVBYoWY4boAafX8tYxX+PS+qg= +go.dedis.ch/kyber/v3 v3.1.0 h1:ghu+kiRgM5JyD9TJ0hTIxTLQlJBR/ehjWvWwYW3XsC0= +go.dedis.ch/kyber/v3 v3.1.0/go.mod h1:kXy7p3STAurkADD+/aZcsznZGKVHEqbtmdIzvPfrs1U= +go.dedis.ch/protobuf v1.0.5/go.mod h1:eIV4wicvi6JK0q/QnfIEGeSFNG0ZeB24kzut5+HaRLo= +go.dedis.ch/protobuf v1.0.7/go.mod h1:pv5ysfkDX/EawiPqcW3ikOxsL5t+BqnV6xHSmE79KI4= +go.dedis.ch/protobuf v1.0.11/go.mod h1:97QR256dnkimeNdfmURz0wAMNVbd1VmLXhG1CrTYrJ4= +golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/sys v0.0.0-20190124100055-b90733256f2e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912 h1:uCLL3g5wH2xjxVREVuAbP9JM5PPKjRbXKRa6IBjkzmU= +golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= diff --git a/go/ocr2/decryptionplugin/go_generate.go b/go/ocr2/decryptionplugin/go_generate.go new file mode 100644 index 0000000..293f461 --- /dev/null +++ b/go/ocr2/decryptionplugin/go_generate.go @@ -0,0 +1,4 @@ +//go:generate protoc -I. --go_out=. types.proto +//go:generate protoc -I. --go_out=./config ./config/config_types.proto + +package decryptionplugin diff --git a/go/ocr2/decryptionplugin/queue.go b/go/ocr2/decryptionplugin/queue.go new file mode 100644 index 0000000..d356d30 --- /dev/null +++ b/go/ocr2/decryptionplugin/queue.go @@ -0,0 +1,26 @@ +package decryptionplugin + +import "errors" + +var ErrNotFound = errors.New("not found") + +type CiphertextId = []byte + +type DecryptionRequest struct { + CiphertextId CiphertextId + Ciphertext []byte +} + +type DecryptionQueuingService interface { + // GetRequests returns up to requestCountLimit oldest pending requests + // with total size up to totalBytesLimit bytes size. + GetRequests(requestCountLimit int, totalBytesLimit int) []DecryptionRequest + + // GetCiphertext returns the ciphertext matching ciphertextId + // if it exists in the queue. + // If the ciphertext does not exist it returns ErrNotFound. + GetCiphertext(ciphertextId CiphertextId) ([]byte, error) + + // SetResult sets the plaintext (decrypted ciphertext) which corresponds to ciphertextId. + SetResult(ciphertextId CiphertextId, plaintext []byte) +} diff --git a/go/ocr2/decryptionplugin/types.pb.go b/go/ocr2/decryptionplugin/types.pb.go new file mode 100644 index 0000000..87c2d60 --- /dev/null +++ b/go/ocr2/decryptionplugin/types.pb.go @@ -0,0 +1,502 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc v3.20.0 +// source: types.proto + +package decryptionplugin + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type CiphertextWithID struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + CiphertextId []byte `protobuf:"bytes,1,opt,name=ciphertext_id,json=ciphertextId,proto3" json:"ciphertext_id,omitempty"` + Ciphertext []byte `protobuf:"bytes,2,opt,name=ciphertext,proto3" json:"ciphertext,omitempty"` +} + +func (x *CiphertextWithID) Reset() { + *x = CiphertextWithID{} + if protoimpl.UnsafeEnabled { + mi := &file_types_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CiphertextWithID) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CiphertextWithID) ProtoMessage() {} + +func (x *CiphertextWithID) ProtoReflect() protoreflect.Message { + mi := &file_types_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CiphertextWithID.ProtoReflect.Descriptor instead. +func (*CiphertextWithID) Descriptor() ([]byte, []int) { + return file_types_proto_rawDescGZIP(), []int{0} +} + +func (x *CiphertextWithID) GetCiphertextId() []byte { + if x != nil { + return x.CiphertextId + } + return nil +} + +func (x *CiphertextWithID) GetCiphertext() []byte { + if x != nil { + return x.Ciphertext + } + return nil +} + +type Query struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + DecryptionRequests []*CiphertextWithID `protobuf:"bytes,1,rep,name=decryption_requests,json=decryptionRequests,proto3" json:"decryption_requests,omitempty"` +} + +func (x *Query) Reset() { + *x = Query{} + if protoimpl.UnsafeEnabled { + mi := &file_types_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Query) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Query) ProtoMessage() {} + +func (x *Query) ProtoReflect() protoreflect.Message { + mi := &file_types_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Query.ProtoReflect.Descriptor instead. +func (*Query) Descriptor() ([]byte, []int) { + return file_types_proto_rawDescGZIP(), []int{1} +} + +func (x *Query) GetDecryptionRequests() []*CiphertextWithID { + if x != nil { + return x.DecryptionRequests + } + return nil +} + +type DecryptionShareWithID struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + CiphertextId []byte `protobuf:"bytes,1,opt,name=ciphertext_id,json=ciphertextId,proto3" json:"ciphertext_id,omitempty"` + DecryptionShare []byte `protobuf:"bytes,2,opt,name=decryption_share,json=decryptionShare,proto3" json:"decryption_share,omitempty"` +} + +func (x *DecryptionShareWithID) Reset() { + *x = DecryptionShareWithID{} + if protoimpl.UnsafeEnabled { + mi := &file_types_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DecryptionShareWithID) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DecryptionShareWithID) ProtoMessage() {} + +func (x *DecryptionShareWithID) ProtoReflect() protoreflect.Message { + mi := &file_types_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DecryptionShareWithID.ProtoReflect.Descriptor instead. +func (*DecryptionShareWithID) Descriptor() ([]byte, []int) { + return file_types_proto_rawDescGZIP(), []int{2} +} + +func (x *DecryptionShareWithID) GetCiphertextId() []byte { + if x != nil { + return x.CiphertextId + } + return nil +} + +func (x *DecryptionShareWithID) GetDecryptionShare() []byte { + if x != nil { + return x.DecryptionShare + } + return nil +} + +type Observation struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + DecryptionShares []*DecryptionShareWithID `protobuf:"bytes,1,rep,name=decryption_shares,json=decryptionShares,proto3" json:"decryption_shares,omitempty"` +} + +func (x *Observation) Reset() { + *x = Observation{} + if protoimpl.UnsafeEnabled { + mi := &file_types_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Observation) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Observation) ProtoMessage() {} + +func (x *Observation) ProtoReflect() protoreflect.Message { + mi := &file_types_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Observation.ProtoReflect.Descriptor instead. +func (*Observation) Descriptor() ([]byte, []int) { + return file_types_proto_rawDescGZIP(), []int{3} +} + +func (x *Observation) GetDecryptionShares() []*DecryptionShareWithID { + if x != nil { + return x.DecryptionShares + } + return nil +} + +type ProcessedDecryptionRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + CiphertextId []byte `protobuf:"bytes,1,opt,name=ciphertext_id,json=ciphertextId,proto3" json:"ciphertext_id,omitempty"` + Plaintext []byte `protobuf:"bytes,2,opt,name=plaintext,proto3" json:"plaintext,omitempty"` +} + +func (x *ProcessedDecryptionRequest) Reset() { + *x = ProcessedDecryptionRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_types_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ProcessedDecryptionRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ProcessedDecryptionRequest) ProtoMessage() {} + +func (x *ProcessedDecryptionRequest) ProtoReflect() protoreflect.Message { + mi := &file_types_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ProcessedDecryptionRequest.ProtoReflect.Descriptor instead. +func (*ProcessedDecryptionRequest) Descriptor() ([]byte, []int) { + return file_types_proto_rawDescGZIP(), []int{4} +} + +func (x *ProcessedDecryptionRequest) GetCiphertextId() []byte { + if x != nil { + return x.CiphertextId + } + return nil +} + +func (x *ProcessedDecryptionRequest) GetPlaintext() []byte { + if x != nil { + return x.Plaintext + } + return nil +} + +type Report struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ProcessedDecryptedRequests []*ProcessedDecryptionRequest `protobuf:"bytes,1,rep,name=processedDecryptedRequests,proto3" json:"processedDecryptedRequests,omitempty"` +} + +func (x *Report) Reset() { + *x = Report{} + if protoimpl.UnsafeEnabled { + mi := &file_types_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Report) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Report) ProtoMessage() {} + +func (x *Report) ProtoReflect() protoreflect.Message { + mi := &file_types_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Report.ProtoReflect.Descriptor instead. +func (*Report) Descriptor() ([]byte, []int) { + return file_types_proto_rawDescGZIP(), []int{5} +} + +func (x *Report) GetProcessedDecryptedRequests() []*ProcessedDecryptionRequest { + if x != nil { + return x.ProcessedDecryptedRequests + } + return nil +} + +var File_types_proto protoreflect.FileDescriptor + +var file_types_proto_rawDesc = []byte{ + 0x0a, 0x0b, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x74, + 0x79, 0x70, 0x65, 0x73, 0x22, 0x57, 0x0a, 0x10, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, + 0x78, 0x74, 0x57, 0x69, 0x74, 0x68, 0x49, 0x44, 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x69, 0x70, 0x68, + 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x0c, 0x63, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x49, 0x64, 0x12, 0x1e, 0x0a, + 0x0a, 0x63, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x0a, 0x63, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x22, 0x51, 0x0a, + 0x05, 0x51, 0x75, 0x65, 0x72, 0x79, 0x12, 0x48, 0x0a, 0x13, 0x64, 0x65, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x43, 0x69, 0x70, 0x68, + 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x57, 0x69, 0x74, 0x68, 0x49, 0x44, 0x52, 0x12, 0x64, 0x65, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x73, + 0x22, 0x67, 0x0a, 0x15, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x68, + 0x61, 0x72, 0x65, 0x57, 0x69, 0x74, 0x68, 0x49, 0x44, 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x69, 0x70, + 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x0c, 0x63, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x49, 0x64, 0x12, 0x29, + 0x0a, 0x10, 0x64, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x73, 0x68, 0x61, + 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x64, 0x65, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x68, 0x61, 0x72, 0x65, 0x22, 0x58, 0x0a, 0x0b, 0x4f, 0x62, 0x73, + 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x49, 0x0a, 0x11, 0x64, 0x65, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x73, 0x68, 0x61, 0x72, 0x65, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x68, 0x61, 0x72, 0x65, 0x57, 0x69, 0x74, 0x68, 0x49, + 0x44, 0x52, 0x10, 0x64, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x68, 0x61, + 0x72, 0x65, 0x73, 0x22, 0x5f, 0x0a, 0x1a, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, + 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0c, 0x63, 0x69, 0x70, 0x68, 0x65, 0x72, + 0x74, 0x65, 0x78, 0x74, 0x49, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x70, 0x6c, 0x61, 0x69, 0x6e, 0x74, + 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x70, 0x6c, 0x61, 0x69, 0x6e, + 0x74, 0x65, 0x78, 0x74, 0x22, 0x6b, 0x0a, 0x06, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x61, + 0x0a, 0x1a, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x44, 0x65, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x50, 0x72, 0x6f, 0x63, 0x65, + 0x73, 0x73, 0x65, 0x64, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x52, 0x1a, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, + 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x73, 0x42, 0x15, 0x5a, 0x13, 0x2e, 0x2f, 0x3b, 0x64, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_types_proto_rawDescOnce sync.Once + file_types_proto_rawDescData = file_types_proto_rawDesc +) + +func file_types_proto_rawDescGZIP() []byte { + file_types_proto_rawDescOnce.Do(func() { + file_types_proto_rawDescData = protoimpl.X.CompressGZIP(file_types_proto_rawDescData) + }) + return file_types_proto_rawDescData +} + +var file_types_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_types_proto_goTypes = []interface{}{ + (*CiphertextWithID)(nil), // 0: types.CiphertextWithID + (*Query)(nil), // 1: types.Query + (*DecryptionShareWithID)(nil), // 2: types.DecryptionShareWithID + (*Observation)(nil), // 3: types.Observation + (*ProcessedDecryptionRequest)(nil), // 4: types.ProcessedDecryptionRequest + (*Report)(nil), // 5: types.Report +} +var file_types_proto_depIdxs = []int32{ + 0, // 0: types.Query.decryption_requests:type_name -> types.CiphertextWithID + 2, // 1: types.Observation.decryption_shares:type_name -> types.DecryptionShareWithID + 4, // 2: types.Report.processedDecryptedRequests:type_name -> types.ProcessedDecryptionRequest + 3, // [3:3] is the sub-list for method output_type + 3, // [3:3] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name +} + +func init() { file_types_proto_init() } +func file_types_proto_init() { + if File_types_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_types_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CiphertextWithID); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_types_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Query); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_types_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DecryptionShareWithID); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_types_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Observation); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_types_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ProcessedDecryptionRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_types_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Report); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_types_proto_rawDesc, + NumEnums: 0, + NumMessages: 6, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_types_proto_goTypes, + DependencyIndexes: file_types_proto_depIdxs, + MessageInfos: file_types_proto_msgTypes, + }.Build() + File_types_proto = out.File + file_types_proto_rawDesc = nil + file_types_proto_goTypes = nil + file_types_proto_depIdxs = nil +} diff --git a/go/ocr2/decryptionplugin/types.proto b/go/ocr2/decryptionplugin/types.proto new file mode 100644 index 0000000..8968692 --- /dev/null +++ b/go/ocr2/decryptionplugin/types.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +option go_package = "./;decryptionplugin"; + +package types; + +message CiphertextWithID { + bytes ciphertext_id = 1; + bytes ciphertext = 2; +} +message Query { + repeated CiphertextWithID decryption_requests = 1; +} + +message DecryptionShareWithID { + bytes ciphertext_id = 1; + bytes decryption_share = 2; +} + +message Observation { + repeated DecryptionShareWithID decryption_shares = 1; +} + +message ProcessedDecryptionRequest { + bytes ciphertext_id = 1; + bytes plaintext = 2; +} + +message Report { + repeated ProcessedDecryptionRequest processedDecryptedRequests = 1; +} \ No newline at end of file diff --git a/go/tdh2/go.mod b/go/tdh2/go.mod new file mode 100644 index 0000000..c04c134 --- /dev/null +++ b/go/tdh2/go.mod @@ -0,0 +1,14 @@ +module github.com/goplugin/tdh2/go/tdh2 + +go 1.19 + +require ( + github.com/google/go-cmp v0.5.9 + go.dedis.ch/kyber/v3 v3.1.0 +) + +require ( + go.dedis.ch/fixbuf v1.0.3 // indirect + golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b // indirect + golang.org/x/sys v0.0.0-20190124100055-b90733256f2e // indirect +) diff --git a/go/tdh2/go.sum b/go/tdh2/go.sum new file mode 100644 index 0000000..b754456 --- /dev/null +++ b/go/tdh2/go.sum @@ -0,0 +1,22 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +go.dedis.ch/fixbuf v1.0.3 h1:hGcV9Cd/znUxlusJ64eAlExS+5cJDIyTyEG+otu5wQs= +go.dedis.ch/fixbuf v1.0.3/go.mod h1:yzJMt34Wa5xD37V5RTdmp38cz3QhMagdGoem9anUalw= +go.dedis.ch/kyber/v3 v3.0.4/go.mod h1:OzvaEnPvKlyrWyp3kGXlFdp7ap1VC6RkZDTaPikqhsQ= +go.dedis.ch/kyber/v3 v3.0.9/go.mod h1:rhNjUUg6ahf8HEg5HUvVBYoWY4boAafX8tYxX+PS+qg= +go.dedis.ch/kyber/v3 v3.1.0 h1:ghu+kiRgM5JyD9TJ0hTIxTLQlJBR/ehjWvWwYW3XsC0= +go.dedis.ch/kyber/v3 v3.1.0/go.mod h1:kXy7p3STAurkADD+/aZcsznZGKVHEqbtmdIzvPfrs1U= +go.dedis.ch/protobuf v1.0.5/go.mod h1:eIV4wicvi6JK0q/QnfIEGeSFNG0ZeB24kzut5+HaRLo= +go.dedis.ch/protobuf v1.0.7/go.mod h1:pv5ysfkDX/EawiPqcW3ikOxsL5t+BqnV6xHSmE79KI4= +go.dedis.ch/protobuf v1.0.11/go.mod h1:97QR256dnkimeNdfmURz0wAMNVbd1VmLXhG1CrTYrJ4= +golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b h1:Elez2XeF2p9uyVj0yEUDqQ56NFcDtcBNkYP7yv8YbUE= +golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/sys v0.0.0-20190124100055-b90733256f2e h1:3GIlrlVLfkoipSReOMNAgApI0ajnalyLa/EZHHca/XI= +golang.org/x/sys v0.0.0-20190124100055-b90733256f2e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/go/tdh2/tdh2/tdh2.go b/go/tdh2/tdh2/tdh2.go new file mode 100644 index 0000000..0d104df --- /dev/null +++ b/go/tdh2/tdh2/tdh2.go @@ -0,0 +1,716 @@ +// Package tdh2 implements the TDH2 protocol (Shoup and Gennaro, 2001: https://www.shoup.net/papers/thresh1.pdf). +package tdh2 + +import ( + "bytes" + "crypto/cipher" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + + "go.dedis.ch/kyber/v3" + "go.dedis.ch/kyber/v3/group/nist" + "go.dedis.ch/kyber/v3/share" +) + +var ( + // defaultHash is the default hash function used. Note, its output size + // determines the input size in TDH2. + defaultHash = sha256.New + // InputSize determines the size of messages and labels. + InputSize = defaultHash().Size() +) + +func parseGroup(group string) (kyber.Group, error) { + switch group { + case nist.NewBlakeSHA256P256().String(): + return nist.NewBlakeSHA256P256(), nil + } + return nil, fmt.Errorf("unsupported group: %q", group) +} + +// PrivateShare is a node's private share. It extends kyber's share.PriShare. +type PrivateShare struct { + group kyber.Group + index int + v kyber.Scalar +} + +func (s PrivateShare) String() string { + return fmt.Sprintf("grp:%s idx:%d", s.group.String(), s.index) +} + +func (s PrivateShare) Index() int { + return s.index +} + +// mulPoint returns a new point with value v*p, where v is a private scalar. +// If p==nil, the returned point has value v*BasePoint. +func (s *PrivateShare) mulPoint(p kyber.Point) kyber.Point { + return s.group.Point().Mul(s.v, p) +} + +// mulScalar returns a new scalar with value v*a where v is a private scalar. +func (s *PrivateShare) mulScalar(a kyber.Scalar) kyber.Scalar { + return s.group.Scalar().Mul(s.v, a) +} + +// privateShareRaw is used for PrivateShare (un)marshaling. +type privateShareRaw struct { + Group string + Index int + V []byte +} + +func (s PrivateShare) Marshal() ([]byte, error) { + v, err := s.v.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("cannot marshal V: %w", err) + } + return json.Marshal(&privateShareRaw{ + Group: s.group.String(), + Index: s.index, + V: v, + }) +} + +func (s *PrivateShare) Unmarshal(data []byte) error { + var raw privateShareRaw + err := json.Unmarshal(data, &raw) + if err != nil { + return fmt.Errorf("cannot unmarshal data: %w", err) + } + + s.group, err = parseGroup(raw.Group) + if err != nil { + return fmt.Errorf("cannot parse group: %w", err) + } + + s.index = raw.Index + if s.v, err = unmarshalScalar(s.group, raw.V); err != nil { + return fmt.Errorf("cannot unmarshal: %w", err) + } + return nil +} + +// PubliKey defines a public and verification key. +type PublicKey struct { + group kyber.Group + g_bar kyber.Point + h kyber.Point + hArray []kyber.Point +} + +func (a *PublicKey) Equal(b *PublicKey) bool { + if a.group.String() != b.group.String() || !a.g_bar.Equal(b.g_bar) || !a.h.Equal(b.h) { + return false + } + if len(a.hArray) != len(b.hArray) { + return false + } + for i := range a.hArray { + if !a.hArray[i].Equal(b.hArray[i]) { + return false + } + } + return true +} + +// MasterSecret keeps the master secret of a TDH2 instance. +type MasterSecret struct { + group kyber.Group + s kyber.Scalar +} + +func (m MasterSecret) String() string { + return fmt.Sprintf("group:%s value:hidden", m.group) +} + +// masterSecretRaw is used for MasterSecret (un)marshaling. +type masterSecretRaw struct { + Group string + S []byte +} + +func (m *MasterSecret) Marshal() ([]byte, error) { + s, err := m.s.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("cannot marshal s: %w", err) + } + return json.Marshal(&masterSecretRaw{ + Group: m.group.String(), + S: s, + }) +} + +func (m *MasterSecret) Unmarshal(data []byte) error { + var raw masterSecretRaw + if err := json.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("cannot unmarshal data: %w", err) + } + var err error + m.group, err = parseGroup(raw.Group) + if err != nil { + return fmt.Errorf("cannot parse group: %w", err) + } + m.s = m.group.Scalar() + if err := m.s.UnmarshalBinary(raw.S); err != nil { + return fmt.Errorf("cannot unmarshal s: %w", err) + } + return nil +} + +// publicKeyRaw is used for PublicKey (un)marshaling. +type publicKeyRaw struct { + Group string + G_bar []byte + H []byte + HArray [][]byte +} + +func (p PublicKey) Marshal() ([]byte, error) { + gbar, err := p.g_bar.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("marshaling G_bar: %w", err) + } + + h, err := p.h.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("marshaling H: %w", err) + } + + harray := [][]byte{} + for _, h := range p.hArray { + d, err := h.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("cannot marshal H: %w", err) + } + harray = append(harray, d) + } + + return json.Marshal(&publicKeyRaw{ + Group: p.group.String(), + G_bar: gbar, + H: h, + HArray: harray, + }) +} + +func (p *PublicKey) Unmarshal(data []byte) error { + var raw publicKeyRaw + err := json.Unmarshal(data, &raw) + if err != nil { + return fmt.Errorf("cannot unmarshal data: %w", err) + } + + p.group, err = parseGroup(raw.Group) + if err != nil { + return fmt.Errorf("cannot parse group: %w", err) + } + + if p.g_bar, err = unmarshalPoint(p.group, raw.G_bar); err != nil { + return fmt.Errorf("unmarshaling G_bar: %w", err) + } + + if p.h, err = unmarshalPoint(p.group, raw.H); err != nil { + return fmt.Errorf("unmarshaling H: %w", err) + } + + p.hArray = []kyber.Point{} + for _, h := range raw.HArray { + new, err := unmarshalPoint(p.group, h) + if err != nil { + return fmt.Errorf("cannot unmarshal point: %w", err) + } + p.hArray = append(p.hArray, new) + } + + return nil +} + +// GenerateKeys generates a master secret, public key, and secret key shares according to TDH2 paper. +// It takes cryptographic group to be used, master secret to be used (if nil, a new secret is generated), +// the total number of nodes n, the number of shares sufficient for decryption k, and a randomness source. +// It returns the master secret (either passed or generated), public key, and secret key shares. +func GenerateKeys(group kyber.Group, ms *MasterSecret, k, n int, rand cipher.Stream) (*MasterSecret, *PublicKey, []*PrivateShare, error) { + if k > n { + return nil, nil, nil, fmt.Errorf("threshold is higher than total number of nodes") + } + if k <= 0 { + return nil, nil, nil, fmt.Errorf("threshold has to be positive") + } + if ms != nil && group.String() != ms.group.String() { + return nil, nil, nil, fmt.Errorf("inconsistent groups") + } + + var s kyber.Scalar + if ms != nil { + s = ms.s + } + poly := share.NewPriPoly(group, k, s, rand) + x := poly.Secret() + if ms != nil && !x.Equal(ms.s) { + return nil, nil, nil, fmt.Errorf("generated wrong secret") + } + + HArray := make([]kyber.Point, n) + shares := poly.Shares(n) + privShares := []*PrivateShare{} + // IDs are assigned consecutively from 0. + for i, s := range shares { + if i != s.I { + return nil, nil, nil, fmt.Errorf("share index=%d, expect=%d", s.I, i) + } + HArray[i] = group.Point().Mul(s.V, nil) + privShares = append(privShares, &PrivateShare{group, s.I, s.V}) + } + + return &MasterSecret{ + group: group, + s: x}, + &PublicKey{ + group: group, + g_bar: group.Point().Pick(rand), + h: group.Point().Mul(x, nil), + hArray: HArray, + }, privShares, nil +} + +// Redeal re-deals private shares such that new quorums can decrypt old ciphertexts. It takes the +// previous public key and master secret as well as the number of nodes sufficient for decrypt k, +// the total number of nodes n, and a randomness source. It returns a new public key and private shares. +// The master secret passed corresponds to the public key returned. The old public key can still be used +// for encryption but it cannot be used for share verification (the new key has to be used instead). +func Redeal(pk *PublicKey, ms *MasterSecret, k, n int, rand cipher.Stream) (*PublicKey, []*PrivateShare, error) { + if ms == nil { + return nil, nil, fmt.Errorf("nil secret") + } + _, new, shares, err := GenerateKeys(pk.group, ms, k, n, rand) + if err != nil { + return nil, nil, fmt.Errorf("cannot generate keys: %w", err) + } + return &PublicKey{ + group: pk.group, + g_bar: pk.g_bar, + h: pk.h, + hArray: new.hArray, + }, shares, nil +} + +// Encrypt a message with a label (see p15 of the paper). +func Encrypt(pk *PublicKey, msg []byte, label []byte, rand cipher.Stream) (*Ciphertext, error) { + r := pk.group.Scalar().Pick(rand) + s := pk.group.Scalar().Pick(rand) + + h, err := hash1(pk.group.String(), pk.group.Point().Mul(r, pk.h)) + if err != nil { + return nil, fmt.Errorf("cannot hash: %w", err) + } + c, err := xor(h, msg) + if err != nil { + return nil, fmt.Errorf("cannot xor: %w", err) + } + + u := pk.group.Point().Mul(r, nil) + w := pk.group.Point().Mul(s, nil) + u_bar := pk.group.Point().Mul(r, pk.g_bar) + w_bar := pk.group.Point().Mul(s, pk.g_bar) + e, err := hash2(c, label, u, w, u_bar, w_bar, pk.group) + if err != nil { + return nil, fmt.Errorf("cannot generate e: %w", err) + } + f := pk.group.Scalar().Add(s, pk.group.Scalar().Mul(r, e.Clone())) + + return &Ciphertext{ + group: pk.group, + c: c, + label: label, + u: u, + u_bar: u_bar, + e: e, + f: f, + }, nil +} + +// VerifyShare verifies the correctness of the decryption share obtained from node i. +// The caller has to ensure that the ciphertext is validated. +func VerifyShare(pk *PublicKey, ctxt *Ciphertext, share *DecryptionShare) error { + if pk.group.String() != ctxt.group.String() { + return fmt.Errorf("incorrect ciphertext group: %q", ctxt.group) + } + + if pk.group.String() != share.group.String() { + return fmt.Errorf("incorrect share group: %q", share.group) + } + + if err := checkEi(pk, ctxt, share); err != nil { + return fmt.Errorf("failed format validity check: %w", err) + } + + return nil +} + +// checkEi checks the validity of param E_i to ensure that it is a DH triple (formula 3 on p13). +func checkEi(pk *PublicKey, ctxt *Ciphertext, share *DecryptionShare) error { + g := pk.group + ui_hat := g.Point().Sub(g.Point().Mul(share.f_i, ctxt.u), g.Point().Mul(share.e_i, share.u_i)) + if share.index >= len(pk.hArray) { + return fmt.Errorf("invalid share index") + } + hi_hat := g.Point().Sub(g.Point().Mul(share.f_i, nil), g.Point().Mul(share.e_i, pk.hArray[share.index])) + ei, err := hash4(share.u_i, ui_hat, hi_hat, pk.group) + if err != nil { + return fmt.Errorf("cannot generate e_i: %w", err) + } + if !share.e_i.Equal(ei) { + return fmt.Errorf("error during the verification of E_i") + } + return nil +} + +// Ciphertext defines a ciphertext as output from the Encryption algorithm. +type Ciphertext struct { + group kyber.Group + c []byte + label []byte + u kyber.Point + u_bar kyber.Point + e kyber.Scalar + f kyber.Scalar +} + +// Verify checks if the ciphertext matches the public key +// (i.e., it checks the validity of param e -- see formula 4 on p15). +func (c *Ciphertext) Verify(pk *PublicKey) error { + if c.group.String() != pk.group.String() { + return fmt.Errorf("group mismatch") + } + w := pk.group.Point().Sub(pk.group.Point().Mul(c.f, nil), pk.group.Point().Mul(c.e, c.u)) + w_bar := pk.group.Point().Sub(pk.group.Point().Mul(c.f, pk.g_bar), pk.group.Point().Mul(c.e, c.u_bar)) + e, err := hash2(c.c, c.label, c.u, w, c.u_bar, w_bar, pk.group) + if err != nil { + return fmt.Errorf("cannot compute e: %w", err) + } + if !c.e.Equal(e) { + return fmt.Errorf("wrong e") + } + return nil +} + +func (a *Ciphertext) Equal(b *Ciphertext) bool { + if a.group.String() != b.group.String() || + !bytes.Equal(a.c, b.c) || + !bytes.Equal(a.label, b.label) || + !a.u.Equal(b.u) || + !a.u_bar.Equal(b.u_bar) || + !a.e.Equal(b.e) || + !a.f.Equal(b.f) { + return false + } + return true + +} + +// Decrypt decrypts a ciphertext using a secret key share x_i according to TDH2 paper. +// The caller has to ensure that the ciphertext is validated. +func (ctxt *Ciphertext) Decrypt(group kyber.Group, x_i *PrivateShare, rand cipher.Stream) (*DecryptionShare, error) { + if group.String() != ctxt.group.String() { + return nil, fmt.Errorf("incorrect ciphertext group: %q", ctxt.group) + } + if group.String() != x_i.group.String() { + return nil, fmt.Errorf("incorrect share group: %q", x_i.group) + } + + s_i := group.Scalar().Pick(rand) + u_i := x_i.mulPoint(ctxt.u) + u_hat := group.Point().Mul(s_i, ctxt.u) + h_hat := group.Point().Mul(s_i, nil) + e_i, err := hash4(u_i, u_hat, h_hat, group) + if err != nil { + return nil, fmt.Errorf("cannot generate e_i: %w", err) + } + f_i := group.Scalar().Add(s_i, x_i.mulScalar(e_i.Clone())) + return &DecryptionShare{ + group: group, + index: x_i.index, + u_i: u_i, + e_i: e_i, + f_i: f_i, + }, nil +} + +// CombineShares combines a set of decryption shares and returns the decrypted message. +// The caller has to ensure that the ciphertext is validated. +func (c *Ciphertext) CombineShares(group kyber.Group, shares []*DecryptionShare, k, n int) ([]byte, error) { + if group.String() != c.group.String() { + return nil, fmt.Errorf("incorrect ciphertext group: %q", c.group) + } + + if len(shares) < k { + return nil, fmt.Errorf("too few shares") + } + + pubShares := []*share.PubShare{} + for _, s := range shares { + if group.String() != s.group.String() { + return nil, fmt.Errorf("incorrect share group: %q", s.group) + } + pubShares = append(pubShares, &share.PubShare{ + I: s.index, + V: s.u_i, + }) + } + + arg, err := share.RecoverCommit(group, pubShares, k, n) + if err != nil { + return nil, fmt.Errorf("cannot recover secret: %w", err) + } + + h, err := hash1(c.group.String(), arg) + if err != nil { + return nil, fmt.Errorf("failed to marshal %q: %w", arg, err) + } + + return xor(h, c.c) +} + +// ciphertextRaw is used for Ciphertext (un)marshaling. +type ciphertextRaw struct { + Group string + C []byte + Label []byte + U []byte + U_bar []byte + E []byte + F []byte +} + +func (c Ciphertext) Marshal() ([]byte, error) { + u, err := c.u.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("cannot marshal U: %w", err) + } + ubar, err := c.u_bar.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("cannot marshal U_bar: %w", err) + } + f, err := c.f.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("cannot marshal F: %w", err) + } + e, err := c.e.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("cannot marshal E: %w", err) + } + return json.Marshal(&ciphertextRaw{ + Group: c.group.String(), + C: c.c, + Label: c.label, + U: u, + U_bar: ubar, + E: e, + F: f, + }) +} + +func (c *Ciphertext) Unmarshal(data []byte) error { + var raw ciphertextRaw + err := json.Unmarshal(data, &raw) + if err != nil { + return fmt.Errorf("cannot unmarshal data: %w", err) + } + c.c = raw.C + c.label = raw.Label + c.group, err = parseGroup(raw.Group) + if err != nil { + return fmt.Errorf("cannot parse group: %w", err) + } + if c.e, err = unmarshalScalar(c.group, raw.E); err != nil { + return fmt.Errorf("cannot unmarshal E: %w", err) + } + if c.u, err = unmarshalPoint(c.group, raw.U); err != nil { + return fmt.Errorf("cannot unmarshal U: %w", err) + } + if c.u_bar, err = unmarshalPoint(c.group, raw.U_bar); err != nil { + return fmt.Errorf("cannot unmarshal U_bar: %w", err) + } + if c.f, err = unmarshalScalar(c.group, raw.F); err != nil { + return fmt.Errorf("cannot unmarshal F: %w", err) + } + return nil +} + +// DecryptionShare defines a decryption share +type DecryptionShare struct { + group kyber.Group + index int + u_i kyber.Point + e_i kyber.Scalar + f_i kyber.Scalar +} + +// TODO(pszal): test + fix tests which currently ignore share equality +func (a *DecryptionShare) Equal(b *DecryptionShare) bool { + if a.group.String() != b.group.String() || + a.index != b.index || + !a.u_i.Equal(b.u_i) || + !a.e_i.Equal(b.e_i) || + !a.f_i.Equal(b.f_i) { + return false + } + return true +} + +func (d DecryptionShare) Index() int { + return d.index +} + +// decryptionShareRaw is used for DecryptionShare (un)marshaling. +type decryptionShareRaw struct { + Group string + Index int + U_i []byte + E_i []byte + F_i []byte +} + +func (d DecryptionShare) Marshal() ([]byte, error) { + u, err := d.u_i.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("cannot marshal U_i: %w", err) + } + f, err := d.f_i.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("cannot marshal F_i: %w", err) + } + e, err := d.e_i.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("cannot marshal E_i: %w", err) + } + return json.Marshal(&decryptionShareRaw{ + Group: d.group.String(), + Index: d.index, + U_i: u, + E_i: e, + F_i: f, + }) +} + +func (d *DecryptionShare) Unmarshal(data []byte) error { + var raw decryptionShareRaw + err := json.Unmarshal(data, &raw) + if err != nil { + return fmt.Errorf("cannot unmarshal data: %w", err) + } + d.index = raw.Index + d.group, err = parseGroup(raw.Group) + if err != nil { + return fmt.Errorf("cannot parse group: %w", err) + } + if d.e_i, err = unmarshalScalar(d.group, raw.E_i); err != nil { + return fmt.Errorf("cannot unmarshal E_i: %w", err) + } + if d.u_i, err = unmarshalPoint(d.group, raw.U_i); err != nil { + return fmt.Errorf("cannot unmarshal U_i: %w", err) + } + if d.f_i, err = unmarshalScalar(d.group, raw.F_i); err != nil { + return fmt.Errorf("cannot unmarshal F_i: %w", err) + } + return nil +} + +// hash is a generic hash function +func hash(msg []byte) []byte { + h := defaultHash() + h.Write(msg) + return h.Sum(nil) +} + +// hash1 is an implementation of the H_1 hash function (see p15 of the paper). +func hash1(group string, g kyber.Point) ([]byte, error) { + point, err := concatenate(group, g) + if err != nil { + return nil, fmt.Errorf("cannot concatenate points: %w", err) + } + return hash(append([]byte("tdh2hash1"), point...)), nil +} + +// hash2 is an implementation of the H_2 hash function (see p15 of the paper). +func hash2(msg, label []byte, g1, g2, g3, g4 kyber.Point, group kyber.Group) (kyber.Scalar, error) { + if len(msg) != len(label) || len(msg) != InputSize { + return nil, fmt.Errorf("message and label must be %dB long", InputSize) + } + + points, err := concatenate(group.String(), g1, g2, g3, g4) + if err != nil { + return nil, fmt.Errorf("cannot concatenate points: %w", err) + } + input := []byte("tdh2hash2") + for _, arg := range [][]byte{msg, label, points} { + input = append(input, arg...) + } + + return group.Scalar().SetBytes(hash(input)), nil +} + +// hash4 is an implementation of the H_4 hash function (see p15 of the paper). +func hash4(g1, g2, g3 kyber.Point, group kyber.Group) (kyber.Scalar, error) { + points, err := concatenate(group.String(), g1, g2, g3) + if err != nil { + return nil, fmt.Errorf("cannot concatenate points: %w", err) + } + h := hash(append([]byte("tdh2hash4"), points...)) + + return group.Scalar().SetBytes(h), nil +} + +// concatenate marshals and concatenates points (elements of a group). It is +// used in hash functions. +func concatenate(group string, points ...kyber.Point) ([]byte, error) { + final := group + for _, point := range points { + p, err := point.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("cannot marshal point=%v err=%v", point, err) + } + final += "," + hex.EncodeToString(p) + } + return []byte(final), nil +} + +// xor computes and returns XOR between two slices. +func xor(a, b []byte) ([]byte, error) { + if len(a) != len(b) { + return nil, fmt.Errorf("length of byte slices is not equivalent: %d != %d", len(a), len(b)) + } + buf := make([]byte, len(a)) + for i := range a { + buf[i] = a[i] ^ b[i] + } + return buf, nil +} + +// unmarshalPoint unmarshals point safely, i.e., avoiding panics present in the kyber lib. +func unmarshalPoint(g kyber.Group, data []byte) (kyber.Point, error) { + if len(data) != g.PointLen() { + return nil, fmt.Errorf("incorrect length") + } + p := g.Point() + if err := p.UnmarshalBinary(data); err != nil { + return nil, err + } + return p, nil +} + +// unmarshalScalar unmarshals scalar safely, i.e., avoiding panics present in the kyber lib. +func unmarshalScalar(g kyber.Group, data []byte) (kyber.Scalar, error) { + if len(data) != g.ScalarLen() { + return nil, fmt.Errorf("incorrect length") + } + s := g.Scalar() + if err := s.UnmarshalBinary(data); err != nil { + return nil, err + } + return s, nil +} diff --git a/go/tdh2/tdh2/tdh2_test.go b/go/tdh2/tdh2/tdh2_test.go new file mode 100644 index 0000000..5d4e571 --- /dev/null +++ b/go/tdh2/tdh2/tdh2_test.go @@ -0,0 +1,1907 @@ +package tdh2 + +import ( + "bytes" + "crypto/cipher" + "crypto/rand" + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "go.dedis.ch/kyber/v3" + "go.dedis.ch/kyber/v3/group/nist" + "go.dedis.ch/kyber/v3/xof/keccak" +) + +var supportedGroups = []string{ + nist.NewBlakeSHA256P256().String(), +} + +type common interface { + Fatalf(format string, args ...interface{}) +} + +func params(t common, group string) (kyber.Group, cipher.Stream, []byte, []byte) { + if _, ok := t.(*testing.T); ok { + t.(*testing.T).Helper() + } + seed := make([]byte, 64) + if n, err := rand.Read(seed); n != 64 || err != nil { + t.Fatalf("cannot generate seed; n=%d, err=%v", n, err) + } + msg := make([]byte, InputSize) + if _, err := rand.Read(msg); err != nil { + t.Fatalf("rand.Read: %v", err) + } + label := make([]byte, InputSize) + if _, err := rand.Read(label); err != nil { + t.Fatalf("rand.Read: %v", err) + } + g, err := parseGroup(group) + if err != nil { + t.Fatalf("parseGroup: %v", err) + } + return g, keccak.New(seed), msg, label +} + +func TestConcatenate(t *testing.T) { + for _, typ := range supportedGroups { + group, rand, _, _ := params(t, typ) + g1 := group.Point().Pick(rand) + g2 := group.Point().Pick(rand) + g3 := group.Point().Pick(rand) + + out, err := concatenate(group.String(), g1) + if err != nil { + t.Errorf("concatenate(g1): %v", err) + } + size := len(out) + if size == 0 { + t.Errorf("concatenate(g1): empty output") + } + + out, err = concatenate(group.String(), g1, g2) + if err != nil { + t.Errorf("concatenate(g1, g2): %v", err) + } + if len(out) <= size { + t.Errorf("concatenate(g1, g2): output shorter/equal the previous") + } + size = len(out) + + out, err = concatenate(group.String(), g1, g2, g3) + if err != nil { + t.Errorf("concatenate(g1, g2, g3): %v", err) + } + if len(out) <= size { + t.Errorf("concatenate(g1, g2, g3): output shorter/equal the previous") + } + } +} + +func TestXor(t *testing.T) { + for _, tc := range []struct { + name string + a []byte + b []byte + expect []byte + err error + }{ + { + name: "empty", + }, + { + name: "OK", + a: []byte{0, 1, 2, 3}, + b: []byte{4, 5, 6, 7}, + expect: []byte{0 ^ 4, 1 ^ 5, 2 ^ 6, 3 ^ 7}, + }, + { + name: "mismatch", + a: []byte{0, 1, 2, 3}, + b: []byte{4, 5, 6}, + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + out, err := xor(tc.a, tc.b) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } + if err == nil && !bytes.Equal(out, tc.expect) { + t.Errorf("got=%v, expected=%v", out, tc.expect) + } + }) + } +} + +func TestHash1and4(t *testing.T) { + for _, typ := range supportedGroups { + group, rand, _, _ := params(t, typ) + g1 := group.Point().Pick(rand) + g2 := group.Point().Pick(rand) + g3 := group.Point().Pick(rand) + + out, err := hash1(group.String(), g1) + if err != nil { + t.Errorf("hash1: %v", err) + } + if len(out) != InputSize { + t.Errorf("hash1 output size: %d, expect: %d", len(out), InputSize) + } + + if _, err := hash4(g1, g2, g3, group); err != nil { + t.Errorf("hash4: %v", err) + } + } +} + +func TestHash2(t *testing.T) { + for _, typ := range supportedGroups { + group, rnd, msg, label := params(t, typ) + g1 := group.Point().Pick(rnd) + g2 := group.Point().Pick(rnd) + g3 := group.Point().Pick(rnd) + g4 := group.Point().Pick(rnd) + + for _, tc := range []struct { + name string + msg []byte + label []byte + err error + }{ + { + name: "OK", + msg: msg, + label: label, + }, + { + name: "both short", + msg: msg[:InputSize-1], + label: label[:InputSize-1], + err: cmpopts.AnyError, + }, + { + name: "one shorter", + msg: msg[:InputSize-2], + label: label, + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + _, err := hash2(tc.msg, tc.label, g1, g2, g3, g4, group) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } + }) + } + } +} + +func TestGenerateKeys(t *testing.T) { + for _, typ := range supportedGroups { + group, rand, _, _ := params(t, typ) + for _, tc := range []struct { + name string + ms *MasterSecret + k int + n int + err error + }{ + { + name: "0 out of 0", + err: cmpopts.AnyError, + }, + { + name: "0 out of 1", + n: 1, + err: cmpopts.AnyError, + }, + { + name: "1 out of 1", + k: 1, + n: 1, + }, + { + name: "secret ok", + ms: &MasterSecret{ + group: group, + s: group.Scalar().Pick(rand)}, + k: 1, + n: 1, + }, + { + name: "secret wrong group", + ms: &MasterSecret{ + group: nist.NewBlakeSHA256QR512(), + s: group.Scalar().Pick(rand)}, + k: 1, + n: 1, + err: cmpopts.AnyError, + }, + { + name: "-1 out of 1", + k: -1, + n: 1, + err: cmpopts.AnyError, + }, + { + name: "10 out of 9", + k: 10, + n: 9, + err: cmpopts.AnyError, + }, + { + name: "1 out of 10", + k: 1, + n: 10, + }, + { + name: "5 out of 10", + k: 7, + n: 10, + }, + { + name: "10 out of 10", + k: 10, + n: 10, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + ms, pk, shares, err := GenerateKeys(group, tc.ms, tc.k, tc.n, rand) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + if len(shares) != tc.n { + t.Errorf("got %d shares, expected %d", len(shares), tc.n) + } + if len(pk.hArray) != tc.n { + t.Errorf("got %d vk.HArray, expected %d", len(pk.hArray), tc.n) + } + if tc.ms != nil && !reflect.DeepEqual(ms, tc.ms) { + t.Errorf("got secret=%v, want=%v", ms, tc.ms) + } + }) + } + } +} + +func TestEncrypt(t *testing.T) { + for _, typ := range supportedGroups { + group, rand, msg, label := params(t, typ) + _, pk, _, err := GenerateKeys(group, nil, 3, 5, rand) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + for _, tc := range []struct { + name string + msg []byte + label []byte + err error + }{ + { + name: "OK", + msg: msg, + label: label, + }, + { + name: "wrong msg size", + msg: []byte("msg"), + label: label, + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + ctxt, err := Encrypt(pk, tc.msg, tc.label, rand) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + if diff := cmp.Diff(label, ctxt.label); diff != "" { + t.Errorf("label/ctx.Label diff: %v", diff) + } + }) + } + } +} + +func TestDecrypt(t *testing.T) { + wrong := nist.NewBlakeSHA256QR512() + for _, typ := range supportedGroups { + group, rand, msg, label := params(t, typ) + _, pk, shares, err := GenerateKeys(group, nil, 3, 5, rand) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + ctxt, err := Encrypt(pk, msg, label, rand) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + for _, tc := range []struct { + name string + ctxt *Ciphertext + share *PrivateShare + err error + }{ + { + name: "OK", + ctxt: ctxt, + share: shares[2], + }, + { + name: "wrong share group", + ctxt: ctxt, + share: &PrivateShare{ + group: wrong, + index: shares[2].index, + v: shares[2].v, + }, + err: cmpopts.AnyError, + }, + { + name: "wrong group", + ctxt: &Ciphertext{ + group: wrong, + c: ctxt.c, + label: ctxt.label, + u: ctxt.u, + u_bar: ctxt.u_bar, + e: ctxt.e, + f: ctxt.f, + }, + share: shares[2], + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + if _, err := tc.ctxt.Decrypt(group, tc.share, rand); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } + }) + } + } +} + +func TestCtxtVerify(t *testing.T) { + for _, typ := range supportedGroups { + group, rand, msg, label := params(t, typ) + _, pk, _, err := GenerateKeys(group, nil, 3, 5, rand) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + ctxt, err := Encrypt(pk, msg, label, rand) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + for _, tc := range []struct { + name string + ctxt *Ciphertext + err error + }{ + { + name: "OK", + ctxt: ctxt, + }, + { + name: "wrong group", + ctxt: &Ciphertext{ + group: nist.NewBlakeSHA256QR512(), + c: ctxt.c, + label: ctxt.label, + u: ctxt.u, + u_bar: ctxt.u_bar, + e: ctxt.e, + f: ctxt.f, + }, + err: cmpopts.AnyError, + }, + { + name: "broken C", + ctxt: &Ciphertext{ + group: group, + c: []byte("broken"), + label: ctxt.label, + u: ctxt.u, + u_bar: ctxt.u_bar, + e: ctxt.e, + f: ctxt.f, + }, + err: cmpopts.AnyError, + }, + { + name: "broken Label", + ctxt: &Ciphertext{ + group: group, + c: ctxt.c, + label: []byte("label"), + u: ctxt.u, + u_bar: ctxt.u_bar, + e: ctxt.e, + f: ctxt.f, + }, + err: cmpopts.AnyError, + }, + { + name: "broken U", + ctxt: &Ciphertext{ + group: group, + c: ctxt.c, + label: ctxt.label, + u: group.Point().Pick(rand), + u_bar: ctxt.u_bar, + e: ctxt.e, + f: ctxt.f, + }, + err: cmpopts.AnyError, + }, + { + name: "broken U_bar", + ctxt: &Ciphertext{ + group: group, + c: ctxt.c, + label: ctxt.label, + u: ctxt.u, + u_bar: group.Point().Pick(rand), + e: ctxt.e, + f: ctxt.f, + }, + err: cmpopts.AnyError, + }, + { + name: "broken E", + ctxt: &Ciphertext{ + group: group, + c: ctxt.c, + label: ctxt.label, + u: ctxt.u, + u_bar: ctxt.u_bar, + e: group.Scalar().Pick(rand), + f: ctxt.f, + }, + err: cmpopts.AnyError, + }, + { + name: "broken F", + ctxt: &Ciphertext{ + group: group, + c: ctxt.c, + label: ctxt.label, + u: ctxt.u, + u_bar: ctxt.u_bar, + e: ctxt.e, + f: group.Scalar().Pick(rand), + }, + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + if err := tc.ctxt.Verify(pk); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } + }) + } + } +} + +func TestCheckEi(t *testing.T) { + for _, typ := range supportedGroups { + group, rand, msg, label := params(t, typ) + _, pk, shares, err := GenerateKeys(group, nil, 3, 5, rand) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + ctxt, err := Encrypt(pk, msg, label, rand) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + ds, err := ctxt.Decrypt(group, shares[2], rand) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + for _, tc := range []struct { + name string + ctxt *Ciphertext + share *DecryptionShare + err error + }{ + { + name: "OK", + ctxt: ctxt, + share: ds, + }, + { + name: "broken U", + ctxt: &Ciphertext{ + c: ctxt.c, + label: ctxt.label, + u: group.Point().Pick(rand), + u_bar: ctxt.u_bar, + e: ctxt.e, + f: ctxt.f, + }, + share: ds, + err: cmpopts.AnyError, + }, + { + name: "out of band index", + ctxt: ctxt, + share: &DecryptionShare{ + index: 10, + u_i: ds.u_i, + e_i: ds.e_i, + f_i: ds.f_i, + }, + err: cmpopts.AnyError, + }, + { + name: "broken U", + ctxt: ctxt, + share: &DecryptionShare{ + index: ds.index, + u_i: group.Point().Pick(rand), + e_i: ds.e_i, + f_i: ds.f_i, + }, + err: cmpopts.AnyError, + }, + { + name: "broken E", + ctxt: ctxt, + share: &DecryptionShare{ + index: ds.index, + u_i: ds.u_i, + e_i: group.Scalar().Pick(rand), + f_i: ds.f_i, + }, + err: cmpopts.AnyError, + }, + { + name: "broken F", + ctxt: ctxt, + share: &DecryptionShare{ + index: ds.index, + u_i: ds.u_i, + e_i: ds.e_i, + f_i: group.Scalar().Pick(rand), + }, + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + if err := checkEi(pk, tc.ctxt, tc.share); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } + }) + } + } +} + +func TestVerifyShare(t *testing.T) { + for _, typ := range supportedGroups { + group, rand, msg, label := params(t, typ) + _, pk, shares, err := GenerateKeys(group, nil, 3, 5, rand) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + _, pkWrong, _, err := GenerateKeys(group, nil, 3, 5, rand) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + wrong := nist.NewBlakeSHA256QR512() + ctxt, err := Encrypt(pk, msg, label, rand) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + ds, err := ctxt.Decrypt(group, shares[0], rand) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + for _, tc := range []struct { + name string + pk *PublicKey + ctxt *Ciphertext + share *DecryptionShare + err error + }{ + { + name: "OK", + pk: pk, + ctxt: ctxt, + share: ds, + }, + { + name: "wrong pk", + pk: pkWrong, + ctxt: ctxt, + share: ds, + err: cmpopts.AnyError, + }, + { + name: "broken ctxt", + pk: pk, + ctxt: &Ciphertext{ + group: group, + c: ctxt.c, + label: ctxt.label, + u: group.Point().Pick(rand), + u_bar: ctxt.u_bar, + e: ctxt.e, + f: ctxt.f, + }, + share: ds, + err: cmpopts.AnyError, + }, + { + name: "wrong ctxt group", + pk: pk, + ctxt: &Ciphertext{ + group: wrong, + c: ctxt.c, + label: ctxt.label, + u: ctxt.u, + u_bar: ctxt.u_bar, + e: ctxt.e, + f: ctxt.f, + }, + share: ds, + err: cmpopts.AnyError, + }, + { + name: "broken decryption share", + pk: pk, + ctxt: ctxt, + share: &DecryptionShare{ + group: group, + index: ds.index, + u_i: ds.u_i, + e_i: ds.e_i, + f_i: group.Scalar().Pick(rand), + }, + err: cmpopts.AnyError, + }, + { + name: "wrong share group", + pk: pk, + ctxt: ctxt, + share: &DecryptionShare{ + group: wrong, + index: ds.index, + u_i: ds.u_i, + e_i: ds.e_i, + f_i: ds.f_i, + }, + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + if err := VerifyShare(tc.pk, tc.ctxt, tc.share); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } + }) + } + } +} + +func TestCombineShares(t *testing.T) { + for _, typ := range supportedGroups { + group, rand, msg, label := params(t, typ) + _, pk, shares, err := GenerateKeys(group, nil, 3, 5, rand) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + wrong := nist.NewBlakeSHA256QR512() + _, pkWrong, _, err := GenerateKeys(group, nil, 3, 5, rand) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + pkWrong.group = wrong + ctxt, err := Encrypt(pk, msg, label, rand) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + decShares := make([]*DecryptionShare, 5) + for i := range shares { + ds, err := ctxt.Decrypt(group, shares[i], rand) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + decShares[i] = ds + } + for _, tc := range []struct { + name string + ctxt *Ciphertext + shares []*DecryptionShare + k int + n int + err error + }{ + { + name: "OK (all shares)", + ctxt: ctxt, + shares: decShares, + k: 3, + n: 5, + }, + { + name: "OK (min shares)", + ctxt: ctxt, + shares: decShares[:3], + k: 3, + n: 5, + }, + { + name: "OK (reordered shares)", + ctxt: ctxt, + shares: []*DecryptionShare{decShares[4], decShares[3], decShares[0]}, + k: 3, + n: 5, + }, + { + name: "Replayed shares", + ctxt: ctxt, + shares: []*DecryptionShare{decShares[4], decShares[3], decShares[4]}, + k: 3, + n: 5, + err: cmpopts.AnyError, + }, + { + name: "not enough", + ctxt: ctxt, + shares: decShares[:2], + k: 3, + n: 5, + err: cmpopts.AnyError, + }, + { + name: "wrong ctxt group", + ctxt: &Ciphertext{ + group: wrong, + c: ctxt.c, + label: ctxt.label, + u: ctxt.u, + u_bar: ctxt.u_bar, + e: ctxt.e, + f: ctxt.f, + }, + shares: decShares[:3], + k: 3, + n: 5, + err: cmpopts.AnyError, + }, + { + name: "wrong share group", + ctxt: ctxt, + shares: []*DecryptionShare{{ + group: wrong, + index: decShares[4].index, + u_i: decShares[4].u_i, + e_i: decShares[4].e_i, + f_i: decShares[4].f_i, + }}, + k: 1, + n: 5, + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + out, err := tc.ctxt.CombineShares(group, tc.shares, tc.k, tc.n) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } + if err != nil { + return + } + if diff := cmp.Diff(msg, out); diff != "" { + t.Errorf("original/decrypted message diff: %v", diff) + } + }) + } + } +} + +func TestParseGroup(t *testing.T) { + for _, tc := range []struct { + group string + err error + }{ + { + group: nist.NewBlakeSHA256P256().String(), + }, + { + group: "wrong", + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("group=%v", tc.group), func(t *testing.T) { + if _, err := parseGroup(tc.group); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("got err=%v, want=%v", err, tc.err) + } + }) + } +} + +func TestPublicKeyUnmarshal(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + point, err := g.Point().Pick(r).MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary: %v", err) + } + for _, tc := range []struct { + name string + raw []byte + err error + }{ + { + name: "ok", + raw: toJSON(t, &publicKeyRaw{ + Group: typ, + G_bar: point, + H: point, + HArray: [][]byte{point}, + }), + }, + { + name: "broken", + raw: []byte("broken"), + err: cmpopts.AnyError, + }, + { + name: "wrong group", + raw: toJSON(t, &publicKeyRaw{ + Group: "wrong", + G_bar: point, + H: point, + }), + err: cmpopts.AnyError, + }, + { + name: "wrong G", + raw: toJSON(t, &publicKeyRaw{ + Group: typ, + G_bar: []byte("broken"), + H: point, + }), + err: cmpopts.AnyError, + }, + { + name: "wrong H", + raw: toJSON(t, &publicKeyRaw{ + Group: typ, + G_bar: point, + H: []byte("broken"), + }), + err: cmpopts.AnyError, + }, + { + name: "wrong HArray", + raw: toJSON(t, &publicKeyRaw{ + Group: typ, + G_bar: point, + H: point, + HArray: [][]byte{[]byte("wrong")}, + }), + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + var pk PublicKey + if err := pk.Unmarshal(tc.raw); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("got err=%v, want=%v", err, tc.err) + } + }) + } + } +} + +func TestPublicKeyMarshal(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + for i, want := range []*PublicKey{ + { + group: g, + g_bar: g.Point().Pick(r), + h: g.Point().Pick(r), + }, + { + group: g, + g_bar: g.Point().Pick(r), + h: g.Point().Pick(r), + hArray: []kyber.Point{g.Point().Pick(r)}, + }, + { + group: g, + g_bar: g.Point().Pick(r), + h: g.Point().Pick(r), + hArray: []kyber.Point{g.Point().Pick(r), g.Point().Pick(r), g.Point().Pick(r)}, + }, + } { + t.Run(fmt.Sprintf("i=%d group=%v", i, typ), func(t *testing.T) { + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got PublicKey + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !got.Equal(want) { + t.Error("public keys not equal") + } + }) + } + } +} + +func TestCiphertextMarshal(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + want := &Ciphertext{ + group: g, + c: []byte("some c"), + label: []byte("some label"), + u: g.Point().Pick(r), + u_bar: g.Point().Pick(r), + e: g.Scalar().Pick(r), + f: g.Scalar().Pick(r), + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got Ciphertext + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !got.Equal(want) { + t.Errorf("different ciphertexts") + } + } +} + +func TestCiphertextUnmarshal(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + point, err := g.Point().Pick(r).MarshalBinary() + if err != nil { + t.Fatalf("point.MarshalBinary: %v", err) + } + scalar, err := g.Scalar().Pick(r).MarshalBinary() + if err != nil { + t.Fatalf("scalar.MarshalBinary: %v", err) + } + tmp := g.Scalar().Pick(r) + e, err := tmp.MarshalBinary() + if err != nil { + t.Fatalf("e.MarshalBinary: %v", err) + } + + for _, tc := range []struct { + name string + raw []byte + err error + }{ + { + name: "ok", + raw: toJSON(t, &ciphertextRaw{ + Group: g.String(), + C: []byte("some c"), + Label: []byte("some label"), + U: point, + U_bar: point, + E: e, + F: scalar, + }), + }, + { + name: "broken", + raw: []byte("broken"), + err: cmpopts.AnyError, + }, + { + name: "wrong group", + raw: toJSON(t, &ciphertextRaw{ + Group: "wrong", + C: []byte("some c"), + Label: []byte("some label"), + U: point, + U_bar: point, + E: e, + F: scalar, + }), + err: cmpopts.AnyError, + }, + { + name: "wrong E", + raw: toJSON(t, &ciphertextRaw{ + Group: g.String(), + C: []byte("some c"), + Label: []byte("some label"), + U: point, + U_bar: point, + E: []byte("broken"), + F: scalar, + }), + err: cmpopts.AnyError, + }, + { + name: "wrong U", + raw: toJSON(t, &ciphertextRaw{ + Group: g.String(), + C: []byte("some c"), + Label: []byte("some label"), + U: []byte("123"), + U_bar: point, + E: e, + F: scalar, + }), + err: cmpopts.AnyError, + }, + { + name: "wrong Ubar", + raw: toJSON(t, &ciphertextRaw{ + Group: g.String(), + C: []byte("some c"), + Label: []byte("some label"), + U: point, + U_bar: []byte("123"), + E: e, + F: scalar, + }), + err: cmpopts.AnyError, + }, + { + name: "wrong F", + raw: toJSON(t, &ciphertextRaw{ + Group: g.String(), + C: []byte("some c"), + Label: []byte("some label"), + U: point, + U_bar: point, + E: e, + F: []byte("123"), + }), + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + var c Ciphertext + if err := c.Unmarshal(tc.raw); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("got err=%v, want=%v", err, tc.err) + } + }) + } + } +} + +func TestMasterSecretMarshal(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + want := &MasterSecret{ + group: g, + s: g.Scalar().Pick(r), + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got MasterSecret + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if got.group.String() != want.group.String() { + t.Errorf("got group=%v, want=%v", got.group, want.group) + } + if !got.s.Equal(want.s) { + t.Errorf("got s=%v, want=%v", got.s, want.s) + } + } +} + +func TestMasterSecretUnmarshal(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + s, err := g.Scalar().Pick(r).MarshalBinary() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + for _, tc := range []struct { + name string + raw []byte + err error + }{ + { + name: "ok", + raw: toJSON(t, &masterSecretRaw{ + Group: g.String(), + S: s, + }), + }, + { + name: "broken", + raw: []byte("broken"), + err: cmpopts.AnyError, + }, + { + name: "broken group", + raw: toJSON(t, &masterSecretRaw{ + Group: "broken", + S: s, + }), + err: cmpopts.AnyError, + }, + { + name: "broken s", + raw: toJSON(t, &masterSecretRaw{ + Group: g.String(), + S: []byte("broken"), + }), + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + var ms MasterSecret + if err := ms.Unmarshal(tc.raw); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("got err=%v, want=%v", err, tc.err) + } + }) + } + } +} + +func TestDecryptionShareMarshal(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + want := &DecryptionShare{ + group: g, + index: 123, + u_i: g.Point().Pick(r), + e_i: g.Scalar().Pick(r), + f_i: g.Scalar().Pick(r), + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got DecryptionShare + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if want.group.String() != got.group.String() { + t.Errorf("got group=%s, want=%v", got.group, want.group) + } + if want.index != got.index { + t.Errorf("got index=%v, want=%v", got.index, want.index) + } + if d := cmp.Diff(got.e_i, want.e_i); d != "" { + t.Errorf("got/want E_i diff=%v", d) + } + if !got.u_i.Equal(want.u_i) { + t.Errorf("got U_i=%v, want=%v", got.u_i, want.u_i) + } + if !got.f_i.Equal(want.f_i) { + t.Errorf("got F_i=%v, want=%v", got.f_i, want.f_i) + } + } +} + +func TestDecryptionShareUnmarshal(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + point, err := g.Point().Pick(r).MarshalBinary() + if err != nil { + t.Fatalf("point.MarshalBinary: %v", err) + } + scalar, err := g.Scalar().Pick(r).MarshalBinary() + if err != nil { + t.Fatalf("scalar.MarshalBinary: %v", err) + } + tmp := g.Scalar().Pick(r) + e, err := tmp.MarshalBinary() + if err != nil { + t.Fatalf("e.MarshalBinary: %v", err) + } + for _, tc := range []struct { + name string + raw []byte + err error + }{ + { + name: "ok", + raw: toJSON(t, &decryptionShareRaw{ + Group: g.String(), + Index: 123, + U_i: point, + E_i: e, + F_i: scalar, + }), + }, + { + name: "broken", + raw: []byte("broken"), + err: cmpopts.AnyError, + }, + { + name: "broken group", + raw: toJSON(t, &decryptionShareRaw{ + Group: "wrong", + Index: 123, + U_i: point, + E_i: e, + F_i: scalar, + }), + err: cmpopts.AnyError, + }, + { + name: "broken E", + raw: toJSON(t, &decryptionShareRaw{ + Group: g.String(), + Index: 123, + U_i: point, + E_i: []byte("broken"), + F_i: scalar, + }), + err: cmpopts.AnyError, + }, + { + name: "broken Ui", + raw: toJSON(t, &decryptionShareRaw{ + Group: g.String(), + Index: 123, + U_i: []byte("broken"), + E_i: e, + F_i: scalar, + }), + err: cmpopts.AnyError, + }, + { + name: "broken Fi", + raw: toJSON(t, &decryptionShareRaw{ + Group: g.String(), + Index: 123, + U_i: point, + E_i: e, + F_i: []byte("broken scalar"), + }), + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + var ds DecryptionShare + if err := ds.Unmarshal(tc.raw); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("got err=%v, want=%v", err, tc.err) + } + }) + } + } +} + +func TestPrivateShareMarshal(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + want := &PrivateShare{ + group: g, + index: 123, + v: g.Scalar().Pick(r), + } + data, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + got := &PrivateShare{} + if err := got.Unmarshal(data); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if want.group.String() != got.group.String() { + t.Errorf("got group=%s, want=%v", got.group, want.group) + } + if !got.v.Equal(want.v) { + t.Errorf("got V=%v, want=%v", got.v, want.v) + } + if got.index != want.index { + t.Errorf("got I=%v, want=%v", got.index, want.index) + } + } +} + +func TestPrivateShareUnmarshal(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + scalar, err := g.Scalar().Pick(r).MarshalBinary() + if err != nil { + t.Fatalf("scalar.MarshalBinary: %v", err) + } + for _, tc := range []struct { + name string + raw []byte + err error + }{ + { + name: "ok", + raw: toJSON(t, &privateShareRaw{ + Group: g.String(), + Index: 123, + V: scalar, + }), + }, + { + name: "broken", + raw: []byte("broken"), + err: cmpopts.AnyError, + }, + { + name: "broken group", + raw: toJSON(t, &privateShareRaw{ + Group: "wrong", + Index: 123, + V: scalar, + }), + err: cmpopts.AnyError, + }, + { + name: "broken V", + raw: toJSON(t, &privateShareRaw{ + Group: g.String(), + Index: 123, + V: []byte("wrong"), + }), + err: cmpopts.AnyError, + }, + } { + t.Run(fmt.Sprintf("test=%q group=%v", tc.name, typ), func(t *testing.T) { + var ps PrivateShare + if err := ps.Unmarshal(tc.raw); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("got err=%v, want=%v", err, tc.err) + } + }) + } + } +} + +func BenchmarkAll(b *testing.B) { + for _, tc := range []struct { + k int + n int + }{{k: 3, n: 5}, {k: 4, n: 7}, {k: 6, n: 10}, {k: 8, n: 15}} { + for _, typ := range supportedGroups { + // setup and validity checks + group, rand, msg, label := params(b, typ) + _, pk, shares, err := GenerateKeys(group, nil, tc.k, tc.n, rand) + if err != nil { + b.Fatalf("GenerateKeys: %v", err) + } + ctxt, err := Encrypt(pk, msg, label, rand) + if err != nil { + b.Fatalf("Encrypt: %v", err) + } + decShares := make([]*DecryptionShare, tc.n) + for i := range shares { + ds, err := ctxt.Decrypt(group, shares[i], rand) + if err != nil { + b.Fatalf("Decrypt: %v", err) + } + decShares[i] = ds + err = VerifyShare(pk, ctxt, ds) + if err != nil { + b.Fatalf("VerifyShare: %v", err) + } + } + out, err := ctxt.CombineShares(group, decShares[:tc.k], tc.k, tc.n) + if err != nil { + b.Fatalf("CombineShares: %v", err) + } + if diff := cmp.Diff(msg, out); diff != "" { + b.Fatalf("original/decrypted message diff: %v", diff) + } + // run actual benchmarks + b.Run(fmt.Sprintf("%v %d out of %d Generate", typ, tc.k, tc.n), func(b *testing.B) { + for i := 0; i < b.N; i++ { + GenerateKeys(group, nil, tc.k, tc.n, rand) + } + }) + b.Run(fmt.Sprintf("%v %d out of %d Encrypt", typ, tc.k, tc.n), func(b *testing.B) { + for i := 0; i < b.N; i++ { + Encrypt(pk, msg, label, rand) + } + }) + b.Run(fmt.Sprintf("%v %d out of %d Decrypt", typ, tc.k, tc.n), func(b *testing.B) { + for i := 0; i < b.N; i++ { + ctxt.Decrypt(group, shares[i%len(shares)], rand) + } + }) + b.Run(fmt.Sprintf("%v %d out of %d VerifyShare", typ, tc.k, tc.n), func(b *testing.B) { + for i := 0; i < b.N; i++ { + VerifyShare(pk, ctxt, decShares[i%len(decShares)]) + } + }) + b.Run(fmt.Sprintf("%v %d out of %d CombineShares", typ, tc.k, tc.n), func(b *testing.B) { + for i := 0; i < b.N; i++ { + ctxt.CombineShares(group, decShares[:tc.k], tc.k, tc.n) + } + }) + } + } +} + +func BenchmarkEC(b *testing.B) { + for _, typ := range supportedGroups { + g, r, _, _ := params(b, typ) + p := g.Point().Pick(r) + q := g.Point().Pick(r) + s := g.Scalar().Pick(r) + b.Run(fmt.Sprintf("%v Add", typ), func(b *testing.B) { + for i := 0; i < b.N; i++ { + g.Point().Add(p, q) + } + }) + b.Run(fmt.Sprintf("%v Mul", typ), func(b *testing.B) { + for i := 0; i < b.N; i++ { + g.Point().Mul(s, p) + } + }) + b.Run(fmt.Sprintf("%v Sub", typ), func(b *testing.B) { + for i := 0; i < b.N; i++ { + g.Point().Sub(p, q) + } + }) + } +} + +func BenchmarkChecks(b *testing.B) { + for _, typ := range supportedGroups { + group, rand, msg, label := params(b, typ) + _, pk, shares, err := GenerateKeys(group, nil, 1, 1, rand) + if err != nil { + b.Fatalf("GenerateKeys: %v", err) + } + ctxt, err := Encrypt(pk, msg, label, rand) + if err != nil { + b.Fatalf("Encrypt: %v", err) + } + decShares := make([]*DecryptionShare, 1) + for i := range shares { + ds, err := ctxt.Decrypt(group, shares[i], rand) + if err != nil { + b.Fatalf("Decrypt: %v", err) + } + decShares[i] = ds + err = VerifyShare(pk, ctxt, ds) + if err != nil { + b.Fatalf("VerifyShare: %v", err) + } + } + + b.Run(fmt.Sprintf("%v ctxt.Verify", typ), func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err := ctxt.Verify(pk); err != nil { + b.Fatalf("checkE: %v", err) + } + } + }) + + b.Run(fmt.Sprintf("%v checkEi", typ), func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err := checkEi(pk, ctxt, decShares[0]); err != nil { + b.Fatalf("checkEi: %v", err) + } + } + }) + } +} + +func TestRedeal(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + ms, pk, _, err := GenerateKeys(g, nil, 2, 5, r) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + for _, tc := range []struct { + name string + ms *MasterSecret + k, n int + err error + }{ + { + name: "ok", + k: 2, + n: 5, + ms: ms, + }, + { + name: "different sizes", + k: 1, + n: 7, + ms: ms, + }, + { + name: "nil ms", + k: 2, + n: 5, + err: cmpopts.AnyError, + }, + { + name: "wrong ms", + k: 2, + n: 5, + ms: &MasterSecret{ + group: nist.NewBlakeSHA256QR512(), + s: ms.s, + }, + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + got, shares, err := Redeal(pk, tc.ms, tc.k, tc.n, r) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Fatalf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + if len(shares) != tc.n { + t.Errorf("got %d shares, want %d", len(shares), tc.n) + } + if got.group.String() != pk.group.String() { + t.Errorf("got group=%v, want=%v", got.group, pk.group) + } + if !got.g_bar.Equal(pk.g_bar) { + t.Errorf("got g_bar=%v, want=%v", got.g_bar, pk.g_bar) + } + if !got.h.Equal(pk.h) { + t.Errorf("got h=%v, want=%v", got.h, pk.h) + } + if len(got.hArray) != tc.n { + t.Errorf("got hArray len=%v, want=%v", len(got.hArray), tc.n) + } + }) + } + } +} + +func TestRedealNewEncryption(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + ms, pk, _, err := GenerateKeys(g, nil, 3, 5, r) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + msg := []byte("12345678901234567890123456789012") + for _, tc := range []struct { + name string + k, n int + }{ + { + name: "same n,k", + k: 3, + n: 5, + }, + { + name: "smaller quorum", + k: 2, + n: 5, + }, + { + name: "larger quorum", + k: 4, + n: 5, + }, + } { + t.Run(tc.name, func(t *testing.T) { + newPk, shares, err := Redeal(pk, ms, tc.k, tc.n, r) + if err != nil { + t.Fatalf("Redeal: %v", err) + } + c, err := Encrypt(newPk, msg, make([]byte, 32), r) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + ds := []*DecryptionShare{} + for _, sh := range shares { + d, err := c.Decrypt(newPk.group, sh, r) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if err := VerifyShare(newPk, c, d); err != nil { + t.Fatalf("VerifyShare: %v", err) + } + ds = append(ds, d) + } + if m, err := c.CombineShares(newPk.group, ds[:tc.k], tc.k, tc.n); err != nil { + t.Errorf("CombineShares: %v", err) + } else if !cmp.Equal(m, msg) { + t.Errorf("got msg=%v, want=%v", m, msg) + } + }) + } + } +} + +func TestRedealOldDecryption(t *testing.T) { + for _, typ := range supportedGroups { + g, r, _, _ := params(t, typ) + msg := []byte("12345678901234567890123456789012") + ms, pk, _, err := GenerateKeys(g, nil, 3, 5, r) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + c, err := Encrypt(pk, msg, make([]byte, 32), r) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + for _, tc := range []struct { + name string + k, n int + }{ + { + name: "same n,k", + k: 3, + n: 5, + }, + { + name: "smaller quorum", + k: 2, + n: 5, + }, + { + name: "larger quorum", + k: 4, + n: 5, + }, + } { + t.Run(tc.name, func(t *testing.T) { + newPk, shares, err := Redeal(pk, ms, tc.k, tc.n, r) + if err != nil { + t.Fatalf("Redeal: %v", err) + } + ds := []*DecryptionShare{} + for _, sh := range shares { + d, err := c.Decrypt(newPk.group, sh, r) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if err := VerifyShare(newPk, c, d); err != nil { + t.Fatalf("VerifyShare: %v", err) + } + ds = append(ds, d) + } + // try to combine w/o enough new shares + if m, err := c.CombineShares(newPk.group, ds[:tc.k-1], tc.k-1, tc.n); err != nil { + t.Errorf("CombineShares: %v", err) + } else if cmp.Equal(m, msg) { + t.Errorf("got correct message") + } + // now try with enough new shares + if m, err := c.CombineShares(newPk.group, ds[:tc.k], tc.k, tc.n); err != nil { + t.Errorf("CombineShares: %v", err) + } else if !cmp.Equal(m, msg) { + t.Errorf("got msg=%v, want=%v", m, msg) + } + }) + } + } +} + +func TestRedealReuseOldShares(t *testing.T) { + for _, typ := range supportedGroups { + t.Run(typ, func(t *testing.T) { + g, r, _, _ := params(t, typ) + ms, pk, shares, err := GenerateKeys(g, nil, 2, 3, r) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + newPk, _, err := Redeal(pk, ms, 2, 3, r) + if err != nil { + t.Fatalf("Redeal: %v", err) + } + c, err := Encrypt(newPk, make([]byte, 32), make([]byte, 32), r) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + ds, err := c.Decrypt(g, shares[0], r) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + // make sure old shares cannot be used for new encryptions + if err := VerifyShare(newPk, c, ds); err == nil { + t.Error("VerifyShare did not fail") + } + }) + } +} + +func toJSON(t *testing.T, v interface{}) []byte { + t.Helper() + blob, err := json.Marshal(v) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + return blob +} + +type marshaler interface { + Marshal() ([]byte, error) +} + +func mustMarshal(f *testing.F, m marshaler) []byte { + f.Helper() + b, err := m.Marshal() + if err != nil { + f.Fatalf("Marshal: %v", err) + } + return b +} + +func FuzzPrivateShareMarshal(f *testing.F) { + for _, groupStr := range supportedGroups { + g, err := parseGroup(groupStr) + if err != nil { + f.Fatalf("parseGroup: %v", err) + } + f.Add(mustMarshal(f, PrivateShare{ + group: g, + index: 123, + v: g.Scalar().Pick(keccak.New(nil)), + })) + } + f.Fuzz(func(t *testing.T, data []byte) { + var ps1, ps2 PrivateShare + if err := ps1.Unmarshal(data); err != nil { + t.Skip() + } + data1, err := ps1.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data, err) + } + if err := ps2.Unmarshal(data1); err != nil { + t.Fatalf("Cannot unmarshal: data=%v err=%v", data1, err) + } + data2, err := ps2.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data2, err) + } + if !bytes.Equal(data1, data2) { + t.Errorf("data1=%v data2=%v", data1, data2) + } + if !reflect.DeepEqual(ps1, ps2) { + t.Errorf("ps1=%v data1=%v ps2=%v data2=%v", ps1, data1, ps2, data2) + } + }) +} + +func FuzzPublicKeyMarshal(f *testing.F) { + r := keccak.New(nil) + for _, groupStr := range supportedGroups { + g, err := parseGroup(groupStr) + if err != nil { + f.Fatalf("parseGroup: %v", err) + } + f.Add(mustMarshal(f, PublicKey{ + group: g, + g_bar: g.Point().Pick(r), + h: g.Point().Pick(r), + })) + f.Add(mustMarshal(f, PublicKey{ + group: g, + g_bar: g.Point().Pick(r), + h: g.Point().Pick(r), + hArray: []kyber.Point{g.Point().Pick(r)}, + })) + f.Add(mustMarshal(f, PublicKey{ + group: g, + g_bar: g.Point().Pick(r), + h: g.Point().Pick(r), + hArray: []kyber.Point{g.Point().Pick(r), g.Point().Pick(r)}, + })) + } + f.Fuzz(func(t *testing.T, data []byte) { + var pk1, pk2 PublicKey + if err := pk1.Unmarshal(data); err != nil { + t.Skip() + } + data1, err := pk1.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data, err) + } + if err := pk2.Unmarshal(data1); err != nil { + t.Fatalf("Cannot unmarshal: data=%v err=%v", data1, err) + } + data2, err := pk2.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data2, err) + } + if !bytes.Equal(data1, data2) { + t.Errorf("data1=%v data2=%v", data1, data2) + } + if !pk1.Equal(&pk2) { + t.Errorf("pk1=%v data1=%v pk2=%v data2=%v", pk1, data1, pk2, data2) + } + }) +} + +func FuzzCiphertextMarshal(f *testing.F) { + r := keccak.New(nil) + for _, groupStr := range supportedGroups { + g, err := parseGroup(groupStr) + if err != nil { + f.Fatalf("parseGroup: %v", err) + } + f.Add(mustMarshal(f, Ciphertext{ + group: g, + c: []byte("ctxt"), + label: []byte("label"), + u: g.Point().Pick(r), + u_bar: g.Point().Pick(r), + e: g.Scalar().Pick(r), + f: g.Scalar().Pick(r), + })) + } + f.Fuzz(func(t *testing.T, data []byte) { + var c1, c2 Ciphertext + if err := c1.Unmarshal(data); err != nil { + t.Skip() + } + data1, err := c1.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data, err) + } + if err := c2.Unmarshal(data1); err != nil { + t.Fatalf("Cannot unmarshal: data=%v err=%v", data1, err) + } + data2, err := c2.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data2, err) + } + if !bytes.Equal(data1, data2) { + t.Errorf("data1=%v data2=%v", data1, data2) + } + if !c1.Equal(&c2) { + t.Errorf("c1=%v data1=%v c2=%v data2=%v", c1, data1, c2, data2) + } + }) +} + +func FuzzDecryptionShareMarshal(f *testing.F) { + r := keccak.New(nil) + for _, groupStr := range supportedGroups { + g, err := parseGroup(groupStr) + if err != nil { + f.Fatalf("parseGroup: %v", err) + } + f.Add(mustMarshal(f, DecryptionShare{ + group: g, + index: 123, + u_i: g.Point().Pick(r), + e_i: g.Scalar().Pick(r), + f_i: g.Scalar().Pick(r), + })) + } + f.Fuzz(func(t *testing.T, data []byte) { + var ds1, ds2 DecryptionShare + if err := ds1.Unmarshal(data); err != nil { + t.Skip() + } + data1, err := ds1.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data, err) + } + if err := ds2.Unmarshal(data1); err != nil { + t.Fatalf("Cannot unmarshal: data=%v err=%v", data1, err) + } + data2, err := ds2.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data2, err) + } + if !bytes.Equal(data1, data2) { + t.Errorf("data1=%v data2=%v", data1, data2) + } + if !ds1.Equal(&ds2) { + t.Errorf("ds1=%v data1=%v ds2=%v data2=%v", ds1, data1, ds2, data2) + } + }) +} diff --git a/go/tdh2/tdh2easy/sym.go b/go/tdh2/tdh2easy/sym.go new file mode 100644 index 0000000..ec57b81 --- /dev/null +++ b/go/tdh2/tdh2easy/sym.go @@ -0,0 +1,52 @@ +package tdh2easy + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" +) + +// symKey generates a symmetric key. +func symKey(keySize int) ([]byte, error) { + key := make([]byte, keySize) + if _, err := rand.Read(key); err != nil { + return nil, fmt.Errorf("cannot generate key") + } + return key, nil +} + +// symEncrypt encrypts the message using the AES-GCM cipher. +func symEncrypt(msg, key []byte) ([]byte, []byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, nil, fmt.Errorf("cannot use AES: %v", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, nil, fmt.Errorf("cannot use GCM mode: %v", err) + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, nil, fmt.Errorf("cannot generate nonce") + } + + return gcm.Seal(nil, nonce, msg, nil), nonce, nil +} + +// symDecrypt decrypts the ciphertext using the AES-GCM cipher. +func symDecrypt(nonce, ctxt, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("cannot use AES: %v", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("cannot use GCM mode: %v", err) + } + if len(nonce) != gcm.NonceSize() { + return nil, fmt.Errorf("nonce must have %dB", gcm.NonceSize()) + } + + return gcm.Open(nil, nonce, ctxt, nil) +} diff --git a/go/tdh2/tdh2easy/sym_test.go b/go/tdh2/tdh2easy/sym_test.go new file mode 100644 index 0000000..7b4284a --- /dev/null +++ b/go/tdh2/tdh2easy/sym_test.go @@ -0,0 +1,141 @@ +package tdh2easy + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestSymmetric(t *testing.T) { + key, err := symKey(16) + if err != nil { + t.Fatalf("symmetricKey: %v", err) + } + for _, tc := range []struct { + name string + msg []byte + key []byte + err error + }{ + { + name: "OK", + msg: []byte("msg"), + key: key, + }, + { + name: "OK (empty)", + key: key, + }, + { + name: "OK (long)", + msg: make([]byte, 65536), + key: key, + }, + { + name: "wrong key length", + msg: make([]byte, 65536), + key: key[:4], + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + c, nonce, err := symEncrypt(tc.msg, tc.key) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + out, err := symDecrypt(nonce, c, key) + if err != nil { + t.Errorf("symmetricDecryption: %v", err) + } + if diff := cmp.Diff(tc.msg, out); diff != "" { + t.Errorf("encrypted/decrypted message diff=%v", diff) + } + }) + } +} + +func TestSymmetricDecryptionFail(t *testing.T) { + msg := []byte("msg") + key, err := symKey(16) + if err != nil { + t.Fatalf("symmetricKey: %v", err) + } + c, nonce, err := symEncrypt(msg, key) + if err != nil { + t.Fatalf("symmetricEncryption: %v", err) + } + for _, tc := range []struct { + name string + nonce []byte + c []byte + key []byte + err error + }{ + { + name: "OK", + key: key, + nonce: nonce, + c: c, + }, + { + name: "wrong key", + key: []byte("key"), + nonce: nonce, + c: c, + err: cmpopts.AnyError, + }, + { + name: "wrong nonce", + key: key, + nonce: []byte("nonce"), + c: c, + err: cmpopts.AnyError, + }, + { + name: "wrong c", + key: key, + nonce: nonce, + c: []byte("ciphertext"), + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + out, err := symDecrypt(nonce, c, key) + if err != nil { + t.Errorf("symmetricDecryption: %v", err) + } + if diff := cmp.Diff(msg, out); diff != "" { + t.Errorf("encrypted/decrypted message diff=%v", diff) + } + }) + } +} + +func FuzzSymEncryption(f *testing.F) { + f.Add(16, []byte("sample message")) + f.Add(24, []byte("another sample message")) + f.Add(32, []byte("and another sample message")) + f.Fuzz(func(t *testing.T, keySize int, msg []byte) { + if keySize != 16 && keySize != 24 && keySize != 32 { + t.Skip() + } + key, err := symKey(keySize) + if err != nil { + t.Fatalf("symKey(%v): %v", keySize, err) + } + c, n, err := symEncrypt(msg, key) + if err != nil { + t.Fatalf("symEncrypt(%v, %v): %v", msg, key, err) + } + p, err := symDecrypt(n, c, key) + if err != nil { + t.Fatalf("symDecryt(%v, %v, %v): %v", n, c, key, err) + } + if d := cmp.Diff(p, msg); d != "" { + t.Fatalf("got/want diff=%v", d) + } + }) +} diff --git a/go/tdh2/tdh2easy/tdh2easy.go b/go/tdh2/tdh2easy/tdh2easy.go new file mode 100644 index 0000000..82fd20d --- /dev/null +++ b/go/tdh2/tdh2easy/tdh2easy.go @@ -0,0 +1,251 @@ +// Package tdh2easy implements an easy interface of TDH2-based hybrid encryption. +package tdh2easy + +import ( + "crypto/rand" + "encoding/json" + "fmt" + + "github.com/goplugin/tdh2/go/tdh2/tdh2" + "go.dedis.ch/kyber/v3" + "go.dedis.ch/kyber/v3/group/nist" + "go.dedis.ch/kyber/v3/xof/keccak" +) + +// key size used in symmetric encryption (AES). 256 bits is a higher securitylevel than provided +// by the EC group deployed, but as tdh2.InputSize is 256 bits we decided to use the same value. +const aes256KeySize = 32 + +// defaultGroup is the default EC group used. +var defaultGroup = nist.NewBlakeSHA256P256() + +// PrivateShare encodes TDH2 private share. +type PrivateShare struct { + p *tdh2.PrivateShare +} + +// Index returns private share index. +func (p *PrivateShare) Index() int { + return p.p.Index() +} + +func (p PrivateShare) Marshal() ([]byte, error) { + return p.p.Marshal() +} + +func (p *PrivateShare) Unmarshal(data []byte) error { + p.p = &tdh2.PrivateShare{} + return p.p.Unmarshal(data) +} + +// DecryptionShare encodes TDH2 decryption share. +type DecryptionShare struct { + d *tdh2.DecryptionShare +} + +// Index returns private share index. +func (d *DecryptionShare) Index() int { + return d.d.Index() +} + +func (d DecryptionShare) Marshal() ([]byte, error) { + return d.d.Marshal() +} + +func (d *DecryptionShare) Unmarshal(data []byte) error { + d.d = &tdh2.DecryptionShare{} + return d.d.Unmarshal(data) +} + +// PublicKey encodes TDH2 public key. +type PublicKey struct { + p *tdh2.PublicKey +} + +func (p PublicKey) Marshal() ([]byte, error) { + return p.p.Marshal() +} + +func (p *PublicKey) Unmarshal(data []byte) error { + p.p = &tdh2.PublicKey{} + return p.p.Unmarshal(data) +} + +// MasterSecret encodes TDH2 master key. +type MasterSecret struct { + m *tdh2.MasterSecret +} + +func (m MasterSecret) Marshal() ([]byte, error) { + return m.m.Marshal() +} + +func (m *MasterSecret) Unmarshal(data []byte) error { + m.m = &tdh2.MasterSecret{} + return m.m.Unmarshal(data) +} + +// Ciphertext encodes hybrid ciphertext. +type Ciphertext struct { + tdh2Ctxt *tdh2.Ciphertext + symCtxt []byte + nonce []byte +} + +// Decrypt returns a decryption share for the ciphertext. +func Decrypt(c *Ciphertext, x_i *PrivateShare) (*DecryptionShare, error) { + xof, err := xof() + if err != nil { + return nil, err + } + d, err := c.tdh2Ctxt.Decrypt(defaultGroup, x_i.p, xof) + if err != nil { + return nil, err + } + return &DecryptionShare{d}, nil +} + +// VerifyShare checks if the share matches the ciphertext and public key. +func VerifyShare(c *Ciphertext, pk *PublicKey, share *DecryptionShare) error { + return tdh2.VerifyShare(pk.p, c.tdh2Ctxt, share.d) +} + +// Aggregate decrypts the TDH2-encrypted key and using it recovers the +// symmetrically encrypted plaintext. It takes decryption shares and +// the total number of participants as the arguments. +// Ciphertext and shares MUST be verified before calling Aggregate. +func Aggregate(c *Ciphertext, shares []*DecryptionShare, n int) ([]byte, error) { + sh := []*tdh2.DecryptionShare{} + for _, s := range shares { + sh = append(sh, s.d) + } + key, err := c.tdh2Ctxt.CombineShares(defaultGroup, sh, len(sh), n) + if err != nil { + return nil, fmt.Errorf("cannot combine shares: %w", err) + } + if aes256KeySize != len(key) { + return nil, fmt.Errorf("incorrect key size") + } + return symDecrypt(c.nonce, c.symCtxt, key) +} + +// xof returns xof used for providing randomness. +func xof() (kyber.XOF, error) { + seed := make([]byte, 64) + if _, err := rand.Read(seed); err != nil { + return nil, fmt.Errorf("cannot generate seed: %w", err) + } + return keccak.New(seed), nil +} + +type ciphertextRaw struct { + TDH2Ctxt []byte + SymCtxt []byte + Nonce []byte +} + +func (c *Ciphertext) Marshal() ([]byte, error) { + ctxt, err := c.tdh2Ctxt.Marshal() + if err != nil { + return nil, fmt.Errorf("cannot marshal TDH2 ciphertext: %w", err) + } + return json.Marshal(&ciphertextRaw{ + TDH2Ctxt: ctxt, + SymCtxt: c.symCtxt, + Nonce: c.nonce, + }) +} + +// UnmarshalVerify unmarshals ciphertext and verifies if it matches the public key. +func (c *Ciphertext) UnmarshalVerify(data []byte, pk *PublicKey) error { + var raw ciphertextRaw + if err := json.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("cannot unmarshal data: %w", err) + } + c.symCtxt = raw.SymCtxt + c.nonce = raw.Nonce + c.tdh2Ctxt = &tdh2.Ciphertext{} + if err := c.tdh2Ctxt.Unmarshal(raw.TDH2Ctxt); err != nil { + return fmt.Errorf("cannot unmarshal TDH2 ciphertext: %w", err) + } + + if err := c.tdh2Ctxt.Verify(pk.p); err != nil { + return fmt.Errorf("tdh2 ciphertext verification: %w", err) + } + return nil +} + +// GenerateKeys generates and returns, the master secret, public key, and private shares. It takes the +// total number of nodes n and a threshold k (the number of shares sufficient for decryption). +func GenerateKeys(k, n int) (*MasterSecret, *PublicKey, []*PrivateShare, error) { + xof, err := xof() + if err != nil { + return nil, nil, nil, err + } + ms, pk, sh, err := tdh2.GenerateKeys(defaultGroup, nil, k, n, xof) + if err != nil { + return nil, nil, nil, err + } + shares := []*PrivateShare{} + for i := range sh { + shares = append(shares, &PrivateShare{sh[i]}) + } + return &MasterSecret{ms}, &PublicKey{pk}, shares, nil +} + +// Redeal re-deals private shares such that new quorums can decrypt old ciphertexts. +// It takes the previous public key and master secret as well as the number of nodes +// sufficient for decrypt k, and the total number of nodes n. It returns a new public +// key and private shares. The master secret passed corresponds to the public key returned. +// The old public key can still be used for encryption but it cannot be used for share +// verification (the new key has to be used instead). +func Redeal(pk *PublicKey, ms *MasterSecret, k, n int) (*PublicKey, []*PrivateShare, error) { + xof, err := xof() + if err != nil { + return nil, nil, err + } + p, sh, err := tdh2.Redeal(pk.p, ms.m, k, n, xof) + if err != nil { + return nil, nil, err + } + shares := []*PrivateShare{} + for i := range sh { + shares = append(shares, &PrivateShare{sh[i]}) + } + return &PublicKey{p}, shares, nil +} + +// Encrypt generates a fresh symmetric key, encrypts and authenticates +// the message with it, and encrypts the key using TDH2. It returns a +// struct encoding the generated ciphertexts. +func Encrypt(pk *PublicKey, msg []byte) (*Ciphertext, error) { + if aes256KeySize != tdh2.InputSize { + return nil, fmt.Errorf("incorrect key size") + } + // generate a fresh key and encrypt the message + key, err := symKey(tdh2.InputSize) + if err != nil { + return nil, fmt.Errorf("cannot generate key: %w", err) + } + // for each encryption a fresh key and nonce are generated, + // therefore the probability of nonce misuse is negligible + symCtxt, nonce, err := symEncrypt(msg, key) + if err != nil { + return nil, fmt.Errorf("cannot encrypt message: %w", err) + } + + xof, err := xof() + if err != nil { + return nil, err + } + // encrypt the key with TDH2 using empty label + tdh2Ctxt, err := tdh2.Encrypt(pk.p, key, make([]byte, tdh2.InputSize), xof) + if err != nil { + return nil, fmt.Errorf("cannot TDH2 encrypt: %w", err) + } + return &Ciphertext{ + tdh2Ctxt: tdh2Ctxt, + symCtxt: symCtxt, + nonce: nonce, + }, nil +} diff --git a/go/tdh2/tdh2easy/tdh2easy_test.go b/go/tdh2/tdh2easy/tdh2easy_test.go new file mode 100644 index 0000000..71216e3 --- /dev/null +++ b/go/tdh2/tdh2easy/tdh2easy_test.go @@ -0,0 +1,590 @@ +package tdh2easy + +import ( + "bytes" + "encoding/json" + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/goplugin/tdh2/go/tdh2/tdh2" + "go.dedis.ch/kyber/v3/group/nist" + "go.dedis.ch/kyber/v3/xof/keccak" +) + +func TestShareIndex(t *testing.T) { + _, pk, sh, err := GenerateKeys(5, 10) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + for i := range sh { + if sh[i].Index() != i { + t.Errorf("index=%v, want=%v", sh[i].Index(), i) + } + } + c, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + for i, s := range sh { + ds, err := Decrypt(c, s) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if ds.Index() != i { + t.Errorf("index=%v, want=%v", ds.Index(), i) + } + } +} + +func TestPrivateShareMarshal(t *testing.T) { + _, _, want, err := GenerateKeys(2, 3) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + b, err := want[0].Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got PrivateShare + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !reflect.DeepEqual(got.p, want[0].p) { + t.Errorf("got=%v want=%v", got, want[0]) + } + if err := got.Unmarshal([]byte("broken")); err == nil { + t.Errorf("Unmarshal did not fail") + } +} + +func TestDecryptionShareMarshal(t *testing.T) { + _, pk, sh, err := GenerateKeys(2, 3) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + c, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + want, err := Decrypt(c, sh[0]) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got DecryptionShare + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !reflect.DeepEqual(got.d, want.d) { + t.Errorf("got=%v want=%v", got, want) + } + if err := got.Unmarshal([]byte("broken")); err == nil { + t.Errorf("Unmarshal did not fail") + } +} + +func TestPublicKeyMarshal(t *testing.T) { + _, want, _, err := GenerateKeys(2, 3) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got PublicKey + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !got.p.Equal(want.p) { + t.Errorf("got=%v want=%v", got, want) + } + if err := got.Unmarshal([]byte("broken")); err == nil { + t.Errorf("Unmarshal did not fail") + } +} + +func TestMasterSecretMarshal(t *testing.T) { + want, _, _, err := GenerateKeys(2, 3) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got MasterSecret + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !reflect.DeepEqual(got.m, want.m) { + t.Errorf("got=%v want=%v", got, want) + } + if err := got.Unmarshal([]byte("broken")); err == nil { + t.Errorf("Unmarshal did not fail") + } +} + +func TestCiphertextDecrypt(t *testing.T) { + _, pk, share, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + _, _, wrong, err := tdh2.GenerateKeys(nist.NewBlakeSHA256QR512(), nil, 1, 1, keccak.New(nil)) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + c, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + if _, err := Decrypt(c, share[0]); err != nil { + t.Errorf("Decrypt: %v", err) + } + if _, err := Decrypt(c, &PrivateShare{wrong[0]}); err == nil { + t.Errorf("Decrypt did not fail") + } +} + +func TestCiphertextVerifyShare(t *testing.T) { + _, pk, share, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + _, _, wrongShare, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + c, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + ds, err := Decrypt(c, share[0]) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + wrongDs, err := Decrypt(c, wrongShare[0]) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if err := VerifyShare(c, pk, ds); err != nil { + t.Errorf("VerifyShare: %v", err) + } + if err := VerifyShare(c, pk, wrongDs); err == nil { + t.Errorf("VerifyShare did not fail") + } +} + +func TestAggregate(t *testing.T) { + k := 3 + n := 5 + _, pk, shares, err := GenerateKeys(k, n) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + msg := []byte("message") + c, err := Encrypt(pk, msg) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + decShares := make([]*DecryptionShare, n) + for i := range shares { + ds, err := Decrypt(c, shares[i]) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + decShares[i] = ds + } + for _, tc := range []struct { + name string + ctxt *Ciphertext + shares []*DecryptionShare + err error + }{ + { + name: "OK (all shares)", + ctxt: c, + shares: decShares, + }, + { + name: "OK (min shares)", + ctxt: c, + shares: decShares[:k], + }, + { + name: "not enough shares", + ctxt: c, + shares: decShares[:2], + err: cmpopts.AnyError, + }, + { + name: "wrong nonce", + ctxt: &Ciphertext{ + tdh2Ctxt: c.tdh2Ctxt, + symCtxt: c.symCtxt, + nonce: make([]byte, len(c.nonce)), + }, + shares: decShares, + err: cmpopts.AnyError, + }, + { + name: "wrong nonce size", + ctxt: &Ciphertext{ + tdh2Ctxt: c.tdh2Ctxt, + symCtxt: c.symCtxt, + nonce: []byte("nonce"), + }, + shares: decShares, + err: cmpopts.AnyError, + }, + { + name: "wrong symmetric ciphertext", + ctxt: &Ciphertext{ + tdh2Ctxt: c.tdh2Ctxt, + symCtxt: []byte("ciphertext"), + nonce: c.nonce, + }, + shares: decShares, + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + out, err := Aggregate(tc.ctxt, tc.shares, n) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + if diff := cmp.Diff(msg, out); diff != "" { + t.Errorf("encrypted decrypted message diff=%v", diff) + } + }) + } +} + +func TestCiphertextMarshal(t *testing.T) { + _, pk, _, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + want, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got Ciphertext + if err := got.UnmarshalVerify(b, pk); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if d := cmp.Diff(got.symCtxt, want.symCtxt); d != "" { + t.Errorf("got/want Ciphertext diff=%v", d) + } + if d := cmp.Diff(got.nonce, want.nonce); d != "" { + t.Errorf("got/want Nonce diff=%v", d) + } + if !got.tdh2Ctxt.Equal(want.tdh2Ctxt) { + t.Errorf("different ciphertexts") + } +} + +func TestCiphertextUnmarshal(t *testing.T) { + _, pk, _, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + _, wrong, _, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + c, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + cRaw, err := c.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + brokenTdh2, err := json.Marshal(&ciphertextRaw{ + TDH2Ctxt: []byte("broken"), + SymCtxt: []byte("ciphertext"), + Nonce: []byte("nonce"), + }) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + for _, tc := range []struct { + name string + raw []byte + pk *PublicKey + err error + }{ + { + name: "ok", + raw: cRaw, + pk: pk, + }, + { + name: "wrong pk", + raw: cRaw, + pk: wrong, + err: cmpopts.AnyError, + }, + { + name: "broken", + raw: []byte("broken"), + pk: pk, + err: cmpopts.AnyError, + }, + { + name: "broken tdh2 ciphertext", + raw: brokenTdh2, + pk: pk, + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + var hc Ciphertext + if err := hc.UnmarshalVerify(tc.raw, tc.pk); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("got err=%v, want=%v", err, tc.err) + } + }) + } +} + +func TestRedealEncryptNew(t *testing.T) { + ms, pk, _, err := GenerateKeys(3, 5) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + want := []byte("msg") + for _, tc := range []struct { + name string + k, n int + }{ + { + name: "same n,k", + k: 3, + n: 5, + }, + { + name: "smaller quorum", + k: 2, + n: 5, + }, + { + name: "larger quorum", + k: 4, + n: 5, + }, + } { + t.Run(tc.name, func(t *testing.T) { + // generate new instance + newPk, shares, err := Redeal(pk, ms, tc.k, tc.n) + if err != nil { + t.Fatalf("Redeal: %v", err) + } + // encrypt and decrypt using new keys + c, err := Encrypt(newPk, want) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + ds := []*DecryptionShare{} + for _, sh := range shares { + d, err := Decrypt(c, sh) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if err := VerifyShare(c, newPk, d); err != nil { + t.Fatalf("VerifyShare: %v", err) + } + ds = append(ds, d) + } + if got, err := Aggregate(c, ds[:tc.k], tc.n); err != nil { + t.Errorf("Aggregate: %v", err) + } else if !cmp.Equal(got, want) { + t.Errorf("got=%v, want=%v", got, want) + } + }) + } +} + +func TestRedealDecryptOld(t *testing.T) { + ms, pk, _, err := GenerateKeys(3, 5) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + want := []byte("msg") + c, err := Encrypt(pk, want) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + for _, tc := range []struct { + name string + k, n int + }{ + { + name: "same n,k", + k: 3, + n: 5, + }, + { + name: "smaller quorum", + k: 2, + n: 5, + }, + { + name: "larger quorum", + k: 4, + n: 5, + }, + } { + t.Run(tc.name, func(t *testing.T) { + // generate new instance + new, shares, err := Redeal(pk, ms, tc.k, tc.n) + if err != nil { + t.Fatalf("Redeal: %v", err) + } + // try to decrypt old ciphertext + ds := []*DecryptionShare{} + for _, sh := range shares { + d, err := Decrypt(c, sh) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if err := VerifyShare(c, new, d); err != nil { + t.Fatalf("VerifyShare: %v", err) + } + ds = append(ds, d) + } + // should fail w/o enough shares + if _, err := Aggregate(c, ds[:tc.k-1], tc.n); err == nil { + t.Error("Aggregate did not fail") + } + // try with enough shares + if got, err := Aggregate(c, ds[:tc.k], tc.n); err != nil { + t.Errorf("Aggregate: %v", err) + } else if !cmp.Equal(got, want) { + t.Errorf("got=%v, want=%v", got, want) + } + }) + } +} + +func TestRedealReuseOldShares(t *testing.T) { + ms, pk, shares, err := GenerateKeys(3, 5) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + newPk, _, err := Redeal(pk, ms, 3, 5) + if err != nil { + t.Fatalf("Redeal: %v", err) + } + c, err := Encrypt(newPk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + // use old share for decryption + ds, err := Decrypt(c, shares[0]) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + // make sure old shares cannot be used for new encryptions + if err := VerifyShare(c, newPk, ds); err == nil { + t.Error("VerifyShare did not fail") + } +} + +func FuzzCiphertextMarshal(f *testing.F) { + _, pk, _, err := GenerateKeys(1, 1) + if err != nil { + f.Fatalf("Keys: %v", err) + } + xof, err := xof() + if err != nil { + f.Fatalf("xof: %v", err) + } + tdh2Input := make([]byte, tdh2.InputSize) + f.Add(tdh2Input, []byte("symcCtxt"), []byte("nonce")) + f.Fuzz(func(t *testing.T, key, symCtxt, nonce []byte) { + if len(key) != tdh2.InputSize { + t.Skip() + } + tdh2Ctxt, err := tdh2.Encrypt(pk.p, key, tdh2Input, xof) + if err != nil { + t.Fatalf("Encrypt(%v): %v", key, err) + } + want := Ciphertext{ + tdh2Ctxt: tdh2Ctxt, + symCtxt: symCtxt, + nonce: nonce, + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal(%v): %v", want, err) + } + var got Ciphertext + if err := got.UnmarshalVerify(b, pk); err != nil { + t.Fatalf("UnmarshalVerify(%v): %v", b, err) + } + }) +} + +func FuzzCiphertextUnmarshal(f *testing.F) { + _, pk, _, err := GenerateKeys(1, 1) + if err != nil { + f.Fatalf("Keys: %v", err) + } + tdh2Ctxt, err := tdh2.Encrypt(pk.p, make([]byte, tdh2.InputSize), make([]byte, tdh2.InputSize), keccak.New(nil)) + if err != nil { + f.Fatalf("Encrypt: %v", err) + } + c := Ciphertext{ + tdh2Ctxt: tdh2Ctxt, + symCtxt: []byte("symCtxt"), + nonce: []byte("nonce"), + } + b, err := c.Marshal() + if err != nil { + f.Fatalf("Marshal: %v", err) + } + f.Add(b) + f.Fuzz(func(t *testing.T, data []byte) { + var c1, c2 Ciphertext + if err := c1.UnmarshalVerify(data, pk); err != nil { + t.Skip() + } + data1, err := c1.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data, err) + } + if err := c2.UnmarshalVerify(data1, pk); err != nil { + t.Fatalf("Cannot unmarshal: data=%v err=%v", data1, err) + } + data2, err := c2.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data2, err) + } + if !bytes.Equal(data1, data2) { + t.Errorf("data1=%v data2=%v", data1, data2) + } + if !bytes.Equal(c1.symCtxt, c2.symCtxt) { + t.Errorf("c1.symCtxt=%v data1=%v c2.symCtxt=%v data2=%v", c1.symCtxt, data1, c2.symCtxt, data2) + + } + if !bytes.Equal(c1.nonce, c2.nonce) { + t.Errorf("c1.nonce=%v data1=%v c2.nonce=%v data2=%v", c1.nonce, data1, c2.nonce, data2) + + } + if !c1.tdh2Ctxt.Equal(c2.tdh2Ctxt) { + t.Errorf("c1.tdh2Ctxt=%v data1=%v c2.tdh2Ctxt=%v data2=%v", c1.tdh2Ctxt, data1, c2.tdh2Ctxt, data2) + } + }) +} diff --git a/js/tdh2/README.md b/js/tdh2/README.md new file mode 100644 index 0000000..d9eefac --- /dev/null +++ b/js/tdh2/README.md @@ -0,0 +1,3 @@ +See test/ for an example. + +A TS project (Uniswap v2 interface) required `echo "/* global BigInt */" >> node_modules/bcrypto/lib/js/bn.js` after `npm install` diff --git a/js/tdh2/decs.d.ts b/js/tdh2/decs.d.ts new file mode 100644 index 0000000..565d411 --- /dev/null +++ b/js/tdh2/decs.d.ts @@ -0,0 +1 @@ +declare module "tdh2" diff --git a/js/tdh2/package-lock.json b/js/tdh2/package-lock.json new file mode 100644 index 0000000..2c1987e --- /dev/null +++ b/js/tdh2/package-lock.json @@ -0,0 +1,73 @@ +{ + "name": "tdh2", + "lockfileVersion": 2, + "requires": true, + "packages": { + "": { + "dependencies": { + "bcrypto": "^5.4.0", + "uniq": "^1.0.1" + } + }, + "node_modules/bcrypto": { + "version": "5.4.0", + "resolved": "https://registry.npmjs.org/bcrypto/-/bcrypto-5.4.0.tgz", + "integrity": "sha512-KDX2CR29o6ZoqpQndcCxFZAtYA1jDMnXU3jmCfzP44g++Cu7AHHtZN/JbrN/MXAg9SLvtQ8XISG+eVD9zH1+Jg==", + "hasInstallScript": true, + "dependencies": { + "bufio": "~1.0.7", + "loady": "~0.0.5" + }, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/bufio": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/bufio/-/bufio-1.0.7.tgz", + "integrity": "sha512-bd1dDQhiC+bEbEfg56IdBv7faWa6OipMs/AFFFvtFnB3wAYjlwQpQRZ0pm6ZkgtfL0pILRXhKxOiQj6UzoMR7A==", + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/loady": { + "version": "0.0.5", + "resolved": "https://registry.npmjs.org/loady/-/loady-0.0.5.tgz", + "integrity": "sha512-uxKD2HIj042/HBx77NBcmEPsD+hxCgAtjEWlYNScuUjIsh/62Uyu39GOR68TBR68v+jqDL9zfftCWoUo4y03sQ==", + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/uniq": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/uniq/-/uniq-1.0.1.tgz", + "integrity": "sha512-Gw+zz50YNKPDKXs+9d+aKAjVwpjNwqzvNpLigIruT4HA9lMZNdMqs9x07kKHB/L9WRzqp4+DlTU5s4wG2esdoA==" + } + }, + "dependencies": { + "bcrypto": { + "version": "5.4.0", + "resolved": "https://registry.npmjs.org/bcrypto/-/bcrypto-5.4.0.tgz", + "integrity": "sha512-KDX2CR29o6ZoqpQndcCxFZAtYA1jDMnXU3jmCfzP44g++Cu7AHHtZN/JbrN/MXAg9SLvtQ8XISG+eVD9zH1+Jg==", + "requires": { + "bufio": "~1.0.7", + "loady": "~0.0.5" + } + }, + "bufio": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/bufio/-/bufio-1.0.7.tgz", + "integrity": "sha512-bd1dDQhiC+bEbEfg56IdBv7faWa6OipMs/AFFFvtFnB3wAYjlwQpQRZ0pm6ZkgtfL0pILRXhKxOiQj6UzoMR7A==" + }, + "loady": { + "version": "0.0.5", + "resolved": "https://registry.npmjs.org/loady/-/loady-0.0.5.tgz", + "integrity": "sha512-uxKD2HIj042/HBx77NBcmEPsD+hxCgAtjEWlYNScuUjIsh/62Uyu39GOR68TBR68v+jqDL9zfftCWoUo4y03sQ==" + }, + "uniq": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/uniq/-/uniq-1.0.1.tgz", + "integrity": "sha512-Gw+zz50YNKPDKXs+9d+aKAjVwpjNwqzvNpLigIruT4HA9lMZNdMqs9x07kKHB/L9WRzqp4+DlTU5s4wG2esdoA==" + } + } +} diff --git a/js/tdh2/package.json b/js/tdh2/package.json new file mode 100644 index 0000000..a403ca9 --- /dev/null +++ b/js/tdh2/package.json @@ -0,0 +1,6 @@ +{ + "dependencies": { + "bcrypto": "^5.4.0", + "uniq": "^1.0.1" + } +} diff --git a/js/tdh2/tdh2.js b/js/tdh2/tdh2.js new file mode 100644 index 0000000..6181cbd --- /dev/null +++ b/js/tdh2/tdh2.js @@ -0,0 +1,126 @@ +const rnd = require('bcrypto/lib/random'); +const sha256 = require('bcrypto/lib/sha256'); +const elliptic = require('bcrypto/lib/js/elliptic'); +const cipher = require('bcrypto/lib/cipher'); + +const { + ShortCurve, + EdwardsCurve, + curves +} = elliptic; + +const { + Cipher, + Decipher, + enc, + dec +} = cipher; + + +const p256 = new curves.P256(); +const groupName = "P256"; +const tdh2InputSize = 32; + +function toHexString(byteArray) { + return Array.from(byteArray, function(byte) { + return ('0' + (byte & 0xFF).toString(16)).slice(-2); + }).join('') +} + +function tdh2Encrypt(pub, msg, label) { + if (pub.Group != groupName) + throw Error('invalid group') + g_bar = p256.decodePoint(Buffer.from(pub.G_bar, 'base64')) + h = p256.decodePoint(Buffer.from(pub.H, 'base64')) + + const r = p256.randomScalar(rnd); + const s = p256.randomScalar(rnd); + + const c = xor(hash1(h.mul(r)), msg) + + const u = p256.g.mul(r) + const w = p256.g.mul(s) + const uBar = g_bar.mul(r) + const wBar = g_bar.mul(s) + + const e = hash2(c, label, u, w, uBar, wBar) + const f = s.add(r.mul(e).mod(p256.n)).mod(p256.n) + + return JSON.stringify({ + Group: groupName, + C: c.toString('base64'), + Label: label.toString('base64'), + U: p256.encodePoint(u, false).toString('base64'), + U_bar: p256.encodePoint(uBar, false).toString('base64'), + E: p256.encodeScalar(e).toString('base64'), + F: p256.encodeScalar(f).toString('base64'), + }) +} + +function concatenate(points) { + var out = groupName; + for (let i = 0; i < points.length; i++) { + out += "," + toHexString(p256.encodePoint(points[i], false)); + } + + return Buffer.from(out); +} + +function hash1(point) { + return sha256.digest(Buffer.concat([ + Buffer.from("tdh2hash1"), + concatenate([point]) + ])); +} + +function hash2(msg, label, p1, p2, p3, p4) { + if (msg.length != tdh2InputSize) + throw new Error('message has incorrect length'); + + if (label.length != tdh2InputSize) + throw new Error('label has incorrect length'); + + const h = sha256.digest(Buffer.concat([ + Buffer.from("tdh2hash2"), + msg, + label, + concatenate([p1,p2,p3,p4]) + ])); + + return p256.decodeScalar(h) +} + +function xor(a, b) { + if (a.length != b.length) + throw new Error('buffers with different lengths'); + + var out = Buffer.alloc(a.length) + for (var i = 0; i < a.length; i++) { + out[i] = a[i] ^ b[i]; + } + + return out +} + +function encrypt(pub, msg) { + const ciph = new Cipher('AES-256-GCM'); + const key = rnd.randomBytes(tdh2InputSize); + const nonce = rnd.randomBytes(12); + + ciph.init(key, nonce); + const ctxt = Buffer.concat([ + ciph.update(msg), + ciph.final(), + ciph.getAuthTag() + ]); + + const tdh2Ctxt = tdh2Encrypt(pub, key, Buffer.alloc(tdh2InputSize)); + + return JSON.stringify({ + TDH2Ctxt: Buffer.from(tdh2Ctxt).toString('base64'), + SymCtxt: ctxt.toString('base64'), + Nonce: nonce.toString('base64'), + }) +} + +module.exports = { encrypt } diff --git a/js/tdh2/test/js_test.go b/js/tdh2/test/js_test.go new file mode 100644 index 0000000..ee8148e --- /dev/null +++ b/js/tdh2/test/js_test.go @@ -0,0 +1,59 @@ +package test + +import ( + "bytes" + "encoding/base64" + "os/exec" + "testing" + + "github.com/goplugin/tdh2/go/tdh2easy" +) + +func TestJS(t *testing.T) { + _, pk, sh, err := tdh2easy.GenerateKeys(2, 3) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + b, err := pk.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + + cmdArgs := []string{"test.js", string(b)} + cmd := exec.Command("node", cmdArgs...) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("Failed to run test.js: %s", err) + } + pairs := bytes.Split(output, []byte("\n")) + // it contains the last empty newline + if len(pairs) < 3 || len(pairs)%2 == 0 { + t.Fatalf("Incorrect script output: %v", pairs) + } + pairs = pairs[:len(pairs)-1] + for i := 0; i < len(pairs)/2; i++ { + want, err := base64.StdEncoding.DecodeString(string(pairs[2*i])) + if err != nil { + t.Fatalf("b64Decode: %v", err) + } + var c tdh2easy.Ciphertext + if err := c.UnmarshalVerify(pairs[2*i+1], pk); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + dec := []*tdh2easy.DecryptionShare{} + for _, s := range sh { + d, err := tdh2easy.Decrypt(&c, s) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + dec = append(dec, d) + } + got, err := tdh2easy.Aggregate(&c, dec, 3) + if err != nil { + t.Fatalf("Aggregate: %v", err) + } + if !bytes.Equal(got, want) { + t.Errorf("got=%v; want=%v", got, want) + } + } +} diff --git a/js/tdh2/test/package.json b/js/tdh2/test/package.json new file mode 100644 index 0000000..3dbc1ca --- /dev/null +++ b/js/tdh2/test/package.json @@ -0,0 +1,3 @@ +{ + "type": "module" +} diff --git a/js/tdh2/test/test.js b/js/tdh2/test/test.js new file mode 100644 index 0000000..5d01236 --- /dev/null +++ b/js/tdh2/test/test.js @@ -0,0 +1,11 @@ +import { randomBytes, randomInt } from 'crypto' +import pkg from '../tdh2.js'; +const { encrypt } = pkg; + +const pub = JSON.parse(process.argv.slice(2)[0]); + +for (let i = 0; i < 100; i++) { + const msg = randomBytes(randomInt(1, 5000)) + console.log(msg.toString('base64')) + console.log(encrypt(pub, msg)) +}