From 5d24cfe267f1877345c471e24c366fd341d20121 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?And=C5=BEej=20Maciusovi=C4=8D?= Date: Fri, 22 Mar 2024 11:37:23 +0200 Subject: [PATCH] Fix grpc stream leak (#233) --- .../daemon/state/container_stats_pipeline.go | 35 ++------ cmd/agent/daemon/state/controller.go | 20 +---- cmd/agent/daemon/state/events_pipeline.go | 42 ++------- cmd/controller/state/delta/controller_test.go | 3 + go.mod | 2 + go.sum | 2 + pkg/castai/client_test.go | 5 ++ pkg/castai/errors.go | 19 ++++ pkg/castai/logs_exporter.go | 30 ++----- pkg/castai/write_stream.go | 88 +++++++++++++++++++ pkg/castai/write_stream_test.go | 86 ++++++++++++++++++ 11 files changed, 231 insertions(+), 101 deletions(-) create mode 100644 pkg/castai/errors.go create mode 100644 pkg/castai/write_stream.go create mode 100644 pkg/castai/write_stream_test.go diff --git a/cmd/agent/daemon/state/container_stats_pipeline.go b/cmd/agent/daemon/state/container_stats_pipeline.go index d4bb761b..b01d0111 100644 --- a/cmd/agent/daemon/state/container_stats_pipeline.go +++ b/cmd/agent/daemon/state/container_stats_pipeline.go @@ -9,12 +9,13 @@ import ( castpb "github.com/castai/kvisor/api/v1/runtime" "github.com/castai/kvisor/cmd/agent/daemon/netstats" + "github.com/castai/kvisor/pkg/castai" "github.com/castai/kvisor/pkg/containers" "github.com/castai/kvisor/pkg/ebpftracer" "github.com/castai/kvisor/pkg/metrics" "github.com/castai/kvisor/pkg/stats" "github.com/samber/lo" - "google.golang.org/grpc/codes" + "google.golang.org/grpc" "k8s.io/apimachinery/pkg/api/resource" ) @@ -22,19 +23,17 @@ func (c *Controller) runContainerStatsPipeline(ctx context.Context) error { c.log.Info("running container stats sink loop") defer c.log.Info("container stats sink loop done") - var writeStream castpb.RuntimeSecurityAgentAPI_ContainerStatsWriteStreamClient - var err error - defer func() { - if writeStream != nil { - _ = writeStream.CloseSend() - } - }() + ws := castai.NewWriteStream[*castpb.ContainerStatsBatch, *castpb.WriteStreamResponse](ctx, func(ctx context.Context) (grpc.ClientStream, error) { + return c.castClient.GRPC.ContainerStatsWriteStream(ctx) + }) + defer ws.Close() + ws.ReopenDelay = c.writeStreamCreateRetryDelay send := func(batch *castpb.ContainerStatsBatch) { c.log.Debugf("sending container cgroup stats, items=%d", len(batch.GetItems())) - if err := writeStream.Send(batch); err != nil { + if err := ws.Send(batch); err != nil { if errors.Is(err, io.EOF) { - writeStream = nil + return } c.log.Errorf("sending container cgroup stats: %v", err) return @@ -46,22 +45,6 @@ func (c *Controller) runContainerStatsPipeline(ctx context.Context) error { defer ticker.Stop() for { - // Create stream. - if writeStream == nil { - select { - case <-ctx.Done(): - return ctx.Err() - default: - writeStream, err = c.castClient.GRPC.ContainerStatsWriteStream(ctx) - if err != nil { - if !isGRPCError(err, codes.Unavailable, codes.Canceled) { - c.log.Warnf("create write stream: %v", err) - } - time.Sleep(c.writeStreamCreateRetryDelay) - continue - } - } - } select { case <-ctx.Done(): diff --git a/cmd/agent/daemon/state/controller.go b/cmd/agent/daemon/state/controller.go index 474d7d8e..4bb17633 100644 --- a/cmd/agent/daemon/state/controller.go +++ b/cmd/agent/daemon/state/controller.go @@ -6,21 +6,18 @@ import ( "sync" "time" + castpb "github.com/castai/kvisor/api/v1/runtime" "github.com/castai/kvisor/cmd/agent/daemon/conntrack" "github.com/castai/kvisor/cmd/agent/daemon/enrichment" "github.com/castai/kvisor/cmd/agent/daemon/netstats" "github.com/castai/kvisor/cmd/agent/kube" "github.com/castai/kvisor/pkg/castai" "github.com/castai/kvisor/pkg/cgroup" + "github.com/castai/kvisor/pkg/containers" "github.com/castai/kvisor/pkg/ebpftracer" "github.com/castai/kvisor/pkg/ebpftracer/signature" "github.com/castai/kvisor/pkg/logging" "golang.org/x/sync/errgroup" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - castpb "github.com/castai/kvisor/api/v1/runtime" - "github.com/castai/kvisor/pkg/containers" ) type Config struct { @@ -161,19 +158,6 @@ type syscallScrapePoint struct { syscalls map[ebpftracer.SyscallID]uint64 } -func isGRPCError(err error, codes ...codes.Code) bool { - st, ok := status.FromError(err) - if !ok { - return false - } - for _, code := range codes { - if st.Code() == code { - return true - } - } - return false -} - func (c *Controller) MuteNamespace(namespace string) error { c.mutedNamespacesMu.Lock() c.mutedNamespaces[namespace] = struct{}{} diff --git a/cmd/agent/daemon/state/events_pipeline.go b/cmd/agent/daemon/state/events_pipeline.go index 711ccfa8..4445bac6 100644 --- a/cmd/agent/daemon/state/events_pipeline.go +++ b/cmd/agent/daemon/state/events_pipeline.go @@ -2,15 +2,13 @@ package state import ( "context" - "errors" "fmt" - "io" - "time" castpb "github.com/castai/kvisor/api/v1/runtime" + "github.com/castai/kvisor/pkg/castai" "github.com/castai/kvisor/pkg/metrics" "github.com/prometheus/client_golang/prometheus" - "google.golang.org/grpc/codes" + "google.golang.org/grpc" "k8s.io/apimachinery/pkg/types" ) @@ -19,43 +17,19 @@ func (c *Controller) runEventsExportLoop(ctx context.Context) error { c.log.Info("running events sink loop") defer c.log.Info("events sink loop done") - var writeStream castpb.RuntimeSecurityAgentAPI_EventsWriteStreamClient - var err error - - defer func() { - if writeStream != nil { - _ = writeStream.CloseSend() - } - }() + ws := castai.NewWriteStream[*castpb.Event, *castpb.WriteStreamResponse](ctx, func(ctx context.Context) (grpc.ClientStream, error) { + return c.castClient.GRPC.EventsWriteStream(ctx) + }) + defer ws.Close() + ws.ReopenDelay = c.writeStreamCreateRetryDelay for { - // Create stream. - if writeStream == nil { - select { - case <-ctx.Done(): - return ctx.Err() - default: - writeStream, err = c.castClient.GRPC.EventsWriteStream(ctx) - if err != nil { - if !isGRPCError(err, codes.Unavailable, codes.Canceled) { - c.log.Warnf("create write stream: %v", err) - } - time.Sleep(c.writeStreamCreateRetryDelay) - continue - } - } - } - select { case <-ctx.Done(): return ctx.Err() case e := <-c.eventsExportQueue: c.enrichEvent(e) - if err := writeStream.Send(e); err != nil { - if errors.Is(err, io.EOF) { - writeStream = nil - } - c.log.Errorf("sending event: %v", err) + if err := ws.Send(e); err != nil { continue } metrics.AgentExportedEventsTotal.With(prometheus.Labels{metrics.EventTypeLabel: e.GetEventType().String()}).Inc() diff --git a/cmd/controller/state/delta/controller_test.go b/cmd/controller/state/delta/controller_test.go index 9990cdc7..ff564f4f 100644 --- a/cmd/controller/state/delta/controller_test.go +++ b/cmd/controller/state/delta/controller_test.go @@ -13,6 +13,7 @@ import ( "github.com/castai/kvisor/pkg/logging" "github.com/samber/lo" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/encoding/protojson" @@ -25,6 +26,8 @@ import ( ) func TestController(t *testing.T) { + defer goleak.VerifyNone(t) + ctx := context.Background() log := logging.NewTestLog() diff --git a/go.mod b/go.mod index f5f5bc2d..f6a426d4 100644 --- a/go.mod +++ b/go.mod @@ -34,6 +34,8 @@ require ( github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.8.4 github.com/vishvananda/netns v0.0.4 + go.uber.org/atomic v1.10.0 + go.uber.org/goleak v1.3.0 golang.org/x/net v0.20.0 golang.org/x/sync v0.6.0 golang.org/x/sys v0.16.0 diff --git a/go.sum b/go.sum index cbfb9d3a..4f169f13 100644 --- a/go.sum +++ b/go.sum @@ -1124,6 +1124,8 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.starlark.net v0.0.0-20231121155337-90ade8b19d09 h1:hzy3LFnSN8kuQK8h9tHl4ndF6UruMj47OqwqsS+/Ai4= go.starlark.net v0.0.0-20231121155337-90ade8b19d09/go.mod h1:LcLNIzVOMp4oV+uusnpk+VU+SzXaJakUuBjoCSWH5dM= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= diff --git a/pkg/castai/client_test.go b/pkg/castai/client_test.go index 06677aeb..df9e4e95 100644 --- a/pkg/castai/client_test.go +++ b/pkg/castai/client_test.go @@ -106,6 +106,7 @@ func TestRemote(t *testing.T) { } type testServer struct { + eventsWriteStreamHandler func(server castaipb.RuntimeSecurityAgentAPI_EventsWriteStreamServer) error } func (t *testServer) KubeBenchReportIngest(ctx context.Context, report *castaipb.KubeBenchReport) (*castaipb.KubeBenchReportIngestResponse, error) { @@ -160,6 +161,10 @@ func (t *testServer) GetConfiguration(ctx context.Context, request *castaipb.Get } func (t *testServer) EventsWriteStream(server castaipb.RuntimeSecurityAgentAPI_EventsWriteStreamServer) error { + if t.eventsWriteStreamHandler != nil { + return t.eventsWriteStreamHandler(server) + } + md, ok := metadata.FromIncomingContext(server.Context()) if !ok { return errors.New("no metadata") diff --git a/pkg/castai/errors.go b/pkg/castai/errors.go new file mode 100644 index 00000000..66643ebc --- /dev/null +++ b/pkg/castai/errors.go @@ -0,0 +1,19 @@ +package castai + +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func IsGRPCError(err error, codes ...codes.Code) bool { + st, ok := status.FromError(err) + if !ok { + return false + } + for _, code := range codes { + if st.Code() == code { + return true + } + } + return false +} diff --git a/pkg/castai/logs_exporter.go b/pkg/castai/logs_exporter.go index 924466e2..25938d98 100644 --- a/pkg/castai/logs_exporter.go +++ b/pkg/castai/logs_exporter.go @@ -2,13 +2,12 @@ package castai import ( "context" - "errors" - "io" "log/slog" "time" castaipb "github.com/castai/kvisor/api/v1/runtime" "github.com/castai/kvisor/pkg/logging" + "google.golang.org/grpc" ) func NewLogsExporter(client *Client) *LogsExporter { @@ -24,33 +23,18 @@ type LogsExporter struct { } func (l *LogsExporter) Run(ctx context.Context) error { - var writeStream castaipb.RuntimeSecurityAgentAPI_LogsWriteStreamClient - var err error - - defer func() { - if writeStream != nil { - _ = writeStream.CloseSend() - } - }() + ws := NewWriteStream[*castaipb.LogEvent, *castaipb.SendLogsResponse](ctx, func(ctx context.Context) (grpc.ClientStream, error) { + return l.client.GRPC.LogsWriteStream(ctx) + }) + defer ws.Close() + ws.ReopenDelay = 1 * time.Second for { - if writeStream == nil { - writeStream, err = l.client.GRPC.LogsWriteStream(ctx) - if err != nil { - time.Sleep(1 * time.Second) - continue - } - } - select { case <-ctx.Done(): return ctx.Err() case e := <-l.logsChan: - if err := writeStream.Send(e); err != nil { - if errors.Is(err, io.EOF) { - writeStream = nil - } - } + _ = ws.Send(e) } } } diff --git a/pkg/castai/write_stream.go b/pkg/castai/write_stream.go new file mode 100644 index 00000000..4da43a2b --- /dev/null +++ b/pkg/castai/write_stream.go @@ -0,0 +1,88 @@ +package castai + +import ( + "context" + "errors" + "fmt" + "time" + + "google.golang.org/grpc" +) + +var ( + errNoActiveStream = errors.New("no active stream") +) + +func NewWriteStream[T, U any](ctx context.Context, createStreamFunc func(ctx context.Context) (grpc.ClientStream, error)) *WriteStream[T, U] { + return &WriteStream[T, U]{ + rootCtx: ctx, + createStreamFunc: createStreamFunc, + } +} + +// WriteStream wraps grpc client stream and handles stream reopen in case of send errors. +type WriteStream[T, U any] struct { + rootCtx context.Context + createStreamFunc func(ctx context.Context) (grpc.ClientStream, error) + activeStream grpc.ClientStream + activeStreamCtx context.Context + activeStreamCtxCancel context.CancelFunc + wasOpened bool + + ReopenDelay time.Duration +} + +func (w *WriteStream[T, U]) Send(m T) error { + if w.activeStream == nil { + if err := w.open(); err != nil { + return err + } + } + + if err := w.activeStream.SendMsg(m); err != nil { + w.close() + return err + } + return nil +} + +func (w *WriteStream[T, U]) Recv(m T) error { + if w.activeStream == nil { + return errNoActiveStream + } + return w.activeStream.RecvMsg(m) +} + +func (w *WriteStream[T, U]) Close() error { + if w.activeStream == nil { + return errNoActiveStream + } + err := w.activeStream.CloseSend() + w.close() + return err +} + +func (w *WriteStream[T, U]) open() error { + if w.wasOpened && w.ReopenDelay != 0 { + time.Sleep(w.ReopenDelay) + } + var err error + w.activeStreamCtx, w.activeStreamCtxCancel = context.WithCancel(w.rootCtx) + w.activeStream, err = w.createStreamFunc(w.activeStreamCtx) + if err != nil { + w.close() + return fmt.Errorf("open stream: %w", err) + } + w.wasOpened = true + return nil +} + +func (w *WriteStream[T, U]) close() { + // To properly close active stream we can cancel it's context. + // See https://github.com/grpc/grpc-go/blob/master/stream.go#L148 + if w.activeStreamCtxCancel != nil { + w.activeStreamCtxCancel() + } + w.activeStreamCtx = nil + w.activeStream = nil +} diff --git a/pkg/castai/write_stream_test.go b/pkg/castai/write_stream_test.go new file mode 100644 index 00000000..4b448244 --- /dev/null +++ b/pkg/castai/write_stream_test.go @@ -0,0 +1,86 @@ +package castai + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + castaipb "github.com/castai/kvisor/api/v1/runtime" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "go.uber.org/goleak" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestWriteStream(t *testing.T) { + defer goleak.VerifyNone(t) + + r := require.New(t) + ctx := context.Background() + + // Setup grpc test server which implements castai api. + ports, err := allocatePorts(1) + r.NoError(err) + addr := fmt.Sprintf("localhost:%d", ports[0]) + lis, err := net.Listen("tcp", addr) + r.NoError(err) + defer lis.Close() + s := grpc.NewServer() + + serverStreamOpenCount := atomic.NewInt64(0) + srv := &testServer{ + eventsWriteStreamHandler: func(server castaipb.RuntimeSecurityAgentAPI_EventsWriteStreamServer) error { + serverStreamOpenCount.Add(1) + var count int + for { + _, _ = server.Recv() + count++ + if count > 10 { + return status.Error(codes.Internal, "internal error") + } + } + }, + } + castaipb.RegisterRuntimeSecurityAgentAPIServer(s, srv) + go s.Serve(lis) + + clusterID := uuid.NewString() + client, err := NewClient("test", Config{ + ClusterID: clusterID, + APIKey: "api-key", + APIGrpcAddr: addr, + }) + r.NoError(err) + defer client.Close() + + ws := NewWriteStream[*castaipb.Event, *castaipb.WriteStreamResponse](ctx, func(ctx context.Context) (grpc.ClientStream, error) { + return client.GRPC.EventsWriteStream(ctx) + }) + ws.ReopenDelay = 1 * time.Millisecond + + var errs []error + for i := 0; i < 100; i++ { + if err := ws.Send(&castaipb.Event{}); err != nil { + errs = append(errs, err) + } + time.Sleep(1 * time.Millisecond) + } + + timeout := time.After(2 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + for { + select { + case <-timeout: + t.Fatal("timeout") + case <-ticker.C: + if serverStreamOpenCount.Load() > 0 && len(errs) > 0 { + return + } + } + } +}