Skip to content

Commit

Permalink
Update blsagg service hash fn to return error (#257)
Browse files Browse the repository at this point in the history
* update blsagg hashfn to return error

* make mocks

* moved hashfunction error log inside verifySignature for consistency
  • Loading branch information
samlaf authored May 30, 2024
1 parent 589b2f8 commit 4bd13c6
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 42 deletions.
19 changes: 16 additions & 3 deletions services/bls_aggregation/blsagg.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ var (
OperatorNotPartOfTaskQuorumErrorFn = func(operatorId types.OperatorId, taskIndex types.TaskIndex) error {
return fmt.Errorf("operator %x not part of task %d's quorum", operatorId, taskIndex)
}
HashFunctionError = func(err error) error {
return fmt.Errorf("Failed to hash task response: %w", err)
}
SignatureVerificationError = func(err error) error {
return fmt.Errorf("Failed to verify signature: %w", err)
}
Expand Down Expand Up @@ -261,14 +264,20 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc(
select {
case signedTaskResponseDigest := <-signedTaskRespsC:
a.logger.Debug("Task goroutine received new signed task response digest", "taskIndex", taskIndex, "signedTaskResponseDigest", signedTaskResponseDigest)
// compute the taskResponseDigest using the hash function
taskResponseDigest := a.hashFunction(signedTaskResponseDigest.TaskResponse)

err := a.verifySignature(taskIndex, signedTaskResponseDigest, operatorsAvsStateDict)
signedTaskResponseDigest.SignatureVerificationErrorC <- err
if err != nil {
continue
}

// compute the taskResponseDigest using the hash function
taskResponseDigest, err := a.hashFunction(signedTaskResponseDigest.TaskResponse)
if err != nil {
// this error should never happen, because we've already hashed the taskResponse in verifySignature,
// but keeping here in case the verifySignature implementation ever changes or some catastrophic bug happens..
continue
}
// after verifying signature we aggregate its sig and pubkey, and update the signed stake amount
digestAggregatedOperators, ok := aggregatedOperatorsDict[taskResponseDigest]
if !ok {
Expand Down Expand Up @@ -382,7 +391,11 @@ func (a *BlsAggregatorService) verifySignature(
return OperatorNotPartOfTaskQuorumErrorFn(signedTaskResponseDigest.OperatorId, taskIndex)
}

taskResponseDigest := a.hashFunction(signedTaskResponseDigest.TaskResponse)
taskResponseDigest, err := a.hashFunction(signedTaskResponseDigest.TaskResponse)
if err != nil {
a.logger.Error("Failed to hash task response, skipping.", "taskIndex", taskIndex, "signedTaskResponseDigest", signedTaskResponseDigest, "err", err)
return HashFunctionError(err)
}

// verify that the msg actually came from the correct operator
operatorG2Pubkey := operatorsAvsStateDict[signedTaskResponseDigest.OperatorId].OperatorInfo.Pubkeys.G2Pubkey
Expand Down
92 changes: 55 additions & 37 deletions services/bls_aggregation/blsagg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ func TestBlsAgg(t *testing.T) {
// 1 second seems to be enough for tests to pass. Currently takes 5s to run all tests
tasksTimeToExpiry := 1 * time.Second

hashFunction := func(taskResponse types.TaskResponse) types.TaskResponseDigest {
hashFunction := func(taskResponse types.TaskResponse) (types.TaskResponseDigest, error) {
taskResponseBytes, err := json.Marshal(taskResponse)
if err != nil {
panic(err)
return types.TaskResponseDigest{}, err
}
return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes))
return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes)), nil
}

wrongHashFunction := func(taskResponse types.TaskResponse) types.TaskResponseDigest {
wrongHashFunction := func(taskResponse types.TaskResponse) (types.TaskResponseDigest, error) {
taskResponseBytes, err := json.Marshal(taskResponse)
if err != nil {
panic(err)
return types.TaskResponseDigest{}, err
}
// append something to the taskResponseBytes to make it different
taskResponseBytes = append(taskResponseBytes, []byte("something")...)
return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes))
return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes)), nil
}

