diff --git a/go.mod b/go.mod index 7bd59cbd5..a6d0e8c6b 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/smartcontractkit/chainlink-common go 1.23.3 require ( + github.com/XSAM/otelsql v0.29.0 github.com/andybalholm/brotli v1.1.1 github.com/atombender/go-jsonschema v0.16.1-0.20240916205339-a74cd4e2851c github.com/bytecodealliance/wasmtime-go/v23 v23.0.0 diff --git a/go.sum b/go.sum index 623f00962..39420f468 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migc github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/Microsoft/hcsshim v0.9.4 h1:mnUj0ivWy6UzbB1uLFqKR6F+ZyiDc7j4iGgHTpO+5+I= github.com/Microsoft/hcsshim v0.9.4/go.mod h1:7pLA8lDk46WKDWlVsENo92gC0XFa8rbKfyFRBqxEbCc= +github.com/XSAM/otelsql v0.29.0 h1:pEw9YXXs8ZrGRYfDc0cmArIz9lci5b42gmP5+tA1Huc= +github.com/XSAM/otelsql v0.29.0/go.mod h1:d3/0xGIGC5RVEE+Ld7KotwaLy6zDeaF3fLJHOPpdN2w= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/apache/arrow-go/v18 v18.0.0 h1:1dBDaSbH3LtulTyOVYaBCHO3yVRwjV+TZaqn3g6V7ZM= diff --git a/pkg/config/build/build.go b/pkg/config/build/build.go new file mode 100644 index 000000000..c9cfdd205 --- /dev/null +++ b/pkg/config/build/build.go @@ -0,0 +1,35 @@ +package build + +import ( + "cmp" + "os" + "runtime/debug" +) + +// Unset is a sentinel value. +const Unset = "unset" + +// Version and Checksum are set at compile time via build arguments. +var ( + // Program is updated to the full main program path if [debug.BuildInfo] is available. + Program = os.Args[0] + // Version is the semantic version of the build or Unset. + Version = Unset + // Checksum is the commit hash of the build or Unset. + Checksum = Unset + ChecksumPrefix = Unset +) + +func init() { + buildInfo, ok := debug.ReadBuildInfo() + if ok { + Program = cmp.Or(buildInfo.Main.Path, Program) + if Version == Unset && buildInfo.Main.Version != "" { + Version = buildInfo.Main.Version + } + if Checksum == Unset && buildInfo.Main.Sum != "" { + Checksum = buildInfo.Main.Sum + } + } + ChecksumPrefix = Checksum[:min(7, len(Checksum))] +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 74af6ff39..bfd5133d5 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -2,6 +2,7 @@ package logger import ( "io" + "fmt" "reflect" "testing" @@ -9,6 +10,8 @@ import ( "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest/observer" + + "github.com/smartcontractkit/chainlink-common/pkg/config/build" ) // Logger is a minimal subset of smartcontractkit/chainlink/core/logger.Logger implemented by go.uber.org/zap.SugaredLogger @@ -52,9 +55,16 @@ func New() (Logger, error) { return defaultConfig.New() } func (c *Config) New() (Logger, error) { return NewWith(func(cfg *zap.Config) { cfg.Level.SetLevel(c.Level) + cfg.InitialFields = map[string]interface{}{ + "version": buildVersion(), + } }) } +func buildVersion() string { + return fmt.Sprintf("%s@%s", build.Version, build.ChecksumPrefix) +} + // NewWith returns a new Logger from a modified [zap.Config]. func NewWith(cfgFn func(*zap.Config)) (Logger, error) { cfg := zap.NewProductionConfig() @@ -83,7 +93,7 @@ func Test(tb testing.TB) Logger { zapcore.DebugLevel, ), ) - return &logger{lggr.Sugar()} + return &logger{lggr.With(zap.String("version", buildVersion())).Sugar()} } // TestSugared returns a new test SugaredLogger. diff --git a/pkg/loop/config.go b/pkg/loop/config.go index e63f72f2f..c789e699c 100644 --- a/pkg/loop/config.go +++ b/pkg/loop/config.go @@ -12,8 +12,15 @@ import ( ) const ( - envDatabaseURL = "CL_DATABASE_URL" - envPromPort = "CL_PROMETHEUS_PORT" + envDatabaseURL = "CL_DATABASE_URL" + envDatabaseIdleInTxSessionTimeout = "CL_DATABASE_IDLE_IN_TX_SESSION_TIMEOUT" + envDatabaseLockTimeout = "CL_DATABASE_LOCK_TIMEOUT" + envDatabaseQueryTimeout = "CL_DATABASE_QUERY_TIMEOUT" + envDatabaseLogSQL = "CL_DATABASE_LOG_SQL" + envDatabaseMaxOpenConns = "CL_DATABASE_MAX_OPEN_CONNS" + envDatabaseMaxIdleConns = "CL_DATABASE_MAX_IDLE_CONNS" + + envPromPort = "CL_PROMETHEUS_PORT" envTracingEnabled = "CL_TRACING_ENABLED" envTracingCollectorTarget = "CL_TRACING_COLLECTOR_TARGET" @@ -36,7 +43,13 @@ const ( // EnvConfig is the configuration between the application and the LOOP executable. The values // are fully resolved and static and passed via the environment. type EnvConfig struct { - DatabaseURL *url.URL + DatabaseURL *url.URL + DatabaseIdleInTxSessionTimeout time.Duration + DatabaseLockTimeout time.Duration + DatabaseQueryTimeout time.Duration + DatabaseLogSQL bool + DatabaseMaxOpenConns int + DatabaseMaxIdleConns int PrometheusPort int @@ -66,7 +79,14 @@ func (e *EnvConfig) AsCmdEnv() (env []string) { if e.DatabaseURL != nil { // optional add(envDatabaseURL, e.DatabaseURL.String()) + add(envDatabaseIdleInTxSessionTimeout, e.DatabaseIdleInTxSessionTimeout.String()) + add(envDatabaseLockTimeout, e.DatabaseLockTimeout.String()) + add(envDatabaseQueryTimeout, e.DatabaseQueryTimeout.String()) + add(envDatabaseLogSQL, strconv.FormatBool(e.DatabaseLogSQL)) + add(envDatabaseMaxOpenConns, strconv.Itoa(e.DatabaseMaxOpenConns)) + add(envDatabaseMaxIdleConns, strconv.Itoa(e.DatabaseMaxIdleConns)) } + add(envPromPort, strconv.Itoa(e.PrometheusPort)) add(envTracingEnabled, strconv.FormatBool(e.TracingEnabled)) @@ -99,13 +119,44 @@ func (e *EnvConfig) AsCmdEnv() (env []string) { // parse deserializes environment variables func (e *EnvConfig) parse() error { - promPortStr := os.Getenv(envPromPort) var err error - e.DatabaseURL, err = getDatabaseURL() + e.DatabaseURL, err = getEnv(envDatabaseURL, func(s string) (*url.URL, error) { + if s == "" { // DatabaseURL is optional + return nil, nil + } + return url.Parse(s) + }) if err != nil { - return fmt.Errorf("failed to parse %s: %w", envDatabaseURL, err) + return err + } + if e.DatabaseURL != nil { + e.DatabaseIdleInTxSessionTimeout, err = getEnv(envDatabaseIdleInTxSessionTimeout, time.ParseDuration) + if err != nil { + return err + } + e.DatabaseLockTimeout, err = getEnv(envDatabaseLockTimeout, time.ParseDuration) + if err != nil { + return err + } + e.DatabaseQueryTimeout, err = getEnv(envDatabaseQueryTimeout, time.ParseDuration) + if err != nil { + return err + } + e.DatabaseLogSQL, err = getEnv(envDatabaseLogSQL, strconv.ParseBool) + if err != nil { + return err + } + e.DatabaseMaxOpenConns, err = getEnv(envDatabaseMaxOpenConns, strconv.Atoi) + if err != nil { + return err + } + e.DatabaseMaxIdleConns, err = getEnv(envDatabaseMaxIdleConns, strconv.Atoi) + if err != nil { + return err + } } + promPortStr := os.Getenv(envPromPort) e.PrometheusPort, err = strconv.Atoi(promPortStr) if err != nil { return fmt.Errorf("failed to parse %s = %q: %w", envPromPort, promPortStr, err) @@ -211,16 +262,11 @@ func getFloat64OrZero(envKey string) float64 { return f } -// getDatabaseURL parses the CL_DATABASE_URL environment variable. -func getDatabaseURL() (*url.URL, error) { - databaseURL := os.Getenv(envDatabaseURL) - if databaseURL == "" { - // DatabaseURL is optional - return nil, nil - } - u, err := url.Parse(databaseURL) +func getEnv[T any](key string, parse func(string) (T, error)) (t T, err error) { + v := os.Getenv(key) + t, err = parse(v) if err != nil { - return nil, fmt.Errorf("invalid %s: %w", envDatabaseURL, err) + err = fmt.Errorf("failed to parse %s=%s: %w", key, v, err) } - return u, nil + return } diff --git a/pkg/loop/config_test.go b/pkg/loop/config_test.go index 78d177aa4..4ef1dbfa1 100644 --- a/pkg/loop/config_test.go +++ b/pkg/loop/config_test.go @@ -19,15 +19,24 @@ import ( func TestEnvConfig_parse(t *testing.T) { cases := []struct { - name string - envVars map[string]string - expectError bool + name string + envVars map[string]string + expectError bool + expectedDatabaseURL string - expectedPrometheusPort int - expectedTracingEnabled bool - expectedTracingCollectorTarget string - expectedTracingSamplingRatio float64 - expectedTracingTLSCertPath string + expectedDatabaseIdleInTxSessionTimeout time.Duration + expectedDatabaseLockTimeout time.Duration + expectedDatabaseQueryTimeout time.Duration + expectedDatabaseLogSQL bool + expectedDatabaseMaxOpenConns int + expectedDatabaseMaxIdleConns int + + expectedPrometheusPort int + expectedTracingEnabled bool + expectedTracingCollectorTarget string + expectedTracingSamplingRatio float64 + expectedTracingTLSCertPath string + expectedTelemetryEnabled bool expectedTelemetryEndpoint string expectedTelemetryInsecureConn bool @@ -43,12 +52,20 @@ func TestEnvConfig_parse(t *testing.T) { name: "All variables set correctly", envVars: map[string]string{ envDatabaseURL: "postgres://user:password@localhost:5432/db", + envDatabaseIdleInTxSessionTimeout: "42s", + envDatabaseLockTimeout: "8m", + envDatabaseQueryTimeout: "7s", + envDatabaseLogSQL: "true", + envDatabaseMaxOpenConns: "9999", + envDatabaseMaxIdleConns: "8080", + envPromPort: "8080", envTracingEnabled: "true", envTracingCollectorTarget: "some:target", envTracingSamplingRatio: "1.0", envTracingTLSCertPath: "internal/test/fixtures/client.pem", envTracingAttribute + "XYZ": "value", + envTelemetryEnabled: "true", envTelemetryEndpoint: "example.com/beholder", envTelemetryInsecureConn: "true", @@ -61,13 +78,22 @@ func TestEnvConfig_parse(t *testing.T) { envTelemetryEmitterBatchProcessor: "true", envTelemetryEmitterExportTimeout: "1s", }, - expectError: false, + expectError: false, + expectedDatabaseURL: "postgres://user:password@localhost:5432/db", - expectedPrometheusPort: 8080, - expectedTracingEnabled: true, - expectedTracingCollectorTarget: "some:target", - expectedTracingSamplingRatio: 1.0, - expectedTracingTLSCertPath: "internal/test/fixtures/client.pem", + expectedDatabaseIdleInTxSessionTimeout: 42 * time.Second, + expectedDatabaseLockTimeout: 8 * time.Minute, + expectedDatabaseQueryTimeout: 7 * time.Second, + expectedDatabaseLogSQL: true, + expectedDatabaseMaxOpenConns: 9999, + expectedDatabaseMaxIdleConns: 8080, + + expectedPrometheusPort: 8080, + expectedTracingEnabled: true, + expectedTracingCollectorTarget: "some:target", + expectedTracingSamplingRatio: 1.0, + expectedTracingTLSCertPath: "internal/test/fixtures/client.pem", + expectedTelemetryEnabled: true, expectedTelemetryEndpoint: "example.com/beholder", expectedTelemetryInsecureConn: true, @@ -123,6 +149,25 @@ func TestEnvConfig_parse(t *testing.T) { if config.DatabaseURL.String() != tc.expectedDatabaseURL { t.Errorf("Expected Database URL %s, got %s", tc.expectedDatabaseURL, config.DatabaseURL) } + if config.DatabaseIdleInTxSessionTimeout != tc.expectedDatabaseIdleInTxSessionTimeout { + t.Errorf("Expected Database idle in tx session timeout %s, got %s", tc.expectedDatabaseIdleInTxSessionTimeout, config.DatabaseIdleInTxSessionTimeout) + } + if config.DatabaseLockTimeout != tc.expectedDatabaseLockTimeout { + t.Errorf("Expected Database lock timeout %s, got %s", tc.expectedDatabaseLockTimeout, config.DatabaseLockTimeout) + } + if config.DatabaseQueryTimeout != tc.expectedDatabaseQueryTimeout { + t.Errorf("Expected Database query timeout %s, got %s", tc.expectedDatabaseQueryTimeout, config.DatabaseQueryTimeout) + } + if config.DatabaseLogSQL != tc.expectedDatabaseLogSQL { + t.Errorf("Expected Database log sql %t, got %t", tc.expectedDatabaseLogSQL, config.DatabaseLogSQL) + } + if config.DatabaseMaxOpenConns != tc.expectedDatabaseMaxOpenConns { + t.Errorf("Expected Database max open conns %d, got %d", tc.expectedDatabaseMaxOpenConns, config.DatabaseMaxOpenConns) + } + if config.DatabaseMaxIdleConns != tc.expectedDatabaseMaxIdleConns { + t.Errorf("Expected Database max idle conns %d, got %d", tc.expectedDatabaseMaxIdleConns, config.DatabaseMaxIdleConns) + } + if config.PrometheusPort != tc.expectedPrometheusPort { t.Errorf("Expected Prometheus port %d, got %d", tc.expectedPrometheusPort, config.PrometheusPort) } diff --git a/pkg/loop/internal/example-relay/main.go b/pkg/loop/internal/example-relay/main.go new file mode 100644 index 000000000..14aa2a8ad --- /dev/null +++ b/pkg/loop/internal/example-relay/main.go @@ -0,0 +1,124 @@ +// This file contains an example implementation of a relayer plugin. +package main + +import ( + "context" + "errors" + "math/big" + + "github.com/hashicorp/go-plugin" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/loop" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/types/core" +) + +const ( + loggerName = "PluginExample" +) + +func main() { + s := loop.MustNewStartedServer(loggerName) + defer s.Stop() + + p := &pluginRelayer{lggr: s.Logger, ds: s.DataSource} + defer s.Logger.ErrorIfFn(p.Close, "Failed to close") + + s.MustRegister(p) + + stopCh := make(chan struct{}) + defer close(stopCh) + + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: loop.PluginRelayerHandshakeConfig(), + Plugins: map[string]plugin.Plugin{ + loop.PluginRelayerName: &loop.GRPCPluginRelayer{ + PluginServer: p, + BrokerConfig: loop.BrokerConfig{ + StopCh: stopCh, + Logger: s.Logger, + GRPCOpts: s.GRPCOpts, + }, + }, + }, + GRPCServer: s.GRPCOpts.NewServer, + }) +} + +type pluginRelayer struct { + lggr logger.Logger + ds sqlutil.DataSource +} + +func (p *pluginRelayer) Ready() error { return nil } + +func (p *pluginRelayer) HealthReport() map[string]error { return map[string]error{p.Name(): nil} } + +func (p *pluginRelayer) Name() string { return p.lggr.Name() } + +func (p *pluginRelayer) NewRelayer(ctx context.Context, config string, keystore core.Keystore, cr core.CapabilitiesRegistry) (loop.Relayer, error) { + return &relayer{lggr: logger.Named(p.lggr, "Relayer"), ds: p.ds}, nil +} + +func (p *pluginRelayer) Close() error { return nil } + +type relayer struct { + lggr logger.Logger + ds sqlutil.DataSource +} + +func (r *relayer) Name() string { return r.lggr.Name() } + +func (r *relayer) Start(ctx context.Context) error { + var names []string + err := r.ds.SelectContext(ctx, names, "SELECT table_name FROM information_schema.tables WHERE table_schema='public'") + if err != nil { + return err + } + r.lggr.Info("Queried table names", "names", names) + return nil +} + +func (r *relayer) Close() error { return nil } + +func (r *relayer) Ready() error { return nil } + +func (r *relayer) HealthReport() map[string]error { return map[string]error{r.Name(): nil} } + +func (r *relayer) LatestHead(ctx context.Context) (types.Head, error) { + return types.Head{}, errors.New("unimplemented") +} + +func (r *relayer) GetChainStatus(ctx context.Context) (types.ChainStatus, error) { + return types.ChainStatus{}, errors.New("unimplemented") +} + +func (r *relayer) ListNodeStatuses(ctx context.Context, pageSize int32, pageToken string) (stats []types.NodeStatus, nextPageToken string, total int, err error) { + return nil, "", -1, errors.New("unimplemented") +} + +func (r *relayer) Transact(ctx context.Context, from, to string, amount *big.Int, balanceCheck bool) error { + return errors.New("unimplemented") +} + +func (r *relayer) NewContractWriter(ctx context.Context, chainWriterConfig []byte) (types.ContractWriter, error) { + return nil, errors.New("unimplemented") +} + +func (r *relayer) NewContractReader(ctx context.Context, contractReaderConfig []byte) (types.ContractReader, error) { + return nil, errors.New("unimplemented") +} + +func (r *relayer) NewConfigProvider(ctx context.Context, args types.RelayArgs) (types.ConfigProvider, error) { + return nil, errors.New("unimplemented") +} + +func (r *relayer) NewPluginProvider(ctx context.Context, args types.RelayArgs, args2 types.PluginArgs) (types.PluginProvider, error) { + return nil, errors.New("unimplemented") +} + +func (r *relayer) NewLLOProvider(ctx context.Context, args types.RelayArgs, args2 types.PluginArgs) (types.LLOProvider, error) { + return nil, errors.New("unimplemented") +} diff --git a/pkg/loop/server.go b/pkg/loop/server.go index c866be20b..c859fc7e8 100644 --- a/pkg/loop/server.go +++ b/pkg/loop/server.go @@ -1,14 +1,21 @@ package loop import ( + "context" "fmt" "os" + "os/signal" + "time" + "github.com/jmoiron/sqlx" "go.opentelemetry.io/otel/attribute" "github.com/smartcontractkit/chainlink-common/pkg/beholder" + "github.com/smartcontractkit/chainlink-common/pkg/config/build" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil/pg" ) // NewStartedServer returns a started Server. @@ -44,10 +51,13 @@ func MustNewStartedServer(loggerName string) *Server { // Server holds common plugin server fields. type Server struct { - GRPCOpts GRPCOpts - Logger logger.SugaredLogger - promServer *PromServer - checker *services.HealthChecker + GRPCOpts GRPCOpts + Logger logger.SugaredLogger + db *sqlx.DB // optional + dbStatsReporter *pg.StatsReporter // optional + DataSource sqlutil.DataSource // optional + promServer *PromServer + checker *services.HealthChecker } func newServer(loggerName string) (*Server, error) { @@ -66,6 +76,11 @@ func newServer(loggerName string) (*Server, error) { } func (s *Server) start() error { + ctx, stopSig := signal.NotifyContext(context.Background(), os.Interrupt) + defer stopSig() + stopAfter := context.AfterFunc(ctx, stopSig) + defer stopAfter() + var envCfg EnvConfig if err := envCfg.parse(); err != nil { return fmt.Errorf("error getting environment configuration: %w", err) @@ -132,6 +147,27 @@ func (s *Server) start() error { return fmt.Errorf("error starting health checker: %w", err) } + if envCfg.DatabaseURL != nil { + pg.SetApplicationName(envCfg.DatabaseURL, build.Program) + dbURL := envCfg.DatabaseURL.String() + var err error + s.db, err = pg.DBConfig{ + IdleInTxSessionTimeout: envCfg.DatabaseIdleInTxSessionTimeout, + LockTimeout: envCfg.DatabaseLockTimeout, + MaxOpenConns: envCfg.DatabaseMaxOpenConns, + MaxIdleConns: envCfg.DatabaseMaxIdleConns, + }.New(ctx, dbURL, pg.DriverPostgres) + if err != nil { + return fmt.Errorf("error connecting to DataBase at %s: %w", dbURL, err) + } + s.DataSource = sqlutil.WrapDataSource(s.db, s.Logger, + sqlutil.TimeoutHook(func() time.Duration { return envCfg.DatabaseQueryTimeout }), + sqlutil.MonitorHook(func() bool { return envCfg.DatabaseLogSQL })) + + s.dbStatsReporter = pg.NewStatsReporter(s.db.Stats, s.Logger) + s.dbStatsReporter.Start() + } + return nil } @@ -146,6 +182,12 @@ func (s *Server) Register(c services.HealthReporter) error { return s.checker.Re // Stop closes resources and flushes logs. func (s *Server) Stop() { + if s.dbStatsReporter != nil { + s.dbStatsReporter.Stop() + } + if s.db != nil { + s.Logger.ErrorIfFn(s.db.Close, "Failed to close database connection") + } s.Logger.ErrorIfFn(s.checker.Close, "Failed to close health checker") s.Logger.ErrorIfFn(s.promServer.Close, "Failed to close prometheus server") if err := s.Logger.Sync(); err != nil { diff --git a/pkg/loop/telem.go b/pkg/loop/telem.go index c66949b23..f3dae75b4 100644 --- a/pkg/loop/telem.go +++ b/pkg/loop/telem.go @@ -4,7 +4,6 @@ import ( "context" "net" "os" - "runtime/debug" grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" "github.com/prometheus/client_golang/prometheus" @@ -21,6 +20,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "github.com/smartcontractkit/chainlink-common/pkg/config/build" loopnet "github.com/smartcontractkit/chainlink-common/pkg/loop/internal/net" ) @@ -100,21 +100,10 @@ func SetupTracing(config TracingConfig) error { } func (config TracingConfig) Attributes() []attribute.KeyValue { - var version string - var service string - buildInfo, ok := debug.ReadBuildInfo() - if !ok { - version = "unknown" - service = "cl-node" - } else { - version = buildInfo.Main.Version - service = buildInfo.Main.Path - } - attributes := []attribute.KeyValue{ - semconv.ServiceNameKey.String(service), + semconv.ServiceNameKey.String(build.Program), semconv.ProcessPIDKey.Int(os.Getpid()), - semconv.ServiceVersionKey.String(version), + semconv.ServiceVersionKey.String(build.Version), } for k, v := range config.NodeAttributes { diff --git a/pkg/sqlutil/pg/connection.go b/pkg/sqlutil/pg/connection.go new file mode 100644 index 000000000..054902f49 --- /dev/null +++ b/pkg/sqlutil/pg/connection.go @@ -0,0 +1,121 @@ +package pg + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/XSAM/otelsql" + "github.com/google/uuid" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/stdlib" + "github.com/jmoiron/sqlx" + "github.com/scylladb/go-reflectx" + "go.opentelemetry.io/otel" + semconv "go.opentelemetry.io/otel/semconv/v1.4.0" + + // need to make sure pgx driver is registered before opening connection + _ "github.com/jackc/pgx/v4/stdlib" +) + +// NOTE: This is the default level in Postgres anyway, we just make it explicit here. +const defaultIsolation = sql.LevelReadCommitted + +// Driver is a compiler enforced type used that maps to database driver names +type Driver string + +const ( + // DriverPostgres represents the postgres dialect. + DriverPostgres Driver = "pgx" + // DriverTxWrappedPostgres is useful for tests. + // When the connection is opened, it starts a transaction and all + // operations performed on the DB will be within that transaction. + DriverTxWrappedPostgres Driver = "txdb" +) + +var ( + otelOpts = []otelsql.Option{otelsql.WithAttributes(semconv.DBSystemPostgreSQL), + otelsql.WithTracerProvider(otel.GetTracerProvider()), + otelsql.WithSQLCommenter(true), + otelsql.WithSpanOptions(otelsql.SpanOptions{ + OmitConnResetSession: true, + OmitConnPrepare: true, + OmitRows: true, + OmitConnectorConnect: true, + OmitConnQuery: false, + })} +) + +type DBConfig struct { + IdleInTxSessionTimeout time.Duration // idle_in_transaction_session_timeout + LockTimeout time.Duration // lock_timeout + MaxOpenConns, MaxIdleConns int +} + +func (config DBConfig) New(ctx context.Context, uri string, driver Driver) (*sqlx.DB, error) { + if driver == DriverTxWrappedPostgres { + // txdb uses the uri as a unique identifier for each transaction. Each ORM should + // be encapsulated in its own transaction, and thus needs its own unique id. + // + // We can happily throw away the original uri here because if we are using txdb + // it should have already been set at the point where we called [txdb.Register]. + return config.open(ctx, uuid.NewString(), string(driver)) + } + return config.openDB(ctx, uri, string(driver)) +} + +func (config DBConfig) open(ctx context.Context, uri string, driverName string) (*sqlx.DB, error) { + sqlDB, err := otelsql.Open(driverName, uri, otelOpts...) + if err != nil { + return nil, err + } + if _, err := sqlDB.ExecContext(ctx, config.initSQL()); err != nil { + return nil, err + } + + return config.newDB(ctx, sqlDB, driverName) +} + +func (config DBConfig) openDB(ctx context.Context, uri string, driverName string) (*sqlx.DB, error) { + connConfig, err := pgx.ParseConfig(uri) + if err != nil { + return nil, fmt.Errorf("database: failed to parse config: %w", err) + } + + initSQL := config.initSQL() + connector := stdlib.GetConnector(*connConfig, stdlib.OptionAfterConnect(func(ctx context.Context, c *pgx.Conn) (err error) { + _, err = c.Exec(ctx, initSQL) + return + })) + + sqlDB := otelsql.OpenDB(connector, otelOpts...) + return config.newDB(ctx, sqlDB, driverName) +} + +func (config DBConfig) initSQL() string { + return fmt.Sprintf(`SET TIME ZONE 'UTC'; SET lock_timeout = %d; SET idle_in_transaction_session_timeout = %d; SET default_transaction_isolation = %q`, + config.LockTimeout.Milliseconds(), config.IdleInTxSessionTimeout.Milliseconds(), defaultIsolation) +} + +func (config DBConfig) newDB(ctx context.Context, sqldb *sql.DB, driverName string) (*sqlx.DB, error) { + db := sqlx.NewDb(sqldb, driverName) + db.MapperFunc(reflectx.CamelToSnakeASCII) + db.SetMaxOpenConns(config.MaxOpenConns) + db.SetMaxIdleConns(config.MaxIdleConns) + return db, disallowReplica(ctx, db) +} + +func disallowReplica(ctx context.Context, db *sqlx.DB) error { + var val string + err := db.GetContext(ctx, &val, "SHOW session_replication_role") + if err != nil { + return err + } + + if val == "replica" { + return fmt.Errorf("invalid `session_replication_role`: %s. Refusing to connect to replica database. Writing to a replica will corrupt the database", val) + } + + return nil +} diff --git a/pkg/sqlutil/pg/connection_test.go b/pkg/sqlutil/pg/connection_test.go new file mode 100644 index 000000000..374cc5db1 --- /dev/null +++ b/pkg/sqlutil/pg/connection_test.go @@ -0,0 +1,31 @@ +//go:build db + +package pg + +import ( + "testing" + + "github.com/google/uuid" + _ "github.com/jackc/pgx/v4/stdlib" + "github.com/jmoiron/sqlx" + "github.com/stretchr/testify/require" +) + +func Test_disallowReplica(t *testing.T) { + db, err := sqlx.Open(string(DriverTxWrappedPostgres), uuid.New().String()) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, db.Close()) }) + + _, err = db.Exec("SET session_replication_role= 'origin'") + require.NoError(t, err) + err = disallowReplica(db) + require.NoError(t, err) + + _, err = db.Exec("SET session_replication_role= 'replica'") + require.NoError(t, err) + err = disallowReplica(db) + require.Error(t, err, "replica role should be disallowed") + + _, err = db.Exec("SET session_replication_role= 'not_valid_role'") + require.Error(t, err) +} diff --git a/pkg/sqlutil/pg/stats.go b/pkg/sqlutil/pg/stats.go new file mode 100644 index 000000000..685304e2b --- /dev/null +++ b/pkg/sqlutil/pg/stats.go @@ -0,0 +1,132 @@ +package pg + +import ( + "context" + "database/sql" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +const dbStatsInternal = 10 * time.Second + +var ( + promDBConnsMax = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "db_conns_max", + Help: "Maximum number of open connections to the database.", + }) + promDBConnsOpen = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "db_conns_open", + Help: "The number of established connections both in use and idle.", + }) + promDBConnsInUse = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "db_conns_used", + Help: "The number of connections currently in use.", + }) + promDBWaitCount = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "db_wait_count", + Help: "The total number of connections waited for.", + }) + promDBWaitDuration = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "db_wait_time_seconds", + Help: "The total time blocked waiting for a new connection.", + }) +) + +func publishStats(stats sql.DBStats) { + promDBConnsMax.Set(float64(stats.MaxOpenConnections)) + promDBConnsOpen.Set(float64(stats.OpenConnections)) + promDBConnsInUse.Set(float64(stats.InUse)) + + promDBWaitCount.Set(float64(stats.WaitCount)) + promDBWaitDuration.Set(stats.WaitDuration.Seconds()) +} + +type StatsReporterOpt func(*StatsReporter) + +func StatsInterval(d time.Duration) StatsReporterOpt { + return func(r *StatsReporter) { + r.interval = d + } +} + +func StatsCustomReporterFn(fn ReportFn) StatsReporterOpt { + return func(r *StatsReporter) { + r.reportFn = fn + } +} + +type ( + StatFn func() sql.DBStats + ReportFn func(sql.DBStats) +) + +type StatsReporter struct { + statFn StatFn + reportFn ReportFn + interval time.Duration + cancel context.CancelFunc + lggr logger.Logger + once sync.Once + wg sync.WaitGroup +} + +func NewStatsReporter(fn StatFn, lggr logger.Logger, opts ...StatsReporterOpt) *StatsReporter { + r := &StatsReporter{ + statFn: fn, + reportFn: publishStats, + interval: dbStatsInternal, + lggr: logger.Named(lggr, "StatsReporter"), + } + + for _, opt := range opts { + opt(r) + } + + return r +} + +func (r *StatsReporter) Start() { + startOnce := func() { + r.wg.Add(1) + r.lggr.Debug("Starting DB stat reporter") + ctx, cancelFunc := context.WithCancel(context.Background()) + r.cancel = cancelFunc + go r.loop(ctx) + } + + r.once.Do(startOnce) +} + +// Stop stops all resources owned by the reporter and waits +// for all of them to be done +func (r *StatsReporter) Stop() { + if r.cancel != nil { + r.lggr.Debug("Stopping DB stat reporter") + r.cancel() + r.cancel = nil + r.wg.Wait() + } +} + +func (r *StatsReporter) loop(ctx context.Context) { + defer r.wg.Done() + + ticker := time.NewTicker(r.interval) + defer ticker.Stop() + + r.reportFn(r.statFn()) + for { + select { + case <-ticker.C: + r.reportFn(r.statFn()) + case <-ctx.Done(): + r.lggr.Debug("stat reporter loop received done. stopping...") + return + } + } +} diff --git a/pkg/sqlutil/pg/stats_test.go b/pkg/sqlutil/pg/stats_test.go new file mode 100644 index 000000000..c98ec10a9 --- /dev/null +++ b/pkg/sqlutil/pg/stats_test.go @@ -0,0 +1,111 @@ +package pg + +import ( + "database/sql" + "strings" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/stretchr/testify/mock" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +// testDbStater implements mocks for the function signatures +// needed by the stat reporte wrapper for statFn +type testDbStater struct { + mock.Mock + t *testing.T + name string + testGauge prometheus.Gauge +} + +func newtestDbStater(t *testing.T, name string) *testDbStater { + return &testDbStater{ + t: t, + name: name, + testGauge: promauto.NewGauge(prometheus.GaugeOpts{ + Name: strings.ReplaceAll(name, " ", "_"), + }), + } +} + +func (s *testDbStater) Stats() sql.DBStats { + s.Called() + return sql.DBStats{} +} + +func (s *testDbStater) Report(stats sql.DBStats) { + s.Called() + s.testGauge.Set(float64(stats.MaxOpenConnections)) +} + +type statScenario struct { + name string + testFn func(*testing.T, *StatsReporter, time.Duration, int) +} + +func TestStatReporter(t *testing.T) { + interval := 2 * time.Millisecond + expectedIntervals := 4 + + lggr := logger.Test(t) + + for _, scenario := range []statScenario{ + {name: "normal_collect_and_stop", testFn: testCollectAndStop}, + {name: "mutli_start", testFn: testMultiStart}, + {name: "multi_stop", testFn: testMultiStop}, + } { + t.Run(scenario.name, func(t *testing.T) { + d := newtestDbStater(t, scenario.name) + d.Mock.On("Stats").Return(sql.DBStats{}) + d.Mock.On("Report").Return() + reporter := NewStatsReporter(d.Stats, + lggr, + StatsInterval(interval), + StatsCustomReporterFn(d.Report), + ) + + scenario.testFn( + t, + reporter, + interval, + expectedIntervals, + ) + + d.AssertCalled(t, "Stats") + d.AssertCalled(t, "Report") + }) + } +} + +// test normal stop +func testCollectAndStop(t *testing.T, r *StatsReporter, interval time.Duration, n int) { + r.Start() + time.Sleep(time.Duration(n) * interval) + r.Stop() +} + +// test multiple start calls are idempotent +func testMultiStart(t *testing.T, r *StatsReporter, interval time.Duration, n int) { + ticker := time.NewTicker(time.Duration(n) * interval) + defer ticker.Stop() + + r.Start() + r.Start() + <-ticker.C + r.Stop() +} + +// test multiple stop calls are idempotent +func testMultiStop(t *testing.T, r *StatsReporter, interval time.Duration, n int) { + ticker := time.NewTicker(time.Duration(n) * interval) + defer ticker.Stop() + + r.Start() + <-ticker.C + r.Stop() + r.Stop() +} diff --git a/pkg/sqlutil/pg/url.go b/pkg/sqlutil/pg/url.go new file mode 100644 index 000000000..38cafa304 --- /dev/null +++ b/pkg/sqlutil/pg/url.go @@ -0,0 +1,11 @@ +package pg + +import "net/url" + +func SetApplicationName(u *url.URL, name string) { + // trim to postgres limit + name = name[:min(63, len(name))] + q := u.Query() + q.Set("application_name", name) + u.RawQuery = q.Encode() +}