diff --git a/pgxpool/pool.go b/pgxpool/pool.go index fdcba7241..92b694edd 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -15,11 +15,13 @@ import ( "github.com/jackc/puddle/v2" ) -var defaultMaxConns = int32(4) -var defaultMinConns = int32(0) -var defaultMaxConnLifetime = time.Hour -var defaultMaxConnIdleTime = time.Minute * 30 -var defaultHealthCheckPeriod = time.Minute +var ( + defaultMaxConns = int32(4) + defaultMinConns = int32(0) + defaultMaxConnLifetime = time.Hour + defaultMaxConnIdleTime = time.Minute * 30 + defaultHealthCheckPeriod = time.Minute +) type connResource struct { conn *pgx.Conn @@ -100,6 +102,11 @@ type Pool struct { closeOnce sync.Once closeChan chan struct{} + + autoLoadTypes []string + reuseTypeMap bool + autoLoadMutex *sync.Mutex + autoLoadTypeInfos []*pgx.TypeInfo } // Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be @@ -147,6 +154,15 @@ type Config struct { // HealthCheckPeriod is the duration between checks of the health of idle connections. HealthCheckPeriod time.Duration + // AutoLoadTypes is a list of user-defined types which should automatically be loaded + // as each new connection is created. This will also load any related types, directly + // or indirectly required to handle these types. + AutoLoadTypes []string + + // ReuseTypeMaps, if enabled, will reuse the typemap information being used by AutoLoadTypes. + // This removes the need to query the database each time a new connection is created. + ReuseTypeMaps bool + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -185,6 +201,8 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { config: config, beforeConnect: config.BeforeConnect, afterConnect: config.AfterConnect, + autoLoadTypes: config.AutoLoadTypes, + reuseTypeMap: config.ReuseTypeMaps, beforeAcquire: config.BeforeAcquire, afterRelease: config.AfterRelease, beforeClose: config.BeforeClose, @@ -196,6 +214,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { healthCheckPeriod: config.HealthCheckPeriod, healthCheckChan: make(chan struct{}, 1), closeChan: make(chan struct{}), + autoLoadMutex: new(sync.Mutex), } if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok { @@ -237,6 +256,19 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { } } + if len(p.autoLoadTypes) > 0 { + types, err := p.loadTypes(ctx, conn, p.autoLoadTypes) + if err != nil { + conn.Close(ctx) + panic(err) + } + if err = conn.RegisterTypes(types, conn.TypeMap()); err != nil { + conn.Close(ctx) + + panic(err) + } + } + jitterSecs := rand.Float64() * config.MaxConnLifetimeJitter.Seconds() maxAgeTime := time.Now().Add(config.MaxConnLifetime).Add(time.Duration(jitterSecs) * time.Second) @@ -388,6 +420,27 @@ func (p *Pool) Close() { }) } +// loadTypes is used internally to autoload the custom types for a connection, +// potentially reusing previously-loaded typemap information. +func (p *Pool) loadTypes(ctx context.Context, conn *pgx.Conn, typeNames []string) ([]*pgx.TypeInfo, error) { + if p.reuseTypeMap { + p.autoLoadMutex.Lock() + defer p.autoLoadMutex.Unlock() + if p.autoLoadTypeInfos != nil { + return p.autoLoadTypeInfos, nil + } + types, err := conn.LoadTypes(ctx, typeNames) + if err != nil { + return nil, err + } + p.autoLoadTypeInfos = types + return types, err + } + // Avoid needing to acquire the mutex and allow connections to initialise in parallel + // if we have chosen to not reuse the type mapping + return conn.LoadTypes(ctx, typeNames) +} + func (p *Pool) isExpired(res *puddle.Resource[*connResource]) bool { return time.Now().After(res.Value().maxAgeTime) } @@ -482,7 +535,6 @@ func (p *Pool) checkMinConns() error { func (p *Pool) createIdleResources(parentCtx context.Context, targetResources int) error { ctx, cancel := context.WithCancel(parentCtx) defer cancel() - errs := make(chan error, targetResources) for i := 0; i < targetResources; i++ { @@ -495,7 +547,6 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in errs <- err }() } - var firstError error for i := 0; i < targetResources; i++ { err := <-errs diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 90428931b..4c1928eb3 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -261,6 +261,31 @@ func TestPoolBeforeConnect(t *testing.T) { assert.EqualValues(t, "pgx", str) } +func TestAutoLoadTypes(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + db1, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db1.Close() + db1.Exec(ctx, "DROP DOMAIN IF EXISTS autoload_uint64; CREATE DOMAIN autoload_uint64 as numeric(20,0)") + defer db1.Exec(ctx, "DROP DOMAIN autoload_uint64") + + config.AutoLoadTypes = []string{"autoload_uint64"} + db2, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + + var n uint64 + err = db2.QueryRow(ctx, "select 12::autoload_uint64").Scan(&n) + require.NoError(t, err) + assert.EqualValues(t, uint64(12), n) +} + func TestPoolAfterConnect(t *testing.T) { t.Parallel() @@ -676,7 +701,6 @@ func TestPoolQuery(t *testing.T) { stats = pool.Stat() assert.EqualValues(t, 0, stats.AcquiredConns()) assert.EqualValues(t, 1, stats.TotalConns()) - } func TestPoolQueryRow(t *testing.T) { @@ -1104,7 +1128,6 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { } t.Fatal("did not reach min pool size") - } func TestPoolSendBatchBatchCloseTwice(t *testing.T) {