type mockTaskResponse struct {
Expand All @@ -60,15 +60,16 @@ func TestBlsAgg(t *testing.T) {
taskResponse := mockTaskResponse{123} // Initialize with appropriate data

// Compute the TaskResponseDigest as the SHA-256 sum of the TaskResponse
taskResponseDigest := hashFunction(taskResponse)
taskResponseDigest, err := hashFunction(taskResponse)
require.Nil(t, err)

blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId)
require.Nil(t, err)
Expand Down Expand Up @@ -108,13 +109,14 @@ func TestBlsAgg(t *testing.T) {
quorumThresholdPercentages := []types.QuorumThresholdPercentage{100}
taskResponse := mockTaskResponse{123}

taskResponseDigest := hashFunction(taskResponse)
taskResponseDigest, err := hashFunction(taskResponse)
require.Nil(t, err)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId)
Expand Down Expand Up @@ -163,13 +165,14 @@ func TestBlsAgg(t *testing.T) {
quorumNumbers := types.QuorumNums{0, 1}
quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100}
taskResponse := mockTaskResponse{123}
taskResponseDigest := hashFunction(taskResponse)
taskResponseDigest, err := hashFunction(taskResponse)
require.Nil(t, err)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId)
Expand Down Expand Up @@ -217,12 +220,14 @@ func TestBlsAgg(t *testing.T) {
// initialize 2 concurrent tasks
task1Index := types.TaskIndex(1)
task1Response := mockTaskResponse{123}
task1ResponseDigest := hashFunction(task1Response)
err := blsAggServ.InitializeNewTask(task1Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
task1ResponseDigest, err := hashFunction(task1Response)
require.Nil(t, err)
err = blsAggServ.InitializeNewTask(task1Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
task2Index := types.TaskIndex(2)
task2Response := mockTaskResponse{234}
task2ResponseDigest := hashFunction(task2Response)
task2ResponseDigest, err := hashFunction(task2Response)
require.Nil(t, err)
err = blsAggServ.InitializeNewTask(task2Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)

Expand Down Expand Up @@ -318,15 +323,16 @@ func TestBlsAgg(t *testing.T) {
quorumNumbers := types.QuorumNums{0}
quorumThresholdPercentages := []types.QuorumThresholdPercentage{50}
taskResponse := mockTaskResponse{123}
taskResponseDigest := hashFunction(taskResponse)
taskResponseDigest, err := hashFunction(taskResponse)
require.Nil(t, err)
blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)
blockNum := uint32(1)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId)
require.Nil(t, err)
Expand Down Expand Up @@ -360,14 +366,15 @@ func TestBlsAgg(t *testing.T) {
quorumNumbers := types.QuorumNums{0}
quorumThresholdPercentages := []types.QuorumThresholdPercentage{60}
taskResponse := mockTaskResponse{123}
taskResponseDigest := hashFunction(taskResponse)
taskResponseDigest, err := hashFunction(taskResponse)
require.Nil(t, err)
blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId)
require.Nil(t, err)
Expand Down Expand Up @@ -395,14 +402,15 @@ func TestBlsAgg(t *testing.T) {
quorumNumbers := types.QuorumNums{0, 1}
quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100}
taskResponse := mockTaskResponse{123}
taskResponseDigest := hashFunction(taskResponse)
taskResponseDigest, err := hashFunction(taskResponse)
require.Nil(t, err)
blockNum := uint32(1)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId)
Expand Down Expand Up @@ -450,14 +458,15 @@ func TestBlsAgg(t *testing.T) {
quorumNumbers := types.QuorumNums{0, 1}
quorumThresholdPercentages := []types.QuorumThresholdPercentage{50, 50}
taskResponse := mockTaskResponse{123}
taskResponseDigest := hashFunction(taskResponse)
taskResponseDigest, err := hashFunction(taskResponse)
require.Nil(t, err)
blockNum := uint32(1)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId)
Expand Down Expand Up @@ -507,14 +516,15 @@ func TestBlsAgg(t *testing.T) {
quorumNumbers := types.QuorumNums{0, 1}
quorumThresholdPercentages := []types.QuorumThresholdPercentage{60, 60}
taskResponse := mockTaskResponse{123}
taskResponseDigest := hashFunction(taskResponse)
taskResponseDigest, err := hashFunction(taskResponse)
require.Nil(t, err)
blockNum := uint32(1)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId)
Expand All @@ -541,14 +551,15 @@ func TestBlsAgg(t *testing.T) {
quorumNumbers := types.QuorumNums{0, 1}
quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100}
taskResponse := mockTaskResponse{123}
taskResponseDigest := hashFunction(taskResponse)
taskResponseDigest, err := hashFunction(taskResponse)
require.Nil(t, err)
blockNum := uint32(1)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId)
Expand Down Expand Up @@ -577,14 +588,15 @@ func TestBlsAgg(t *testing.T) {
quorumNumbers := types.QuorumNums{0, 1}
quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100}
taskResponse := mockTaskResponse{123}
taskResponseDigest := hashFunction(taskResponse)
taskResponseDigest, err := hashFunction(taskResponse)
require.Nil(t, err)
blockNum := uint32(1)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId)
Expand All @@ -606,14 +618,15 @@ func TestBlsAgg(t *testing.T) {
blockNum := uint32(1)
taskIndex := types.TaskIndex(0)
taskResponse := mockTaskResponse{123}
taskResponseDigest := hashFunction(taskResponse)
taskResponseDigest, err := hashFunction(taskResponse)
require.Nil(t, err)
blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId)
require.Equal(t, TaskNotFoundErrorFn(taskIndex), err)
})

