Skip to content

Commit

Permalink
kgo: fix deadlock in Produce when using MaxBufferedBytes
Browse files Browse the repository at this point in the history
Copying from the issue,
"""
1) Produce() record A (100 bytes)
2) Produce() record B (50 bytes), waiting for buffer to free
3) Produce() record C (50 bytes), waiting for buffer to free
4) Record A is produced, finishRecordPromise() gets called, detects it was over the limit so publish 1 message to waitBuffer
5) Record B is unlocked, finishRecordPromise() gets called, does not detect it was over the limit (only 50 bytes), so record C is never unblocked and will wait indefinitely on waitBuffer
"""

The fix requires adding a lock while producing. This reuses the existing
lock on the `producer` type. This can lead to a few more spurious
wakeups in other functions that use this same mutex, but that's fine.

The prior algorithm counted anything to produce immediately into the
buffered records and bytes fields; the fix for #777 could not really be
possible unless we avoid counting the "buffered" aspect right away.
Specifically, we need to have a goroutine looping with a sync.Cond that
checks *IF* we add the record, will we still be blocked? This allows us
to wake up all blocked goroutines always (unlike one at a time, the
problem this issue points out), and each goroutine can check under a
lock if they still do not fit.

This also fixes an unreported bug where, if a record WOULD be blocked
but fails early due to no topic / not in a transaction while in a
transactional client, the serial promise finishing goroutine would
deadlock.

Closes #777.
  • Loading branch information
twmb committed Jul 29, 2024
1 parent 58d20a1 commit 24fbb0f
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 71 deletions.
6 changes: 4 additions & 2 deletions pkg/kgo/consumer_direct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ func TestPauseIssue489(t *testing.T) {
exit.Store(true)
}
})
time.Sleep(100 * time.Microsecond)
cl.Flush(ctx)
time.Sleep(50 * time.Microsecond)
}
}()
defer cancel()
Expand Down Expand Up @@ -416,7 +417,8 @@ func TestPauseIssueOct2023(t *testing.T) {
exit.Store(true)
}
})
time.Sleep(100 * time.Microsecond)
cl.Flush(ctx)
time.Sleep(50 * time.Microsecond)
}
}()
defer cancel()
Expand Down
62 changes: 62 additions & 0 deletions pkg/kgo/produce_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,75 @@ package kgo

import (
"bytes"
"context"
"errors"
"hash/crc32"
"math/rand"
"strings"
"sync"
"sync/atomic"
"testing"

"github.com/twmb/franz-go/pkg/kbin"
"github.com/twmb/franz-go/pkg/kmsg"
)

func TestClient_Produce(t *testing.T) {
var (
topic, cleanup = tmpTopicPartitions(t, 1)
numWorkers = 50
recsToWrite = int64(20_000)

workers sync.WaitGroup
writeSuccess atomic.Int64
writeFailure atomic.Int64

randRec = func() *Record {
return &Record{
Key: []byte("test"),
Value: []byte(strings.Repeat("x", rand.Intn(1000))),
Topic: topic,
}
}
)
defer cleanup()

cl, _ := newTestClient(MaxBufferedBytes(5000))
defer cl.Close()

// Start N workers that will concurrently write to the same partition.
var recsWritten atomic.Int64
var fatal atomic.Bool
for i := 0; i < numWorkers; i++ {
workers.Add(1)

go func() {
defer workers.Done()

for recsWritten.Add(1) <= recsToWrite {
res := cl.ProduceSync(context.Background(), randRec())
if err := res.FirstErr(); err == nil {
writeSuccess.Add(1)
} else {
if !errors.Is(err, ErrMaxBuffered) {
t.Errorf("unexpected error: %v", err)
fatal.Store(true)
}

writeFailure.Add(1)
}
}
}()
}
workers.Wait()

t.Logf("writes succeeded: %d", writeSuccess.Load())
t.Logf("writes failed: %d", writeFailure.Load())
if fatal.Load() {
t.Fatal("failed")
}
}

// This file contains golden tests against kmsg AppendTo's to ensure our custom
// encoding is correct.

Expand Down
188 changes: 120 additions & 68 deletions pkg/kgo/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ import (
)

