diff --git a/controller/cache.go b/controller/cache.go index e6932fc..37fc0a7 100644 --- a/controller/cache.go +++ b/controller/cache.go @@ -27,93 +27,50 @@ func (s Store) FilterByGroupKind(gk schema.GroupKind) []Object { }) } -type Cache interface { - List() Store - Add(obj Object) - Delete(obj Object) - Replace(Store) +func (s Store) DeepCopy() Store { + return lo.SliceToMap(lo.Keys(s), func(uid string) (string, Object) { + return uid, s[uid].DeepCopyObject().(Object) + }) +} + +func (s Store) Equal(other Store) bool { + return len(s) == len(other) && lo.EveryBy(lo.Keys(s), func(uid string) bool { + otherObj, ok := other[uid] + return ok && reflect.DeepEqual(s[uid], otherObj) + }) } -type cacheStore struct { +type CacheStore struct { sync.RWMutex - store Store + watchable.Map[string, Store] } -func (c *cacheStore) List() Store { +func (c *CacheStore) List(storeId string) Store { c.RLock() defer c.RUnlock() - - ret := make(Store, len(c.store)) - for k, v := range c.store { - ret[k] = v.DeepCopyObject().(Object) - } - return ret + store, _ := c.Load(storeId) + return store } -func (c *cacheStore) Add(obj Object) { +func (c *CacheStore) Add(storeId string, obj Object) { c.Lock() defer c.Unlock() - - c.store[string(obj.GetUID())] = obj + uid := string(obj.GetUID()) + store, _ := c.Load(storeId) + store[uid] = obj + c.Store(storeId, store) } -func (c *cacheStore) Delete(obj Object) { +func (c *CacheStore) Delete(storeId string, obj Object) { c.Lock() defer c.Unlock() - - delete(c.store, string(obj.GetUID())) + store, _ := c.Load(storeId) + delete(store, string(obj.GetUID())) + c.Store(storeId, store) } -func (c *cacheStore) Replace(store Store) { +func (c *CacheStore) Replace(storeId string, store Store) { c.Lock() defer c.Unlock() - - c.store = make(Store, len(store)) - for k, v := range store { - c.store[k] = v.DeepCopyObject().(Object) - } -} - -type watchableCacheStore struct { - watchable.Map[string, watchableCacheEntry] -} - -func (c *watchableCacheStore) List() Store { - entries := c.LoadAll() - store := make(Store, len(entries)) - for uid, obj := range entries { - store[uid] = obj.Object - } - return store -} - -func (c *watchableCacheStore) Add(obj Object) { - c.Store(string(obj.GetUID()), watchableCacheEntry{obj}) -} - -func (c *watchableCacheStore) Delete(obj Object) { - c.Map.Delete(string(obj.GetUID())) -} - -func (c *watchableCacheStore) Replace(store Store) { - for uid, obj := range store { - c.Store(uid, watchableCacheEntry{obj}) - } - for uid := range c.LoadAll() { - if _, ok := store[uid]; !ok { - c.Map.Delete(uid) - } - } -} - -type watchableCacheEntry struct { - Object -} - -func (e watchableCacheEntry) DeepCopy() watchableCacheEntry { - return watchableCacheEntry{e.DeepCopyObject().(Object)} -} - -func (e watchableCacheEntry) Equal(other watchableCacheEntry) bool { - return reflect.DeepEqual(e, other) + c.Store(storeId, store) } diff --git a/controller/controller.go b/controller/controller.go index f04dbcc..13f112b 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -3,12 +3,12 @@ package controller import ( "context" "fmt" + "reflect" "sync" "time" "github.com/go-logr/logr" "github.com/samber/lo" - "github.com/telepresenceio/watchable" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/dynamic" @@ -22,6 +22,8 @@ import ( "github.com/kuadrant/policy-machinery/machinery" ) +const resourceStoreId = "resources" + type ControllerOptions struct { name string logger logr.Logger @@ -124,7 +126,7 @@ func NewController(f ...ControllerOption) *Controller { logger: opts.logger, client: opts.client, manager: opts.manager, - cache: &watchableCacheStore{}, + cache: &CacheStore{}, topology: newGatewayAPITopologyBuilder(opts.policyKinds, opts.objectKinds, opts.objectLinks, opts.allowTopologyLoops), runnables: map[string]Runnable{}, reconcile: opts.reconcile, @@ -146,7 +148,7 @@ type Controller struct { logger logr.Logger client *dynamic.DynamicClient manager ctrlruntime.Manager - cache Cache + cache *CacheStore topology *gatewayAPITopologyBuilder runnables map[string]Runnable listFuncs []ListFunc @@ -159,7 +161,7 @@ func (c *Controller) Start(ctx context.Context) error { stopCh := make(chan struct{}) // subscribe to cache - c.subscribe() + c.subscribe(ctx) // start runnables for name := range c.runnables { @@ -217,7 +219,7 @@ func (c *Controller) Reconcile(ctx context.Context, _ ctrlruntimereconcile.Reque store[string(object.GetUID())] = object } } - c.cache.Replace(store) + c.cache.Replace(resourceStoreId, store) return ctrlruntimereconcile.Result{}, nil } @@ -234,25 +236,25 @@ func (c *Controller) add(obj Object) { c.Lock() defer c.Unlock() - c.cache.Add(obj) + c.cache.Add(resourceStoreId, obj) } func (c *Controller) update(_, newObj Object) { c.Lock() defer c.Unlock() - c.cache.Add(newObj) + c.cache.Add(resourceStoreId, newObj) } func (c *Controller) delete(obj Object) { c.Lock() defer c.Unlock() - c.cache.Delete(obj) + c.cache.Delete(resourceStoreId, obj) } func (c *Controller) propagate(resourceEvents []ResourceEvent) { - topology, err := c.topology.Build(c.cache.List()) + topology, err := c.topology.Build(c.cache.List(resourceStoreId)) if err != nil { c.logger.Error(err, "error building topology") } @@ -261,42 +263,53 @@ func (c *Controller) propagate(resourceEvents []ResourceEvent) { } } -func (c *Controller) subscribe() { - cache, ok := c.cache.(*watchableCacheStore) // should we add Subscribe(ctx) to the Cache interface or remove the interface altogether? - if !ok { - return - } - recent := make(Store) - subscription := cache.Subscribe(context.TODO()) +func (c *Controller) subscribe(ctx context.Context) { + oldObjs := make(Store) + subscription := c.cache.SubscribeSubset(ctx, func(storeId string, _ Store) bool { + return storeId == resourceStoreId + }) go func() { for snapshot := range subscription { c.Lock() - c.propagate(lo.FlatMap(snapshot.Updates, func(update watchable.Update[string, watchableCacheEntry], _ int) []ResourceEvent { - key := update.Key - obj := update.Value + newObjs := snapshot.State[resourceStoreId] + events := lo.FilterMap(lo.Keys(newObjs), func(uid string, _ int) (ResourceEvent, bool) { + newObj := newObjs[uid] event := ResourceEvent{ - Kind: obj.GetObjectKind().GroupVersionKind().GroupKind(), + Kind: newObj.GetObjectKind().GroupVersionKind().GroupKind(), + NewObject: newObj, + } + if oldObj, exists := oldObjs[uid]; !exists { + event.EventType = CreateEvent + oldObjs[uid] = newObj + return event, true + } else if !reflect.DeepEqual(oldObj, newObj) { + event.EventType = UpdateEvent + event.OldObject = oldObj + oldObjs[uid] = newObj + return event, true } + return event, false + }) - if update.Delete { - event.EventType = DeleteEvent - event.OldObject = obj - delete(recent, key) - } else { - if oldObj, ok := recent[key]; ok { - event.EventType = UpdateEvent - event.OldObject = oldObj - } else { - event.EventType = CreateEvent - } - event.NewObject = obj - recent[key] = obj + deleteEvents := lo.FilterMap(lo.Keys(oldObjs), func(uid string, _ int) (ResourceEvent, bool) { + oldObj := oldObjs[uid] + event := ResourceEvent{ + EventType: DeleteEvent, + Kind: oldObj.GetObjectKind().GroupVersionKind().GroupKind(), + OldObject: oldObj, + } + _, exists := newObjs[uid] + if !exists { + delete(oldObjs, uid) } + return event, !exists + }) + + events = append(events, deleteEvents...) - return []ResourceEvent{event} - })) + c.propagate(events) c.Unlock() } diff --git a/controller/controller_test.go b/controller/controller_test.go index 2f68acf..4042ebf 100644 --- a/controller/controller_test.go +++ b/controller/controller_test.go @@ -173,11 +173,6 @@ func TestNewController(t *testing.T) { if c.manager != tc.expected.manager { t.Errorf("expected manager %v, got %v", tc.expected.manager, c.manager) } - switch c.cache.(type) { - case *watchableCacheStore: - default: - t.Errorf("expected cache type *watchableCacheStore, got %T", c.cache) - } if len(c.topology.policyKinds) != len(tc.expected.policyKinds) || !lo.Every(c.topology.policyKinds, tc.expected.policyKinds) { t.Errorf("expected policyKinds %v, got %v", tc.expected.policyKinds, c.topology.policyKinds) } @@ -200,7 +195,7 @@ func TestControllerReconcile(t *testing.T) { &corev1.ConfigMap{ObjectMeta: metav1.ObjectMeta{Name: "test-configmap", UID: "aed148b1-285a-48ab-8839-fe99475bc6fc"}}, } objUIDs := lo.Map(objs, func(o Object, _ int) string { return string(o.GetUID()) }) - cache := &cacheStore{store: make(Store)} + cache := &CacheStore{} controller := &Controller{ logger: testLogger, cache: cache, @@ -209,7 +204,7 @@ func TestControllerReconcile(t *testing.T) { }, } controller.Reconcile(context.TODO(), ctrlruntimereconcile.Request{}) - cachedObjs := lo.Keys(cache.List()) + cachedObjs := lo.Keys(cache.List(resourceStoreId)) if len(cachedObjs) != 2 { t.Errorf("expected 2 objects, got %d", len(cachedObjs)) }