diff --git a/pkg/kgo/client.go b/pkg/kgo/client.go index d1368f7e..48708f85 100644 --- a/pkg/kgo/client.go +++ b/pkg/kgo/client.go @@ -1325,6 +1325,8 @@ func (cl *Client) shardedRequest(ctx context.Context, req kmsg.Request) ([]Respo *kmsg.ListGroupsRequest, // key 16 *kmsg.DeleteRecordsRequest, // key 21 *kmsg.OffsetForLeaderEpochRequest, // key 23 + *kmsg.AddPartitionsToTxnRequest, // key 24 + *kmsg.WriteTxnMarkersRequest, // key 27 *kmsg.DescribeConfigsRequest, // key 32 *kmsg.AlterConfigsRequest, // key 33 *kmsg.AlterReplicaLogDirsRequest, // key 34 @@ -1775,8 +1777,6 @@ func (cl *Client) handleCoordinatorReq(ctx context.Context, req kmsg.Request) Re // names, we delete no coordinator. coordinator, resp, err := cl.handleReqWithCoordinator(ctx, func() (*broker, error) { return cl.broker(), nil }, coordinatorTypeTxn, "", req) return shard(coordinator, req, resp, err) - case *kmsg.AddPartitionsToTxnRequest: - return cl.handleCoordinatorReqSimple(ctx, coordinatorTypeTxn, t.TransactionalID, req) case *kmsg.AddOffsetsToTxnRequest: return cl.handleCoordinatorReqSimple(ctx, coordinatorTypeTxn, t.TransactionalID, req) case *kmsg.EndTxnRequest: @@ -1840,10 +1840,6 @@ func (cl *Client) handleReqWithCoordinator( // TXN case *kmsg.InitProducerIDResponse: code = t.ErrorCode - case *kmsg.AddPartitionsToTxnResponse: - if len(t.Topics) > 0 && len(t.Topics[0].Partitions) > 0 { - code = t.Topics[0].Partitions[0].ErrorCode - } case *kmsg.AddOffsetsToTxnResponse: code = t.ErrorCode case *kmsg.EndTxnResponse: @@ -2080,6 +2076,8 @@ func (cl *Client) handleShardedReq(ctx context.Context, req kmsg.Request) ([]Res sharder = &deleteRecordsSharder{cl} case *kmsg.OffsetForLeaderEpochRequest: sharder = &offsetForLeaderEpochSharder{cl} + case *kmsg.AddPartitionsToTxnRequest: + sharder = &addPartitionsToTxnSharder{cl} case *kmsg.WriteTxnMarkersRequest: sharder = &writeTxnMarkersSharder{cl} case *kmsg.DescribeConfigsRequest: @@ -2767,9 +2765,16 @@ func (cl *offsetFetchSharder) shard(ctx context.Context, kreq kmsg.Request, last broker: id, }) } + } else if len(req.Groups) == 1 { + single := offsetFetchGroupToReq(req.RequireStable, req.Groups[0]) + single.Groups = req.Groups + issues = append(issues, issueShard{ + req: single, + broker: id, + }) } else { issues = append(issues, issueShard{ - req: &pinReq{Request: req, pinMin: true, min: 8}, + req: &pinReq{Request: req, pinMin: len(req.Groups) > 1, min: 8}, broker: id, }) } @@ -2791,7 +2796,7 @@ func (cl *offsetFetchSharder) shard(ctx context.Context, kreq kmsg.Request, last } func (cl *offsetFetchSharder) onResp(kreq kmsg.Request, kresp kmsg.Response) error { - req := kreq.(*kmsg.OffsetFetchRequest) // we always issue pinned requests + req := kreq.(*kmsg.OffsetFetchRequest) resp := kresp.(*kmsg.OffsetFetchResponse) switch len(resp.Groups) { @@ -2876,9 +2881,16 @@ func (*findCoordinatorSharder) shard(_ context.Context, kreq kmsg.Request, lastE for key := range uniq { req.CoordinatorKeys = append(req.CoordinatorKeys, key) } + if len(req.CoordinatorKeys) == 1 { + req.CoordinatorKey = req.CoordinatorKeys[0] + } splitReq := errors.Is(lastErr, errBrokerTooOld) if !splitReq { + // With only one key, we do not need to split nor pin this. + if len(req.CoordinatorKeys) == 0 { + return []issueShard{{req: req, any: true}}, false, nil + } return []issueShard{{ req: &pinReq{Request: req, pinMin: true, min: 4}, any: true, @@ -2899,7 +2911,7 @@ func (*findCoordinatorSharder) shard(_ context.Context, kreq kmsg.Request, lastE } func (*findCoordinatorSharder) onResp(kreq kmsg.Request, kresp kmsg.Response) error { - req := kreq.(*kmsg.FindCoordinatorRequest) // we always issue pinned requests + req := kreq.(*kmsg.FindCoordinatorRequest) resp := kresp.(*kmsg.FindCoordinatorResponse) switch len(resp.Coordinators) { @@ -3293,6 +3305,195 @@ func (*offsetForLeaderEpochSharder) merge(sresps []ResponseShard) (kmsg.Response return merged, firstErr } +// handle sharding AddPartitionsToTXn, where v4+ switched to batch requests +type addPartitionsToTxnSharder struct{ *Client } + +func addPartitionsReqToTxn(req *kmsg.AddPartitionsToTxnRequest) { + t := kmsg.NewAddPartitionsToTxnRequestTransaction() + t.TransactionalID = req.TransactionalID + t.ProducerID = req.ProducerID + t.ProducerEpoch = req.ProducerEpoch + for i := range req.Topics { + rt := &req.Topics[i] + tt := kmsg.NewAddPartitionsToTxnRequestTransactionTopic() + tt.Topic = rt.Topic + tt.Partitions = rt.Partitions + t.Topics = append(t.Topics, tt) + } + req.Transactions = append(req.Transactions, t) +} + +func addPartitionsTxnToReq(req *kmsg.AddPartitionsToTxnRequest) { + if len(req.Transactions) != 1 { + return + } + t0 := &req.Transactions[0] + req.TransactionalID = t0.TransactionalID + req.ProducerID = t0.ProducerID + req.ProducerEpoch = t0.ProducerEpoch + for _, tt := range t0.Topics { + rt := kmsg.NewAddPartitionsToTxnRequestTopic() + rt.Topic = tt.Topic + rt.Partitions = tt.Partitions + req.Topics = append(req.Topics, rt) + } +} + +func addPartitionsTxnToResp(resp *kmsg.AddPartitionsToTxnResponse) { + if len(resp.Transactions) == 0 { + return + } + t0 := &resp.Transactions[0] + for _, tt := range t0.Topics { + rt := kmsg.NewAddPartitionsToTxnResponseTopic() + rt.Topic = tt.Topic + for _, tp := range tt.Partitions { + rp := kmsg.NewAddPartitionsToTxnResponseTopicPartition() + rp.Partition = tp.Partition + rp.ErrorCode = tp.ErrorCode + rt.Partitions = append(rt.Partitions, rp) + } + resp.Topics = append(resp.Topics, rt) + } +} + +func (cl *addPartitionsToTxnSharder) shard(ctx context.Context, kreq kmsg.Request, _ error) ([]issueShard, bool, error) { + req := kreq.(*kmsg.AddPartitionsToTxnRequest) + + if len(req.Transactions) == 0 { + addPartitionsReqToTxn(req) + } + txnIDs := make([]string, 0, len(req.Transactions)) + for i := range req.Transactions { + txnIDs = append(txnIDs, req.Transactions[i].TransactionalID) + } + coordinators := cl.loadCoordinators(ctx, coordinatorTypeTxn, txnIDs...) + + type unkerr struct { + err error + txn kmsg.AddPartitionsToTxnRequestTransaction + } + var ( + brokerReqs = make(map[int32]*kmsg.AddPartitionsToTxnRequest) + kerrs = make(map[*kerr.Error][]kmsg.AddPartitionsToTxnRequestTransaction) + unkerrs []unkerr + ) + + newReq := func(txns ...kmsg.AddPartitionsToTxnRequestTransaction) *kmsg.AddPartitionsToTxnRequest { + req := kmsg.NewPtrAddPartitionsToTxnRequest() + req.Transactions = txns + addPartitionsTxnToReq(req) + return req + } + + for _, txn := range req.Transactions { + berr := coordinators[txn.TransactionalID] + var ke *kerr.Error + switch { + case berr.err == nil: + brokerReq := brokerReqs[berr.b.meta.NodeID] + if brokerReq == nil { + brokerReq = newReq(txn) + brokerReqs[berr.b.meta.NodeID] = brokerReq + } else { + brokerReq.Transactions = append(brokerReq.Transactions, txn) + } + case errors.As(berr.err, &ke): + kerrs[ke] = append(kerrs[ke], txn) + default: + unkerrs = append(unkerrs, unkerr{berr.err, txn}) + } + } + + var issues []issueShard + for id, req := range brokerReqs { + issues = append(issues, issueShard{ + req: req, + broker: id, + }) + } + for _, unkerr := range unkerrs { + issues = append(issues, issueShard{ + req: newReq(unkerr.txn), + err: unkerr.err, + }) + } + for kerr, txns := range kerrs { + issues = append(issues, issueShard{ + req: newReq(txns...), + err: kerr, + }) + } + + return issues, true, nil // reshardable to load correct coordinators +} + +func (cl *addPartitionsToTxnSharder) onResp(kreq kmsg.Request, kresp kmsg.Response) error { + req := kreq.(*kmsg.AddPartitionsToTxnRequest) + resp := kresp.(*kmsg.AddPartitionsToTxnResponse) + + // We default to the top level error, which is used in v4+. For v3 + // (case 0), we use the per-partition error, which is the same for + // every partition on not_coordinator errors. + code := resp.ErrorCode + if code == 0 && len(resp.Transactions) == 0 { + // Convert v3 and prior to v4+ + resptxn := kmsg.NewAddPartitionsToTxnResponseTransaction() + resptxn.TransactionalID = req.TransactionalID + for _, rt := range resp.Topics { + respt := kmsg.NewAddPartitionsToTxnResponseTransactionTopic() + respt.Topic = rt.Topic + for _, rp := range rt.Partitions { + respp := kmsg.NewAddPartitionsToTxnResponseTransactionTopicPartition() + respp.Partition = rp.Partition + respp.ErrorCode = rp.ErrorCode + code = rp.ErrorCode // v3 and prior has per-partition errors, not top level + respt.Partitions = append(respt.Partitions, respp) + } + resptxn.Topics = append(resptxn.Topics, respt) + } + resp.Transactions = append(resp.Transactions, resptxn) + } else { + // Convert v4 to v3 and prior: either we have a top level error + // code or we have at least one transaction. + // + // If the code is non-zero, we convert it to per-partition error + // codes; v3 does not have a top level err. + addPartitionsTxnToResp(resp) + if code != 0 { + for _, reqt := range req.Topics { + respt := kmsg.NewAddPartitionsToTxnResponseTopic() + respt.Topic = reqt.Topic + for _, reqp := range reqt.Partitions { + respp := kmsg.NewAddPartitionsToTxnResponseTopicPartition() + respp.Partition = reqp + respp.ErrorCode = resp.ErrorCode + respt.Partitions = append(respt.Partitions, respp) + } + resp.Topics = append(resp.Topics, respt) + } + } + } + if err := kerr.ErrorForCode(code); cl.maybeDeleteStaleCoordinator(req.TransactionalID, coordinatorTypeTxn, err) { + return err + } + return nil +} + +func (*addPartitionsToTxnSharder) merge(sresps []ResponseShard) (kmsg.Response, error) { + merged := kmsg.NewPtrAddPartitionsToTxnResponse() + + firstErr := firstErrMerger(sresps, func(kresp kmsg.Response) { + resp := kresp.(*kmsg.AddPartitionsToTxnResponse) + merged.Version = resp.Version + merged.ThrottleMillis = resp.ThrottleMillis + merged.ErrorCode = resp.ErrorCode + merged.Transactions = append(merged.Transactions, resp.Transactions...) + }) + addPartitionsTxnToResp(merged) + return merged, firstErr +} + // handle sharding WriteTxnMarkersRequest type writeTxnMarkersSharder struct{ *Client }