Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kgo: allow PreTxnCommitFnContext to modify empty offsets #605

Merged
merged 2 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 68 additions & 45 deletions pkg/kgo/consumer_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,12 @@ type groupConsumer struct {
uncommitted uncommitted

// memberID and generation are written to in the join and sync loop,
// and mostly read within that loop. The reason these two are under the
// mutex is because they are read during commits, which can happen at
// any arbitrary moment. It is **recommended** to be done within the
// context of a group session, but (a) users may have some unique use
// cases, and (b) the onRevoke hook may take longer than a user
// and mostly read within that loop. This can be read during commits,
// which can happy any time. It is **recommended** to be done within
// the context of a group session, but (a) users may have some unique
// use cases, and (b) the onRevoke hook may take longer than a user
// expects, which would rotate a session.
memberID string
generation int32
memberGen groupMemberGen

// commitCancel and commitDone are set under mu before firing off an
// async commit request. If another commit happens, it cancels the
Expand All @@ -155,6 +153,42 @@ type groupConsumer struct {
leaveErr error // set before left is closed
}

type groupMemberGen struct {
v atomic.Value // *groupMemberGenT
}

type groupMemberGenT struct {
memberID string
generation int32
}

func (g *groupMemberGen) memberID() string {
memberID, _ := g.load()
return memberID
}

func (g *groupMemberGen) generation() int32 {
_, generation := g.load()
return generation
}

func (g *groupMemberGen) load() (memberID string, generation int32) {
v := g.v.Load()
if v == nil {
return "", -1
}
t := v.(*groupMemberGenT)
return t.memberID, t.generation
}

func (g *groupMemberGen) store(memberID string, generation int32) {
g.v.Store(&groupMemberGenT{memberID, generation})
}

func (g *groupMemberGen) storeMember(memberID string) {
g.store(memberID, g.generation())
}

// LeaveGroup leaves a group. Close automatically leaves the group, so this is
// only necessary to call if you plan to leave the group but continue to use
// the client. If a rebalance is in progress, this function waits for the
Expand Down Expand Up @@ -235,12 +269,7 @@ func (cl *Client) GroupMetadata() (string, int32) {
if g == nil {
return "", -1
}
g.mu.Lock()
defer g.mu.Unlock()
if g.memberID == "" {
return "", -1
}
return g.memberID, g.generation
return g.memberGen.load()
}

func (c *consumer) initGroup() {
Expand Down Expand Up @@ -488,17 +517,18 @@ func (g *groupConsumer) leave(ctx context.Context) {
return
}

memberID := g.memberGen.memberID()
g.cfg.logger.Log(LogLevelInfo, "leaving group",
"group", g.cfg.group,
"member_id", g.memberID, // lock not needed now since nothing can change it (manageDone)
"member_id", memberID,
)
// If we error when leaving, there is not much
// we can do. We may as well just return.
req := kmsg.NewPtrLeaveGroupRequest()
req.Group = g.cfg.group
req.MemberID = g.memberID
req.MemberID = memberID
member := kmsg.NewLeaveGroupRequestMember()
member.MemberID = g.memberID
member.MemberID = memberID
member.Reason = kmsg.StringPtr("client leaving group per normal operation")
req.Members = append(req.Members, member)

Expand Down Expand Up @@ -940,8 +970,9 @@ func (g *groupConsumer) heartbeat(fetchErrCh <-chan error, s *assignRevokeSessio
g.cfg.logger.Log(LogLevelDebug, "heartbeating", "group", g.cfg.group)
req := kmsg.NewPtrHeartbeatRequest()
req.Group = g.cfg.group
req.Generation = g.generation
req.MemberID = g.memberID
memberID, generation := g.memberGen.load()
req.Generation = generation
req.MemberID = memberID
req.InstanceID = g.cfg.instanceID
var resp *kmsg.HeartbeatResponse
if resp, err = req.RequestWith(g.ctx, g.cl); err == nil {
Expand Down Expand Up @@ -1075,7 +1106,7 @@ start:
joinReq.SessionTimeoutMillis = int32(g.cfg.sessionTimeout.Milliseconds())
joinReq.RebalanceTimeoutMillis = int32(g.cfg.rebalanceTimeout.Milliseconds())
joinReq.ProtocolType = g.cfg.protocol
joinReq.MemberID = g.memberID
joinReq.MemberID = g.memberGen.memberID()
joinReq.InstanceID = g.cfg.instanceID
joinReq.Protocols = g.joinGroupProtocols()
if joinWhy != "" {
Expand Down Expand Up @@ -1120,8 +1151,9 @@ start:

syncReq := kmsg.NewPtrSyncGroupRequest()
syncReq.Group = g.cfg.group
syncReq.Generation = g.generation
syncReq.MemberID = g.memberID
memberID, generation := g.memberGen.load()
syncReq.Generation = generation
syncReq.MemberID = memberID
syncReq.InstanceID = g.cfg.instanceID
syncReq.ProtocolType = &g.cfg.protocol
syncReq.Protocol = &protocol
Expand Down Expand Up @@ -1168,7 +1200,7 @@ start:
// and must trigger a rebalance.
if plan != nil && joinResp.SkipAssignment {
for _, assign := range plan {
if assign.MemberID == g.memberID {
if assign.MemberID == memberID {
if !bytes.Equal(assign.MemberAssignment, syncResp.MemberAssignment) {
g.rejoin("instance group leader restarted and was reassigned old plan, our topic interests changed and we must rejoin to force a rebalance")
}
Expand All @@ -1184,27 +1216,17 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo
if err = kerr.ErrorForCode(resp.ErrorCode); err != nil {
switch err {
case kerr.MemberIDRequired:
g.mu.Lock()
g.memberID = resp.MemberID // KIP-394
g.mu.Unlock()
g.memberGen.storeMember(resp.MemberID) // KIP-394
g.cfg.logger.Log(LogLevelInfo, "join returned MemberIDRequired, rejoining with response's MemberID", "group", g.cfg.group, "member_id", resp.MemberID)
return true, "", nil, nil
case kerr.UnknownMemberID:
g.mu.Lock()
g.memberID = ""
g.mu.Unlock()
g.memberGen.storeMember("")
g.cfg.logger.Log(LogLevelInfo, "join returned UnknownMemberID, rejoining without a member id", "group", g.cfg.group)
return true, "", nil, nil
}
return // Request retries as necessary, so this must be a failure
}

// Concurrent committing, while erroneous to do at the moment, could
// race with this function. We need to lock setting these two fields.
g.mu.Lock()
g.memberID = resp.MemberID
g.generation = resp.Generation
g.mu.Unlock()
g.memberGen.store(resp.MemberID, resp.Generation)

if resp.Protocol != nil {
protocol = *resp.Protocol
Expand Down Expand Up @@ -1252,9 +1274,9 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo
g.leader.Store(true)
g.cfg.logger.Log(LogLevelInfo, "joined, balancing group",
"group", g.cfg.group,
"member_id", g.memberID,
"member_id", resp.MemberID,
"instance_id", strptr{g.cfg.instanceID},
"generation", g.generation,
"generation", resp.Generation,
"balance_protocol", protocol,
"leader", true,
)
Expand All @@ -1263,18 +1285,18 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo
g.leader.Store(true)
g.cfg.logger.Log(LogLevelInfo, "joined as leader but unable to balance group due to KIP-345 limitations",
"group", g.cfg.group,
"member_id", g.memberID,
"member_id", resp.MemberID,
"instance_id", strptr{g.cfg.instanceID},
"generation", g.generation,
"generation", resp.Generation,
"balance_protocol", protocol,
"leader", true,
)
} else {
g.cfg.logger.Log(LogLevelInfo, "joined",
"group", g.cfg.group,
"member_id", g.memberID,
"member_id", resp.MemberID,
"instance_id", strptr{g.cfg.instanceID},
"generation", g.generation,
"generation", resp.Generation,
"leader", false,
)
}
Expand Down Expand Up @@ -1427,7 +1449,6 @@ func (g *groupConsumer) joinGroupProtocols() []kmsg.JoinGroupRequestProtocol {
for t, ps := range g.lastAssigned {
lastDup[t] = append([]int32(nil), ps...) // deep copy to allow modifications
}
gen := g.generation

g.mu.Unlock()

Expand All @@ -1436,6 +1457,7 @@ func (g *groupConsumer) joinGroupProtocols() []kmsg.JoinGroupRequestProtocol {
sort.Slice(partitions, func(i, j int) bool { return partitions[i] < partitions[j] }) // same for partitions
}

gen := g.memberGen.generation()
var protos []kmsg.JoinGroupRequestProtocol
for _, balancer := range g.cfg.balancers {
proto := kmsg.NewJoinGroupRequestProtocol()
Expand Down Expand Up @@ -1931,7 +1953,7 @@ func (g *groupConsumer) updateCommitted(
g.mu.Lock()
defer g.mu.Unlock()

if req.Generation != g.generation {
if req.Generation != g.memberGen.generation() {
return
}
if g.uncommitted == nil {
Expand Down Expand Up @@ -2764,8 +2786,9 @@ func (g *groupConsumer) commit(

req := kmsg.NewPtrOffsetCommitRequest()
req.Group = g.cfg.group
req.Generation = g.generation
req.MemberID = g.memberID
memberID, generation := g.memberGen.load()
req.Generation = generation
req.MemberID = memberID
req.InstanceID = g.cfg.instanceID

if ctx.Done() != nil {
Expand Down
102 changes: 53 additions & 49 deletions pkg/kgo/txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import (
"sync"
"time"

"github.com/twmb/franz-go/pkg/kerr"
"github.com/twmb/franz-go/pkg/kmsg"

"github.com/twmb/franz-go/pkg/kerr"
)

// TransactionEndTry is simply a named bool.
Expand Down Expand Up @@ -1060,7 +1061,13 @@ func (cl *Client) commitTransactionOffsets(
onDone(kmsg.NewPtrTxnOffsetCommitRequest(), kmsg.NewPtrTxnOffsetCommitResponse(), errNotGroup)
return nil
}
if len(uncommitted) == 0 {

req, err := g.prepareTxnOffsetCommit(ctx, uncommitted)
if err != nil {
onDone(req, kmsg.NewPtrTxnOffsetCommitResponse(), err)
return g
}
if len(req.Topics) == 0 {
onDone(kmsg.NewPtrTxnOffsetCommitRequest(), kmsg.NewPtrTxnOffsetCommitResponse(), nil)
return g
}
Expand Down Expand Up @@ -1088,7 +1095,7 @@ func (cl *Client) commitTransactionOffsets(
g.mu.Lock()
defer g.mu.Unlock()

g.commitTxn(ctx, uncommitted, unblockJoinSync)
g.commitTxn(ctx, req, unblockJoinSync)
return g
}

Expand Down Expand Up @@ -1139,18 +1146,10 @@ func (cl *Client) addOffsetsToTxn(ctx context.Context, group string) error {
// commitTxn is ALMOST EXACTLY THE SAME as commit, but changed for txn types
// and we avoid updateCommitted. We avoid updating because we manually
// SetOffsets when ending the transaction.
func (g *groupConsumer) commitTxn(
ctx context.Context,
uncommitted map[string]map[int32]EpochOffset,
onDone func(*kmsg.TxnOffsetCommitRequest, *kmsg.TxnOffsetCommitResponse, error),
) {
func (g *groupConsumer) commitTxn(ctx context.Context, req *kmsg.TxnOffsetCommitRequest, onDone func(*kmsg.TxnOffsetCommitRequest, *kmsg.TxnOffsetCommitResponse, error)) {
if onDone == nil { // note we must always call onDone
onDone = func(_ *kmsg.TxnOffsetCommitRequest, _ *kmsg.TxnOffsetCommitResponse, _ error) {}
}
if len(uncommitted) == 0 { // only empty if called thru autocommit / default revoke
onDone(kmsg.NewPtrTxnOffsetCommitRequest(), kmsg.NewPtrTxnOffsetCommitResponse(), nil)
return
}

if g.commitCancel != nil {
g.commitCancel() // cancel any prior commit
Expand All @@ -1169,21 +1168,6 @@ func (g *groupConsumer) commitTxn(
g.commitCancel = commitCancel
g.commitDone = commitDone

// We issue this request even if the producer ID is failed; the request
// will fail if it is.
//
// The id must have been set at least once by this point because of
// addOffsetsToTxn.
id, epoch, _ := g.cl.producerID()
req := kmsg.NewPtrTxnOffsetCommitRequest()
req.TransactionalID = *g.cl.cfg.txnID
req.Group = g.cfg.group
req.ProducerID = id
req.ProducerEpoch = epoch
req.Generation = g.generation
req.MemberID = g.memberID
req.InstanceID = g.cfg.instanceID

if ctx.Done() != nil {
go func() {
select {
Expand All @@ -1206,28 +1190,7 @@ func (g *groupConsumer) commitTxn(
<-priorDone // wait for any prior request to finish
}
}
g.cl.cfg.logger.Log(LogLevelDebug, "issuing txn offset commit", "uncommitted", uncommitted)

for topic, partitions := range uncommitted {
reqTopic := kmsg.NewTxnOffsetCommitRequestTopic()
reqTopic.Topic = topic
for partition, eo := range partitions {
reqPartition := kmsg.NewTxnOffsetCommitRequestTopicPartition()
reqPartition.Partition = partition
reqPartition.Offset = eo.Offset
reqPartition.LeaderEpoch = eo.Epoch
reqPartition.Metadata = &req.MemberID
reqTopic.Partitions = append(reqTopic.Partitions, reqPartition)
}
req.Topics = append(req.Topics, reqTopic)
}

if fn, ok := ctx.Value(txnCommitContextFn).(func(*kmsg.TxnOffsetCommitRequest) error); ok {
if err := fn(req); err != nil {
onDone(req, nil, err)
return
}
}
g.cl.cfg.logger.Log(LogLevelDebug, "issuing txn offset commit", "uncommitted", req)

var resp *kmsg.TxnOffsetCommitResponse
var err error
Expand All @@ -1241,3 +1204,44 @@ func (g *groupConsumer) commitTxn(
onDone(req, resp, nil)
}()
}

func (g *groupConsumer) prepareTxnOffsetCommit(ctx context.Context, uncommitted map[string]map[int32]EpochOffset) (*kmsg.TxnOffsetCommitRequest, error) {
req := kmsg.NewPtrTxnOffsetCommitRequest()

// We're now generating the producerID before addOffsetsToTxn.
// We will not make this request until after addOffsetsToTxn, but it's possible to fail here due to a failed producerID.
id, epoch, err := g.cl.producerID()
if err != nil {
return req, err
}

req.TransactionalID = *g.cl.cfg.txnID
req.Group = g.cfg.group
req.ProducerID = id
req.ProducerEpoch = epoch
memberID, generation := g.memberGen.load()
req.Generation = generation
req.MemberID = memberID
req.InstanceID = g.cfg.instanceID

for topic, partitions := range uncommitted {
reqTopic := kmsg.NewTxnOffsetCommitRequestTopic()
reqTopic.Topic = topic
for partition, eo := range partitions {
reqPartition := kmsg.NewTxnOffsetCommitRequestTopicPartition()
reqPartition.Partition = partition
reqPartition.Offset = eo.Offset
reqPartition.LeaderEpoch = eo.Epoch
reqPartition.Metadata = &req.MemberID
reqTopic.Partitions = append(reqTopic.Partitions, reqPartition)
}
req.Topics = append(req.Topics, reqTopic)
}

if fn, ok := ctx.Value(txnCommitContextFn).(func(*kmsg.TxnOffsetCommitRequest) error); ok {
if err := fn(req); err != nil {
return req, err
}
}
return req, nil
}