diff --git a/pkg/kgo/consumer_direct_test.go b/pkg/kgo/consumer_direct_test.go index ac385e8e..cc8fca77 100644 --- a/pkg/kgo/consumer_direct_test.go +++ b/pkg/kgo/consumer_direct_test.go @@ -263,6 +263,15 @@ func TestAddRemovePartitions(t *testing.T) { } } +func closed(ch <-chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } +} + func TestPauseIssue489(t *testing.T) { t.Parallel() @@ -297,8 +306,9 @@ func TestPauseIssue489(t *testing.T) { defer cancel() for i := 0; i < 10; i++ { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) var sawZero, sawOne bool - for !sawZero || !sawOne { + for (!sawZero || !sawOne) && !closed(ctx.Done()) { fs := cl.PollFetches(ctx) fs.EachRecord(func(r *Record) { sawZero = sawZero || r.Partition == 0 @@ -307,7 +317,7 @@ func TestPauseIssue489(t *testing.T) { } cl.PauseFetchPartitions(map[string][]int32{t1: {0}}) sawZero, sawOne = false, false - for i := 0; i < 10; i++ { + for i := 0; i < 10 && !closed(ctx.Done()); i++ { var fs Fetches if i < 5 { fs = cl.PollFetches(ctx) @@ -319,13 +329,86 @@ func TestPauseIssue489(t *testing.T) { sawOne = sawOne || r.Partition == 1 }) } + cancel() if sawZero { - t.Error("saw partition zero even though it was paused") + t.Fatal("saw partition zero even though it was paused") + } + if !sawOne { + t.Fatal("did not see partition one even though it was not paused") } cl.ResumeFetchPartitions(map[string][]int32{t1: {0}}) } } +func TestPauseIssueOct2023(t *testing.T) { + t.Parallel() + + t1, cleanup1 := tmpTopicPartitions(t, 1) + t2, cleanup2 := tmpTopicPartitions(t, 1) + defer cleanup1() + defer cleanup2() + ts := []string{t1, t2} + + cl, _ := NewClient( + getSeedBrokers(), + UnknownTopicRetries(-1), + ConsumeTopics(ts...), + FetchMaxWait(100*time.Millisecond), + ) + defer cl.Close() + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + var exit atomic.Bool + var which int + for !exit.Load() { + r := StringRecord("v") + r.Topic = ts[which%len(ts)] + which++ + cl.Produce(ctx, r, func(r *Record, err error) { + if err == context.Canceled { + exit.Store(true) + } + }) + } + }() + defer cancel() + + for i := 0; i < 10; i++ { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + var sawt1, sawt2 bool + for (!sawt1 || !sawt2) && !closed(ctx.Done()) { + fs := cl.PollFetches(ctx) + fs.EachRecord(func(r *Record) { + sawt1 = sawt1 || r.Topic == t1 + sawt2 = sawt2 || r.Topic == t2 + }) + } + cl.PauseFetchTopics(t1) + sawt1, sawt2 = false, false + for i := 0; i < 10 && !closed(ctx.Done()); i++ { + var fs Fetches + if i < 5 { + fs = cl.PollFetches(ctx) + } else { + fs = cl.PollRecords(ctx, 2) + } + fs.EachRecord(func(r *Record) { + sawt1 = sawt1 || r.Topic == t1 + sawt2 = sawt2 || r.Topic == t2 + }) + } + cancel() + if sawt1 { + t.Fatal("saw topic t1 even though it was paused") + } + if !sawt2 { + t.Fatal("did not see topic t2 even though it was not paused") + } + cl.ResumeFetchTopics(t1) + } +} + func TestIssue523(t *testing.T) { t.Parallel() diff --git a/pkg/kgo/source.go b/pkg/kgo/source.go index 3086e60d..df5e4432 100644 --- a/pkg/kgo/source.go +++ b/pkg/kgo/source.go @@ -355,6 +355,10 @@ func (s *source) takeBuffered(paused pausedTopics) Fetch { // and strip the topic entirely. pps, ok := paused.t(t) if !ok { + for _, o := range ps { + o.from.setOffset(o.cursorOffset) + o.from.allowUsable() + } continue } if strip == nil {