From d7994a85f642b39b63b34b728180119368be980f Mon Sep 17 00:00:00 2001 From: afkbyte Date: Fri, 24 May 2024 23:48:45 -0400 Subject: [PATCH 1/8] minimal diff to be able to return the task response bytes --- services/bls_aggregation/blsagg.go | 34 ++++++- services/bls_aggregation/blsagg_test.go | 119 +++++++++++++++--------- 2 files changed, 106 insertions(+), 47 deletions(-) diff --git a/services/bls_aggregation/blsagg.go b/services/bls_aggregation/blsagg.go index d470da18..94aa203f 100644 --- a/services/bls_aggregation/blsagg.go +++ b/services/bls_aggregation/blsagg.go @@ -2,6 +2,7 @@ package blsagg import ( "context" + "crypto/sha256" "errors" "fmt" "math/big" @@ -45,6 +46,7 @@ var ( type BlsAggregationServiceResponse struct { Err error // if Err is not nil, the other fields are not valid TaskIndex types.TaskIndex // unique identifier of the task + TaskResponse types.TaskResponse // the task response that was signed TaskResponseDigest types.TaskResponseDigest // digest of the task response that was signed // The below 8 fields are the data needed to build the IBLSSignatureChecker.NonSignerStakesAndSignature struct // users of this service will need to build the struct themselves by converting the bls points @@ -89,7 +91,7 @@ type BlsAggregationService interface { timeToExpiry time.Duration, ) error - // ProcessNewSignature processes a new signature over a taskResponseDigest for a particular taskIndex by a particular operator + // ProcessNewSignature processes a new signature over a taskResponseDigest (sha256 of the taskResponse) for a particular taskIndex by a particular operator // It verifies that the signature is correct and returns an error if it is not, and then aggregates the signature and stake of // the operator with all other signatures for the same taskIndex and taskResponseDigest pair. // Note: This function currently only verifies signatures over the taskResponseDigest directly, so avs code needs to verify that the digest @@ -97,7 +99,7 @@ type BlsAggregationService interface { ProcessNewSignature( ctx context.Context, taskIndex types.TaskIndex, - taskResponseDigest types.TaskResponseDigest, + taskResponse types.TaskResponse, blsSignature *bls.Signature, operatorId types.OperatorId, ) error @@ -134,6 +136,9 @@ type BlsAggregatorService struct { taskChansMutex sync.RWMutex avsRegistryService avsregistry.AvsRegistryService logger logging.Logger + + // taskResponseMap is a map of taskResponseDigest to taskResponse + taskResponseMap map[types.TaskResponseDigest]types.TaskResponse } var _ BlsAggregationService = (*BlsAggregatorService)(nil) @@ -145,6 +150,7 @@ func NewBlsAggregatorService(avsRegistryService avsregistry.AvsRegistryService, taskChansMutex: sync.RWMutex{}, avsRegistryService: avsRegistryService, logger: logger, + taskResponseMap: make(map[types.TaskResponseDigest]types.TaskResponse), } } @@ -179,7 +185,8 @@ func (a *BlsAggregatorService) InitializeNewTask( func (a *BlsAggregatorService) ProcessNewSignature( ctx context.Context, taskIndex types.TaskIndex, - taskResponseDigest types.TaskResponseDigest, + taskResponse types.TaskResponse, + //taskResponseDigest types.TaskResponseDigest, blsSignature *bls.Signature, operatorId types.OperatorId, ) error { @@ -189,9 +196,16 @@ func (a *BlsAggregatorService) ProcessNewSignature( if !taskInitialized { return TaskNotFoundErrorFn(taskIndex) } + // compute the taskResponseDigest, note that this is now enforcing a specific encoding for the taskResponse + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + + // Store the TaskResponse in our mapping + a.taskResponseMap[taskResponseDigest] = taskResponse + signatureVerificationErrorC := make(chan error) // send the task to the goroutine processing this task // and return the error (if any) returned by the signature verification routine + select { // we need to send this as part of select because if the goroutine is processing another SignedTaskResponseDigest // and cannot receive this one, we want the context to be able to cancel the request @@ -316,9 +330,23 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( } return } + + // Retrieve the TaskResponse from the map + taskResponse := a.taskResponseMap[signedTaskResponseDigest.TaskResponseDigest] + + // verify that the taskResponseDigest that was signed is the digest of the taskResponse + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + if signedTaskResponseDigest.TaskResponseDigest != taskResponseDigest { + a.aggregatedResponsesC <- BlsAggregationServiceResponse{ + Err: fmt.Errorf("signedTaskResponseDigest.TaskResponseDigest %x is not the digest of the TaskResponse %x", signedTaskResponseDigest.TaskResponseDigest, taskResponseDigest), + } + return + } + blsAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: signedTaskResponseDigest.TaskResponseDigest, NonSignersPubkeysG1: nonSignersG1Pubkeys, QuorumApksG1: quorumApksG1, diff --git a/services/bls_aggregation/blsagg_test.go b/services/bls_aggregation/blsagg_test.go index 27ad9704..ca6b62bf 100644 --- a/services/bls_aggregation/blsagg_test.go +++ b/services/bls_aggregation/blsagg_test.go @@ -2,6 +2,7 @@ package blsagg import ( "context" + "crypto/sha256" "math/big" "testing" "time" @@ -33,7 +34,11 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := types.TaskResponse{123} // Initialize with appropriate data + + // Compute the TaskResponseDigest as the SHA-256 sum of the TaskResponse + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) @@ -42,11 +47,12 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1()}, @@ -73,11 +79,13 @@ func TestBlsAgg(t *testing.T) { StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(300), 1: big.NewInt(100)}, BlsKeypair: newBlsKeyPairPanics("0x3"), } + blockNum := uint32(1) taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} - taskResponseDigest := types.TaskResponseDigest{123} - blockNum := uint32(1) + taskResponse := types.TaskResponse{123} + + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) noopLogger := logging.NewNoopLogger() @@ -86,18 +94,19 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) blsSigOp3 := testOperator3.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp3, testOperator3.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp3, testOperator3.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1(). @@ -126,11 +135,12 @@ func TestBlsAgg(t *testing.T) { StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, BlsKeypair: newBlsKeyPairPanics("0x2"), } + blockNum := uint32(1) taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponseDigest := types.TaskResponseDigest{123} - blockNum := uint32(1) + taskResponse := types.TaskResponse{123} + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() @@ -139,15 +149,16 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{ @@ -182,30 +193,33 @@ func TestBlsAgg(t *testing.T) { // initialize 2 concurrent tasks task1Index := types.TaskIndex(1) - task1ResponseDigest := types.TaskResponseDigest{123} + task1Response := types.TaskResponse{123} + task1ResponseDigest := types.TaskResponseDigest(sha256.Sum256(task1Response)) err := blsAggServ.InitializeNewTask(task1Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) task2Index := types.TaskIndex(2) - task2ResponseDigest := types.TaskResponseDigest{230} + task2Response := types.TaskResponse{234} + task2ResponseDigest := types.TaskResponseDigest(sha256.Sum256(task2Response)) err = blsAggServ.InitializeNewTask(task2Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigTask1Op1 := testOperator1.BlsKeypair.SignMessage(task1ResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), task1Index, task1ResponseDigest, blsSigTask1Op1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), task1Index, task1Response, blsSigTask1Op1, testOperator1.OperatorId) require.Nil(t, err) blsSigTask2Op1 := testOperator1.BlsKeypair.SignMessage(task2ResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), task2Index, task2ResponseDigest, blsSigTask2Op1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), task2Index, task2Response, blsSigTask2Op1, testOperator1.OperatorId) require.Nil(t, err) blsSigTask1Op2 := testOperator2.BlsKeypair.SignMessage(task1ResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), task1Index, task1ResponseDigest, blsSigTask1Op2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), task1Index, task1Response, blsSigTask1Op2, testOperator2.OperatorId) require.Nil(t, err) blsSigTask2Op2 := testOperator2.BlsKeypair.SignMessage(task2ResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), task2Index, task2ResponseDigest, blsSigTask2Op2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), task2Index, task2Response, blsSigTask2Op2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponseTask1 := BlsAggregationServiceResponse{ Err: nil, TaskIndex: task1Index, + TaskResponse: task1Response, TaskResponseDigest: task1ResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{ @@ -218,6 +232,7 @@ func TestBlsAgg(t *testing.T) { wantAggregationServiceResponseTask2 := BlsAggregationServiceResponse{ Err: nil, TaskIndex: task2Index, + TaskResponse: task2Response, TaskResponseDigest: task2ResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{ @@ -279,7 +294,8 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{50} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := types.TaskResponse{123} + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) blockNum := uint32(1) @@ -289,11 +305,12 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{testOperator2.BlsKeypair.GetPubKeyG1()}, QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1().Add(testOperator2.BlsKeypair.GetPubKeyG1())}, @@ -319,7 +336,8 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{60} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := types.TaskResponse{123} + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) @@ -328,7 +346,7 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: TaskExpiredErrorFn(taskIndex), @@ -353,7 +371,8 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := types.TaskResponse{123} + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) @@ -363,15 +382,16 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{ @@ -406,7 +426,8 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{50, 50} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := types.TaskResponse{123} + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) @@ -416,15 +437,16 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{ testOperator3.BlsKeypair.GetPubKeyG1(), @@ -461,7 +483,8 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{60, 60} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := types.TaskResponse{123} + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) @@ -471,10 +494,10 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ @@ -494,7 +517,8 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := types.TaskResponse{123} + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) @@ -504,7 +528,7 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ @@ -529,7 +553,8 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := types.TaskResponse{123} + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) @@ -539,7 +564,7 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ @@ -557,14 +582,15 @@ func TestBlsAgg(t *testing.T) { } blockNum := uint32(1) taskIndex := types.TaskIndex(0) - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := types.TaskResponse{123} + taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) - err := blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + err := blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Equal(t, TaskNotFoundErrorFn(taskIndex), err) }) @@ -592,22 +618,25 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - taskResponseDigest1 := types.TaskResponseDigest{1} + taskResponse1 := types.TaskResponse{1} + taskResponseDigest1 := types.TaskResponseDigest(sha256.Sum256(taskResponse1)) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest1, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse1, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) - taskResponseDigest2 := types.TaskResponseDigest{2} + taskResponse2 := types.TaskResponse{2} + taskResponseDigest2 := types.TaskResponseDigest(sha256.Sum256(taskResponse2)) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - err = blsAggServ.ProcessNewSignature(ctx, taskIndex, taskResponseDigest2, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(ctx, taskIndex, taskResponse2, blsSigOp2, testOperator2.OperatorId) // this should timeout because the task goroutine is blocked on the response channel (since we only listen for it below) require.Equal(t, context.DeadlineExceeded, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse1, TaskResponseDigest: taskResponseDigest1, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1()}, @@ -640,13 +669,15 @@ func TestBlsAgg(t *testing.T) { err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - taskResponseDigest1 := types.TaskResponseDigest{1} + taskResponse1 := types.TaskResponse{1} + taskResponseDigest1 := types.TaskResponseDigest(sha256.Sum256(taskResponse1)) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest1, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse1, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) - taskResponseDigest2 := types.TaskResponseDigest{2} + taskResponse2 := types.TaskResponse{2} + taskResponseDigest2 := types.TaskResponseDigest(sha256.Sum256(taskResponse2)) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest2, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse2, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: TaskExpiredErrorFn(taskIndex), From 86a5aa1496f580baec551e9390fe934632eda41d Mon Sep 17 00:00:00 2001 From: afkbyte Date: Fri, 24 May 2024 23:51:10 -0400 Subject: [PATCH 2/8] add taskReponse type --- types/avs.go | 1 + 1 file changed, 1 insertion(+) diff --git a/types/avs.go b/types/avs.go index 9a10c679..bf76feb0 100644 --- a/types/avs.go +++ b/types/avs.go @@ -8,6 +8,7 @@ import ( type TaskIndex = uint32 type TaskResponseDigest = Bytes32 +type TaskResponse = []byte type SignedTaskResponseDigest struct { TaskResponseDigest TaskResponseDigest From 63b6a39421bc0867931ebedf3c20593b8f61d663 Mon Sep 17 00:00:00 2001 From: afkbyte Date: Sat, 25 May 2024 16:25:10 -0400 Subject: [PATCH 3/8] only add to taskResponseMap if doesn't already exist --- services/bls_aggregation/blsagg.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/services/bls_aggregation/blsagg.go b/services/bls_aggregation/blsagg.go index 94aa203f..6f1877f2 100644 --- a/services/bls_aggregation/blsagg.go +++ b/services/bls_aggregation/blsagg.go @@ -186,7 +186,6 @@ func (a *BlsAggregatorService) ProcessNewSignature( ctx context.Context, taskIndex types.TaskIndex, taskResponse types.TaskResponse, - //taskResponseDigest types.TaskResponseDigest, blsSignature *bls.Signature, operatorId types.OperatorId, ) error { @@ -199,8 +198,12 @@ func (a *BlsAggregatorService) ProcessNewSignature( // compute the taskResponseDigest, note that this is now enforcing a specific encoding for the taskResponse taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) - // Store the TaskResponse in our mapping - a.taskResponseMap[taskResponseDigest] = taskResponse + // check if the taskResponseDigest is already in the map + _, taskResponseExists := a.taskResponseMap[taskResponseDigest] + if !taskResponseExists { + // Store the TaskResponse in our mapping + a.taskResponseMap[taskResponseDigest] = taskResponse + } signatureVerificationErrorC := make(chan error) // send the task to the goroutine processing this task From 994392fc37c2f29c34cc13afe6116d46085622d3 Mon Sep 17 00:00:00 2001 From: afkbyte Date: Sat, 25 May 2024 19:27:51 -0400 Subject: [PATCH 4/8] add interface for hash function --- services/bls_aggregation/blsagg.go | 30 ++--- services/bls_aggregation/blsagg_test.go | 147 ++++++++++++++++-------- types/avs.go | 4 +- 3 files changed, 113 insertions(+), 68 deletions(-) diff --git a/services/bls_aggregation/blsagg.go b/services/bls_aggregation/blsagg.go index 6f1877f2..11f2a19c 100644 --- a/services/bls_aggregation/blsagg.go +++ b/services/bls_aggregation/blsagg.go @@ -2,7 +2,6 @@ package blsagg import ( "context" - "crypto/sha256" "errors" "fmt" "math/big" @@ -91,7 +90,7 @@ type BlsAggregationService interface { timeToExpiry time.Duration, ) error - // ProcessNewSignature processes a new signature over a taskResponseDigest (sha256 of the taskResponse) for a particular taskIndex by a particular operator + // ProcessNewSignature processes a new signature over a taskResponseDigest for a particular taskIndex by a particular operator // It verifies that the signature is correct and returns an error if it is not, and then aggregates the signature and stake of // the operator with all other signatures for the same taskIndex and taskResponseDigest pair. // Note: This function currently only verifies signatures over the taskResponseDigest directly, so avs code needs to verify that the digest @@ -139,11 +138,13 @@ type BlsAggregatorService struct { // taskResponseMap is a map of taskResponseDigest to taskResponse taskResponseMap map[types.TaskResponseDigest]types.TaskResponse + + hashFunction types.TaskResponseHashFunction } var _ BlsAggregationService = (*BlsAggregatorService)(nil) -func NewBlsAggregatorService(avsRegistryService avsregistry.AvsRegistryService, logger logging.Logger) *BlsAggregatorService { +func NewBlsAggregatorService(avsRegistryService avsregistry.AvsRegistryService, hashFunction types.TaskResponseHashFunction, logger logging.Logger) *BlsAggregatorService { return &BlsAggregatorService{ aggregatedResponsesC: make(chan BlsAggregationServiceResponse), signedTaskRespsCs: make(map[types.TaskIndex]chan types.SignedTaskResponseDigest), @@ -151,6 +152,7 @@ func NewBlsAggregatorService(avsRegistryService avsregistry.AvsRegistryService, avsRegistryService: avsRegistryService, logger: logger, taskResponseMap: make(map[types.TaskResponseDigest]types.TaskResponse), + hashFunction: hashFunction, } } @@ -195,8 +197,8 @@ func (a *BlsAggregatorService) ProcessNewSignature( if !taskInitialized { return TaskNotFoundErrorFn(taskIndex) } - // compute the taskResponseDigest, note that this is now enforcing a specific encoding for the taskResponse - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + // compute the taskResponseDigest using the hash function + taskResponseDigest := a.hashFunction(taskResponse) // check if the taskResponseDigest is already in the map _, taskResponseExists := a.taskResponseMap[taskResponseDigest] @@ -272,6 +274,7 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( select { case signedTaskResponseDigest := <-signedTaskRespsC: a.logger.Debug("Task goroutine received new signed task response digest", "taskIndex", taskIndex, "signedTaskResponseDigest", signedTaskResponseDigest) + err := a.verifySignature(taskIndex, signedTaskResponseDigest, operatorsAvsStateDict) signedTaskResponseDigest.SignatureVerificationErrorC <- err if err != nil { @@ -334,18 +337,6 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( return } - // Retrieve the TaskResponse from the map - taskResponse := a.taskResponseMap[signedTaskResponseDigest.TaskResponseDigest] - - // verify that the taskResponseDigest that was signed is the digest of the taskResponse - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) - if signedTaskResponseDigest.TaskResponseDigest != taskResponseDigest { - a.aggregatedResponsesC <- BlsAggregationServiceResponse{ - Err: fmt.Errorf("signedTaskResponseDigest.TaskResponseDigest %x is not the digest of the TaskResponse %x", signedTaskResponseDigest.TaskResponseDigest, taskResponseDigest), - } - return - } - blsAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, @@ -402,7 +393,7 @@ func (a *BlsAggregatorService) verifySignature( return OperatorNotPartOfTaskQuorumErrorFn(signedTaskResponseDigest.OperatorId, taskIndex) } - // 0. verify that the msg actually came from the correct operator + // verify that the msg actually came from the correct operator operatorG2Pubkey := operatorsAvsStateDict[signedTaskResponseDigest.OperatorId].OperatorInfo.Pubkeys.G2Pubkey if operatorG2Pubkey == nil { a.logger.Error("Operator G2 pubkey not found", "operatorId", signedTaskResponseDigest.OperatorId, "taskId", taskIndex) @@ -413,6 +404,9 @@ func (a *BlsAggregatorService) verifySignature( "taskResponseDigest", signedTaskResponseDigest.TaskResponseDigest, "blsSignature", signedTaskResponseDigest.BlsSignature, ) + + // if the operator signs a digest that is not the digest of the TaskResponse submitted in ProcessNewTask + // then the signature will not be verified signatureVerified, err := signedTaskResponseDigest.BlsSignature.Verify(operatorG2Pubkey, signedTaskResponseDigest.TaskResponseDigest) if err != nil { return SignatureVerificationError(err) diff --git a/services/bls_aggregation/blsagg_test.go b/services/bls_aggregation/blsagg_test.go index ca6b62bf..616c6754 100644 --- a/services/bls_aggregation/blsagg_test.go +++ b/services/bls_aggregation/blsagg_test.go @@ -3,6 +3,7 @@ package blsagg import ( "context" "crypto/sha256" + "encoding/json" "math/big" "testing" "time" @@ -24,6 +25,28 @@ func TestBlsAgg(t *testing.T) { // 1 second seems to be enough for tests to pass. Currently takes 5s to run all tests tasksTimeToExpiry := 1 * time.Second + hashFunction := func(taskResponse types.TaskResponse) types.TaskResponseDigest { + taskResponseBytes, err := json.Marshal(taskResponse) + if err != nil { + panic(err) + } + return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes)) + } + + wrongHashFunction := func(taskResponse types.TaskResponse) types.TaskResponseDigest { + taskResponseBytes, err := json.Marshal(taskResponse) + if err != nil { + panic(err) + } + // append something to the taskResponseBytes to make it different + taskResponseBytes = append(taskResponseBytes, []byte("something")...) + return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes)) + } + + type mockTaskResponse struct { + Value int + } + t.Run("1 quorum 1 operator 1 correct signature", func(t *testing.T) { testOperator1 := types.TestOperator{ OperatorId: types.OperatorId{1}, @@ -34,16 +57,16 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} - taskResponse := types.TaskResponse{123} // Initialize with appropriate data + taskResponse := mockTaskResponse{123} // Initialize with appropriate data // Compute the TaskResponseDigest as the SHA-256 sum of the TaskResponse - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + taskResponseDigest := hashFunction(taskResponse) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -83,13 +106,13 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} - taskResponse := types.TaskResponse{123} + taskResponse := mockTaskResponse{123} - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + taskResponseDigest := hashFunction(taskResponse) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -139,12 +162,12 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponse := types.TaskResponse{123} - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -189,17 +212,17 @@ func TestBlsAgg(t *testing.T) { fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) // initialize 2 concurrent tasks task1Index := types.TaskIndex(1) - task1Response := types.TaskResponse{123} - task1ResponseDigest := types.TaskResponseDigest(sha256.Sum256(task1Response)) + task1Response := mockTaskResponse{123} + task1ResponseDigest := hashFunction(task1Response) err := blsAggServ.InitializeNewTask(task1Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) task2Index := types.TaskIndex(2) - task2Response := types.TaskResponse{234} - task2ResponseDigest := types.TaskResponseDigest(sha256.Sum256(task2Response)) + task2Response := mockTaskResponse{234} + task2ResponseDigest := hashFunction(task2Response) err = blsAggServ.InitializeNewTask(task2Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -269,7 +292,7 @@ func TestBlsAgg(t *testing.T) { fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -294,14 +317,14 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{50} - taskResponse := types.TaskResponse{123} - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -336,13 +359,13 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{60} - taskResponse := types.TaskResponse{123} - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -371,13 +394,13 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponse := types.TaskResponse{123} - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -426,13 +449,13 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{50, 50} - taskResponse := types.TaskResponse{123} - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -483,13 +506,13 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{60, 60} - taskResponse := types.TaskResponse{123} - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -517,13 +540,13 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponse := types.TaskResponse{123} - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -553,13 +576,13 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponse := types.TaskResponse{123} - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -582,13 +605,13 @@ func TestBlsAgg(t *testing.T) { } blockNum := uint32(1) taskIndex := types.TaskIndex(0) - taskResponse := types.TaskResponse{123} - taskResponseDigest := types.TaskResponseDigest(sha256.Sum256(taskResponse)) + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Equal(t, TaskNotFoundErrorFn(taskIndex), err) @@ -614,18 +637,18 @@ func TestBlsAgg(t *testing.T) { fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - taskResponse1 := types.TaskResponse{1} - taskResponseDigest1 := types.TaskResponseDigest(sha256.Sum256(taskResponse1)) + taskResponse1 := mockTaskResponse{1} + taskResponseDigest1 := hashFunction(taskResponse1) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse1, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) - taskResponse2 := types.TaskResponse{2} - taskResponseDigest2 := types.TaskResponseDigest(sha256.Sum256(taskResponse2)) + taskResponse2 := mockTaskResponse{2} + taskResponseDigest2 := hashFunction(taskResponse2) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -665,17 +688,17 @@ func TestBlsAgg(t *testing.T) { fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - taskResponse1 := types.TaskResponse{1} - taskResponseDigest1 := types.TaskResponseDigest(sha256.Sum256(taskResponse1)) + taskResponse1 := mockTaskResponse{1} + taskResponseDigest1 := hashFunction(taskResponse1) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse1, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) - taskResponse2 := types.TaskResponse{2} - taskResponseDigest2 := types.TaskResponseDigest(sha256.Sum256(taskResponse2)) + taskResponse2 := mockTaskResponse{2} + taskResponseDigest2 := hashFunction(taskResponse2) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2) err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse2, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) @@ -685,6 +708,32 @@ func TestBlsAgg(t *testing.T) { gotAggregationServiceResponse := <-blsAggServ.aggregatedResponsesC require.Equal(t, wantAggregationServiceResponse, gotAggregationServiceResponse) }) + + t.Run("1 quorum 1 operator 1 invalid signature (TaskResponseDigest does not match TaskResponse)", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + blockNum := uint32(1) + taskIndex := types.TaskIndex(0) + quorumNumbers := types.QuorumNums{0} + quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} + taskResponse := mockTaskResponse{123} // Initialize with appropriate data + + taskResponseDigest := wrongHashFunction(taskResponse) + + blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) + + err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) + require.EqualError(t, err, "Signature verification failed. Incorrect Signature.") + }) } func newBlsKeyPairPanics(hexKey string) *bls.KeyPair { diff --git a/types/avs.go b/types/avs.go index bf76feb0..79e66871 100644 --- a/types/avs.go +++ b/types/avs.go @@ -8,7 +8,9 @@ import ( type TaskIndex = uint32 type TaskResponseDigest = Bytes32 -type TaskResponse = []byte +type TaskResponse = interface{} + +type TaskResponseHashFunction func(taskResponse TaskResponse) TaskResponseDigest type SignedTaskResponseDigest struct { TaskResponseDigest TaskResponseDigest From 32b29a7c33f220d8718c844dbb3b867f9ed657db Mon Sep 17 00:00:00 2001 From: afkbyte Date: Sat, 25 May 2024 19:35:39 -0400 Subject: [PATCH 5/8] fix lint --- services/bls_aggregation/blsagg.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/services/bls_aggregation/blsagg.go b/services/bls_aggregation/blsagg.go index 11f2a19c..c46d2d94 100644 --- a/services/bls_aggregation/blsagg.go +++ b/services/bls_aggregation/blsagg.go @@ -274,6 +274,8 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( select { case signedTaskResponseDigest := <-signedTaskRespsC: a.logger.Debug("Task goroutine received new signed task response digest", "taskIndex", taskIndex, "signedTaskResponseDigest", signedTaskResponseDigest) + // Retrieve the TaskResponse from the map + taskResponse := a.taskResponseMap[signedTaskResponseDigest.TaskResponseDigest] err := a.verifySignature(taskIndex, signedTaskResponseDigest, operatorsAvsStateDict) signedTaskResponseDigest.SignatureVerificationErrorC <- err From 2177792fa8f14b23cead9ea6a159cfca2b3c0776 Mon Sep 17 00:00:00 2001 From: afkbyte Date: Wed, 29 May 2024 00:43:03 -0400 Subject: [PATCH 6/8] make local --- services/bls_aggregation/blsagg.go | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/services/bls_aggregation/blsagg.go b/services/bls_aggregation/blsagg.go index c46d2d94..3292c334 100644 --- a/services/bls_aggregation/blsagg.go +++ b/services/bls_aggregation/blsagg.go @@ -197,15 +197,6 @@ func (a *BlsAggregatorService) ProcessNewSignature( if !taskInitialized { return TaskNotFoundErrorFn(taskIndex) } - // compute the taskResponseDigest using the hash function - taskResponseDigest := a.hashFunction(taskResponse) - - // check if the taskResponseDigest is already in the map - _, taskResponseExists := a.taskResponseMap[taskResponseDigest] - if !taskResponseExists { - // Store the TaskResponse in our mapping - a.taskResponseMap[taskResponseDigest] = taskResponse - } signatureVerificationErrorC := make(chan error) // send the task to the goroutine processing this task @@ -215,7 +206,7 @@ func (a *BlsAggregatorService) ProcessNewSignature( // we need to send this as part of select because if the goroutine is processing another SignedTaskResponseDigest // and cannot receive this one, we want the context to be able to cancel the request case taskC <- types.SignedTaskResponseDigest{ - TaskResponseDigest: taskResponseDigest, + TaskResponse: taskResponse, BlsSignature: blsSignature, OperatorId: operatorId, SignatureVerificationErrorC: signatureVerificationErrorC, @@ -274,8 +265,8 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( select { case signedTaskResponseDigest := <-signedTaskRespsC: a.logger.Debug("Task goroutine received new signed task response digest", "taskIndex", taskIndex, "signedTaskResponseDigest", signedTaskResponseDigest) - // Retrieve the TaskResponse from the map - taskResponse := a.taskResponseMap[signedTaskResponseDigest.TaskResponseDigest] + // compute the taskResponseDigest using the hash function + taskResponseDigest := a.hashFunction(signedTaskResponseDigest.TaskResponse) err := a.verifySignature(taskIndex, signedTaskResponseDigest, operatorsAvsStateDict) signedTaskResponseDigest.SignatureVerificationErrorC <- err @@ -283,7 +274,7 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( continue } // after verifying signature we aggregate its sig and pubkey, and update the signed stake amount - digestAggregatedOperators, ok := aggregatedOperatorsDict[signedTaskResponseDigest.TaskResponseDigest] + digestAggregatedOperators, ok := aggregatedOperatorsDict[taskResponseDigest] if !ok { // first operator to sign on this digest digestAggregatedOperators = aggregatedOperators{ @@ -308,7 +299,7 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( } // update the aggregatedOperatorsDict. Note that we need to assign the whole struct value at once, // because of https://github.com/golang/go/issues/3117 - aggregatedOperatorsDict[signedTaskResponseDigest.TaskResponseDigest] = digestAggregatedOperators + aggregatedOperatorsDict[taskResponseDigest] = digestAggregatedOperators if checkIfStakeThresholdsMet(a.logger, digestAggregatedOperators.signersTotalStakePerQuorum, totalStakePerQuorum, quorumThresholdPercentagesMap) { nonSignersOperatorIds := []types.OperatorId{} @@ -342,8 +333,8 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( blsAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, - TaskResponse: taskResponse, - TaskResponseDigest: signedTaskResponseDigest.TaskResponseDigest, + TaskResponse: signedTaskResponseDigest.TaskResponse, + TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: nonSignersG1Pubkeys, QuorumApksG1: quorumApksG1, SignersApkG2: digestAggregatedOperators.signersApkG2, @@ -395,6 +386,8 @@ func (a *BlsAggregatorService) verifySignature( return OperatorNotPartOfTaskQuorumErrorFn(signedTaskResponseDigest.OperatorId, taskIndex) } + taskResponseDigest := a.hashFunction(signedTaskResponseDigest.TaskResponse) + // verify that the msg actually came from the correct operator operatorG2Pubkey := operatorsAvsStateDict[signedTaskResponseDigest.OperatorId].OperatorInfo.Pubkeys.G2Pubkey if operatorG2Pubkey == nil { @@ -403,13 +396,13 @@ func (a *BlsAggregatorService) verifySignature( } a.logger.Debug("Verifying signed task response digest signature", "operatorG2Pubkey", operatorG2Pubkey, - "taskResponseDigest", signedTaskResponseDigest.TaskResponseDigest, + "taskResponseDigest", taskResponseDigest, "blsSignature", signedTaskResponseDigest.BlsSignature, ) // if the operator signs a digest that is not the digest of the TaskResponse submitted in ProcessNewTask // then the signature will not be verified - signatureVerified, err := signedTaskResponseDigest.BlsSignature.Verify(operatorG2Pubkey, signedTaskResponseDigest.TaskResponseDigest) + signatureVerified, err := signedTaskResponseDigest.BlsSignature.Verify(operatorG2Pubkey, taskResponseDigest) if err != nil { return SignatureVerificationError(err) } From 7f920debdbb5e0d3eb1057de9ad151f03947f0c8 Mon Sep 17 00:00:00 2001 From: afkbyte Date: Wed, 29 May 2024 00:46:03 -0400 Subject: [PATCH 7/8] add types --- types/avs.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/types/avs.go b/types/avs.go index 79e66871..22787013 100644 --- a/types/avs.go +++ b/types/avs.go @@ -13,7 +13,7 @@ type TaskResponse = interface{} type TaskResponseHashFunction func(taskResponse TaskResponse) TaskResponseDigest type SignedTaskResponseDigest struct { - TaskResponseDigest TaskResponseDigest + TaskResponse TaskResponse BlsSignature *bls.Signature OperatorId OperatorId SignatureVerificationErrorC chan error `json:"-"` // removed from json because channels are not marshallable @@ -21,7 +21,7 @@ type SignedTaskResponseDigest struct { func (strd SignedTaskResponseDigest) LogValue() slog.Value { return slog.GroupValue( - slog.Any("taskResponseDigest", strd.TaskResponseDigest), + slog.Any("taskResponse", strd.TaskResponse), slog.Any("blsSignature", strd.BlsSignature), slog.Any("operatorId", strd.OperatorId), ) From d98558b69b9808e4986647681b89d3215bac1a8b Mon Sep 17 00:00:00 2001 From: afkbyte Date: Wed, 29 May 2024 00:48:21 -0400 Subject: [PATCH 8/8] remove mapping --- services/bls_aggregation/blsagg.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/services/bls_aggregation/blsagg.go b/services/bls_aggregation/blsagg.go index 3292c334..ddd75cc5 100644 --- a/services/bls_aggregation/blsagg.go +++ b/services/bls_aggregation/blsagg.go @@ -136,9 +136,6 @@ type BlsAggregatorService struct { avsRegistryService avsregistry.AvsRegistryService logger logging.Logger - // taskResponseMap is a map of taskResponseDigest to taskResponse - taskResponseMap map[types.TaskResponseDigest]types.TaskResponse - hashFunction types.TaskResponseHashFunction } @@ -151,7 +148,6 @@ func NewBlsAggregatorService(avsRegistryService avsregistry.AvsRegistryService, taskChansMutex: sync.RWMutex{}, avsRegistryService: avsRegistryService, logger: logger, - taskResponseMap: make(map[types.TaskResponseDigest]types.TaskResponse), hashFunction: hashFunction, } }