Skip to content

Commit

Permalink
refactor: only reconcile tls policies that were affected by events
Browse files Browse the repository at this point in the history
Signed-off-by: KevFan <[email protected]>
  • Loading branch information
KevFan committed Oct 16, 2024
1 parent 52ccba8 commit cd35907
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 83 deletions.
11 changes: 4 additions & 7 deletions controllers/effective_tls_policies_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,11 @@ func (t *EffectiveTLSPoliciesReconciler) Subscription() *controller.Subscription
//+kubebuilder:rbac:groups="",resources=secrets,verbs=get;list;watch
//+kubebuilder:rbac:groups="cert-manager.io",resources=certificates,verbs=get;list;watch;create;update;patch;delete

func (t *EffectiveTLSPoliciesReconciler) Reconcile(ctx context.Context, _ []controller.ResourceEvent, topology *machinery.Topology, _ error, s *sync.Map) error {
func (t *EffectiveTLSPoliciesReconciler) Reconcile(ctx context.Context, events []controller.ResourceEvent, topology *machinery.Topology, _ error, s *sync.Map) error {
logger := controller.LoggerFromContext(ctx).WithName("EffectiveTLSPoliciesReconciler").WithName("Reconcile")

// Get all TLS Policies
policies := lo.Filter(topology.Policies().Items(), func(item machinery.Policy, index int) bool {
_, ok := item.(*kuadrantv1alpha1.TLSPolicy)
return ok
})
// Get affected TLS Policies
policies := GetTLSPoliciesByEvents(topology, events)

// Get all certs in topology for comparison with expected certs to determine orphaned certs later
certs := lo.FilterMap(topology.Objects().Items(), func(item machinery.Object, index int) (*certmanv1.Certificate, bool) {
Expand Down Expand Up @@ -135,7 +132,7 @@ func (t *EffectiveTLSPoliciesReconciler) Reconcile(ctx context.Context, _ []cont
continue
}
_, err = resource.Create(ctx, un, metav1.CreateOptions{})
if err != nil {
if err != nil && !apierrors.IsAlreadyExists(err) {
logger.Error(err, "unable to create certificate", "name", policy.Name, "namespace", policy.Namespace, "uid", policy.GetUID())
}

Expand Down
138 changes: 103 additions & 35 deletions controllers/tls_workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
gatewayapiv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2"

kuadrantv1alpha1 "github.com/kuadrant/kuadrant-operator/api/v1alpha1"
"github.com/kuadrant/kuadrant-operator/pkg/library/utils"
)

const (
Expand Down Expand Up @@ -104,24 +105,7 @@ func LinkGatewayToIssuerFunc(objs controller.Store) machinery.LinkFunc {
return p.Spec.IssuerRef.Name == issuer.GetName() && p.GetNamespace() == issuer.GetNamespace() && p.Spec.IssuerRef.Kind == certmanagerv1.IssuerKind
})

if len(linkedPolicies) == 0 {
return nil
}

// Can infer linked gateways through the policy
linkedGateways := lo.Filter(gateways, func(g *gwapiv1.Gateway, index int) bool {
for _, l := range linkedPolicies {
if string(l.Spec.TargetRef.Name) == g.GetName() && g.GetNamespace() == l.GetNamespace() {
return true
}
}

return false
})

return lo.Map(linkedGateways, func(item *gwapiv1.Gateway, index int) machinery.Object {
return &machinery.Gateway{Gateway: item}
})
return findLinkedGatewaysForIssuer(linkedPolicies, gateways)
},
}
}
Expand All @@ -142,26 +126,30 @@ func LinkGatewayToClusterIssuerFunc(objs controller.Store) machinery.LinkFunc {
return p.Spec.IssuerRef.Name == clusterIssuer.GetName() && p.Spec.IssuerRef.Kind == certmanagerv1.ClusterIssuerKind
})

if len(linkedPolicies) == 0 {
return nil
}
return findLinkedGatewaysForIssuer(linkedPolicies, gateways)
},
}
}

