diff --git a/api/v1/merge_strategies.go b/api/v1/merge_strategies.go index f169c5f73..779c9f157 100644 --- a/api/v1/merge_strategies.go +++ b/api/v1/merge_strategies.go @@ -69,13 +69,11 @@ func AtomicDefaultsMergeStrategy(source, target machinery.Policy) machinery.Poli return source } - mergeableTargetPolicy := target.(MergeablePolicy) - - if !mergeableTargetPolicy.Empty() { - return mergeableTargetPolicy.DeepCopyObject().(machinery.Policy) + if mergeableTarget := target.(MergeablePolicy); !mergeableTarget.Empty() { + return copyMergeablePolicy(mergeableTarget) } - return source.(MergeablePolicy).DeepCopyObject().(machinery.Policy) + return copyMergeablePolicy(source.(MergeablePolicy)) } var _ machinery.MergeStrategy = AtomicDefaultsMergeStrategy @@ -86,7 +84,7 @@ func AtomicOverridesMergeStrategy(source, _ machinery.Policy) machinery.Policy { if source == nil { return nil } - return source.(MergeablePolicy).DeepCopyObject().(machinery.Policy) + return copyMergeablePolicy(source.(MergeablePolicy)) } var _ machinery.MergeStrategy = AtomicOverridesMergeStrategy @@ -105,12 +103,16 @@ func PolicyRuleDefaultsMergeStrategy(source, target machinery.Policy) machinery. targetMergeablePolicy := target.(MergeablePolicy) // copy rules from the target - rules := targetMergeablePolicy.Rules() + rules := lo.MapValues(targetMergeablePolicy.Rules(), mapRuleWithSourceFunc(target)) // add extra rules from the source for ruleID, rule := range sourceMergeablePolicy.Rules() { if _, ok := targetMergeablePolicy.Rules()[ruleID]; !ok { - rules[ruleID] = rule.WithSource(source.GetLocator()) + origin := rule.GetSource() + if origin == "" { + origin = source.GetLocator() + } + rules[ruleID] = rule.WithSource(origin) } } @@ -129,12 +131,16 @@ func PolicyRuleOverridesMergeStrategy(source, target machinery.Policy) machinery targetMergeablePolicy := target.(MergeablePolicy) // copy rules from the source - rules := sourceMergeablePolicy.Rules() + rules := lo.MapValues(sourceMergeablePolicy.Rules(), mapRuleWithSourceFunc(source)) // add extra rules from the target for ruleID, rule := range targetMergeablePolicy.Rules() { if _, ok := sourceMergeablePolicy.Rules()[ruleID]; !ok { - rules[ruleID] = rule + origin := rule.GetSource() + if origin == "" { + origin = target.GetLocator() + } + rules[ruleID] = rule.WithSource(origin) } } @@ -206,3 +212,15 @@ func PathID(path []machinery.Targetable) string { return strings.TrimPrefix(k8stypes.NamespacedName{Namespace: t.GetNamespace(), Name: t.GetName()}.String(), string(k8stypes.Separator)) }), "|") } + +func mapRuleWithSourceFunc(source machinery.Policy) func(MergeableRule, string) MergeableRule { + return func(rule MergeableRule, _ string) MergeableRule { + return rule.WithSource(source.GetLocator()) + } +} + +func copyMergeablePolicy(policy MergeablePolicy) MergeablePolicy { + dup := policy.DeepCopyObject().(MergeablePolicy) + dup.SetRules(lo.MapValues(dup.Rules(), mapRuleWithSourceFunc(policy))) + return dup +} diff --git a/api/v1beta3/authpolicy_types.go b/api/v1beta3/authpolicy_types.go index 27aba8cf3..8aae9cda9 100644 --- a/api/v1beta3/authpolicy_types.go +++ b/api/v1beta3/authpolicy_types.go @@ -189,6 +189,11 @@ func (p *AuthPolicy) Rules() map[string]kuadrantv1.MergeableRule { } func (p *AuthPolicy) SetRules(rules map[string]kuadrantv1.MergeableRule) { + // clear all rules of the policy before setting new ones + p.Spec.Proper().NamedPatterns = nil + p.Spec.Proper().Conditions = nil + p.Spec.Proper().AuthScheme = nil + ensureNamedPatterns := func() { if p.Spec.Proper().NamedPatterns == nil { p.Spec.Proper().NamedPatterns = make(map[string]MergeablePatternExpressions) diff --git a/api/v1beta3/ratelimitpolicy_types.go b/api/v1beta3/ratelimitpolicy_types.go index 796f911c7..0d058291b 100644 --- a/api/v1beta3/ratelimitpolicy_types.go +++ b/api/v1beta3/ratelimitpolicy_types.go @@ -132,7 +132,10 @@ func (p *RateLimitPolicy) Rules() map[string]kuadrantv1.MergeableRule { } func (p *RateLimitPolicy) SetRules(rules map[string]kuadrantv1.MergeableRule) { - if len(rules) > 0 && p.Spec.Proper().Limits == nil { + // clear all rules of the policy before setting new ones + p.Spec.Proper().Limits = nil + + if len(rules) > 0 { p.Spec.Proper().Limits = make(map[string]Limit) }