From 1445641b8ab9d88778c4b34299d05f1e61405a78 Mon Sep 17 00:00:00 2001 From: Tharsanan1 Date: Wed, 25 Sep 2024 14:39:02 +0530 Subject: [PATCH] Improve subscription validation value set logic --- .../operator/controllers/dp/api_controller.go | 10 +- .../dp/airatelimitpolicy_controller.go | 2 +- .../internal/xds/ratelimiter_cache.go | 174 +++++++++--------- .../ballerina/types.bal | 2 +- .../serviceAccount/agent-cluster-role.yaml | 18 ++ 5 files changed, 119 insertions(+), 87 deletions(-) diff --git a/adapter/internal/operator/controllers/dp/api_controller.go b/adapter/internal/operator/controllers/dp/api_controller.go index ff64f7e35..267c505d5 100644 --- a/adapter/internal/operator/controllers/dp/api_controller.go +++ b/adapter/internal/operator/controllers/dp/api_controller.go @@ -894,6 +894,14 @@ func (apiReconciler *APIReconciler) getAPIPolicyChildrenRefs(ctx context.Context backendJWTs := make(map[string]dpv1alpha1.BackendJWT) aiProvider := &dpv1alpha3.AIProvider{} subscriptionValidation := false + for _, apiPolicy := range allAPIPolicies { + if apiPolicy.Spec.Default != nil { + subscriptionValidation = subscriptionValidation || apiPolicy.Spec.Default.SubscriptionValidation + } + if apiPolicy.Spec.Override != nil { + subscriptionValidation = subscriptionValidation || apiPolicy.Spec.Override.SubscriptionValidation + } + } for _, apiPolicy := range allAPIPolicies { if apiPolicy.Spec.Default != nil { if len(apiPolicy.Spec.Default.RequestInterceptors) > 0 { @@ -925,7 +933,6 @@ func (apiReconciler *APIReconciler) getAPIPolicyChildrenRefs(ctx context.Context aiProvider = aiProviderPtr } } - subscriptionValidation = apiPolicy.Spec.Default.SubscriptionValidation } if apiPolicy.Spec.Override != nil { if len(apiPolicy.Spec.Override.RequestInterceptors) > 0 { @@ -957,7 +964,6 @@ func (apiReconciler *APIReconciler) getAPIPolicyChildrenRefs(ctx context.Context aiProvider = aiProviderPtr } } - subscriptionValidation = apiPolicy.Spec.Override.SubscriptionValidation } } return interceptorServices, backendJWTs, subscriptionValidation, aiProvider, nil diff --git a/common-controller/internal/operator/controllers/dp/airatelimitpolicy_controller.go b/common-controller/internal/operator/controllers/dp/airatelimitpolicy_controller.go index 98789f89a..235017395 100644 --- a/common-controller/internal/operator/controllers/dp/airatelimitpolicy_controller.go +++ b/common-controller/internal/operator/controllers/dp/airatelimitpolicy_controller.go @@ -108,7 +108,7 @@ func (r *AIRateLimitPolicyReconciler) Reconcile(ctx context.Context, req ctrl.Re if ratelimitPolicy.Spec.Override == nil { ratelimitPolicy.Spec.Override = ratelimitPolicy.Spec.Default } - if ratelimitPolicy.Spec.TargetRef.Name != "" { + if ratelimitPolicy.Spec.TargetRef.Kind == "Backend" { r.ods.AddorUpdateAIRatelimitToStore(ratelimitKey, ratelimitPolicy.Spec) xds.UpdateRateLimitXDSCacheForAIRatelimitPolicies(r.ods.GetAIRatelimitPolicySpecs()) xds.UpdateRateLimiterPolicies(conf.CommonController.Server.Label) diff --git a/common-controller/internal/xds/ratelimiter_cache.go b/common-controller/internal/xds/ratelimiter_cache.go index 035c6c91f..0b0d0c251 100644 --- a/common-controller/internal/xds/ratelimiter_cache.go +++ b/common-controller/internal/xds/ratelimiter_cache.go @@ -337,62 +337,66 @@ func (r *rateLimitPolicyCache) ProcessSubscriptionBasedAIRatelimitPolicySpecsAnd aiRlDescriptors := make([]*rls_config.RateLimitDescriptor, 0) for namespacedNameRl := range subscriptionEnabledAIRatelimitPolicies { if airl, exists := aiRatelimitPolicySpecs[namespacedNameRl]; exists { - // Add descriptor for RequestTokenCount - aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ - Key: DescriptorKeyForSubscriptionBasedAIRequestTokenCount, - Value: prepareSubscriptionBasedAIRatelimitIdentifier(airl.Override.Organization, namespacedNameRl), - Descriptors: []*rls_config.RateLimitDescriptor{ - { - Key: DescriptorKeyForSubscription, - RateLimit: &rls_config.RateLimitPolicy{ - Unit: getRateLimitUnit(airl.Override.TokenCount.Unit), - RequestsPerUnit: uint32(airl.Override.TokenCount.RequestTokenCount), + if airl.Override.TokenCount != nil { + // Add descriptor for RequestTokenCount + aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ + Key: DescriptorKeyForSubscriptionBasedAIRequestTokenCount, + Value: prepareSubscriptionBasedAIRatelimitIdentifier(airl.Override.Organization, namespacedNameRl), + Descriptors: []*rls_config.RateLimitDescriptor{ + { + Key: DescriptorKeyForSubscription, + RateLimit: &rls_config.RateLimitPolicy{ + Unit: getRateLimitUnit(airl.Override.TokenCount.Unit), + RequestsPerUnit: uint32(airl.Override.TokenCount.RequestTokenCount), + }, }, }, - }, - }) - // Add descriptor for ResponseTokenCount - aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ - Key: DescriptorKeyForSubscriptionBasedAIResponseTokenCount, - Value: prepareSubscriptionBasedAIRatelimitIdentifier(airl.Override.Organization, namespacedNameRl), - Descriptors: []*rls_config.RateLimitDescriptor{ - { - Key: DescriptorKeyForSubscription, - RateLimit: &rls_config.RateLimitPolicy{ - Unit: getRateLimitUnit(airl.Override.TokenCount.Unit), - RequestsPerUnit: uint32(airl.Override.TokenCount.ResponseTokenCount), + }) + // Add descriptor for ResponseTokenCount + aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ + Key: DescriptorKeyForSubscriptionBasedAIResponseTokenCount, + Value: prepareSubscriptionBasedAIRatelimitIdentifier(airl.Override.Organization, namespacedNameRl), + Descriptors: []*rls_config.RateLimitDescriptor{ + { + Key: DescriptorKeyForSubscription, + RateLimit: &rls_config.RateLimitPolicy{ + Unit: getRateLimitUnit(airl.Override.TokenCount.Unit), + RequestsPerUnit: uint32(airl.Override.TokenCount.ResponseTokenCount), + }, }, }, - }, - }) - // Add descriptor for TotalTokenCount - aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ - Key: DescriptorKeyForSubscriptionBasedAITotalTokenCount, - Value: prepareSubscriptionBasedAIRatelimitIdentifier(airl.Override.Organization, namespacedNameRl), - Descriptors: []*rls_config.RateLimitDescriptor{ - { - Key: DescriptorKeyForSubscription, - RateLimit: &rls_config.RateLimitPolicy{ - Unit: getRateLimitUnit(airl.Override.TokenCount.Unit), - RequestsPerUnit: uint32(airl.Override.TokenCount.TotalTokenCount), + }) + // Add descriptor for TotalTokenCount + aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ + Key: DescriptorKeyForSubscriptionBasedAITotalTokenCount, + Value: prepareSubscriptionBasedAIRatelimitIdentifier(airl.Override.Organization, namespacedNameRl), + Descriptors: []*rls_config.RateLimitDescriptor{ + { + Key: DescriptorKeyForSubscription, + RateLimit: &rls_config.RateLimitPolicy{ + Unit: getRateLimitUnit(airl.Override.TokenCount.Unit), + RequestsPerUnit: uint32(airl.Override.TokenCount.TotalTokenCount), + }, }, }, - }, - }) + }) + } // Add descriptor for RequestCount - aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ - Key: DescriptorKeyForSubscriptionBasedAIRequestCount, - Value: prepareSubscriptionBasedAIRatelimitIdentifier(airl.Override.Organization, namespacedNameRl), - Descriptors: []*rls_config.RateLimitDescriptor{ - { - Key: DescriptorKeyForSubscription, - RateLimit: &rls_config.RateLimitPolicy{ - Unit: getRateLimitUnit(airl.Override.TokenCount.Unit), - RequestsPerUnit: uint32(airl.Override.RequestCount.RequestsPerUnit), + if airl.Override.RequestCount != nil { + aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ + Key: DescriptorKeyForSubscriptionBasedAIRequestCount, + Value: prepareSubscriptionBasedAIRatelimitIdentifier(airl.Override.Organization, namespacedNameRl), + Descriptors: []*rls_config.RateLimitDescriptor{ + { + Key: DescriptorKeyForSubscription, + RateLimit: &rls_config.RateLimitPolicy{ + Unit: getRateLimitUnit(airl.Override.RequestCount.Unit), + RequestsPerUnit: uint32(airl.Override.RequestCount.RequestsPerUnit), + }, }, }, - }, - }) + }) + } } } r.subscriptionBasedAIRatelimitDescriptors = aiRlDescriptors @@ -402,42 +406,46 @@ func (r *rateLimitPolicyCache) ProcessSubscriptionBasedAIRatelimitPolicySpecsAnd func (r *rateLimitPolicyCache) ProcessAIRatelimitPolicySpecsAndUpdateCache(aiRateLimitPolicySpecs map[types.NamespacedName]*dpv1alpha3.AIRateLimitPolicySpec) { aiRlDescriptors := make([]*rls_config.RateLimitDescriptor, 0) for namespacedName, spec := range aiRateLimitPolicySpecs { - // Add descriptor for RequestTokenCount - aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ - Key: DescriptorKeyForAIRequestTokenCount, - Value: prepareAIRatelimitIdentifier(spec.Override.Organization, namespacedName, spec), - RateLimit: &rls_config.RateLimitPolicy{ - Unit: getRateLimitUnit(spec.Override.TokenCount.Unit), - RequestsPerUnit: uint32(spec.Override.TokenCount.RequestTokenCount), - }, - }) - // Add descriptor for ResponseTokenCount - aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ - Key: DescriptorKeyForAIResponseTokenCount, - Value: prepareAIRatelimitIdentifier(spec.Override.Organization, namespacedName, spec), - RateLimit: &rls_config.RateLimitPolicy{ - Unit: getRateLimitUnit(spec.Override.TokenCount.Unit), - RequestsPerUnit: uint32(spec.Override.TokenCount.ResponseTokenCount), - }, - }) - // Add descriptor for TotalTokenCount - aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ - Key: DescriptorKeyForAITotalTokenCount, - Value: prepareAIRatelimitIdentifier(spec.Override.Organization, namespacedName, spec), - RateLimit: &rls_config.RateLimitPolicy{ - Unit: getRateLimitUnit(spec.Override.TokenCount.Unit), - RequestsPerUnit: uint32(spec.Override.TokenCount.TotalTokenCount), - }, - }) - // Add descriptor for RequestCount - aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ - Key: DescriptorKeyForAIRequestCount, - Value: prepareAIRatelimitIdentifier(spec.Override.Organization, namespacedName, spec), - RateLimit: &rls_config.RateLimitPolicy{ - Unit: getRateLimitUnit(spec.Override.RequestCount.Unit), - RequestsPerUnit: uint32(spec.Override.RequestCount.RequestsPerUnit), - }, - }) + if spec.Override.TokenCount != nil { + // Add descriptor for RequestTokenCount + aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ + Key: DescriptorKeyForAIRequestTokenCount, + Value: prepareAIRatelimitIdentifier(spec.Override.Organization, namespacedName, spec), + RateLimit: &rls_config.RateLimitPolicy{ + Unit: getRateLimitUnit(spec.Override.TokenCount.Unit), + RequestsPerUnit: uint32(spec.Override.TokenCount.RequestTokenCount), + }, + }) + // Add descriptor for ResponseTokenCount + aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ + Key: DescriptorKeyForAIResponseTokenCount, + Value: prepareAIRatelimitIdentifier(spec.Override.Organization, namespacedName, spec), + RateLimit: &rls_config.RateLimitPolicy{ + Unit: getRateLimitUnit(spec.Override.TokenCount.Unit), + RequestsPerUnit: uint32(spec.Override.TokenCount.ResponseTokenCount), + }, + }) + // Add descriptor for TotalTokenCount + aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ + Key: DescriptorKeyForAITotalTokenCount, + Value: prepareAIRatelimitIdentifier(spec.Override.Organization, namespacedName, spec), + RateLimit: &rls_config.RateLimitPolicy{ + Unit: getRateLimitUnit(spec.Override.TokenCount.Unit), + RequestsPerUnit: uint32(spec.Override.TokenCount.TotalTokenCount), + }, + }) + } + if spec.Override.RequestCount != nil { + // Add descriptor for RequestCount + aiRlDescriptors = append(aiRlDescriptors, &rls_config.RateLimitDescriptor{ + Key: DescriptorKeyForAIRequestCount, + Value: prepareAIRatelimitIdentifier(spec.Override.Organization, namespacedName, spec), + RateLimit: &rls_config.RateLimitPolicy{ + Unit: getRateLimitUnit(spec.Override.RequestCount.Unit), + RequestsPerUnit: uint32(spec.Override.RequestCount.RequestsPerUnit), + }, + }) + } } r.aiRatelimitDescriptors = aiRlDescriptors } diff --git a/runtime/config-deployer-service/ballerina/types.bal b/runtime/config-deployer-service/ballerina/types.bal index ed2abbccd..de9b1aad8 100644 --- a/runtime/config-deployer-service/ballerina/types.bal +++ b/runtime/config-deployer-service/ballerina/types.bal @@ -299,7 +299,7 @@ public type EndpointConfigurations record { # + endpointSecurity - The security configuration for the endpoint. # + certificate - The certificate configuration for the endpoint. # + resiliency - The resiliency configuration for the endpoint. -# + AIRatelimit - The AIRatelimit configuration for the AI ratelimit. +# + aiRatelimit - The AIRatelimit configuration for the AI ratelimit. public type EndpointConfiguration record { string|K8sService endpoint; EndpointSecurity endpointSecurity?; diff --git a/test/apim-apk-agent-test/agent-helm-chart/templates/serviceAccount/agent-cluster-role.yaml b/test/apim-apk-agent-test/agent-helm-chart/templates/serviceAccount/agent-cluster-role.yaml index 86e208218..44fdf0481 100644 --- a/test/apim-apk-agent-test/agent-helm-chart/templates/serviceAccount/agent-cluster-role.yaml +++ b/test/apim-apk-agent-test/agent-helm-chart/templates/serviceAccount/agent-cluster-role.yaml @@ -122,4 +122,22 @@ rules: - apiGroups: ["dp.wso2.com"] resources: ["gqlroutes/status"] verbs: ["get","patch","update"] + - apiGroups: ["dp.wso2.com"] + resources: ["aiproviders"] + verbs: ["get","list","watch","update","delete","create"] + - apiGroups: ["dp.wso2.com"] + resources: ["aiproviders/finalizers"] + verbs: ["update"] + - apiGroups: ["dp.wso2.com"] + resources: ["aiproviders/status"] + verbs: ["get","patch","update"] + - apiGroups: ["dp.wso2.com"] + resources: ["airatelimitpolicies"] + verbs: ["get","list","watch","update","delete","create"] + - apiGroups: ["dp.wso2.com"] + resources: ["airatelimitpolicies/finalizers"] + verbs: ["update"] + - apiGroups: ["dp.wso2.com"] + resources: ["airatelimitpolicies/status"] + verbs: ["get","patch","update"] {{- end }} \ No newline at end of file