Skip to content

Commit

Permalink
Avs registry chaincaller test
Browse files Browse the repository at this point in the history
  • Loading branch information
shrimalmadhur committed Jul 18, 2024
1 parent f88f44a commit fb3a19c
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 314 deletions.
173 changes: 103 additions & 70 deletions services/avsregistry/avsregistry_chaincaller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,90 @@ package avsregistry

import (
"context"
"errors"
"math/big"
"reflect"
"testing"

chainiomocks "github.com/Layr-Labs/eigensdk-go/chainio/mocks"
"github.com/ethereum/go-ethereum/accounts/abi/bind"

opstateretrievar "github.com/Layr-Labs/eigensdk-go/contracts/bindings/OperatorStateRetriever"
"github.com/Layr-Labs/eigensdk-go/crypto/bls"
"github.com/Layr-Labs/eigensdk-go/logging"
servicemocks "github.com/Layr-Labs/eigensdk-go/services/mocks"
"github.com/Layr-Labs/eigensdk-go/types"
"github.com/ethereum/go-ethereum/common"
"go.uber.org/mock/gomock"
)

type fakeAVSRegistryReader struct {
filterQueryRange *big.Int

Check failure on line 20 in services/avsregistry/avsregistry_chaincaller_test.go

View workflow job for this annotation

GitHub Actions / Lint

field `filterQueryRange` is unused (unused)
opAddress []types.OperatorAddr
opPubKeys []types.OperatorPubkeys
operatorId types.OperatorId
socket types.Socket
err error
}

func newFakeAVSRegistryReader(
opr *testOperator,
err error,
) *fakeAVSRegistryReader {
if opr == nil {
return &fakeAVSRegistryReader{}
}
return &fakeAVSRegistryReader{
opAddress: []common.Address{opr.operatorAddr},
opPubKeys: []types.OperatorPubkeys{opr.operatorInfo.Pubkeys},
socket: opr.operatorInfo.Socket,
operatorId: opr.operatorId,
err: err,
}
}

func (f *fakeAVSRegistryReader) GetOperatorFromId(
opts *bind.CallOpts,
operatorId types.OperatorId,
) (common.Address, error) {
return f.opAddress[0], f.err
}

func (f *fakeAVSRegistryReader) GetOperatorsStakeInQuorumsAtBlock(
opts *bind.CallOpts,
quorumNumbers types.QuorumNums,
blockNumber types.BlockNum,
) ([][]opstateretrievar.OperatorStateRetrieverOperator, error) {
return [][]opstateretrievar.OperatorStateRetrieverOperator{
{
{
OperatorId: f.operatorId,
Stake: big.NewInt(123),
},
},
}, nil
}

func (f *fakeAVSRegistryReader) GetCheckSignaturesIndices(
opts *bind.CallOpts,
referenceBlockNumber uint32,
quorumNumbers types.QuorumNums,
nonSignerOperatorIds []types.OperatorId,
) (opstateretrievar.OperatorStateRetrieverCheckSignaturesIndices, error) {
return opstateretrievar.OperatorStateRetrieverCheckSignaturesIndices{}, nil
}

type fakeOperatorInfoService struct {
operatorInfo types.OperatorInfo
}

func newFakeOperatorInfoService(operatorInfo types.OperatorInfo) *fakeOperatorInfoService {
return &fakeOperatorInfoService{
operatorInfo: operatorInfo,
}
}

func (f *fakeOperatorInfoService) GetOperatorInfo(ctx context.Context, operator common.Address) (operatorInfo types.OperatorInfo, operatorFound bool) {
return f.operatorInfo, true
}