// Can infer linked gateways through the policy
linkedGateways := lo.Filter(gateways, func(g *gwapiv1.Gateway, index int) bool {
for _, l := range linkedPolicies {
if string(l.Spec.TargetRef.Name) == g.GetName() && g.GetNamespace() == l.GetNamespace() {
return true
}
}
func findLinkedGatewaysForIssuer(linkedPolicies []*kuadrantv1alpha1.TLSPolicy, gateways []*gwapiv1.Gateway) []machinery.Object {
if len(linkedPolicies) == 0 {
return nil
}

return false
})
// Can infer linked gateways through the policy
linkedGateways := lo.Filter(gateways, func(g *gwapiv1.Gateway, index int) bool {
for _, l := range linkedPolicies {
if string(l.Spec.TargetRef.Name) == g.GetName() && g.GetNamespace() == l.GetNamespace() {
return true
}
}

return lo.Map(linkedGateways, func(item *gwapiv1.Gateway, index int) machinery.Object {
return &machinery.Gateway{Gateway: item}
})
},
}
return false
})

return lo.Map(linkedGateways, func(item *gwapiv1.Gateway, index int) machinery.Object {
return &machinery.Gateway{Gateway: item}
})
}

// Common functions used across multiple reconcilers
Expand All @@ -179,3 +167,83 @@ func IsTLSPolicyValid(ctx context.Context, s *sync.Map, policy *kuadrantv1alpha1

return isPolicyValidErrorMap[policy.GetLocator()] == nil, isPolicyValidErrorMap[policy.GetLocator()]
}

func GetTLSPoliciesByEvents(topology *machinery.Topology, events []controller.ResourceEvent) []machinery.Policy {
policies := lo.Filter(topology.Policies().Items(), func(item machinery.Policy, index int) bool {
_, ok := item.(*kuadrantv1alpha1.TLSPolicy)
return ok
})

var affectedPolicies []machinery.Policy
for _, event := range events {
if event.Kind == machinery.GatewayGroupKind {
ob := event.NewObject
if ob == nil {
ob = event.OldObject
}

g := machinery.Gateway{Gateway: ob.(*gwapiv1.Gateway)}

affectedPolicies = append(affectedPolicies, lo.Filter(policies, func(item machinery.Policy, index int) bool {
for _, tg := range item.GetTargetRefs() {
if g.GetLocator() == tg.GetLocator() {
return true
}
}
return false
})...)
}

if event.Kind == kuadrantv1alpha1.TLSPolicyGroupKind {
ob := event.NewObject
if ob == nil {
ob = event.OldObject
}

affectedPolicies = append(affectedPolicies, lo.Filter(policies, func(item machinery.Policy, index int) bool {
return item.GetName() == ob.GetName() && item.GetNamespace() == ob.GetNamespace()
})...)
}

if event.Kind == CertManagerCertificateKind {
ob := event.NewObject
if ob == nil {
ob = event.OldObject
}

affectedPolicies = append(affectedPolicies, lo.Filter(policies, func(item machinery.Policy, index int) bool {
p := item.(*kuadrantv1alpha1.TLSPolicy)
return utils.IsOwnedBy(ob, p)
})...)
}

if event.Kind == CertManagerIssuerKind {
ob := event.NewObject
if ob == nil {
ob = event.OldObject
}

affectedPolicies = append(affectedPolicies, lo.Filter(policies, func(item machinery.Policy, index int) bool {
p := item.(*kuadrantv1alpha1.TLSPolicy)

return ob.GetName() == p.Spec.IssuerRef.Name && lo.Contains([]string{"", certmanagerv1.IssuerKind}, p.Spec.IssuerRef.Kind) &&
item.GetNamespace() == ob.GetNamespace()
})...)
}

if event.Kind == CertManagerClusterIssuerKind {
ob := event.NewObject
if ob == nil {
ob = event.OldObject
}

affectedPolicies = append(affectedPolicies, lo.Filter(policies, func(item machinery.Policy, index int) bool {
p := item.(*kuadrantv1alpha1.TLSPolicy)
return ob.GetName() == p.Spec.IssuerRef.Name && p.Spec.IssuerRef.Kind == certmanagerv1.ClusterIssuerKind
})...)
}
}

// Return only unique policies as there can be duplicates from multiple events
return lo.Uniq(affectedPolicies)
}
10 changes: 4 additions & 6 deletions controllers/tlspolicies_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,15 @@ func (t *TLSPoliciesValidator) Subscription() *controller.Subscription {
}
}

