diff --git a/pkg/kgo/consumer.go b/pkg/kgo/consumer.go index 475270ed..4c88b112 100644 --- a/pkg/kgo/consumer.go +++ b/pkg/kgo/consumer.go @@ -433,6 +433,10 @@ func (cl *Client) PollRecords(ctx context.Context, maxPollRecords int) Fetches { }() } + c.pausedMu.Lock() + defer c.pausedMu.Unlock() + paused := c.loadPaused() + // A group can grab the consumer lock then the group mu and // assign partitions. The group mu is grabbed to update its // uncommitted map. Assigning partitions clears sources ready @@ -451,13 +455,13 @@ func (cl *Client) PollRecords(ctx context.Context, maxPollRecords int) Fetches { c.sourcesReadyMu.Lock() if maxPollRecords < 0 { for _, ready := range c.sourcesReadyForDraining { - fetches = append(fetches, ready.takeBuffered()) + fetches = append(fetches, ready.takeBuffered(paused)) } c.sourcesReadyForDraining = nil } else { for len(c.sourcesReadyForDraining) > 0 && maxPollRecords > 0 { source := c.sourcesReadyForDraining[0] - fetch, taken, drained := source.takeNBuffered(maxPollRecords) + fetch, taken, drained := source.takeNBuffered(paused, maxPollRecords) if drained { c.sourcesReadyForDraining = c.sourcesReadyForDraining[1:] } @@ -555,9 +559,7 @@ func (cl *Client) UpdateFetchMaxBytes(maxBytes, maxPartBytes int32) { // PauseFetchTopics sets the client to no longer fetch the given topics and // returns all currently paused topics. Paused topics persist until resumed. // You can call this function with no topics to simply receive the list of -// currently paused topics. Pausing topics drops everything currently buffered -// and kills any in flight fetch requests to ensure nothing that is paused -// can be returned anymore from polling. +// currently paused topics. // // Pausing topics is independent from pausing individual partitions with the // PauseFetchPartitions method. If you pause partitions for a topic with @@ -569,15 +571,8 @@ func (cl *Client) PauseFetchTopics(topics ...string) []string { if len(topics) == 0 { return c.loadPaused().pausedTopics() } - c.pausedMu.Lock() defer c.pausedMu.Unlock() - defer func() { - c.mu.Lock() - defer c.mu.Unlock() - c.assignPartitions(nil, assignBumpSession, nil, fmt.Sprintf("pausing fetch topics %v", topics)) - }() - paused := c.clonePaused() paused.addTopics(topics...) c.storePaused(paused) @@ -587,9 +582,7 @@ func (cl *Client) PauseFetchTopics(topics ...string) []string { // PauseFetchPartitions sets the client to no longer fetch the given partitions // and returns all currently paused partitions. Paused partitions persist until // resumed. You can call this function with no partitions to simply receive the -// list of currently paused partitions. Pausing partitions drops everything -// currently buffered and kills any in flight fetch requests to ensure nothing -// that is paused can be returned anymore from polling. +// list of currently paused partitions. // // Pausing individual partitions is independent from pausing topics with the // PauseFetchTopics method. If you pause partitions for a topic with @@ -601,15 +594,8 @@ func (cl *Client) PauseFetchPartitions(topicPartitions map[string][]int32) map[s if len(topicPartitions) == 0 { return c.loadPaused().pausedPartitions() } - c.pausedMu.Lock() defer c.pausedMu.Unlock() - defer func() { - c.mu.Lock() - defer c.mu.Unlock() - c.assignPartitions(nil, assignBumpSession, nil, fmt.Sprintf("pausing fetch partitions %v", topicPartitions)) - }() - paused := c.clonePaused() paused.addPartitions(topicPartitions) c.storePaused(paused) @@ -884,10 +870,6 @@ const ( // The counterpart to assignInvalidateMatching, assignSetMatching // resets all matching partitions to the specified offset / epoch. assignSetMatching - - // For pausing, we want to drop anything inflight. We start a new - // session with the old tps. - assignBumpSession ) func (h assignHow) String() string { @@ -902,8 +884,6 @@ func (h assignHow) String() string { return "unassigning and purging any partition matching the input topics" case assignSetMatching: return "reassigning any currently assigned matching partition to the input" - case assignBumpSession: - return "bumping internal consumer session to drop anything currently in flight" } return "" } @@ -984,8 +964,6 @@ func (c *consumer) assignPartitions(assignments map[string]map[int32]Offset, how // if we had no session before, which is why we need to pass in // our topicPartitions. session = c.guardSessionChange(tps) - } else if how == assignBumpSession { - loadOffsets, tps = c.stopSession() } else { loadOffsets, _ = c.stopSession() @@ -1032,7 +1010,7 @@ func (c *consumer) assignPartitions(assignments map[string]map[int32]Offset, how // assignment went straight to listing / epoch loading, and // that list/epoch never finished. switch how { - case assignWithoutInvalidating, assignBumpSession: + case assignWithoutInvalidating: // Nothing to do -- this is handled above. case assignInvalidateAll: loadOffsets = listOrEpochLoads{} diff --git a/pkg/kgo/consumer_direct_test.go b/pkg/kgo/consumer_direct_test.go index 548ab575..ac385e8e 100644 --- a/pkg/kgo/consumer_direct_test.go +++ b/pkg/kgo/consumer_direct_test.go @@ -307,8 +307,13 @@ func TestPauseIssue489(t *testing.T) { } cl.PauseFetchPartitions(map[string][]int32{t1: {0}}) sawZero, sawOne = false, false - for i := 0; i < 5; i++ { - fs := cl.PollFetches(ctx) + for i := 0; i < 10; i++ { + var fs Fetches + if i < 5 { + fs = cl.PollFetches(ctx) + } else { + fs = cl.PollRecords(ctx, 2) + } fs.EachRecord(func(r *Record) { sawZero = sawZero || r.Partition == 0 sawOne = sawOne || r.Partition == 1 diff --git a/pkg/kgo/source.go b/pkg/kgo/source.go index 226ef458..e375654b 100644 --- a/pkg/kgo/source.go +++ b/pkg/kgo/source.go @@ -344,8 +344,38 @@ func (s *source) hook(f *Fetch, buffered, polled bool) { } // takeBuffered drains a buffered fetch and updates offsets. -func (s *source) takeBuffered() Fetch { - return s.takeBufferedFn(true, usedOffsets.finishUsingAllWithSet) +func (s *source) takeBuffered(paused pausedTopics) Fetch { + var strip mtmps + f := s.takeBufferedFn(true, func(os usedOffsets) { + os.eachOffset(func(o *cursorOffsetNext) { + if paused != nil && paused.has(o.from.topic, o.from.partition) { + o.from.allowUsable() + strip.add(o.from.topic, o.from.partition) + return + } + o.from.setOffset(o.cursorOffset) + o.from.allowUsable() + }) + }) + if strip != nil { + keep := f.Topics[:0] + for _, t := range f.Topics { + if strip.has(t.Topic, -1) { + continue + } + keepp := t.Partitions[:0] + for _, p := range t.Partitions { + if strip.has(t.Topic, p.Partition) { + continue + } + keepp = append(keepp, p) + } + t.Partitions = keepp + keep = append(keep, t) + } + f.Topics = keep + } + return f } func (s *source) discardBuffered() { @@ -359,7 +389,7 @@ func (s *source) discardBuffered() { // // This returns the number of records taken and whether the source has been // completely drained. -func (s *source) takeNBuffered(n int) (Fetch, int, bool) { +func (s *source) takeNBuffered(paused pausedTopics, n int) (Fetch, int, bool) { var r Fetch var taken int @@ -368,6 +398,17 @@ func (s *source) takeNBuffered(n int) (Fetch, int, bool) { for len(bf.Topics) > 0 && n > 0 { t := &bf.Topics[0] + // If the topic is outright paused, we allowUsable all + // partitions in the topic and skip the topic entirely. + if paused != nil && paused.has(t.Topic, -1) { + bf.Topics = bf.Topics[1:] + for _, pCursor := range b.usedOffsets[t.Topic] { + pCursor.from.allowUsable() + } + delete(b.usedOffsets, t.Topic) + continue + } + r.Topics = append(r.Topics, *t) rt := &r.Topics[len(r.Topics)-1] rt.Partitions = nil @@ -377,6 +418,17 @@ func (s *source) takeNBuffered(n int) (Fetch, int, bool) { for len(t.Partitions) > 0 && n > 0 { p := &t.Partitions[0] + if paused != nil && paused.has(t.Topic, p.Partition) { + t.Partitions = t.Partitions[1:] + pCursor := tCursors[p.Partition] + pCursor.from.allowUsable() + delete(tCursors, p.Partition) + if len(tCursors) == 0 { + delete(b.usedOffsets, t.Topic) + } + continue + } + rt.Partitions = append(rt.Partitions, *p) rp := &rt.Partitions[len(rt.Partitions)-1] @@ -402,7 +454,7 @@ func (s *source) takeNBuffered(n int) (Fetch, int, bool) { if len(tCursors) == 0 { delete(b.usedOffsets, t.Topic) } - break + continue } lastReturnedRecord := rp.Records[len(rp.Records)-1] @@ -422,7 +474,7 @@ func (s *source) takeNBuffered(n int) (Fetch, int, bool) { drained := len(bf.Topics) == 0 if drained { - s.takeBuffered() + s.takeBuffered(nil) } return r, taken, drained } diff --git a/pkg/kgo/topics_and_partitions.go b/pkg/kgo/topics_and_partitions.go index 9e6e9e9e..ccdf8e76 100644 --- a/pkg/kgo/topics_and_partitions.go +++ b/pkg/kgo/topics_and_partitions.go @@ -108,6 +108,21 @@ func (m mtmps) onlyt(t string) bool { return exists && len(ps) == 0 } +func (m mtmps) has(t string, p int32) bool { + if m == nil { + return false + } + ps, exists := m[t] + if !exists { + return false + } + if p == -1 { + return true + } + _, exists = ps[p] + return exists +} + func (m mtmps) remove(t string, p int32) { if m == nil { return