Skip to content

Commit

Permalink
refactor txdb and opts to use URLs
Browse files Browse the repository at this point in the history
  • Loading branch information
krehermann committed Sep 26, 2023
1 parent eda6497 commit cfb1c1b
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 89 deletions.
8 changes: 4 additions & 4 deletions core/chains/evm/headtracker/head_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ func firstHead(t *testing.T, db *sqlx.DB) (h evmtypes.Head) {
func TestHeadTracker_New(t *testing.T) {
t.Parallel()

db := pgtest.NewEVMScopedDB(t)
config := cltest.NewTestChainScopedConfig(t)
db := evmtest.NewScopedDB(t, config.Database())
logger := logger.TestLogger(t)
config := configtest.NewGeneralConfig(t, nil)
ethClient := evmtest.NewEthClientMockWithDefaultChain(t)
ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(cltest.Head(0), nil)

Expand All @@ -71,9 +71,9 @@ func TestHeadTracker_New(t *testing.T) {
func TestHeadTracker_Save_InsertsAndTrimsTable(t *testing.T) {
t.Parallel()

db := pgtest.NewEVMScopedDB(t)
logger := logger.TestLogger(t)
config := cltest.NewTestChainScopedConfig(t)
db := evmtest.NewScopedDB(t, config.Database())
logger := logger.TestLogger(t)

ethClient := evmtest.NewEthClientMockWithDefaultChain(t)
orm := headtracker.NewORM(db, logger, config.Database(), cltest.FixtureChainID)
Expand Down
4 changes: 2 additions & 2 deletions core/cmd/shell_local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestShell_RunNodeWithPasswords(t *testing.T) {
AppConfig: cfg,
EventBroadcaster: pg.NewNullEventBroadcaster(),
MailMon: &utils.MailboxMonitor{},
DB: pgtest.NewEVMScopedDB(t),
DB: evmtest.NewScopedDB(t, cfg.Database()),
},
}
testRelayers := genTestEVMRelayers(t, opts, keyStore)
Expand Down Expand Up @@ -195,7 +195,7 @@ func TestShell_RunNodeWithAPICredentialsFile(t *testing.T) {
EventBroadcaster: pg.NewNullEventBroadcaster(),

MailMon: &utils.MailboxMonitor{},
DB: pgtest.NewEVMScopedDB(t),
DB: evmtest.NewScopedDB(t, cfg.Database()),
},
}
testRelayers := genTestEVMRelayers(t, opts, keyStore)
Expand Down
3 changes: 1 addition & 2 deletions core/internal/cltest/cltest.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/evmtest"
clhttptest "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/httptest"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/keystest"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest"
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/logger/audit"
"github.com/smartcontractkit/chainlink/v2/core/services/chainlink"
Expand Down Expand Up @@ -396,7 +395,7 @@ func NewApplicationWithConfig(t testing.TB, cfg chainlink.GeneralConfig, flagsAn
AppConfig: cfg,
EventBroadcaster: eventBroadcaster,
MailMon: mailMon,
DB: pgtest.NewEVMScopedDB(t),
DB: evmtest.NewScopedDB(t, cfg.Database()),
},
CSAETHKeystore: keyStore,
}
Expand Down
10 changes: 10 additions & 0 deletions core/internal/testutils/evmtest/evmtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/smartcontractkit/sqlx"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/test-go/testify/assert"
"golang.org/x/exp/slices"
"gopkg.in/guregu/null.v4"

Expand All @@ -30,7 +31,9 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller"
"github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr"
evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types"
"github.com/smartcontractkit/chainlink/v2/core/config"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest"
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/keystore"
"github.com/smartcontractkit/chainlink/v2/core/services/pg"
Expand Down Expand Up @@ -368,3 +371,10 @@ func (r *RawSub[T]) TrySend(t T) {
case r.ch <- t:
}
}

func NewScopedDB(t testing.TB, cfg config.Database) *sqlx.DB {
// hack to scope to evm schema. the value "evm" will need to be dynamic to support multiple relayers
url, err := pg.SchemaScopedConnection(cfg.URL(), "evm")
assert.NoError(t, err, "failed to create evm scoped db")
return pgtest.NewSqlxDB(t, pgtest.WithURL(*url))
}
33 changes: 16 additions & 17 deletions core/internal/testutils/pgtest/pgtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package pgtest

