From b28c4b39ad7ed479901ca7204afe8be22f38ac49 Mon Sep 17 00:00:00 2001 From: Gabriel Paradiso Date: Fri, 26 Jan 2024 15:04:32 +0100 Subject: [PATCH] feat: make persisting allowlist compatible with older contract --- .../handlers/functions/allowlist/allowlist.go | 11 ++++ .../functions/allowlist/allowlist_test.go | 11 +++- .../handlers/functions/allowlist/mocks/orm.go | 24 +++++++ .../handlers/functions/allowlist/orm.go | 24 +++++++ .../handlers/functions/allowlist/orm_test.go | 65 +++++++++++++++++++ 5 files changed, 134 insertions(+), 1 deletion(-) diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist.go b/core/services/gateway/handlers/functions/allowlist/allowlist.go index 20dc92ced70..cd92e0c984f 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist.go @@ -219,6 +219,17 @@ func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *b if err != nil { return errors.Wrap(err, "error calling GetAllAllowedSenders") } + + err = a.orm.PurgeAllowedSenders() + if err != nil { + a.lggr.Errorf("failed to purge allowedSenderList: %w", err) + } + + err = a.orm.CreateAllowedSenders(allowedSenderList) + if err != nil { + a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) + } + } else { err = a.syncBlockedSenders(ctx, tosContract, blockNum) if err != nil { diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist_test.go b/core/services/gateway/handlers/functions/allowlist/allowlist_test.go index e8cbca80b94..103ed412433 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist_test.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist_test.go @@ -50,6 +50,9 @@ func TestAllowlist_UpdateAndCheck(t *testing.T) { } orm := amocks.NewORM(t) + orm.On("PurgeAllowedSenders").Times(1).Return(nil) + orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -99,7 +102,9 @@ func TestAllowlist_UpdatePeriodically(t *testing.T) { } orm := amocks.NewORM(t) + orm.On("PurgeAllowedSenders").Times(1).Return(nil) orm.On("GetAllowedSenders", uint(0), uint(1000)).Return([]common.Address{}, nil) + orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -114,6 +119,7 @@ func TestAllowlist_UpdatePeriodically(t *testing.T) { return allowlist.Allow(common.HexToAddress(addr1)) && !allowlist.Allow(common.HexToAddress(addr3)) }, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue()) } + func TestAllowlist_UpdateFromContract(t *testing.T) { t.Parallel() @@ -151,7 +157,7 @@ func TestAllowlist_UpdateFromContract(t *testing.T) { }, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue()) }) - t.Run("OK-fetch_complete_list_of_allowed_senders_without_storing", func(t *testing.T) { + t.Run("OK-fetch_complete_list_of_allowed_senders", func(t *testing.T) { ctx, cancel := context.WithCancel(testutils.Context(t)) client := mocks.NewClient(t) client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil) @@ -171,6 +177,9 @@ func TestAllowlist_UpdateFromContract(t *testing.T) { } orm := amocks.NewORM(t) + orm.On("PurgeAllowedSenders").Times(1).Return(nil) + orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) diff --git a/core/services/gateway/handlers/functions/allowlist/mocks/orm.go b/core/services/gateway/handlers/functions/allowlist/mocks/orm.go index c2ba27c3a24..daff33d8902 100644 --- a/core/services/gateway/handlers/functions/allowlist/mocks/orm.go +++ b/core/services/gateway/handlers/functions/allowlist/mocks/orm.go @@ -101,6 +101,30 @@ func (_m *ORM) GetAllowedSenders(offset uint, limit uint, qopts ...pg.QOpt) ([]c return r0, r1 } +// PurgeAllowedSenders provides a mock function with given fields: qopts +func (_m *ORM) PurgeAllowedSenders(qopts ...pg.QOpt) error { + _va := make([]interface{}, len(qopts)) + for _i := range qopts { + _va[_i] = qopts[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for PurgeAllowedSenders") + } + + var r0 error + if rf, ok := ret.Get(0).(func(...pg.QOpt) error); ok { + r0 = rf(qopts...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // NewORM creates a new instance of ORM. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewORM(t interface { diff --git a/core/services/gateway/handlers/functions/allowlist/orm.go b/core/services/gateway/handlers/functions/allowlist/orm.go index 07ee1ea3b3b..ccacec81a43 100644 --- a/core/services/gateway/handlers/functions/allowlist/orm.go +++ b/core/services/gateway/handlers/functions/allowlist/orm.go @@ -18,6 +18,7 @@ type ORM interface { GetAllowedSenders(offset, limit uint, qopts ...pg.QOpt) ([]common.Address, error) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg.QOpt) error DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error + PurgeAllowedSenders(qopts ...pg.QOpt) error } type orm struct { @@ -91,6 +92,8 @@ func (o *orm) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg. return nil } +// DeleteAllowedSenders is used to remove blocked senders from the functions_allowlist table. +// This is achieved by specifying a list of blockedSenders to remove. func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error { var valuesPlaceholder []string for i := 1; i <= len(blockedSenders); i++ { @@ -121,3 +124,24 @@ func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg. return nil } + +// PurgeAllowedSenders will remove all the allowed senders for the configured orm routerContractAddress +func (o *orm) PurgeAllowedSenders(qopts ...pg.QOpt) error { + stmt := fmt.Sprintf(` + DELETE FROM %s + WHERE router_contract_address = $1;`, tableName) + + res, err := o.q.WithOpts(qopts...).Exec(stmt, o.routerContractAddress) + if err != nil { + return err + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + + o.lggr.Debugf("Successfully purged allowed senders for routerContractAddress: %s. rowsAffected: %d", o.routerContractAddress, rowsAffected) + + return nil +} diff --git a/core/services/gateway/handlers/functions/allowlist/orm_test.go b/core/services/gateway/handlers/functions/allowlist/orm_test.go index 0f63e83cd5f..1d357616fab 100644 --- a/core/services/gateway/handlers/functions/allowlist/orm_test.go +++ b/core/services/gateway/handlers/functions/allowlist/orm_test.go @@ -174,6 +174,71 @@ func TestORM_DeleteAllowedSenders(t *testing.T) { }) } +func TestORM_PurgeAllowedSenders(t *testing.T) { + t.Parallel() + + t.Run("OK-purge_allowed_list", func(t *testing.T) { + orm, err := setupORM(t) + require.NoError(t, err) + add1 := testutils.NewAddress() + add2 := testutils.NewAddress() + add3 := testutils.NewAddress() + err = orm.CreateAllowedSenders([]common.Address{add1, add2, add3}) + require.NoError(t, err) + + results, err := orm.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 3, len(results), "incorrect results length") + require.Equal(t, add1, results[0]) + + err = orm.PurgeAllowedSenders() + require.NoError(t, err) + + results, err = orm.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 0, len(results), "incorrect results length") + }) + + t.Run("OK-purge_allowed_list_for_contract_address", func(t *testing.T) { + orm1, err := setupORM(t) + require.NoError(t, err) + add1 := testutils.NewAddress() + add2 := testutils.NewAddress() + err = orm1.CreateAllowedSenders([]common.Address{add1, add2}) + require.NoError(t, err) + + results, err := orm1.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 2, len(results), "incorrect results length") + require.Equal(t, add1, results[0]) + + orm2, err := setupORM(t) + require.NoError(t, err) + add3 := testutils.NewAddress() + add4 := testutils.NewAddress() + err = orm2.CreateAllowedSenders([]common.Address{add3, add4}) + require.NoError(t, err) + + results, err = orm2.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 2, len(results), "incorrect results length") + require.Equal(t, add3, results[0]) + + err = orm2.PurgeAllowedSenders() + require.NoError(t, err) + + results, err = orm2.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 0, len(results), "incorrect results length") + + results, err = orm1.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 2, len(results), "incorrect results length") + require.Equal(t, add1, results[0]) + require.Equal(t, add2, results[1]) + }) +} + func Test_NewORM(t *testing.T) { t.Run("OK-create_ORM", func(t *testing.T) { _, err := allowlist.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), pgtest.NewQConfig(true), testutils.NewAddress())