type producer struct {
bufferedRecords atomicI64
bufferedBytes atomicI64
inflight atomicI64 // high 16: # waiters, low 48: # inflight
inflight atomicI64 // high 16: # waiters, low 48: # inflight

// mu and c are used for flush and drain notifications; mu is used for
// a few other tight locks.
mu sync.Mutex
c *sync.Cond

bufferedRecords int64
bufferedBytes int64

cl *Client

Expand Down Expand Up @@ -45,19 +51,14 @@ type producer struct {
// We must have a producer field for flushing; we cannot just have a
// field on recBufs that is toggled on flush. If we did, then a new
// recBuf could be created and records sent to while we are flushing.
flushing atomicI32 // >0 if flushing, can Flush many times concurrently
blocked atomicI32 // >0 if over max recs or bytes
flushing atomicI32 // >0 if flushing, can Flush many times concurrently
blocked atomicI32 // >0 if over max recs or bytes
blockedBytes int64

aborting atomicI32 // >0 if aborting, can abort many times concurrently

idMu sync.Mutex
idVersion int16
waitBuffer chan struct{}

// mu and c are used for flush and drain notifications; mu is used for
// a few other tight locks.
mu sync.Mutex
c *sync.Cond
idMu sync.Mutex
idVersion int16

batchPromises ringBatchPromise
promisesMu sync.Mutex
Expand Down Expand Up @@ -86,14 +87,18 @@ type producer struct {
// flushing records produced by your client (which can help determine network /
// cluster health).
func (cl *Client) BufferedProduceRecords() int64 {
return cl.producer.bufferedRecords.Load()
cl.producer.mu.Lock()
defer cl.producer.mu.Unlock()
return cl.producer.bufferedRecords + int64(cl.producer.blocked.Load())
}

// BufferedProduceBytes returns the number of bytes currently buffered for
// producing within the client. This is the sum of all keys, values, and header
// keys/values. See the related [BufferedProduceRecords] for more information.
func (cl *Client) BufferedProduceBytes() int64 {
return cl.producer.bufferedBytes.Load()
cl.producer.mu.Lock()
defer cl.producer.mu.Unlock()
return cl.producer.bufferedBytes + cl.producer.blockedBytes
}

type unknownTopicProduces struct {
Expand All @@ -106,7 +111,6 @@ func (p *producer) init(cl *Client) {
p.cl = cl
p.topics = newTopicsPartitions()
p.unknownTopics = make(map[string]*unknownTopicProduces)
p.waitBuffer = make(chan struct{}, math.MaxInt32)
p.idVersion = -1
p.id.Store(&producerID{
id: -1,
Expand Down Expand Up @@ -397,58 +401,93 @@ func (cl *Client) produce(
}
}

var (
userSize = r.userSize()
bufRecs = p.bufferedRecords.Add(1)
bufBytes = p.bufferedBytes.Add(userSize)
overMaxRecs = bufRecs > cl.cfg.maxBufferedRecords
overMaxBytes bool
)
if cl.cfg.maxBufferedBytes > 0 {
if userSize > cl.cfg.maxBufferedBytes {
p.promiseRecord(promisedRec{ctx, promise, r}, kerr.MessageTooLarge)
return
}
overMaxBytes = bufBytes > cl.cfg.maxBufferedBytes
}

// We can now fail the rec after the buffered hook.
if r.Topic == "" {
p.promiseRecord(promisedRec{ctx, promise, r}, errNoTopic)
p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, errNoTopic)
return
}
if cl.cfg.txnID != nil && !p.producingTxn.Load() {
p.promiseRecord(promisedRec{ctx, promise, r}, errNotInTransaction)
p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, errNotInTransaction)
return
}

userSize := r.userSize()
if cl.cfg.maxBufferedBytes > 0 && userSize > cl.cfg.maxBufferedBytes {
p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, kerr.MessageTooLarge)
return
}

// We have to grab the produce lock to check if this record will exceed
// configured limits. We try to keep the logic tight since this is
// effectively a global lock around producing.
var (
nextBufRecs, nextBufBytes int64
overMaxRecs, overMaxBytes bool

calcNums = func() {
nextBufRecs = p.bufferedRecords + 1
nextBufBytes = p.bufferedBytes + userSize
overMaxRecs = nextBufRecs > cl.cfg.maxBufferedRecords
overMaxBytes = cl.cfg.maxBufferedBytes > 0 && nextBufBytes > cl.cfg.maxBufferedBytes
}
)
p.mu.Lock()
calcNums()
if overMaxRecs || overMaxBytes {
if !block || cl.cfg.manualFlushing {
p.mu.Unlock()
p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, ErrMaxBuffered)
return
}

// Before we potentially unlinger, add that we are blocked to
// ensure we do NOT start a linger anymore. We THEN wakeup
// anything that is actively lingering. Note that blocked is
// also used when finishing promises to see if we need to be
// notified.
p.blocked.Add(1)
p.blockedBytes += userSize
p.mu.Unlock()

cl.cfg.logger.Log(LogLevelDebug, "blocking Produce because we are either over max buffered records or max buffered bytes",
"over_max_records", overMaxRecs,
"over_max_bytes", overMaxBytes,
)
// Before we potentially unlinger, add that we are blocked.
// Lingering always checks blocked, so we will not start a
// linger while we are blocked. We THEN wakeup anything that
// is actively lingering.
cl.producer.blocked.Add(1)

cl.unlingerDueToMaxRecsBuffered()
// If the client ctx cancels or the produce ctx cancels, we
// need to un-count our buffering of this record. We also need
// to drain a slot from the waitBuffer chan, which could be
// sent to right when we are erroring.

// We keep the lock when we exit. If we are flushing, we want
// this blocked record to be produced before we return from
// flushing. This blocked record will be accounted for in the
// bufferedRecords addition below, after being removed from
// blocked in the goroutine.
wait := make(chan struct{})
var quit bool
go func() {
defer close(wait)
p.mu.Lock()
calcNums()
for !quit && (overMaxRecs || overMaxBytes) {
p.c.Wait()
calcNums()
}
p.blocked.Add(-1)
p.blockedBytes -= userSize
}()

drainBuffered := func(err error) {
p.promiseRecord(promisedRec{ctx, promise, r}, err)
<-p.waitBuffer
cl.producer.blocked.Add(-1)
}
if !block || cl.cfg.manualFlushing {
drainBuffered(ErrMaxBuffered)
return
p.mu.Lock()
quit = true
p.mu.Unlock()
p.c.Broadcast() // wake the goroutine above
<-wait
p.mu.Unlock() // we wait for the goroutine to exit, then unlock again (since the goroutine leaves the mutex locked)
p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, err)
}

select {
case <-p.waitBuffer:
cl.cfg.logger.Log(LogLevelDebug, "Produce block signaled, continuing to produce")
cl.producer.blocked.Add(-1)
case <-wait:
cl.cfg.logger.Log(LogLevelDebug, "Produce block awoken, we now have space to produce, continuing to partition and produce")
case <-cl.ctx.Done():
drainBuffered(ErrClientClosed)
cl.cfg.logger.Log(LogLevelDebug, "client ctx canceled while blocked in Produce, returning")
Expand All @@ -459,6 +498,9 @@ func (cl *Client) produce(
return
}
}
p.bufferedRecords = nextBufRecs
p.bufferedBytes = nextBufBytes
p.mu.Unlock()

cl.partitionRecord(promisedRec{ctx, promise, r})
}
Expand All @@ -468,6 +510,7 @@ type batchPromise struct {
pid int64
epoch int16
attrs RecordAttrs
beforeBuf bool
partition int32
recs []promisedRec
err error
Expand All @@ -483,6 +526,10 @@ func (p *producer) promiseRecord(pr promisedRec, err error) {
p.promiseBatch(batchPromise{recs: []promisedRec{pr}, err: err})
}

func (p *producer) promiseRecordBeforeBuf(pr promisedRec, err error) {
p.promiseBatch(batchPromise{recs: []promisedRec{pr}, beforeBuf: true, err: err})
}

func (p *producer) finishPromises(b batchPromise) {
cl := p.cl
var more bool
Expand All @@ -495,7 +542,7 @@ start:
pr.ProducerID = b.pid
pr.ProducerEpoch = b.epoch
pr.Attrs = b.attrs
cl.finishRecordPromise(pr, b.err)
cl.finishRecordPromise(pr, b.err, b.beforeBuf)
b.recs[i] = promisedRec{}
}
p.promisesMu.Unlock()
Expand All @@ -509,7 +556,7 @@ start:
}
}

