diff --git a/pkg/kgo/consumer_group.go b/pkg/kgo/consumer_group.go index 6dc6bbd2..43203023 100644 --- a/pkg/kgo/consumer_group.go +++ b/pkg/kgo/consumer_group.go @@ -126,14 +126,12 @@ type groupConsumer struct { uncommitted uncommitted // memberID and generation are written to in the join and sync loop, - // and mostly read within that loop. The reason these two are under the - // mutex is because they are read during commits, which can happen at - // any arbitrary moment. It is **recommended** to be done within the - // context of a group session, but (a) users may have some unique use - // cases, and (b) the onRevoke hook may take longer than a user + // and mostly read within that loop. This can be read during commits, + // which can happy any time. It is **recommended** to be done within + // the context of a group session, but (a) users may have some unique + // use cases, and (b) the onRevoke hook may take longer than a user // expects, which would rotate a session. - memberID string - generation int32 + memberGen groupMemberGen // commitCancel and commitDone are set under mu before firing off an // async commit request. If another commit happens, it cancels the @@ -155,6 +153,42 @@ type groupConsumer struct { leaveErr error // set before left is closed } +type groupMemberGen struct { + v atomic.Value // *groupMemberGenT +} + +type groupMemberGenT struct { + memberID string + generation int32 +} + +func (g *groupMemberGen) memberID() string { + memberID, _ := g.load() + return memberID +} + +func (g *groupMemberGen) generation() int32 { + _, generation := g.load() + return generation +} + +func (g *groupMemberGen) load() (memberID string, generation int32) { + v := g.v.Load() + if v == nil { + return "", -1 + } + t := v.(*groupMemberGenT) + return t.memberID, t.generation +} + +func (g *groupMemberGen) store(memberID string, generation int32) { + g.v.Store(&groupMemberGenT{memberID, generation}) +} + +func (g *groupMemberGen) storeMember(memberID string) { + g.store(memberID, g.generation()) +} + // LeaveGroup leaves a group. Close automatically leaves the group, so this is // only necessary to call if you plan to leave the group but continue to use // the client. If a rebalance is in progress, this function waits for the @@ -235,12 +269,7 @@ func (cl *Client) GroupMetadata() (string, int32) { if g == nil { return "", -1 } - g.mu.Lock() - defer g.mu.Unlock() - if g.memberID == "" { - return "", -1 - } - return g.memberID, g.generation + return g.memberGen.load() } func (c *consumer) initGroup() { @@ -488,17 +517,18 @@ func (g *groupConsumer) leave(ctx context.Context) { return } + memberID := g.memberGen.memberID() g.cfg.logger.Log(LogLevelInfo, "leaving group", "group", g.cfg.group, - "member_id", g.memberID, // lock not needed now since nothing can change it (manageDone) + "member_id", memberID, ) // If we error when leaving, there is not much // we can do. We may as well just return. req := kmsg.NewPtrLeaveGroupRequest() req.Group = g.cfg.group - req.MemberID = g.memberID + req.MemberID = memberID member := kmsg.NewLeaveGroupRequestMember() - member.MemberID = g.memberID + member.MemberID = memberID member.Reason = kmsg.StringPtr("client leaving group per normal operation") req.Members = append(req.Members, member) @@ -940,8 +970,9 @@ func (g *groupConsumer) heartbeat(fetchErrCh <-chan error, s *assignRevokeSessio g.cfg.logger.Log(LogLevelDebug, "heartbeating", "group", g.cfg.group) req := kmsg.NewPtrHeartbeatRequest() req.Group = g.cfg.group - req.Generation = g.generation - req.MemberID = g.memberID + memberID, generation := g.memberGen.load() + req.Generation = generation + req.MemberID = memberID req.InstanceID = g.cfg.instanceID var resp *kmsg.HeartbeatResponse if resp, err = req.RequestWith(g.ctx, g.cl); err == nil { @@ -1075,7 +1106,7 @@ start: joinReq.SessionTimeoutMillis = int32(g.cfg.sessionTimeout.Milliseconds()) joinReq.RebalanceTimeoutMillis = int32(g.cfg.rebalanceTimeout.Milliseconds()) joinReq.ProtocolType = g.cfg.protocol - joinReq.MemberID = g.memberID + joinReq.MemberID = g.memberGen.memberID() joinReq.InstanceID = g.cfg.instanceID joinReq.Protocols = g.joinGroupProtocols() if joinWhy != "" { @@ -1120,8 +1151,9 @@ start: syncReq := kmsg.NewPtrSyncGroupRequest() syncReq.Group = g.cfg.group - syncReq.Generation = g.generation - syncReq.MemberID = g.memberID + memberID, generation := g.memberGen.load() + syncReq.Generation = generation + syncReq.MemberID = memberID syncReq.InstanceID = g.cfg.instanceID syncReq.ProtocolType = &g.cfg.protocol syncReq.Protocol = &protocol @@ -1168,7 +1200,7 @@ start: // and must trigger a rebalance. if plan != nil && joinResp.SkipAssignment { for _, assign := range plan { - if assign.MemberID == g.memberID { + if assign.MemberID == memberID { if !bytes.Equal(assign.MemberAssignment, syncResp.MemberAssignment) { g.rejoin("instance group leader restarted and was reassigned old plan, our topic interests changed and we must rejoin to force a rebalance") } @@ -1184,27 +1216,17 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo if err = kerr.ErrorForCode(resp.ErrorCode); err != nil { switch err { case kerr.MemberIDRequired: - g.mu.Lock() - g.memberID = resp.MemberID // KIP-394 - g.mu.Unlock() + g.memberGen.storeMember(resp.MemberID) // KIP-394 g.cfg.logger.Log(LogLevelInfo, "join returned MemberIDRequired, rejoining with response's MemberID", "group", g.cfg.group, "member_id", resp.MemberID) return true, "", nil, nil case kerr.UnknownMemberID: - g.mu.Lock() - g.memberID = "" - g.mu.Unlock() + g.memberGen.storeMember("") g.cfg.logger.Log(LogLevelInfo, "join returned UnknownMemberID, rejoining without a member id", "group", g.cfg.group) return true, "", nil, nil } return // Request retries as necessary, so this must be a failure } - - // Concurrent committing, while erroneous to do at the moment, could - // race with this function. We need to lock setting these two fields. - g.mu.Lock() - g.memberID = resp.MemberID - g.generation = resp.Generation - g.mu.Unlock() + g.memberGen.store(resp.MemberID, resp.Generation) if resp.Protocol != nil { protocol = *resp.Protocol @@ -1252,9 +1274,9 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo g.leader.Store(true) g.cfg.logger.Log(LogLevelInfo, "joined, balancing group", "group", g.cfg.group, - "member_id", g.memberID, + "member_id", resp.MemberID, "instance_id", strptr{g.cfg.instanceID}, - "generation", g.generation, + "generation", resp.Generation, "balance_protocol", protocol, "leader", true, ) @@ -1263,18 +1285,18 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo g.leader.Store(true) g.cfg.logger.Log(LogLevelInfo, "joined as leader but unable to balance group due to KIP-345 limitations", "group", g.cfg.group, - "member_id", g.memberID, + "member_id", resp.MemberID, "instance_id", strptr{g.cfg.instanceID}, - "generation", g.generation, + "generation", resp.Generation, "balance_protocol", protocol, "leader", true, ) } else { g.cfg.logger.Log(LogLevelInfo, "joined", "group", g.cfg.group, - "member_id", g.memberID, + "member_id", resp.MemberID, "instance_id", strptr{g.cfg.instanceID}, - "generation", g.generation, + "generation", resp.Generation, "leader", false, ) } @@ -1427,7 +1449,6 @@ func (g *groupConsumer) joinGroupProtocols() []kmsg.JoinGroupRequestProtocol { for t, ps := range g.lastAssigned { lastDup[t] = append([]int32(nil), ps...) // deep copy to allow modifications } - gen := g.generation g.mu.Unlock() @@ -1436,6 +1457,7 @@ func (g *groupConsumer) joinGroupProtocols() []kmsg.JoinGroupRequestProtocol { sort.Slice(partitions, func(i, j int) bool { return partitions[i] < partitions[j] }) // same for partitions } + gen := g.memberGen.generation() var protos []kmsg.JoinGroupRequestProtocol for _, balancer := range g.cfg.balancers { proto := kmsg.NewJoinGroupRequestProtocol() @@ -1931,7 +1953,7 @@ func (g *groupConsumer) updateCommitted( g.mu.Lock() defer g.mu.Unlock() - if req.Generation != g.generation { + if req.Generation != g.memberGen.generation() { return } if g.uncommitted == nil { @@ -2764,8 +2786,9 @@ func (g *groupConsumer) commit( req := kmsg.NewPtrOffsetCommitRequest() req.Group = g.cfg.group - req.Generation = g.generation - req.MemberID = g.memberID + memberID, generation := g.memberGen.load() + req.Generation = generation + req.MemberID = memberID req.InstanceID = g.cfg.instanceID if ctx.Done() != nil { diff --git a/pkg/kgo/txn.go b/pkg/kgo/txn.go index bc1382b0..a4cb8bcf 100644 --- a/pkg/kgo/txn.go +++ b/pkg/kgo/txn.go @@ -8,8 +8,9 @@ import ( "sync" "time" - "github.com/twmb/franz-go/pkg/kerr" "github.com/twmb/franz-go/pkg/kmsg" + + "github.com/twmb/franz-go/pkg/kerr" ) // TransactionEndTry is simply a named bool. @@ -1060,7 +1061,13 @@ func (cl *Client) commitTransactionOffsets( onDone(kmsg.NewPtrTxnOffsetCommitRequest(), kmsg.NewPtrTxnOffsetCommitResponse(), errNotGroup) return nil } - if len(uncommitted) == 0 { + + req, err := g.prepareTxnOffsetCommit(ctx, uncommitted) + if err != nil { + onDone(req, kmsg.NewPtrTxnOffsetCommitResponse(), err) + return g + } + if len(req.Topics) == 0 { onDone(kmsg.NewPtrTxnOffsetCommitRequest(), kmsg.NewPtrTxnOffsetCommitResponse(), nil) return g } @@ -1088,7 +1095,7 @@ func (cl *Client) commitTransactionOffsets( g.mu.Lock() defer g.mu.Unlock() - g.commitTxn(ctx, uncommitted, unblockJoinSync) + g.commitTxn(ctx, req, unblockJoinSync) return g } @@ -1139,18 +1146,10 @@ func (cl *Client) addOffsetsToTxn(ctx context.Context, group string) error { // commitTxn is ALMOST EXACTLY THE SAME as commit, but changed for txn types // and we avoid updateCommitted. We avoid updating because we manually // SetOffsets when ending the transaction. -func (g *groupConsumer) commitTxn( - ctx context.Context, - uncommitted map[string]map[int32]EpochOffset, - onDone func(*kmsg.TxnOffsetCommitRequest, *kmsg.TxnOffsetCommitResponse, error), -) { +func (g *groupConsumer) commitTxn(ctx context.Context, req *kmsg.TxnOffsetCommitRequest, onDone func(*kmsg.TxnOffsetCommitRequest, *kmsg.TxnOffsetCommitResponse, error)) { if onDone == nil { // note we must always call onDone onDone = func(_ *kmsg.TxnOffsetCommitRequest, _ *kmsg.TxnOffsetCommitResponse, _ error) {} } - if len(uncommitted) == 0 { // only empty if called thru autocommit / default revoke - onDone(kmsg.NewPtrTxnOffsetCommitRequest(), kmsg.NewPtrTxnOffsetCommitResponse(), nil) - return - } if g.commitCancel != nil { g.commitCancel() // cancel any prior commit @@ -1169,21 +1168,6 @@ func (g *groupConsumer) commitTxn( g.commitCancel = commitCancel g.commitDone = commitDone - // We issue this request even if the producer ID is failed; the request - // will fail if it is. - // - // The id must have been set at least once by this point because of - // addOffsetsToTxn. - id, epoch, _ := g.cl.producerID() - req := kmsg.NewPtrTxnOffsetCommitRequest() - req.TransactionalID = *g.cl.cfg.txnID - req.Group = g.cfg.group - req.ProducerID = id - req.ProducerEpoch = epoch - req.Generation = g.generation - req.MemberID = g.memberID - req.InstanceID = g.cfg.instanceID - if ctx.Done() != nil { go func() { select { @@ -1206,28 +1190,7 @@ func (g *groupConsumer) commitTxn( <-priorDone // wait for any prior request to finish } } - g.cl.cfg.logger.Log(LogLevelDebug, "issuing txn offset commit", "uncommitted", uncommitted) - - for topic, partitions := range uncommitted { - reqTopic := kmsg.NewTxnOffsetCommitRequestTopic() - reqTopic.Topic = topic - for partition, eo := range partitions { - reqPartition := kmsg.NewTxnOffsetCommitRequestTopicPartition() - reqPartition.Partition = partition - reqPartition.Offset = eo.Offset - reqPartition.LeaderEpoch = eo.Epoch - reqPartition.Metadata = &req.MemberID - reqTopic.Partitions = append(reqTopic.Partitions, reqPartition) - } - req.Topics = append(req.Topics, reqTopic) - } - - if fn, ok := ctx.Value(txnCommitContextFn).(func(*kmsg.TxnOffsetCommitRequest) error); ok { - if err := fn(req); err != nil { - onDone(req, nil, err) - return - } - } + g.cl.cfg.logger.Log(LogLevelDebug, "issuing txn offset commit", "uncommitted", req) var resp *kmsg.TxnOffsetCommitResponse var err error @@ -1241,3 +1204,44 @@ func (g *groupConsumer) commitTxn( onDone(req, resp, nil) }() } + +func (g *groupConsumer) prepareTxnOffsetCommit(ctx context.Context, uncommitted map[string]map[int32]EpochOffset) (*kmsg.TxnOffsetCommitRequest, error) { + req := kmsg.NewPtrTxnOffsetCommitRequest() + + // We're now generating the producerID before addOffsetsToTxn. + // We will not make this request until after addOffsetsToTxn, but it's possible to fail here due to a failed producerID. + id, epoch, err := g.cl.producerID() + if err != nil { + return req, err + } + + req.TransactionalID = *g.cl.cfg.txnID + req.Group = g.cfg.group + req.ProducerID = id + req.ProducerEpoch = epoch + memberID, generation := g.memberGen.load() + req.Generation = generation + req.MemberID = memberID + req.InstanceID = g.cfg.instanceID + + for topic, partitions := range uncommitted { + reqTopic := kmsg.NewTxnOffsetCommitRequestTopic() + reqTopic.Topic = topic + for partition, eo := range partitions { + reqPartition := kmsg.NewTxnOffsetCommitRequestTopicPartition() + reqPartition.Partition = partition + reqPartition.Offset = eo.Offset + reqPartition.LeaderEpoch = eo.Epoch + reqPartition.Metadata = &req.MemberID + reqTopic.Partitions = append(reqTopic.Partitions, reqPartition) + } + req.Topics = append(req.Topics, reqTopic) + } + + if fn, ok := ctx.Value(txnCommitContextFn).(func(*kmsg.TxnOffsetCommitRequest) error); ok { + if err := fn(req); err != nil { + return req, err + } + } + return req, nil +}