type testOperator struct {
operatorAddr common.Address
operatorId types.OperatorId
Expand All @@ -24,7 +94,7 @@ type testOperator struct {

func TestAvsRegistryServiceChainCaller_getOperatorPubkeys(t *testing.T) {
logger := logging.NewNoopLogger()
testOperator := testOperator{
testOperator1 := testOperator{
operatorAddr: common.HexToAddress("0x1"),
operatorId: types.OperatorId{1},
operatorInfo: types.OperatorInfo{
Expand All @@ -38,40 +108,33 @@ func TestAvsRegistryServiceChainCaller_getOperatorPubkeys(t *testing.T) {

// TODO(samlaf): add error test cases
var tests = []struct {
name string
mocksInitializationFunc func(*chainiomocks.MockAVSReader, *servicemocks.MockOperatorsInfoService)
queryOperatorId types.OperatorId
wantErr error
wantOperatorInfo types.OperatorInfo
name string
operator *testOperator
queryOperatorId types.OperatorId
wantErr error
wantOperatorInfo types.OperatorInfo
}{
{
name: "should return operator info",
mocksInitializationFunc: func(mockAvsRegistryReader *chainiomocks.MockAVSReader, mockOperatorsInfoService *servicemocks.MockOperatorsInfoService) {
mockAvsRegistryReader.EXPECT().GetOperatorFromId(gomock.Any(), testOperator.operatorId).Return(testOperator.operatorAddr, nil)
mockOperatorsInfoService.EXPECT().GetOperatorInfo(gomock.Any(), testOperator.operatorAddr).Return(testOperator.operatorInfo, true)
},
queryOperatorId: testOperator.operatorId,
name: "should return operator info",
operator: &testOperator1,
queryOperatorId: testOperator1.operatorId,
wantErr: nil,
wantOperatorInfo: testOperator.operatorInfo,
wantOperatorInfo: testOperator1.operatorInfo,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mocks
mockCtrl := gomock.NewController(t)
mockAvsRegistryReader := chainiomocks.NewMockAVSReader(mockCtrl)
mockOperatorsInfoService := servicemocks.NewMockOperatorsInfoService(mockCtrl)
mockAvsRegistryReader := newFakeAVSRegistryReader(tt.operator, nil)
mockOperatorsInfoService := newFakeOperatorInfoService(tt.operator.operatorInfo)

if tt.mocksInitializationFunc != nil {
tt.mocksInitializationFunc(mockAvsRegistryReader, mockOperatorsInfoService)
}
// Create a new instance of the avsregistry service
service := NewAvsRegistryServiceChainCaller(mockAvsRegistryReader, mockOperatorsInfoService, logger)

// Call the GetOperatorPubkeys method with the test operator address
gotOperatorInfo, gotErr := service.getOperatorInfo(context.Background(), tt.queryOperatorId)
if tt.wantErr != gotErr {
if !errors.Is(gotErr, tt.wantErr) {
t.Fatalf("GetOperatorPubkeys returned wrong error. Got: %v, want: %v.", gotErr, tt.wantErr)
}
if tt.wantErr == nil && !reflect.DeepEqual(tt.wantOperatorInfo, gotOperatorInfo) {
Expand All @@ -83,7 +146,7 @@ func TestAvsRegistryServiceChainCaller_getOperatorPubkeys(t *testing.T) {

func TestAvsRegistryServiceChainCaller_GetOperatorsAvsState(t *testing.T) {
logger := logging.NewNoopLogger()
testOperator := testOperator{
testOperator1 := testOperator{
operatorAddr: common.HexToAddress("0x1"),
operatorId: types.OperatorId{1},
operatorInfo: types.OperatorInfo{
Expand All @@ -97,33 +160,22 @@ func TestAvsRegistryServiceChainCaller_GetOperatorsAvsState(t *testing.T) {

var tests = []struct {
name string
mocksInitializationFunc func(*chainiomocks.MockAVSReader, *servicemocks.MockOperatorsInfoService)
queryQuorumNumbers types.QuorumNums
queryBlockNum types.BlockNum
wantErr error
wantOperatorsAvsStateDict map[types.OperatorId]types.OperatorAvsState
operator *testOperator
}{
{
name: "should return operatorsAvsState",
mocksInitializationFunc: func(mockAvsRegistryReader *chainiomocks.MockAVSReader, mockOperatorsInfoService *servicemocks.MockOperatorsInfoService) {
mockAvsRegistryReader.EXPECT().GetOperatorsStakeInQuorumsAtBlock(gomock.Any(), types.QuorumNums{1}, types.BlockNum(1)).Return([][]opstateretrievar.OperatorStateRetrieverOperator{
{
{
OperatorId: testOperator.operatorId,
Stake: big.NewInt(123),
},
},
}, nil)
mockAvsRegistryReader.EXPECT().GetOperatorFromId(gomock.Any(), testOperator.operatorId).Return(testOperator.operatorAddr, nil)
mockOperatorsInfoService.EXPECT().GetOperatorInfo(gomock.Any(), testOperator.operatorAddr).Return(testOperator.operatorInfo, true)
},
name: "should return operatorsAvsState",
queryQuorumNumbers: types.QuorumNums{1},
operator: &testOperator1,
queryBlockNum: 1,
wantErr: nil,
wantOperatorsAvsStateDict: map[types.OperatorId]types.OperatorAvsState{
testOperator.operatorId: {
OperatorId: testOperator.operatorId,
OperatorInfo: testOperator.operatorInfo,
testOperator1.operatorId: {
OperatorId: testOperator1.operatorId,
OperatorInfo: testOperator1.operatorInfo,
StakePerQuorum: map[types.QuorumNum]types.StakeAmount{1: big.NewInt(123)},
BlockNumber: 1,
},
Expand All @@ -134,19 +186,15 @@ func TestAvsRegistryServiceChainCaller_GetOperatorsAvsState(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mocks
mockCtrl := gomock.NewController(t)
mockAvsRegistryReader := chainiomocks.NewMockAVSReader(mockCtrl)
mockOperatorsInfoService := servicemocks.NewMockOperatorsInfoService(mockCtrl)
mockAvsRegistryReader := newFakeAVSRegistryReader(tt.operator, nil)
mockOperatorsInfoService := newFakeOperatorInfoService(tt.operator.operatorInfo)

if tt.mocksInitializationFunc != nil {
tt.mocksInitializationFunc(mockAvsRegistryReader, mockOperatorsInfoService)
}
// Create a new instance of the avsregistry service
service := NewAvsRegistryServiceChainCaller(mockAvsRegistryReader, mockOperatorsInfoService, logger)

// Call the GetOperatorPubkeys method with the test operator address
gotOperatorsAvsStateDict, gotErr := service.GetOperatorsAvsStateAtBlock(context.Background(), tt.queryQuorumNumbers, tt.queryBlockNum)
if tt.wantErr != gotErr {
if !errors.Is(gotErr, tt.wantErr) {
t.Fatalf("GetOperatorsAvsState returned wrong error. Got: %v, want: %v.", gotErr, tt.wantErr)
}
if tt.wantErr == nil && !reflect.DeepEqual(tt.wantOperatorsAvsStateDict, gotOperatorsAvsStateDict) {
Expand All @@ -158,7 +206,7 @@ func TestAvsRegistryServiceChainCaller_GetOperatorsAvsState(t *testing.T) {

func TestAvsRegistryServiceChainCaller_GetQuorumsAvsState(t *testing.T) {
logger := logging.NewNoopLogger()
testOperator := testOperator{
testOperator1 := testOperator{
operatorAddr: common.HexToAddress("0x1"),
operatorId: types.OperatorId{1},
operatorInfo: types.OperatorInfo{
Expand All @@ -172,27 +220,16 @@ func TestAvsRegistryServiceChainCaller_GetQuorumsAvsState(t *testing.T) {

var tests = []struct {
name string
mocksInitializationFunc func(*chainiomocks.MockAVSReader, *servicemocks.MockOperatorsInfoService)
queryQuorumNumbers types.QuorumNums
queryBlockNum types.BlockNum
wantErr error
wantQuorumsAvsStateDict map[types.QuorumNum]types.QuorumAvsState
operator *testOperator
}{
{
name: "should return operatorsAvsState",
mocksInitializationFunc: func(mockAvsRegistryReader *chainiomocks.MockAVSReader, mockOperatorsInfoService *servicemocks.MockOperatorsInfoService) {
mockAvsRegistryReader.EXPECT().GetOperatorsStakeInQuorumsAtBlock(gomock.Any(), types.QuorumNums{1}, types.BlockNum(1)).Return([][]opstateretrievar.OperatorStateRetrieverOperator{
{
{
OperatorId: testOperator.operatorId,
Stake: big.NewInt(123),
},
},
}, nil)
mockAvsRegistryReader.EXPECT().GetOperatorFromId(gomock.Any(), testOperator.operatorId).Return(testOperator.operatorAddr, nil)
mockOperatorsInfoService.EXPECT().GetOperatorInfo(gomock.Any(), testOperator.operatorAddr).Return(testOperator.operatorInfo, true)
},
name: "should return operatorsAvsState",
queryQuorumNumbers: types.QuorumNums{1},
operator: &testOperator1,
queryBlockNum: 1,
wantErr: nil,
wantQuorumsAvsStateDict: map[types.QuorumNum]types.QuorumAvsState{
Expand All @@ -209,19 +246,15 @@ func TestAvsRegistryServiceChainCaller_GetQuorumsAvsState(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mocks
mockCtrl := gomock.NewController(t)
mockAvsRegistryReader := chainiomocks.NewMockAVSReader(mockCtrl)
mockOperatorsInfoService := servicemocks.NewMockOperatorsInfoService(mockCtrl)
mockAvsRegistryReader := newFakeAVSRegistryReader(tt.operator, nil)
mockOperatorsInfoService := newFakeOperatorInfoService(tt.operator.operatorInfo)

if tt.mocksInitializationFunc != nil {
tt.mocksInitializationFunc(mockAvsRegistryReader, mockOperatorsInfoService)
}
// Create a new instance of the avsregistry service
service := NewAvsRegistryServiceChainCaller(mockAvsRegistryReader, mockOperatorsInfoService, logger)

// Call the GetOperatorPubkeys method with the test operator address
aggG1PubkeyPerQuorum, gotErr := service.GetQuorumsAvsStateAtBlock(context.Background(), tt.queryQuorumNumbers, tt.queryBlockNum)
if tt.wantErr != gotErr {
if !errors.Is(gotErr, tt.wantErr) {
t.Fatalf("GetOperatorsAvsState returned wrong error. Got: %v, want: %v.", gotErr, tt.wantErr)
}
if tt.wantErr == nil && !reflect.DeepEqual(tt.wantQuorumsAvsStateDict, aggG1PubkeyPerQuorum) {
Expand Down
13 changes: 0 additions & 13 deletions services/gen.go

This file was deleted.

88 changes: 0 additions & 88 deletions services/mocks/avsregistry.go

This file was deleted.

Loading

0 comments on commit fb3a19c

Please sign in to comment.