Expand Down Expand Up @@ -642,13 +655,15 @@ func TestBlsAgg(t *testing.T) {
err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
taskResponse1 := mockTaskResponse{1}
taskResponseDigest1 := hashFunction(taskResponse1)
taskResponseDigest1, err := hashFunction(taskResponse1)
require.Nil(t, err)
blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse1, blsSigOp1, testOperator1.OperatorId)
require.Nil(t, err)

taskResponse2 := mockTaskResponse{2}
taskResponseDigest2 := hashFunction(taskResponse2)
taskResponseDigest2, err := hashFunction(taskResponse2)
require.Nil(t, err)
blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
Expand Down Expand Up @@ -693,12 +708,14 @@ func TestBlsAgg(t *testing.T) {
err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
taskResponse1 := mockTaskResponse{1}
taskResponseDigest1 := hashFunction(taskResponse1)
taskResponseDigest1, err := hashFunction(taskResponse1)
require.Nil(t, err)
blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse1, blsSigOp1, testOperator1.OperatorId)
require.Nil(t, err)
taskResponse2 := mockTaskResponse{2}
taskResponseDigest2 := hashFunction(taskResponse2)
taskResponseDigest2, err := hashFunction(taskResponse2)
require.Nil(t, err)
blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse2, blsSigOp2, testOperator2.OperatorId)
require.Nil(t, err)
Expand All @@ -721,15 +738,16 @@ func TestBlsAgg(t *testing.T) {
quorumThresholdPercentages := []types.QuorumThresholdPercentage{100}
taskResponse := mockTaskResponse{123} // Initialize with appropriate data

taskResponseDigest := wrongHashFunction(taskResponse)
taskResponseDigest, err := wrongHashFunction(taskResponse)
require.Nil(t, err)

blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest)

fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1})
noopLogger := logging.NewNoopLogger()
blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger)

err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
err = blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry)
require.Nil(t, err)
err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId)
require.EqualError(t, err, "Signature verification failed. Incorrect Signature.")
Expand Down
2 changes: 1 addition & 1 deletion services/mocks/blsagg/blsaggregation.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 4bd13c6

Please sign in to comment.