diff --git a/pkg/kgo/producer.go b/pkg/kgo/producer.go index 39fcb5e6..6e2fa3fa 100644 --- a/pkg/kgo/producer.go +++ b/pkg/kgo/producer.go @@ -16,11 +16,8 @@ import ( type producer struct { inflight atomicI64 // high 16: # waiters, low 48: # inflight - // mu and c are used for flush and drain notifications; mu is used for - // a few other tight locks. - mu sync.Mutex - c *sync.Cond - + produceMu sync.Mutex + produceC *sync.Cond bufferedRecords int64 bufferedBytes int64 @@ -60,6 +57,11 @@ type producer struct { idMu sync.Mutex idVersion int16 + // mu and c are used for flush and drain notifications; mu is used for + // a few other tight locks. + mu sync.Mutex + c *sync.Cond + batchPromises ringBatchPromise promisesMu sync.Mutex @@ -87,8 +89,8 @@ type producer struct { // flushing records produced by your client (which can help determine network / // cluster health). func (cl *Client) BufferedProduceRecords() int64 { - cl.producer.mu.Lock() - defer cl.producer.mu.Unlock() + cl.producer.produceMu.Lock() + defer cl.producer.produceMu.Unlock() return cl.producer.bufferedRecords + int64(cl.producer.blocked.Load()) } @@ -96,8 +98,8 @@ func (cl *Client) BufferedProduceRecords() int64 { // producing within the client. This is the sum of all keys, values, and header // keys/values. See the related [BufferedProduceRecords] for more information. func (cl *Client) BufferedProduceBytes() int64 { - cl.producer.mu.Lock() - defer cl.producer.mu.Unlock() + cl.producer.produceMu.Lock() + defer cl.producer.produceMu.Unlock() return cl.producer.bufferedBytes + cl.producer.blockedBytes } @@ -117,6 +119,7 @@ func (p *producer) init(cl *Client) { epoch: -1, err: errReloadProducerID, }) + p.produceC = sync.NewCond(&p.produceMu) p.c = sync.NewCond(&p.mu) inithooks := func() { @@ -431,11 +434,11 @@ func (cl *Client) produce( overMaxBytes = cl.cfg.maxBufferedBytes > 0 && nextBufBytes > cl.cfg.maxBufferedBytes } ) - p.mu.Lock() + p.produceMu.Lock() calcNums() if overMaxRecs || overMaxBytes { if !block || cl.cfg.manualFlushing { - p.mu.Unlock() + p.produceMu.Unlock() p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, ErrMaxBuffered) return } @@ -447,7 +450,7 @@ func (cl *Client) produce( // notified. p.blocked.Add(1) p.blockedBytes += userSize - p.mu.Unlock() + p.produceMu.Unlock() cl.cfg.logger.Log(LogLevelDebug, "blocking Produce because we are either over max buffered records or max buffered bytes", "over_max_records", overMaxRecs, @@ -456,35 +459,34 @@ func (cl *Client) produce( cl.unlingerDueToMaxRecsBuffered() + // We keep the lock when we exit. wait := make(chan struct{}) var quit bool go func() { defer close(wait) - p.mu.Lock() - defer p.mu.Unlock() + p.produceMu.Lock() calcNums() for !quit && (overMaxRecs || overMaxBytes) { - p.c.Wait() + p.produceC.Wait() calcNums() } p.blocked.Add(-1) p.blockedBytes -= userSize - p.c.Broadcast() // ensure Flush is awoken, if need be }() drainBuffered := func(err error) { - p.mu.Lock() + p.produceMu.Lock() quit = true - p.mu.Unlock() - p.c.Broadcast() // wake the goroutine above + p.produceMu.Unlock() + p.produceC.Broadcast() // wake the goroutine above + <-wait + p.produceMu.Unlock() p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, err) } select { case <-wait: cl.cfg.logger.Log(LogLevelDebug, "Produce block awoken, we now have space to produce, continuing to partition and produce") - p.mu.Lock() // lock before modifying p.buffered<> below - calcNums() // update numbers within final lock case <-cl.ctx.Done(): drainBuffered(ErrClientClosed) cl.cfg.logger.Log(LogLevelDebug, "client ctx canceled while blocked in Produce, returning") @@ -497,7 +499,7 @@ func (cl *Client) produce( } p.bufferedRecords = nextBufRecs p.bufferedBytes = nextBufBytes - p.mu.Unlock() + p.produceMu.Unlock() cl.partitionRecord(promisedRec{ctx, promise, r}) } @@ -577,13 +579,18 @@ func (cl *Client) finishRecordPromise(pr promisedRec, err error, beforeBuffering } // Keep the lock as tight as possible: the broadcast can come after. - p.mu.Lock() + p.produceMu.Lock() p.bufferedBytes -= userSize p.bufferedRecords-- - broadcast := p.blocked.Load() > 0 || p.bufferedRecords == 0 && p.flushing.Load() > 0 - p.mu.Unlock() + broadcastC := p.blocked.Load() > 0 + broadcastFlush := p.shouldBroadcastFlush() + p.produceMu.Unlock() - if broadcast { + if broadcastC { + p.produceC.Broadcast() + } else if broadcastFlush { + p.mu.Lock() + p.mu.Unlock() //nolint:gocritic,staticcheck // We use the lock as a barrier, unlocking immediately is safe. p.c.Broadcast() } } @@ -1033,6 +1040,10 @@ func (cl *Client) unlingerDueToMaxRecsBuffered() { cl.cfg.logger.Log(LogLevelDebug, "unlingered all partitions due to hitting max buffered") } +func (p *producer) shouldBroadcastFlush() bool { + return p.blocked.Load() == 0 && p.bufferedRecords == 0 && p.flushing.Load() > 0 +} + // Flush hangs waiting for all buffered records to be flushed, stopping all // lingers if necessary. // @@ -1070,7 +1081,7 @@ func (cl *Client) Flush(ctx context.Context) error { defer p.mu.Unlock() defer close(done) - for !quit && p.bufferedRecords+int64(p.blocked.Load()) > 0 { + for !quit && cl.BufferedProduceRecords() > 0 { p.c.Wait() } }()