func (t *TLSPoliciesValidator) Validate(ctx context.Context, _ []controller.ResourceEvent, topology *machinery.Topology, _ error, s *sync.Map) error {
func (t *TLSPoliciesValidator) Validate(ctx context.Context, events []controller.ResourceEvent, topology *machinery.Topology, _ error, s *sync.Map) error {
logger := controller.LoggerFromContext(ctx).WithName("TLSPoliciesValidator").WithName("Validate")

policies := lo.FilterMap(topology.Policies().Items(), func(item machinery.Policy, index int) (*kuadrantv1alpha1.TLSPolicy, bool) {
p, ok := item.(*kuadrantv1alpha1.TLSPolicy)
return p, ok
})
policies := GetTLSPoliciesByEvents(topology, events)

isPolicyValidErrorMap := make(map[string]error, len(policies))

for _, p := range policies {
for _, policy := range policies {
p := policy.(*kuadrantv1alpha1.TLSPolicy)
if p.DeletionTimestamp != nil {
logger.V(1).Info("tls policy is marked for deletion, skipping", "name", p.Name, "namespace", p.Namespace)
continue
Expand Down
68 changes: 33 additions & 35 deletions controllers/tlspolicy_status_updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,45 +46,43 @@ func (t *TLSPolicyStatusUpdater) Subscription() *controller.Subscription {
}
}

func (t *TLSPolicyStatusUpdater) UpdateStatus(ctx context.Context, _ []controller.ResourceEvent, topology *machinery.Topology, _ error, s *sync.Map) error {
func (t *TLSPolicyStatusUpdater) UpdateStatus(ctx context.Context, events []controller.ResourceEvent, topology *machinery.Topology, _ error, s *sync.Map) error {
logger := controller.LoggerFromContext(ctx).WithName("TLSPolicyStatusUpdater").WithName("UpdateStatus")

policies := lo.FilterMap(topology.Policies().Items(), func(item machinery.Policy, index int) (*kuadrantv1alpha1.TLSPolicy, bool) {
p, ok := item.(*kuadrantv1alpha1.TLSPolicy)
return p, ok
})
policies := GetTLSPoliciesByEvents(topology, events)

for _, policy := range policies {
if policy.DeletionTimestamp != nil {
logger.V(1).Info("tls policy is marked for deletion, skipping", "name", policy.GetName(), "namespace", policy.GetNamespace(), "uid", policy.GetUID())
p := policy.(*kuadrantv1alpha1.TLSPolicy)
if p.DeletionTimestamp != nil {
logger.V(1).Info("tls policy is marked for deletion, skipping", "name", policy.GetName(), "namespace", policy.GetNamespace(), "uid", p.GetUID())
continue
}

newStatus := &kuadrantv1alpha1.TLSPolicyStatus{
// Copy initial conditions. Otherwise, status will always be updated
Conditions: slices.Clone(policy.Status.Conditions),
ObservedGeneration: policy.Status.ObservedGeneration,
Conditions: slices.Clone(p.Status.Conditions),
ObservedGeneration: p.Status.ObservedGeneration,
}

_, err := IsTLSPolicyValid(ctx, s, policy)
meta.SetStatusCondition(&newStatus.Conditions, *kuadrant.AcceptedCondition(policy, err))
_, err := IsTLSPolicyValid(ctx, s, p)
meta.SetStatusCondition(&newStatus.Conditions, *kuadrant.AcceptedCondition(p, err))

// Do not set enforced condition if Accepted condition is false
if meta.IsStatusConditionFalse(newStatus.Conditions, string(gatewayapiv1alpha2.PolicyReasonAccepted)) {
meta.RemoveStatusCondition(&newStatus.Conditions, string(kuadrant.PolicyConditionEnforced))
} else {
enforcedCond := t.enforcedCondition(ctx, policy, topology)
enforcedCond := t.enforcedCondition(ctx, p, topology)
meta.SetStatusCondition(&newStatus.Conditions, *enforcedCond)
}

// Nothing to do
equalStatus := equality.Semantic.DeepEqual(newStatus, policy.Status)
if equalStatus && policy.Generation == policy.Status.ObservedGeneration {
equalStatus := equality.Semantic.DeepEqual(newStatus, p.Status)
if equalStatus && p.Generation == p.Status.ObservedGeneration {
logger.V(1).Info("policy status unchanged, skipping update")
continue
}
newStatus.ObservedGeneration = policy.Generation
policy.Status = *newStatus
newStatus.ObservedGeneration = p.Generation
p.Status = *newStatus

resource := t.Client.Resource(kuadrantv1alpha1.TLSPoliciesResource).Namespace(policy.GetNamespace())
un, err := controller.Destruct(policy)
Expand All @@ -95,26 +93,26 @@ func (t *TLSPolicyStatusUpdater) UpdateStatus(ctx context.Context, _ []controlle

_, err = resource.UpdateStatus(ctx, un, metav1.UpdateOptions{})
if err != nil {
logger.Error(err, "unable to update status for TLSPolicy", "name", policy.GetName(), "namespace", policy.GetNamespace(), "uid", policy.GetUID())
logger.Error(err, "unable to update status for TLSPolicy", "name", policy.GetName(), "namespace", policy.GetNamespace(), "uid", p.GetUID())
}
}

return nil
}

func (t *TLSPolicyStatusUpdater) enforcedCondition(ctx context.Context, tlsPolicy *kuadrantv1alpha1.TLSPolicy, topology *machinery.Topology) *metav1.Condition {
if err := t.isIssuerReady(ctx, tlsPolicy, topology); err != nil {
return kuadrant.EnforcedCondition(tlsPolicy, kuadrant.NewErrUnknown(tlsPolicy.Kind(), err), false)
func (t *TLSPolicyStatusUpdater) enforcedCondition(ctx context.Context, policy *kuadrantv1alpha1.TLSPolicy, topology *machinery.Topology) *metav1.Condition {
if err := t.isIssuerReady(ctx, policy, topology); err != nil {
return kuadrant.EnforcedCondition(policy, kuadrant.NewErrUnknown(policy.Kind(), err), false)
}

if err := t.isCertificatesReady(tlsPolicy, topology); err != nil {
return kuadrant.EnforcedCondition(tlsPolicy, kuadrant.NewErrUnknown(tlsPolicy.Kind(), err), false)
if err := t.isCertificatesReady(policy, topology); err != nil {
return kuadrant.EnforcedCondition(policy, kuadrant.NewErrUnknown(policy.Kind(), err), false)
}

return kuadrant.EnforcedCondition(tlsPolicy, nil, true)
return kuadrant.EnforcedCondition(policy, nil, true)
}

func (t *TLSPolicyStatusUpdater) isIssuerReady(ctx context.Context, tlsPolicy *kuadrantv1alpha1.TLSPolicy, topology *machinery.Topology) error {
func (t *TLSPolicyStatusUpdater) isIssuerReady(ctx context.Context, policy *kuadrantv1alpha1.TLSPolicy, topology *machinery.Topology) error {
logger := controller.LoggerFromContext(ctx).WithName("TLSPolicyStatusUpdater").WithName("isIssuerReady")

// Get all gateways
Expand All @@ -125,26 +123,26 @@ func (t *TLSPolicyStatusUpdater) isIssuerReady(ctx context.Context, tlsPolicy *k

// Find gateway defined by target ref
gw, ok := lo.Find(gws, func(item *machinery.Gateway) bool {
if item.GetName() == string(tlsPolicy.GetTargetRef().Name) && item.GetNamespace() == tlsPolicy.GetNamespace() {
if item.GetName() == string(policy.GetTargetRef().Name) && item.GetNamespace() == policy.GetNamespace() {
return true
}
return false
})

if !ok {
return fmt.Errorf("unable to find target ref %s for policy %s in ns %s in topology", tlsPolicy.GetTargetRef(), tlsPolicy.Name, tlsPolicy.Namespace)
return fmt.Errorf("unable to find target ref %s for policy %s in ns %s in topology", policy.GetTargetRef(), policy.Name, policy.Namespace)
}

var conditions []certmanagerv1.IssuerCondition

switch tlsPolicy.Spec.IssuerRef.Kind {
switch policy.Spec.IssuerRef.Kind {
case "", certmanagerv1.IssuerKind:
objs := topology.Objects().Children(gw)
obj, ok := lo.Find(objs, func(o machinery.Object) bool {
return o.GroupVersionKind().GroupKind() == CertManagerIssuerKind && o.GetNamespace() == tlsPolicy.GetNamespace() && o.GetName() == tlsPolicy.Spec.IssuerRef.Name
return o.GroupVersionKind().GroupKind() == CertManagerIssuerKind && o.GetNamespace() == policy.GetNamespace() && o.GetName() == policy.Spec.IssuerRef.Name
})
if !ok {
err := fmt.Errorf("%s \"%s\" not found", tlsPolicy.Spec.IssuerRef.Kind, tlsPolicy.Spec.IssuerRef.Name)
err := fmt.Errorf("%s \"%s\" not found", policy.Spec.IssuerRef.Kind, policy.Spec.IssuerRef.Name)
logger.Error(err, "error finding object in topology")
return err
}
Expand All @@ -155,33 +153,33 @@ func (t *TLSPolicyStatusUpdater) isIssuerReady(ctx context.Context, tlsPolicy *k
case certmanagerv1.ClusterIssuerKind:
objs := topology.Objects().Children(gw)
obj, ok := lo.Find(objs, func(o machinery.Object) bool {
return o.GroupVersionKind().GroupKind() == CertManagerClusterIssuerKind && o.GetName() == tlsPolicy.Spec.IssuerRef.Name
return o.GroupVersionKind().GroupKind() == CertManagerClusterIssuerKind && o.GetName() == policy.Spec.IssuerRef.Name
})
if !ok {
err := fmt.Errorf("%s \"%s\" not found", tlsPolicy.Spec.IssuerRef.Kind, tlsPolicy.Spec.IssuerRef.Name)
err := fmt.Errorf("%s \"%s\" not found", policy.Spec.IssuerRef.Kind, policy.Spec.IssuerRef.Name)
logger.Error(err, "error finding object in topology")
return err
}

issuer := obj.(*controller.RuntimeObject).Object.(*certmanagerv1.ClusterIssuer)
conditions = issuer.Status.Conditions
default:
return fmt.Errorf(`invalid value %q for issuerRef.kind. Must be empty, %q or %q`, tlsPolicy.Spec.IssuerRef.Kind, certmanagerv1.IssuerKind, certmanagerv1.ClusterIssuerKind)
return fmt.Errorf(`invalid value %q for issuerRef.kind. Must be empty, %q or %q`, policy.Spec.IssuerRef.Kind, certmanagerv1.IssuerKind, certmanagerv1.ClusterIssuerKind)
}

transformedCond := utils.Map(conditions, func(c certmanagerv1.IssuerCondition) metav1.Condition {
return metav1.Condition{Reason: c.Reason, Status: metav1.ConditionStatus(c.Status), Type: string(c.Type), Message: c.Message}
})

if !meta.IsStatusConditionTrue(transformedCond, string(certmanagerv1.IssuerConditionReady)) {
return fmt.Errorf("%s not ready", tlsPolicy.Spec.IssuerRef.Kind)
return fmt.Errorf("%s not ready", policy.Spec.IssuerRef.Kind)
}

return nil
}

func (t *TLSPolicyStatusUpdater) isCertificatesReady(p machinery.Policy, topology *machinery.Topology) error {
tlsPolicy, ok := p.(*kuadrantv1alpha1.TLSPolicy)
policy, ok := p.(*kuadrantv1alpha1.TLSPolicy)
if !ok {
return errors.New("invalid policy")
}
Expand All @@ -204,7 +202,7 @@ func (t *TLSPolicyStatusUpdater) isCertificatesReady(p machinery.Policy, topolog
continue
}

expectedCertificates := expectedCertificatesForListener(l, tlsPolicy)
expectedCertificates := expectedCertificatesForListener(l, policy)

for _, cert := range expectedCertificates {
objs := topology.Objects().Children(l)
Expand Down

0 comments on commit cd35907

Please sign in to comment.