From 0ed053aadf15f72c77ef5abbd945e381a7da868f Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Thu, 23 May 2024 09:21:40 +0200 Subject: [PATCH] Cherry-pick afbce6aa877d66c99ddb942ebe610d0ec5a5e5f1 with conflicts --- go/pools/smartconnpool/pool.go | 762 ++++++++++++ go/pools/smartconnpool/pool_test.go | 1082 +++++++++++++++++ go/vt/vttablet/endtoend/misc_test.go | 6 +- go/vt/vttablet/endtoend/stream_test.go | 7 +- go/vt/vttablet/tabletserver/connpool/pool.go | 6 - .../tabletserver/connpool/pool_test.go | 11 +- go/vt/vttablet/tabletserver/debugenv.go | 18 +- go/vt/vttablet/tabletserver/query_executor.go | 14 + go/vt/vttablet/tabletserver/tabletserver.go | 18 +- .../tabletserver/tabletserver_test.go | 12 +- go/vt/vttablet/tabletserver/tx_pool_test.go | 7 + 11 files changed, 1924 insertions(+), 19 deletions(-) create mode 100644 go/pools/smartconnpool/pool.go create mode 100644 go/pools/smartconnpool/pool_test.go diff --git a/go/pools/smartconnpool/pool.go b/go/pools/smartconnpool/pool.go new file mode 100644 index 00000000000..d49032f34a1 --- /dev/null +++ b/go/pools/smartconnpool/pool.go @@ -0,0 +1,762 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smartconnpool + +import ( + "context" + "math/rand/v2" + "slices" + "sync" + "sync/atomic" + "time" + + "vitess.io/vitess/go/vt/log" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/servenv" + "vitess.io/vitess/go/vt/vterrors" +) + +var ( + // ErrTimeout is returned if a connection get times out. + ErrTimeout = vterrors.New(vtrpcpb.Code_RESOURCE_EXHAUSTED, "connection pool timed out") + + // ErrCtxTimeout is returned if a ctx is already expired by the time the connection pool is used + ErrCtxTimeout = vterrors.New(vtrpcpb.Code_DEADLINE_EXCEEDED, "connection pool context already expired") + + // ErrConnPoolClosed is returned when trying to get a connection from a closed conn pool + ErrConnPoolClosed = vterrors.New(vtrpcpb.Code_INTERNAL, "connection pool is closed") + + // PoolCloseTimeout is how long to wait for all connections to be returned to the pool during close + PoolCloseTimeout = 10 * time.Second +) + +type Metrics struct { + maxLifetimeClosed atomic.Int64 + getCount atomic.Int64 + getWithSettingsCount atomic.Int64 + waitCount atomic.Int64 + waitTime atomic.Int64 + idleClosed atomic.Int64 + diffSetting atomic.Int64 + resetSetting atomic.Int64 +} + +func (m *Metrics) MaxLifetimeClosed() int64 { + return m.maxLifetimeClosed.Load() +} + +func (m *Metrics) GetCount() int64 { + return m.getCount.Load() +} + +func (m *Metrics) GetSettingCount() int64 { + return m.getWithSettingsCount.Load() +} + +func (m *Metrics) WaitCount() int64 { + return m.waitCount.Load() +} + +func (m *Metrics) WaitTime() time.Duration { + return time.Duration(m.waitTime.Load()) +} + +func (m *Metrics) IdleClosed() int64 { + return m.idleClosed.Load() +} + +func (m *Metrics) DiffSettingCount() int64 { + return m.diffSetting.Load() +} + +func (m *Metrics) ResetSettingCount() int64 { + return m.resetSetting.Load() +} + +type Connector[C Connection] func(ctx context.Context) (C, error) +type RefreshCheck func() (bool, error) + +type Config[C Connection] struct { + Capacity int64 + IdleTimeout time.Duration + MaxLifetime time.Duration + RefreshInterval time.Duration + LogWait func(time.Time) +} + +// stackMask is the number of connection setting stacks minus one; +// the number of stacks must always be a power of two +const stackMask = 7 + +// ConnPool is a connection pool for generic connections +type ConnPool[C Connection] struct { + // clean is a connections stack for connections with no Setting applied + clean connStack[C] + // settings are N connection stacks for connections with a Setting applied + // connections are distributed between stacks based on their Setting.bucket + settings [stackMask + 1]connStack[C] + // freshSettingStack is the index in settings to the last stack when a connection + // was pushed, or -1 if no connection with a Setting has been opened in this pool + freshSettingsStack atomic.Int64 + // wait is the list of clients waiting for a connection to be returned to the pool + wait waitlist[C] + + // borrowed is the number of connections that the pool has given out to clients + // and that haven't been returned yet + borrowed atomic.Int64 + // active is the number of connections that the pool has opened; this includes connections + // in the pool and borrowed by clients + active atomic.Int64 + // capacity is the maximum number of connections that this pool can open + capacity atomic.Int64 + + // workers is a waitgroup for all the currently running worker goroutines + workers sync.WaitGroup + close chan struct{} + capacityMu sync.Mutex + + config struct { + // connect is the callback to create a new connection for the pool + connect Connector[C] + // refresh is the callback to check whether the pool needs to be refreshed + refresh RefreshCheck + + // maxCapacity is the maximum value to which capacity can be set; when the pool + // is re-opened, it defaults to this capacity + maxCapacity int64 + // maxLifetime is the maximum time a connection can be open + maxLifetime atomic.Int64 + // idleTimeout is the maximum time a connection can remain idle + idleTimeout atomic.Int64 + // refreshInterval is how often to call the refresh check + refreshInterval atomic.Int64 + // logWait is called every time a client must block waiting for a connection + logWait func(time.Time) + } + + Metrics Metrics + Name string +} + +// NewPool creates a new connection pool with the given Config. +// The pool must be ConnPool.Open before it can start giving out connections +func NewPool[C Connection](config *Config[C]) *ConnPool[C] { + pool := &ConnPool[C]{} + pool.freshSettingsStack.Store(-1) + pool.config.maxCapacity = config.Capacity + pool.config.maxLifetime.Store(config.MaxLifetime.Nanoseconds()) + pool.config.idleTimeout.Store(config.IdleTimeout.Nanoseconds()) + pool.config.refreshInterval.Store(config.RefreshInterval.Nanoseconds()) + pool.config.logWait = config.LogWait + pool.wait.init() + + return pool +} + +func (pool *ConnPool[C]) runWorker(close <-chan struct{}, interval time.Duration, worker func(now time.Time) bool) { + pool.workers.Add(1) + + go func() { + tick := time.NewTicker(interval) + + defer tick.Stop() + defer pool.workers.Done() + + for { + select { + case now := <-tick.C: + if !worker(now) { + return + } + case <-close: + return + } + } + }() +} + +func (pool *ConnPool[C]) open() { + pool.close = make(chan struct{}) + pool.capacity.Store(pool.config.maxCapacity) + + // The expire worker takes care of removing from the waiter list any clients whose + // context has been cancelled. + pool.runWorker(pool.close, 1*time.Second, func(_ time.Time) bool { + pool.wait.expire(false) + return true + }) + + idleTimeout := pool.IdleTimeout() + if idleTimeout != 0 { + // The idle worker takes care of closing connections that have been idle too long + pool.runWorker(pool.close, idleTimeout/10, func(now time.Time) bool { + pool.closeIdleResources(now) + return true + }) + } + + refreshInterval := pool.RefreshInterval() + if refreshInterval != 0 && pool.config.refresh != nil { + // The refresh worker periodically checks the refresh callback in this pool + // to decide whether all the connections in the pool need to be cycled + // (this usually only happens when there's a global DNS change). + pool.runWorker(pool.close, refreshInterval, func(_ time.Time) bool { + refresh, err := pool.config.refresh() + if err != nil { + log.Error(err) + } + if refresh { + go pool.reopen() + return false + } + return true + }) + } +} + +// Open starts the background workers that manage the pool and gets it ready +// to start serving out connections. +func (pool *ConnPool[C]) Open(connect Connector[C], refresh RefreshCheck) *ConnPool[C] { + if pool.close != nil { + // already open + return pool + } + + pool.config.connect = connect + pool.config.refresh = refresh + pool.open() + return pool +} + +// Close shuts down the pool. No connections will be returned from ConnPool.Get after calling this, +// but calling ConnPool.Put is still allowed. This function will not return until all of the pool's +// connections have been returned or the default PoolCloseTimeout has elapsed +func (pool *ConnPool[C]) Close() { + ctx, cancel := context.WithTimeout(context.Background(), PoolCloseTimeout) + defer cancel() + + if err := pool.CloseWithContext(ctx); err != nil { + log.Errorf("failed to close pool %q: %v", pool.Name, err) + } +} + +// CloseWithContext behaves like Close but allows passing in a Context to time out the +// pool closing operation +func (pool *ConnPool[C]) CloseWithContext(ctx context.Context) error { + pool.capacityMu.Lock() + defer pool.capacityMu.Unlock() + + if pool.close == nil || pool.capacity.Load() == 0 { + // already closed + return nil + } + + // close all the connections in the pool; if we time out while waiting for + // users to return our connections, we still want to finish the shutdown + // for the pool + err := pool.setCapacity(ctx, 0) + + close(pool.close) + pool.workers.Wait() + pool.close = nil + return err +} + +func (pool *ConnPool[C]) reopen() { + pool.capacityMu.Lock() + defer pool.capacityMu.Unlock() + + capacity := pool.capacity.Load() + if capacity == 0 { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), PoolCloseTimeout) + defer cancel() + + // to re-open the connection pool, first set the capacity to 0 so we close + // all the existing connections, as they're now connected to a stale MySQL + // instance. + if err := pool.setCapacity(ctx, 0); err != nil { + log.Errorf("failed to reopen pool %q: %v", pool.Name, err) + } + + // the second call to setCapacity cannot fail because it's only increasing the number + // of connections and doesn't need to shut down any + _ = pool.setCapacity(ctx, capacity) +} + +// IsOpen returns whether the pool is open +func (pool *ConnPool[C]) IsOpen() bool { + return pool.close != nil +} + +// Capacity returns the maximum amount of connections that this pool can maintain open +func (pool *ConnPool[C]) Capacity() int64 { + return pool.capacity.Load() +} + +// MaxCapacity returns the maximum value to which Capacity can be set via ConnPool.SetCapacity +func (pool *ConnPool[C]) MaxCapacity() int64 { + return pool.config.maxCapacity +} + +// InUse returns the number of connections that the pool has lent out to clients and that +// haven't been returned yet. +func (pool *ConnPool[C]) InUse() int64 { + return pool.borrowed.Load() +} + +// Available returns the number of connections that the pool can immediately lend out to +// clients without blocking. +func (pool *ConnPool[C]) Available() int64 { + return pool.capacity.Load() - pool.borrowed.Load() +} + +// Active returns the numer of connections that the pool has currently open. +func (pool *ConnPool[C]) Active() int64 { + return pool.active.Load() +} + +func (pool *ConnPool[D]) IdleTimeout() time.Duration { + return time.Duration(pool.config.idleTimeout.Load()) +} + +func (pool *ConnPool[C]) SetIdleTimeout(duration time.Duration) { + pool.config.idleTimeout.Store(duration.Nanoseconds()) +} + +func (pool *ConnPool[D]) RefreshInterval() time.Duration { + return time.Duration(pool.config.refreshInterval.Load()) +} + +func (pool *ConnPool[C]) recordWait(start time.Time) { + pool.Metrics.waitCount.Add(1) + pool.Metrics.waitTime.Add(time.Since(start).Nanoseconds()) + if pool.config.logWait != nil { + pool.config.logWait(start) + } +} + +// Get returns a connection from the pool with the given Setting applied. +// If there are no connections in the pool to be returned, Get blocks until one +// is returned, or until the given ctx is cancelled. +// The connection must be returned to the pool once it's not needed by calling Pooled.Recycle +func (pool *ConnPool[C]) Get(ctx context.Context, setting *Setting) (*Pooled[C], error) { + if ctx.Err() != nil { + return nil, ErrCtxTimeout + } + if pool.capacity.Load() == 0 { + return nil, ErrConnPoolClosed + } + if setting == nil { + return pool.get(ctx) + } + return pool.getWithSetting(ctx, setting) +} + +// put returns a connection to the pool. This is a private API. +// Return connections to the pool by calling Pooled.Recycle +func (pool *ConnPool[C]) put(conn *Pooled[C]) { + pool.borrowed.Add(-1) + + if conn == nil { + var err error + conn, err = pool.connNew(context.Background()) + if err != nil { + pool.closedConn() + return + } + } else { + conn.timeUsed = time.Now() + + lifetime := pool.extendedMaxLifetime() + if lifetime > 0 && time.Until(conn.timeCreated.Add(lifetime)) < 0 { + pool.Metrics.maxLifetimeClosed.Add(1) + conn.Close() + if err := pool.connReopen(context.Background(), conn, conn.timeUsed); err != nil { + pool.closedConn() + return + } + } + } + + if !pool.wait.tryReturnConn(conn) { + connSetting := conn.Conn.Setting() + if connSetting == nil { + pool.clean.Push(conn) + } else { + stack := connSetting.bucket & stackMask + pool.settings[stack].Push(conn) + pool.freshSettingsStack.Store(int64(stack)) + } + } +} + +func (pool *ConnPool[D]) extendedMaxLifetime() time.Duration { + maxLifetime := pool.config.maxLifetime.Load() + if maxLifetime == 0 { + return 0 + } + return time.Duration(maxLifetime) + time.Duration(rand.Uint32N(uint32(maxLifetime))) +} + +func (pool *ConnPool[C]) connReopen(ctx context.Context, dbconn *Pooled[C], now time.Time) error { + var err error + dbconn.Conn, err = pool.config.connect(ctx) + if err != nil { + return err + } + + dbconn.timeUsed = now + dbconn.timeCreated = now + return nil +} + +func (pool *ConnPool[C]) connNew(ctx context.Context) (*Pooled[C], error) { + conn, err := pool.config.connect(ctx) + if err != nil { + return nil, err + } + now := time.Now() + return &Pooled[C]{ + timeCreated: now, + timeUsed: now, + pool: pool, + Conn: conn, + }, nil +} + +func (pool *ConnPool[C]) getFromSettingsStack(setting *Setting) *Pooled[C] { + fresh := pool.freshSettingsStack.Load() + if fresh < 0 { + return nil + } + + var start uint32 + if setting == nil { + start = uint32(fresh) + } else { + start = setting.bucket + } + + for i := uint32(0); i <= stackMask; i++ { + pos := (i + start) & stackMask + if conn, ok := pool.settings[pos].Pop(); ok { + return conn + } + } + return nil +} + +func (pool *ConnPool[C]) closedConn() { + _ = pool.active.Add(-1) +} + +func (pool *ConnPool[C]) getNew(ctx context.Context) (*Pooled[C], error) { + for { + open := pool.active.Load() + if open >= pool.capacity.Load() { + return nil, nil + } + + if pool.active.CompareAndSwap(open, open+1) { + conn, err := pool.connNew(ctx) + if err != nil { + pool.closedConn() + return nil, err + } + return conn, nil + } + } +} + +// get returns a pooled connection with no Setting applied +func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) { + pool.Metrics.getCount.Add(1) + + // best case: if there's a connection in the clean stack, return it right away + if conn, ok := pool.clean.Pop(); ok { + pool.borrowed.Add(1) + return conn, nil + } + + // check if we have enough capacity to open a brand-new connection to return + conn, err := pool.getNew(ctx) + if err != nil { + return nil, err + } + // if we don't have capacity, try popping a connection from any of the setting stacks + if conn == nil { + conn = pool.getFromSettingsStack(nil) + } + // if there are no connections in the setting stacks and we've lent out connections + // to other clients, wait until one of the connections is returned + if conn == nil { + start := time.Now() + conn, err = pool.wait.waitForConn(ctx, nil) + if err != nil { + return nil, ErrTimeout + } + pool.recordWait(start) + } + // no connections available and no connections to wait for (pool is closed) + if conn == nil { + return nil, ErrTimeout + } + + // if the connection we've acquired has a Setting applied, we must reset it before returning + if conn.Conn.Setting() != nil { + pool.Metrics.resetSetting.Add(1) + + err = conn.Conn.ResetSetting(ctx) + if err != nil { + conn.Close() + err = pool.connReopen(ctx, conn, time.Now()) + if err != nil { + pool.closedConn() + return nil, err + } + } + } + + pool.borrowed.Add(1) + return conn, nil +} + +// getWithSetting returns a connection from the pool with the given Setting applied +func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) (*Pooled[C], error) { + pool.Metrics.getWithSettingsCount.Add(1) + + var err error + // best case: check if there's a connection in the setting stack where our Setting belongs + conn, _ := pool.settings[setting.bucket&stackMask].Pop() + // if there's connection with our setting, try popping a clean connection + if conn == nil { + conn, _ = pool.clean.Pop() + } + // otherwise try opening a brand new connection and we'll apply the setting to it + if conn == nil { + conn, err = pool.getNew(ctx) + if err != nil { + return nil, err + } + } + // try on the _other_ setting stacks, even if we have to reset the Setting for the returned + // connection + if conn == nil { + conn = pool.getFromSettingsStack(setting) + } + // no connections anywhere in the pool; if we've lent out connections to other clients + // wait for one of them + if conn == nil { + start := time.Now() + conn, err = pool.wait.waitForConn(ctx, setting) + if err != nil { + return nil, ErrTimeout + } + pool.recordWait(start) + } + // no connections available and no connections to wait for (pool is closed) + if conn == nil { + return nil, ErrTimeout + } + + // ensure that the setting applied to the connection matches the one we want + connSetting := conn.Conn.Setting() + if connSetting != setting { + // if there's another setting applied, reset it before applying our setting + if connSetting != nil { + pool.Metrics.diffSetting.Add(1) + + err = conn.Conn.ResetSetting(ctx) + if err != nil { + conn.Close() + err = pool.connReopen(ctx, conn, time.Now()) + if err != nil { + pool.closedConn() + return nil, err + } + } + } + // apply our setting now; if we can't we assume that the conn is broken + // and close it without returning to the pool + if err := conn.Conn.ApplySetting(ctx, setting); err != nil { + conn.Close() + pool.closedConn() + return nil, err + } + } + + pool.borrowed.Add(1) + return conn, nil +} + +// SetCapacity changes the capacity (number of open connections) on the pool. +// If the capacity is smaller than the number of connections that there are +// currently open, we'll close enough connections before returning, even if +// that means waiting for clients to return connections to the pool. +// If the given context times out before we've managed to close enough connections +// an error will be returned. +func (pool *ConnPool[C]) SetCapacity(ctx context.Context, newcap int64) error { + pool.capacityMu.Lock() + defer pool.capacityMu.Unlock() + return pool.setCapacity(ctx, newcap) +} + +// setCapacity is the internal implementation for SetCapacity; it must be called +// with pool.capacityMu being held +func (pool *ConnPool[C]) setCapacity(ctx context.Context, newcap int64) error { + if newcap < 0 { + panic("negative capacity") + } + + oldcap := pool.capacity.Swap(newcap) + if oldcap == newcap { + return nil + } + + const delay = 10 * time.Millisecond + + // close connections until we're under capacity + for pool.active.Load() > newcap { + if err := ctx.Err(); err != nil { + return vterrors.Errorf(vtrpcpb.Code_ABORTED, + "timed out while waiting for connections to be returned to the pool (capacity=%d, active=%d, borrowed=%d)", + pool.capacity.Load(), pool.active.Load(), pool.borrowed.Load()) + } + // if we're closing down the pool, make sure there's no clients waiting + // for connections because they won't be returned in the future + if newcap == 0 { + pool.wait.expire(true) + } + + // try closing from connections which are currently idle in the stacks + conn := pool.getFromSettingsStack(nil) + if conn == nil { + conn, _ = pool.clean.Pop() + } + if conn == nil { + time.Sleep(delay) + continue + } + conn.Close() + pool.closedConn() + } + + return nil +} + +func (pool *ConnPool[C]) closeIdleResources(now time.Time) { + timeout := pool.IdleTimeout() + if timeout == 0 { + return + } + if pool.Capacity() == 0 { + return + } + + var conns []*Pooled[C] + + closeInStack := func(s *connStack[C]) { + conns = s.PopAll(conns[:0]) + slices.Reverse(conns) + + for _, conn := range conns { + if conn.timeUsed.Add(timeout).Sub(now) < 0 { + pool.Metrics.idleClosed.Add(1) + conn.Close() + pool.closedConn() + continue + } + + s.Push(conn) + } + } + + for i := 0; i <= stackMask; i++ { + closeInStack(&pool.settings[i]) + } + closeInStack(&pool.clean) +} + +func (pool *ConnPool[C]) StatsJSON() map[string]any { + return map[string]any{ + "Capacity": int(pool.Capacity()), + "Available": int(pool.Available()), + "Active": int(pool.active.Load()), + "InUse": int(pool.InUse()), + "WaitCount": int(pool.Metrics.WaitCount()), + "WaitTime": pool.Metrics.WaitTime(), + "IdleTimeout": pool.IdleTimeout(), + "IdleClosed": int(pool.Metrics.IdleClosed()), + "MaxLifetimeClosed": int(pool.Metrics.MaxLifetimeClosed()), + } +} + +// RegisterStats registers this pool's metrics into a stats Exporter +func (pool *ConnPool[C]) RegisterStats(stats *servenv.Exporter, name string) { + if stats == nil || name == "" { + return + } + + pool.Name = name + + stats.NewGaugeFunc(name+"Capacity", "Tablet server conn pool capacity", func() int64 { + return pool.Capacity() + }) + stats.NewGaugeFunc(name+"Available", "Tablet server conn pool available", func() int64 { + return pool.Available() + }) + stats.NewGaugeFunc(name+"Active", "Tablet server conn pool active", func() int64 { + return pool.Active() + }) + stats.NewGaugeFunc(name+"InUse", "Tablet server conn pool in use", func() int64 { + return pool.InUse() + }) + stats.NewGaugeFunc(name+"MaxCap", "Tablet server conn pool max cap", func() int64 { + // the smartconnpool doesn't have a maximum capacity + return pool.Capacity() + }) + stats.NewCounterFunc(name+"WaitCount", "Tablet server conn pool wait count", func() int64 { + return pool.Metrics.WaitCount() + }) + stats.NewCounterDurationFunc(name+"WaitTime", "Tablet server wait time", func() time.Duration { + return pool.Metrics.WaitTime() + }) + stats.NewGaugeDurationFunc(name+"IdleTimeout", "Tablet server idle timeout", func() time.Duration { + return pool.IdleTimeout() + }) + stats.NewCounterFunc(name+"IdleClosed", "Tablet server conn pool idle closed", func() int64 { + return pool.Metrics.IdleClosed() + }) + stats.NewCounterFunc(name+"MaxLifetimeClosed", "Tablet server conn pool refresh closed", func() int64 { + return pool.Metrics.MaxLifetimeClosed() + }) + stats.NewCounterFunc(name+"Get", "Tablet server conn pool get count", func() int64 { + return pool.Metrics.GetCount() + }) + stats.NewCounterFunc(name+"GetSetting", "Tablet server conn pool get with setting count", func() int64 { + return pool.Metrics.GetSettingCount() + }) + stats.NewCounterFunc(name+"DiffSetting", "Number of times pool applied different setting", func() int64 { + return pool.Metrics.DiffSettingCount() + }) + stats.NewCounterFunc(name+"ResetSetting", "Number of times pool reset the setting", func() int64 { + return pool.Metrics.ResetSettingCount() + }) +} diff --git a/go/pools/smartconnpool/pool_test.go b/go/pools/smartconnpool/pool_test.go new file mode 100644 index 00000000000..701327005ad --- /dev/null +++ b/go/pools/smartconnpool/pool_test.go @@ -0,0 +1,1082 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smartconnpool + +import ( + "context" + "fmt" + "reflect" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + sFoo = &Setting{queryApply: "set foo=1"} + sBar = &Setting{queryApply: "set bar=1"} +) + +type TestState struct { + lastID, open, close, reset atomic.Int64 + waits []time.Time + + chaos struct { + delayConnect time.Duration + failConnect bool + failApply bool + } +} + +func (ts *TestState) LogWait(start time.Time) { + ts.waits = append(ts.waits, start) +} + +type TestConn struct { + counts *TestState + onClose chan struct{} + + setting *Setting + num int64 + timeCreated time.Time + closed bool + failApply bool +} + +func (tr *TestConn) waitForClose() chan struct{} { + tr.onClose = make(chan struct{}) + return tr.onClose +} + +func (tr *TestConn) IsClosed() bool { + return tr.closed +} + +func (tr *TestConn) Setting() *Setting { + return tr.setting +} + +func (tr *TestConn) ResetSetting(ctx context.Context) error { + tr.counts.reset.Add(1) + tr.setting = nil + return nil +} + +func (tr *TestConn) ApplySetting(ctx context.Context, setting *Setting) error { + if tr.failApply { + return fmt.Errorf("ApplySetting failed") + } + tr.setting = setting + return nil +} + +func (tr *TestConn) Close() { + if !tr.closed { + if tr.onClose != nil { + close(tr.onClose) + } + tr.counts.open.Add(-1) + tr.counts.close.Add(1) + tr.closed = true + } +} + +var _ Connection = (*TestConn)(nil) + +func newConnector(state *TestState) Connector[*TestConn] { + return func(ctx context.Context) (*TestConn, error) { + state.open.Add(1) + if state.chaos.delayConnect != 0 { + time.Sleep(state.chaos.delayConnect) + } + if state.chaos.failConnect { + return nil, fmt.Errorf("failed to connect: forced failure") + } + return &TestConn{ + num: state.lastID.Add(1), + timeCreated: time.Now(), + counts: state, + failApply: state.chaos.failApply, + }, nil + } +} + +func TestOpen(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources [10]*Pooled[*TestConn] + var r *Pooled[*TestConn] + var err error + + // Test Get + for i := 0; i < 5; i++ { + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + assert.EqualValues(t, 5-i-1, p.Available()) + assert.Zero(t, p.Metrics.WaitCount()) + assert.Zero(t, len(state.waits)) + assert.Zero(t, p.Metrics.WaitTime()) + assert.EqualValues(t, i+1, state.lastID.Load()) + assert.EqualValues(t, i+1, state.open.Load()) + } + + // Test that Get waits + done := make(chan struct{}) + go func() { + for i := 0; i < 5; i++ { + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + close(done) + }() + for i := 0; i < 5; i++ { + // block until we have a client wait for a connection, then offer it + for p.wait.waiting() == 0 { + time.Sleep(time.Millisecond) + } + p.put(resources[i]) + } + <-done + assert.EqualValues(t, 5, p.Metrics.WaitCount()) + assert.Equal(t, 5, len(state.waits)) + // verify start times are monotonic increasing + for i := 1; i < len(state.waits); i++ { + if state.waits[i].Before(state.waits[i-1]) { + t.Errorf("Expecting monotonic increasing start times") + } + } + assert.NotZero(t, p.Metrics.WaitTime()) + assert.EqualValues(t, 5, state.lastID.Load()) + // Test Close resource + r, err = p.Get(ctx, nil) + require.NoError(t, err) + r.Close() + // A nil Put should cause the resource to be reopened. + p.put(nil) + assert.EqualValues(t, 5, state.open.Load()) + assert.EqualValues(t, 6, state.lastID.Load()) + + for i := 0; i < 5; i++ { + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + assert.EqualValues(t, 5, state.open.Load()) + assert.EqualValues(t, 6, state.lastID.Load()) + + // SetCapacity + err = p.SetCapacity(ctx, 3) + require.NoError(t, err) + assert.EqualValues(t, 3, state.open.Load()) + assert.EqualValues(t, 6, state.lastID.Load()) + assert.EqualValues(t, 3, p.Capacity()) + assert.EqualValues(t, 3, p.Available()) + + err = p.SetCapacity(ctx, 6) + require.NoError(t, err) + assert.EqualValues(t, 6, p.Capacity()) + assert.EqualValues(t, 6, p.Available()) + + for i := 0; i < 6; i++ { + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + for i := 0; i < 6; i++ { + p.put(resources[i]) + } + assert.EqualValues(t, 6, state.open.Load()) + assert.EqualValues(t, 9, state.lastID.Load()) + + // Close + p.Close() + assert.EqualValues(t, 0, p.Capacity()) + assert.EqualValues(t, 0, p.Available()) + assert.EqualValues(t, 0, state.open.Load()) +} + +func TestShrinking(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources [10]*Pooled[*TestConn] + // Leave one empty slot in the pool + for i := 0; i < 4; i++ { + var r *Pooled[*TestConn] + var err error + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + done := make(chan bool) + go func() { + err := p.SetCapacity(ctx, 3) + require.NoError(t, err) + + done <- true + }() + expected := map[string]any{ + "Capacity": 3, + "Available": -1, // negative because we've borrowed past our capacity + "Active": 4, + "InUse": 4, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + for i := 0; i < 10; i++ { + time.Sleep(10 * time.Millisecond) + stats := p.StatsJSON() + if reflect.DeepEqual(expected, stats) { + break + } + if i == 9 { + assert.Equal(t, expected, stats) + } + } + // There are already 2 resources available in the pool. + // So, returning one should be enough for SetCapacity to complete. + p.put(resources[3]) + <-done + // Return the rest of the resources + for i := 0; i < 3; i++ { + p.put(resources[i]) + } + stats := p.StatsJSON() + expected = map[string]any{ + "Capacity": 3, + "Available": 3, + "Active": 3, + "InUse": 0, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + assert.EqualValues(t, 3, state.open.Load()) + + // Ensure no deadlock if SetCapacity is called after we start + // waiting for a resource + var err error + for i := 0; i < 3; i++ { + var r *Pooled[*TestConn] + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + // This will wait because pool is empty + go func() { + r, err := p.Get(ctx, nil) + require.NoError(t, err) + p.put(r) + done <- true + }() + + // This will also wait + go func() { + err := p.SetCapacity(ctx, 2) + require.NoError(t, err) + done <- true + }() + time.Sleep(10 * time.Millisecond) + + // This should not hang + for i := 0; i < 3; i++ { + p.put(resources[i]) + } + <-done + <-done + assert.EqualValues(t, 2, p.Capacity()) + assert.EqualValues(t, 2, p.Available()) + assert.EqualValues(t, 1, p.Metrics.WaitCount()) + assert.EqualValues(t, p.Metrics.WaitCount(), len(state.waits)) + assert.EqualValues(t, 2, state.open.Load()) + + // Test race condition of SetCapacity with itself + err = p.SetCapacity(ctx, 3) + require.NoError(t, err) + for i := 0; i < 3; i++ { + var r *Pooled[*TestConn] + var err error + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + // This will wait because pool is empty + go func() { + r, err := p.Get(ctx, nil) + require.NoError(t, err) + p.put(r) + done <- true + }() + time.Sleep(10 * time.Millisecond) + + // This will wait till we Put + go func() { + err := p.SetCapacity(ctx, 2) + require.NoError(t, err) + }() + time.Sleep(10 * time.Millisecond) + go func() { + err := p.SetCapacity(ctx, 4) + require.NoError(t, err) + }() + time.Sleep(10 * time.Millisecond) + + // This should not hang + for i := 0; i < 3; i++ { + p.put(resources[i]) + } + <-done + + assert.Panics(t, func() { + _ = p.SetCapacity(ctx, -1) + }) + + assert.EqualValues(t, 4, p.Capacity()) + assert.EqualValues(t, 4, p.Available()) +} + +func TestClosing(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources [10]*Pooled[*TestConn] + for i := 0; i < 5; i++ { + var r *Pooled[*TestConn] + var err error + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + ch := make(chan bool) + go func() { + p.Close() + ch <- true + }() + + // Wait for goroutine to call Close + time.Sleep(10 * time.Millisecond) + stats := p.StatsJSON() + expected := map[string]any{ + "Capacity": 0, + "Available": -5, + "Active": 5, + "InUse": 5, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + + // Put is allowed when closing + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + + // Wait for Close to return + <-ch + + stats = p.StatsJSON() + expected = map[string]any{ + "Capacity": 0, + "Available": 0, + "Active": 0, + "InUse": 0, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + assert.EqualValues(t, 5, state.lastID.Load()) + assert.EqualValues(t, 0, state.open.Load()) +} + +func TestReopen(t *testing.T) { + var state TestState + var refreshed atomic.Bool + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + RefreshInterval: 500 * time.Millisecond, + }).Open(newConnector(&state), func() (bool, error) { + refreshed.Store(true) + return true, nil + }) + + var resources [10]*Pooled[*TestConn] + for i := 0; i < 5; i++ { + var r *Pooled[*TestConn] + var err error + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + + time.Sleep(10 * time.Millisecond) + stats := p.StatsJSON() + expected := map[string]any{ + "Capacity": 5, + "Available": 0, + "Active": 5, + "InUse": 5, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + + time.Sleep(1 * time.Second) + assert.Truef(t, refreshed.Load(), "did not refresh") + + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + time.Sleep(50 * time.Millisecond) + stats = p.StatsJSON() + expected = map[string]any{ + "Capacity": 5, + "Available": 5, + "Active": 0, + "InUse": 0, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + assert.EqualValues(t, 5, state.lastID.Load()) + assert.EqualValues(t, 0, state.open.Load()) +} + +func TestUserClosing(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources [5]*Pooled[*TestConn] + for i := 0; i < 5; i++ { + var err error + resources[i], err = p.Get(ctx, nil) + require.NoError(t, err) + } + + for _, r := range resources[:4] { + r.Recycle() + } + + ch := make(chan error) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := p.CloseWithContext(ctx) + ch <- err + close(ch) + }() + + select { + case <-time.After(5 * time.Second): + t.Fatalf("Pool did not shutdown after 5s") + case err := <-ch: + require.Error(t, err) + t.Logf("Shutdown error: %v", err) + } +} + +func TestIdleTimeout(t *testing.T) { + testTimeout := func(t *testing.T, setting *Setting) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: 10 * time.Millisecond, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + defer p.Close() + + var conns []*Pooled[*TestConn] + for i := 0; i < 5; i++ { + r, err := p.Get(ctx, setting) + require.NoError(t, err) + assert.EqualValues(t, i+1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.IdleClosed()) + + conns = append(conns, r) + } + + // wait a long while; ensure that none of the conns have been closed + time.Sleep(1 * time.Second) + + var closers []chan struct{} + for _, conn := range conns { + assert.Falsef(t, conn.Conn.IsClosed(), "connection was idle-closed while outside the pool") + closers = append(closers, conn.Conn.waitForClose()) + p.put(conn) + } + + for _, closed := range closers { + <-closed + } + + // no need to assert anything: all the connections in the pool should are idle-closed + // now and if they're not the test will timeout and fail + } + + t.Run("WithoutSettings", func(t *testing.T) { testTimeout(t, nil) }) + t.Run("WithSettings", func(t *testing.T) { testTimeout(t, sFoo) }) +} + +func TestIdleTimeoutCreateFail(t *testing.T) { + var state TestState + var connector = newConnector(&state) + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 1, + IdleTimeout: 10 * time.Millisecond, + LogWait: state.LogWait, + }).Open(connector, nil) + + defer p.Close() + + for _, setting := range []*Setting{nil, sFoo} { + r, err := p.Get(ctx, setting) + require.NoError(t, err) + // Change the factory before putting back + // to prevent race with the idle closer, who will + // try to use it. + state.chaos.failConnect = true + p.put(r) + timeout := time.After(1 * time.Second) + for p.Active() != 0 { + select { + case <-timeout: + t.Errorf("Timed out waiting for resource to be closed by idle timeout") + default: + } + } + // reset factory for next run. + state.chaos.failConnect = false + } +} + +func TestMaxLifetime(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 1, + IdleTimeout: 10 * time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + r, err := p.Get(ctx, nil) + require.NoError(t, err) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.MaxLifetimeClosed()) + + time.Sleep(10 * time.Millisecond) + + p.put(r) + assert.EqualValues(t, 1, state.lastID.Load()) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.MaxLifetimeClosed()) + + p.Close() + + // maxLifetime > 0 + state.lastID.Store(0) + state.open.Store(0) + + p = NewPool(&Config[*TestConn]{ + Capacity: 1, + IdleTimeout: 10 * time.Second, + MaxLifetime: 10 * time.Millisecond, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + r, err = p.Get(ctx, nil) + require.NoError(t, err) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.MaxLifetimeClosed()) + + time.Sleep(5 * time.Millisecond) + + p.put(r) + assert.EqualValues(t, 1, state.lastID.Load()) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.MaxLifetimeClosed()) + + r, err = p.Get(ctx, nil) + require.NoError(t, err) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.MaxLifetimeClosed()) + + time.Sleep(10 * time.Millisecond * 2) + + p.put(r) + assert.EqualValues(t, 2, state.lastID.Load()) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 1, p.Metrics.MaxLifetimeClosed()) +} + +func TestExtendedLifetimeTimeout(t *testing.T) { + var state TestState + var connector = newConnector(&state) + var config = &Config[*TestConn]{ + Capacity: 1, + IdleTimeout: time.Second, + MaxLifetime: 0, + LogWait: state.LogWait, + } + + // maxLifetime 0 + p := NewPool(config).Open(connector, nil) + assert.Zero(t, p.extendedMaxLifetime()) + p.Close() + + // maxLifetime > 0 + config.MaxLifetime = 10 * time.Millisecond + for i := 0; i < 10; i++ { + p = NewPool(config).Open(connector, nil) + assert.LessOrEqual(t, config.MaxLifetime, p.extendedMaxLifetime()) + assert.Greater(t, 2*config.MaxLifetime, p.extendedMaxLifetime()) + p.Close() + } +} + +func TestCreateFail(t *testing.T) { + var state TestState + state.chaos.failConnect = true + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + for _, setting := range []*Setting{nil, sFoo} { + if _, err := p.Get(ctx, setting); err.Error() != "failed to connect: forced failure" { + t.Errorf("Expecting Failed, received %v", err) + } + stats := p.StatsJSON() + expected := map[string]any{ + "Capacity": 5, + "Available": 5, + "Active": 0, + "InUse": 0, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + } +} + +func TestCreateFailOnPut(t *testing.T) { + var state TestState + var connector = newConnector(&state) + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(connector, nil) + + defer p.Close() + + for _, setting := range []*Setting{nil, sFoo} { + _, err := p.Get(ctx, setting) + require.NoError(t, err) + + // change factory to fail the put. + state.chaos.failConnect = true + p.put(nil) + assert.Zero(t, p.Active()) + + // change back for next iteration. + state.chaos.failConnect = false + } +} + +func TestSlowCreateFail(t *testing.T) { + var state TestState + state.chaos.delayConnect = 10 * time.Millisecond + + ctx := context.Background() + ch := make(chan *Pooled[*TestConn]) + + for _, setting := range []*Setting{nil, sFoo} { + p := NewPool(&Config[*TestConn]{ + Capacity: 2, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + state.chaos.failConnect = true + + for i := 0; i < 3; i++ { + go func() { + conn, _ := p.Get(ctx, setting) + ch <- conn + }() + } + assert.Nil(t, <-ch) + assert.Nil(t, <-ch) + assert.Equalf(t, p.Capacity(), int64(2), "pool should not be out of capacity") + assert.Equalf(t, p.Available(), int64(2), "pool should not be out of availability") + + select { + case <-ch: + assert.Fail(t, "there should be no capacity for a third connection") + default: + } + + state.chaos.failConnect = false + conn, err := p.Get(ctx, setting) + require.NoError(t, err) + + p.put(conn) + conn = <-ch + assert.NotNil(t, conn) + p.put(conn) + p.Close() + } +} + +func TestTimeout(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 1, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + defer p.Close() + + // take the only connection available + r, err := p.Get(ctx, nil) + require.NoError(t, err) + + for _, setting := range []*Setting{nil, sFoo} { + // trying to get the connection without a timeout. + newctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + _, err = p.Get(newctx, setting) + cancel() + assert.EqualError(t, err, "connection pool timed out") + + } + + // put the connection take was taken initially. + p.put(r) +} + +func TestExpired(t *testing.T) { + var state TestState + + p := NewPool(&Config[*TestConn]{ + Capacity: 1, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + defer p.Close() + + for _, setting := range []*Setting{nil, sFoo} { + // expired context + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) + _, err := p.Get(ctx, setting) + cancel() + require.EqualError(t, err, "connection pool context already expired") + } +} + +func TestMultiSettings(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources [10]*Pooled[*TestConn] + var r *Pooled[*TestConn] + var err error + + settings := []*Setting{nil, sFoo, sBar, sBar, sFoo} + + // Test Get + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[i]) + require.NoError(t, err) + resources[i] = r + assert.EqualValues(t, 5-i-1, p.Available()) + assert.Zero(t, p.Metrics.WaitCount()) + assert.Zero(t, len(state.waits)) + assert.Zero(t, p.Metrics.WaitTime()) + assert.EqualValues(t, i+1, state.lastID.Load()) + assert.EqualValues(t, i+1, state.open.Load()) + } + + // Test that Get waits + ch := make(chan bool) + go func() { + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[i]) + require.NoError(t, err) + resources[i] = r + } + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + ch <- true + }() + for i := 0; i < 5; i++ { + // Sleep to ensure the goroutine waits + time.Sleep(10 * time.Millisecond) + p.put(resources[i]) + } + <-ch + assert.EqualValues(t, 5, p.Metrics.WaitCount()) + assert.Equal(t, 5, len(state.waits)) + // verify start times are monotonic increasing + for i := 1; i < len(state.waits); i++ { + if state.waits[i].Before(state.waits[i-1]) { + t.Errorf("Expecting monotonic increasing start times") + } + } + assert.NotZero(t, p.Metrics.WaitTime()) + assert.EqualValues(t, 5, state.lastID.Load()) + + // Close + p.Close() + assert.EqualValues(t, 0, p.Capacity()) + assert.EqualValues(t, 0, p.Available()) + assert.EqualValues(t, 0, state.open.Load()) +} + +func TestMultiSettingsWithReset(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources [10]*Pooled[*TestConn] + var r *Pooled[*TestConn] + var err error + + settings := []*Setting{nil, sFoo, sBar, sBar, sFoo} + + // Test Get + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[i]) + require.NoError(t, err) + resources[i] = r + assert.EqualValues(t, 5-i-1, p.Available()) + assert.EqualValues(t, i+1, state.lastID.Load()) + assert.EqualValues(t, i+1, state.open.Load()) + } + + // Put all of them back + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + + // Getting all with same setting. + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[1]) // {foo} + require.NoError(t, err) + assert.Truef(t, r.Conn.setting == settings[1], "setting was not properly applied") + resources[i] = r + } + assert.EqualValues(t, 2, state.reset.Load()) // when setting was {bar} and getting for {foo} + assert.EqualValues(t, 0, p.Available()) + assert.EqualValues(t, 5, state.lastID.Load()) + assert.EqualValues(t, 5, state.open.Load()) + + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + + // Close + p.Close() + assert.EqualValues(t, 0, p.Capacity()) + assert.EqualValues(t, 0, p.Available()) + assert.EqualValues(t, 0, state.open.Load()) +} + +func TestApplySettingsFailure(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources []*Pooled[*TestConn] + var r *Pooled[*TestConn] + var err error + + settings := []*Setting{nil, sFoo, sBar, sBar, sFoo} + // get the resource and mark for failure + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[i]) + require.NoError(t, err) + r.Conn.failApply = true + resources = append(resources, r) + } + // put them back + for _, r = range resources { + p.put(r) + } + + // any new connection created will fail to apply setting + state.chaos.failApply = true + + // Get the resource with "foo" setting + // For an applied connection if the setting are same it will be returned as-is. + // Otherwise, will fail to get the resource. + var failCount int + resources = nil + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[1]) + if err != nil { + failCount++ + assert.EqualError(t, err, "ApplySetting failed") + continue + } + resources = append(resources, r) + } + // put them back + for _, r = range resources { + p.put(r) + } + require.Equal(t, 3, failCount) + + // should be able to get all the resource with no setting + resources = nil + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, nil) + require.NoError(t, err) + resources = append(resources, r) + } + // put them back + for _, r = range resources { + p.put(r) + } +} diff --git a/go/vt/vttablet/endtoend/misc_test.go b/go/vt/vttablet/endtoend/misc_test.go index 147730d7319..6c58d0ca5da 100644 --- a/go/vt/vttablet/endtoend/misc_test.go +++ b/go/vt/vttablet/endtoend/misc_test.go @@ -265,8 +265,10 @@ func TestSidecarTables(t *testing.T) { } func TestConsolidation(t *testing.T) { - defer framework.Server.SetPoolSize(framework.Server.PoolSize()) - framework.Server.SetPoolSize(1) + defer framework.Server.SetPoolSize(context.Background(), framework.Server.PoolSize()) + + err := framework.Server.SetPoolSize(context.Background(), 1) + require.NoError(t, err) const tag = "Waits/Histograms/Consolidations/Count" diff --git a/go/vt/vttablet/endtoend/stream_test.go b/go/vt/vttablet/endtoend/stream_test.go index 05045fd6f7d..a3c73dd8152 100644 --- a/go/vt/vttablet/endtoend/stream_test.go +++ b/go/vt/vttablet/endtoend/stream_test.go @@ -17,6 +17,7 @@ limitations under the License. package endtoend import ( + "context" "errors" "fmt" "reflect" @@ -98,11 +99,13 @@ func TestStreamConsolidation(t *testing.T) { defaultPoolSize := framework.Server.StreamPoolSize() - framework.Server.SetStreamPoolSize(4) + err = framework.Server.SetStreamPoolSize(context.Background(), 4) + require.NoError(t, err) + framework.Server.SetStreamConsolidationBlocking(true) defer func() { - framework.Server.SetStreamPoolSize(defaultPoolSize) + _ = framework.Server.SetStreamPoolSize(context.Background(), defaultPoolSize) framework.Server.SetStreamConsolidationBlocking(false) }() diff --git a/go/vt/vttablet/tabletserver/connpool/pool.go b/go/vt/vttablet/tabletserver/connpool/pool.go index d2f8efb7af0..b905633c4b2 100644 --- a/go/vt/vttablet/tabletserver/connpool/pool.go +++ b/go/vt/vttablet/tabletserver/connpool/pool.go @@ -34,15 +34,9 @@ import ( "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/mysqlctl" "vitess.io/vitess/go/vt/servenv" - "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv" - - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) -// ErrConnPoolClosed is returned when the connection pool is closed. -var ErrConnPoolClosed = vterrors.New(vtrpcpb.Code_INTERNAL, "internal error: unexpected: conn pool is closed") - const ( getWithoutS = "GetWithoutSettings" getWithS = "GetWithSettings" diff --git a/go/vt/vttablet/tabletserver/connpool/pool_test.go b/go/vt/vttablet/tabletserver/connpool/pool_test.go index 43c27fa817a..dee119fe337 100644 --- a/go/vt/vttablet/tabletserver/connpool/pool_test.go +++ b/go/vt/vttablet/tabletserver/connpool/pool_test.go @@ -69,7 +69,7 @@ func TestConnPoolTimeout(t *testing.T) { require.NoError(t, err) defer dbConn.Recycle() _, err = connPool.Get(context.Background(), nil) - assert.EqualError(t, err, "resource pool timed out") + assert.EqualError(t, err, "connection pool timed out") } func TestConnPoolMaxWaiters(t *testing.T) { @@ -181,6 +181,7 @@ func TestConnPoolSetCapacity(t *testing.T) { connPool := newPool() connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) defer connPool.Close() +<<<<<<< HEAD err := connPool.SetCapacity(-10) if err == nil { t.Fatalf("set capacity should return error for negative capacity") @@ -189,6 +190,14 @@ func TestConnPoolSetCapacity(t *testing.T) { if err != nil { t.Fatalf("set capacity should succeed") } +======= + + assert.Panics(t, func() { + _ = connPool.SetCapacity(context.Background(), -10) + }) + err := connPool.SetCapacity(context.Background(), 10) + assert.NoError(t, err) +>>>>>>> afbce6aa87 (connpool: Allow time out during shutdown (#15979)) if connPool.Capacity() != 10 { t.Fatalf("capacity should be 10") } diff --git a/go/vt/vttablet/tabletserver/debugenv.go b/go/vt/vttablet/tabletserver/debugenv.go index e229c46cadd..7e2eb2191c4 100644 --- a/go/vt/vttablet/tabletserver/debugenv.go +++ b/go/vt/vttablet/tabletserver/debugenv.go @@ -17,6 +17,7 @@ limitations under the License. package tabletserver import ( + "context" "encoding/json" "fmt" "html" @@ -82,6 +83,17 @@ func debugEnvHandler(tsv *TabletServer, w http.ResponseWriter, r *http.Request) f(ival) msg = fmt.Sprintf("Setting %v to: %v", varname, value) } + setIntValCtx := func(f func(context.Context, int) error) { + ival, err := strconv.Atoi(value) + if err == nil { + err = f(r.Context(), ival) + if err == nil { + msg = fmt.Sprintf("Setting %v to: %v", varname, value) + return + } + } + msg = fmt.Sprintf("Failed setting value for %v: %v", varname, err) + } setInt64Val := func(f func(int64)) { ival, err := strconv.ParseInt(value, 10, 64) if err != nil { @@ -111,11 +123,11 @@ func debugEnvHandler(tsv *TabletServer, w http.ResponseWriter, r *http.Request) } switch varname { case "PoolSize": - setIntVal(tsv.SetPoolSize) + setIntValCtx(tsv.SetPoolSize) case "StreamPoolSize": - setIntVal(tsv.SetStreamPoolSize) + setIntValCtx(tsv.SetStreamPoolSize) case "TxPoolSize": - setIntVal(tsv.SetTxPoolSize) + setIntValCtx(tsv.SetTxPoolSize) case "MaxResultSize": setIntVal(tsv.SetMaxResultSize) case "WarnResultSize": diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index b69288329f7..4d4d3323cd0 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -772,6 +772,7 @@ func (qre *QueryExecutor) getConn() (*connpool.DBConn, error) { span, ctx := trace.NewSpan(qre.ctx, "QueryExecutor.getConn") defer span.Finish() +<<<<<<< HEAD start := time.Now() conn, err := qre.tsv.qe.conns.Get(ctx, qre.setting) @@ -783,12 +784,19 @@ func (qre *QueryExecutor) getConn() (*connpool.DBConn, error) { return nil, err } return nil, err +======= + defer func(start time.Time) { + qre.logStats.WaitingForConnection += time.Since(start) + }(time.Now()) + return qre.tsv.qe.conns.Get(ctx, qre.setting) +>>>>>>> afbce6aa87 (connpool: Allow time out during shutdown (#15979)) } func (qre *QueryExecutor) getStreamConn() (*connpool.DBConn, error) { span, ctx := trace.NewSpan(qre.ctx, "QueryExecutor.getStreamConn") defer span.Finish() +<<<<<<< HEAD start := time.Now() conn, err := qre.tsv.qe.streamConns.Get(ctx, qre.setting) switch err { @@ -799,6 +807,12 @@ func (qre *QueryExecutor) getStreamConn() (*connpool.DBConn, error) { return nil, err } return nil, err +======= + defer func(start time.Time) { + qre.logStats.WaitingForConnection += time.Since(start) + }(time.Now()) + return qre.tsv.qe.streamConns.Get(ctx, qre.setting) +>>>>>>> afbce6aa87 (connpool: Allow time out during shutdown (#15979)) } // txFetch fetches from a TxConnection. diff --git a/go/vt/vttablet/tabletserver/tabletserver.go b/go/vt/vttablet/tabletserver/tabletserver.go index 10a9a72bff0..2afdc29c1d5 100644 --- a/go/vt/vttablet/tabletserver/tabletserver.go +++ b/go/vt/vttablet/tabletserver/tabletserver.go @@ -1979,11 +1979,15 @@ func (tsv *TabletServer) EnableHistorian(enabled bool) { } // SetPoolSize changes the pool size to the specified value. -func (tsv *TabletServer) SetPoolSize(val int) { +func (tsv *TabletServer) SetPoolSize(ctx context.Context, val int) error { if val <= 0 { - return + return nil } +<<<<<<< HEAD tsv.qe.conns.SetCapacity(val) +======= + return tsv.qe.conns.SetCapacity(ctx, int64(val)) +>>>>>>> afbce6aa87 (connpool: Allow time out during shutdown (#15979)) } // PoolSize returns the pool size. @@ -1992,8 +1996,13 @@ func (tsv *TabletServer) PoolSize() int { } // SetStreamPoolSize changes the pool size to the specified value. +<<<<<<< HEAD func (tsv *TabletServer) SetStreamPoolSize(val int) { tsv.qe.streamConns.SetCapacity(val) +======= +func (tsv *TabletServer) SetStreamPoolSize(ctx context.Context, val int) error { + return tsv.qe.streamConns.SetCapacity(ctx, int64(val)) +>>>>>>> afbce6aa87 (connpool: Allow time out during shutdown (#15979)) } // SetStreamConsolidationBlocking sets whether the stream consolidator should wait for slow clients @@ -2007,8 +2016,13 @@ func (tsv *TabletServer) StreamPoolSize() int { } // SetTxPoolSize changes the tx pool size to the specified value. +<<<<<<< HEAD func (tsv *TabletServer) SetTxPoolSize(val int) { tsv.te.txPool.scp.conns.SetCapacity(val) +======= +func (tsv *TabletServer) SetTxPoolSize(ctx context.Context, val int) error { + return tsv.te.txPool.scp.conns.SetCapacity(ctx, int64(val)) +>>>>>>> afbce6aa87 (connpool: Allow time out during shutdown (#15979)) } // TxPoolSize returns the tx pool size. diff --git a/go/vt/vttablet/tabletserver/tabletserver_test.go b/go/vt/vttablet/tabletserver/tabletserver_test.go index 6d6e47e65e4..f1d9a1a3aa9 100644 --- a/go/vt/vttablet/tabletserver/tabletserver_test.go +++ b/go/vt/vttablet/tabletserver/tabletserver_test.go @@ -2046,7 +2046,9 @@ func TestConfigChanges(t *testing.T) { newSize := 10 newDuration := time.Duration(10 * time.Millisecond) - tsv.SetPoolSize(newSize) + err := tsv.SetPoolSize(context.Background(), newSize) + require.NoError(t, err) + if val := tsv.PoolSize(); val != newSize { t.Errorf("PoolSize: %d, want %d", val, newSize) } @@ -2054,7 +2056,9 @@ func TestConfigChanges(t *testing.T) { t.Errorf("tsv.qe.connPool.Capacity: %d, want %d", val, newSize) } - tsv.SetStreamPoolSize(newSize) + err = tsv.SetStreamPoolSize(context.Background(), newSize) + require.NoError(t, err) + if val := tsv.StreamPoolSize(); val != newSize { t.Errorf("StreamPoolSize: %d, want %d", val, newSize) } @@ -2062,7 +2066,9 @@ func TestConfigChanges(t *testing.T) { t.Errorf("tsv.qe.streamConnPool.Capacity: %d, want %d", val, newSize) } - tsv.SetTxPoolSize(newSize) + err = tsv.SetTxPoolSize(context.Background(), newSize) + require.NoError(t, err) + if val := tsv.TxPoolSize(); val != newSize { t.Errorf("TxPoolSize: %d, want %d", val, newSize) } diff --git a/go/vt/vttablet/tabletserver/tx_pool_test.go b/go/vt/vttablet/tabletserver/tx_pool_test.go index 3515310c481..67bab15a80e 100644 --- a/go/vt/vttablet/tabletserver/tx_pool_test.go +++ b/go/vt/vttablet/tabletserver/tx_pool_test.go @@ -214,8 +214,15 @@ func primeTxPoolWithConnection(t *testing.T, ctx context.Context) (*fakesqldb.DB db := fakesqldb.New(t) txPool, _ := newTxPool() // Set the capacity to 1 to ensure that the db connection is reused. +<<<<<<< HEAD txPool.scp.conns.SetCapacity(1) txPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) +======= + err := txPool.scp.conns.SetCapacity(context.Background(), 1) + require.NoError(t, err) + params := dbconfigs.New(db.ConnParams()) + txPool.Open(params, params, params) +>>>>>>> afbce6aa87 (connpool: Allow time out during shutdown (#15979)) // Run a query to trigger a database connection. That connection will be // reused by subsequent transactions.