From 4bd13c6fdde831383cf976eb8363340579edcd91 Mon Sep 17 00:00:00 2001 From: Samuel Laferriere Date: Thu, 30 May 2024 23:29:35 +0800 Subject: [PATCH] Update blsagg service hash fn to return error (#257) * update blsagg hashfn to return error * make mocks * moved hashfunction error log inside verifySignature for consistency --- services/bls_aggregation/blsagg.go | 19 ++++- services/bls_aggregation/blsagg_test.go | 92 +++++++++++++++---------- services/mocks/blsagg/blsaggregation.go | 2 +- types/avs.go | 2 +- 4 files changed, 73 insertions(+), 42 deletions(-) diff --git a/services/bls_aggregation/blsagg.go b/services/bls_aggregation/blsagg.go index ddd75cc5..9514bb2b 100644 --- a/services/bls_aggregation/blsagg.go +++ b/services/bls_aggregation/blsagg.go @@ -35,6 +35,9 @@ var ( OperatorNotPartOfTaskQuorumErrorFn = func(operatorId types.OperatorId, taskIndex types.TaskIndex) error { return fmt.Errorf("operator %x not part of task %d's quorum", operatorId, taskIndex) } + HashFunctionError = func(err error) error { + return fmt.Errorf("Failed to hash task response: %w", err) + } SignatureVerificationError = func(err error) error { return fmt.Errorf("Failed to verify signature: %w", err) } @@ -261,14 +264,20 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( select { case signedTaskResponseDigest := <-signedTaskRespsC: a.logger.Debug("Task goroutine received new signed task response digest", "taskIndex", taskIndex, "signedTaskResponseDigest", signedTaskResponseDigest) - // compute the taskResponseDigest using the hash function - taskResponseDigest := a.hashFunction(signedTaskResponseDigest.TaskResponse) err := a.verifySignature(taskIndex, signedTaskResponseDigest, operatorsAvsStateDict) signedTaskResponseDigest.SignatureVerificationErrorC <- err if err != nil { continue } + + // compute the taskResponseDigest using the hash function + taskResponseDigest, err := a.hashFunction(signedTaskResponseDigest.TaskResponse) + if err != nil { + // this error should never happen, because we've already hashed the taskResponse in verifySignature, + // but keeping here in case the verifySignature implementation ever changes or some catastrophic bug happens.. + continue + } // after verifying signature we aggregate its sig and pubkey, and update the signed stake amount digestAggregatedOperators, ok := aggregatedOperatorsDict[taskResponseDigest] if !ok { @@ -382,7 +391,11 @@ func (a *BlsAggregatorService) verifySignature( return OperatorNotPartOfTaskQuorumErrorFn(signedTaskResponseDigest.OperatorId, taskIndex) } - taskResponseDigest := a.hashFunction(signedTaskResponseDigest.TaskResponse) + taskResponseDigest, err := a.hashFunction(signedTaskResponseDigest.TaskResponse) + if err != nil { + a.logger.Error("Failed to hash task response, skipping.", "taskIndex", taskIndex, "signedTaskResponseDigest", signedTaskResponseDigest, "err", err) + return HashFunctionError(err) + } // verify that the msg actually came from the correct operator operatorG2Pubkey := operatorsAvsStateDict[signedTaskResponseDigest.OperatorId].OperatorInfo.Pubkeys.G2Pubkey diff --git a/services/bls_aggregation/blsagg_test.go b/services/bls_aggregation/blsagg_test.go index 616c6754..21e565ef 100644 --- a/services/bls_aggregation/blsagg_test.go +++ b/services/bls_aggregation/blsagg_test.go @@ -25,22 +25,22 @@ func TestBlsAgg(t *testing.T) { // 1 second seems to be enough for tests to pass. Currently takes 5s to run all tests tasksTimeToExpiry := 1 * time.Second - hashFunction := func(taskResponse types.TaskResponse) types.TaskResponseDigest { + hashFunction := func(taskResponse types.TaskResponse) (types.TaskResponseDigest, error) { taskResponseBytes, err := json.Marshal(taskResponse) if err != nil { - panic(err) + return types.TaskResponseDigest{}, err } - return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes)) + return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes)), nil } - wrongHashFunction := func(taskResponse types.TaskResponse) types.TaskResponseDigest { + wrongHashFunction := func(taskResponse types.TaskResponse) (types.TaskResponseDigest, error) { taskResponseBytes, err := json.Marshal(taskResponse) if err != nil { - panic(err) + return types.TaskResponseDigest{}, err } // append something to the taskResponseBytes to make it different taskResponseBytes = append(taskResponseBytes, []byte("something")...) - return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes)) + return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes)), nil } type mockTaskResponse struct { @@ -60,7 +60,8 @@ func TestBlsAgg(t *testing.T) { taskResponse := mockTaskResponse{123} // Initialize with appropriate data // Compute the TaskResponseDigest as the SHA-256 sum of the TaskResponse - taskResponseDigest := hashFunction(taskResponse) + taskResponseDigest, err := hashFunction(taskResponse) + require.Nil(t, err) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) @@ -68,7 +69,7 @@ func TestBlsAgg(t *testing.T) { noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Nil(t, err) @@ -108,13 +109,14 @@ func TestBlsAgg(t *testing.T) { quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} taskResponse := mockTaskResponse{123} - taskResponseDigest := hashFunction(taskResponse) + taskResponseDigest, err := hashFunction(taskResponse) + require.Nil(t, err) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) @@ -163,13 +165,14 @@ func TestBlsAgg(t *testing.T) { quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} taskResponse := mockTaskResponse{123} - taskResponseDigest := hashFunction(taskResponse) + taskResponseDigest, err := hashFunction(taskResponse) + require.Nil(t, err) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) @@ -217,12 +220,14 @@ func TestBlsAgg(t *testing.T) { // initialize 2 concurrent tasks task1Index := types.TaskIndex(1) task1Response := mockTaskResponse{123} - task1ResponseDigest := hashFunction(task1Response) - err := blsAggServ.InitializeNewTask(task1Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + task1ResponseDigest, err := hashFunction(task1Response) + require.Nil(t, err) + err = blsAggServ.InitializeNewTask(task1Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) task2Index := types.TaskIndex(2) task2Response := mockTaskResponse{234} - task2ResponseDigest := hashFunction(task2Response) + task2ResponseDigest, err := hashFunction(task2Response) + require.Nil(t, err) err = blsAggServ.InitializeNewTask(task2Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -318,7 +323,8 @@ func TestBlsAgg(t *testing.T) { quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{50} taskResponse := mockTaskResponse{123} - taskResponseDigest := hashFunction(taskResponse) + taskResponseDigest, err := hashFunction(taskResponse) + require.Nil(t, err) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) blockNum := uint32(1) @@ -326,7 +332,7 @@ func TestBlsAgg(t *testing.T) { noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Nil(t, err) @@ -360,14 +366,15 @@ func TestBlsAgg(t *testing.T) { quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{60} taskResponse := mockTaskResponse{123} - taskResponseDigest := hashFunction(taskResponse) + taskResponseDigest, err := hashFunction(taskResponse) + require.Nil(t, err) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Nil(t, err) @@ -395,14 +402,15 @@ func TestBlsAgg(t *testing.T) { quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} taskResponse := mockTaskResponse{123} - taskResponseDigest := hashFunction(taskResponse) + taskResponseDigest, err := hashFunction(taskResponse) + require.Nil(t, err) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) @@ -450,14 +458,15 @@ func TestBlsAgg(t *testing.T) { quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{50, 50} taskResponse := mockTaskResponse{123} - taskResponseDigest := hashFunction(taskResponse) + taskResponseDigest, err := hashFunction(taskResponse) + require.Nil(t, err) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) @@ -507,14 +516,15 @@ func TestBlsAgg(t *testing.T) { quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{60, 60} taskResponse := mockTaskResponse{123} - taskResponseDigest := hashFunction(taskResponse) + taskResponseDigest, err := hashFunction(taskResponse) + require.Nil(t, err) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) @@ -541,14 +551,15 @@ func TestBlsAgg(t *testing.T) { quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} taskResponse := mockTaskResponse{123} - taskResponseDigest := hashFunction(taskResponse) + taskResponseDigest, err := hashFunction(taskResponse) + require.Nil(t, err) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) @@ -577,14 +588,15 @@ func TestBlsAgg(t *testing.T) { quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} taskResponse := mockTaskResponse{123} - taskResponseDigest := hashFunction(taskResponse) + taskResponseDigest, err := hashFunction(taskResponse) + require.Nil(t, err) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) @@ -606,14 +618,15 @@ func TestBlsAgg(t *testing.T) { blockNum := uint32(1) taskIndex := types.TaskIndex(0) taskResponse := mockTaskResponse{123} - taskResponseDigest := hashFunction(taskResponse) + taskResponseDigest, err := hashFunction(taskResponse) + require.Nil(t, err) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Equal(t, TaskNotFoundErrorFn(taskIndex), err) }) @@ -642,13 +655,15 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) taskResponse1 := mockTaskResponse{1} - taskResponseDigest1 := hashFunction(taskResponse1) + taskResponseDigest1, err := hashFunction(taskResponse1) + require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse1, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) taskResponse2 := mockTaskResponse{2} - taskResponseDigest2 := hashFunction(taskResponse2) + taskResponseDigest2, err := hashFunction(taskResponse2) + require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -693,12 +708,14 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) taskResponse1 := mockTaskResponse{1} - taskResponseDigest1 := hashFunction(taskResponse1) + taskResponseDigest1, err := hashFunction(taskResponse1) + require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse1, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) taskResponse2 := mockTaskResponse{2} - taskResponseDigest2 := hashFunction(taskResponse2) + taskResponseDigest2, err := hashFunction(taskResponse2) + require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse2, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) @@ -721,7 +738,8 @@ func TestBlsAgg(t *testing.T) { quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} taskResponse := mockTaskResponse{123} // Initialize with appropriate data - taskResponseDigest := wrongHashFunction(taskResponse) + taskResponseDigest, err := wrongHashFunction(taskResponse) + require.Nil(t, err) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) @@ -729,7 +747,7 @@ func TestBlsAgg(t *testing.T) { noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.EqualError(t, err, "Signature verification failed. Incorrect Signature.") diff --git a/services/mocks/blsagg/blsaggregation.go b/services/mocks/blsagg/blsaggregation.go index a2c3f629..418bf836 100644 --- a/services/mocks/blsagg/blsaggregation.go +++ b/services/mocks/blsagg/blsaggregation.go @@ -72,7 +72,7 @@ func (mr *MockBlsAggregationServiceMockRecorder) InitializeNewTask(arg0, arg1, a } // ProcessNewSignature mocks base method. -func (m *MockBlsAggregationService) ProcessNewSignature(arg0 context.Context, arg1 uint32, arg2 types.Bytes32, arg3 *bls.Signature, arg4 types.Bytes32) error { +func (m *MockBlsAggregationService) ProcessNewSignature(arg0 context.Context, arg1 uint32, arg2 any, arg3 *bls.Signature, arg4 types.Bytes32) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ProcessNewSignature", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(error) diff --git a/types/avs.go b/types/avs.go index 22787013..12c4cd26 100644 --- a/types/avs.go +++ b/types/avs.go @@ -10,7 +10,7 @@ type TaskIndex = uint32 type TaskResponseDigest = Bytes32 type TaskResponse = interface{} -type TaskResponseHashFunction func(taskResponse TaskResponse) TaskResponseDigest +type TaskResponseHashFunction func(taskResponse TaskResponse) (TaskResponseDigest, error) type SignedTaskResponseDigest struct { TaskResponse TaskResponse