Skip to content

Commit

Permalink
Refactor forwarder ORM
Browse files Browse the repository at this point in the history
  • Loading branch information
DylanTinianov committed Feb 21, 2024
1 parent 423529c commit 3183b21
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 87 deletions.
10 changes: 6 additions & 4 deletions core/chains/evm/forwarders/forwarder_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ type FwdMgr struct {
wg sync.WaitGroup
}

func NewFwdMgr(db *sqlx.DB, client evmclient.Client, logpoller evmlogpoller.LogPoller, l logger.Logger, cfg Config, dbConfig pg.QConfig) *FwdMgr {
func NewFwdMgr(db *sqlx.DB, client evmclient.Client, logpoller evmlogpoller.LogPoller, l logger.Logger, cfg Config) *FwdMgr {
lggr := logger.Sugared(logger.Named(l, "EVMForwarderManager"))
fwdMgr := FwdMgr{
logger: lggr,
cfg: cfg,
evmClient: client,
ORM: NewORM(db, lggr, dbConfig),
ORM: NewORM(db),
logpoller: logpoller,
sendersCache: make(map[common.Address][]common.Address),
}
Expand All @@ -80,7 +80,7 @@ func (f *FwdMgr) Start(ctx context.Context) error {
f.logger.Debug("Initializing EVM forwarder manager")
chainId := f.evmClient.ConfiguredChainID()

fwdrs, err := f.ORM.FindForwardersByChain(big.Big(*chainId))
fwdrs, err := f.ORM.FindForwardersByChain(ctx, big.Big(*chainId))
if err != nil {
return errors.Wrapf(err, "Failed to retrieve forwarders for chain %d", chainId)
}
Expand Down Expand Up @@ -113,7 +113,9 @@ func FilterName(addr common.Address) string {

func (f *FwdMgr) ForwarderFor(addr common.Address) (forwarder common.Address, err error) {
// Gets forwarders for current chain.
fwdrs, err := f.ORM.FindForwardersByChain(big.Big(*f.evmClient.ConfiguredChainID()))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fwdrs, err := f.ORM.FindForwardersByChain(ctx, big.Big(*f.evmClient.ConfiguredChainID()))
if err != nil {
return common.Address{}, err
}
Expand Down
25 changes: 14 additions & 11 deletions core/chains/evm/forwarders/forwarder_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"testing"
"time"

"github.com/smartcontractkit/chainlink-common/pkg/sqlutil"

"github.com/ethereum/go-ethereum/accounts/abi/bind/backends"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
Expand All @@ -26,7 +28,6 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/evmtest"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest"
"github.com/smartcontractkit/chainlink/v2/core/services/pg"
)

var GetAuthorisedSendersABI = evmtypes.MustGetABI(authorized_receiver.AuthorizedReceiverABI).Methods["getAuthorizedSenders"]
Expand All @@ -39,6 +40,7 @@ func TestFwdMgr_MaybeForwardTransaction(t *testing.T) {
cfg := configtest.NewTestGeneralConfig(t)
evmcfg := evmtest.NewChainScopedConfig(t, cfg)
owner := testutils.MustNewSimTransactor(t)
ctx := testutils.Context(t)

ec := backends.NewSimulatedBackend(map[common.Address]core.GenesisAccount{
owner.From: {
Expand All @@ -61,12 +63,12 @@ func TestFwdMgr_MaybeForwardTransaction(t *testing.T) {

evmClient := client.NewSimulatedBackendClient(t, ec, testutils.FixtureChainID)
lp := logpoller.NewLogPoller(logpoller.NewORM(testutils.FixtureChainID, db, lggr, pgtest.NewQConfig(true)), evmClient, lggr, 100*time.Millisecond, false, 2, 3, 2, 1000)
fwdMgr := forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM(), evmcfg.Database())
fwdMgr.ORM = forwarders.NewORM(db, logger.Test(t), cfg.Database())
fwdMgr := forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM())
fwdMgr.ORM = forwarders.NewORM(db)

fwd, err := fwdMgr.ORM.CreateForwarder(forwarderAddr, ubig.Big(*testutils.FixtureChainID))
fwd, err := fwdMgr.ORM.CreateForwarder(ctx, forwarderAddr, ubig.Big(*testutils.FixtureChainID))
require.NoError(t, err)
lst, err := fwdMgr.ORM.FindForwardersByChain(ubig.Big(*testutils.FixtureChainID))
lst, err := fwdMgr.ORM.FindForwardersByChain(ctx, ubig.Big(*testutils.FixtureChainID))
require.NoError(t, err)
require.Equal(t, len(lst), 1)
require.Equal(t, lst[0].Address, forwarderAddr)
Expand All @@ -79,22 +81,23 @@ func TestFwdMgr_MaybeForwardTransaction(t *testing.T) {
require.NoError(t, err)

cleanupCalled := false
cleanup := func(tx pg.Queryer, evmChainId int64, addr common.Address) error {
cleanup := func(tx sqlutil.Queryer, evmChainId int64, addr common.Address) error {
require.Equal(t, testutils.FixtureChainID.Int64(), evmChainId)
require.Equal(t, forwarderAddr, addr)
require.NotNil(t, tx)
cleanupCalled = true
return nil
}

err = fwdMgr.ORM.DeleteForwarder(fwd.ID, cleanup)
err = fwdMgr.ORM.DeleteForwarder(ctx, fwd.ID, cleanup)
assert.NoError(t, err)
assert.True(t, cleanupCalled)
}

func TestFwdMgr_AccountUnauthorizedToForward_SkipsForwarding(t *testing.T) {
lggr := logger.Test(t)
db := pgtest.NewSqlxDB(t)
ctx := testutils.Context(t)
cfg := configtest.NewTestGeneralConfig(t)
evmcfg := evmtest.NewChainScopedConfig(t, cfg)
owner := testutils.MustNewSimTransactor(t)
Expand All @@ -114,12 +117,12 @@ func TestFwdMgr_AccountUnauthorizedToForward_SkipsForwarding(t *testing.T) {

evmClient := client.NewSimulatedBackendClient(t, ec, testutils.FixtureChainID)
lp := logpoller.NewLogPoller(logpoller.NewORM(testutils.FixtureChainID, db, lggr, pgtest.NewQConfig(true)), evmClient, lggr, 100*time.Millisecond, false, 2, 3, 2, 1000)
fwdMgr := forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM(), evmcfg.Database())
fwdMgr.ORM = forwarders.NewORM(db, logger.Test(t), cfg.Database())
fwdMgr := forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM())
fwdMgr.ORM = forwarders.NewORM(db)

_, err = fwdMgr.ORM.CreateForwarder(forwarderAddr, ubig.Big(*testutils.FixtureChainID))
_, err = fwdMgr.ORM.CreateForwarder(ctx, forwarderAddr, ubig.Big(*testutils.FixtureChainID))
require.NoError(t, err)
lst, err := fwdMgr.ORM.FindForwardersByChain(ubig.Big(*testutils.FixtureChainID))
lst, err := fwdMgr.ORM.FindForwardersByChain(ctx, ubig.Big(*testutils.FixtureChainID))
require.NoError(t, err)
require.Equal(t, len(lst), 1)
require.Equal(t, lst[0].Address, forwarderAddr)
Expand Down
79 changes: 42 additions & 37 deletions core/chains/evm/forwarders/orm.go
Original file line number Diff line number Diff line change
@@ -1,105 +1,110 @@
package forwarders

import (
"context"
"database/sql"

"github.com/smartcontractkit/chainlink-common/pkg/sqlutil"

"github.com/ethereum/go-ethereum/common"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big"
"github.com/smartcontractkit/chainlink/v2/core/services/pg"
)

//go:generate mockery --quiet --name ORM --output ./mocks/ --case=underscore

type ORM interface {
CreateForwarder(addr common.Address, evmChainId big.Big) (fwd Forwarder, err error)
FindForwarders(offset, limit int) ([]Forwarder, int, error)
FindForwardersByChain(evmChainId big.Big) ([]Forwarder, error)
DeleteForwarder(id int64, cleanup func(tx pg.Queryer, evmChainId int64, addr common.Address) error) error
FindForwardersInListByChain(evmChainId big.Big, addrs []common.Address) ([]Forwarder, error)
CreateForwarder(ctx context.Context, addr common.Address, evmChainId big.Big) (fwd Forwarder, err error)
FindForwarders(ctx context.Context, offset, limit int) ([]Forwarder, int, error)
FindForwardersByChain(ctx context.Context, evmChainId big.Big) ([]Forwarder, error)
DeleteForwarder(ctx context.Context, id int64, cleanup func(tx sqlutil.Queryer, evmChainId int64, addr common.Address) error) error
FindForwardersInListByChain(ctx context.Context, evmChainId big.Big, addrs []common.Address) ([]Forwarder, error)
}

type orm struct {
q pg.Q
type DbORM struct {
db sqlutil.Queryer
}

var _ ORM = (*orm)(nil)
var _ ORM = &DbORM{}

func NewORM(db sqlutil.Queryer) *DbORM {
return &DbORM{db: db}
}

func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig) *orm {
return &orm{pg.NewQ(db, lggr, cfg)}
func (o *DbORM) Transaction(ctx context.Context, fn func(*DbORM) error) (err error) {
return sqlutil.Transact(ctx, o.new, o.db, nil, fn)
}

// new returns a NewORM like o, but backed by q.
func (o *DbORM) new(q sqlutil.Queryer) *DbORM { return NewORM(q) }

// CreateForwarder creates the Forwarder address associated with the current EVM chain id.
func (o *orm) CreateForwarder(addr common.Address, evmChainId big.Big) (fwd Forwarder, err error) {
func (o *DbORM) CreateForwarder(ctx context.Context, addr common.Address, evmChainId big.Big) (fwd Forwarder, err error) {
sql := `INSERT INTO evm.forwarders (address, evm_chain_id, created_at, updated_at) VALUES ($1, $2, now(), now()) RETURNING *`
err = o.q.Get(&fwd, sql, addr, evmChainId)
err = o.db.GetContext(ctx, &fwd, sql, addr, evmChainId)
return fwd, err
}

// DeleteForwarder removes a forwarder address.
// If cleanup is non-nil, it can be used to perform any chain- or contract-specific cleanup that need to happen atomically
// on forwarder deletion. If cleanup returns an error, forwarder deletion will be aborted.
func (o *orm) DeleteForwarder(id int64, cleanup func(tx pg.Queryer, evmChainID int64, addr common.Address) error) (err error) {
func (o *DbORM) DeleteForwarder(ctx context.Context, id int64, cleanup func(tx sqlutil.Queryer, evmChainID int64, addr common.Address) error) (err error) {
var dest struct {
EvmChainId int64
Address common.Address
}

var rowsAffected int64
err = o.q.Transaction(func(tx pg.Queryer) error {
err = tx.Get(&dest, `SELECT evm_chain_id, address FROM evm.forwarders WHERE id = $1`, id)
return o.Transaction(ctx, func(orm *DbORM) error {
err := orm.db.GetContext(ctx, &dest, `SELECT evm_chain_id, address FROM evm.forwarders WHERE id = $1`, id)
if err != nil {
return err
}
if cleanup != nil {
if err = cleanup(tx, dest.EvmChainId, dest.Address); err != nil {
if err = cleanup(orm.db, dest.EvmChainId, dest.Address); err != nil {
return err
}
}

result, err2 := o.q.Exec(`DELETE FROM evm.forwarders WHERE id = $1`, id)
result, err := orm.db.ExecContext(ctx, `DELETE FROM evm.forwarders WHERE id = $1`, id)
// If the forwarder wasn't found, we still want to delete the filter.
// In that case, the transaction must return nil, even though DeleteForwarder
// will return sql.ErrNoRows
if err2 != nil && !errors.Is(err2, sql.ErrNoRows) {
return err2
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
rowsAffected, err2 = result.RowsAffected()

return err2
rowsAffected, err = result.RowsAffected()
if err == nil && rowsAffected == 0 {
err = sql.ErrNoRows
}
return err
})

if err == nil && rowsAffected == 0 {
err = sql.ErrNoRows
}
return err
}

// FindForwarders returns all forwarder addresses from offset up until limit.
func (o *orm) FindForwarders(offset, limit int) (fwds []Forwarder, count int, err error) {
func (o *DbORM) FindForwarders(ctx context.Context, offset, limit int) (fwds []Forwarder, count int, err error) {
sql := `SELECT count(*) FROM evm.forwarders`
if err = o.q.Get(&count, sql); err != nil {
if err = o.db.GetContext(ctx, &count, sql); err != nil {
return
}

sql = `SELECT * FROM evm.forwarders ORDER BY created_at DESC, id DESC LIMIT $1 OFFSET $2`
if err = o.q.Select(&fwds, sql, limit, offset); err != nil {
if err = o.db.SelectContext(ctx, &fwds, sql, limit, offset); err != nil {
return
}
return
}

// FindForwardersByChain returns all forwarder addresses for a chain.
func (o *orm) FindForwardersByChain(evmChainId big.Big) (fwds []Forwarder, err error) {
func (o *DbORM) FindForwardersByChain(ctx context.Context, evmChainId big.Big) (fwds []Forwarder, err error) {
sql := `SELECT * FROM evm.forwarders where evm_chain_id = $1 ORDER BY created_at DESC, id DESC`
err = o.q.Select(&fwds, sql, evmChainId)
err = o.db.SelectContext(ctx, &fwds, sql, evmChainId)
return
}

func (o *orm) FindForwardersInListByChain(evmChainId big.Big, addrs []common.Address) ([]Forwarder, error) {
func (o *DbORM) FindForwardersInListByChain(ctx context.Context, evmChainId big.Big, addrs []common.Address) ([]Forwarder, error) {
var fwdrs []Forwarder

arg := map[string]interface{}{
Expand All @@ -124,8 +129,8 @@ func (o *orm) FindForwardersInListByChain(evmChainId big.Big, addrs []common.Add
return nil, errors.Wrap(err, "Failed to run sqlx.IN on query")
}

query = o.q.Rebind(query)
err = o.q.Select(&fwdrs, query, args...)
query = o.db.Rebind(query)
err = o.db.SelectContext(ctx, &fwdrs, query, args...)

if err != nil {
return nil, errors.Wrap(err, "Failed to execute query")
Expand Down
20 changes: 10 additions & 10 deletions core/chains/evm/forwarders/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ import (
"errors"
"testing"

"github.com/smartcontractkit/chainlink-common/pkg/sqlutil"

"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/jmoiron/sqlx"

"github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest"
"github.com/smartcontractkit/chainlink/v2/core/services/pg"

"github.com/jmoiron/sqlx"
)

type TestORM struct {
Expand All @@ -27,9 +27,8 @@ func setupORM(t *testing.T) *TestORM {
t.Helper()

var (
db = pgtest.NewSqlxDB(t)
lggr = logger.Test(t)
orm = NewORM(db, lggr, pgtest.NewQConfig(true))
db = pgtest.NewSqlxDB(t)
orm = NewORM(db)
)

return &TestORM{ORM: orm, db: db}
Expand All @@ -41,8 +40,9 @@ func Test_DeleteForwarder(t *testing.T) {
orm := setupORM(t)
addr := testutils.NewAddress()
chainID := testutils.FixtureChainID
ctx := testutils.Context(t)

fwd, err := orm.CreateForwarder(addr, *big.New(chainID))
fwd, err := orm.CreateForwarder(ctx, addr, *big.New(chainID))
require.NoError(t, err)
assert.Equal(t, addr, fwd.Address)

Expand All @@ -56,14 +56,14 @@ func Test_DeleteForwarder(t *testing.T) {
rets := []error{ErrCleaningUp, nil, nil, ErrCleaningUp}
expected := []error{ErrCleaningUp, nil, sql.ErrNoRows, sql.ErrNoRows}

testCleanupFn := func(q pg.Queryer, evmChainID int64, addr common.Address) error {
testCleanupFn := func(q sqlutil.Queryer, evmChainID int64, addr common.Address) error {
require.Less(t, cleanupCalled, len(rets))
cleanupCalled++
return rets[cleanupCalled-1]
}

for _, expect := range expected {
err = orm.DeleteForwarder(fwd.ID, testCleanupFn)
err = orm.DeleteForwarder(ctx, fwd.ID, testCleanupFn)
assert.ErrorIs(t, err, expect)
}
assert.Equal(t, 2, cleanupCalled)
Expand Down
2 changes: 1 addition & 1 deletion core/chains/evm/txmgr/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewTxm(
var fwdMgr FwdMgr

if txConfig.ForwardersEnabled() {
fwdMgr = forwarders.NewFwdMgr(db, client, logPoller, lggr, chainConfig, dbConfig)
fwdMgr = forwarders.NewFwdMgr(db, client, logPoller, lggr, chainConfig)
} else {
lggr.Info("EvmForwarderManager: Disabled")
}
Expand Down
4 changes: 2 additions & 2 deletions core/chains/evm/txmgr/txmgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ func TestTxm_CreateTransaction(t *testing.T) {
evmConfig.MaxQueued = uint64(1)

// Create mock forwarder, mock authorizedsenders call.
form := forwarders.NewORM(db, logger.Test(t), cfg.Database())
form := forwarders.NewORM(db)
fwdrAddr := testutils.NewAddress()
fwdr, err := form.CreateForwarder(fwdrAddr, ubig.Big(cltest.FixtureChainID))
fwdr, err := form.CreateForwarder(testutils.Context(t), fwdrAddr, ubig.Big(cltest.FixtureChainID))
require.NoError(t, err)
require.Equal(t, fwdr.Address, fwdrAddr)

Expand Down
4 changes: 2 additions & 2 deletions core/cmd/ocr2vrf_configure_commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ func (s *Shell) authorizeForwarder(c *cli.Context, db *sqlx.DB, lggr logger.Logg
}

// Create forwarder for management in forwarder_manager.go.
orm := forwarders.NewORM(db, lggr, s.Config.Database())
_, err = orm.CreateForwarder(common.HexToAddress(forwarderAddress), *ubig.NewI(chainID))
orm := forwarders.NewORM(db)
_, err = orm.CreateForwarder(ctx, common.HexToAddress(forwarderAddress), *ubig.NewI(chainID))
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions core/internal/features/features_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -773,9 +773,9 @@ func setupForwarderEnabledNode(t *testing.T, owner *bind.TransactOpts, portV2 in
b.Commit()

// add forwarder address to be tracked in db
forwarderORM := forwarders.NewORM(app.GetSqlxDB(), logger.TestLogger(t), config.Database())
forwarderORM := forwarders.NewORM(app.GetSqlxDB())
chainID := ubig.Big(*b.Blockchain().Config().ChainID)
_, err = forwarderORM.CreateForwarder(forwarder, chainID)
_, err = forwarderORM.CreateForwarder(testutils.Context(t), forwarder, chainID)
require.NoError(t, err)

return app, p2pKey.PeerID().Raw(), transmitter, forwarder, key
Expand Down
4 changes: 2 additions & 2 deletions core/internal/features/ocr2/features_ocr2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ func setupNodeOCR2(
b.Commit()

// add forwarder address to be tracked in db
forwarderORM := forwarders.NewORM(app.GetSqlxDB(), logger.TestLogger(t), config.Database())
forwarderORM := forwarders.NewORM(app.GetSqlxDB())
chainID := ubig.Big(*b.Blockchain().Config().ChainID)
_, err2 = forwarderORM.CreateForwarder(faddr, chainID)
_, err2 = forwarderORM.CreateForwarder(testutils.Context(t), faddr, chainID)
require.NoError(t, err2)

effectiveTransmitter = faddr
Expand Down
Loading

0 comments on commit 3183b21

Please sign in to comment.