func (cl *Client) finishRecordPromise(pr promisedRec, err error) {
func (cl *Client) finishRecordPromise(pr promisedRec, err error, beforeBuffering bool) {
p := &cl.producer

if p.hooks != nil && len(p.hooks.unbuffered) > 0 {
Expand All @@ -519,22 +566,27 @@ func (cl *Client) finishRecordPromise(pr promisedRec, err error) {
}

// Capture user size before potential modification by the promise.
//
// We call the promise before finishing the flush notification,
// allowing users of Flush to know all buf recs are done by the
// time we notify flush below.
userSize := pr.userSize()
nowBufBytes := p.bufferedBytes.Add(-userSize)
nowBufRecs := p.bufferedRecords.Add(-1)
wasOverMaxRecs := nowBufRecs >= cl.cfg.maxBufferedRecords
wasOverMaxBytes := cl.cfg.maxBufferedBytes > 0 && nowBufBytes+userSize > cl.cfg.maxBufferedBytes

// We call the promise before finishing the record; this allows users
// of Flush to know that all buffered records are completely done
// before Flush returns.
pr.promise(pr.Record, err)

if wasOverMaxRecs || wasOverMaxBytes {
p.waitBuffer <- struct{}{}
} else if nowBufRecs == 0 && p.flushing.Load() > 0 {
p.mu.Lock()
p.mu.Unlock() //nolint:gocritic,staticcheck // We use the lock as a barrier, unlocking immediately is safe.
// If this record was never buffered, it's size was never accounted
// for on any p field: return early.
if beforeBuffering {
return
}

// Keep the lock as tight as possible: the broadcast can come after.
p.mu.Lock()
p.bufferedBytes -= userSize
p.bufferedRecords--
broadcast := p.blocked.Load() > 0 || p.bufferedRecords == 0 && p.flushing.Load() > 0
p.mu.Unlock()

if broadcast {
p.c.Broadcast()
}
}
Expand Down Expand Up @@ -1021,7 +1073,7 @@ func (cl *Client) Flush(ctx context.Context) error {
defer p.mu.Unlock()
defer close(done)

for !quit && p.bufferedRecords.Load() > 0 {
for !quit && p.bufferedRecords+int64(p.blocked.Load()) > 0 {
p.c.Wait()
}
}()
Expand Down
Loading

0 comments on commit 24fbb0f

Please sign in to comment.