Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kgo: allow record ctx cancelation to propagate a bit more #792

Merged
merged 2 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions pkg/kgo/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func isRetryableBrokerErr(err error) bool {
}
// We could have a retryable producer ID failure, which then bubbled up
// as errProducerIDLoadFail so as to be retried later.
if errors.Is(err, errProducerIDLoadFail) {
if pe := (*errProducerIDLoadFail)(nil); errors.As(err, &pe) {
return true
}
// We could have chosen a broker, and then a concurrent metadata update
Expand Down Expand Up @@ -139,8 +139,6 @@ var (
// restart a new connection ourselves.
errSaslReauthLoop = errors.New("the broker is repeatedly giving us sasl lifetimes that are too short to write a request")

errProducerIDLoadFail = errors.New("unable to initialize a producer ID due to request failures")

// A temporary error returned when Kafka replies with a different
// correlation ID than we were expecting for the request the client
// issued.
Expand Down Expand Up @@ -224,6 +222,19 @@ type ErrFirstReadEOF struct {
err error
}

type errProducerIDLoadFail struct {
err error
}

func (e *errProducerIDLoadFail) Error() string {
if e.err == nil {
return "unable to initialize a producer ID due to request failures"
}
return fmt.Sprintf("unable to initialize a producer ID due to request failures: %v", e.err)
}

func (e *errProducerIDLoadFail) Unwrap() error { return e.err }

const (
firstReadSASL uint8 = iota
firstReadTLS
Expand Down
26 changes: 26 additions & 0 deletions pkg/kgo/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@ var (
npartitionsAt int64
)

type slowConn struct {
net.Conn
}

func (s *slowConn) Write(p []byte) (int, error) {
time.Sleep(100 * time.Millisecond)
return s.Conn.Write(p)
}

func (s *slowConn) Read(p []byte) (int, error) {
time.Sleep(100 * time.Millisecond)
return s.Conn.Read(p)
}

type slowDialer struct {
d net.Dialer
}

func (s *slowDialer) DialContext(ctx context.Context, network, host string) (net.Conn, error) {
c, err := s.d.DialContext(ctx, network, host)
if err != nil {
return nil, err
}
return &slowConn{c}, nil
}

func init() {
var err error
if n, _ := strconv.Atoi(os.Getenv("KGO_TEST_RF")); n > 0 {
Expand Down
124 changes: 124 additions & 0 deletions pkg/kgo/produce_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"
"sync/atomic"
"testing"
"time"

"github.com/twmb/franz-go/pkg/kbin"
"github.com/twmb/franz-go/pkg/kmsg"
Expand Down Expand Up @@ -71,6 +72,129 @@ func TestClient_Produce(t *testing.T) {
}
}

// The produce below actually SUCCEEDS if the code for 769 is not working
// correctly. 769 is about a hanging produce not obeying a record cancelation,
// but we can simulate the same thing.
func TestIssue769(t *testing.T) {
t.Parallel()

topic, cleanup := tmpTopic(t)
defer cleanup()

cl, _ := newTestClient(
DefaultProduceTopic(topic),
UnknownTopicRetries(-1),
Dialer(new(slowDialer).DialContext),
)
defer cl.Close()

ctx, cancel := context.WithCancel(context.Background())
cancel()
canceled := &Record{Value: []byte("foo"), Context: ctx}
okay := &Record{Value: []byte("foo")}

// First check: ensure that an already-canceled record bails right
// away. This actually bails in the unknown-topic bit of logic,
// although there is no way to surface that to the end user.
{
done := make(chan struct{})
var rerr error
cl.Produce(context.Background(), canceled, func(_ *Record, err2 error) {
defer close(done)
rerr = err2
})
timer := time.NewTimer(3 * time.Second)
select {
case <-done:
case <-timer.C:
t.Fatal("expected record to fail within 3s")
}
if !errors.Is(rerr, context.Canceled) {
t.Errorf("got %v != exp context.Canceled", rerr)
}
}

// We have to produce one record successfully to ensure the topic is
// known, then we modify the guts of the client to forget the loaded
// producer ID.
{
done := make(chan struct{})
var rerr error
cl.Produce(context.Background(), okay, func(_ *Record, err2 error) {
defer close(done)
rerr = err2
})
<-done
if rerr != nil {
t.Fatal("unexpected error on the first produce")
}
cl.producer.id.Store(&producerID{
id: -1,
epoch: -1,
err: errReloadProducerID,
})
}

// With a loaded topic but forgotten producer ID, we now ensure that a
// canceled record fails in the producer ID portion.
{
done := make(chan struct{})
var rerr error
cl.Produce(context.Background(), canceled, func(_ *Record, err2 error) {
defer close(done)
rerr = err2
})
timer := time.NewTimer(3 * time.Second)
select {
case <-done:
case <-timer.C:
t.Fatal("expected record to fail within 3s")
}
if pe := (*errProducerIDLoadFail)(nil); !errors.As(rerr, &pe) || !(errors.Is(pe.err, context.Canceled) || strings.Contains(pe.err.Error(), "canceled")) {
t.Errorf("got %v != exp errProducerIDLoadFail{context.Canceled}", rerr)
}
}

// We now produce successfully again to ensure the next attempt fails
// after the producer ID stage.
{
done := make(chan struct{})
var rerr error
cl.Produce(context.Background(), okay, func(_ *Record, err2 error) {
defer close(done)
rerr = err2
})
cl.Flush(context.Background())
<-done
if rerr != nil {
t.Fatal("unexpected error on the first produce")
}
}

// This fails before the produce request is issued, which is the furthest we
// can take the test. We do not use record context's in issued produce requests.
{
done := make(chan struct{})
var rerr error
cl.Produce(context.Background(), canceled, func(_ *Record, err2 error) {
defer close(done)
rerr = err2
})
timer := time.NewTimer(3 * time.Second)
select {
case <-done:
case <-timer.C:
t.Fatal("expected record to fail within 3s")
}
if pe := (*errProducerIDLoadFail)(nil); errors.As(rerr, &pe) {
t.Error("unexpectedly got errProducerIDLoadFail")
}
if !errors.Is(rerr, context.Canceled) {
t.Errorf("got %v != context.Canceled", rerr)
}
}
}

// This file contains golden tests against kmsg AppendTo's to ensure our custom
// encoding is correct.

Expand Down
26 changes: 17 additions & 9 deletions pkg/kgo/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,11 @@ func (cl *Client) TryProduce(
// retries. If any of these conditions are hit and it is currently safe to fail
// records, all buffered records for the relevant partition are failed. Only
// the first record's context in a batch is considered when determining whether
// the batch should be canceled.
// the batch should be canceled. A record is not safe to fail if the client
// is idempotently producing and a request has been sent; in this case, the
// client cannot know if the broker actually processed the request (if so, then
// removing the records from the client will create errors the next time you
// produce).
//
// If the client is transactional and a transaction has not been begun, the
// promise is immediately called with an error corresponding to not being in a
Expand Down Expand Up @@ -679,7 +683,7 @@ func (cl *Client) ProducerID(ctx context.Context) (int64, int16, error) {

go func() {
defer close(done)
id, epoch, err = cl.producerID()
id, epoch, err = cl.producerID(ctx2fn(ctx))
}()

select {
Expand All @@ -701,7 +705,7 @@ var errReloadProducerID = errors.New("producer id needs reloading")
// initProducerID initializes the client's producer ID for idempotent
// producing only (no transactions, which are more special). After the first
// load, this clears all buffered unknown topics.
func (cl *Client) producerID() (int64, int16, error) {
func (cl *Client) producerID(ctxFn func() context.Context) (int64, int16, error) {
p := &cl.producer

id := p.id.Load().(*producerID)
Expand Down Expand Up @@ -730,7 +734,7 @@ func (cl *Client) producerID() (int64, int16, error) {
}
p.id.Store(id)
} else {
newID, keep := cl.doInitProducerID(id.id, id.epoch)
newID, keep := cl.doInitProducerID(ctxFn, id.id, id.epoch)
if keep {
id = newID
// Whenever we have a new producer ID, we need
Expand All @@ -748,7 +752,7 @@ func (cl *Client) producerID() (int64, int16, error) {
id = &producerID{
id: id.id,
epoch: id.epoch,
err: errProducerIDLoadFail,
err: &errProducerIDLoadFail{newID.err},
}
}
}
Expand Down Expand Up @@ -825,7 +829,7 @@ func (cl *Client) failProducerID(id int64, epoch int16, err error) {

// doInitProducerID inits the idempotent ID and potentially the transactional
// producer epoch, returning whether to keep the result.
func (cl *Client) doInitProducerID(lastID int64, lastEpoch int16) (*producerID, bool) {
func (cl *Client) doInitProducerID(ctxFn func() context.Context, lastID int64, lastEpoch int16) (*producerID, bool) {
cl.cfg.logger.Log(LogLevelInfo, "initializing producer id")
req := kmsg.NewPtrInitProducerIDRequest()
req.TransactionalID = cl.cfg.txnID
Expand All @@ -835,7 +839,8 @@ func (cl *Client) doInitProducerID(lastID int64, lastEpoch int16) (*producerID,
req.TransactionTimeoutMillis = int32(cl.cfg.txnTimeout.Milliseconds())
}

resp, err := req.RequestWith(cl.ctx, cl)
ctx := ctxFn()
resp, err := req.RequestWith(ctx, cl)
if err != nil {
if errors.Is(err, errUnknownRequestKey) || errors.Is(err, errBrokerTooOld) {
cl.cfg.logger.Log(LogLevelInfo, "unable to initialize a producer id because the broker is too old or the client is pinned to an old version, continuing without a producer id")
Expand Down Expand Up @@ -940,13 +945,14 @@ func (cl *Client) addUnknownTopicRecord(pr promisedRec) {
}
unknown.buffered = append(unknown.buffered, pr)
if len(unknown.buffered) == 1 {
go cl.waitUnknownTopic(pr.ctx, pr.Topic, unknown)
go cl.waitUnknownTopic(pr.ctx, pr.Record.Context, pr.Topic, unknown)
}
}

// waitUnknownTopic waits for a notification
func (cl *Client) waitUnknownTopic(
rctx context.Context,
pctx context.Context, // context passed to Produce
rctx context.Context, // context on the record itself
topic string,
unknown *unknownTopicProduces,
) {
Expand Down Expand Up @@ -974,6 +980,8 @@ func (cl *Client) waitUnknownTopic(

for err == nil {
select {
case <-pctx.Done():
err = pctx.Err()
case <-rctx.Done():
err = rctx.Err()
case <-cl.ctx.Done():
Expand Down
Loading