diff --git a/config.go b/config.go index 2f50ad1..e5f3215 100644 --- a/config.go +++ b/config.go @@ -6,12 +6,13 @@ import ( ) type configuration struct { - txOptions *sql.TxOptions - splitStatement bool - panicOnBindError bool - stackTraceOnError bool - parameterPrefix string - retrySleep time.Duration + txOptions *sql.TxOptions + splitStatement bool + panicOnBindError bool + normalizeContextCancellation bool + stackTraceOnError bool + parameterPrefix string + retrySleep time.Duration } func NewPool(handle *sql.DB, options ...option) ConnectionPool { @@ -31,6 +32,10 @@ func newPool(handle *sql.DB, config configuration) ConnectionPool { pool = NewSplitStatementConnectionPool(pool, config.parameterPrefix) } + if config.normalizeContextCancellation { + pool = NewNormalizeContextCancellationConnectionPool(pool) + } + if config.stackTraceOnError { pool = NewStackTraceConnectionPool(pool) } @@ -59,6 +64,9 @@ func (singleton) TxOptions(value *sql.TxOptions) option { func (singleton) PanicOnBindError(value bool) option { return func(this *configuration) { this.panicOnBindError = value } } +func (singleton) NormalizeContextCancellation(value bool) option { + return func(this *configuration) { this.normalizeContextCancellation = value } +} func (singleton) MySQL() option { return func(this *configuration) { this.splitStatement = true; this.parameterPrefix = "?" } } @@ -86,6 +94,7 @@ func (singleton) defaults(options ...option) []option { var defaultTxOptions = &sql.TxOptions{Isolation: sql.LevelReadCommitted} const defaultStackTraceErrDiagnostics = true const defaultPanicOnBindError = true + const defaultNormalizeContextCancellation = true const defaultSplitStatement = true const defaultParameterPrefix = "?" const defaultRetrySleep = 0 @@ -93,6 +102,7 @@ func (singleton) defaults(options ...option) []option { return append([]option{ Options.TxOptions(defaultTxOptions), Options.PanicOnBindError(defaultPanicOnBindError), + Options.NormalizeContextCancellation(defaultNormalizeContextCancellation), Options.StackTraceErrDiagnostics(defaultStackTraceErrDiagnostics), Options.ParameterPrefix(defaultParameterPrefix), Options.SplitStatement(defaultSplitStatement), diff --git a/normalize_context_cancellation_connection_pool.go b/normalize_context_cancellation_connection_pool.go new file mode 100644 index 0000000..df7069e --- /dev/null +++ b/normalize_context_cancellation_connection_pool.go @@ -0,0 +1,55 @@ +package sqldb + +import ( + "context" + "fmt" + "strings" +) + +type NormalizeContextCancellationConnectionPool struct { + inner ConnectionPool +} + +func NewNormalizeContextCancellationConnectionPool(inner ConnectionPool) *NormalizeContextCancellationConnectionPool { + return &NormalizeContextCancellationConnectionPool{inner: inner} +} + +func (this *NormalizeContextCancellationConnectionPool) Ping(ctx context.Context) error { + return this.normalizeContextCancellationError(this.inner.Ping(ctx)) +} + +func (this *NormalizeContextCancellationConnectionPool) BeginTransaction(ctx context.Context) (Transaction, error) { + if tx, err := this.inner.BeginTransaction(ctx); err == nil { + return NewStackTraceTransaction(tx), nil + } else { + return nil, this.normalizeContextCancellationError(err) + } +} + +func (this *NormalizeContextCancellationConnectionPool) Close() error { + return this.normalizeContextCancellationError(this.inner.Close()) +} + +func (this *NormalizeContextCancellationConnectionPool) Execute(ctx context.Context, statement string, parameters ...interface{}) (uint64, error) { + affected, err := this.inner.Execute(ctx, statement, parameters...) + return affected, this.normalizeContextCancellationError(err) +} + +func (this *NormalizeContextCancellationConnectionPool) Select(ctx context.Context, query string, parameters ...interface{}) (SelectResult, error) { + result, err := this.inner.Select(ctx, query, parameters...) + return result, this.normalizeContextCancellationError(err) +} + +// TODO remove manual check of "use of closed network connection" with release of https://github.com/go-sql-driver/mysql/pull/1615 +func (this *NormalizeContextCancellationConnectionPool) normalizeContextCancellationError(err error) error { + if err == nil { + return nil + } + if strings.Contains(err.Error(), "operation was canceled") { + return fmt.Errorf("%w: %w", context.Canceled, err) + } + if strings.Contains(err.Error(), "use of closed network connection") { + return fmt.Errorf("%w: %w", context.Canceled, err) + } + return err +} diff --git a/normalize_context_cancellation_connection_pool_test.go b/normalize_context_cancellation_connection_pool_test.go new file mode 100644 index 0000000..5e1455e --- /dev/null +++ b/normalize_context_cancellation_connection_pool_test.go @@ -0,0 +1,206 @@ +package sqldb + +import ( + "context" + "errors" + "testing" + + "github.com/smarty/assertions/should" + "github.com/smarty/gunit" +) + +func TestNormalizeContextCancellationConnectionPoolFixture(t *testing.T) { + gunit.Run(new(NormalizeContextCancellationConnectionPoolFixture), t) +} + +type NormalizeContextCancellationConnectionPoolFixture struct { + *gunit.Fixture + + inner *FakeConnectionPool + adapter *NormalizeContextCancellationConnectionPool +} + +func (this *NormalizeContextCancellationConnectionPoolFixture) Setup() { + this.inner = &FakeConnectionPool{} + this.adapter = NewNormalizeContextCancellationConnectionPool(this.inner) +} + +func (this *NormalizeContextCancellationConnectionPoolFixture) TestPing_Successful() { + err := this.adapter.Ping(context.Background()) + + this.So(err, should.BeNil) + this.So(this.inner.pingCalls, should.Equal, 1) +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestPing_Failed() { + pingErr := errors.New("PING ERROR") + this.inner.pingError = pingErr + + err := this.adapter.Ping(context.Background()) + + this.So(this.inner.pingCalls, should.Equal, 1) + this.So(err, should.Equal, pingErr) +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestPing_AdaptContextCancelled() { + this.inner.pingError = operationCanceledErr + + err := this.adapter.Ping(context.Background()) + + this.So(this.inner.pingCalls, should.Equal, 1) + this.So(errors.Is(err, operationCanceledErr), should.BeTrue) + this.So(errors.Is(err, context.Canceled), should.BeTrue) +} + +func (this *NormalizeContextCancellationConnectionPoolFixture) TestBeginTransaction_Successful() { + transaction := new(FakeTransaction) + this.inner.transaction = transaction + + tx, err := this.adapter.BeginTransaction(context.Background()) + + this.So(err, should.BeNil) + this.So(this.inner.transactionCalls, should.Equal, 1) + this.So(tx, should.NotBeNil) +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestBeginTransaction_Failed() { + transactionErr := errors.New("BEGIN TRANSACTION ERROR") + this.inner.transactionError = transactionErr + + tx, err := this.adapter.BeginTransaction(context.Background()) + + this.So(tx, should.BeNil) + this.So(err, should.Equal, transactionErr) +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestBeginTransaction_AdaptContextCancelled() { + this.inner.transactionError = operationCanceledErr + + tx, err := this.adapter.BeginTransaction(context.Background()) + + this.So(tx, should.BeNil) + this.So(errors.Is(err, operationCanceledErr), should.BeTrue) + this.So(errors.Is(err, context.Canceled), should.BeTrue) +} + +func (this *NormalizeContextCancellationConnectionPoolFixture) TestClose_Successful() { + err := this.adapter.Close() + + this.So(err, should.BeNil) + this.So(this.inner.closeCalls, should.Equal, 1) +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestClose_Failed() { + closeErr := errors.New("CLOSE ERROR") + this.inner.closeError = closeErr + + err := this.adapter.Close() + + this.So(this.inner.closeCalls, should.Equal, 1) + this.So(err, should.Equal, closeErr) +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestClose_AdaptContextCancelled() { + this.inner.closeError = operationCanceledErr + + err := this.adapter.Close() + + this.So(this.inner.closeCalls, should.Equal, 1) + this.So(errors.Is(err, operationCanceledErr), should.BeTrue) + this.So(errors.Is(err, context.Canceled), should.BeTrue) +} + +func (this *NormalizeContextCancellationConnectionPoolFixture) TestExecute_Successful() { + this.inner.executeResult = 42 + + result, err := this.adapter.Execute(context.Background(), "statement") + + this.So(result, should.Equal, 42) + this.So(err, should.BeNil) + this.So(this.inner.executeCalls, should.Equal, 1) + this.So(this.inner.executeStatement, should.Equal, "statement") +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestExecute_Failed() { + this.inner.executeResult = 42 + executeErr := errors.New("EXECUTE ERROR") + this.inner.executeError = executeErr + + result, err := this.adapter.Execute(context.Background(), "statement") + + this.So(result, should.Equal, 42) + this.So(err, should.Equal, executeErr) + this.So(this.inner.executeCalls, should.Equal, 1) + this.So(this.inner.executeStatement, should.Equal, "statement") +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestExecute_AdaptContextCancelled() { + this.inner.executeResult = 42 + this.inner.executeError = operationCanceledErr + + result, err := this.adapter.Execute(context.Background(), "statement") + + this.So(result, should.Equal, 42) + this.So(this.inner.executeCalls, should.Equal, 1) + this.So(this.inner.executeStatement, should.Equal, "statement") + this.So(errors.Is(err, operationCanceledErr), should.BeTrue) + this.So(errors.Is(err, context.Canceled), should.BeTrue) +} + +func (this *NormalizeContextCancellationConnectionPoolFixture) TestSelect_Successful() { + expectedResult := new(FakeSelectResult) + this.inner.selectResult = expectedResult + + result, err := this.adapter.Select(context.Background(), "query", 1, 2, 3) + + this.So(result, should.Equal, expectedResult) + this.So(err, should.BeNil) + this.So(this.inner.selectCalls, should.Equal, 1) + this.So(this.inner.selectStatement, should.Equal, "query") + this.So(this.inner.selectParameters, should.Equal, []any{1, 2, 3}) +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestSelect_Failed() { + expectedResult := new(FakeSelectResult) + this.inner.selectResult = expectedResult + selectErr := errors.New("SELECT ERROR") + this.inner.selectError = selectErr + + result, err := this.adapter.Select(context.Background(), "query", 1, 2, 3) + + this.So(result, should.Equal, expectedResult) + this.So(err, should.Equal, selectErr) + this.So(this.inner.selectCalls, should.Equal, 1) + this.So(this.inner.selectStatement, should.Equal, "query") + this.So(this.inner.selectParameters, should.Equal, []any{1, 2, 3}) +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestSelect_AdaptContextCancelled() { + expectedResult := new(FakeSelectResult) + this.inner.selectResult = expectedResult + this.inner.selectError = operationCanceledErr + + result, err := this.adapter.Select(context.Background(), "query", 1, 2, 3) + + this.So(result, should.Equal, expectedResult) + this.So(this.inner.selectCalls, should.Equal, 1) + this.So(this.inner.selectStatement, should.Equal, "query") + this.So(this.inner.selectParameters, should.Equal, []any{1, 2, 3}) + this.So(errors.Is(err, operationCanceledErr), should.BeTrue) + this.So(errors.Is(err, context.Canceled), should.BeTrue) +} + +func (this *NormalizeContextCancellationConnectionPoolFixture) TestContextCancellationErrorAdapter_NilError() { + err := this.adapter.normalizeContextCancellationError(nil) + this.So(err, should.BeNil) +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestContextCancellationErrorAdapter_GenericError() { + genericErr := errors.New("generic error") + err := this.adapter.normalizeContextCancellationError(genericErr) + this.So(err, should.Equal, genericErr) +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestContextCancellationErrorAdapter_OperationCanceledError() { + err := this.adapter.normalizeContextCancellationError(operationCanceledErr) + this.So(errors.Is(err, operationCanceledErr), should.BeTrue) + this.So(errors.Is(err, context.Canceled), should.BeTrue) +} +func (this *NormalizeContextCancellationConnectionPoolFixture) TestContextCancellationErrorAdapter_ClosedConnectionError() { + err := this.adapter.normalizeContextCancellationError(closedNetworkConnectionErr) + this.So(errors.Is(err, closedNetworkConnectionErr), should.BeTrue) + this.So(errors.Is(err, context.Canceled), should.BeTrue) +} + +var ( + operationCanceledErr = errors.New("operation was canceled") + closedNetworkConnectionErr = errors.New("use of closed network connection") +)