diff --git a/adapter/internal/oasparser/envoyconf/http_filters.go b/adapter/internal/oasparser/envoyconf/http_filters.go index 4b57e7ba6..5ab28d2a5 100644 --- a/adapter/internal/oasparser/envoyconf/http_filters.go +++ b/adapter/internal/oasparser/envoyconf/http_filters.go @@ -241,12 +241,12 @@ func getExtProcessHTTPFilter() *hcmv3.HttpFilter { }, FailureModeAllow: true, ProcessingMode: &ext_process.ProcessingMode{ - // ResponseBodyMode: ext_process.ProcessingMode_BUFFERED, + ResponseBodyMode: ext_process.ProcessingMode_BUFFERED, RequestHeaderMode: ext_process.ProcessingMode_SEND, ResponseHeaderMode: ext_process.ProcessingMode_SEND, // RequestHeaderMode: ext_process.ProcessingMode_SKIP, // ResponseHeaderMode: ext_process.ProcessingMode_SKIP, - // RequestBodyMode: ext_process.ProcessingMode_BUFFERED, + RequestBodyMode: ext_process.ProcessingMode_BUFFERED, }, MetadataOptions: &ext_process.MetadataOptions{ ForwardingNamespaces: &ext_process.MetadataOptions_MetadataNamespaces{ diff --git a/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/basic-auth.routes.yaml b/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/basic-auth.routes.yaml index f87be1147..75d30a059 100644 --- a/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/basic-auth.routes.yaml +++ b/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/basic-auth.routes.yaml @@ -12,7 +12,3 @@ cluster: first-route-dest upgradeConfigs: - upgradeType: websocket - typedPerFilterConfig: - envoy.filters.http.basic_auth_first-route: - '@type': type.googleapis.com/envoy.config.route.v3.FilterConfig - config: {} diff --git a/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/ext-auth.routes.yaml b/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/ext-auth.routes.yaml index 349a4edec..7d65845ce 100644 --- a/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/ext-auth.routes.yaml +++ b/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/ext-auth.routes.yaml @@ -12,12 +12,6 @@ cluster: httproute/default/httproute-1/rule/0 upgradeConfigs: - upgradeType: websocket - typedPerFilterConfig: - envoy.filters.http.ext_authz_: - '@type': type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute - checkSettings: - contextExtensions: - route-name: httproute/default/httproute-1/rule/0/match/0/www_example_com - match: path: bar name: httproute/default/httproute-2/rule/0/match/0/www_example_com @@ -25,9 +19,3 @@ cluster: httproute/default/httproute-2/rule/0 upgradeConfigs: - upgradeType: websocket - typedPerFilterConfig: - envoy.filters.http.ext_authz_: - '@type': type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute - checkSettings: - contextExtensions: - route-name: httproute/default/httproute-2/rule/0/match/0/www_example_com diff --git a/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/multiple-listeners-same-port-with-different-filters.routes.yaml b/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/multiple-listeners-same-port-with-different-filters.routes.yaml index cd35bd956..eac6583e3 100755 --- a/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/multiple-listeners-same-port-with-different-filters.routes.yaml +++ b/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/multiple-listeners-same-port-with-different-filters.routes.yaml @@ -17,10 +17,6 @@ cluster: httproute/default/httproute-1/rule/0 upgradeConfigs: - upgradeType: websocket - typedPerFilterConfig: - envoy.filters.http.basic_auth_httproute/default/httproute-1/rule/0/match/0/www_foo_com: - '@type': type.googleapis.com/envoy.config.route.v3.FilterConfig - config: {} - match: pathSeparatedPrefix: /foo2 name: httproute/default/httproute-2/rule/0/match/0/www_foo_com @@ -33,12 +29,6 @@ cluster: httproute/default/httproute-2/rule/0 upgradeConfigs: - upgradeType: websocket - typedPerFilterConfig: - envoy.filters.http.ext_authz_securitypolicy/default/policy-for-http-route-2: - '@type': type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute - checkSettings: - contextExtensions: - route-name: httproute/default/httproute-2/rule/0/match/0/www_foo_com - domains: - www.bar.com name: default/gateway-2/http/www_bar_com @@ -50,7 +40,3 @@ cluster: httproute/default/httproute-3/rule/0 upgradeConfigs: - upgradeType: websocket - typedPerFilterConfig: - envoy.filters.http.oauth2_httproute/default/httproute-3/rule/0/match/0/www_bar_com: - '@type': type.googleapis.com/envoy.config.route.v3.FilterConfig - config: {} diff --git a/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/oidc.routes.yaml b/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/oidc.routes.yaml index a093d6967..6f96996d0 100644 --- a/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/oidc.routes.yaml +++ b/adapter/internal/operator/gateway-api/translator/testdata/out/xds-ir/oidc.routes.yaml @@ -12,10 +12,6 @@ cluster: first-route-dest upgradeConfigs: - upgradeType: websocket - typedPerFilterConfig: - envoy.filters.http.oauth2_first-route: - '@type': type.googleapis.com/envoy.config.route.v3.FilterConfig - config: {} - match: path: bar name: second-route @@ -23,7 +19,3 @@ cluster: second-route-dest upgradeConfigs: - upgradeType: websocket - typedPerFilterConfig: - envoy.filters.http.oauth2_second-route: - '@type': type.googleapis.com/envoy.config.route.v3.FilterConfig - config: {} diff --git a/gateway/enforcer/cmd/main.go b/gateway/enforcer/cmd/main.go index 7c07d7c43..19076b413 100644 --- a/gateway/enforcer/cmd/main.go +++ b/gateway/enforcer/cmd/main.go @@ -32,10 +32,10 @@ func main() { client.InitiateEventingGRPCConnection() // Create the XDS clients - xds.CreateXDSClients(cfg) + apiStore, _,_ := xds.CreateXDSClients(cfg) // Start the external processing server - go extproc.StartExternalProcessingServer(cfg) + go extproc.StartExternalProcessingServer(cfg, apiStore) // Wait forever select {} diff --git a/gateway/enforcer/internal/datastore/api_store.go b/gateway/enforcer/internal/datastore/api_store.go index c8f085422..073f2ce0b 100644 --- a/gateway/enforcer/internal/datastore/api_store.go +++ b/gateway/enforcer/internal/datastore/api_store.go @@ -21,18 +21,19 @@ import ( "sync" api "github.com/wso2/apk/adapter/pkg/discovery/api/wso2/discovery/api" + "github.com/wso2/apk/gateway/enforcer/internal/util" ) // APIStore is a thread-safe store for APIs. type APIStore struct { - apis []*api.Api + apis map[string]*api.Api mu sync.RWMutex } // NewAPIStore creates a new instance of APIStore. func NewAPIStore() *APIStore { return &APIStore{ - apis: make([]*api.Api, 0), + // apis: make(map[string]*api.Api, 0), } } @@ -41,13 +42,23 @@ func NewAPIStore() *APIStore { func (s *APIStore) AddAPIs(apis []*api.Api) { s.mu.Lock() defer s.mu.Unlock() - s.apis = apis + s.apis = make(map[string]*api.Api, len(apis)) + for _, api := range apis { + s.apis[util.PrepareAPIKey(api.Vhost, api.BasePath, api.Version)] = api + } } // GetAPIs retrieves the list of APIs from the store. // This method is thread-safe. -func (s *APIStore) GetAPIs() []*api.Api { +func (s *APIStore) GetAPIs() map[string]*api.Api { s.mu.RLock() defer s.mu.RUnlock() return s.apis } + +// GetMatchedAPI retrieves the API that matches the given API key. +func (s *APIStore) GetMatchedAPI(apiKey string) *api.Api { + s.mu.RLock() + defer s.mu.RUnlock() + return s.apis[apiKey] +} diff --git a/gateway/enforcer/internal/extproc/ext_proc.go b/gateway/enforcer/internal/extproc/ext_proc.go index 6b2f67889..a27b14f61 100644 --- a/gateway/enforcer/internal/extproc/ext_proc.go +++ b/gateway/enforcer/internal/extproc/ext_proc.go @@ -23,8 +23,11 @@ import ( corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_service_proc_v3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + api "github.com/wso2/apk/adapter/pkg/discovery/api/wso2/discovery/api" "github.com/wso2/apk/gateway/enforcer/internal/config" + "github.com/wso2/apk/gateway/enforcer/internal/datastore" "github.com/wso2/apk/gateway/enforcer/internal/logging" + "github.com/wso2/apk/gateway/enforcer/internal/ratelimit" "github.com/wso2/apk/gateway/enforcer/internal/util" "net" @@ -35,6 +38,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/prototext" structpb "google.golang.org/protobuf/types/known/structpb" ) @@ -43,6 +47,9 @@ import ( type ExternalProcessingServer struct { log logging.Logger externalProcessingEnvoyAttributes *ExternalProcessingEnvoyAttributes + matchedAPI *api.Api + apiStore *datastore.APIStore + ratelimitHelper *ratelimit.AIRatelimitHelper } // ExternalProcessingEnvoyAttributes represents the attributes extracted from the external processing request. @@ -71,7 +78,7 @@ const ( ) // Define the regular expression as a constant -const keyValuePattern = `key: "(.*?)" value { string_value: "(.*?)" }` +const keyValuePattern = `key: "([^.]*)" value { string_value: "(.*?)" }` // Pre-compile the regular expression var re = regexp.MustCompile(keyValuePattern) @@ -85,7 +92,7 @@ var re = regexp.MustCompile(keyValuePattern) // public and private keys, and a logger instance. // // If there is an error during the creation of the gRPC server, the function will panic. -func StartExternalProcessingServer(cfg *config.Server) { +func StartExternalProcessingServer(cfg *config.Server, apiStore *datastore.APIStore) { kaParams := keepalive.ServerParameters{ Time: time.Duration(cfg.ExternalProcessingKeepAliveTime) * time.Hour, // Ping the client if it is idle for 2 hours Timeout: 20 * time.Second, @@ -99,7 +106,8 @@ func StartExternalProcessingServer(cfg *config.Server) { panic(err) } - envoy_service_proc_v3.RegisterExternalProcessorServer(server, &ExternalProcessingServer{cfg.Logger, nil}) + ratelimitHelper := ratelimit.NewAIRatelimitHelper(cfg) + envoy_service_proc_v3.RegisterExternalProcessorServer(server, &ExternalProcessingServer{cfg.Logger, nil, nil, apiStore, ratelimitHelper}) listener, err := net.Listen("tcp", fmt.Sprintf(":%s", cfg.ExternalProcessingPort)) if err != nil { cfg.Logger.Error(err, fmt.Sprintf("Failed to listen on port: %s", cfg.ExternalProcessingPort)) @@ -152,6 +160,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro s.log.Error(err, "failed to extract context attributes") } s.externalProcessingEnvoyAttributes = attributes + s.matchedAPI = s.apiStore.GetMatchedAPI(util.PrepareAPIKey(s.externalProcessingEnvoyAttributes.VHost, s.externalProcessingEnvoyAttributes.BasePath, s.externalProcessingEnvoyAttributes.APIVersion)) rhq := &envoy_service_proc_v3.HeadersResponse{ Response: &envoy_service_proc_v3.CommonResponse{ @@ -185,10 +194,25 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro ResponseHeaders: rhq, }, } + // s.log.Info(fmt.Sprintf("Matched api: %s", s.matchedAPI)) + if s.matchedAPI.Aiprovider != nil && + s.matchedAPI.Aiprovider.CompletionToken != nil && + s.externalProcessingEnvoyAttributes.EnableBackendBasedAIRatelimit == "true" && + s.matchedAPI.Aiprovider.CompletionToken.In == "Header" { + s.log.Info("Backend based AI rate limit enabled using headers") + tokenCount, err := ratelimit.ExtractTokenCountFromExternalProcessingResponseHeaders(req.GetResponseHeaders().GetHeaders().GetHeaders(), s.matchedAPI.Aiprovider.PromptTokens.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.Model.Value) + if err != nil { + s.log.Error(err, "failed to extract token count from response headers") + } else { + s.ratelimitHelper.DoAIRatelimit(tokenCount, true, false, s.externalProcessingEnvoyAttributes.BackendBasedAIRatelimitDescriptorValue) + } + } + break case *envoy_service_proc_v3.ProcessingRequest_ResponseBody: // httpBody := req.GetResponseBody() - // s.log.Info(fmt.Sprint("response body \n")) + s.log.Info(fmt.Sprintf("attribute %+v\n", s.externalProcessingEnvoyAttributes)) + rbq := &envoy_service_proc_v3.BodyResponse{ Response: &envoy_service_proc_v3.CommonResponse{}, } @@ -198,6 +222,19 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro }, } + if s.matchedAPI.Aiprovider != nil && + s.matchedAPI.Aiprovider.CompletionToken != nil && + s.externalProcessingEnvoyAttributes.EnableBackendBasedAIRatelimit == "true" && + s.matchedAPI.Aiprovider.CompletionToken.In == "Body" { + s.log.Info("Backend based AI rate limit enabled using body") + tokenCount, err := ratelimit.ExtractTokenCountFromExternalProcessingResponseBody(req.GetResponseBody().Body, s.matchedAPI.Aiprovider.PromptTokens.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.Model.Value) + if err != nil { + s.log.Error(err, "failed to extract token count from response body") + } else { + s.ratelimitHelper.DoAIRatelimit(tokenCount, true, false, s.externalProcessingEnvoyAttributes.BackendBasedAIRatelimitDescriptorValue) + } + } + case *envoy_service_proc_v3.ProcessingRequest_RequestBody: // httpBody := req.GetRequestBody() // s.log.Info(fmt.Sprint("request body")) @@ -218,6 +255,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro } } +// extractExternalProcessingAttributes extracts the external processing attributes from the given data. func extractExternalProcessingAttributes(data map[string]*structpb.Struct) (*ExternalProcessingEnvoyAttributes, error) { // Get the fields from the map @@ -234,39 +272,62 @@ func extractExternalProcessingAttributes(data map[string]*structpb.Struct) (*Ext if field, ok := fields["xds.route_metadata"]; ok { filterMetadata := field.GetStringValue() + var structData corev3.Metadata + err := prototext.Unmarshal([]byte(filterMetadata), &structData) + if err != nil { + return nil, fmt.Errorf("failed to parse Protobuf text: %v", err) + } - matches := re.FindAllStringSubmatch(filterMetadata, -1) + // Extract values for predefined keys + extractedValues := make(map[string]string) - // Iterate over the matches and assign values to the struct - for _, match := range matches { - key, value := match[1], match[2] + keysToExtract := []string{ + pathAttribute, + vHostAttribute, + basePathAttribute, + methodAttribute, + apiVersionAttribute, + apiNameAttribute, + clusterNameAttribute, + enableBackendBasedAIRatelimitAttribute, + backendBasedAIRatelimitDescriptorValueAttribute, + } - switch key { - case enableBackendBasedAIRatelimitAttribute: - attributes.EnableBackendBasedAIRatelimit = value - case backendBasedAIRatelimitDescriptorValueAttribute: - attributes.BackendBasedAIRatelimitDescriptorValue = value - case pathAttribute: - attributes.Path = value - case vHostAttribute: - attributes.VHost = value - case basePathAttribute: - attributes.BasePath = value - case methodAttribute: - attributes.Method = value - case apiNameAttribute: - attributes.APIName = value - case apiVersionAttribute: - attributes.APIVersion = value - case clusterNameAttribute: - attributes.ClusterName = value - default: + for _, key := range keysToExtract { + if field, exists := structData.FilterMetadata["envoy.filters.http.ext_proc"]; exists { + extractedValues[key] = field.Fields[key].GetStringValue() + // case condition to populate ExternalProcessingEnvoyAttributes + switch key { + case pathAttribute: + attributes.Path = extractedValues[key] + case vHostAttribute: + attributes.VHost = extractedValues[key] + case basePathAttribute: + attributes.BasePath = extractedValues[key] + case methodAttribute: + attributes.Method = extractedValues[key] + case apiVersionAttribute: + attributes.APIVersion = extractedValues[key] + case apiNameAttribute: + attributes.APIName = extractedValues[key] + case clusterNameAttribute: + attributes.ClusterName = extractedValues[key] + case enableBackendBasedAIRatelimitAttribute: + attributes.EnableBackendBasedAIRatelimit = extractedValues[key] + case backendBasedAIRatelimitDescriptorValueAttribute: + attributes.BackendBasedAIRatelimitDescriptorValue = extractedValues[key] + } } } - } else { - fmt.Println("Key 'xds.route_metadata' not found in fields") + + // Print extracted values + for key, value := range extractedValues { + fmt.Printf("%s: %s\n", key, value) + } + // Return the populated struct + return attributes, nil } - // Return the populated struct - return attributes, nil + // Key not found + return nil, fmt.Errorf("key xds.route_metadata not found") } diff --git a/gateway/enforcer/internal/ratelimit/ai_ratelimit.go b/gateway/enforcer/internal/ratelimit/ai_ratelimit.go new file mode 100644 index 000000000..ad2afa397 --- /dev/null +++ b/gateway/enforcer/internal/ratelimit/ai_ratelimit.go @@ -0,0 +1,245 @@ +package ratelimit + +import ( + "encoding/json" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + + v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + "github.com/wso2/apk/gateway/enforcer/internal/config" + "github.com/wso2/apk/gateway/enforcer/internal/util" +) + +// AIRatelimitHelper is a helper struct for managing AI rate limiting. +type AIRatelimitHelper struct { + rlClient *client +} + +// TokenCountAndModel is a struct that holds the prompt, completion, and total token counts. +type TokenCountAndModel struct { + promt int + completion int + total int + model string +} + +const ( + // DescriptorKeyForAIPromtTokenCount is the descriptor key for the AI prompt token count. + DescriptorKeyForAIPromtTokenCount = "airequesttokencount" + // DescriptorKeyForAICompletionTokenCount is the descriptor key for the AI completion token count. + DescriptorKeyForAICompletionTokenCount = "airesponsetokencount" + // DescriptorKeyForAITotalTokenCount is the descriptor key for the AI total token count. + DescriptorKeyForAITotalTokenCount = "aitotaltokencount" +) + +// NewAIRatelimitHelper creates a new instance of the AIRatelimitHelper. +func NewAIRatelimitHelper(cfg *config.Server) *AIRatelimitHelper { + client := newClient(cfg) + client.start() + return &AIRatelimitHelper{ + rlClient: client, + } +} + +// DoAIRatelimit performs AI rate limiting. +func (airl *AIRatelimitHelper) DoAIRatelimit(tokenCount *TokenCountAndModel, doBackendBasedAIRatelimit bool, doSubscriptionBasedAIRatelimit bool, backendBasedAIRatelimitDescriptorValue string) { + go func() { + configs := []*keyValueHitsAddend{} + if doBackendBasedAIRatelimit { + // For promt token count + configs = append(configs, &keyValueHitsAddend{ + Key: DescriptorKeyForAIPromtTokenCount, + Value: backendBasedAIRatelimitDescriptorValue, + HitsAddend: tokenCount.promt, + }) + // For completion token count + configs = append(configs, &keyValueHitsAddend{ + Key: DescriptorKeyForAICompletionTokenCount, + Value: backendBasedAIRatelimitDescriptorValue, + HitsAddend: tokenCount.completion, + }) + // For total token count + configs = append(configs, &keyValueHitsAddend{ + Key: DescriptorKeyForAITotalTokenCount, + Value: backendBasedAIRatelimitDescriptorValue, + HitsAddend: tokenCount.total, + }) + airl.rlClient.shouldRatelimit(configs) + } + }() +} + +// ExtractTokenCountFromExternalProcessingResponseHeaders extracts token counts from external processing response headers. +func ExtractTokenCountFromExternalProcessingResponseHeaders(headerValues []*v3.HeaderValue, promptHeader, completionHeader, totalHeader, modelHeader string) (*TokenCountAndModel, error) { + tokenCount := &TokenCountAndModel{} + promtFlag, completionFlag, totalFlag := false, false, false + for _, headerValue := range headerValues { + if headerValue.Key == promptHeader { + if headerValue.Value != "" { + value, err := util.ConvertStringToInt(headerValue.Value) + if err != nil { + tokenCount.promt = value + promtFlag = true + } else { + return nil, err + } + } else if len(headerValue.RawValue) != 0 { + value, err := util.ConvertBytesToInt(headerValue.RawValue) + if err != nil { + tokenCount.promt = value + promtFlag = true + } else { + return nil, err + } + } + + } else if headerValue.Key == completionHeader { + if headerValue.Value != "" { + value, err := strconv.Atoi(headerValue.Value) + if err != nil { + tokenCount.completion = value + completionFlag = true + } else { + return nil, err + } + } else if len(headerValue.RawValue) != 0 { + value, err := util.ConvertBytesToInt(headerValue.RawValue) + if err != nil { + tokenCount.completion = value + completionFlag = true + } else { + return nil, err + } + } + } else if headerValue.Key == totalHeader { + if headerValue.Value != "" { + value, err := strconv.Atoi(headerValue.Value) + if err != nil { + tokenCount.total = value + totalFlag = true + } else { + return nil, err + } + } else if len(headerValue.RawValue) != 0 { + value, err := util.ConvertBytesToInt(headerValue.RawValue) + if err != nil { + tokenCount.total = value + totalFlag = true + } else { + return nil, err + } + } + } else if headerValue.Key == modelHeader { + if headerValue.Value != "" { + tokenCount.model = headerValue.Value + } else if len(headerValue.RawValue) != 0 { + tokenCount.model = string(headerValue.RawValue) + } + } + } + if !promtFlag || !completionFlag || !totalFlag { + return nil, fmt.Errorf("missing token headers from the AI response headers") + } + return tokenCount, nil +} + +// ExtractTokenCountFromExternalProcessingResponseBody extracts token counts from external processing response body. +func ExtractTokenCountFromExternalProcessingResponseBody(body []byte, promptPath, completionPath, totalPath, modelPath string) (*TokenCountAndModel, error) { + bodyStr := string(body) + sanitizedBody := sanitize(bodyStr) + tokenCount, err := extractUsageFromBody(sanitizedBody, promptPath, completionPath, totalPath, "model") + if err != nil { + return nil, fmt.Errorf("failed to extract token count from the AI response body: %w", err) + } + return tokenCount, nil + +} + +func sanitize(input string) string { + // Define a regex to match all newline characters and tabs + re := regexp.MustCompile(`[\t\n\r]+`) + // Replace matches with a space and trim the result + return strings.TrimSpace(re.ReplaceAllString(input, " ")) +} + +// extractValueFromPath extracts a value from a nested JSON structure based on a dot-separated path. +func extractValueFromPath(data map[string]interface{}, path string) (interface{}, error) { + keys := strings.Split(path, ".") + if len(keys) > 0 && keys[0] == "$" { + keys = keys[1:] + } + + var current interface{} = data + for _, key := range keys { + if node, ok := current.(map[string]interface{}); ok { + if val, exists := node[key]; exists { + current = val + } else { + return nil, errors.New("key not found: " + key) + } + } else { + return nil, errors.New("invalid structure for key: " + key) + } + } + return current, nil +} + +// extractUsageFromBody extracts usage data from the JSON body based on the provided paths. +func extractUsageFromBody(body, completionTokenPath, promptTokenPath, totalTokenPath, modelPath string) (*TokenCountAndModel, error) { + body = sanitize(body) + var rootNode map[string]interface{} + if err := json.Unmarshal([]byte(body), &rootNode); err != nil { + return nil, fmt.Errorf("failed to parse JSON: %w", err) + } + + usage := &TokenCountAndModel{} + + // Extract prompt tokens + promt, err := extractValueFromPath(rootNode, promptTokenPath) + if err != nil { + return nil, fmt.Errorf("failed to extract prompt tokens: %w", err) + } + if pt, ok := promt.(float64); ok { // JSON numbers are decoded as float64 + usage.promt = int(pt) + } else { + return nil, errors.New("invalid type for prompt tokens") + } + + // Extract completion tokens + completion, err := extractValueFromPath(rootNode, completionTokenPath) + if err != nil { + return nil, fmt.Errorf("failed to extract completion tokens: %w", err) + } + if ct, ok := completion.(float64); ok { + usage.completion = int(ct) + } else { + return nil, errors.New("invalid type for completion tokens") + } + + // Extract total tokens + total, err := extractValueFromPath(rootNode, totalTokenPath) + if err != nil { + return nil, fmt.Errorf("failed to extract total tokens: %w", err) + } + if tt, ok := total.(float64); ok { + usage.total = int(tt) + } else { + return nil, errors.New("invalid type for total tokens") + } + + // Extract model + // model, err := extractValueFromPath(rootNode, modelPath) + // if err != nil { + // return nil, fmt.Errorf("failed to extract model: %w", err) + // } + // if m, ok := model.(string); ok { + // usage.model = m + // } else { + // return nil, errors.New("invalid type for model") + // } + + return usage, nil +} diff --git a/gateway/enforcer/internal/ratelimit/rl_client.go b/gateway/enforcer/internal/ratelimit/rl_client.go new file mode 100644 index 000000000..8c1df8e3d --- /dev/null +++ b/gateway/enforcer/internal/ratelimit/rl_client.go @@ -0,0 +1,99 @@ +package ratelimit + +import ( + "context" + "fmt" + "time" + + v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/common/ratelimit/v3" + rls_svc "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/wso2/apk/gateway/enforcer/internal/config" + "github.com/wso2/apk/gateway/enforcer/internal/logging" + "github.com/wso2/apk/gateway/enforcer/internal/util" +) + +// client is a client for managing gRPC connections to the Rate Limit Service (RLS). +type client struct { + log logging.Logger + cfg *config.Server + rlsClient rls_svc.RateLimitServiceClient +} + +// keyValueHitsAddend is a struct that holds the key, value, and hits addend for a rate limit. +type keyValueHitsAddend struct { + Key string + Value string + HitsAddend int + KeyValueHitsAddend *keyValueHitsAddend +} + +// newClient creates a new instance of the Rate Limit Client. +func newClient(cfg *config.Server) *client { + return &client{ + log: cfg.Logger, + cfg: cfg, + } +} + +// start initializes the Rate Limit Client by creating a gRPC connection to the RLS. +func (r *client) start() { + r.log.Info("Starting the rate limit client") + + clientCert, err := util.LoadCertificates(r.cfg.EnforcerPublicKeyPath, r.cfg.EnforcerPrivateKeyPath) + if err != nil { + panic(err) + } + + // Load the trusted CA certificates + certPool, err := util.LoadCACertificates(r.cfg.TrustedAdapterCertsPath) + if err != nil { + panic(err) + } + + // Create the TLS configuration + tlsConfig := util.CreateTLSConfig(clientCert, certPool) + grpcConn := util.CreateGRPCConnectionWithRetryAndPanic(nil, r.cfg.RatelimiterHost, r.cfg.RatelimiterPort, tlsConfig, r.cfg.XdsMaxRetries, time.Duration(r.cfg.XdsRetryPeriod)*time.Millisecond) + r.rlsClient = rls_svc.NewRateLimitServiceClient(grpcConn) + r.log.Info("Rate limit client started successfully") +} + +// shouldRatelimit checks if the request should be rate limited based on the given configurations. +func (r *client) shouldRatelimit(configs []*keyValueHitsAddend) { + for _, config := range configs { + descriptorEntries := []*v3.RateLimitDescriptor_Entry{ + { + Key: config.Key, + Value: config.Value, + }, + } + + internalConfig := config.KeyValueHitsAddend + hitsAddend := config.HitsAddend + for internalConfig != nil { + descriptorEntries = append(descriptorEntries, &v3.RateLimitDescriptor_Entry{ + Key: internalConfig.Key, + Value: internalConfig.Value, + }) + hitsAddend = internalConfig.HitsAddend + internalConfig = internalConfig.KeyValueHitsAddend + } + + rateLimitRequest := &rls_svc.RateLimitRequest{ + Descriptors: []*v3.RateLimitDescriptor{ + { + Entries: descriptorEntries, + }, + }, + Domain: "Default", + HitsAddend: uint32(hitsAddend), + } + + response, err := r.rlsClient.ShouldRateLimit(context.Background(), rateLimitRequest) + if err != nil { + r.log.Info(fmt.Sprintf("Error while calling rate limiter: %v", err)) + continue + } + + r.log.Info(fmt.Sprintf("Rate limit response: %v", response)) + } +} diff --git a/gateway/enforcer/internal/util/api.go b/gateway/enforcer/internal/util/api.go new file mode 100644 index 000000000..bacbe43e8 --- /dev/null +++ b/gateway/enforcer/internal/util/api.go @@ -0,0 +1,10 @@ +package util + +import ( + "fmt" +) + +// PrepareAPIKey prepares the API key using the given vhost, basePath, and version. +func PrepareAPIKey(vhost, basePath, version string) string { + return fmt.Sprintf("%s:%s:%s", vhost, basePath, version) +} \ No newline at end of file diff --git a/gateway/enforcer/internal/util/conversion.go b/gateway/enforcer/internal/util/conversion.go new file mode 100644 index 000000000..a134d4835 --- /dev/null +++ b/gateway/enforcer/internal/util/conversion.go @@ -0,0 +1,32 @@ +package util + +import ( + "fmt" + "strconv" +) + + +// ConvertBytesToInt converts a []byte to an int. +// It assumes the []byte contains a valid numeric string (e.g., "123"). +func ConvertBytesToInt(data []byte) (int, error) { + // Convert the []byte to string + str := string(data) + + // Convert the string to int + num, err := strconv.Atoi(str) + if err != nil { + return 0, fmt.Errorf("invalid numeric value: %s, error: %w", str, err) + } + + return num, nil +} + +// ConvertStringToInt converts a string to an integer. +// Returns the converted int and an error if the input is invalid. +func ConvertStringToInt(input string) (int, error) { + num, err := strconv.Atoi(input) + if err != nil { + return 0, fmt.Errorf("invalid input: %s, error: %w", input, err) + } + return num, nil +} diff --git a/gateway/enforcer/internal/xds/client_manager.go b/gateway/enforcer/internal/xds/client_manager.go index 818c184f9..c07d529d4 100644 --- a/gateway/enforcer/internal/xds/client_manager.go +++ b/gateway/enforcer/internal/xds/client_manager.go @@ -32,7 +32,7 @@ import ( // CreateXDSClients initializes and establishes connections for multiple XDS clients, // including API XDS, Config XDS, and JWT Issuer XDS clients. // It handles TLS configuration, certificate loading, and connection setup. -func CreateXDSClients(cfg *config.Server) { +func CreateXDSClients(cfg *config.Server) (*datastore.APIStore, *datastore.ConfigStore, *datastore.JWTIssuerStore) { clientCert, err := util.LoadCertificates(cfg.EnforcerPublicKeyPath, cfg.EnforcerPrivateKeyPath) if err != nil { panic(err) @@ -57,6 +57,7 @@ func CreateXDSClients(cfg *config.Server) { configXDSClient.InitiateConfigXDSConnection() jwtIssuerXDSClient.InitiateSubscriptionXDSConnection() cfg.Logger.Info("XDS clients initiated successfully") + return apiDatastore, configDatastore, jwtIssuerDatastore } // CreateNode creates a new Node object with the given node ID and instance identifier. diff --git a/helm-charts/templates/data-plane/gateway-components/gateway-runtime/gateway-runtime-deployment.yaml b/helm-charts/templates/data-plane/gateway-components/gateway-runtime/gateway-runtime-deployment.yaml index 4dab4e670..3b7227327 100644 --- a/helm-charts/templates/data-plane/gateway-components/gateway-runtime/gateway-runtime-deployment.yaml +++ b/helm-charts/templates/data-plane/gateway-components/gateway-runtime/gateway-runtime-deployment.yaml @@ -274,18 +274,18 @@ spec: subPath: ca.crt {{- end }} {{ end }} - readinessProbe: - exec: - command: [ "sh", "check_health.sh" ] - initialDelaySeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.readinessProbe.initialDelaySeconds }} - periodSeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.readinessProbe.periodSeconds }} - failureThreshold: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.readinessProbe.failureThreshold }} - livenessProbe: - exec: - command: [ "sh", "check_health.sh" ] - initialDelaySeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.livenessProbe.initialDelaySeconds }} - periodSeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.livenessProbe.periodSeconds }} - failureThreshold: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.livenessProbe.failureThreshold }} + # readinessProbe: + # exec: + # command: [ "sh", "check_health.sh" ] + # initialDelaySeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.readinessProbe.initialDelaySeconds }} + # periodSeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.readinessProbe.periodSeconds }} + # failureThreshold: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.readinessProbe.failureThreshold }} + # livenessProbe: + # exec: + # command: [ "sh", "check_health.sh" ] + # initialDelaySeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.livenessProbe.initialDelaySeconds }} + # periodSeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.livenessProbe.periodSeconds }} + # failureThreshold: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.enforcer.livenessProbe.failureThreshold }} securityContext: allowPrivilegeEscalation: false capabilities: @@ -391,30 +391,30 @@ spec: subPath: ca.crt {{- end }} {{ end }} - livenessProbe: - exec: - command: [ "sh", "router_check_health.sh", "health" ] - initialDelaySeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.livenessProbe.initialDelaySeconds }} - periodSeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.livenessProbe.periodSeconds }} - failureThreshold: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.livenessProbe.failureThreshold }} - readinessProbe: - exec: - command: [ "sh", "router_check_health.sh", "ready" ] - initialDelaySeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.initialDelaySeconds }} - periodSeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.periodSeconds }} - failureThreshold: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.failureThreshold }} + # livenessProbe: + # exec: + # command: [ "sh", "router_check_health.sh", "health" ] + # initialDelaySeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.livenessProbe.initialDelaySeconds }} + # periodSeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.livenessProbe.periodSeconds }} + # failureThreshold: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.livenessProbe.failureThreshold }} + # readinessProbe: + # exec: + # command: [ "sh", "router_check_health.sh", "ready" ] + # initialDelaySeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.initialDelaySeconds }} + # periodSeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.periodSeconds }} + # failureThreshold: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.failureThreshold }} securityContext: allowPrivilegeEscalation: false capabilities: drop: ["ALL"] readOnlyRootFilesystem: true runAsNonRoot: true - startupProbe: - exec: - command: [ "sh", "router_check_health.sh", "ready" ] - initialDelaySeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.initialDelaySeconds }} - periodSeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.periodSeconds }} - failureThreshold: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.failureThreshold }} + # startupProbe: + # exec: + # command: [ "sh", "router_check_health.sh", "ready" ] + # initialDelaySeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.initialDelaySeconds }} + # periodSeconds: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.periodSeconds }} + # failureThreshold: {{ .Values.wso2.apk.dp.gatewayRuntime.deployment.router.readinessProbe.failureThreshold }} {{- if and .Values.wso2.subscription .Values.wso2.subscription.imagePullSecrets}} imagePullSecrets: - name: {{ .Values.wso2.subscription.imagePullSecrets }}