Skip to content

Commit

Permalink
Merge pull request #605 from twmb/580
Browse files Browse the repository at this point in the history
kgo: allow PreTxnCommitFnContext to modify empty offsets
  • Loading branch information
twmb authored Oct 22, 2023
2 parents d156322 + 54a7418 commit ec02fac
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 94 deletions.
113 changes: 68 additions & 45 deletions pkg/kgo/consumer_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,12 @@ type groupConsumer struct {
uncommitted uncommitted

// memberID and generation are written to in the join and sync loop,
// and mostly read within that loop. The reason these two are under the
// mutex is because they are read during commits, which can happen at
// any arbitrary moment. It is **recommended** to be done within the
// context of a group session, but (a) users may have some unique use
// cases, and (b) the onRevoke hook may take longer than a user
// and mostly read within that loop. This can be read during commits,
// which can happy any time. It is **recommended** to be done within
// the context of a group session, but (a) users may have some unique
// use cases, and (b) the onRevoke hook may take longer than a user
// expects, which would rotate a session.
memberID string
generation int32
memberGen groupMemberGen

// commitCancel and commitDone are set under mu before firing off an
// async commit request. If another commit happens, it cancels the
Expand All @@ -155,6 +153,42 @@ type groupConsumer struct {
leaveErr error // set before left is closed
}

type groupMemberGen struct {
v atomic.Value // *groupMemberGenT
}

type groupMemberGenT struct {
memberID string
generation int32
}

func (g *groupMemberGen) memberID() string {
memberID, _ := g.load()
return memberID
}

func (g *groupMemberGen) generation() int32 {
_, generation := g.load()
return generation
}

func (g *groupMemberGen) load() (memberID string, generation int32) {
v := g.v.Load()
if v == nil {
return "", -1
}
t := v.(*groupMemberGenT)
return t.memberID, t.generation
}

func (g *groupMemberGen) store(memberID string, generation int32) {
g.v.Store(&groupMemberGenT{memberID, generation})
}

func (g *groupMemberGen) storeMember(memberID string) {
g.store(memberID, g.generation())
}

// LeaveGroup leaves a group. Close automatically leaves the group, so this is
// only necessary to call if you plan to leave the group but continue to use
// the client. If a rebalance is in progress, this function waits for the
Expand Down Expand Up @@ -235,12 +269,7 @@ func (cl *Client) GroupMetadata() (string, int32) {
if g == nil {
return "", -1
}
g.mu.Lock()
defer g.mu.Unlock()
if g.memberID == "" {
return "", -1
}
return g.memberID, g.generation
return g.memberGen.load()
}

func (c *consumer) initGroup() {
Expand Down Expand Up @@ -488,17 +517,18 @@ func (g *groupConsumer) leave(ctx context.Context) {
return
}

memberID := g.memberGen.memberID()
g.cfg.logger.Log(LogLevelInfo, "leaving group",
"group", g.cfg.group,
"member_id", g.memberID, // lock not needed now since nothing can change it (manageDone)
"member_id", memberID,
)
// If we error when leaving, there is not much
// we can do. We may as well just return.
req := kmsg.NewPtrLeaveGroupRequest()
req.Group = g.cfg.group
req.MemberID = g.memberID
req.MemberID = memberID
member := kmsg.NewLeaveGroupRequestMember()
member.MemberID = g.memberID
member.MemberID = memberID
member.Reason = kmsg.StringPtr("client leaving group per normal operation")
req.Members = append(req.Members, member)

Expand Down Expand Up @@ -940,8 +970,9 @@ func (g *groupConsumer) heartbeat(fetchErrCh <-chan error, s *assignRevokeSessio
g.cfg.logger.Log(LogLevelDebug, "heartbeating", "group", g.cfg.group)
req := kmsg.NewPtrHeartbeatRequest()
req.Group = g.cfg.group
req.Generation = g.generation
req.MemberID = g.memberID
memberID, generation := g.memberGen.load()
req.Generation = generation
req.MemberID = memberID
req.InstanceID = g.cfg.instanceID
var resp *kmsg.HeartbeatResponse
if resp, err = req.RequestWith(g.ctx, g.cl); err == nil {
Expand Down Expand Up @@ -1075,7 +1106,7 @@ start:
joinReq.SessionTimeoutMillis = int32(g.cfg.sessionTimeout.Milliseconds())
joinReq.RebalanceTimeoutMillis = int32(g.cfg.rebalanceTimeout.Milliseconds())
joinReq.ProtocolType = g.cfg.protocol
joinReq.MemberID = g.memberID
joinReq.MemberID = g.memberGen.memberID()
joinReq.InstanceID = g.cfg.instanceID
joinReq.Protocols = g.joinGroupProtocols()
if joinWhy != "" {
Expand Down Expand Up @@ -1120,8 +1151,9 @@ start:

syncReq := kmsg.NewPtrSyncGroupRequest()
syncReq.Group = g.cfg.group
syncReq.Generation = g.generation
syncReq.MemberID = g.memberID
memberID, generation := g.memberGen.load()
syncReq.Generation = generation
syncReq.MemberID = memberID
syncReq.InstanceID = g.cfg.instanceID
syncReq.ProtocolType = &g.cfg.protocol
syncReq.Protocol = &protocol
Expand Down Expand Up @@ -1168,7 +1200,7 @@ start:
// and must trigger a rebalance.
if plan != nil && joinResp.SkipAssignment {
for _, assign := range plan {
if assign.MemberID == g.memberID {
if assign.MemberID == memberID {
if !bytes.Equal(assign.MemberAssignment, syncResp.MemberAssignment) {
g.rejoin("instance group leader restarted and was reassigned old plan, our topic interests changed and we must rejoin to force a rebalance")
}
Expand All @@ -1184,27 +1216,17 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo
if err = kerr.ErrorForCode(resp.ErrorCode); err != nil {
switch err {
case kerr.MemberIDRequired:
g.mu.Lock()
g.memberID = resp.MemberID // KIP-394
g.mu.Unlock()
g.memberGen.storeMember(resp.MemberID) // KIP-394
g.cfg.logger.Log(LogLevelInfo, "join returned MemberIDRequired, rejoining with response's MemberID", "group", g.cfg.group, "member_id", resp.MemberID)
return true, "", nil, nil
case kerr.UnknownMemberID:
g.mu.Lock()
g.memberID = ""
g.mu.Unlock()
g.memberGen.storeMember("")
g.cfg.logger.Log(LogLevelInfo, "join returned UnknownMemberID, rejoining without a member id", "group", g.cfg.group)
return true, "", nil, nil
}
return // Request retries as necessary, so this must be a failure
}

// Concurrent committing, while erroneous to do at the moment, could
// race with this function. We need to lock setting these two fields.
g.mu.Lock()
g.memberID = resp.MemberID
g.generation = resp.Generation
g.mu.Unlock()
g.memberGen.store(resp.MemberID, resp.Generation)

if resp.Protocol != nil {
protocol = *resp.Protocol
Expand Down Expand Up @@ -1252,9 +1274,9 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo
g.leader.Store(true)
g.cfg.logger.Log(LogLevelInfo, "joined, balancing group",
"group", g.cfg.group,
"member_id", g.memberID,
"member_id", resp.MemberID,
"instance_id", strptr{g.cfg.instanceID},
"generation", g.generation,
"generation", resp.Generation,
"balance_protocol", protocol,
"leader", true,
)
Expand All @@ -1263,18 +1285,18 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo
g.leader.Store(true)
g.cfg.logger.Log(LogLevelInfo, "joined as leader but unable to balance group due to KIP-345 limitations",
"group", g.cfg.group,
"member_id", g.memberID,
"member_id", resp.MemberID,
"instance_id", strptr{g.cfg.instanceID},
"generation", g.generation,
"generation", resp.Generation,
"balance_protocol", protocol,
"leader", true,
)
} else {
g.cfg.logger.Log(LogLevelInfo, "joined",
"group", g.cfg.group,
"member_id", g.memberID,
"member_id", resp.MemberID,
"instance_id", strptr{g.cfg.instanceID},
"generation", g.generation,
"generation", resp.Generation,
"leader", false,
)
}
Expand Down Expand Up @@ -1427,7 +1449,6 @@ func (g *groupConsumer) joinGroupProtocols() []kmsg.JoinGroupRequestProtocol {
for t, ps := range g.lastAssigned {
lastDup[t] = append([]int32(nil), ps...) // deep copy to allow modifications
}
gen := g.generation

g.mu.Unlock()

Expand All @@ -1436,6 +1457,7 @@ func (g *groupConsumer) joinGroupProtocols() []kmsg.JoinGroupRequestProtocol {
sort.Slice(partitions, func(i, j int) bool { return partitions[i] < partitions[j] }) // same for partitions
}

gen := g.memberGen.generation()
var protos []kmsg.JoinGroupRequestProtocol
for _, balancer := range g.cfg.balancers {
proto := kmsg.NewJoinGroupRequestProtocol()
Expand Down Expand Up @@ -1931,7 +1953,7 @@ func (g *groupConsumer) updateCommitted(
g.mu.Lock()
defer g.mu.Unlock()

if req.Generation != g.generation {
if req.Generation != g.memberGen.generation() {
return
}
if g.uncommitted == nil {
Expand Down Expand Up @@ -2764,8 +2786,9 @@ func (g *groupConsumer) commit(

req := kmsg.NewPtrOffsetCommitRequest()
req.Group = g.cfg.group
req.Generation = g.generation
req.MemberID = g.memberID
memberID, generation := g.memberGen.load()
req.Generation = generation
req.MemberID = memberID
req.InstanceID = g.cfg.instanceID

if ctx.Done() != nil {
Expand Down
102 changes: 53 additions & 49 deletions pkg/kgo/txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import (
"sync"
"time"

"github.com/twmb/franz-go/pkg/kerr"
"github.com/twmb/franz-go/pkg/kmsg"

"github.com/twmb/franz-go/pkg/kerr"
)

// TransactionEndTry is simply a named bool.
Expand Down Expand Up @@ -1060,7 +1061,13 @@ func (cl *Client) commitTransactionOffsets(
onDone(kmsg.NewPtrTxnOffsetCommitRequest(), kmsg.NewPtrTxnOffsetCommitResponse(), errNotGroup)
return nil
}
if len(uncommitted) == 0 {

req, err := g.prepareTxnOffsetCommit(ctx, uncommitted)
if err != nil {
onDone(req, kmsg.NewPtrTxnOffsetCommitResponse(), err)
return g
}
if len(req.Topics) == 0 {
onDone(kmsg.NewPtrTxnOffsetCommitRequest(), kmsg.NewPtrTxnOffsetCommitResponse(), nil)
return g
}
Expand Down Expand Up @@ -1088,7 +1095,7 @@ func (cl *Client) commitTransactionOffsets(
g.mu.Lock()
defer g.mu.Unlock()

g.commitTxn(ctx, uncommitted, unblockJoinSync)
g.commitTxn(ctx, req, unblockJoinSync)
return g
}

Expand Down Expand Up @@ -1139,18 +1146,10 @@ func (cl *Client) addOffsetsToTxn(ctx context.Context, group string) error {
// commitTxn is ALMOST EXACTLY THE SAME as commit, but changed for txn types
// and we avoid updateCommitted. We avoid updating because we manually
// SetOffsets when ending the transaction.
func (g *groupConsumer) commitTxn(
ctx context.Context,
uncommitted map[string]map[int32]EpochOffset,
onDone func(*kmsg.TxnOffsetCommitRequest, *kmsg.TxnOffsetCommitResponse, error),
) {
func (g *groupConsumer) commitTxn(ctx context.Context, req *kmsg.TxnOffsetCommitRequest, onDone func(*kmsg.TxnOffsetCommitRequest, *kmsg.TxnOffsetCommitResponse, error)) {
if onDone == nil { // note we must always call onDone
onDone = func(_ *kmsg.TxnOffsetCommitRequest, _ *kmsg.TxnOffsetCommitResponse, _ error) {}
}
if len(uncommitted) == 0 { // only empty if called thru autocommit / default revoke
onDone(kmsg.NewPtrTxnOffsetCommitRequest(), kmsg.NewPtrTxnOffsetCommitResponse(), nil)
return
}

if g.commitCancel != nil {
g.commitCancel() // cancel any prior commit
Expand All @@ -1169,21 +1168,6 @@ func (g *groupConsumer) commitTxn(
g.commitCancel = commitCancel
g.commitDone = commitDone

// We issue this request even if the producer ID is failed; the request
// will fail if it is.
//
// The id must have been set at least once by this point because of
// addOffsetsToTxn.
id, epoch, _ := g.cl.producerID()
req := kmsg.NewPtrTxnOffsetCommitRequest()
req.TransactionalID = *g.cl.cfg.txnID
req.Group = g.cfg.group
req.ProducerID = id
req.ProducerEpoch = epoch
req.Generation = g.generation
req.MemberID = g.memberID
req.InstanceID = g.cfg.instanceID

if ctx.Done() != nil {
go func() {
select {
Expand All @@ -1206,28 +1190,7 @@ func (g *groupConsumer) commitTxn(
<-priorDone // wait for any prior request to finish
}
}
g.cl.cfg.logger.Log(LogLevelDebug, "issuing txn offset commit", "uncommitted", uncommitted)

for topic, partitions := range uncommitted {
reqTopic := kmsg.NewTxnOffsetCommitRequestTopic()
reqTopic.Topic = topic
for partition, eo := range partitions {
reqPartition := kmsg.NewTxnOffsetCommitRequestTopicPartition()
reqPartition.Partition = partition
reqPartition.Offset = eo.Offset
reqPartition.LeaderEpoch = eo.Epoch
reqPartition.Metadata = &req.MemberID
reqTopic.Partitions = append(reqTopic.Partitions, reqPartition)
}
req.Topics = append(req.Topics, reqTopic)
}

if fn, ok := ctx.Value(txnCommitContextFn).(func(*kmsg.TxnOffsetCommitRequest) error); ok {
if err := fn(req); err != nil {
onDone(req, nil, err)
return
}
}
g.cl.cfg.logger.Log(LogLevelDebug, "issuing txn offset commit", "uncommitted", req)

var resp *kmsg.TxnOffsetCommitResponse
var err error
Expand All @@ -1241,3 +1204,44 @@ func (g *groupConsumer) commitTxn(
onDone(req, resp, nil)
}()
}

func (g *groupConsumer) prepareTxnOffsetCommit(ctx context.Context, uncommitted map[string]map[int32]EpochOffset) (*kmsg.TxnOffsetCommitRequest, error) {
req := kmsg.NewPtrTxnOffsetCommitRequest()

// We're now generating the producerID before addOffsetsToTxn.
// We will not make this request until after addOffsetsToTxn, but it's possible to fail here due to a failed producerID.
id, epoch, err := g.cl.producerID()
if err != nil {
return req, err
}

req.TransactionalID = *g.cl.cfg.txnID
req.Group = g.cfg.group
req.ProducerID = id
req.ProducerEpoch = epoch
memberID, generation := g.memberGen.load()
req.Generation = generation
req.MemberID = memberID
req.InstanceID = g.cfg.instanceID

for topic, partitions := range uncommitted {
reqTopic := kmsg.NewTxnOffsetCommitRequestTopic()
reqTopic.Topic = topic
for partition, eo := range partitions {
reqPartition := kmsg.NewTxnOffsetCommitRequestTopicPartition()
reqPartition.Partition = partition
reqPartition.Offset = eo.Offset
reqPartition.LeaderEpoch = eo.Epoch
reqPartition.Metadata = &req.MemberID
reqTopic.Partitions = append(reqTopic.Partitions, reqPartition)
}
req.Topics = append(req.Topics, reqTopic)
}

if fn, ok := ctx.Value(txnCommitContextFn).(func(*kmsg.TxnOffsetCommitRequest) error); ok {
if err := fn(req); err != nil {
return req, err
}
}
return req, nil
}

0 comments on commit ec02fac

Please sign in to comment.