diff --git a/internal/services/controller/controller.go b/internal/services/controller/controller.go index adaa7980..e8700830 100644 --- a/internal/services/controller/controller.go +++ b/internal/services/controller/controller.go @@ -2,11 +2,7 @@ package controller import ( "context" - "encoding/base64" - "encoding/json" - "fmt" "reflect" - "strings" "sync" "time" @@ -15,6 +11,7 @@ import ( batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" storagev1 "k8s.io/api/storage/v1" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/informers" "k8s.io/client-go/tools/cache" @@ -29,6 +26,7 @@ import ( type Controller struct { log logrus.FieldLogger + clusterID string castaiclient castai.Client provider types.Provider queue workqueue.RateLimitingInterface @@ -36,7 +34,7 @@ type Controller struct { prepDuration time.Duration informers map[reflect.Type]cache.SharedInformer - delta *castai.Delta + delta *delta mu sync.Mutex spotCache map[string]bool agentVersion *config.AgentVersion @@ -74,11 +72,12 @@ func New( c := &Controller{ log: log, + clusterID: clusterID, castaiclient: castaiclient, provider: provider, interval: interval, prepDuration: prepDuration, - delta: &castai.Delta{ClusterID: clusterID, ClusterVersion: v.Full(), FullSnapshot: true}, + delta: newDelta(log, clusterID, v.Full()), spotCache: map[string]bool{}, queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "castai-agent"), informers: typeInformerMap, @@ -95,25 +94,25 @@ func New( if typ == reflect.TypeOf(&corev1.Node{}) { h = cache.ResourceEventHandlerFuncs{ AddFunc: func(obj interface{}) { - c.nodeAddHandler(log, castai.EventAdd, obj) + c.nodeAddHandler(log, eventAdd, obj) }, UpdateFunc: func(oldObj, newObj interface{}) { - c.nodeAddHandler(log, castai.EventUpdate, newObj) + c.nodeAddHandler(log, eventUpdate, newObj) }, DeleteFunc: func(obj interface{}) { - c.nodeDeleteHandler(log, castai.EventDelete, obj) + c.nodeDeleteHandler(log, eventDelete, obj) }, } } else { h = cache.ResourceEventHandlerFuncs{ AddFunc: func(obj interface{}) { - genericHandler(log, c.queue, typ, castai.EventAdd, obj) + genericHandler(log, c.queue, typ, eventAdd, obj) }, UpdateFunc: func(oldObj, newObj interface{}) { - genericHandler(log, c.queue, typ, castai.EventUpdate, newObj) + genericHandler(log, c.queue, typ, eventUpdate, newObj) }, DeleteFunc: func(obj interface{}) { - genericHandler(log, c.queue, typ, castai.EventDelete, obj) + genericHandler(log, c.queue, typ, eventDelete, obj) }, } } @@ -124,11 +123,7 @@ func New( return c } -func (c *Controller) nodeAddHandler( - log logrus.FieldLogger, - event castai.EventType, - obj interface{}, -) { +func (c *Controller) nodeAddHandler(log logrus.FieldLogger, event event, obj interface{}) { node, ok := obj.(*corev1.Node) if !ok { log.Errorf("expected to get *corev1.Node but got %T", obj) @@ -153,11 +148,7 @@ func (c *Controller) nodeAddHandler( genericHandler(log, c.queue, reflect.TypeOf(&corev1.Node{}), event, node) } -func (c *Controller) nodeDeleteHandler( - log logrus.FieldLogger, - event castai.EventType, - obj interface{}, -) { +func (c *Controller) nodeDeleteHandler(log logrus.FieldLogger, event event, obj interface{}) { node, ok := obj.(*corev1.Node) if !ok { log.Errorf("expected to get *corev1.Node but got %T", obj) @@ -173,7 +164,7 @@ func genericHandler( log logrus.FieldLogger, queue workqueue.RateLimitingInterface, expected reflect.Type, - event castai.EventType, + event event, obj interface{}, ) { if reflect.TypeOf(obj) != expected { @@ -181,31 +172,12 @@ func genericHandler( return } - typeName := expected.String() - kind := typeName[strings.LastIndex(typeName, ".")+1:] - - data, err := encode(obj) - if err != nil { - log.Errorf("failed to encode %T: %v", obj, err) - return - } - - queue.Add(&castai.DeltaItem{ - Event: event, - Kind: kind, - Data: data, - CreatedAt: time.Now().UTC(), + queue.Add(&item{ + obj: obj.(runtime.Object), + event: event, }) } -func encode(obj interface{}) (string, error) { - b, err := json.Marshal(obj) - if err != nil { - return "", fmt.Errorf("marshaling %T to json: %v", obj, err) - } - return base64.StdEncoding.EncodeToString(b), nil -} - func (c *Controller) Run(ctx context.Context) { defer c.queue.ShutDown() @@ -230,13 +202,13 @@ func (c *Controller) Run(ctx context.Context) { AgentVersion: c.agentVersion.Version, GitCommit: c.agentVersion.GitCommit, } - cfg, err := c.castaiclient.ExchangeAgentTelemetry(ctx, c.delta.ClusterID, req) + cfg, err := c.castaiclient.ExchangeAgentTelemetry(ctx, c.clusterID, req) if err != nil { c.log.Errorf("failed getting agent configuration: %v", err) return } // Resync only when at least one full snapshot has already been sent. - if cfg.Resync && !c.delta.FullSnapshot { + if cfg.Resync && !c.delta.fullSnapshot { c.log.Info("restarting controller to resync data") cancel() } @@ -262,19 +234,19 @@ func (c *Controller) Run(ctx context.Context) { func (c *Controller) pollQueueUntilDone() { for { - item, done := c.queue.Get() + i, done := c.queue.Get() if done { return } - di, ok := item.(*castai.DeltaItem) + di, ok := i.(*item) if !ok { - c.log.Errorf("expected queue item to be of type %T but got %T", &castai.DeltaItem{}, item) + c.log.Errorf("expected queue item to be of type %T but got %T", &item{}, i) continue } c.mu.Lock() - c.delta.Items = append(c.delta.Items, di) + c.delta.add(di) c.mu.Unlock() } } @@ -283,11 +255,10 @@ func (c *Controller) send(ctx context.Context) { c.mu.Lock() defer c.mu.Unlock() - if err := c.castaiclient.SendDelta(ctx, c.delta); err != nil { + if err := c.castaiclient.SendDelta(ctx, c.delta.toCASTAIRequest()); err != nil { c.log.Errorf("failed sending delta: %v", err) return } - c.delta.Items = nil - c.delta.FullSnapshot = false + c.delta.clear() } diff --git a/internal/services/controller/delta.go b/internal/services/controller/delta.go new file mode 100644 index 00000000..799edc27 --- /dev/null +++ b/internal/services/controller/delta.go @@ -0,0 +1,148 @@ +package controller + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/sirupsen/logrus" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + + "castai-agent/internal/castai" +) + +// newDelta initializes the delta struct which is used to collect cluster deltas, debounce them and map to CASTAI +// requests. +func newDelta(log logrus.FieldLogger, clusterID, clusterVersion string) *delta { + return &delta{ + log: log, + clusterID: clusterID, + clusterVersion: clusterVersion, + fullSnapshot: true, + cache: map[string]*item{}, + } +} + +// delta is used to colelct cluster deltas, debounce them and map to CASTAI requests. It holds a cache of queue items +// which is referenced any time a new item is added to debounce the items. +type delta struct { + log logrus.FieldLogger + clusterID string + clusterVersion string + fullSnapshot bool + cache map[string]*item +} + +// add will add an item to the delta cache. It will debounce the objects. +func (d *delta) add(i *item) { + key := mustKeyObject(i.obj) + + if other, ok := d.cache[key]; ok && other.event == eventAdd && i.event == eventDelete { + delete(d.cache, key) + } else if ok && other.event == eventAdd && i.event == eventUpdate { + i.event = eventAdd + d.cache[key] = i + } else if ok && other.event == eventDelete && (i.event == eventAdd || i.event == eventUpdate) { + i.event = eventUpdate + d.cache[key] = i + } else { + d.cache[key] = i + } +} + +// clear resets the delta cache and sets fullSnapshot to false. Should be called after toCASTAIRequest is successfully +// delivered. +func (d *delta) clear() { + d.fullSnapshot = false + d.cache = map[string]*item{} +} + +// toCASTAIRequest maps the collected delta cache to the castai.Delta type. +func (d *delta) toCASTAIRequest() *castai.Delta { + var items []*castai.DeltaItem + + for _, i := range d.cache { + data, err := encode(i.obj) + if err != nil { + d.log.Errorf("failed to encode %T: %v", i.obj, err) + continue + } + + kinds, _, err := scheme.ObjectKinds(i.obj) + if err != nil { + d.log.Errorf("failed to find object %T kind: %v", i.obj, err) + continue + } + if len(kinds) == 0 || kinds[0].Kind == "" { + d.log.Errorf("unknown object kind for object %T", i.obj) + continue + } + + items = append(items, &castai.DeltaItem{ + Event: i.event.toCASTAIEvent(), + Kind: kinds[0].Kind, + Data: data, + CreatedAt: time.Now().UTC(), + }) + } + + return &castai.Delta{ + ClusterID: d.clusterID, + ClusterVersion: d.clusterVersion, + FullSnapshot: d.fullSnapshot, + Items: items, + } +} + +func encode(obj interface{}) (string, error) { + b, err := json.Marshal(obj) + if err != nil { + return "", fmt.Errorf("marshaling %T to json: %v", obj, err) + } + return base64.StdEncoding.EncodeToString(b), nil +} + +type item struct { + obj runtime.Object + event event +} + +type event string + +const ( + eventAdd event = "add" + eventDelete event = "delete" + eventUpdate event = "update" +) + +func (e event) toCASTAIEvent() castai.EventType { + switch e { + case eventAdd: + return castai.EventAdd + case eventDelete: + return castai.EventDelete + case eventUpdate: + return castai.EventUpdate + } + return "" +} + +// keyObject generates a unique key for an object, for example: `*v1.Pod::namespace/name`. +func keyObject(obj runtime.Object) (string, error) { + metaObj, ok := obj.(metav1.Object) + if !ok { + return "", fmt.Errorf("expected object of type %T to implement metav1.Object", obj) + } + return fmt.Sprintf("%s::%s/%s", reflect.TypeOf(obj).String(), metaObj.GetNamespace(), metaObj.GetName()), nil +} + +func mustKeyObject(obj runtime.Object) string { + k, err := keyObject(obj) + if err != nil { + panic(fmt.Errorf("getting object key: %w", err)) + } + return k +} diff --git a/internal/services/controller/delta_test.go b/internal/services/controller/delta_test.go new file mode 100644 index 00000000..b88ce812 --- /dev/null +++ b/internal/services/controller/delta_test.go @@ -0,0 +1,222 @@ +package controller + +import ( + "testing" + + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "castai-agent/internal/castai" +) + +func TestDelta(t *testing.T) { + clusterID := uuid.New().String() + version := "1.18" + + pod1 := &v1.Pod{ObjectMeta: metav1.ObjectMeta{Namespace: v1.NamespaceDefault, Name: "a"}} + pod1Updated := &v1.Pod{ObjectMeta: metav1.ObjectMeta{Namespace: v1.NamespaceDefault, Name: "a", Labels: map[string]string{"a": "b"}}} + + pod2 := &v1.Pod{ObjectMeta: metav1.ObjectMeta{Namespace: v1.NamespaceDefault, Name: "b"}} + + tests := []struct { + name string + items []*item + expected *castai.Delta + }{ + { + name: "empty items", + items: []*item{}, + expected: &castai.Delta{ + ClusterID: clusterID, + ClusterVersion: version, + FullSnapshot: true, + }, + }, + { + name: "multiple items", + items: []*item{ + { + obj: pod1, + event: eventAdd, + }, + { + obj: pod2, + event: eventAdd, + }, + }, + expected: &castai.Delta{ + ClusterID: clusterID, + ClusterVersion: version, + FullSnapshot: true, + Items: []*castai.DeltaItem{ + { + Event: castai.EventAdd, + Kind: "Pod", + Data: mustEncode(t, pod1), + }, + { + Event: castai.EventAdd, + Kind: "Pod", + Data: mustEncode(t, pod2), + }, + }, + }, + }, + { + name: "debounce: override added item with updated data", + items: []*item{ + { + obj: pod1, + event: eventAdd, + }, + { + obj: pod1Updated, + event: eventUpdate, + }, + }, + expected: &castai.Delta{ + ClusterID: clusterID, + ClusterVersion: version, + FullSnapshot: true, + Items: []*castai.DeltaItem{ + { + Event: castai.EventAdd, + Kind: "Pod", + Data: mustEncode(t, pod1Updated), + }, + }, + }, + }, + { + name: "debounce: entirely remove added item when it is deleted", + items: []*item{ + { + obj: pod1, + event: eventAdd, + }, + { + obj: pod1, + event: eventDelete, + }, + }, + expected: &castai.Delta{ + ClusterID: clusterID, + ClusterVersion: version, + FullSnapshot: true, + }, + }, + { + name: "debounce: keep only delete event when an updated item is deleted", + items: []*item{ + { + obj: pod1, + event: eventUpdate, + }, + { + obj: pod1, + event: eventDelete, + }, + }, + expected: &castai.Delta{ + ClusterID: clusterID, + ClusterVersion: version, + FullSnapshot: true, + Items: []*castai.DeltaItem{ + { + Event: castai.EventDelete, + Kind: "Pod", + Data: mustEncode(t, pod1), + }, + }, + }, + }, + { + name: "debounce: override updated item with newer updated data", + items: []*item{ + { + obj: pod1, + event: eventUpdate, + }, + { + obj: pod1Updated, + event: eventUpdate, + }, + }, + expected: &castai.Delta{ + ClusterID: clusterID, + ClusterVersion: version, + FullSnapshot: true, + Items: []*castai.DeltaItem{ + { + Event: castai.EventUpdate, + Kind: "Pod", + Data: mustEncode(t, pod1Updated), + }, + }, + }, + }, + { + name: "debounce: change deleted item to updated when it is readded", + items: []*item{ + { + obj: pod1, + event: eventDelete, + }, + { + obj: pod1Updated, + event: eventAdd, + }, + }, + expected: &castai.Delta{ + ClusterID: clusterID, + ClusterVersion: version, + FullSnapshot: true, + Items: []*castai.DeltaItem{ + { + Event: castai.EventUpdate, + Kind: "Pod", + Data: mustEncode(t, pod1Updated), + }, + }, + }, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + d := newDelta(logrus.New(), clusterID, version) + + for _, item := range test.items { + d.add(item) + } + + got := d.toCASTAIRequest() + + require.Equal(t, clusterID, got.ClusterID) + require.Equal(t, version, got.ClusterVersion) + require.True(t, got.FullSnapshot) + require.Equal(t, len(got.Items), len(test.expected.Items)) + for _, expectedItem := range test.expected.Items { + requireContains(t, got.Items, expectedItem) + } + }) + } +} + +func mustEncode(t *testing.T, obj interface{}) string { + data, err := encode(obj) + require.NoError(t, err) + return data +} + +func requireContains(t *testing.T, actual []*castai.DeltaItem, expected *castai.DeltaItem) { + for _, di := range actual { + if di.Kind == expected.Kind && di.Event == expected.Event && di.Data == expected.Data { + return + } + } + require.Failf(t, "failed", "expected %s to contain %s", actual, expected) +} diff --git a/internal/services/controller/register.go b/internal/services/controller/register.go new file mode 100644 index 00000000..862dd5ff --- /dev/null +++ b/internal/services/controller/register.go @@ -0,0 +1,22 @@ +package controller + +import ( + appsv1 "k8s.io/api/apps/v1" + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + storagev1 "k8s.io/api/storage/v1" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" +) + +var scheme = runtime.NewScheme() +var builder = runtime.SchemeBuilder{ + corev1.AddToScheme, + appsv1.AddToScheme, + storagev1.AddToScheme, + batchv1.AddToScheme, +} + +func init() { + utilruntime.Must(builder.AddToScheme(scheme)) +}