From 9494d87252b662051a1d3a9cf7305681c4a8a1fc Mon Sep 17 00:00:00 2001 From: Travis Bischel Date: Wed, 17 Jul 2024 23:40:53 -0600 Subject: [PATCH] kgo: allow record ctx cancelation to propagate a bit more If a record's context is canceled, we now allow it to be failed in two more locations: * while the producer ID is loading -- we can actually now cancel the producer ID loading request (which may also benefit people using transactions that want to force quit the client) * while a sink is backing off due to request failures For people using transactions, canceling a context now allows you to force quit in more areas, but the same caveat applies: your client will likely end up in an invalid transactional state and be unable to continue. For #769. --- pkg/kgo/errors.go | 17 ++++- pkg/kgo/helpers_test.go | 26 +++++++ pkg/kgo/produce_request_test.go | 3 +- pkg/kgo/producer.go | 26 ++++--- pkg/kgo/sink.go | 121 ++++++++++++++++++++++++++++++-- pkg/kgo/source.go | 3 + pkg/kgo/txn.go | 26 +++---- 7 files changed, 192 insertions(+), 30 deletions(-) diff --git a/pkg/kgo/errors.go b/pkg/kgo/errors.go index 37c4f8fc..3ff1dbfe 100644 --- a/pkg/kgo/errors.go +++ b/pkg/kgo/errors.go @@ -53,7 +53,7 @@ func isRetryableBrokerErr(err error) bool { } // We could have a retryable producer ID failure, which then bubbled up // as errProducerIDLoadFail so as to be retried later. - if errors.Is(err, errProducerIDLoadFail) { + if pe := (*errProducerIDLoadFail)(nil); errors.As(err, &pe) { return true } // We could have chosen a broker, and then a concurrent metadata update @@ -139,8 +139,6 @@ var ( // restart a new connection ourselves. errSaslReauthLoop = errors.New("the broker is repeatedly giving us sasl lifetimes that are too short to write a request") - errProducerIDLoadFail = errors.New("unable to initialize a producer ID due to request failures") - // A temporary error returned when Kafka replies with a different // correlation ID than we were expecting for the request the client // issued. @@ -224,6 +222,19 @@ type ErrFirstReadEOF struct { err error } +type errProducerIDLoadFail struct { + err error +} + +func (e *errProducerIDLoadFail) Error() string { + if e.err == nil { + return "unable to initialize a producer ID due to request failures" + } + return fmt.Sprintf("unable to initialize a producer ID due to request failures: %v", e.err) +} + +func (e *errProducerIDLoadFail) Unwrap() error { return e.err } + const ( firstReadSASL uint8 = iota firstReadTLS diff --git a/pkg/kgo/helpers_test.go b/pkg/kgo/helpers_test.go index ccf32865..078fbf77 100644 --- a/pkg/kgo/helpers_test.go +++ b/pkg/kgo/helpers_test.go @@ -55,6 +55,32 @@ var ( npartitionsAt int64 ) +type slowConn struct { + net.Conn +} + +func (s *slowConn) Write(p []byte) (int, error) { + time.Sleep(100 * time.Millisecond) + return s.Conn.Write(p) +} + +func (s *slowConn) Read(p []byte) (int, error) { + time.Sleep(100 * time.Millisecond) + return s.Conn.Read(p) +} + +type slowDialer struct { + d net.Dialer +} + +func (s *slowDialer) DialContext(ctx context.Context, network, host string) (net.Conn, error) { + c, err := s.d.DialContext(ctx, network, host) + if err != nil { + return nil, err + } + return &slowConn{c}, nil +} + func init() { var err error if n, _ := strconv.Atoi(os.Getenv("KGO_TEST_RF")); n > 0 { diff --git a/pkg/kgo/produce_request_test.go b/pkg/kgo/produce_request_test.go index 66e06177..efd26270 100644 --- a/pkg/kgo/produce_request_test.go +++ b/pkg/kgo/produce_request_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "hash/crc32" + "strings" "testing" "time" @@ -90,7 +91,7 @@ func TestIssue769(t *testing.T) { case <-timer.C: t.Fatal("expected record to fail within 3s") } - if pe := (*errProducerIDLoadFail)(nil); !errors.As(rerr, &pe) || !errors.Is(pe.err, context.Canceled) { + if pe := (*errProducerIDLoadFail)(nil); !errors.As(rerr, &pe) || !(errors.Is(pe.err, context.Canceled) || strings.Contains(pe.err.Error(), "canceled")) { t.Errorf("got %v != exp errProducerIDLoadFail{context.Canceled}", rerr) } } diff --git a/pkg/kgo/producer.go b/pkg/kgo/producer.go index cfc6aa14..d1f42721 100644 --- a/pkg/kgo/producer.go +++ b/pkg/kgo/producer.go @@ -358,7 +358,11 @@ func (cl *Client) TryProduce( // retries. If any of these conditions are hit and it is currently safe to fail // records, all buffered records for the relevant partition are failed. Only // the first record's context in a batch is considered when determining whether -// the batch should be canceled. +// the batch should be canceled. A record is not safe to fail if the client +// is idempotently producing and a request has been sent; in this case, the +// client cannot know if the broker actually processed the request (if so, then +// removing the records from the client will create errors the next time you +// produce). // // If the client is transactional and a transaction has not been begun, the // promise is immediately called with an error corresponding to not being in a @@ -626,7 +630,7 @@ func (cl *Client) ProducerID(ctx context.Context) (int64, int16, error) { go func() { defer close(done) - id, epoch, err = cl.producerID() + id, epoch, err = cl.producerID(ctx2fn(ctx)) }() select { @@ -648,7 +652,7 @@ var errReloadProducerID = errors.New("producer id needs reloading") // initProducerID initializes the client's producer ID for idempotent // producing only (no transactions, which are more special). After the first // load, this clears all buffered unknown topics. -func (cl *Client) producerID() (int64, int16, error) { +func (cl *Client) producerID(ctxFn func() context.Context) (int64, int16, error) { p := &cl.producer id := p.id.Load().(*producerID) @@ -677,7 +681,7 @@ func (cl *Client) producerID() (int64, int16, error) { } p.id.Store(id) } else { - newID, keep := cl.doInitProducerID(id.id, id.epoch) + newID, keep := cl.doInitProducerID(ctxFn, id.id, id.epoch) if keep { id = newID // Whenever we have a new producer ID, we need @@ -695,7 +699,7 @@ func (cl *Client) producerID() (int64, int16, error) { id = &producerID{ id: id.id, epoch: id.epoch, - err: errProducerIDLoadFail, + err: &errProducerIDLoadFail{newID.err}, } } } @@ -772,7 +776,7 @@ func (cl *Client) failProducerID(id int64, epoch int16, err error) { // doInitProducerID inits the idempotent ID and potentially the transactional // producer epoch, returning whether to keep the result. -func (cl *Client) doInitProducerID(lastID int64, lastEpoch int16) (*producerID, bool) { +func (cl *Client) doInitProducerID(ctxFn func() context.Context, lastID int64, lastEpoch int16) (*producerID, bool) { cl.cfg.logger.Log(LogLevelInfo, "initializing producer id") req := kmsg.NewPtrInitProducerIDRequest() req.TransactionalID = cl.cfg.txnID @@ -782,7 +786,8 @@ func (cl *Client) doInitProducerID(lastID int64, lastEpoch int16) (*producerID, req.TransactionTimeoutMillis = int32(cl.cfg.txnTimeout.Milliseconds()) } - resp, err := req.RequestWith(cl.ctx, cl) + ctx := ctxFn() + resp, err := req.RequestWith(ctx, cl) if err != nil { if errors.Is(err, errUnknownRequestKey) || errors.Is(err, errBrokerTooOld) { cl.cfg.logger.Log(LogLevelInfo, "unable to initialize a producer id because the broker is too old or the client is pinned to an old version, continuing without a producer id") @@ -887,13 +892,14 @@ func (cl *Client) addUnknownTopicRecord(pr promisedRec) { } unknown.buffered = append(unknown.buffered, pr) if len(unknown.buffered) == 1 { - go cl.waitUnknownTopic(pr.ctx, pr.Topic, unknown) + go cl.waitUnknownTopic(pr.ctx, pr.Record.Context, pr.Topic, unknown) } } // waitUnknownTopic waits for a notification func (cl *Client) waitUnknownTopic( - rctx context.Context, + pctx context.Context, // context passed to Produce + rctx context.Context, // context on the record itself topic string, unknown *unknownTopicProduces, ) { @@ -921,6 +927,8 @@ func (cl *Client) waitUnknownTopic( for err == nil { select { + case <-pctx.Done(): + err = pctx.Err() case <-rctx.Done(): err = rctx.Err() case <-cl.ctx.Done(): diff --git a/pkg/kgo/sink.go b/pkg/kgo/sink.go index 49bff3b1..8d19c572 100644 --- a/pkg/kgo/sink.go +++ b/pkg/kgo/sink.go @@ -208,6 +208,7 @@ func (s *sink) maybeBackoff() { select { case <-after.C: case <-s.cl.ctx.Done(): + case <-s.anyCtx().Done(): } } @@ -247,6 +248,34 @@ func (s *sink) drain() { } } +// Returns the first context encountered ranging across all records. +// This does not use defers to make it clear at the return that all +// unlocks are called in proper order. Ideally, do not call this func +// due to lock intensity. +func (s *sink) anyCtx() context.Context { + s.recBufsMu.Lock() + for _, recBuf := range s.recBufs { + recBuf.mu.Lock() + if len(recBuf.batches) > 0 { + batch0 := recBuf.batches[0] + batch0.mu.Lock() + if batch0.canFailFromLoadErrs && len(batch0.records) > 0 { + r0 := batch0.records[0] + if rctx := r0.cancelingCtx(); rctx != nil { + batch0.mu.Unlock() + recBuf.mu.Unlock() + s.recBufsMu.Unlock() + return rctx + } + } + batch0.mu.Unlock() + } + recBuf.mu.Unlock() + } + s.recBufsMu.Unlock() + return context.Background() +} + func (s *sink) produce(sem <-chan struct{}) bool { var produced bool defer func() { @@ -267,6 +296,7 @@ func (s *sink) produce(sem <-chan struct{}) bool { // - auth failure // - transactional: a produce failure that failed the producer ID // - AddPartitionsToTxn failure (see just below) + // - some head-of-line context failure // // All but the first error is fatal. Recovery may be possible with // EndTransaction in specific cases, but regardless, all buffered @@ -275,10 +305,71 @@ func (s *sink) produce(sem <-chan struct{}) bool { // NOTE: we init the producer ID before creating a request to ensure we // are always using the latest id/epoch with the proper sequence // numbers. (i.e., resetAllSequenceNumbers && producerID logic combo). - id, epoch, err := s.cl.producerID() + // + // For the first-discovered-record-head-of-line context, we want to + // avoid looking it up if possible (which is why producerID takes a + // ctxFn). If we do use one, we want to be sure that the + // context.Canceled error is from *that* context rather than the client + // context or something else. So, we go through some special care to + // track setting the ctx / looking up if it is canceled. + var holCtxMu sync.Mutex + var holCtx context.Context + ctxFn := func() context.Context { + holCtxMu.Lock() + defer holCtxMu.Unlock() + holCtx = s.anyCtx() + return holCtx + } + isHolCtxDone := func() bool { + holCtxMu.Lock() + defer holCtxMu.Unlock() + if holCtx == nil { + return false + } + select { + case <-holCtx.Done(): + return true + default: + } + return false + } + + id, epoch, err := s.cl.producerID(ctxFn) if err != nil { + var pe *errProducerIDLoadFail switch { - case errors.Is(err, errProducerIDLoadFail): + case errors.As(err, &pe): + if errors.Is(pe.err, context.Canceled) && isHolCtxDone() { + // Some head-of-line record in a partition had a context cancelation. + // We look for any partition with HOL cancelations and fail them all. + s.cl.cfg.logger.Log(LogLevelInfo, "the first record in some partition(s) had a context cancelation; failing all relevant partitions", "broker", logID(s.nodeID)) + s.recBufsMu.Lock() + defer s.recBufsMu.Unlock() + for _, recBuf := range s.recBufs { + recBuf.mu.Lock() + var failAll bool + if len(recBuf.batches) > 0 { + batch0 := recBuf.batches[0] + batch0.mu.Lock() + if batch0.canFailFromLoadErrs && len(batch0.records) > 0 { + r0 := batch0.records[0] + if rctx := r0.cancelingCtx(); rctx != nil { + select { + case <-rctx.Done(): + failAll = true // we must not call failAllRecords here, because failAllRecords locks batches! + default: + } + } + } + batch0.mu.Unlock() + } + if failAll { + recBuf.failAllRecords(err) + } + recBuf.mu.Unlock() + } + return true + } s.cl.bumpRepeatedLoadErr(err) s.cl.cfg.logger.Log(LogLevelWarn, "unable to load producer ID, bumping client's buffered record load errors by 1 and retrying") return true // whatever caused our produce, we did nothing, so keep going @@ -385,6 +476,9 @@ func (s *sink) doSequenced( promise: promise, } + // We can NOT use any record context. If we do, we force the request to + // fail while also force the batch to be unfailable (due to no + // response), br, err := s.cl.brokerOrErr(s.cl.ctx, s.nodeID, errUnknownBroker) if err != nil { wait.err = err @@ -432,6 +526,11 @@ func (s *sink) doTxnReq( req.batches.eachOwnerLocked(seqRecBatch.removeFromTxn) } }() + // We do NOT let record context cancelations fail this request: doing + // so would put the transactional ID in an unknown state. This is + // similar to the warning we give in the txn.go file, but the + // difference there is the user knows explicitly at the function call + // that canceling the context will opt them into invalid state. err = s.cl.doWithConcurrentTransactions(s.cl.ctx, "AddPartitionsToTxn", func() error { stripped, err = s.issueTxnReq(req, txnReq) return err @@ -1393,6 +1492,16 @@ type promisedRec struct { *Record } +func (pr promisedRec) cancelingCtx() context.Context { + if pr.ctx.Done() != nil { + return pr.ctx + } + if pr.Context.Done() != nil { + return pr.Context + } + return nil +} + // recBatch is the type used for buffering records before they are written. type recBatch struct { owner *recBuf // who owns us @@ -1421,10 +1530,12 @@ type recBatch struct { // Returns an error if the batch should fail. func (b *recBatch) maybeFailErr(cfg *cfg) error { if len(b.records) > 0 { - ctx := b.records[0].ctx + r0 := &b.records[0] select { - case <-ctx.Done(): - return ctx.Err() + case <-r0.ctx.Done(): + return r0.ctx.Err() + case <-r0.Context.Done(): + return r0.Context.Err() default: } } diff --git a/pkg/kgo/source.go b/pkg/kgo/source.go index 603b0864..7a55ef0c 100644 --- a/pkg/kgo/source.go +++ b/pkg/kgo/source.go @@ -894,6 +894,9 @@ func (s *source) fetch(consumerSession *consumerSession, doneFetch chan<- struct // reload offsets *always* triggers a metadata update. if updateWhy != nil { why := updateWhy.reason(fmt.Sprintf("fetch had inner topic errors from broker %d", s.nodeID)) + // loadWithSessionNow triggers a metadata update IF there are + // offsets to reload. If there are no offsets to reload, we + // trigger one here. if !reloadOffsets.loadWithSessionNow(consumerSession, why) { if updateWhy.isOnly(kerr.UnknownTopicOrPartition) || updateWhy.isOnly(kerr.UnknownTopicID) { s.cl.triggerUpdateMetadata(false, why) diff --git a/pkg/kgo/txn.go b/pkg/kgo/txn.go index 7df9c65e..25cfd443 100644 --- a/pkg/kgo/txn.go +++ b/pkg/kgo/txn.go @@ -13,6 +13,8 @@ import ( "github.com/twmb/franz-go/pkg/kerr" ) +func ctx2fn(ctx context.Context) func() context.Context { return func() context.Context { return ctx } } + // TransactionEndTry is simply a named bool. type TransactionEndTry bool @@ -468,7 +470,7 @@ func (cl *Client) BeginTransaction() error { return errors.New("invalid attempt to begin a transaction while already in a transaction") } - needRecover, didRecover, err := cl.maybeRecoverProducerID() + needRecover, didRecover, err := cl.maybeRecoverProducerID(context.Background()) if needRecover && !didRecover { cl.cfg.logger.Log(LogLevelInfo, "unable to begin transaction due to unrecoverable producer id error", "err", err) return fmt.Errorf("producer ID has a fatal, unrecoverable error, err: %w", err) @@ -557,7 +559,7 @@ func (cl *Client) EndAndBeginTransaction( // expect to be in one. defer func() { if rerr == nil { - needRecover, didRecover, err := cl.maybeRecoverProducerID() + needRecover, didRecover, err := cl.maybeRecoverProducerID(ctx) if needRecover && !didRecover { cl.cfg.logger.Log(LogLevelInfo, "unable to begin transaction due to unrecoverable producer id error", "err", err) rerr = fmt.Errorf("producer ID has a fatal, unrecoverable error, err: %w", err) @@ -620,12 +622,12 @@ func (cl *Client) EndAndBeginTransaction( } // From EndTransaction: if the pid has an error, we may try to recover. - id, epoch, err := cl.producerID() + id, epoch, err := cl.producerID(ctx2fn(ctx)) if err != nil { if commit { return kerr.OperationNotAttempted } - if _, didRecover, _ := cl.maybeRecoverProducerID(); didRecover { + if _, didRecover, _ := cl.maybeRecoverProducerID(ctx); didRecover { return nil } } @@ -882,7 +884,7 @@ func (cl *Client) EndTransaction(ctx context.Context, commit TransactionEndTry) return nil } - id, epoch, err := cl.producerID() + id, epoch, err := cl.producerID(ctx2fn(ctx)) if err != nil { if commit { return kerr.OperationNotAttempted @@ -892,7 +894,7 @@ func (cl *Client) EndTransaction(ctx context.Context, commit TransactionEndTry) // there is no reason to issue an abort now that the id is // different. Otherwise, we issue our EndTxn which will likely // fail, but that is ok, we will just return error. - _, didRecover, _ := cl.maybeRecoverProducerID() + _, didRecover, _ := cl.maybeRecoverProducerID(ctx) if didRecover { return nil } @@ -939,11 +941,11 @@ func (cl *Client) EndTransaction(ctx context.Context, commit TransactionEndTry) // error), whether it is possible to recover, and, if not, the error. // // We call this when beginning a transaction or when ending with an abort. -func (cl *Client) maybeRecoverProducerID() (necessary, did bool, err error) { +func (cl *Client) maybeRecoverProducerID(ctx context.Context) (necessary, did bool, err error) { cl.producer.mu.Lock() defer cl.producer.mu.Unlock() - id, epoch, err := cl.producerID() + id, epoch, err := cl.producerID(ctx2fn(ctx)) if err == nil { return false, false, nil } @@ -1009,7 +1011,7 @@ start: select { case <-time.After(backoff): case <-ctx.Done(): - cl.cfg.logger.Log(LogLevelError, fmt.Sprintf("abandoning %s retry due to client ctx quitting", name)) + cl.cfg.logger.Log(LogLevelError, fmt.Sprintf("abandoning %s retry due to request ctx quitting", name)) return err case <-cl.ctx.Done(): cl.cfg.logger.Log(LogLevelError, fmt.Sprintf("abandoning %s retry due to client ctx quitting", name)) @@ -1081,7 +1083,7 @@ func (cl *Client) commitTransactionOffsets( } if !g.offsetsAddedToTxn { - if err := cl.addOffsetsToTxn(g.ctx, g.cfg.group); err != nil { + if err := cl.addOffsetsToTxn(ctx, g.cfg.group); err != nil { if onDone != nil { onDone(nil, nil, err) } @@ -1111,7 +1113,7 @@ func (cl *Client) commitTransactionOffsets( // this initializes one if it is not yet initialized. This would only be the // case if trying to commit before any records have been sent. func (cl *Client) addOffsetsToTxn(ctx context.Context, group string) error { - id, epoch, err := cl.producerID() + id, epoch, err := cl.producerID(ctx2fn(ctx)) if err != nil { return err } @@ -1218,7 +1220,7 @@ func (g *groupConsumer) prepareTxnOffsetCommit(ctx context.Context, uncommitted // We're now generating the producerID before addOffsetsToTxn. // We will not make this request until after addOffsetsToTxn, but it's possible to fail here due to a failed producerID. - id, epoch, err := g.cl.producerID() + id, epoch, err := g.cl.producerID(ctx2fn(ctx)) if err != nil { return req, err }