From 1017c1e04842b5d964142937628fb8f1b9f67ca4 Mon Sep 17 00:00:00 2001 From: Tyler Treat Date: Tue, 28 Dec 2021 17:04:09 -0700 Subject: [PATCH 1/2] Fix async error reset in PublishAsync mock --- v2/common_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/v2/common_test.go b/v2/common_test.go index b7f5305..3b98d51 100644 --- a/v2/common_test.go +++ b/v2/common_test.go @@ -489,7 +489,6 @@ func (m *mockAPI) PublishAsync(stream proto.API_PublishAsyncServer) error { m.mu.Lock() if m.publishAsyncErr != nil { err := m.publishAsyncErr - m.publishAsyncErr = nil if m.autoClearError { m.publishAsyncErr = nil } From 57c1575d9d17b4c17ecabf5882292eee45a2706b Mon Sep 17 00:00:00 2001 From: Tyler Treat Date: Tue, 28 Dec 2021 17:05:14 -0700 Subject: [PATCH 2/2] Make PublishAsync more resilient Make PublishAsync more resilient by attempting to reconnect the PublishAsync streaming RPC if the connection is disrupted. --- v2/brokers.go | 180 +++++++++++++++++++++++++++++++++++++--------- v2/client.go | 27 +++++-- v2/client_test.go | 57 +++++++++++++++ 3 files changed, 222 insertions(+), 42 deletions(-) diff --git a/v2/brokers.go b/v2/brokers.go index 9947363..1823561 100644 --- a/v2/brokers.go +++ b/v2/brokers.go @@ -12,6 +12,8 @@ import ( proto "github.com/liftbridge-io/liftbridge-api/go" "github.com/serialx/hashring" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type ackReceivedFunc func(*proto.PublishResponse) @@ -124,7 +126,7 @@ func (b *brokers) FromAddr(addr string) (proto.APIClient, error) { return nil, fmt.Errorf("no broker found: %v", addr) } - return broker.client, nil + return broker.grpcClient.Client(), nil } func (b *brokers) ChooseBroker(selectionCriteria SelectionCriteria) (proto.APIClient, error) { @@ -144,48 +146,116 @@ func (b *brokers) ChooseBroker(selectionCriteria SelectionCriteria) (proto.APICl // Find server with lowest latency for i := 0; i < len(b.brokers); i++ { + status := b.brokers[i].Status() if i == 0 { - minLatency = int(b.brokers[i].status.LastKnownLatency) + minLatency = int(status.LastKnownLatency) broker = b.brokers[i] continue } - if int(b.brokers[i].status.LastKnownLatency) < minLatency { - minLatency = int(b.brokers[i].status.LastKnownLatency) + if int(status.LastKnownLatency) < minLatency { + minLatency = int(status.LastKnownLatency) broker = b.brokers[i] } } - return broker.client, nil + return broker.grpcClient.Client(), nil case Workload: // Find server with lowest work load minPartitionCount := -1 for i := 0; i < len(b.brokers); i++ { + status := b.brokers[i].Status() if i == 0 { - minPartitionCount = int(b.brokers[i].status.PartitionCount) + minPartitionCount = int(status.PartitionCount) broker = b.brokers[i] continue } if int(b.brokers[i].status.PartitionCount) < minPartitionCount { - minPartitionCount = int(b.brokers[i].status.PartitionCount) + minPartitionCount = int(status.PartitionCount) broker = b.brokers[i] } } - return broker.client, nil + return broker.grpcClient.Client(), nil case Random: // Return the current broker (randomly chosen) - return broker.client, nil + return broker.grpcClient.Client(), nil default: // Return the current broker (randomly chosen) - return broker.client, nil + return broker.grpcClient.Client(), nil } } -// PublicationStream returns a publication stream based on a stream name and a -// partition. -func (b *brokers) PublicationStream(stream string, partition int32) (proto.API_PublishAsyncClient, error) { +// grpcClient wraps a gRPC APIClient and API_PublishAsyncClient. +type grpcClient struct { + addr string + dialOpts []grpc.DialOption + conn *grpc.ClientConn + client proto.APIClient + asyncClient proto.API_PublishAsyncClient + mu sync.RWMutex + closed bool +} + +func newGrpcClient(ctx context.Context, addr string, opts []grpc.DialOption) (*grpcClient, error) { + g := &grpcClient{addr: addr, dialOpts: opts} + if err := g.redial(ctx); err != nil { + return nil, err + } + return g, nil +} + +func (g *grpcClient) redial(ctx context.Context) error { + g.mu.Lock() + defer g.mu.Unlock() + if g.closed { + return errors.New("client was closed") + } + oldConn := g.conn + conn, err := dialBroker(ctx, g.addr, g.dialOpts) + if err != nil { + return err + } + newClient := proto.NewAPIClient(conn) + newAsyncClient, err := newClient.PublishAsync(ctx) + if err != nil { + conn.Close() + return err + } + g.conn = conn + g.client = newClient + g.asyncClient = newAsyncClient + if oldConn != nil { + oldConn.Close() + } + return nil +} + +func (g *grpcClient) close() { + g.mu.Lock() + defer g.mu.Unlock() + if g.closed { + return + } + g.conn.Close() + g.closed = true +} + +func (g *grpcClient) Client() proto.APIClient { + g.mu.RLock() + defer g.mu.RUnlock() + return g.client +} + +func (g *grpcClient) AsyncClient() proto.API_PublishAsyncClient { + g.mu.RLock() + defer g.mu.RUnlock() + return g.asyncClient +} + +// GetGrpcClient returns a grpcClient based on a stream name and a partition. +func (b *brokers) GetGrpcClient(stream string, partition int32) (*grpcClient, error) { b.mu.RLock() defer b.mu.RUnlock() @@ -200,7 +270,7 @@ func (b *brokers) PublicationStream(stream string, partition int32) (proto.API_P return nil, fmt.Errorf("broker not found: %v", addr) } - return broker.stream, nil + return broker.grpcClient, nil } func brokerHashringKey(stream string, partition int32) string { @@ -209,28 +279,23 @@ func brokerHashringKey(stream string, partition int32) string { // broker represents a connection to a broker. type broker struct { - conn *grpc.ClientConn - client proto.APIClient - stream proto.API_PublishAsyncClient - wg sync.WaitGroup - status *brokerStatus + grpcClient *grpcClient + wg sync.WaitGroup + status *brokerStatus + closed chan struct{} + mu sync.RWMutex } func newBroker(ctx context.Context, addr string, opts []grpc.DialOption, ackReceived ackReceivedFunc) (*broker, error) { - conn, err := dialBroker(ctx, addr, opts) + client, err := newGrpcClient(ctx, addr, opts) if err != nil { return nil, err } b := &broker{ - conn: conn, - client: proto.NewAPIClient(conn), - status: &brokerStatus{PartitionCount: 0, LastKnownLatency: 0}, - } - - if b.stream, err = b.client.PublishAsync(ctx); err != nil { - conn.Close() - return nil, err + grpcClient: client, + status: &brokerStatus{PartitionCount: 0, LastKnownLatency: 0}, + closed: make(chan struct{}), } b.wg.Add(1) @@ -250,7 +315,7 @@ func (b *broker) updateStatus(ctx context.Context, addr string) error { // Measure instant server response time start := time.Now() - resp, err := b.client.FetchMetadata(ctx, &proto.FetchMetadataRequest{}) + resp, err := b.grpcClient.Client().FetchMetadata(ctx, &proto.FetchMetadataRequest{}) elapsed := time.Since(start) @@ -259,10 +324,9 @@ func (b *broker) updateStatus(ctx context.Context, addr string) error { } // Parse broker status - b.status.LastKnownLatency = elapsed + updatedStatus := &brokerStatus{LastKnownLatency: elapsed} // Count total number of partitions for this broker - for _, broker := range resp.Brokers { brokerInfo := &BrokerInfo{ id: broker.Id, @@ -272,35 +336,81 @@ func (b *broker) updateStatus(ctx context.Context, addr string) error { partitionCount: broker.PartitionCount, } if brokerInfo.Addr() == addr { - b.status.PartitionCount = brokerInfo.LeaderCount() + brokerInfo.PartitionCount() + updatedStatus.PartitionCount = brokerInfo.LeaderCount() + brokerInfo.PartitionCount() break } } + b.mu.Lock() + b.status = updatedStatus + b.mu.Unlock() + return nil +} +func (b *broker) Status() *brokerStatus { + b.mu.RLock() + defer b.mu.RUnlock() + return b.status } func (b *broker) Close() { - b.conn.Close() + select { + case <-b.closed: + return + default: + } + b.grpcClient.close() + close(b.closed) b.wg.Wait() } func (b *broker) dispatchAcks(ackReceived ackReceivedFunc) { + stream := b.grpcClient.AsyncClient() for { - resp, err := b.stream.Recv() + resp, err := stream.Recv() if err == io.EOF { return } if err != nil { - // TODO: reconnect? + // Check if the broker connection has been closed. + select { + case <-b.closed: + return + default: + } + if status.Code(err) == codes.Unavailable { + // Attempt to reconnect. + if err := b.reconnect(); err == nil { + stream = b.grpcClient.AsyncClient() + continue + } + } return } ackReceived(resp) } } +func (b *broker) reconnect() error { + b.mu.RLock() + var ( + err error + ctx = context.Background() + ) + b.mu.RUnlock() + for i := 0; i < 5; i++ { + if er := b.grpcClient.redial(ctx); er != nil { + err = er + sleepContext(ctx, 50*time.Millisecond) + continue + } + return nil + } + return err +} + func dialBroker(ctx context.Context, addr string, opts []grpc.DialOption) (*grpc.ClientConn, error) { // Perform a blocking dial if a context with a deadline has been provided. _, hasDeadline := ctx.Deadline() diff --git a/v2/client.go b/v2/client.go index 4dde300..7225231 100644 --- a/v2/client.go +++ b/v2/client.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "math/rand" "sync" "time" @@ -1249,7 +1250,7 @@ func (c *client) publishAsync(ctx context.Context, streamName string, value []by return err } - stream, err := c.brokers.PublicationStream(streamName, req.Partition) + grpcClient, err := c.brokers.GetGrpcClient(streamName, req.Partition) if err != nil { return fmt.Errorf("broker for stream: %w", err) } @@ -1280,15 +1281,27 @@ func (c *client) publishAsync(ctx context.Context, streamName string, value []by c.mu.Unlock() } - if err := stream.Send(req); err != nil { - c.removeAckContext(req.CorrelationId) - if status.Code(err) == codes.FailedPrecondition { - err = ErrReadonlyPartition + for i := 0; i < 5; i++ { + stream := grpcClient.AsyncClient() + if e := stream.Send(req); e != nil { + err = e + if e == io.EOF { + // We were disconnected, so attempt to use the reconnected + // stream (the dispatchAcks goroutine will attempt to + // reconnect). + sleepContext(ctx, 50*time.Millisecond) + continue + } + if status.Code(e) == codes.FailedPrecondition { + e = ErrReadonlyPartition + } + c.removeAckContext(req.CorrelationId) + return e } - return err + return nil } - return nil + return err } // PublishToSubject publishes a new message to the NATS subject. Note that diff --git a/v2/client_test.go b/v2/client_test.go index 4cc483d..b9217e3 100644 --- a/v2/client_test.go +++ b/v2/client_test.go @@ -1094,6 +1094,63 @@ func TestPublishAsyncInternalError(t *testing.T) { } } +func TestPublishAsyncReconnect(t *testing.T) { + server := newMockServer() + server.SetAutoClearError() + defer server.Stop(t) + port := server.Start(t) + + server.SetupMockFetchMetadataResponse(new(proto.FetchMetadataResponse)) + server.SetupMockPublishAsyncError(status.Error(codes.Unavailable, "disconnected")) + + client, err := Connect([]string{fmt.Sprintf("localhost:%d", port)}) + require.NoError(t, err) + defer client.Close() + + expectedAck := &proto.Ack{ + Stream: "foo", + PartitionSubject: "foo", + MsgSubject: "foo", + Offset: 0, + AckInbox: "ack", + AckPolicy: proto.AckPolicy_LEADER, + } + + server.SetupMockPublishAsyncResponse(&proto.PublishResponse{Ack: expectedAck}) + + ackC := make(chan *Ack) + err = client.PublishAsync(context.Background(), "foo", []byte("hello"), + func(ack *Ack, err error) { + require.NoError(t, err) + ackC <- ack + }, + ) + require.NoError(t, err) + + select { + case ack := <-ackC: + require.Equal(t, expectedAck.Stream, ack.Stream()) + require.Equal(t, expectedAck.PartitionSubject, ack.PartitionSubject()) + require.Equal(t, expectedAck.MsgSubject, ack.MessageSubject()) + require.Equal(t, expectedAck.Offset, ack.Offset()) + require.Equal(t, expectedAck.AckInbox, ack.AckInbox()) + require.Equal(t, expectedAck.CorrelationId, ack.CorrelationID()) + require.Equal(t, AckPolicy(expectedAck.AckPolicy), ack.AckPolicy()) + case <-time.After(time.Second): + t.Fatal("Did not receive expected ack") + } + + req := server.GetPublishAsyncRequests()[0] + require.Equal(t, []byte(nil), req.Key) + require.Equal(t, []byte("hello"), req.Value) + require.Equal(t, "foo", req.Stream) + require.Equal(t, int32(0), req.Partition) + require.Equal(t, map[string][]byte(nil), req.Headers) + require.Equal(t, "", req.AckInbox) + require.NotEqual(t, "", req.CorrelationId) + require.Equal(t, proto.AckPolicy_LEADER, req.AckPolicy) +} + func TestPublishToPartition(t *testing.T) { server := newMockServer() defer server.Stop(t)