diff --git a/pkg/kgo/consumer_direct_test.go b/pkg/kgo/consumer_direct_test.go index 884c427b..dc12083d 100644 --- a/pkg/kgo/consumer_direct_test.go +++ b/pkg/kgo/consumer_direct_test.go @@ -338,7 +338,8 @@ func TestPauseIssue489(t *testing.T) { exit.Store(true) } }) - time.Sleep(100 * time.Microsecond) + cl.Flush(ctx) + time.Sleep(50 * time.Microsecond) } }() defer cancel() @@ -416,7 +417,8 @@ func TestPauseIssueOct2023(t *testing.T) { exit.Store(true) } }) - time.Sleep(100 * time.Microsecond) + cl.Flush(ctx) + time.Sleep(50 * time.Microsecond) } }() defer cancel() diff --git a/pkg/kgo/produce_request_test.go b/pkg/kgo/produce_request_test.go index 8a652a2f..2c7cde56 100644 --- a/pkg/kgo/produce_request_test.go +++ b/pkg/kgo/produce_request_test.go @@ -2,13 +2,75 @@ package kgo import ( "bytes" + "context" + "errors" "hash/crc32" + "math/rand" + "strings" + "sync" + "sync/atomic" "testing" "github.com/twmb/franz-go/pkg/kbin" "github.com/twmb/franz-go/pkg/kmsg" ) +func TestClient_Produce(t *testing.T) { + var ( + topic, cleanup = tmpTopicPartitions(t, 1) + numWorkers = 50 + recsToWrite = int64(20_000) + + workers sync.WaitGroup + writeSuccess atomic.Int64 + writeFailure atomic.Int64 + + randRec = func() *Record { + return &Record{ + Key: []byte("test"), + Value: []byte(strings.Repeat("x", rand.Intn(1000))), + Topic: topic, + } + } + ) + defer cleanup() + + cl, _ := newTestClient(MaxBufferedBytes(5000)) + defer cl.Close() + + // Start N workers that will concurrently write to the same partition. + var recsWritten atomic.Int64 + var fatal atomic.Bool + for i := 0; i < numWorkers; i++ { + workers.Add(1) + + go func() { + defer workers.Done() + + for recsWritten.Add(1) <= recsToWrite { + res := cl.ProduceSync(context.Background(), randRec()) + if err := res.FirstErr(); err == nil { + writeSuccess.Add(1) + } else { + if !errors.Is(err, ErrMaxBuffered) { + t.Errorf("unexpected error: %v", err) + fatal.Store(true) + } + + writeFailure.Add(1) + } + } + }() + } + workers.Wait() + + t.Logf("writes succeeded: %d", writeSuccess.Load()) + t.Logf("writes failed: %d", writeFailure.Load()) + if fatal.Load() { + t.Fatal("failed") + } +} + // This file contains golden tests against kmsg AppendTo's to ensure our custom // encoding is correct. diff --git a/pkg/kgo/producer.go b/pkg/kgo/producer.go index f09b4712..ce8fdb61 100644 --- a/pkg/kgo/producer.go +++ b/pkg/kgo/producer.go @@ -14,9 +14,15 @@ import ( ) type producer struct { - bufferedRecords atomicI64 - bufferedBytes atomicI64 - inflight atomicI64 // high 16: # waiters, low 48: # inflight + inflight atomicI64 // high 16: # waiters, low 48: # inflight + + // mu and c are used for flush and drain notifications; mu is used for + // a few other tight locks. + mu sync.Mutex + c *sync.Cond + + bufferedRecords int64 + bufferedBytes int64 cl *Client @@ -45,19 +51,14 @@ type producer struct { // We must have a producer field for flushing; we cannot just have a // field on recBufs that is toggled on flush. If we did, then a new // recBuf could be created and records sent to while we are flushing. - flushing atomicI32 // >0 if flushing, can Flush many times concurrently - blocked atomicI32 // >0 if over max recs or bytes + flushing atomicI32 // >0 if flushing, can Flush many times concurrently + blocked atomicI32 // >0 if over max recs or bytes + blockedBytes int64 aborting atomicI32 // >0 if aborting, can abort many times concurrently - idMu sync.Mutex - idVersion int16 - waitBuffer chan struct{} - - // mu and c are used for flush and drain notifications; mu is used for - // a few other tight locks. - mu sync.Mutex - c *sync.Cond + idMu sync.Mutex + idVersion int16 batchPromises ringBatchPromise promisesMu sync.Mutex @@ -86,14 +87,18 @@ type producer struct { // flushing records produced by your client (which can help determine network / // cluster health). func (cl *Client) BufferedProduceRecords() int64 { - return cl.producer.bufferedRecords.Load() + cl.producer.mu.Lock() + defer cl.producer.mu.Unlock() + return cl.producer.bufferedRecords + int64(cl.producer.blocked.Load()) } // BufferedProduceBytes returns the number of bytes currently buffered for // producing within the client. This is the sum of all keys, values, and header // keys/values. See the related [BufferedProduceRecords] for more information. func (cl *Client) BufferedProduceBytes() int64 { - return cl.producer.bufferedBytes.Load() + cl.producer.mu.Lock() + defer cl.producer.mu.Unlock() + return cl.producer.bufferedBytes + cl.producer.blockedBytes } type unknownTopicProduces struct { @@ -106,7 +111,6 @@ func (p *producer) init(cl *Client) { p.cl = cl p.topics = newTopicsPartitions() p.unknownTopics = make(map[string]*unknownTopicProduces) - p.waitBuffer = make(chan struct{}, math.MaxInt32) p.idVersion = -1 p.id.Store(&producerID{ id: -1, @@ -397,58 +401,93 @@ func (cl *Client) produce( } } - var ( - userSize = r.userSize() - bufRecs = p.bufferedRecords.Add(1) - bufBytes = p.bufferedBytes.Add(userSize) - overMaxRecs = bufRecs > cl.cfg.maxBufferedRecords - overMaxBytes bool - ) - if cl.cfg.maxBufferedBytes > 0 { - if userSize > cl.cfg.maxBufferedBytes { - p.promiseRecord(promisedRec{ctx, promise, r}, kerr.MessageTooLarge) - return - } - overMaxBytes = bufBytes > cl.cfg.maxBufferedBytes - } - + // We can now fail the rec after the buffered hook. if r.Topic == "" { - p.promiseRecord(promisedRec{ctx, promise, r}, errNoTopic) + p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, errNoTopic) return } if cl.cfg.txnID != nil && !p.producingTxn.Load() { - p.promiseRecord(promisedRec{ctx, promise, r}, errNotInTransaction) + p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, errNotInTransaction) return } + userSize := r.userSize() + if cl.cfg.maxBufferedBytes > 0 && userSize > cl.cfg.maxBufferedBytes { + p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, kerr.MessageTooLarge) + return + } + + // We have to grab the produce lock to check if this record will exceed + // configured limits. We try to keep the logic tight since this is + // effectively a global lock around producing. + var ( + nextBufRecs, nextBufBytes int64 + overMaxRecs, overMaxBytes bool + + calcNums = func() { + nextBufRecs = p.bufferedRecords + 1 + nextBufBytes = p.bufferedBytes + userSize + overMaxRecs = nextBufRecs > cl.cfg.maxBufferedRecords + overMaxBytes = cl.cfg.maxBufferedBytes > 0 && nextBufBytes > cl.cfg.maxBufferedBytes + } + ) + p.mu.Lock() + calcNums() if overMaxRecs || overMaxBytes { + if !block || cl.cfg.manualFlushing { + p.mu.Unlock() + p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, ErrMaxBuffered) + return + } + + // Before we potentially unlinger, add that we are blocked to + // ensure we do NOT start a linger anymore. We THEN wakeup + // anything that is actively lingering. Note that blocked is + // also used when finishing promises to see if we need to be + // notified. + p.blocked.Add(1) + p.blockedBytes += userSize + p.mu.Unlock() + cl.cfg.logger.Log(LogLevelDebug, "blocking Produce because we are either over max buffered records or max buffered bytes", "over_max_records", overMaxRecs, "over_max_bytes", overMaxBytes, ) - // Before we potentially unlinger, add that we are blocked. - // Lingering always checks blocked, so we will not start a - // linger while we are blocked. We THEN wakeup anything that - // is actively lingering. - cl.producer.blocked.Add(1) + cl.unlingerDueToMaxRecsBuffered() - // If the client ctx cancels or the produce ctx cancels, we - // need to un-count our buffering of this record. We also need - // to drain a slot from the waitBuffer chan, which could be - // sent to right when we are erroring. + + // We keep the lock when we exit. If we are flushing, we want + // this blocked record to be produced before we return from + // flushing. This blocked record will be accounted for in the + // bufferedRecords addition below, after being removed from + // blocked in the goroutine. + wait := make(chan struct{}) + var quit bool + go func() { + defer close(wait) + p.mu.Lock() + calcNums() + for !quit && (overMaxRecs || overMaxBytes) { + p.c.Wait() + calcNums() + } + p.blocked.Add(-1) + p.blockedBytes -= userSize + }() + drainBuffered := func(err error) { - p.promiseRecord(promisedRec{ctx, promise, r}, err) - <-p.waitBuffer - cl.producer.blocked.Add(-1) - } - if !block || cl.cfg.manualFlushing { - drainBuffered(ErrMaxBuffered) - return + p.mu.Lock() + quit = true + p.mu.Unlock() + p.c.Broadcast() // wake the goroutine above + <-wait + p.mu.Unlock() // we wait for the goroutine to exit, then unlock again (since the goroutine leaves the mutex locked) + p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, err) } + select { - case <-p.waitBuffer: - cl.cfg.logger.Log(LogLevelDebug, "Produce block signaled, continuing to produce") - cl.producer.blocked.Add(-1) + case <-wait: + cl.cfg.logger.Log(LogLevelDebug, "Produce block awoken, we now have space to produce, continuing to partition and produce") case <-cl.ctx.Done(): drainBuffered(ErrClientClosed) cl.cfg.logger.Log(LogLevelDebug, "client ctx canceled while blocked in Produce, returning") @@ -459,6 +498,9 @@ func (cl *Client) produce( return } } + p.bufferedRecords = nextBufRecs + p.bufferedBytes = nextBufBytes + p.mu.Unlock() cl.partitionRecord(promisedRec{ctx, promise, r}) } @@ -468,6 +510,7 @@ type batchPromise struct { pid int64 epoch int16 attrs RecordAttrs + beforeBuf bool partition int32 recs []promisedRec err error @@ -483,6 +526,10 @@ func (p *producer) promiseRecord(pr promisedRec, err error) { p.promiseBatch(batchPromise{recs: []promisedRec{pr}, err: err}) } +func (p *producer) promiseRecordBeforeBuf(pr promisedRec, err error) { + p.promiseBatch(batchPromise{recs: []promisedRec{pr}, beforeBuf: true, err: err}) +} + func (p *producer) finishPromises(b batchPromise) { cl := p.cl var more bool @@ -495,7 +542,7 @@ start: pr.ProducerID = b.pid pr.ProducerEpoch = b.epoch pr.Attrs = b.attrs - cl.finishRecordPromise(pr, b.err) + cl.finishRecordPromise(pr, b.err, b.beforeBuf) b.recs[i] = promisedRec{} } p.promisesMu.Unlock() @@ -509,7 +556,7 @@ start: } } -func (cl *Client) finishRecordPromise(pr promisedRec, err error) { +func (cl *Client) finishRecordPromise(pr promisedRec, err error, beforeBuffering bool) { p := &cl.producer if p.hooks != nil && len(p.hooks.unbuffered) > 0 { @@ -519,22 +566,27 @@ func (cl *Client) finishRecordPromise(pr promisedRec, err error) { } // Capture user size before potential modification by the promise. + // + // We call the promise before finishing the flush notification, + // allowing users of Flush to know all buf recs are done by the + // time we notify flush below. userSize := pr.userSize() - nowBufBytes := p.bufferedBytes.Add(-userSize) - nowBufRecs := p.bufferedRecords.Add(-1) - wasOverMaxRecs := nowBufRecs >= cl.cfg.maxBufferedRecords - wasOverMaxBytes := cl.cfg.maxBufferedBytes > 0 && nowBufBytes+userSize > cl.cfg.maxBufferedBytes - - // We call the promise before finishing the record; this allows users - // of Flush to know that all buffered records are completely done - // before Flush returns. pr.promise(pr.Record, err) - if wasOverMaxRecs || wasOverMaxBytes { - p.waitBuffer <- struct{}{} - } else if nowBufRecs == 0 && p.flushing.Load() > 0 { - p.mu.Lock() - p.mu.Unlock() //nolint:gocritic,staticcheck // We use the lock as a barrier, unlocking immediately is safe. + // If this record was never buffered, it's size was never accounted + // for on any p field: return early. + if beforeBuffering { + return + } + + // Keep the lock as tight as possible: the broadcast can come after. + p.mu.Lock() + p.bufferedBytes -= userSize + p.bufferedRecords-- + broadcast := p.blocked.Load() > 0 || p.bufferedRecords == 0 && p.flushing.Load() > 0 + p.mu.Unlock() + + if broadcast { p.c.Broadcast() } } @@ -1021,7 +1073,7 @@ func (cl *Client) Flush(ctx context.Context) error { defer p.mu.Unlock() defer close(done) - for !quit && p.bufferedRecords.Load() > 0 { + for !quit && p.bufferedRecords+int64(p.blocked.Load()) > 0 { p.c.Wait() } }() diff --git a/pkg/kgo/sink.go b/pkg/kgo/sink.go index 49bff3b1..f5283ddf 100644 --- a/pkg/kgo/sink.go +++ b/pkg/kgo/sink.go @@ -258,7 +258,7 @@ func (s *sink) produce(sem <-chan struct{}) bool { // We could have been triggered from a metadata update even though the // user is not producing at all. If we have no buffered records, let's // avoid potentially creating a producer ID. - if s.cl.producer.bufferedRecords.Load() == 0 { + if s.cl.BufferedProduceRecords() == 0 { return false }