From fc263af752afa59d5b8cf1bf7afd2a21af1574d7 Mon Sep 17 00:00:00 2001 From: nareshmmr Date: Tue, 26 Nov 2024 11:42:10 +0100 Subject: [PATCH] base 1 commit --- .github/workflows/go.yaml | 31 + go/go.work.sum | 2 + go/ocr2/decryptionplugin/config/config.go | 22 +- .../config/config_types.pb.go | 61 +- .../config/config_types.proto | 3 - .../config/mocks/config_parser.go | 54 ++ go/ocr2/decryptionplugin/decryption.go | 43 +- go/ocr2/decryptionplugin/decryption_test.go | 771 +++++++++++++++++- go/ocr2/decryptionplugin/go.mod | 5 +- go/ocr2/decryptionplugin/go.sum | 5 +- go/tdh2/go.mod | 11 +- go/tdh2/go.sum | 20 - go/tdh2/internal/group/LICENSE | 377 +++++++++ go/tdh2/internal/group/group.go | 156 ++++ go/tdh2/internal/group/mod/int.go | 219 +++++ go/tdh2/internal/group/mod/int_test.go | 47 ++ go/tdh2/internal/group/nist/curve.go | 246 ++++++ go/tdh2/internal/group/nist/group_test.go | 178 ++++ go/tdh2/internal/group/share/poly.go | 152 ++++ go/tdh2/internal/group/share/poly_test.go | 169 ++++ go/tdh2/internal/group/test/group.go | 153 ++++ go/tdh2/internal/group/test/test.go | 349 ++++++++ go/tdh2/tdh2/tdh2.go | 150 ++-- go/tdh2/tdh2/tdh2_test.go | 84 +- {js/tdh2/test => go/tdh2/tdh2easy}/js_test.go | 18 +- go/tdh2/tdh2easy/tdh2easy.go | 44 +- go/tdh2/tdh2easy/tdh2easy_test.go | 21 +- 27 files changed, 3127 insertions(+), 264 deletions(-) create mode 100644 .github/workflows/go.yaml create mode 100644 go/ocr2/decryptionplugin/config/mocks/config_parser.go create mode 100644 go/tdh2/internal/group/LICENSE create mode 100644 go/tdh2/internal/group/group.go create mode 100644 go/tdh2/internal/group/mod/int.go create mode 100644 go/tdh2/internal/group/mod/int_test.go create mode 100644 go/tdh2/internal/group/nist/curve.go create mode 100644 go/tdh2/internal/group/nist/group_test.go create mode 100644 go/tdh2/internal/group/share/poly.go create mode 100644 go/tdh2/internal/group/share/poly_test.go create mode 100644 go/tdh2/internal/group/test/group.go create mode 100644 go/tdh2/internal/group/test/test.go rename {js/tdh2/test => go/tdh2/tdh2easy}/js_test.go (78%) diff --git a/.github/workflows/go.yaml b/.github/workflows/go.yaml new file mode 100644 index 0000000..baef4ad --- /dev/null +++ b/.github/workflows/go.yaml @@ -0,0 +1,31 @@ +name: Go package + +on: [push] + +jobs: + build: + + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.20' + + - name: Build and test plugin + working-directory: ./go/ocr2/decryptionplugin + run: | + go build -v ./... + go test -v ./... + + - name: Download npm deps + working-directory: ./js/tdh2 + run: npm install + + - name: Build and test TDH2 + working-directory: ./go/tdh2 + run: | + go build -v ./... + go test -v ./... diff --git a/go/go.work.sum b/go/go.work.sum index 8e5ea79..44a132b 100644 --- a/go/go.work.sum +++ b/go/go.work.sum @@ -1,5 +1,7 @@ filippo.io/edwards25519 v1.0.0-rc.1/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= github.com/gtank/ristretto255 v0.1.3-0.20210930101514-6bb39798585c/go.mod h1:tDPFhGdt3hJWqtKwx57i9baiB1Cj0yAg22VOPUqm5vY= +github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= diff --git a/go/ocr2/decryptionplugin/config/config.go b/go/ocr2/decryptionplugin/config/config.go index aa53ea8..0bc3c31 100644 --- a/go/ocr2/decryptionplugin/config/config.go +++ b/go/ocr2/decryptionplugin/config/config.go @@ -1,10 +1,6 @@ package config import ( - "fmt" - "math" - - "github.com/goplugin/plugin-libocr/commontypes" "google.golang.org/protobuf/proto" ) @@ -26,16 +22,14 @@ func EncodeReportingPluginConfig(rpConfig *ReportingPluginConfigWrapper) ([]byte return proto.Marshal(rpConfig.Config) } -func EncodeOracleIdtoKeyShareIndex(oracleID commontypes.OracleID, keyShareIndex int) *OracleIDtoKeyShareIndex { - return &OracleIDtoKeyShareIndex{ - OracleId: uint32(oracleID), - KeyShareIndex: uint32(keyShareIndex), - } +//go:generate mockery --quiet --name ConfigParser --output ./mocks/ --case=underscore +type ConfigParser interface { + ParseConfig(offchainConfig []byte) (*ReportingPluginConfigWrapper, error) } -func DecodeOracleIdtoKeyShareIndex(oracleIDtoKeyShareIndex *OracleIDtoKeyShareIndex) (commontypes.OracleID, int, error) { - if oracleIDtoKeyShareIndex.OracleId > math.MaxUint8 { - return 0, 0, fmt.Errorf("oracleID is larger than MAX_UINT8") - } - return commontypes.OracleID(oracleIDtoKeyShareIndex.OracleId), int(oracleIDtoKeyShareIndex.KeyShareIndex), nil +type DefaultConfigParser struct { +} + +func (p *DefaultConfigParser) ParseConfig(offchainConfig []byte) (*ReportingPluginConfigWrapper, error) { + return DecodeReportingPluginConfig(offchainConfig) } diff --git a/go/ocr2/decryptionplugin/config/config_types.pb.go b/go/ocr2/decryptionplugin/config/config_types.pb.go index 282a6a5..a1b5bf5 100644 --- a/go/ocr2/decryptionplugin/config/config_types.pb.go +++ b/go/ocr2/decryptionplugin/config/config_types.pb.go @@ -80,15 +80,12 @@ type ReportingPluginConfig struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - MaxQueryLengthBytes uint32 `protobuf:"varint,1,opt,name=max_query_length_bytes,json=maxQueryLengthBytes,proto3" json:"max_query_length_bytes,omitempty"` - MaxObservationLengthBytes uint32 `protobuf:"varint,2,opt,name=max_observation_length_bytes,json=maxObservationLengthBytes,proto3" json:"max_observation_length_bytes,omitempty"` - MaxReportLengthBytes uint32 `protobuf:"varint,3,opt,name=max_report_length_bytes,json=maxReportLengthBytes,proto3" json:"max_report_length_bytes,omitempty"` - RequestCountLimit uint32 `protobuf:"varint,4,opt,name=request_count_limit,json=requestCountLimit,proto3" json:"request_count_limit,omitempty"` - RequestTotalBytesLimit uint32 `protobuf:"varint,5,opt,name=request_total_bytes_limit,json=requestTotalBytesLimit,proto3" json:"request_total_bytes_limit,omitempty"` - RequireLocalRequestCheck bool `protobuf:"varint,6,opt,name=require_local_request_check,json=requireLocalRequestCheck,proto3" json:"require_local_request_check,omitempty"` - PublicKey []byte `protobuf:"bytes,7,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"` - PrivKeyShare []byte `protobuf:"bytes,8,opt,name=priv_key_share,json=privKeyShare,proto3" json:"priv_key_share,omitempty"` - OracleIdToKeyIndex []*OracleIDtoKeyShareIndex `protobuf:"bytes,9,rep,name=oracle_id_to_key_index,json=oracleIdToKeyIndex,proto3" json:"oracle_id_to_key_index,omitempty"` + MaxQueryLengthBytes uint32 `protobuf:"varint,1,opt,name=max_query_length_bytes,json=maxQueryLengthBytes,proto3" json:"max_query_length_bytes,omitempty"` + MaxObservationLengthBytes uint32 `protobuf:"varint,2,opt,name=max_observation_length_bytes,json=maxObservationLengthBytes,proto3" json:"max_observation_length_bytes,omitempty"` + MaxReportLengthBytes uint32 `protobuf:"varint,3,opt,name=max_report_length_bytes,json=maxReportLengthBytes,proto3" json:"max_report_length_bytes,omitempty"` + RequestCountLimit uint32 `protobuf:"varint,4,opt,name=request_count_limit,json=requestCountLimit,proto3" json:"request_count_limit,omitempty"` + RequestTotalBytesLimit uint32 `protobuf:"varint,5,opt,name=request_total_bytes_limit,json=requestTotalBytesLimit,proto3" json:"request_total_bytes_limit,omitempty"` + RequireLocalRequestCheck bool `protobuf:"varint,6,opt,name=require_local_request_check,json=requireLocalRequestCheck,proto3" json:"require_local_request_check,omitempty"` } func (x *ReportingPluginConfig) Reset() { @@ -165,27 +162,6 @@ func (x *ReportingPluginConfig) GetRequireLocalRequestCheck() bool { return false } -func (x *ReportingPluginConfig) GetPublicKey() []byte { - if x != nil { - return x.PublicKey - } - return nil -} - -func (x *ReportingPluginConfig) GetPrivKeyShare() []byte { - if x != nil { - return x.PrivKeyShare - } - return nil -} - -func (x *ReportingPluginConfig) GetOracleIdToKeyIndex() []*OracleIDtoKeyShareIndex { - if x != nil { - return x.OracleIdToKeyIndex - } - return nil -} - var File_config_config_types_proto protoreflect.FileDescriptor var file_config_config_types_proto_rawDesc = []byte{ @@ -197,7 +173,7 @@ var file_config_config_types_proto_rawDesc = []byte{ 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x6f, 0x72, 0x61, 0x63, 0x6c, 0x65, 0x49, 0x64, 0x12, 0x26, 0x0a, 0x0f, 0x6b, 0x65, 0x79, 0x5f, 0x73, 0x68, 0x61, 0x72, 0x65, 0x5f, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0d, 0x6b, 0x65, 0x79, 0x53, - 0x68, 0x61, 0x72, 0x65, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x22, 0x8e, 0x04, 0x0a, 0x15, 0x52, 0x65, + 0x68, 0x61, 0x72, 0x65, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x22, 0xee, 0x02, 0x0a, 0x15, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x33, 0x0a, 0x16, 0x6d, 0x61, 0x78, 0x5f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, @@ -220,17 +196,7 @@ var file_config_config_types_proto_rawDesc = []byte{ 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x5f, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x18, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x75, - 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, - 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x0e, 0x70, 0x72, 0x69, - 0x76, 0x5f, 0x6b, 0x65, 0x79, 0x5f, 0x73, 0x68, 0x61, 0x72, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x0c, 0x70, 0x72, 0x69, 0x76, 0x4b, 0x65, 0x79, 0x53, 0x68, 0x61, 0x72, 0x65, 0x12, - 0x59, 0x0a, 0x16, 0x6f, 0x72, 0x61, 0x63, 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x5f, 0x74, 0x6f, 0x5f, - 0x6b, 0x65, 0x79, 0x5f, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x18, 0x09, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x25, 0x2e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4f, - 0x72, 0x61, 0x63, 0x6c, 0x65, 0x49, 0x44, 0x74, 0x6f, 0x4b, 0x65, 0x79, 0x53, 0x68, 0x61, 0x72, - 0x65, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x52, 0x12, 0x6f, 0x72, 0x61, 0x63, 0x6c, 0x65, 0x49, 0x64, - 0x54, 0x6f, 0x4b, 0x65, 0x79, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x42, 0x0b, 0x5a, 0x09, 0x2e, 0x2f, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x42, 0x0b, 0x5a, 0x09, 0x2e, 0x2f, 0x3b, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } @@ -252,12 +218,11 @@ var file_config_config_types_proto_goTypes = []interface{}{ (*ReportingPluginConfig)(nil), // 1: config_types.ReportingPluginConfig } var file_config_config_types_proto_depIdxs = []int32{ - 0, // 0: config_types.ReportingPluginConfig.oracle_id_to_key_index:type_name -> config_types.OracleIDtoKeyShareIndex - 1, // [1:1] is the sub-list for method output_type - 1, // [1:1] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name } func init() { file_config_config_types_proto_init() } diff --git a/go/ocr2/decryptionplugin/config/config_types.proto b/go/ocr2/decryptionplugin/config/config_types.proto index 40153a4..df9269d 100644 --- a/go/ocr2/decryptionplugin/config/config_types.proto +++ b/go/ocr2/decryptionplugin/config/config_types.proto @@ -16,7 +16,4 @@ message ReportingPluginConfig { uint32 request_count_limit = 4; uint32 request_total_bytes_limit = 5; bool require_local_request_check = 6; - bytes public_key = 7; - bytes priv_key_share = 8; - repeated OracleIDtoKeyShareIndex oracle_id_to_key_index = 9; } \ No newline at end of file diff --git a/go/ocr2/decryptionplugin/config/mocks/config_parser.go b/go/ocr2/decryptionplugin/config/mocks/config_parser.go new file mode 100644 index 0000000..886ae11 --- /dev/null +++ b/go/ocr2/decryptionplugin/config/mocks/config_parser.go @@ -0,0 +1,54 @@ +// Code generated by mockery v2.28.1. DO NOT EDIT. + +package mocks + +import ( + config "github.com/goplugin/tdh2/go/ocr2/decryptionplugin/config" + mock "github.com/stretchr/testify/mock" +) + +// ConfigParser is an autogenerated mock type for the ConfigParser type +type ConfigParser struct { + mock.Mock +} + +// ParseConfig provides a mock function with given fields: offchainConfig +func (_m *ConfigParser) ParseConfig(offchainConfig []byte) (*config.ReportingPluginConfigWrapper, error) { + ret := _m.Called(offchainConfig) + + var r0 *config.ReportingPluginConfigWrapper + var r1 error + if rf, ok := ret.Get(0).(func([]byte) (*config.ReportingPluginConfigWrapper, error)); ok { + return rf(offchainConfig) + } + if rf, ok := ret.Get(0).(func([]byte) *config.ReportingPluginConfigWrapper); ok { + r0 = rf(offchainConfig) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*config.ReportingPluginConfigWrapper) + } + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(offchainConfig) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type mockConstructorTestingTNewConfigParser interface { + mock.TestingT + Cleanup(func()) +} + +// NewConfigParser creates a new instance of ConfigParser. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewConfigParser(t mockConstructorTestingTNewConfigParser) *ConfigParser { + mock := &ConfigParser{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/go/ocr2/decryptionplugin/decryption.go b/go/ocr2/decryptionplugin/decryption.go index 8b44cf5..f05de25 100644 --- a/go/ocr2/decryptionplugin/decryption.go +++ b/go/ocr2/decryptionplugin/decryption.go @@ -9,13 +9,17 @@ import ( "github.com/goplugin/plugin-libocr/commontypes" "github.com/goplugin/plugin-libocr/offchainreporting2/types" "github.com/goplugin/tdh2/go/ocr2/decryptionplugin/config" - "github.com/goplugin/tdh2/go/tdh2easy" + "github.com/goplugin/tdh2/go/tdh2/tdh2easy" "google.golang.org/protobuf/proto" ) type DecryptionReportingPluginFactory struct { - DecryptionQueue DecryptionQueuingService - Logger commontypes.Logger + DecryptionQueue DecryptionQueuingService + ConfigParser config.ConfigParser + PublicKey *tdh2easy.PublicKey + PrivKeyShare *tdh2easy.PrivateShare + OracleToKeyShare map[commontypes.OracleID]int + Logger commontypes.Logger } type decryptionPlugin struct { @@ -30,12 +34,12 @@ type decryptionPlugin struct { // NewReportingPlugin complies with ReportingPluginFactory. func (f DecryptionReportingPluginFactory) NewReportingPlugin(rpConfig types.ReportingPluginConfig) (types.ReportingPlugin, types.ReportingPluginInfo, error) { - pluginConfig, err := config.DecodeReportingPluginConfig(rpConfig.OffchainConfig) + pluginConfig, err := f.ConfigParser.ParseConfig(rpConfig.OffchainConfig) if err != nil { f.Logger.Error("unable to decode reporting plugin config", commontypes.LogFields{ "configDigest": rpConfig.ConfigDigest.String(), }) - return nil, types.ReportingPluginInfo{}, fmt.Errorf("unalbe to decode reporting plugin config: %w", err) + return nil, types.ReportingPluginInfo{}, fmt.Errorf("unable to decode reporting plugin config: %w", err) } info := types.ReportingPluginInfo{ @@ -49,33 +53,16 @@ func (f DecryptionReportingPluginFactory) NewReportingPlugin(rpConfig types.Repo }, } - oracleToKeyShare := make(map[commontypes.OracleID]int) - for _, entry := range pluginConfig.Config.OracleIdToKeyIndex { - oID, ksID, err := config.DecodeOracleIdtoKeyShareIndex(entry) - if err != nil { - return nil, types.ReportingPluginInfo{}, fmt.Errorf("unalbe to decode reporting plugin oracle id to key Share index mapping: %w", err) - } - oracleToKeyShare[oID] = ksID - } - plugin := decryptionPlugin{ f.Logger, f.DecryptionQueue, - &tdh2easy.PublicKey{}, - &tdh2easy.PrivateShare{}, - oracleToKeyShare, + f.PublicKey, + f.PrivKeyShare, + f.OracleToKeyShare, &rpConfig, pluginConfig, } - if err = plugin.publicKey.Unmarshal(pluginConfig.Config.PublicKey); err != nil { - return nil, info, fmt.Errorf("cannot unmarshal public key: %w", err) - } - - if err = plugin.privKeyShare.Unmarshal(pluginConfig.Config.PrivKeyShare); err != nil { - return nil, info, fmt.Errorf("cannot unmarshal private key share: %w", err) - } - return &plugin, info, nil } @@ -261,13 +248,13 @@ func (dp *decryptionPlugin) Report(ctx context.Context, ts types.ReportTimestamp continue } - validDecryptionShares[ciphertextID] = append(validDecryptionShares[ciphertextID], validDecryptionShare) - if len(validDecryptionShares[ciphertextID]) >= fPlusOne { + if len(validDecryptionShares[ciphertextID]) < fPlusOne { + validDecryptionShares[ciphertextID] = append(validDecryptionShares[ciphertextID], validDecryptionShare) + } else { dp.logger.Trace("DecryptionReporting Report: we have already f+1 valid decryption shares", commontypes.LogFields{ "ciphertextID": ciphertextID, "observer": ob.Observer, }) - break } } } diff --git a/go/ocr2/decryptionplugin/decryption_test.go b/go/ocr2/decryptionplugin/decryption_test.go index be631a8..caeb20a 100644 --- a/go/ocr2/decryptionplugin/decryption_test.go +++ b/go/ocr2/decryptionplugin/decryption_test.go @@ -1,3 +1,772 @@ package decryptionplugin -// TODO +import ( + "bytes" + "context" + "errors" + "fmt" + "reflect" + "sort" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/goplugin/plugin-libocr/commontypes" + "github.com/goplugin/plugin-libocr/offchainreporting2/types" + "github.com/goplugin/plugin-libocr/ragep2p/loggers" + "github.com/goplugin/tdh2/go/ocr2/decryptionplugin/config" + "github.com/goplugin/tdh2/go/ocr2/decryptionplugin/config/mocks" + "github.com/goplugin/tdh2/go/tdh2/tdh2easy" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +// dummyLogger implements a dummy logger for testing only. +type dummyLogger struct{} + +func (l dummyLogger) Trace(msg string, fields commontypes.LogFields) {} +func (l dummyLogger) Debug(msg string, fields commontypes.LogFields) {} +func (l dummyLogger) Info(msg string, fields commontypes.LogFields) {} +func (l dummyLogger) Warn(msg string, fields commontypes.LogFields) {} +func (l dummyLogger) Error(msg string, fields commontypes.LogFields) {} +func (l dummyLogger) Critical(msg string, fields commontypes.LogFields) {} + +// queue implements the DecryptionQueuingService interface. +type queue struct { + q []DecryptionRequest + res [][]byte +} + +func (q *queue) GetRequests(requestCountLimit int, totalBytesLimit int) []DecryptionRequest { + stop := 0 + for i, tot := 0, 0; i < len(q.q) && i < requestCountLimit; i++ { + tot += len(q.q[i].Ciphertext) + if tot > totalBytesLimit { + break + } + stop++ + } + out := q.q[:stop] + q.q = q.q[stop:] + return out +} + +func (q *queue) GetCiphertext(ciphertextId CiphertextId) ([]byte, error) { + if bytes.Equal([]byte("please fail"), ciphertextId) { + return nil, fmt.Errorf("some error") + } + for _, e := range q.q { + if bytes.Equal(ciphertextId, e.CiphertextId) { + return e.Ciphertext, nil + } + } + return nil, ErrNotFound +} + +func (q *queue) SetResult(ciphertextId CiphertextId, plaintext []byte) { + q.res = append(q.res, ciphertextId) + q.res = append(q.res, plaintext) +} + +func makeConfig(t *testing.T, c *config.ReportingPluginConfig) types.ReportingPluginConfig { + t.Helper() + conf, err := config.EncodeReportingPluginConfig(&config.ReportingPluginConfigWrapper{ + Config: c, + }) + if err != nil { + t.Fatalf("EncodeReportingPluginConfig: %v", err) + } + return types.ReportingPluginConfig{OffchainConfig: conf} + +} + +func TestNewReportingPlugin(t *testing.T) { + _, pk, sh, err := tdh2easy.GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + for _, tc := range []struct { + name string + conf types.ReportingPluginConfig + err error + }{ + { + name: "ok", + conf: makeConfig(t, &config.ReportingPluginConfig{ + MaxQueryLengthBytes: 1, + MaxObservationLengthBytes: 2, + MaxReportLengthBytes: 3, + }), + }, + { + name: "ok minimal", + conf: makeConfig(t, &config.ReportingPluginConfig{}), + }, + { + name: "broken conf", + conf: types.ReportingPluginConfig{ + OffchainConfig: []byte("broken"), + }, + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + factory := DecryptionReportingPluginFactory{ + Logger: dummyLogger{}, + PublicKey: pk, + PrivKeyShare: sh[0], + ConfigParser: &config.DefaultConfigParser{}, + } + plugin, info, err := factory.NewReportingPlugin(tc.conf) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Fatalf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + conf, err := config.DecodeReportingPluginConfig(tc.conf.OffchainConfig) + if err != nil { + t.Fatalf("DecodeReportingPluginConfig: %v", err) + } + if a, b := info.Limits.MaxQueryLength, int(conf.Config.MaxQueryLengthBytes); a != b { + t.Errorf("info.Limits.MaxQueryLength=%v, want=%v", a, b) + + } + if a, b := info.Limits.MaxObservationLength, int(conf.Config.MaxObservationLengthBytes); a != b { + t.Errorf("info.Limits.MaxObservationLength=%v, want=%v", a, b) + } + if a, b := info.Limits.MaxReportLength, int(conf.Config.MaxReportLengthBytes); a != b { + t.Errorf("info.Limits.MaxReportLength=%v, want=%v", a, b) + } + p := plugin.(*decryptionPlugin) + if !reflect.DeepEqual(p.publicKey, pk) { + t.Errorf("got pubkey %v, want %v", p.publicKey, pk) + } + if !reflect.DeepEqual(p.privKeyShare, sh[0]) { + t.Errorf("got privkey %v, want %v", p.privKeyShare, sh[0]) + } + }) + } +} + +func TestGetValidDecryptionShare(t *testing.T) { + _, pk, sh, err := tdh2easy.GenerateKeys(1, 2) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + c, err := tdh2easy.Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + ds, err := tdh2easy.Decrypt(c, sh[1]) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + dsRaw, err := ds.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + c2, err := tdh2easy.Encrypt(pk, []byte("msg2")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + dp := &decryptionPlugin{ + oracleToKeyShare: map[commontypes.OracleID]int{ + 10: 0, + 123: 1, + }, + publicKey: pk, + } + for _, tc := range []struct { + name string + id commontypes.OracleID + c *tdh2easy.Ciphertext + share []byte + err error + }{ + { + name: "ok", + id: 123, + c: c, + share: dsRaw, + }, + { + name: "no oracle", + id: 1, + c: c, + share: dsRaw, + err: cmpopts.AnyError, + }, + { + name: "wrong index", + id: 10, + c: c, + share: dsRaw, + err: cmpopts.AnyError, + }, + { + name: "wrong share", + id: 123, + c: c2, + share: dsRaw, + err: cmpopts.AnyError, + }, + { + name: "broken ds", + id: 123, + c: c, + share: []byte("broken"), + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + got, err := dp.getValidDecryptionShare(tc.id, tc.c, tc.share) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Fatalf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + if !reflect.DeepEqual(got, ds) { + t.Errorf("got ds=%v, want=%v", got, ds) + } + }) + } + +} + +func TestQuery(t *testing.T) { + _, pk, _, err := tdh2easy.GenerateKeys(1, 2) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + ctxts := []*CiphertextWithID{} + for i := 0; i < 10; i++ { + id := []byte(fmt.Sprintf("%d", i)) + c, err := tdh2easy.Encrypt(pk, id) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + raw, err := c.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + ctxts = append(ctxts, &CiphertextWithID{ + CiphertextId: id, + Ciphertext: raw, + }) + } + for _, tc := range []struct { + name string + in []*CiphertextWithID + want []*CiphertextWithID + }{ + { + name: "empty", + }, + { + name: "one", + in: ctxts[:1], + want: ctxts[:1], + }, + { + name: "all", + in: ctxts, + want: ctxts, + }, + { + name: "one wrong", + in: append(ctxts, &CiphertextWithID{ + CiphertextId: []byte("1"), + Ciphertext: []byte("broken"), + }), + want: ctxts, + }, + } { + t.Run(tc.name, func(t *testing.T) { + q := &queue{} + for _, e := range tc.in { + q.q = append(q.q, DecryptionRequest{ + CiphertextId: e.CiphertextId, + Ciphertext: e.Ciphertext, + }) + } + dp := &decryptionPlugin{ + logger: dummyLogger{}, + publicKey: pk, + specificConfig: &config.ReportingPluginConfigWrapper{ + Config: &config.ReportingPluginConfig{ + RequestCountLimit: 999, + RequestTotalBytesLimit: 999999, + }, + }, + decryptionQueue: q, + } + b, err := dp.Query(context.Background(), types.ReportTimestamp{}) + if err != nil { + t.Fatalf("Query: %v", err) + } + got := Query{} + if err := proto.Unmarshal(b, &got); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if d := cmp.Diff(got.DecryptionRequests, tc.want, cmpopts.IgnoreUnexported(CiphertextWithID{})); d != "" { + t.Errorf("got/want diff=%v", d) + } + }) + } +} + +func TestShouldAcceptFinalizedReport(t *testing.T) { + r := &Report{ + ProcessedDecryptedRequests: []*ProcessedDecryptionRequest{ + { + CiphertextId: []byte("id1"), + Plaintext: []byte("p1"), + }, + { + CiphertextId: []byte("id2"), + Plaintext: []byte("p2"), + }, + { + CiphertextId: []byte("id3"), + Plaintext: []byte("p3"), + }, + }, + } + b, err := proto.Marshal(r) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + for _, tc := range []struct { + name string + in []byte + want [][]byte + err error + }{ + { + name: "empty", + }, + { + name: "broken", + in: []byte("broken"), + err: cmpopts.AnyError, + }, + { + name: "ok", + in: b, + want: [][]byte{[]byte("id1"), []byte("p1"), []byte("id2"), []byte("p2"), []byte("id3"), []byte("p3")}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + dp := &decryptionPlugin{ + logger: dummyLogger{}, + decryptionQueue: &queue{}, + } + transmit, err := dp.ShouldAcceptFinalizedReport(context.Background(), types.ReportTimestamp{}, tc.in) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Fatalf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + if transmit { + t.Errorf("ShouldAcceptFinalizedReport returned true") + } + q := dp.decryptionQueue.(*queue) + if d := cmp.Diff(q.res, tc.want); d != "" { + t.Errorf("got/want diff=%v", d) + } + }) + } +} + +func makeQuery(t *testing.T, c []*CiphertextWithID) []byte { + t.Helper() + b, err := proto.Marshal(&Query{ + DecryptionRequests: c, + }) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + return b +} + +type ctxtWithId struct { + id []byte + c *tdh2easy.Ciphertext +} + +func TestObservation(t *testing.T) { + _, pk, sh, err := tdh2easy.GenerateKeys(1, 2) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + q := &queue{} + ctxts := []*ctxtWithId{} + ctxtsRaw := []*CiphertextWithID{} + for i := 0; i < 10; i++ { + id := []byte(fmt.Sprintf("%d", i)) + c, err := tdh2easy.Encrypt(pk, id) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + raw, err := c.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + ctxtsRaw = append(ctxtsRaw, &CiphertextWithID{ + CiphertextId: id, + Ciphertext: raw, + }) + // add only 5 to the queue + if i < 5 { + q.q = append(q.q, DecryptionRequest{ + CiphertextId: id, + Ciphertext: raw, + }) + } + ctxts = append(ctxts, &ctxtWithId{ + id: id, + c: c, + }) + } + for _, tc := range []struct { + name string + query []byte + local bool + queue DecryptionQueuingService + err error + want []*ctxtWithId + }{ + { + name: "broken", + query: []byte("broken"), + err: cmpopts.AnyError, + }, + { + name: "empty", + query: makeQuery(t, nil), + }, + { + name: "one", + query: makeQuery(t, ctxtsRaw[:1]), + want: ctxts[:1], + }, + { + name: "many", + query: makeQuery(t, ctxtsRaw), + want: ctxts, + }, + { + name: "many locally queued", + query: makeQuery(t, ctxtsRaw[:5]), + local: true, + queue: q, + want: ctxts[:5], + }, + { + name: "some locally queued, some not found", + query: makeQuery(t, ctxtsRaw), + local: true, + queue: q, + want: ctxts[:5], + }, + { + name: "queue failing", + query: makeQuery(t, append(ctxtsRaw[:5], &CiphertextWithID{ + CiphertextId: []byte("please fail"), + Ciphertext: ctxtsRaw[5].Ciphertext, + })), + local: true, + queue: q, + want: ctxts[:5], + }, + { + name: "queued ciphertext mismatch", + query: makeQuery(t, append(ctxtsRaw[:4], &CiphertextWithID{ + CiphertextId: ctxtsRaw[4].CiphertextId, + Ciphertext: ctxtsRaw[5].Ciphertext, + })), + local: true, + queue: q, + want: ctxts[:4], + }, + { + name: "broken ciphertext", + query: makeQuery(t, append(ctxtsRaw[:3], &CiphertextWithID{ + CiphertextId: []byte("id"), + Ciphertext: []byte("broken"), + })), + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + dp := &decryptionPlugin{ + logger: dummyLogger{}, + publicKey: pk, + privKeyShare: sh[1], + specificConfig: &config.ReportingPluginConfigWrapper{ + Config: &config.ReportingPluginConfig{ + RequireLocalRequestCheck: tc.local, + }, + }, + decryptionQueue: tc.queue, + } + b, err := dp.Observation(context.Background(), types.ReportTimestamp{}, tc.query) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Fatalf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + var got Observation + if err := proto.Unmarshal(b, &got); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if a, b := len(got.DecryptionShares), len(tc.want); a != b { + t.Errorf("got %v dec shares, want %v", a, b) + } + for i := 0; i < len(got.DecryptionShares) && i < len(tc.want); i++ { + if a, b := got.DecryptionShares[i].CiphertextId, tc.want[i].id; !bytes.Equal(a, b) { + t.Errorf("got id=%v, want=%v", a, b) + } + var ds tdh2easy.DecryptionShare + if err := ds.Unmarshal(got.DecryptionShares[i].DecryptionShare); err != nil { + t.Errorf("Unmarshal: %v", err) + continue + } + if ds.Index() != 1 { + t.Errorf("got index=%v, want=1", ds.Index()) + } + if err := tdh2easy.VerifyShare(tc.want[i].c, pk, &ds); err != nil { + t.Errorf("VerifyShare id=%v err=%v", tc.want[i].id, err) + } + } + }) + } +} + +func makeObservations(t *testing.T, oracle2ids map[int][]string, id2shares map[string][][]byte) []types.AttributedObservation { + t.Helper() + var out []types.AttributedObservation + for oracle, ids := range oracle2ids { + decShares := []*DecryptionShareWithID{} + for _, id := range ids { + decShares = append(decShares, &DecryptionShareWithID{ + CiphertextId: []byte(id), + DecryptionShare: id2shares[id][oracle], + }) + } + ob, err := proto.Marshal(&Observation{ + DecryptionShares: decShares, + }) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + out = append(out, types.AttributedObservation{ + Observer: commontypes.OracleID(oracle), + Observation: ob, + }) + } + return out +} + +func TestReport(t *testing.T) { + _, pk, sh, err := tdh2easy.GenerateKeys(3, 5) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + want := []*ProcessedDecryptionRequest{} + ctxts := []*CiphertextWithID{} + shares := map[string][][]byte{} + // generate id-plaintext pairs, "id0"->"0", "id1"->"1", "id2"->"2" + for i := 0; i < 3; i++ { + id := []byte(fmt.Sprintf("id%d", i)) + plain := []byte(fmt.Sprintf("%d", i)) + c, err := tdh2easy.Encrypt(pk, plain) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + raw, err := c.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + ctxts = append(ctxts, &CiphertextWithID{ + CiphertextId: id, + Ciphertext: raw, + }) + want = append(want, &ProcessedDecryptionRequest{ + CiphertextId: id, + Plaintext: plain, + }) + for _, s := range sh { + ds, err := tdh2easy.Decrypt(c, s) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + b, err := ds.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + shares[string(id)] = append(shares[string(id)], b) + } + } + for _, tc := range []struct { + name string + query []byte + obs []types.AttributedObservation + err error + wantProcessed bool + want []*ProcessedDecryptionRequest + }{ + { + name: "empty", + query: makeQuery(t, nil), + }, + { + name: "broken query", + query: []byte("broken"), + err: cmpopts.AnyError, + }, + { + name: "broken ciphertext", + query: makeQuery(t, append(ctxts, &CiphertextWithID{ + CiphertextId: []byte("id"), + Ciphertext: []byte("broken"), + })), + err: cmpopts.AnyError, + }, + { + name: "nothing processed (no shares)", + query: makeQuery(t, ctxts), + }, + { + name: "nothing processed (no enough shares)", + query: makeQuery(t, ctxts), + obs: makeObservations(t, map[int][]string{ + 0: {"id0", "id1", "id2"}, + 1: {"id0", "id1", "id2"}, + }, shares), + }, + { + name: "one processed", + query: makeQuery(t, ctxts[:1]), + obs: makeObservations(t, map[int][]string{ + 0: {"id0", "id1", "id2"}, + 1: {"id0", "id1", "id2"}, + 2: {"id0", "id1", "id2"}, + }, shares), + wantProcessed: true, + want: want[:1], + }, + { + name: "two processed", + query: makeQuery(t, ctxts), + obs: makeObservations(t, map[int][]string{ + 0: {"id0", "id1", "id2"}, + 1: {"id0", "id1", "id2"}, + 2: {"id0", "id1"}, + }, shares), + wantProcessed: true, + want: want[:2], + }, + { + name: "all processed", + query: makeQuery(t, ctxts), + obs: makeObservations(t, map[int][]string{ + 0: {"id0", "id1", "id2"}, + 1: {"id0", "id1", "id2"}, + 2: {"id0", "id1", "id2"}, + }, shares), + wantProcessed: true, + want: want, + }, + { + name: "all processed, more shares than needed", + query: makeQuery(t, ctxts), + obs: makeObservations(t, map[int][]string{ + 0: {"id0", "id1", "id2"}, + 1: {"id0", "id1", "id2"}, + 2: {"id0", "id1", "id2"}, + 3: {"id0", "id1", "id2"}, + }, shares), + wantProcessed: true, + want: want, + }, + { + name: "nothing processed (wrong oracle-index mapping)", + query: makeQuery(t, ctxts), + obs: makeObservations(t, map[int][]string{ + 0: {"id0", "id1", "id2"}, + 1: {"id0", "id1", "id2"}, + 4: {"id0", "id1", "id2"}, + }, shares), + }, + { + name: "all processed, one broken obs", + query: makeQuery(t, ctxts), + obs: append(makeObservations(t, map[int][]string{ + 0: {"id0", "id1", "id2"}, + 1: {"id0", "id1", "id2"}, + 2: {"id0", "id1", "id2"}, + }, shares), types.AttributedObservation{ + Observer: 4, + Observation: []byte("broken"), + }), + wantProcessed: true, + want: want, + }, + } { + t.Run(tc.name, func(t *testing.T) { + dp := &decryptionPlugin{ + logger: dummyLogger{}, + publicKey: pk, + genericConfig: &types.ReportingPluginConfig{ + F: 2, + }, + oracleToKeyShare: map[commontypes.OracleID]int{ + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 5, // wrong mapping + }, + } + processed, reportBytes, err := dp.Report(context.Background(), types.ReportTimestamp{}, tc.query, tc.obs) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Fatalf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + if processed != tc.wantProcessed { + t.Errorf("got processed=%v, want=%v", processed, tc.wantProcessed) + } + var report Report + if err := proto.Unmarshal(reportBytes, &report); err != nil { + t.Errorf("Unmarshal: %v", err) + } + // make sure processed requests are sorted before comparison + got := report.ProcessedDecryptedRequests + sort.Slice(got, func(i, j int) bool { + return string(got[i].CiphertextId) < string(got[j].CiphertextId) + }) + if d := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(ProcessedDecryptionRequest{})); d != "" { + t.Errorf("got/want diff=%v", d) + } + }) + } +} + +func TestNewReportingPlugin_CustomConfigParser(t *testing.T) { + customParser := mocks.NewConfigParser(t) + factory := DecryptionReportingPluginFactory{ + ConfigParser: customParser, + Logger: loggers.MakeLogrusLogger(), + } + + customParser.On("ParseConfig", mock.Anything).Return(&config.ReportingPluginConfigWrapper{}, nil).Once() + _, _, err := factory.NewReportingPlugin(types.ReportingPluginConfig{}) + require.NoError(t, err) + + customParser.On("ParseConfig", mock.Anything).Return(nil, errors.New("error")).Once() + _, _, err = factory.NewReportingPlugin(types.ReportingPluginConfig{}) + require.Error(t, err) +} diff --git a/go/ocr2/decryptionplugin/go.mod b/go/ocr2/decryptionplugin/go.mod index d85811a..b9e72ac 100644 --- a/go/ocr2/decryptionplugin/go.mod +++ b/go/ocr2/decryptionplugin/go.mod @@ -3,10 +3,11 @@ module github.com/goplugin/tdh2/go/ocr2/decryptionplugin go 1.20 require ( + github.com/google/go-cmp v0.5.9 //github.com/goplugin/plugin-libocr v0.0.0-20230503222226-29f534b2de1a github.com/goplugin/plugin-libocr v0.1.1-beta //plugin update changes - //github.com/goplugin/tdh2 v0.0.0-20230523083904-ccb0d2ebd7d4 - github.com/goplugin/tdh2 v0.0.1 //plugin update changes + //github.com/goplugin/tdh2/go/tdh2 v0.0.0-20230524070358-28006f3fdc99 + github.com/goplugin/tdh2/go/tdh2 v0.0.1 //plugin update changes google.golang.org/protobuf v1.30.0 ) diff --git a/go/ocr2/decryptionplugin/go.sum b/go/ocr2/decryptionplugin/go.sum index 50f11d9..c82349c 100644 --- a/go/ocr2/decryptionplugin/go.sum +++ b/go/ocr2/decryptionplugin/go.sum @@ -3,6 +3,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -11,8 +12,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/goplugin/plugin-libocr v0.0.0-20230503222226-29f534b2de1a h1:NFTlIAjSwx9vBYZeUerd0DDsvg+zZ5TSESw4dPNdwJk= github.com/goplugin/plugin-libocr v0.0.0-20230503222226-29f534b2de1a/go.mod h1:5JnCHuYgmIP9ZyXzgAfI5Iwu0WxBtBKp+ApeT5o1Cjw= -github.com/goplugin/tdh2 v0.0.0-20230523083904-ccb0d2ebd7d4 h1:dFozHOgWYCh3qVrv9274+d14GvgKmVyGbY18SS93ZSU= -github.com/goplugin/tdh2 v0.0.0-20230523083904-ccb0d2ebd7d4/go.mod h1:37DpReCODOYar/xLkieKH/DKeI3ubU3QKSDOjQX1QnY= +github.com/goplugin/tdh2/go/tdh2 v0.0.0-20230524070358-28006f3fdc99 h1:XkM9YPlI0uUxp4INWXk/Nxc+k/QhSPpi04owatSR3t4= +github.com/goplugin/tdh2/go/tdh2 v0.0.0-20230524070358-28006f3fdc99/go.mod h1:Jf9J8VVTgeONnXq/Dtv634P+JxbSn5IK2lNww84PiIY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/go/tdh2/go.mod b/go/tdh2/go.mod index c04c134..0831170 100644 --- a/go/tdh2/go.mod +++ b/go/tdh2/go.mod @@ -2,13 +2,4 @@ module github.com/goplugin/tdh2/go/tdh2 go 1.19 -require ( - github.com/google/go-cmp v0.5.9 - go.dedis.ch/kyber/v3 v3.1.0 -) - -require ( - go.dedis.ch/fixbuf v1.0.3 // indirect - golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b // indirect - golang.org/x/sys v0.0.0-20190124100055-b90733256f2e // indirect -) +require github.com/google/go-cmp v0.5.9 diff --git a/go/tdh2/go.sum b/go/tdh2/go.sum index b754456..62841cd 100644 --- a/go/tdh2/go.sum +++ b/go/tdh2/go.sum @@ -1,22 +1,2 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -go.dedis.ch/fixbuf v1.0.3 h1:hGcV9Cd/znUxlusJ64eAlExS+5cJDIyTyEG+otu5wQs= -go.dedis.ch/fixbuf v1.0.3/go.mod h1:yzJMt34Wa5xD37V5RTdmp38cz3QhMagdGoem9anUalw= -go.dedis.ch/kyber/v3 v3.0.4/go.mod h1:OzvaEnPvKlyrWyp3kGXlFdp7ap1VC6RkZDTaPikqhsQ= -go.dedis.ch/kyber/v3 v3.0.9/go.mod h1:rhNjUUg6ahf8HEg5HUvVBYoWY4boAafX8tYxX+PS+qg= -go.dedis.ch/kyber/v3 v3.1.0 h1:ghu+kiRgM5JyD9TJ0hTIxTLQlJBR/ehjWvWwYW3XsC0= -go.dedis.ch/kyber/v3 v3.1.0/go.mod h1:kXy7p3STAurkADD+/aZcsznZGKVHEqbtmdIzvPfrs1U= -go.dedis.ch/protobuf v1.0.5/go.mod h1:eIV4wicvi6JK0q/QnfIEGeSFNG0ZeB24kzut5+HaRLo= -go.dedis.ch/protobuf v1.0.7/go.mod h1:pv5ysfkDX/EawiPqcW3ikOxsL5t+BqnV6xHSmE79KI4= -go.dedis.ch/protobuf v1.0.11/go.mod h1:97QR256dnkimeNdfmURz0wAMNVbd1VmLXhG1CrTYrJ4= -golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b h1:Elez2XeF2p9uyVj0yEUDqQ56NFcDtcBNkYP7yv8YbUE= -golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/sys v0.0.0-20190124100055-b90733256f2e h1:3GIlrlVLfkoipSReOMNAgApI0ajnalyLa/EZHHca/XI= -golang.org/x/sys v0.0.0-20190124100055-b90733256f2e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/go/tdh2/internal/group/LICENSE b/go/tdh2/internal/group/LICENSE new file mode 100644 index 0000000..b7b98e4 --- /dev/null +++ b/go/tdh2/internal/group/LICENSE @@ -0,0 +1,377 @@ +This code is derived from https://github.com/dedis/kyber (v3.1.0) and its original license is below. + +This code is (c) by DEDIS/EPFL 2017 under the MPL v2 or later version. + +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. \ No newline at end of file diff --git a/go/tdh2/internal/group/group.go b/go/tdh2/internal/group/group.go new file mode 100644 index 0000000..02c9f0f --- /dev/null +++ b/go/tdh2/internal/group/group.go @@ -0,0 +1,156 @@ +// package group provides interfaces for group-related objects. +package group + +import ( + "crypto/cipher" + "encoding" +) + +/* +Marshaling is a basic interface representing fixed-length (or known-length) +cryptographic objects or structures having a built-in binary encoding. +Implementors must ensure that calls to these methods do not modify +the underlying object so that other users of the object can access +it concurrently. +*/ +type Marshaling interface { + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler + + // String returns the human readable string representation of the object. + String() string + + // Encoded length of this object in bytes. + MarshalSize() int +} + +// Scalar represents a scalar value by which +// a Point (group element) may be encrypted to produce another Point. +// This is an exponent in DSA-style groups, +// in which security is based on the Discrete Logarithm assumption, +// and a scalar multiplier in elliptic curve groups. +type Scalar interface { + Marshaling + + // Equality test for two Scalars derived from the same Group. + Equal(s2 Scalar) bool + + // Clone creates a new Scalar with the same value. + Clone() Scalar + + // SetInt64 sets the receiver to a small integer value. + SetInt64(v int64) Scalar + + // Set to the additive identity (0). + Zero() Scalar + + // Set to the modular sum of scalars a and b. + Add(a, b Scalar) Scalar + + // Set to the modular difference a - b. + Sub(a, b Scalar) Scalar + + // Set to the modular negation of scalar a. + Neg(a Scalar) Scalar + + // Set to the multiplicative identity (1). + One() Scalar + + // Set to the modular product of scalars a and b. + Mul(a, b Scalar) Scalar + + // Set to the modular division of scalar a by scalar b. + Div(a, b Scalar) Scalar + + // Set to the modular inverse of scalar a. + Inv(a Scalar) Scalar + + // Set to a fresh random or pseudo-random scalar. + Pick(rand cipher.Stream) Scalar + + // SetBytes sets the scalar from a byte-slice, + // reducing if necessary to the appropriate modulus. + // The endianess of the byte-slice is determined by the + // implementation. + SetBytes([]byte) Scalar +} + +// Point represents an element of a public-key cryptographic Group. +// For example, +// this is a number modulo the prime P in a DSA-style Schnorr group, +// or an (x, y) point on an elliptic curve. +// A Point can contain a Diffie-Hellman public key, an ElGamal ciphertext, etc. +type Point interface { + Marshaling + + // Equality test for two Points derived from the same Group. + Equal(s2 Point) bool + + // Null sets the receiver to the neutral identity element. + Null() Point + + // Base sets the receiver to this group's standard base point. + Base() Point + + // Pick sets the receiver to a fresh random or pseudo-random Point. + Pick(rand cipher.Stream) Point + + // Set sets the receiver equal to another Point p. + Set(p Point) Point + + // Clone clones the underlying point. + Clone() Point + + // Add points so that their scalars add homomorphically. + Add(a, b Point) Point + + // Subtract points so that their scalars subtract homomorphically. + Sub(a, b Point) Point + + // Set to the negation of point a. + Neg(a Point) Point + + // Multiply point p by the scalar s. + // If p == nil, multiply with the standard base point Base(). + Mul(s Scalar, p Point) Point +} + +// Group interface represents a mathematical group +// usable for Diffie-Hellman key exchange, ElGamal encryption, +// and the related body of public-key cryptographic algorithms +// and zero-knowledge proof methods. +// The Group interface is designed in particular to be a generic front-end +// to both traditional DSA-style modular arithmetic groups +// and ECDSA-style elliptic curves: +// the caller of this interface's methods +// need not know or care which specific mathematical construction +// underlies the interface. +// +// The Group interface is essentially just a "constructor" interface +// enabling the caller to generate the two particular types of objects +// relevant to DSA-style public-key cryptography; +// we call these objects Points and Scalars. +// The caller must explicitly initialize or set a new Point or Scalar object +// to some value before using it as an input to some other operation +// involving Point and/or Scalar objects. +// For example, to compare a point P against the neutral (identity) element, +// you might use P.Equal(suite.Point().Null()), +// but not just P.Equal(suite.Point()). +// +// It is expected that any implementation of this interface +// should satisfy suitable hardness assumptions for the applicable group: +// e.g., that it is cryptographically hard for an adversary to +// take an encrypted Point and the known generator it was based on, +// and derive the Scalar with which the Point was encrypted. +// Any implementation is also expected to satisfy +// the standard homomorphism properties that Diffie-Hellman +// and the associated body of public-key cryptography are based on. +type Group interface { + String() string + + ScalarLen() int // Max length of scalars in bytes + Scalar() Scalar // Create new scalar + + PointLen() int // Max length of point in bytes + Point() Point // Create new point +} diff --git a/go/tdh2/internal/group/mod/int.go b/go/tdh2/internal/group/mod/int.go new file mode 100644 index 0000000..a31b111 --- /dev/null +++ b/go/tdh2/internal/group/mod/int.go @@ -0,0 +1,219 @@ +// Package mod contains a generic implementation of finite field arithmetic +// on integer fields with a constant modulus. +package mod + +import ( + "crypto/cipher" + "encoding/hex" + "errors" + "math/big" + + "github.com/goplugin/tdh2/go/tdh2/internal/group" +) + +// Int is a generic implementation of finite field arithmetic +// on integer finite fields with a given constant modulus, +// built using Go's built-in big.Int package. +// Int satisfies the group.Scalar interface, +// and hence serves as a basic implementation of group.Scalar, +// e.g., representing discrete-log exponents of Schnorr groups +// or scalar multipliers for elliptic curves. +// +// Int offers an API similar to and compatible with big.Int, +// but "carries around" a pointer to the relevant modulus +// and automatically normalizes the value to that modulus +// after all arithmetic operations, simplifying modular arithmetic. +// Binary operations assume that the source(s) +// have the same modulus, but do not check this assumption. +// Unary and binary arithmetic operations may be performed on uninitialized +// target objects, and receive the modulus of the first operand. +// For efficiency the modulus field M is a pointer, +// whose target is assumed never to change. +type Int struct { + V big.Int // Integer value from 0 through M-1 + M *big.Int // Modulus for finite field arithmetic +} + +// NewInt64 creates a new Int with a given int64 value and big.Int modulus. +func NewInt64(v int64, M *big.Int) *Int { + i := &Int{M: M} + i.V.SetInt64(v).Mod(&i.V, M) + return i +} + +// Return the Int's integer value in hexadecimal string representation. +func (i *Int) String() string { + return hex.EncodeToString(i.V.Bytes()) +} + +// Equal returns true if the two Ints are equal +func (i *Int) Equal(s2 group.Scalar) bool { + return i.V.Cmp(&s2.(*Int).V) == 0 +} + +// Clone returns a separate duplicate of this Int. +func (i *Int) Clone() group.Scalar { + ni := &Int{M: i.M} + ni.V.Set(&i.V).Mod(&i.V, i.M) + return ni +} + +// Zero set the Int to the value 0. The modulus must already be initialized. +func (i *Int) Zero() group.Scalar { + i.V.SetInt64(0) + return i +} + +// One sets the Int to the value 1. The modulus must already be initialized. +func (i *Int) One() group.Scalar { + i.V.SetInt64(1) + return i +} + +// SetInt64 sets the Int to an arbitrary 64-bit "small integer" value. +// The modulus must already be initialized. +func (i *Int) SetInt64(v int64) group.Scalar { + i.V.SetInt64(v).Mod(&i.V, i.M) + return i +} + +// Add sets the target to a + b mod M, where M is a's modulus.. +func (i *Int) Add(a, b group.Scalar) group.Scalar { + ai := a.(*Int) + bi := b.(*Int) + i.M = ai.M + i.V.Add(&ai.V, &bi.V).Mod(&i.V, i.M) + return i +} + +// Sub sets the target to a - b mod M. +// Target receives a's modulus. +func (i *Int) Sub(a, b group.Scalar) group.Scalar { + ai := a.(*Int) + bi := b.(*Int) + i.M = ai.M + i.V.Sub(&ai.V, &bi.V).Mod(&i.V, i.M) + return i +} + +// Neg sets the target to -a mod M. +func (i *Int) Neg(a group.Scalar) group.Scalar { + ai := a.(*Int) + i.M = ai.M + if ai.V.Sign() > 0 { + i.V.Sub(i.M, &ai.V) + } else { + i.V.SetUint64(0) + } + return i +} + +// Mul sets the target to a * b mod M. +// Target receives a's modulus. +func (i *Int) Mul(a, b group.Scalar) group.Scalar { + ai := a.(*Int) + bi := b.(*Int) + i.M = ai.M + i.V.Mul(&ai.V, &bi.V).Mod(&i.V, i.M) + return i +} + +// Div sets the target to a * b^-1 mod M, where b^-1 is the modular inverse of b. +func (i *Int) Div(a, b group.Scalar) group.Scalar { + ai := a.(*Int) + bi := b.(*Int) + var t big.Int + i.M = ai.M + i.V.Mul(&ai.V, t.ModInverse(&bi.V, i.M)) + i.V.Mod(&i.V, i.M) + return i +} + +// Inv sets the target to the modular inverse of a with respect to modulus M. +func (i *Int) Inv(a group.Scalar) group.Scalar { + ai := a.(*Int) + i.M = ai.M + i.V.ModInverse(&a.(*Int).V, i.M) + return i +} + +// Pick a [pseudo-]random integer modulo M +// using bits from the given stream cipher. +// This code is adopted from Go's elliptic.GenerateKey() +// and the rejection sampling can lead to up to a two-fold +// slowdown, if M is not close to 2**bitSize. +func (i *Int) Pick(rand cipher.Stream) group.Scalar { + var n *big.Int + // This is just a bitmask with the number of ones starting at 8 then + // incrementing by index. To account for fields with bitsizes that are not a whole + // number of bytes, we mask off the unnecessary bits. h/t agl + mask := []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f} + bitSize := i.M.BitLen() + byteLen := (bitSize + 7) / 8 + b := make([]byte, byteLen) + + for { + rand.XORKeyStream(b, b) + // We have to mask off any excess bits in the case that the size of the + // underlying field is not a whole number of bytes. + b[0] &= mask[bitSize%8] + // This is because, in tests, rand will return all zeros and we don't + // want to get the point at infinity and loop forever. + b[1] ^= 0x42 + + n = new(big.Int).SetBytes(b) + // If the scalar is out of range, sample another random number. + if n.Cmp(i.M) < 0 { + break + } + } + + i.V.Set(n) + return i +} + +// MarshalSize returns the length in bytes of encoded integers with modulus M. +// The length of encoded Ints depends only on the size of the modulus, +// and not on the the value of the encoded integer, +// making the encoding is fixed-length for simplicity and security. +func (i *Int) MarshalSize() int { + return (i.M.BitLen() + 7) / 8 +} + +// MarshalBinary encodes the value of this Int into a byte-slice exactly Len() bytes long. +// It uses big endian. +func (i *Int) MarshalBinary() ([]byte, error) { + l := i.MarshalSize() + b := i.V.Bytes() // may be shorter than l + offset := l - len(b) + + if offset != 0 { + nb := make([]byte, l) + copy(nb[offset:], b) + b = nb + } + return b, nil +} + +// UnmarshalBinary tries to decode a Int from a byte-slice buffer. +// Returns an error if the buffer is not exactly Len() bytes long +// or if the contents of the buffer represents an out-of-range integer. +func (i *Int) UnmarshalBinary(buf []byte) error { + if len(buf) != i.MarshalSize() { + return errors.New("UnmarshalBinary: wrong size buffer") + } + + i.V.SetBytes(buf) + if i.V.Cmp(i.M) >= 0 { + return errors.New("UnmarshalBinary: value out of range") + } + return nil +} + +// SetBytes set the value value to a number represented +// by a byte string. +func (i *Int) SetBytes(a []byte) group.Scalar { + var buff = a + i.V.SetBytes(buff).Mod(&i.V, i.M) + return i +} diff --git a/go/tdh2/internal/group/mod/int_test.go b/go/tdh2/internal/group/mod/int_test.go new file mode 100644 index 0000000..e06c039 --- /dev/null +++ b/go/tdh2/internal/group/mod/int_test.go @@ -0,0 +1,47 @@ +package mod + +import ( + "bytes" + "crypto/elliptic" + "math/big" + "testing" +) + +func FuzzIntMarshal(f *testing.F) { + mods := []*big.Int{elliptic.P256().Params().N, elliptic.P384().Params().N, elliptic.P521().Params().N} + for idx, m := range mods { + i := NewInt64(0, m) + b, err := i.MarshalBinary() + if err != nil { + f.Fatalf("MarshalBinary: %v", err) + } + f.Add(idx, b) + } + f.Fuzz(func(t *testing.T, idx int, data []byte) { + if idx < 0 || idx >= len(mods) { + t.Skip() + } + i1 := NewInt64(0, mods[idx]) + i2 := NewInt64(0, mods[idx]) + if err := i1.UnmarshalBinary(data); err != nil { + t.Skip() + } + data1, err := i1.MarshalBinary() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data, err) + } + if err := i2.UnmarshalBinary(data1); err != nil { + t.Fatalf("Cannot unmarshal: data=%v err=%v", data1, err) + } + data2, err := i2.MarshalBinary() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data2, err) + } + if !bytes.Equal(data1, data2) { + t.Errorf("data1=%v data2=%v", data1, data2) + } + if !i1.Equal(i2) { + t.Errorf("ps1=%v data1=%v ps2=%v data2=%v", i1, data1, i2, data2) + } + }) +} diff --git a/go/tdh2/internal/group/nist/curve.go b/go/tdh2/internal/group/nist/curve.go new file mode 100644 index 0000000..83de2b2 --- /dev/null +++ b/go/tdh2/internal/group/nist/curve.go @@ -0,0 +1,246 @@ +// Package nist implements cryptographic groups and ciphersuites +// based on the NIST standards, using Go's built-in crypto library. +package nist + +import ( + "crypto/cipher" + "crypto/elliptic" + "errors" + "fmt" + "math/big" + + "github.com/goplugin/tdh2/go/tdh2/internal/group" + "github.com/goplugin/tdh2/go/tdh2/internal/group/mod" +) + +// streamReader implements io.Reader from cipher.Stream +type streamReader struct { + stream cipher.Stream +} + +func (s *streamReader) Read(p []byte) (int, error) { + s.stream.XORKeyStream(p, p) + return len(p), nil +} + +type curvePoint struct { + x, y *big.Int + c *curve +} + +func (p *curvePoint) String() string { + return "(" + p.x.String() + "," + p.y.String() + ")" +} + +func (p *curvePoint) Equal(p2 group.Point) bool { + cp2 := p2.(*curvePoint) + + // Make sure both coordinates are normalized. + // Apparently Go's elliptic curve code doesn't always ensure this. + M := p.c.p.P + p.x.Mod(p.x, M) + p.y.Mod(p.y, M) + cp2.x.Mod(cp2.x, M) + cp2.y.Mod(cp2.y, M) + + return p.x.Cmp(cp2.x) == 0 && p.y.Cmp(cp2.y) == 0 +} + +func (p *curvePoint) Null() group.Point { + p.x = new(big.Int).SetInt64(0) + p.y = new(big.Int).SetInt64(0) + return p +} + +func (p *curvePoint) Base() group.Point { + p.x = p.c.p.Gx + p.y = p.c.p.Gy + return p +} + +func (p *curvePoint) Valid() bool { + // The IsOnCurve function in Go's elliptic curve package + // doesn't consider the point-at-infinity to be "on the curve" + return p.c.IsOnCurve(p.x, p.y) || + (p.x.Sign() == 0 && p.y.Sign() == 0) +} + +func (p *curvePoint) Pick(rand cipher.Stream) group.Point { + var err error + _, p.x, p.y, err = elliptic.GenerateKey(p.c, &streamReader{rand}) + if err != nil { + // It cannot panic since GenerateKey returns errors only on reading + // from the randomness source which is deterministic in our case. + panic(fmt.Sprintf("cannot generate point: %v", err)) + } + return p +} + +func (p *curvePoint) Add(a, b group.Point) group.Point { + ca := a.(*curvePoint) + cb := b.(*curvePoint) + p.x, p.y = p.c.Add(ca.x, ca.y, cb.x, cb.y) + return p +} + +func (p *curvePoint) Sub(a, b group.Point) group.Point { + ca := a.(*curvePoint) + cb := b.(*curvePoint) + + cbn := p.c.Point().Neg(cb).(*curvePoint) + p.x, p.y = p.c.Add(ca.x, ca.y, cbn.x, cbn.y) + return p +} + +func (p *curvePoint) Neg(a group.Point) group.Point { + s := p.c.Scalar().One() + s.Neg(s) + return p.Mul(s, a).(*curvePoint) +} + +func (p *curvePoint) Mul(s group.Scalar, b group.Point) group.Point { + cs := s.(*mod.Int) + if b != nil { + cb := b.(*curvePoint) + p.x, p.y = p.c.ScalarMult(cb.x, cb.y, cs.V.Bytes()) + } else { + p.x, p.y = p.c.ScalarBaseMult(cs.V.Bytes()) + } + return p +} + +func (p *curvePoint) MarshalSize() int { + coordlen := (p.c.Params().BitSize + 7) >> 3 + return 1 + 2*coordlen // uncompressed ANSI X9.62 representation +} + +func (p *curvePoint) MarshalBinary() ([]byte, error) { + return elliptic.Marshal(p.c, p.x, p.y), nil +} + +func (p *curvePoint) UnmarshalBinary(buf []byte) error { + if len(buf) != p.MarshalSize() { + return errors.New("wrong buffer size") + } + // Check whether all bytes after first one are 0, so we + // just return the initial point. Read everything to + // prevent timing-leakage. + var c byte + for _, b := range buf[1:] { + c |= b + } + if c != 0 { + p.x, p.y = elliptic.Unmarshal(p.c, buf) + if p.x == nil || !p.Valid() { + return errors.New("invalid elliptic curve point") + } + } else { + // All bytes are 0, so we initialize x and y + p.x = big.NewInt(0) + p.y = big.NewInt(0) + } + return nil +} + +// Curve is an implementation of the group.Group interface +// for NIST elliptic curves, built on Go's native elliptic curve library. +type curve struct { + elliptic.Curve + p *elliptic.CurveParams +} + +// Return the number of bytes in the encoding of a Scalar for this curve. +func (c *curve) ScalarLen() int { return (c.p.N.BitLen() + 7) / 8 } + +// Create a Scalar associated with this curve. The scalars created by +// this package implement group.Scalar's SetBytes method, interpreting +// the bytes as a big-endian integer, so as to be compatible with the +// Go standard library's big.Int type. +func (c *curve) Scalar() group.Scalar { + return mod.NewInt64(0, c.p.N) +} + +// Number of bytes required to store one coordinate on this curve +func (c *curve) coordLen() int { + return (c.p.BitSize + 7) / 8 +} + +// Return the number of bytes in the encoding of a Point for this curve. +// Currently uses uncompressed ANSI X9.62 format with both X and Y coordinates; +// this could change. +func (c *curve) PointLen() int { + return 1 + 2*c.coordLen() // ANSI X9.62: 1 header byte plus 2 coords +} + +// Create a Point associated with this curve. +func (c *curve) Point() group.Point { + p := new(curvePoint) + p.c = c + return p +} + +func (p *curvePoint) Set(P group.Point) group.Point { + p.x = P.(*curvePoint).x + p.y = P.(*curvePoint).y + return p +} + +func (p *curvePoint) Clone() group.Point { + return &curvePoint{x: p.x, y: p.y, c: p.c} +} + +// Return the order of this curve: the prime N in the curve parameters. +func (c *curve) Order() *big.Int { + return c.p.N +} + +// P256 implements the group.Group interface for the NIST P-256 elliptic curve. +type P256 struct { + curve +} + +func (curve *P256) String() string { + return "P256" +} + +// NewP256 returns a new instance of P256. +func NewP256() *P256 { + var g P256 + g.curve.Curve = elliptic.P256() + g.p = g.Params() + return &g +} + +// P384 implements the group.Group interface for the NIST P-384 elliptic curve. +type P384 struct { + curve +} + +func (curve *P384) String() string { + return "P384" +} + +// NewP384 returns a new instance of P384. +func NewP384() *P384 { + var g P384 + g.curve.Curve = elliptic.P384() + g.p = g.Params() + return &g +} + +// P521 implements the group.Group interface for the NIST P-521 elliptic curve. +type P521 struct { + curve +} + +func (curve *P521) String() string { + return "P521" +} + +// NewP521 returns a new instance of P521. +func NewP521() *P521 { + var g P521 + g.curve.Curve = elliptic.P521() + g.p = g.Params() + return &g +} diff --git a/go/tdh2/internal/group/nist/group_test.go b/go/tdh2/internal/group/nist/group_test.go new file mode 100644 index 0000000..b43af35 --- /dev/null +++ b/go/tdh2/internal/group/nist/group_test.go @@ -0,0 +1,178 @@ +package nist + +import ( + "bytes" + "testing" + + "github.com/goplugin/tdh2/go/tdh2/internal/group" + "github.com/goplugin/tdh2/go/tdh2/internal/group/test" +) + +var benchmarks = []*test.GroupBench{ + test.NewGroupBench(NewP256()), + test.NewGroupBench(NewP384()), + test.NewGroupBench(NewP521()), +} + +func TestSetBytesBE(t *testing.T) { + for _, b := range benchmarks { + t.Run(b.String(), func(t *testing.T) { + s := b.G.Scalar() + s.SetBytes([]byte{0, 1, 2, 3}) + // 010203 because initial 0 is trimmed in String(), and 03 (last byte of BE) ends up + // in the LSB of the bigint. + if s.String() != "010203" { + t.Fatal("unexpected result from String():", s.String()) + } + }) + } +} + +func TestGroup(t *testing.T) { + for _, bench := range benchmarks { + t.Run(bench.String(), func(t *testing.T) { + test.GroupTest(t, bench.G) + }) + } +} + +func BenchmarkScalarAdd(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.ScalarAdd(b.N) }) + } +} + +func BenchmarkScalarSub(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.ScalarSub(b.N) }) + } +} + +func BenchmarkScalarNeg(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.ScalarNeg(b.N) }) + } +} + +func BenchmarkScalarMul(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.ScalarMul(b.N) }) + } +} + +func BenchmarkScalarDiv(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.ScalarDiv(b.N) }) + } +} + +func BenchmarkScalarInv(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.ScalarInv(b.N) }) + } +} + +func BenchmarkScalarPick(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.ScalarPick(b.N) }) + } +} + +func BenchmarkScalarEncode(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.ScalarEncode(b.N) }) + } +} + +func BenchmarkScalarDecode(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.ScalarDecode(b.N) }) + } +} + +func BenchmarkPointAdd(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.PointAdd(b.N) }) + } +} + +func BenchmarkPointSub(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.PointSub(b.N) }) + } +} + +func BenchmarkPointNeg(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.PointNeg(b.N) }) + } +} + +func BenchmarkPointMul(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.PointMul(b.N) }) + } +} + +func BenchmarkPointBaseMul(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.PointBaseMul(b.N) }) + } +} + +func BenchmarkPointPick(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.PointPick(b.N) }) + } +} + +func BenchmarkPointEncode(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.PointEncode(b.N) }) + } +} + +func BenchmarkPointDecode(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.String(), func(b *testing.B) { bench.PointDecode(b.N) }) + } +} + +func FuzzCurvePointMarshal(f *testing.F) { + groups := []group.Group{NewP256(), NewP384(), NewP521()} + for idx, g := range groups { + p := g.Point().Base() + b, err := p.MarshalBinary() + if err != nil { + f.Fatalf("MarshalBinary: %v", err) + } + f.Add(idx, b) + } + f.Fuzz(func(t *testing.T, idx int, data []byte) { + if idx < 0 || idx >= len(groups) { + t.Skip() + } + p1 := groups[idx].Point() + p2 := groups[idx].Point() + if err := p1.UnmarshalBinary(data); err != nil { + t.Skip() + } + data1, err := p1.MarshalBinary() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data, err) + } + if err := p2.UnmarshalBinary(data1); err != nil { + t.Fatalf("Cannot unmarshal: data=%v err=%v", data1, err) + } + data2, err := p2.MarshalBinary() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data2, err) + } + if !bytes.Equal(data1, data2) { + t.Errorf("data1=%v data2=%v", data1, data2) + } + if !p1.Equal(p2) { + t.Errorf("ps1=%v data1=%v ps2=%v data2=%v", p1, data1, p2, data2) + } + }) +} diff --git a/go/tdh2/internal/group/share/poly.go b/go/tdh2/internal/group/share/poly.go new file mode 100644 index 0000000..2ad9cf1 --- /dev/null +++ b/go/tdh2/internal/group/share/poly.go @@ -0,0 +1,152 @@ +// Package share implements Shamir secret sharing and polynomial commitments. +// Shamir's scheme allows you to split a secret value into multiple parts, so called +// shares, by evaluating a secret sharing polynomial at certain indices. The +// shared secret can only be reconstructed (via Lagrange interpolation) if a +// threshold of the participants provide their shares. A polynomial commitment +// scheme allows a committer to commit to a secret sharing polynomial so that +// a verifier can check the claimed evaluations of the committed polynomial. +// Both schemes of this package are core building blocks for more advanced +// secret sharing techniques. +package share + +import ( + "crypto/cipher" + "errors" + "fmt" + "sort" + "strings" + + "github.com/goplugin/tdh2/go/tdh2/internal/group" +) + +// PriShare represents a private share. +type PriShare struct { + I int // Index of the private share + V group.Scalar // Value of the private share +} + +func (p *PriShare) String() string { + return fmt.Sprintf("{%d:%s}", p.I, p.V) +} + +// PriPoly represents a secret sharing polynomial. +type PriPoly struct { + g group.Group // Cryptographic group + coeffs []group.Scalar // Coefficients of the polynomial +} + +// NewPriPoly creates a new secret sharing polynomial using the provided +// cryptographic group, the secret sharing threshold t, and the secret to be +// shared s. If s is nil, a new s is chosen using the provided randomness +// stream rand. +func NewPriPoly(grp group.Group, t int, s group.Scalar, rand cipher.Stream) *PriPoly { + coeffs := make([]group.Scalar, t) + coeffs[0] = s + if coeffs[0] == nil { + coeffs[0] = grp.Scalar().Pick(rand) + } + for i := 1; i < t; i++ { + coeffs[i] = grp.Scalar().Pick(rand) + } + return &PriPoly{g: grp, coeffs: coeffs} +} + +// Secret returns the shared secret p(0), i.e., the constant term of the polynomial. +func (p *PriPoly) Secret() group.Scalar { + return p.coeffs[0] +} + +// Eval computes the private share v = p(i). +func (p *PriPoly) Eval(i int) *PriShare { + xi := p.g.Scalar().SetInt64(1 + int64(i)) + v := p.g.Scalar().Zero() + for j := len(p.coeffs) - 1; j >= 0; j-- { + v.Mul(v, xi) + v.Add(v, p.coeffs[j]) + } + return &PriShare{i, v} +} + +// Shares creates a list of n private shares p(1),...,p(n). +func (p *PriPoly) Shares(n int) []*PriShare { + shares := make([]*PriShare, n) + for i := range shares { + shares[i] = p.Eval(i) + } + return shares +} + +func (p *PriPoly) String() string { + var strs = make([]string, len(p.coeffs)) + for i, c := range p.coeffs { + strs[i] = c.String() + } + return "[ " + strings.Join(strs, ", ") + " ]" +} + +// PubShare represents a public share. +type PubShare struct { + I int // Index of the public share + V group.Point // Value of the public share +} + +// xyCommits is the public version of xScalars. +func xyCommit(g group.Group, shares []*PubShare, t, n int) (map[int]group.Scalar, map[int]group.Point) { + // we are sorting first the shares since the shares may be unrelated for + // some applications. In this case, all participants needs to interpolate on + // the exact same order shares. + sorted := make([]*PubShare, 0, n) + for _, share := range shares { + if share != nil { + sorted = append(sorted, share) + } + } + sort.Slice(sorted, func(i, j int) bool { return sorted[i].I < sorted[j].I }) + + x := make(map[int]group.Scalar) + y := make(map[int]group.Point) + + for _, s := range sorted { + if s == nil || s.V == nil || s.I < 0 { + continue + } + idx := s.I + x[idx] = g.Scalar().SetInt64(int64(idx + 1)) + y[idx] = s.V + if len(x) == t { + break + } + } + return x, y +} + +// RecoverCommit reconstructs the secret commitment p(0) from a list of public +// shares using Lagrange interpolation. +func RecoverCommit(g group.Group, shares []*PubShare, t, n int) (group.Point, error) { + x, y := xyCommit(g, shares, t, n) + if len(x) < t { + return nil, errors.New("share: not enough good public shares to reconstruct secret commitment") + } + + num := g.Scalar() + den := g.Scalar() + tmp := g.Scalar() + Acc := g.Point().Null() + Tmp := g.Point() + + for i, xi := range x { + num.One() + den.One() + for j, xj := range x { + if i == j { + continue + } + num.Mul(num, xj) + den.Mul(den, tmp.Sub(xj, xi)) + } + Tmp.Mul(num.Div(num, den), y[i]) + Acc.Add(Acc, Tmp) + } + + return Acc, nil +} diff --git a/go/tdh2/internal/group/share/poly_test.go b/go/tdh2/internal/group/share/poly_test.go new file mode 100644 index 0000000..60d7d9d --- /dev/null +++ b/go/tdh2/internal/group/share/poly_test.go @@ -0,0 +1,169 @@ +package share + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "testing" + + "github.com/goplugin/tdh2/go/tdh2/internal/group" + "github.com/goplugin/tdh2/go/tdh2/internal/group/nist" +) + +var groups = []group.Group{ + nist.NewP256(), + nist.NewP384(), + nist.NewP521(), +} + +func randStream(t *testing.T) cipher.Stream { + block, err := aes.NewCipher(make([]byte, 16)) + if err != nil { + t.Fatalf("NewCipher: %v", err) + } + iv := make([]byte, aes.BlockSize) + if _, err := rand.Read(iv); err != nil { + t.Fatalf("Read: %v", err) + } + return cipher.NewCTR(block, iv) +} + +func pubShares(g group.Group, shares []*PriShare) []*PubShare { + out := []*PubShare{} + for _, s := range shares { + out = append(out, &PubShare{ + I: s.I, + V: g.Point().Mul(s.V, nil), + }) + } + return out +} + +func TestRecoveryWithoutSecret(test *testing.T) { + for _, g := range groups { + test.Run(g.String(), func(test *testing.T) { + n := 10 + t := n/2 + 1 + + priPoly := NewPriPoly(g, t, nil, randStream(test)) + shares := priPoly.Shares(n) + pubShares := pubShares(g, shares) + + recovered, err := RecoverCommit(g, pubShares, t, n) + if err != nil { + test.Fatal(err) + } + + if !recovered.Equal(g.Point().Mul(priPoly.Secret(), nil)) { + test.Fatal("recovered commit does not match initial value") + } + + }) + } +} + +func TestRecoveryWithSecret(test *testing.T) { + for _, g := range groups { + test.Run(g.String(), func(test *testing.T) { + n := 10 + t := n/2 + 1 + s := g.Scalar().Pick(randStream(test)) + + priPoly := NewPriPoly(g, t, s, randStream(test)) + if !s.Equal(priPoly.Secret()) { + test.Fatalf("secrets not equal") + } + shares := priPoly.Shares(n) + pubShares := pubShares(g, shares) + + recovered, err := RecoverCommit(g, pubShares, t, n) + if err != nil { + test.Fatal(err) + } + + if !recovered.Equal(g.Point().Mul(s, nil)) { + test.Fatal("recovered commit does not match initial value") + } + }) + } +} + +func TestPublicRecoveryOutIndex(test *testing.T) { + for _, g := range groups { + test.Run(g.String(), func(test *testing.T) { + n := 10 + t := n/2 + 1 + + priPoly := NewPriPoly(g, t, nil, randStream(test)) + pubShares := pubShares(g, priPoly.Shares(n)) + comm := g.Point().Mul(priPoly.Secret(), nil) + + selected := pubShares[n-t:] + if len(selected) != t { + test.Fatalf("len(selected) != t") + } + newN := t + 1 + + recovered, err := RecoverCommit(g, selected, t, newN) + if err != nil { + test.Fatal(err) + } + + if !recovered.Equal(comm) { + test.Fatal("recovered commit does not match initial value") + } + }) + } +} + +func TestPublicRecoveryDelete(test *testing.T) { + for _, g := range groups { + test.Run(g.String(), func(test *testing.T) { + n := 10 + t := n/2 + 1 + + priPoly := NewPriPoly(g, t, nil, randStream(test)) + shares := pubShares(g, priPoly.Shares(n)) + comm := g.Point().Mul(priPoly.Secret(), nil) + + // Corrupt a few shares + shares[2] = nil + shares[5] = nil + shares[7] = nil + shares[8] = nil + + recovered, err := RecoverCommit(g, shares, t, n) + if err != nil { + test.Fatal(err) + } + + if !recovered.Equal(comm) { + test.Fatal("recovered commit does not match initial value") + } + }) + } +} + +func TestPublicRecoveryDeleteFail(test *testing.T) { + for _, g := range groups { + test.Run(g.String(), func(test *testing.T) { + n := 10 + t := n/2 + 1 + + priPoly := NewPriPoly(g, t, nil, randStream(test)) + shares := pubShares(g, priPoly.Shares(n)) + + // Corrupt one more share than acceptable + shares[1] = nil + shares[2] = nil + shares[5] = nil + shares[7] = nil + shares[8] = nil + + _, err := RecoverCommit(g, shares, t, n) + if err == nil { + test.Fatal("recovered commit unexpectably") + } + }) + } +} diff --git a/go/tdh2/internal/group/test/group.go b/go/tdh2/internal/group/test/group.go new file mode 100644 index 0000000..fd18a55 --- /dev/null +++ b/go/tdh2/internal/group/test/group.go @@ -0,0 +1,153 @@ +package test + +import ( + "github.com/goplugin/tdh2/go/tdh2/internal/group" +) + +// GroupBench is a generic benchmark suite for group.groups. +type GroupBench struct { + G group.Group + + // Random secrets and points for testing + x, y group.Scalar + X, Y group.Point + xe []byte // encoded Scalar + Xe []byte // encoded Point +} + +// NewGroupBench returns a new GroupBench. +func NewGroupBench(g group.Group) *GroupBench { + var gb GroupBench + rng := randomNew() + gb.G = g + gb.x = g.Scalar().Pick(rng) + gb.y = g.Scalar().Pick(rng) + gb.xe, _ = gb.x.MarshalBinary() + gb.X = g.Point().Pick(rng) + gb.Y = g.Point().Pick(rng) + gb.Xe, _ = gb.X.MarshalBinary() + return &gb +} + +func (gb GroupBench) String() string { + return gb.G.String() +} + +// ScalarAdd benchmarks the addition operation for scalars +func (gb GroupBench) ScalarAdd(iters int) { + for i := 1; i < iters; i++ { + gb.x.Add(gb.x, gb.y) + } +} + +// ScalarSub benchmarks the substraction operation for scalars +func (gb GroupBench) ScalarSub(iters int) { + for i := 1; i < iters; i++ { + gb.x.Sub(gb.x, gb.y) + } +} + +// ScalarNeg benchmarks the negation operation for scalars +func (gb GroupBench) ScalarNeg(iters int) { + for i := 1; i < iters; i++ { + gb.x.Neg(gb.x) + } +} + +// ScalarMul benchmarks the multiplication operation for scalars +func (gb GroupBench) ScalarMul(iters int) { + for i := 1; i < iters; i++ { + gb.x.Mul(gb.x, gb.y) + } +} + +// ScalarDiv benchmarks the division operation for scalars +func (gb GroupBench) ScalarDiv(iters int) { + for i := 1; i < iters; i++ { + gb.x.Div(gb.x, gb.y) + } +} + +// ScalarInv benchmarks the inverse operation for scalars +func (gb GroupBench) ScalarInv(iters int) { + for i := 1; i < iters; i++ { + gb.x.Inv(gb.x) + } +} + +// ScalarPick benchmarks the Pick operation for scalars +func (gb GroupBench) ScalarPick(iters int) { + for i := 1; i < iters; i++ { + gb.x.Pick(randomNew()) + } +} + +// ScalarEncode benchmarks the marshalling operation for scalars +func (gb GroupBench) ScalarEncode(iters int) { + for i := 1; i < iters; i++ { + _, _ = gb.x.MarshalBinary() + } +} + +// ScalarDecode benchmarks the unmarshalling operation for scalars +func (gb GroupBench) ScalarDecode(iters int) { + for i := 1; i < iters; i++ { + _ = gb.x.UnmarshalBinary(gb.xe) + } +} + +// PointAdd benchmarks the addition operation for points +func (gb GroupBench) PointAdd(iters int) { + for i := 1; i < iters; i++ { + gb.X.Add(gb.X, gb.Y) + } +} + +// PointSub benchmarks the substraction operation for points +func (gb GroupBench) PointSub(iters int) { + for i := 1; i < iters; i++ { + gb.X.Sub(gb.X, gb.Y) + } +} + +// PointNeg benchmarks the negation operation for points +func (gb GroupBench) PointNeg(iters int) { + for i := 1; i < iters; i++ { + gb.X.Neg(gb.X) + } +} + +// PointMul benchmarks the multiplication operation for points +func (gb GroupBench) PointMul(iters int) { + for i := 1; i < iters; i++ { + gb.X.Mul(gb.y, gb.X) + } +} + +// PointBaseMul benchmarks the base multiplication operation for points +func (gb GroupBench) PointBaseMul(iters int) { + for i := 1; i < iters; i++ { + gb.X.Mul(gb.y, nil) + } +} + +// PointPick benchmarks the pick-ing operation for points +func (gb GroupBench) PointPick(iters int) { + for i := 1; i < iters; i++ { + gb.X.Pick(randomNew()) + } +} + +// PointEncode benchmarks the encoding operation for points +func (gb GroupBench) PointEncode(iters int) { + for i := 1; i < iters; i++ { + _, _ = gb.X.MarshalBinary() + } +} + +// PointDecode benchmarks the decoding operation for points +func (gb GroupBench) PointDecode(iters int) { + for i := 1; i < iters; i++ { + _ = gb.X.UnmarshalBinary(gb.Xe) + } +} diff --git a/go/tdh2/internal/group/test/test.go b/go/tdh2/internal/group/test/test.go new file mode 100644 index 0000000..cdaffc1 --- /dev/null +++ b/go/tdh2/internal/group/test/test.go @@ -0,0 +1,349 @@ +// Package test contains generic testing and benchmarking infrastructure +// for cryptographic groups and ciphersuites. +package test + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "testing" + + "github.com/goplugin/tdh2/go/tdh2/internal/group" +) + +func testPointSet(t *testing.T, g group.Group, rand cipher.Stream) { + N := 1000 + null := g.Point().Null() + for i := 0; i < N; i++ { + P1 := g.Point().Pick(rand) + P2 := g.Point() + P2.Set(P1) + if !P1.Equal(P2) { + t.Errorf("Set() set to a different point: %v != %v", P1, P2) + } + if !P1.Equal(null) { + P1.Add(P1, P1) + if P1.Equal(P2) { + t.Errorf("Modifying P1 shouldn't modify P2: %v == %v", P1, P2) + } + } + } +} + +func testPointClone(t *testing.T, g group.Group, rand cipher.Stream) { + N := 1000 + null := g.Point().Null() + for i := 0; i < N; i++ { + P1 := g.Point().Pick(rand) + P2 := P1.Clone() + if !P1.Equal(P2) { + t.Errorf("Clone didn't work for point: %v != %v", P1, P2) + } + if !P1.Equal(null) { + P1.Add(P1, P1) + if P1.Equal(P2) { + t.Errorf("Modifying P1 shouldn't modify P2: %v == %v", P1, P2) + } + } + } +} + +func testScalarSet(t *testing.T, g group.Group, rand cipher.Stream) { + N := 1000 + zero := g.Scalar().Zero() + one := g.Scalar().One() + for i := 0; i < N; i++ { + s1 := g.Scalar().Pick(rand) + s2 := s1.Clone() + if !s1.Equal(s2) { + t.Errorf("Set() set to a different scalar: %v != %v", s1, s2) + } + if !s1.Equal(zero) && !s1.Equal(one) { + s1.Mul(s1, s1) + if s1.Equal(s2) { + t.Errorf("Modifying s1 shouldn't modify s2: %v == %v", s1, s2) + } + } + } +} + +func testScalarClone(t *testing.T, g group.Group, rand cipher.Stream) { + N := 1000 + zero := g.Scalar().Zero() + one := g.Scalar().One() + for i := 0; i < N; i++ { + s1 := g.Scalar().Pick(rand) + s2 := s1.Clone() + if !s1.Equal(s2) { + t.Errorf("Clone didn't work for scalar: %v != %v", s1, s2) + } + if !s1.Equal(zero) && !s1.Equal(one) { + s1.Mul(s1, s1) + if s1.Equal(s2) { + t.Errorf("Modifying s1 shouldn't modify s2: %v == %v", s1, s2) + } + } + } +} + +// Apply a generic set of validation tests to a cryptographic Group, +// using a given source of [pseudo-]randomness. +// +// Returns a log of the pseudorandom Points produced in the test, +// for comparison across alternative implementations +// that are supposed to be equivalent. +func testGroup(t *testing.T, g group.Group, rand cipher.Stream) []group.Point { + t.Logf("\nTesting group '%s': %d-byte Point, %d-byte Scalar\n", + g.String(), g.PointLen(), g.ScalarLen()) + + points := make([]group.Point, 0) + ptmp := g.Point() + stmp := g.Scalar() + pzero := g.Point().Null() + szero := g.Scalar().Zero() + sone := g.Scalar().One() + + // Do a simple Diffie-Hellman test + s1 := g.Scalar().Pick(rand) + s2 := g.Scalar().Pick(rand) + if s1.Equal(szero) { + t.Errorf("first secret is scalar zero %v", s1) + } + if s2.Equal(szero) { + t.Errorf("second secret is scalar zero %v", s2) + } + if s1.Equal(s2) { + t.Errorf("not getting unique secrets: picked %s twice", s1) + } + + gen := g.Point().Base() + points = append(points, gen) + + // Sanity-check relationship between addition and multiplication + p1 := g.Point().Add(gen, gen) + p2 := g.Point().Mul(stmp.SetInt64(2), nil) + if !p1.Equal(p2) { + t.Errorf("multiply by two doesn't work: %v == %v (+) %[2]v != %[2]v (x) 2 == %v", p1, gen, p2) + } + p1.Add(p1, p1) + p2.Mul(stmp.SetInt64(4), nil) + if !p1.Equal(p2) { + t.Errorf("multiply by four doesn't work: %v (+) %[1]v != %v (x) 4 == %v", + g.Point().Add(gen, gen), gen, p2) + } + points = append(points, p1) + + // Find out if this curve has a prime order: + // if the curve does not offer a method IsPrimeOrder, + // then assume that it is. + type canCheckPrimeOrder interface { + IsPrimeOrder() bool + } + primeOrder := true + if gpo, ok := g.(canCheckPrimeOrder); ok { + primeOrder = gpo.IsPrimeOrder() + } + + // Verify additive and multiplicative identities of the generator. + ptmp.Mul(stmp.SetInt64(-1), nil).Add(ptmp, gen) + if !ptmp.Equal(pzero) { + t.Errorf("generator additive identity doesn't work: %v (x) -1 (+) %v != %v the group point identity", + ptmp.Mul(stmp.SetInt64(-1), nil), gen, pzero) + } + // secret.Inv works only in prime-order groups + if primeOrder { + ptmp.Mul(stmp.SetInt64(2), nil).Mul(stmp.Inv(stmp), ptmp) + if !ptmp.Equal(gen) { + t.Errorf("generator multiplicative identity doesn't work:\n%v (x) %v = %v\n%[3]v (x) %v = %v", + ptmp.Base().String(), stmp.SetInt64(2).String(), + ptmp.Mul(stmp.SetInt64(2), nil).String(), + stmp.Inv(stmp).String(), + ptmp.Mul(stmp.SetInt64(2), nil).Mul(stmp.Inv(stmp), ptmp).String()) + } + } + + p1.Mul(s1, gen) + p2.Mul(s2, gen) + if p1.Equal(p2) { + t.Errorf("encryption isn't producing unique points: %v (x) %v == %v (x) %[2]v == %[4]v", s1, gen, s2, p1) + } + points = append(points, p1) + + dh1 := g.Point().Mul(s2, p1) + dh2 := g.Point().Mul(s1, p2) + if !dh1.Equal(dh2) { + t.Errorf("Diffie-Hellman didn't work: %v == %v (x) %v != %v (x) %v == %v", dh1, s2, p1, s1, p2, dh2) + } + points = append(points, dh1) + t.Logf("shared secret = %v", dh1) + + // Test secret inverse to get from dh1 back to p1 + if primeOrder { + ptmp.Mul(g.Scalar().Inv(s2), dh1) + if !ptmp.Equal(p1) { + t.Errorf("Scalar inverse didn't work: %v != (-)%v (x) %v == %v", p1, s2, dh1, ptmp) + } + } + + // Zero and One identity secrets + //println("dh1^0 = ",ptmp.Mul(dh1, szero).String()) + if !ptmp.Mul(szero, dh1).Equal(pzero) { + t.Errorf("Encryption with secret=0 didn't work: %v (x) %v == %v != %v", szero, dh1, ptmp, pzero) + } + if !ptmp.Mul(sone, dh1).Equal(dh1) { + t.Errorf("Encryption with secret=1 didn't work: %v (x) %v == %v != %[2]v", sone, dh1, ptmp) + } + + // Additive homomorphic identities + ptmp.Add(p1, p2) + stmp.Add(s1, s2) + pt2 := g.Point().Mul(stmp, gen) + if !pt2.Equal(ptmp) { + t.Errorf("Additive homomorphism doesn't work: %v + %v == %v, %[3]v (x) %v == %v != %v == %v (+) %v", + s1, s2, stmp, gen, pt2, ptmp, p1, p2) + } + ptmp.Sub(p1, p2) + stmp.Sub(s1, s2) + pt2.Mul(stmp, gen) + if !pt2.Equal(ptmp) { + t.Errorf("Additive homomorphism doesn't work: %v - %v == %v, %[3]v (x) %v == %v != %v == %v (-) %v", + s1, s2, stmp, gen, pt2, ptmp, p1, p2) + } + st2 := g.Scalar().Neg(s2) + st2.Add(s1, st2) + if !stmp.Equal(st2) { + t.Errorf("Scalar.Neg doesn't work: -%v == %v, %[2]v + %v == %v != %v", + s2, g.Scalar().Neg(s2), s1, st2, stmp) + } + pt2.Neg(p2).Add(pt2, p1) + if !pt2.Equal(ptmp) { + t.Errorf("Point.Neg doesn't work: (-)%v == %v, %[2]v (+) %v == %v != %v", + p2, g.Point().Neg(p2), p1, pt2, ptmp) + } + + // Multiplicative homomorphic identities + stmp.Mul(s1, s2) + if !ptmp.Mul(stmp, gen).Equal(dh1) { + t.Errorf("Multiplicative homomorphism doesn't work: %v * %v == %v, %[2]v (x) %v == %v != %v", + s1, s2, stmp, gen, ptmp, dh1) + } + if primeOrder { + st2.Inv(s2) + st2.Mul(st2, stmp) + if !st2.Equal(s1) { + t.Errorf("Scalar division doesn't work: %v^-1 * %v == %v * %[2]v == %[4]v != %v", + s2, stmp, g.Scalar().Inv(s2), st2, s1) + } + st2.Div(stmp, s2) + if !st2.Equal(s1) { + t.Errorf("Scalar division doesn't work: %v / %v == %v != %v", + stmp, s2, st2, s1) + } + } + + // Test randomly picked points + last := gen + for i := 0; i < 5; i++ { + rgen := g.Point().Pick(rand) + if rgen.Equal(last) { + t.Errorf("Pick() not producing unique points: got %v twice", rgen) + } + last = rgen + + ptmp.Mul(stmp.SetInt64(-1), rgen).Add(ptmp, rgen) + if !ptmp.Equal(pzero) { + t.Errorf("random generator fails additive identity: %v (x) %v == %v, %v (+) %[3]v == %[5]v != %v", + g.Scalar().SetInt64(-1), rgen, g.Point().Mul(g.Scalar().SetInt64(-1), rgen), + rgen, g.Point().Mul(g.Scalar().SetInt64(-1), rgen), pzero) + } + if primeOrder { + ptmp.Mul(stmp.SetInt64(2), rgen).Mul(stmp.Inv(stmp), ptmp) + if !ptmp.Equal(rgen) { + t.Errorf("random generator fails multiplicative identity: %v (x) (2 (x) %v) == %v != %[2]v", + stmp, rgen, ptmp) + } + } + points = append(points, rgen) + } + + // Test verifiable secret sharing + + // Test encoding and decoding + for i := 0; i < 5; i++ { + s := g.Scalar().Pick(rand) + b, err := s.MarshalBinary() + if err != nil { + t.Errorf("encoding of secret fails: " + err.Error()) + } + if err := stmp.UnmarshalBinary(b); err != nil { + t.Errorf("decoding of secret fails: " + err.Error()) + } + if !stmp.Equal(s) { + t.Errorf("decoding produces different secret than encoded") + } + + p := g.Point().Pick(rand) + b, err = p.MarshalBinary() + if err != nil { + t.Errorf("encoding of point fails: " + err.Error()) + } + if err := ptmp.UnmarshalBinary(b); err != nil { + t.Errorf("decoding of point fails: " + err.Error()) + } + if !ptmp.Equal(p) { + t.Errorf("decoding produces different point than encoded") + } + } + + // Test that we can marshal/ unmarshal null point + pzero = g.Point().Null() + b, _ := pzero.MarshalBinary() + repzero := g.Point() + err := repzero.UnmarshalBinary(b) + if err != nil { + t.Errorf("Could not unmarshall binary %v: %v", b, err) + } + + testPointSet(t, g, rand) + testPointClone(t, g, rand) + testScalarSet(t, g, rand) + testScalarClone(t, g, rand) + + return points +} + +// GroupTest applies a generic set of validation tests to a cryptographic Group. +func GroupTest(t *testing.T, g group.Group) { + testGroup(t, g, randomNew()) +} + +// CompareGroups tests two group implementations that are supposed to be equivalent, +// and compare their results. +func CompareGroups(t *testing.T, fn func(key []byte) cipher.Stream, g1, g2 group.Group) { + // Produce test results from the same pseudorandom seed + r1 := testGroup(t, g1, fn(nil)) + r2 := testGroup(t, g2, fn(nil)) + + // Compare resulting Points + for i := range r1 { + b1, _ := r1[i].MarshalBinary() + b2, _ := r2[i].MarshalBinary() + if !bytes.Equal(b1, b2) { + t.Errorf("unequal result-pair %v\n1: %v\n2: %v", + i, r1[i], r2[i]) + } + } +} + +func randomNew() cipher.Stream { + block, err := aes.NewCipher(make([]byte, 16)) + if err != nil { + panic(err) + } + iv := make([]byte, aes.BlockSize) + if _, err := rand.Read(iv); err != nil { + panic(err) + } + return cipher.NewCTR(block, iv) +} diff --git a/go/tdh2/tdh2/tdh2.go b/go/tdh2/tdh2/tdh2.go index 0d104df..d171994 100644 --- a/go/tdh2/tdh2/tdh2.go +++ b/go/tdh2/tdh2/tdh2.go @@ -9,9 +9,9 @@ import ( "encoding/json" "fmt" - "go.dedis.ch/kyber/v3" - "go.dedis.ch/kyber/v3/group/nist" - "go.dedis.ch/kyber/v3/share" + "github.com/goplugin/tdh2/go/tdh2/internal/group" + "github.com/goplugin/tdh2/go/tdh2/internal/group/nist" + "github.com/goplugin/tdh2/go/tdh2/internal/group/share" ) var ( @@ -22,19 +22,23 @@ var ( InputSize = defaultHash().Size() ) -func parseGroup(group string) (kyber.Group, error) { +func parseGroup(group string) (group.Group, error) { switch group { - case nist.NewBlakeSHA256P256().String(): - return nist.NewBlakeSHA256P256(), nil + case nist.NewP256().String(): + return nist.NewP256(), nil + case nist.NewP384().String(): + return nist.NewP384(), nil + case nist.NewP521().String(): + return nist.NewP521(), nil } return nil, fmt.Errorf("unsupported group: %q", group) } -// PrivateShare is a node's private share. It extends kyber's share.PriShare. +// PrivateShare is a node's private share. It extends group.s share.PriShare. type PrivateShare struct { - group kyber.Group + group group.Group index int - v kyber.Scalar + v group.Scalar } func (s PrivateShare) String() string { @@ -47,12 +51,12 @@ func (s PrivateShare) Index() int { // mulPoint returns a new point with value v*p, where v is a private scalar. // If p==nil, the returned point has value v*BasePoint. -func (s *PrivateShare) mulPoint(p kyber.Point) kyber.Point { +func (s *PrivateShare) mulPoint(p group.Point) group.Point { return s.group.Point().Mul(s.v, p) } // mulScalar returns a new scalar with value v*a where v is a private scalar. -func (s *PrivateShare) mulScalar(a kyber.Scalar) kyber.Scalar { +func (s *PrivateShare) mulScalar(a group.Scalar) group.Scalar { return s.group.Scalar().Mul(s.v, a) } @@ -88,7 +92,8 @@ func (s *PrivateShare) Unmarshal(data []byte) error { } s.index = raw.Index - if s.v, err = unmarshalScalar(s.group, raw.V); err != nil { + s.v = s.group.Scalar() + if err = s.v.UnmarshalBinary(raw.V); err != nil { return fmt.Errorf("cannot unmarshal: %w", err) } return nil @@ -96,10 +101,10 @@ func (s *PrivateShare) Unmarshal(data []byte) error { // PubliKey defines a public and verification key. type PublicKey struct { - group kyber.Group - g_bar kyber.Point - h kyber.Point - hArray []kyber.Point + group group.Group + g_bar group.Point + h group.Point + hArray []group.Point } func (a *PublicKey) Equal(b *PublicKey) bool { @@ -119,8 +124,8 @@ func (a *PublicKey) Equal(b *PublicKey) bool { // MasterSecret keeps the master secret of a TDH2 instance. type MasterSecret struct { - group kyber.Group - s kyber.Scalar + group group.Group + s group.Scalar } func (m MasterSecret) String() string { @@ -209,18 +214,20 @@ func (p *PublicKey) Unmarshal(data []byte) error { return fmt.Errorf("cannot parse group: %w", err) } - if p.g_bar, err = unmarshalPoint(p.group, raw.G_bar); err != nil { + p.g_bar = p.group.Point() + if err = p.g_bar.UnmarshalBinary(raw.G_bar); err != nil { return fmt.Errorf("unmarshaling G_bar: %w", err) } - if p.h, err = unmarshalPoint(p.group, raw.H); err != nil { + p.h = p.group.Point() + if err = p.h.UnmarshalBinary(raw.H); err != nil { return fmt.Errorf("unmarshaling H: %w", err) } - p.hArray = []kyber.Point{} + p.hArray = []group.Point{} for _, h := range raw.HArray { - new, err := unmarshalPoint(p.group, h) - if err != nil { + new := p.group.Point() + if err = new.UnmarshalBinary(h); err != nil { return fmt.Errorf("cannot unmarshal point: %w", err) } p.hArray = append(p.hArray, new) @@ -233,28 +240,28 @@ func (p *PublicKey) Unmarshal(data []byte) error { // It takes cryptographic group to be used, master secret to be used (if nil, a new secret is generated), // the total number of nodes n, the number of shares sufficient for decryption k, and a randomness source. // It returns the master secret (either passed or generated), public key, and secret key shares. -func GenerateKeys(group kyber.Group, ms *MasterSecret, k, n int, rand cipher.Stream) (*MasterSecret, *PublicKey, []*PrivateShare, error) { +func GenerateKeys(grp group.Group, ms *MasterSecret, k, n int, rand cipher.Stream) (*MasterSecret, *PublicKey, []*PrivateShare, error) { if k > n { return nil, nil, nil, fmt.Errorf("threshold is higher than total number of nodes") } if k <= 0 { return nil, nil, nil, fmt.Errorf("threshold has to be positive") } - if ms != nil && group.String() != ms.group.String() { + if ms != nil && grp.String() != ms.group.String() { return nil, nil, nil, fmt.Errorf("inconsistent groups") } - var s kyber.Scalar + var s group.Scalar if ms != nil { s = ms.s } - poly := share.NewPriPoly(group, k, s, rand) + poly := share.NewPriPoly(grp, k, s, rand) x := poly.Secret() if ms != nil && !x.Equal(ms.s) { return nil, nil, nil, fmt.Errorf("generated wrong secret") } - HArray := make([]kyber.Point, n) + HArray := make([]group.Point, n) shares := poly.Shares(n) privShares := []*PrivateShare{} // IDs are assigned consecutively from 0. @@ -262,17 +269,17 @@ func GenerateKeys(group kyber.Group, ms *MasterSecret, k, n int, rand cipher.Str if i != s.I { return nil, nil, nil, fmt.Errorf("share index=%d, expect=%d", s.I, i) } - HArray[i] = group.Point().Mul(s.V, nil) - privShares = append(privShares, &PrivateShare{group, s.I, s.V}) + HArray[i] = grp.Point().Mul(s.V, nil) + privShares = append(privShares, &PrivateShare{grp, s.I, s.V}) } return &MasterSecret{ - group: group, + group: grp, s: x}, &PublicKey{ - group: group, - g_bar: group.Point().Pick(rand), - h: group.Point().Mul(x, nil), + group: grp, + g_bar: grp.Point().Pick(rand), + h: grp.Point().Mul(x, nil), hArray: HArray, }, privShares, nil } @@ -371,13 +378,13 @@ func checkEi(pk *PublicKey, ctxt *Ciphertext, share *DecryptionShare) error { // Ciphertext defines a ciphertext as output from the Encryption algorithm. type Ciphertext struct { - group kyber.Group + group group.Group c []byte label []byte - u kyber.Point - u_bar kyber.Point - e kyber.Scalar - f kyber.Scalar + u group.Point + u_bar group.Point + e group.Scalar + f group.Scalar } // Verify checks if the ciphertext matches the public key @@ -414,7 +421,7 @@ func (a *Ciphertext) Equal(b *Ciphertext) bool { // Decrypt decrypts a ciphertext using a secret key share x_i according to TDH2 paper. // The caller has to ensure that the ciphertext is validated. -func (ctxt *Ciphertext) Decrypt(group kyber.Group, x_i *PrivateShare, rand cipher.Stream) (*DecryptionShare, error) { +func (ctxt *Ciphertext) Decrypt(group group.Group, x_i *PrivateShare, rand cipher.Stream) (*DecryptionShare, error) { if group.String() != ctxt.group.String() { return nil, fmt.Errorf("incorrect ciphertext group: %q", ctxt.group) } @@ -442,7 +449,7 @@ func (ctxt *Ciphertext) Decrypt(group kyber.Group, x_i *PrivateShare, rand ciphe // CombineShares combines a set of decryption shares and returns the decrypted message. // The caller has to ensure that the ciphertext is validated. -func (c *Ciphertext) CombineShares(group kyber.Group, shares []*DecryptionShare, k, n int) ([]byte, error) { +func (c *Ciphertext) CombineShares(group group.Group, shares []*DecryptionShare, k, n int) ([]byte, error) { if group.String() != c.group.String() { return nil, fmt.Errorf("incorrect ciphertext group: %q", c.group) } @@ -526,16 +533,20 @@ func (c *Ciphertext) Unmarshal(data []byte) error { if err != nil { return fmt.Errorf("cannot parse group: %w", err) } - if c.e, err = unmarshalScalar(c.group, raw.E); err != nil { + c.e = c.group.Scalar() + if err = c.e.UnmarshalBinary(raw.E); err != nil { return fmt.Errorf("cannot unmarshal E: %w", err) } - if c.u, err = unmarshalPoint(c.group, raw.U); err != nil { + c.u = c.group.Point() + if err = c.u.UnmarshalBinary(raw.U); err != nil { return fmt.Errorf("cannot unmarshal U: %w", err) } - if c.u_bar, err = unmarshalPoint(c.group, raw.U_bar); err != nil { + c.u_bar = c.group.Point() + if err = c.u_bar.UnmarshalBinary(raw.U_bar); err != nil { return fmt.Errorf("cannot unmarshal U_bar: %w", err) } - if c.f, err = unmarshalScalar(c.group, raw.F); err != nil { + c.f = c.group.Scalar() + if err = c.f.UnmarshalBinary(raw.F); err != nil { return fmt.Errorf("cannot unmarshal F: %w", err) } return nil @@ -543,11 +554,11 @@ func (c *Ciphertext) Unmarshal(data []byte) error { // DecryptionShare defines a decryption share type DecryptionShare struct { - group kyber.Group + group group.Group index int - u_i kyber.Point - e_i kyber.Scalar - f_i kyber.Scalar + u_i group.Point + e_i group.Scalar + f_i group.Scalar } // TODO(pszal): test + fix tests which currently ignore share equality @@ -608,13 +619,16 @@ func (d *DecryptionShare) Unmarshal(data []byte) error { if err != nil { return fmt.Errorf("cannot parse group: %w", err) } - if d.e_i, err = unmarshalScalar(d.group, raw.E_i); err != nil { + d.e_i = d.group.Scalar() + if err = d.e_i.UnmarshalBinary(raw.E_i); err != nil { return fmt.Errorf("cannot unmarshal E_i: %w", err) } - if d.u_i, err = unmarshalPoint(d.group, raw.U_i); err != nil { + d.u_i = d.group.Point() + if err = d.u_i.UnmarshalBinary(raw.U_i); err != nil { return fmt.Errorf("cannot unmarshal U_i: %w", err) } - if d.f_i, err = unmarshalScalar(d.group, raw.F_i); err != nil { + d.f_i = d.group.Scalar() + if err = d.f_i.UnmarshalBinary(raw.F_i); err != nil { return fmt.Errorf("cannot unmarshal F_i: %w", err) } return nil @@ -628,7 +642,7 @@ func hash(msg []byte) []byte { } // hash1 is an implementation of the H_1 hash function (see p15 of the paper). -func hash1(group string, g kyber.Point) ([]byte, error) { +func hash1(group string, g group.Point) ([]byte, error) { point, err := concatenate(group, g) if err != nil { return nil, fmt.Errorf("cannot concatenate points: %w", err) @@ -637,7 +651,7 @@ func hash1(group string, g kyber.Point) ([]byte, error) { } // hash2 is an implementation of the H_2 hash function (see p15 of the paper). -func hash2(msg, label []byte, g1, g2, g3, g4 kyber.Point, group kyber.Group) (kyber.Scalar, error) { +func hash2(msg, label []byte, g1, g2, g3, g4 group.Point, group group.Group) (group.Scalar, error) { if len(msg) != len(label) || len(msg) != InputSize { return nil, fmt.Errorf("message and label must be %dB long", InputSize) } @@ -655,7 +669,7 @@ func hash2(msg, label []byte, g1, g2, g3, g4 kyber.Point, group kyber.Group) (ky } // hash4 is an implementation of the H_4 hash function (see p15 of the paper). -func hash4(g1, g2, g3 kyber.Point, group kyber.Group) (kyber.Scalar, error) { +func hash4(g1, g2, g3 group.Point, group group.Group) (group.Scalar, error) { points, err := concatenate(group.String(), g1, g2, g3) if err != nil { return nil, fmt.Errorf("cannot concatenate points: %w", err) @@ -667,7 +681,7 @@ func hash4(g1, g2, g3 kyber.Point, group kyber.Group) (kyber.Scalar, error) { // concatenate marshals and concatenates points (elements of a group). It is // used in hash functions. -func concatenate(group string, points ...kyber.Point) ([]byte, error) { +func concatenate(group string, points ...group.Point) ([]byte, error) { final := group for _, point := range points { p, err := point.MarshalBinary() @@ -690,27 +704,3 @@ func xor(a, b []byte) ([]byte, error) { } return buf, nil } - -// unmarshalPoint unmarshals point safely, i.e., avoiding panics present in the kyber lib. -func unmarshalPoint(g kyber.Group, data []byte) (kyber.Point, error) { - if len(data) != g.PointLen() { - return nil, fmt.Errorf("incorrect length") - } - p := g.Point() - if err := p.UnmarshalBinary(data); err != nil { - return nil, err - } - return p, nil -} - -// unmarshalScalar unmarshals scalar safely, i.e., avoiding panics present in the kyber lib. -func unmarshalScalar(g kyber.Group, data []byte) (kyber.Scalar, error) { - if len(data) != g.ScalarLen() { - return nil, fmt.Errorf("incorrect length") - } - s := g.Scalar() - if err := s.UnmarshalBinary(data); err != nil { - return nil, err - } - return s, nil -} diff --git a/go/tdh2/tdh2/tdh2_test.go b/go/tdh2/tdh2/tdh2_test.go index 5d4e571..e315f6d 100644 --- a/go/tdh2/tdh2/tdh2_test.go +++ b/go/tdh2/tdh2/tdh2_test.go @@ -2,6 +2,7 @@ package tdh2 import ( "bytes" + "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/json" @@ -11,20 +12,43 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "go.dedis.ch/kyber/v3" - "go.dedis.ch/kyber/v3/group/nist" - "go.dedis.ch/kyber/v3/xof/keccak" + "github.com/goplugin/tdh2/go/tdh2/internal/group" + "github.com/goplugin/tdh2/go/tdh2/internal/group/nist" ) var supportedGroups = []string{ - nist.NewBlakeSHA256P256().String(), + nist.NewP256().String(), + nist.NewP384().String(), + nist.NewP521().String(), +} + +// unsupported implements an unsupported group +type unsupported nist.P256 + +func (u *unsupported) String() string { + return "unsupported" +} +func newUnsupported() *unsupported { + return (*unsupported)(nist.NewP256()) } type common interface { Fatalf(format string, args ...interface{}) } -func params(t common, group string) (kyber.Group, cipher.Stream, []byte, []byte) { +func randStream(t common) cipher.Stream { + block, err := aes.NewCipher(make([]byte, 16)) + if err != nil { + t.Fatalf("NewCipher: %v", err) + } + iv := make([]byte, aes.BlockSize) + if _, err := rand.Read(iv); err != nil { + t.Fatalf("Read: %w", err) + } + return cipher.NewCTR(block, iv) +} + +func params(t common, group string) (group.Group, cipher.Stream, []byte, []byte) { if _, ok := t.(*testing.T); ok { t.(*testing.T).Helper() } @@ -44,7 +68,7 @@ func params(t common, group string) (kyber.Group, cipher.Stream, []byte, []byte) if err != nil { t.Fatalf("parseGroup: %v", err) } - return g, keccak.New(seed), msg, label + return g, randStream(t), msg, label } func TestConcatenate(t *testing.T) { @@ -216,7 +240,7 @@ func TestGenerateKeys(t *testing.T) { { name: "secret wrong group", ms: &MasterSecret{ - group: nist.NewBlakeSHA256QR512(), + group: newUnsupported(), s: group.Scalar().Pick(rand)}, k: 1, n: 1, @@ -312,7 +336,7 @@ func TestEncrypt(t *testing.T) { } func TestDecrypt(t *testing.T) { - wrong := nist.NewBlakeSHA256QR512() + wrong := newUnsupported() for _, typ := range supportedGroups { group, rand, msg, label := params(t, typ) _, pk, shares, err := GenerateKeys(group, nil, 3, 5, rand) @@ -391,7 +415,7 @@ func TestCtxtVerify(t *testing.T) { { name: "wrong group", ctxt: &Ciphertext{ - group: nist.NewBlakeSHA256QR512(), + group: newUnsupported(), c: ctxt.c, label: ctxt.label, u: ctxt.u, @@ -593,7 +617,7 @@ func TestVerifyShare(t *testing.T) { if err != nil { t.Fatalf("GenerateKeys: %v", err) } - wrong := nist.NewBlakeSHA256QR512() + wrong := newUnsupported() ctxt, err := Encrypt(pk, msg, label, rand) if err != nil { t.Fatalf("Encrypt: %v", err) @@ -695,7 +719,7 @@ func TestCombineShares(t *testing.T) { if err != nil { t.Fatalf("GenerateKeys: %v", err) } - wrong := nist.NewBlakeSHA256QR512() + wrong := newUnsupported() _, pkWrong, _, err := GenerateKeys(group, nil, 3, 5, rand) if err != nil { t.Fatalf("GenerateKeys: %v", err) @@ -808,10 +832,20 @@ func TestCombineShares(t *testing.T) { func TestParseGroup(t *testing.T) { for _, tc := range []struct { group string + want group.Group err error }{ { - group: nist.NewBlakeSHA256P256().String(), + group: nist.NewP256().String(), + want: nist.NewP256(), + }, + { + group: nist.NewP384().String(), + want: nist.NewP384(), + }, + { + group: nist.NewP521().String(), + want: nist.NewP521(), }, { group: "wrong", @@ -819,8 +853,14 @@ func TestParseGroup(t *testing.T) { }, } { t.Run(fmt.Sprintf("group=%v", tc.group), func(t *testing.T) { - if _, err := parseGroup(tc.group); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + got, err := parseGroup(tc.group) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { t.Errorf("got err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + if reflect.TypeOf(got) != reflect.TypeOf(tc.want) { + t.Errorf("got %T, want %T", got, tc.want) } }) } @@ -913,13 +953,13 @@ func TestPublicKeyMarshal(t *testing.T) { group: g, g_bar: g.Point().Pick(r), h: g.Point().Pick(r), - hArray: []kyber.Point{g.Point().Pick(r)}, + hArray: []group.Point{g.Point().Pick(r)}, }, { group: g, g_bar: g.Point().Pick(r), h: g.Point().Pick(r), - hArray: []kyber.Point{g.Point().Pick(r), g.Point().Pick(r), g.Point().Pick(r)}, + hArray: []group.Point{g.Point().Pick(r), g.Point().Pick(r), g.Point().Pick(r)}, }, } { t.Run(fmt.Sprintf("i=%d group=%v", i, typ), func(t *testing.T) { @@ -1531,7 +1571,7 @@ func TestRedeal(t *testing.T) { k: 2, n: 5, ms: &MasterSecret{ - group: nist.NewBlakeSHA256QR512(), + group: newUnsupported(), s: ms.s, }, err: cmpopts.AnyError, @@ -1746,7 +1786,7 @@ func FuzzPrivateShareMarshal(f *testing.F) { f.Add(mustMarshal(f, PrivateShare{ group: g, index: 123, - v: g.Scalar().Pick(keccak.New(nil)), + v: g.Scalar().Pick(randStream(f)), })) } f.Fuzz(func(t *testing.T, data []byte) { @@ -1775,7 +1815,7 @@ func FuzzPrivateShareMarshal(f *testing.F) { } func FuzzPublicKeyMarshal(f *testing.F) { - r := keccak.New(nil) + r := randStream(f) for _, groupStr := range supportedGroups { g, err := parseGroup(groupStr) if err != nil { @@ -1790,13 +1830,13 @@ func FuzzPublicKeyMarshal(f *testing.F) { group: g, g_bar: g.Point().Pick(r), h: g.Point().Pick(r), - hArray: []kyber.Point{g.Point().Pick(r)}, + hArray: []group.Point{g.Point().Pick(r)}, })) f.Add(mustMarshal(f, PublicKey{ group: g, g_bar: g.Point().Pick(r), h: g.Point().Pick(r), - hArray: []kyber.Point{g.Point().Pick(r), g.Point().Pick(r)}, + hArray: []group.Point{g.Point().Pick(r), g.Point().Pick(r)}, })) } f.Fuzz(func(t *testing.T, data []byte) { @@ -1825,7 +1865,7 @@ func FuzzPublicKeyMarshal(f *testing.F) { } func FuzzCiphertextMarshal(f *testing.F) { - r := keccak.New(nil) + r := randStream(f) for _, groupStr := range supportedGroups { g, err := parseGroup(groupStr) if err != nil { @@ -1867,7 +1907,7 @@ func FuzzCiphertextMarshal(f *testing.F) { } func FuzzDecryptionShareMarshal(f *testing.F) { - r := keccak.New(nil) + r := randStream(f) for _, groupStr := range supportedGroups { g, err := parseGroup(groupStr) if err != nil { diff --git a/js/tdh2/test/js_test.go b/go/tdh2/tdh2easy/js_test.go similarity index 78% rename from js/tdh2/test/js_test.go rename to go/tdh2/tdh2easy/js_test.go index ee8148e..219f7ea 100644 --- a/js/tdh2/test/js_test.go +++ b/go/tdh2/tdh2easy/js_test.go @@ -1,16 +1,16 @@ -package test +package tdh2easy import ( "bytes" "encoding/base64" "os/exec" "testing" - - "github.com/goplugin/tdh2/go/tdh2easy" ) +const jsTestPath = "../../../js/tdh2/test/test.js" + func TestJS(t *testing.T) { - _, pk, sh, err := tdh2easy.GenerateKeys(2, 3) + _, pk, sh, err := GenerateKeys(2, 3) if err != nil { t.Fatalf("GenerateKeys: %v", err) } @@ -19,7 +19,7 @@ func TestJS(t *testing.T) { t.Fatalf("Marshal: %v", err) } - cmdArgs := []string{"test.js", string(b)} + cmdArgs := []string{jsTestPath, string(b)} cmd := exec.Command("node", cmdArgs...) output, err := cmd.CombinedOutput() if err != nil { @@ -36,19 +36,19 @@ func TestJS(t *testing.T) { if err != nil { t.Fatalf("b64Decode: %v", err) } - var c tdh2easy.Ciphertext + var c Ciphertext if err := c.UnmarshalVerify(pairs[2*i+1], pk); err != nil { t.Fatalf("Unmarshal: %v", err) } - dec := []*tdh2easy.DecryptionShare{} + dec := []*DecryptionShare{} for _, s := range sh { - d, err := tdh2easy.Decrypt(&c, s) + d, err := Decrypt(&c, s) if err != nil { t.Fatalf("Decrypt: %v", err) } dec = append(dec, d) } - got, err := tdh2easy.Aggregate(&c, dec, 3) + got, err := Aggregate(&c, dec, 3) if err != nil { t.Fatalf("Aggregate: %v", err) } diff --git a/go/tdh2/tdh2easy/tdh2easy.go b/go/tdh2/tdh2easy/tdh2easy.go index 82fd20d..d93bfe2 100644 --- a/go/tdh2/tdh2easy/tdh2easy.go +++ b/go/tdh2/tdh2easy/tdh2easy.go @@ -2,14 +2,14 @@ package tdh2easy import ( + "crypto/aes" + "crypto/cipher" "crypto/rand" "encoding/json" "fmt" + "github.com/goplugin/tdh2/go/tdh2/internal/group/nist" "github.com/goplugin/tdh2/go/tdh2/tdh2" - "go.dedis.ch/kyber/v3" - "go.dedis.ch/kyber/v3/group/nist" - "go.dedis.ch/kyber/v3/xof/keccak" ) // key size used in symmetric encryption (AES). 256 bits is a higher securitylevel than provided @@ -17,7 +17,7 @@ import ( const aes256KeySize = 32 // defaultGroup is the default EC group used. -var defaultGroup = nist.NewBlakeSHA256P256() +var defaultGroup = nist.NewP256() // PrivateShare encodes TDH2 private share. type PrivateShare struct { @@ -94,11 +94,11 @@ type Ciphertext struct { // Decrypt returns a decryption share for the ciphertext. func Decrypt(c *Ciphertext, x_i *PrivateShare) (*DecryptionShare, error) { - xof, err := xof() + r, err := randStream() if err != nil { return nil, err } - d, err := c.tdh2Ctxt.Decrypt(defaultGroup, x_i.p, xof) + d, err := c.tdh2Ctxt.Decrypt(defaultGroup, x_i.p, r) if err != nil { return nil, err } @@ -129,13 +129,21 @@ func Aggregate(c *Ciphertext, shares []*DecryptionShare, n int) ([]byte, error) return symDecrypt(c.nonce, c.symCtxt, key) } -// xof returns xof used for providing randomness. -func xof() (kyber.XOF, error) { - seed := make([]byte, 64) - if _, err := rand.Read(seed); err != nil { - return nil, fmt.Errorf("cannot generate seed: %w", err) +// randStream returns a stream cipher used for providing randomness. +func randStream() (cipher.Stream, error) { + key := make([]byte, aes256KeySize) + if _, err := rand.Read(key); err != nil { + return nil, fmt.Errorf("cannot generate key: %w", err) + } + iv := make([]byte, aes.BlockSize) + if _, err := rand.Read(iv); err != nil { + return nil, fmt.Errorf("cannot generate iv: %w", err) + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("cannot init aes: %w", err) } - return keccak.New(seed), nil + return cipher.NewCTR(block, iv), nil } type ciphertextRaw struct { @@ -178,11 +186,11 @@ func (c *Ciphertext) UnmarshalVerify(data []byte, pk *PublicKey) error { // GenerateKeys generates and returns, the master secret, public key, and private shares. It takes the // total number of nodes n and a threshold k (the number of shares sufficient for decryption). func GenerateKeys(k, n int) (*MasterSecret, *PublicKey, []*PrivateShare, error) { - xof, err := xof() + r, err := randStream() if err != nil { return nil, nil, nil, err } - ms, pk, sh, err := tdh2.GenerateKeys(defaultGroup, nil, k, n, xof) + ms, pk, sh, err := tdh2.GenerateKeys(defaultGroup, nil, k, n, r) if err != nil { return nil, nil, nil, err } @@ -200,11 +208,11 @@ func GenerateKeys(k, n int) (*MasterSecret, *PublicKey, []*PrivateShare, error) // The old public key can still be used for encryption but it cannot be used for share // verification (the new key has to be used instead). func Redeal(pk *PublicKey, ms *MasterSecret, k, n int) (*PublicKey, []*PrivateShare, error) { - xof, err := xof() + r, err := randStream() if err != nil { return nil, nil, err } - p, sh, err := tdh2.Redeal(pk.p, ms.m, k, n, xof) + p, sh, err := tdh2.Redeal(pk.p, ms.m, k, n, r) if err != nil { return nil, nil, err } @@ -234,12 +242,12 @@ func Encrypt(pk *PublicKey, msg []byte) (*Ciphertext, error) { return nil, fmt.Errorf("cannot encrypt message: %w", err) } - xof, err := xof() + r, err := randStream() if err != nil { return nil, err } // encrypt the key with TDH2 using empty label - tdh2Ctxt, err := tdh2.Encrypt(pk.p, key, make([]byte, tdh2.InputSize), xof) + tdh2Ctxt, err := tdh2.Encrypt(pk.p, key, make([]byte, tdh2.InputSize), r) if err != nil { return nil, fmt.Errorf("cannot TDH2 encrypt: %w", err) } diff --git a/go/tdh2/tdh2easy/tdh2easy_test.go b/go/tdh2/tdh2easy/tdh2easy_test.go index 71216e3..0c1d0da 100644 --- a/go/tdh2/tdh2easy/tdh2easy_test.go +++ b/go/tdh2/tdh2easy/tdh2easy_test.go @@ -8,9 +8,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/goplugin/tdh2/go/tdh2/internal/group/nist" "github.com/goplugin/tdh2/go/tdh2/tdh2" - "go.dedis.ch/kyber/v3/group/nist" - "go.dedis.ch/kyber/v3/xof/keccak" ) func TestShareIndex(t *testing.T) { @@ -135,7 +134,11 @@ func TestCiphertextDecrypt(t *testing.T) { if err != nil { t.Fatalf("GenerateKeys: %v", err) } - _, _, wrong, err := tdh2.GenerateKeys(nist.NewBlakeSHA256QR512(), nil, 1, 1, keccak.New(nil)) + r, err := randStream() + if err != nil { + t.Fatalf("RandStream: %v", err) + } + _, _, wrong, err := tdh2.GenerateKeys(nist.NewP521(), nil, 1, 1, r) if err != nil { t.Fatalf("GenerateKeys: %v", err) } @@ -507,9 +510,9 @@ func FuzzCiphertextMarshal(f *testing.F) { if err != nil { f.Fatalf("Keys: %v", err) } - xof, err := xof() + r, err := randStream() if err != nil { - f.Fatalf("xof: %v", err) + f.Fatalf("randStream: %v", err) } tdh2Input := make([]byte, tdh2.InputSize) f.Add(tdh2Input, []byte("symcCtxt"), []byte("nonce")) @@ -517,7 +520,7 @@ func FuzzCiphertextMarshal(f *testing.F) { if len(key) != tdh2.InputSize { t.Skip() } - tdh2Ctxt, err := tdh2.Encrypt(pk.p, key, tdh2Input, xof) + tdh2Ctxt, err := tdh2.Encrypt(pk.p, key, tdh2Input, r) if err != nil { t.Fatalf("Encrypt(%v): %v", key, err) } @@ -542,7 +545,11 @@ func FuzzCiphertextUnmarshal(f *testing.F) { if err != nil { f.Fatalf("Keys: %v", err) } - tdh2Ctxt, err := tdh2.Encrypt(pk.p, make([]byte, tdh2.InputSize), make([]byte, tdh2.InputSize), keccak.New(nil)) + r, err := randStream() + if err != nil { + f.Fatalf("ranStream: %v", err) + } + tdh2Ctxt, err := tdh2.Encrypt(pk.p, make([]byte, tdh2.InputSize), make([]byte, tdh2.InputSize), r) if err != nil { f.Fatalf("Encrypt: %v", err) }