import (
"database/sql"
"encoding/json"
"net/url"
"testing"

"github.com/google/uuid"
Expand Down Expand Up @@ -30,24 +30,23 @@ func NewSqlDB(t *testing.T) *sql.DB {
return db
}

func NewEVMScopedDB(t testing.TB) *sqlx.DB {
// hack to scope to evm schema. the value "evm" will need to be dynamic to support multiple relayers
url := pg.SchemaScopedConnection(defaultDBURL, "evm")
return NewSqlxDB(t, WithURL(url))
func uniqueConnection(t testing.TB) *url.URL {
url := testutils.MustParseURL(t, defaultDBURL.String())
// inject uuid by default because the transaction wrapped driver requires it
q := url.Query()
q.Add("uuid", uuid.New().String())
url.RawQuery = q.Encode()
return url
}

func NewSqlxDB(t testing.TB, opts ...ConnectionOpt) *sqlx.DB {
testutils.SkipShortDB(t)
conn := &pg.ConnectionScope{
UUID: uuid.New().String(),
URL: defaultDBURL,
}

url := uniqueConnection(t)
for _, opt := range opts {
opt(conn)
opt(url)
}
enc, err := json.Marshal(conn)
require.NoError(t, err)
db, err := sqlx.Open(string(dialects.TransactionWrappedPostgres), string(enc))
db, err := sqlx.Open(string(dialects.TransactionWrappedPostgres), url.String())
require.NoError(t, err)
t.Cleanup(func() { assert.NoError(t, db.Close()) })

Expand All @@ -56,11 +55,11 @@ func NewSqlxDB(t testing.TB, opts ...ConnectionOpt) *sqlx.DB {
return db
}

type ConnectionOpt func(conn *pg.ConnectionScope)
type ConnectionOpt func(*url.URL)

func WithURL(url string) ConnectionOpt {
return func(conn *pg.ConnectionScope) {
conn.URL = url
func WithURL(override url.URL) ConnectionOpt {
return func(u *url.URL) {
u = &override
}
}

Expand Down
75 changes: 44 additions & 31 deletions core/internal/testutils/pgtest/txdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"flag"
"fmt"
"io"
Expand Down Expand Up @@ -37,7 +36,7 @@ import (
// store to use the raw DialectPostgres dialect and setup a one-use database.
// See heavyweight.FullTestDB() as a convenience function to help you do this,
// but please use sparingly because as it's name implies, it is expensive.
var defaultDBURL string // global used to create a [NewSqlxDB]
var defaultDBURL *url.URL // global used to create a [NewSqlxDB]
func init() {
testing.Init()
if !flag.Parsed() {
Expand All @@ -47,26 +46,25 @@ func init() {
// -short tests don't need a DB
return
}
defaultDBURL = string(env.DatabaseURL.Get())
if defaultDBURL == "" {
envDBURL := string(env.DatabaseURL.Get())
if envDBURL == "" {
panic("you must provide a CL_DATABASE_URL environment variable")
}

parsed, err := url.Parse(defaultDBURL)
var err error
defaultDBURL, err = url.Parse(envDBURL)
if err != nil {
panic(err)
}
if parsed.Path == "" {
msg := fmt.Sprintf("invalid %[1]s: `%[2]s`. You must set %[1]s env var to point to your test database. Note that the test database MUST end in `_test` to differentiate from a possible production DB. HINT: Try %[1]s=postgresql://postgres@localhost:5432/chainlink_test?sslmode=disable", env.DatabaseURL, parsed.String())
if defaultDBURL.Path == "" {
msg := fmt.Sprintf("invalid %[1]s: `%[2]s`. You must set %[1]s env var to point to your test database. Note that the test database MUST end in `_test` to differentiate from a possible production DB. HINT: Try %[1]s=postgresql://postgres@localhost:5432/chainlink_test?sslmode=disable", env.DatabaseURL, defaultDBURL.String())
panic(msg)
}
if !strings.HasSuffix(parsed.Path, "_test") {
msg := fmt.Sprintf("cannot run tests against database named `%s`. Note that the test database MUST end in `_test` to differentiate from a possible production DB. HINT: Try %s=postgresql://postgres@localhost:5432/chainlink_test?sslmode=disable", parsed.Path[1:], env.DatabaseURL)
if !strings.HasSuffix(defaultDBURL.Path, "_test") {
msg := fmt.Sprintf("cannot run tests against database named `%s`. Note that the test database MUST end in `_test` to differentiate from a possible production DB. HINT: Try %s=postgresql://postgres@localhost:5432/chainlink_test?sslmode=disable", defaultDBURL.Path[1:], env.DatabaseURL)
panic(msg)
}
name := string(dialects.TransactionWrappedPostgres)
sql.Register(name, &txDriver{
dbURL: defaultDBURL,
conns: make(map[string]map[string]*conn),
})
sqlx.BindDriver(name, sqlx.DOLLAR)
Expand All @@ -80,45 +78,60 @@ type txDriver struct {
sync.Mutex
db map[string]*sql.DB // url -> db
conns map[string]map[string]*conn // url -> (uuid -> db) so we can close per url

dbURL string
}

// jsonConnectionScope must be a json-encoded [ConnectionScope]. The Open interface requires
// a string
func (d *txDriver) Open(jsonConnectionScope string) (driver.Conn, error) {
func cleanseURL(u *url.URL) {
q := u.Query()
q.Del("uuid")
u.RawQuery = q.Encode()

}
func (d *txDriver) Open(connection string) (driver.Conn, error) {
d.Lock()
defer d.Unlock()
var scope pg.ConnectionScope
err := json.Unmarshal([]byte(jsonConnectionScope), &scope)

dbURL, err := url.Parse(connection)
if err != nil {
return nil, fmt.Errorf("pgtest tx driver failed to parse connection string %s: must be json encoded scope: %w", jsonConnectionScope, err)
return nil, fmt.Errorf("pgtest tx driver failed to parse connection string %s: %w", connection, err)
}

// tx db requires client connection to self-identify a uuid for connection routing
// extract in and track the resulting url for scoping connections
queryVals := dbURL.Query()
uuid := queryVals.Get("uuid")
if uuid == "" {
return nil, fmt.Errorf("txdb can open connection: missing `uuid` query parameter in connection string %s", connection)
}
queryVals.Del("uuid")
dbURL.RawQuery = queryVals.Encode() // encode is sorted

endpt := dbURL.String()

// initialize dbs if the first call
if d.db == nil {
d.db = make(map[string]*sql.DB)
}
db, exists := d.db[scope.URL]
db, exists := d.db[endpt]
if !exists {
db, err = sql.Open("pgx", scope.URL)
db, err = sql.Open("pgx", endpt)
if err != nil {
return nil, err
}
d.db[scope.URL] = db
d.conns[scope.URL] = make(map[string]*conn)
d.db[endpt] = db
d.conns[endpt] = make(map[string]*conn)
}

c, exists := d.conns[scope.URL][scope.UUID]
c, exists := d.conns[endpt][uuid]
if !exists || !c.tryOpen() {
tx, err := db.Begin()
if err != nil {
return nil, err
}
c = &conn{tx: tx, opened: 1, scope: scope}
c = &conn{tx: tx, opened: 1, scope: pg.ConnectionScope{Endpoint: endpt, UUID: uuid}}
c.removeSelf = func() error {
return d.deleteConn(c)
}
d.conns[scope.URL][scope.UUID] = c
d.conns[endpt][uuid] = c
}
return c, nil
}
Expand All @@ -130,15 +143,15 @@ func (d *txDriver) deleteConn(c *conn) error {
d.Lock()
defer d.Unlock()

if d.conns[c.scope.URL][c.scope.UUID] != c {
if d.conns[c.scope.Endpoint][c.scope.UUID] != c {
return nil // already been replaced
}
delete(d.conns[c.scope.URL], c.scope.UUID)
if len(d.conns[c.scope.URL]) == 0 && d.db[c.scope.URL] != nil {
if err := d.db[c.scope.URL].Close(); err != nil {
delete(d.conns[c.scope.Endpoint], c.scope.UUID)
if len(d.conns[c.scope.Endpoint]) == 0 && d.db[c.scope.Endpoint] != nil {
if err := d.db[c.scope.Endpoint].Close(); err != nil {
return err
}
delete(d.db, c.scope.URL)
delete(d.db, c.scope.Endpoint)
}
return nil
}
Expand Down
5 changes: 3 additions & 2 deletions core/services/chainlink/relayer_chain_interoperators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/chains/starknet"
"github.com/smartcontractkit/chainlink/v2/core/internal/cltest"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/evmtest"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest"
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/chainlink"
Expand Down Expand Up @@ -209,7 +210,7 @@ func TestCoreRelayerChainInteroperators(t *testing.T) {
AppConfig: cfg,
EventBroadcaster: pg.NewNullEventBroadcaster(),
MailMon: &utils.MailboxMonitor{},
DB: pgtest.NewEVMScopedDB(t),
DB: evmtest.NewScopedDB(t, cfg.Database()),
},
CSAETHKeystore: keyStore,
}),
Expand Down Expand Up @@ -284,7 +285,7 @@ func TestCoreRelayerChainInteroperators(t *testing.T) {
AppConfig: cfg,
EventBroadcaster: pg.NewNullEventBroadcaster(),
MailMon: &utils.MailboxMonitor{},
DB: pgtest.NewEVMScopedDB(t),
DB: evmtest.NewScopedDB(t, cfg.Database()),
},
CSAETHKeystore: keyStore,
}),
Expand Down
20 changes: 8 additions & 12 deletions core/services/pg/connection.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package pg

import (
"encoding/json"
"fmt"
"net/url"
"time"

"github.com/google/uuid"
Expand All @@ -21,8 +21,8 @@ type ConnectionConfig interface {
}

type ConnectionScope struct {
URL string
UUID string
Endpoint string
UUID string
}

func NewConnection(uri string, dialect dialects.DialectName, config ConnectionConfig) (db *sqlx.DB, err error) {
Expand All @@ -31,18 +31,14 @@ func NewConnection(uri string, dialect dialects.DialectName, config ConnectionCo
// should be encapsulated in it's own transaction, and thus needs its own
// unique id.
//
// We can happily throw away the original uri here because if we are using
// txdb it should have already been set at the point where we called
// txdb.Register
s := ConnectionScope{
URL: uri,
UUID: uuid.New().String(),
}
b, err := json.Marshal(&s)
u, err := url.Parse(uri)
if err != nil {
return nil, err
}
uri = string(b)
q := u.Query()
q.Add("uuid", uuid.New().String())
u.RawQuery = q.Encode()
uri = u.String()
}

// Initialize sql/sqlx
Expand Down
Loading

0 comments on commit cfb1c1b

Please sign in to comment.