diff --git a/pkg/kgo/consumer.go b/pkg/kgo/consumer.go index 76ba397b..ef8f46cd 100644 --- a/pkg/kgo/consumer.go +++ b/pkg/kgo/consumer.go @@ -758,6 +758,29 @@ func (cl *Client) AddConsumeTopics(topics ...string) { cl.triggerUpdateMetadataNow("from AddConsumeTopics") } +// GetConsumeTopics retrives a list of current topics being consumed. +func (cl *Client) GetConsumeTopics() []string { + c := &cl.consumer + if c.g == nil && c.d == nil { + return nil + } + var m map[string]*topicPartitions + var ok bool + if c.g != nil { + m, ok = c.g.tps.v.Load().(topicsPartitionsData) + } else { + m, ok = c.d.tps.v.Load().(topicsPartitionsData) + } + if !ok { + return nil + } + topics := make([]string, 0, len(m)) + for k := range m { + topics = append(topics, k) + } + return topics +} + // AddConsumePartitions adds new partitions to be consumed at the given // offsets. This function works only for direct, non-regex consumers. func (cl *Client) AddConsumePartitions(partitions map[string]map[int32]Offset) { diff --git a/pkg/kgo/consumer_direct_test.go b/pkg/kgo/consumer_direct_test.go index f02b0d69..da476b72 100644 --- a/pkg/kgo/consumer_direct_test.go +++ b/pkg/kgo/consumer_direct_test.go @@ -39,6 +39,47 @@ func TestIssue325(t *testing.T) { } } +// Allow adding a topic to consume after the client is initialized with nothing +// to consume. +func TestConsumeTopicRetrieval(t *testing.T) { + t.Parallel() + topicName := "test" + cl, _ := newTestClient() + defer cl.Close() + topics := cl.GetConsumeTopics() + if len(topics) != 0 { + t.Fatalf("expected no topics, got %v", topics) + } + cl.AddConsumeTopics(topicName) + topics = cl.GetConsumeTopics() + if len(topics) != 1 || topics[0] != topicName { + t.Fatalf("expected to see %v, got %v", topicName, topics) + } +} + +// Allow adding a topic to consume after the client is initialized with nothing +// to consume. +func TestConsumeTopicRetrieval_Many(t *testing.T) { + t.Parallel() + topicName := "test" + cl, _ := newTestClient() + defer cl.Close() + topics := cl.GetConsumeTopics() + if len(topics) != 0 { + t.Fatalf("expected no topics, got %v", topics) + } + for i := 0; i < 100; i++ { + cl.AddConsumeTopics(fmt.Sprintf("%s_%d", topicName, i)) + } + topics = cl.GetConsumeTopics() + sort.Slice(topics, func(i, j int) bool { + return topics[i] < topics[j] + }) + if len(topics) != 100 || topics[0] != fmt.Sprintf("%s_%d", topicName, 0) { + t.Fatalf("expected to see %v, got %v", topicName, topics) + } +} + // Ensure we only consume one partition if we only ask for one partition. func TestIssue337(t *testing.T) { t.Parallel()