Skip to content

Commit

Permalink
feat: make persisting allowlist compatible with older contract
Browse files Browse the repository at this point in the history
  • Loading branch information
agparadiso committed Jan 26, 2024
1 parent c76a01e commit 98744a8
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 1 deletion.
11 changes: 11 additions & 0 deletions core/services/gateway/handlers/functions/allowlist/allowlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,17 @@ func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *b
if err != nil {
return errors.Wrap(err, "error calling GetAllAllowedSenders")
}

err = a.orm.PurgeAllowedSenders()
if err != nil {
a.lggr.Errorf("failed to purge allowedSenderList: %w", err)
}

err = a.orm.CreateAllowedSenders(allowedSenderList)
if err != nil {
a.lggr.Errorf("failed to update stored allowedSenderList: %w", err)
}

} else {
err = a.syncBlockedSenders(ctx, tosContract, blockNum)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ func TestAllowlist_UpdateAndCheck(t *testing.T) {
}

orm := amocks.NewORM(t)
orm.On("PurgeAllowedSenders").Times(1).Return(nil)
orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil)

allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t))
require.NoError(t, err)

Expand Down Expand Up @@ -99,7 +102,9 @@ func TestAllowlist_UpdatePeriodically(t *testing.T) {
}

orm := amocks.NewORM(t)
orm.On("PurgeAllowedSenders").Times(1).Return(nil)
orm.On("GetAllowedSenders", uint(0), uint(1000)).Return([]common.Address{}, nil)
orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil)

allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t))
require.NoError(t, err)
Expand All @@ -114,6 +119,7 @@ func TestAllowlist_UpdatePeriodically(t *testing.T) {
return allowlist.Allow(common.HexToAddress(addr1)) && !allowlist.Allow(common.HexToAddress(addr3))
}, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue())
}

func TestAllowlist_UpdateFromContract(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -151,7 +157,7 @@ func TestAllowlist_UpdateFromContract(t *testing.T) {
}, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue())
})

t.Run("OK-fetch_complete_list_of_allowed_senders_without_storing", func(t *testing.T) {
t.Run("OK-fetch_complete_list_of_allowed_senders", func(t *testing.T) {
ctx, cancel := context.WithCancel(testutils.Context(t))
client := mocks.NewClient(t)
client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil)
Expand All @@ -171,6 +177,9 @@ func TestAllowlist_UpdateFromContract(t *testing.T) {
}

orm := amocks.NewORM(t)
orm.On("PurgeAllowedSenders").Times(1).Return(nil)
orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil)

allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t))
require.NoError(t, err)

Expand Down
24 changes: 24 additions & 0 deletions core/services/gateway/handlers/functions/allowlist/mocks/orm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions core/services/gateway/handlers/functions/allowlist/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type ORM interface {
GetAllowedSenders(offset, limit uint, qopts ...pg.QOpt) ([]common.Address, error)
CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg.QOpt) error
DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error
PurgeAllowedSenders(qopts ...pg.QOpt) error
}

type orm struct {
Expand Down Expand Up @@ -91,6 +92,8 @@ func (o *orm) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg.
return nil
}

// DeleteAllowedSenders is used to remove blocked senders from the functions_allowlist table.
// This is achieved by specifying a list of blockedSenders to remove.
func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error {
var valuesPlaceholder []string
for i := 1; i <= len(blockedSenders); i++ {
Expand Down Expand Up @@ -121,3 +124,24 @@ func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.

return nil
}

// PurgeAllowedSenders will remove all the allowed senders for the configured orm routerContractAddress
func (o *orm) PurgeAllowedSenders(qopts ...pg.QOpt) error {
stmt := fmt.Sprintf(`
DELETE FROM %s
WHERE router_contract_address = $1;`, tableName)

res, err := o.q.WithOpts(qopts...).Exec(stmt, o.routerContractAddress)
if err != nil {
return err
}

rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}

o.lggr.Debugf("Successfully purged allowed senders for routerContractAddress: %s. rowsAffected: %d", o.routerContractAddress, rowsAffected)

return nil
}
65 changes: 65 additions & 0 deletions core/services/gateway/handlers/functions/allowlist/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,71 @@ func TestORM_DeleteAllowedSenders(t *testing.T) {
})
}

func TestORM_PurgeAllowedSenders(t *testing.T) {
t.Parallel()

t.Run("OK-purge_allowed_list", func(t *testing.T) {
orm, err := setupORM(t)
require.NoError(t, err)
add1 := testutils.NewAddress()
add2 := testutils.NewAddress()
add3 := testutils.NewAddress()
err = orm.CreateAllowedSenders([]common.Address{add1, add2, add3})
require.NoError(t, err)

results, err := orm.GetAllowedSenders(0, 10)
require.NoError(t, err)
require.Equal(t, 3, len(results), "incorrect results length")
require.Equal(t, add1, results[0])

err = orm.PurgeAllowedSenders()
require.NoError(t, err)

results, err = orm.GetAllowedSenders(0, 10)
require.NoError(t, err)
require.Equal(t, 0, len(results), "incorrect results length")
})

t.Run("OK-purge_allowed_list_for_contract_address", func(t *testing.T) {
orm1, err := setupORM(t)
require.NoError(t, err)
add1 := testutils.NewAddress()
add2 := testutils.NewAddress()
err = orm1.CreateAllowedSenders([]common.Address{add1, add2})
require.NoError(t, err)

results, err := orm1.GetAllowedSenders(0, 10)
require.NoError(t, err)
require.Equal(t, 2, len(results), "incorrect results length")
require.Equal(t, add1, results[0])

orm2, err := setupORM(t)
require.NoError(t, err)
add3 := testutils.NewAddress()
add4 := testutils.NewAddress()
err = orm2.CreateAllowedSenders([]common.Address{add3, add4})
require.NoError(t, err)

results, err = orm2.GetAllowedSenders(0, 10)
require.NoError(t, err)
require.Equal(t, 2, len(results), "incorrect results length")
require.Equal(t, add3, results[0])

err = orm2.PurgeAllowedSenders()
require.NoError(t, err)

results, err = orm2.GetAllowedSenders(0, 10)
require.NoError(t, err)
require.Equal(t, 0, len(results), "incorrect results length")

results, err = orm1.GetAllowedSenders(0, 10)
require.NoError(t, err)
require.Equal(t, 2, len(results), "incorrect results length")
require.Equal(t, add1, results[0])
require.Equal(t, add2, results[1])
})
}

func Test_NewORM(t *testing.T) {
t.Run("OK-create_ORM", func(t *testing.T) {
_, err := allowlist.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), pgtest.NewQConfig(true), testutils.NewAddress())
Expand Down

0 comments on commit 98744a8

Please sign in to comment.