Skip to content

Commit

Permalink
Fix grpc stream leak (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
anjmao authored Mar 22, 2024
1 parent 17667f4 commit 5d24cfe
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 101 deletions.
35 changes: 9 additions & 26 deletions cmd/agent/daemon/state/container_stats_pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,31 @@ import (

castpb "github.com/castai/kvisor/api/v1/runtime"
"github.com/castai/kvisor/cmd/agent/daemon/netstats"
"github.com/castai/kvisor/pkg/castai"
"github.com/castai/kvisor/pkg/containers"
"github.com/castai/kvisor/pkg/ebpftracer"
"github.com/castai/kvisor/pkg/metrics"
"github.com/castai/kvisor/pkg/stats"
"github.com/samber/lo"
"google.golang.org/grpc/codes"
"google.golang.org/grpc"
"k8s.io/apimachinery/pkg/api/resource"
)

func (c *Controller) runContainerStatsPipeline(ctx context.Context) error {
c.log.Info("running container stats sink loop")
defer c.log.Info("container stats sink loop done")

var writeStream castpb.RuntimeSecurityAgentAPI_ContainerStatsWriteStreamClient
var err error
defer func() {
if writeStream != nil {
_ = writeStream.CloseSend()
}
}()
ws := castai.NewWriteStream[*castpb.ContainerStatsBatch, *castpb.WriteStreamResponse](ctx, func(ctx context.Context) (grpc.ClientStream, error) {
return c.castClient.GRPC.ContainerStatsWriteStream(ctx)
})
defer ws.Close()
ws.ReopenDelay = c.writeStreamCreateRetryDelay

send := func(batch *castpb.ContainerStatsBatch) {
c.log.Debugf("sending container cgroup stats, items=%d", len(batch.GetItems()))
if err := writeStream.Send(batch); err != nil {
if err := ws.Send(batch); err != nil {
if errors.Is(err, io.EOF) {
writeStream = nil
return
}
c.log.Errorf("sending container cgroup stats: %v", err)
return
Expand All @@ -46,22 +45,6 @@ func (c *Controller) runContainerStatsPipeline(ctx context.Context) error {
defer ticker.Stop()

for {
// Create stream.
if writeStream == nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
writeStream, err = c.castClient.GRPC.ContainerStatsWriteStream(ctx)
if err != nil {
if !isGRPCError(err, codes.Unavailable, codes.Canceled) {
c.log.Warnf("create write stream: %v", err)
}
time.Sleep(c.writeStreamCreateRetryDelay)
continue
}
}
}

select {
case <-ctx.Done():
Expand Down
20 changes: 2 additions & 18 deletions cmd/agent/daemon/state/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,18 @@ import (
"sync"
"time"

castpb "github.com/castai/kvisor/api/v1/runtime"
"github.com/castai/kvisor/cmd/agent/daemon/conntrack"
"github.com/castai/kvisor/cmd/agent/daemon/enrichment"
"github.com/castai/kvisor/cmd/agent/daemon/netstats"
"github.com/castai/kvisor/cmd/agent/kube"
"github.com/castai/kvisor/pkg/castai"
"github.com/castai/kvisor/pkg/cgroup"
"github.com/castai/kvisor/pkg/containers"
"github.com/castai/kvisor/pkg/ebpftracer"
"github.com/castai/kvisor/pkg/ebpftracer/signature"
"github.com/castai/kvisor/pkg/logging"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

castpb "github.com/castai/kvisor/api/v1/runtime"
"github.com/castai/kvisor/pkg/containers"
)

type Config struct {
Expand Down Expand Up @@ -161,19 +158,6 @@ type syscallScrapePoint struct {
syscalls map[ebpftracer.SyscallID]uint64
}

func isGRPCError(err error, codes ...codes.Code) bool {
st, ok := status.FromError(err)
if !ok {
return false
}
for _, code := range codes {
if st.Code() == code {
return true
}
}
return false
}

func (c *Controller) MuteNamespace(namespace string) error {
c.mutedNamespacesMu.Lock()
c.mutedNamespaces[namespace] = struct{}{}
Expand Down
42 changes: 8 additions & 34 deletions cmd/agent/daemon/state/events_pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ package state

import (
"context"
"errors"
"fmt"
"io"
"time"

castpb "github.com/castai/kvisor/api/v1/runtime"
"github.com/castai/kvisor/pkg/castai"
"github.com/castai/kvisor/pkg/metrics"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc"
"k8s.io/apimachinery/pkg/types"
)

Expand All @@ -19,43 +17,19 @@ func (c *Controller) runEventsExportLoop(ctx context.Context) error {
c.log.Info("running events sink loop")
defer c.log.Info("events sink loop done")

var writeStream castpb.RuntimeSecurityAgentAPI_EventsWriteStreamClient
var err error

defer func() {
if writeStream != nil {
_ = writeStream.CloseSend()
}
}()
ws := castai.NewWriteStream[*castpb.Event, *castpb.WriteStreamResponse](ctx, func(ctx context.Context) (grpc.ClientStream, error) {
return c.castClient.GRPC.EventsWriteStream(ctx)
})
defer ws.Close()
ws.ReopenDelay = c.writeStreamCreateRetryDelay

for {
// Create stream.
if writeStream == nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
writeStream, err = c.castClient.GRPC.EventsWriteStream(ctx)
if err != nil {
if !isGRPCError(err, codes.Unavailable, codes.Canceled) {
c.log.Warnf("create write stream: %v", err)
}
time.Sleep(c.writeStreamCreateRetryDelay)
continue
}
}
}

select {
case <-ctx.Done():
return ctx.Err()
case e := <-c.eventsExportQueue:
c.enrichEvent(e)
if err := writeStream.Send(e); err != nil {
if errors.Is(err, io.EOF) {
writeStream = nil
}
c.log.Errorf("sending event: %v", err)
if err := ws.Send(e); err != nil {
continue
}
metrics.AgentExportedEventsTotal.With(prometheus.Labels{metrics.EventTypeLabel: e.GetEventType().String()}).Inc()
Expand Down
3 changes: 3 additions & 0 deletions cmd/controller/state/delta/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/castai/kvisor/pkg/logging"
"github.com/samber/lo"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/encoding/protojson"
Expand All @@ -25,6 +26,8 @@ import (
)

func TestController(t *testing.T) {
defer goleak.VerifyNone(t)

ctx := context.Background()
log := logging.NewTestLog()

Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ require (
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.8.4
github.com/vishvananda/netns v0.0.4
go.uber.org/atomic v1.10.0
go.uber.org/goleak v1.3.0
golang.org/x/net v0.20.0
golang.org/x/sync v0.6.0
golang.org/x/sys v0.16.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,8 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.starlark.net v0.0.0-20231121155337-90ade8b19d09 h1:hzy3LFnSN8kuQK8h9tHl4ndF6UruMj47OqwqsS+/Ai4=
go.starlark.net v0.0.0-20231121155337-90ade8b19d09/go.mod h1:LcLNIzVOMp4oV+uusnpk+VU+SzXaJakUuBjoCSWH5dM=
go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
Expand Down
5 changes: 5 additions & 0 deletions pkg/castai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func TestRemote(t *testing.T) {
}

type testServer struct {
eventsWriteStreamHandler func(server castaipb.RuntimeSecurityAgentAPI_EventsWriteStreamServer) error
}

func (t *testServer) KubeBenchReportIngest(ctx context.Context, report *castaipb.KubeBenchReport) (*castaipb.KubeBenchReportIngestResponse, error) {
Expand Down Expand Up @@ -160,6 +161,10 @@ func (t *testServer) GetConfiguration(ctx context.Context, request *castaipb.Get
}

func (t *testServer) EventsWriteStream(server castaipb.RuntimeSecurityAgentAPI_EventsWriteStreamServer) error {
if t.eventsWriteStreamHandler != nil {
return t.eventsWriteStreamHandler(server)
}

md, ok := metadata.FromIncomingContext(server.Context())
if !ok {
return errors.New("no metadata")
Expand Down
19 changes: 19 additions & 0 deletions pkg/castai/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package castai

import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func IsGRPCError(err error, codes ...codes.Code) bool {
st, ok := status.FromError(err)
if !ok {
return false
}
for _, code := range codes {
if st.Code() == code {
return true
}
}
return false
}
30 changes: 7 additions & 23 deletions pkg/castai/logs_exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ package castai

import (
"context"
"errors"
"io"
"log/slog"
"time"

castaipb "github.com/castai/kvisor/api/v1/runtime"
"github.com/castai/kvisor/pkg/logging"
"google.golang.org/grpc"
)

func NewLogsExporter(client *Client) *LogsExporter {
Expand All @@ -24,33 +23,18 @@ type LogsExporter struct {
}

func (l *LogsExporter) Run(ctx context.Context) error {
var writeStream castaipb.RuntimeSecurityAgentAPI_LogsWriteStreamClient
var err error

defer func() {
if writeStream != nil {
_ = writeStream.CloseSend()
}
}()
ws := NewWriteStream[*castaipb.LogEvent, *castaipb.SendLogsResponse](ctx, func(ctx context.Context) (grpc.ClientStream, error) {
return l.client.GRPC.LogsWriteStream(ctx)
})
defer ws.Close()
ws.ReopenDelay = 1 * time.Second

for {
if writeStream == nil {
writeStream, err = l.client.GRPC.LogsWriteStream(ctx)
if err != nil {
time.Sleep(1 * time.Second)
continue
}
}

select {
case <-ctx.Done():
return ctx.Err()
case e := <-l.logsChan:
if err := writeStream.Send(e); err != nil {
if errors.Is(err, io.EOF) {
writeStream = nil
}
}
_ = ws.Send(e)
}
}
}
Expand Down
88 changes: 88 additions & 0 deletions pkg/castai/write_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package castai

import (
"context"
"errors"
"fmt"
"time"

"google.golang.org/grpc"
)

var (
errNoActiveStream = errors.New("no active stream")
)

func NewWriteStream[T, U any](ctx context.Context, createStreamFunc func(ctx context.Context) (grpc.ClientStream, error)) *WriteStream[T, U] {
return &WriteStream[T, U]{
rootCtx: ctx,
createStreamFunc: createStreamFunc,
}
}

// WriteStream wraps grpc client stream and handles stream reopen in case of send errors.
type WriteStream[T, U any] struct {
rootCtx context.Context
createStreamFunc func(ctx context.Context) (grpc.ClientStream, error)
activeStream grpc.ClientStream
activeStreamCtx context.Context
activeStreamCtxCancel context.CancelFunc
wasOpened bool

ReopenDelay time.Duration
}

func (w *WriteStream[T, U]) Send(m T) error {
if w.activeStream == nil {
if err := w.open(); err != nil {
return err
}
}

if err := w.activeStream.SendMsg(m); err != nil {
w.close()
return err
}
return nil
}

func (w *WriteStream[T, U]) Recv(m T) error {
if w.activeStream == nil {
return errNoActiveStream
}
return w.activeStream.RecvMsg(m)
}

func (w *WriteStream[T, U]) Close() error {
if w.activeStream == nil {
return errNoActiveStream
}
err := w.activeStream.CloseSend()
w.close()
return err
}

func (w *WriteStream[T, U]) open() error {
if w.wasOpened && w.ReopenDelay != 0 {
time.Sleep(w.ReopenDelay)
}
var err error
w.activeStreamCtx, w.activeStreamCtxCancel = context.WithCancel(w.rootCtx)
w.activeStream, err = w.createStreamFunc(w.activeStreamCtx)
if err != nil {
w.close()
return fmt.Errorf("open stream: %w", err)
}
w.wasOpened = true
return nil
}

func (w *WriteStream[T, U]) close() {
// To properly close active stream we can cancel it's context.
// See https://github.com/grpc/grpc-go/blob/master/stream.go#L148
if w.activeStreamCtxCancel != nil {
w.activeStreamCtxCancel()
}
w.activeStreamCtx = nil
w.activeStream = nil
}
Loading

0 comments on commit 5d24cfe

Please sign in to comment.