Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend based AI Ratelimit #2687

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions adapter/internal/oasparser/envoyconf/http_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,12 @@ func getExtProcessHTTPFilter() *hcmv3.HttpFilter {
},
FailureModeAllow: true,
ProcessingMode: &ext_process.ProcessingMode{
// ResponseBodyMode: ext_process.ProcessingMode_BUFFERED,
ResponseBodyMode: ext_process.ProcessingMode_BUFFERED,
RequestHeaderMode: ext_process.ProcessingMode_SEND,
ResponseHeaderMode: ext_process.ProcessingMode_SEND,
// RequestHeaderMode: ext_process.ProcessingMode_SKIP,
// ResponseHeaderMode: ext_process.ProcessingMode_SKIP,
// RequestBodyMode: ext_process.ProcessingMode_BUFFERED,
RequestBodyMode: ext_process.ProcessingMode_BUFFERED,
},
MetadataOptions: &ext_process.MetadataOptions{
ForwardingNamespaces: &ext_process.MetadataOptions_MetadataNamespaces{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,3 @@
cluster: first-route-dest
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.basic_auth_first-route:
'@type': type.googleapis.com/envoy.config.route.v3.FilterConfig
config: {}
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,10 @@
cluster: httproute/default/httproute-1/rule/0
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.ext_authz_:
'@type': type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute
checkSettings:
contextExtensions:
route-name: httproute/default/httproute-1/rule/0/match/0/www_example_com
- match:
path: bar
name: httproute/default/httproute-2/rule/0/match/0/www_example_com
route:
cluster: httproute/default/httproute-2/rule/0
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.ext_authz_:
'@type': type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute
checkSettings:
contextExtensions:
route-name: httproute/default/httproute-2/rule/0/match/0/www_example_com
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
cluster: httproute/default/httproute-1/rule/0
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.basic_auth_httproute/default/httproute-1/rule/0/match/0/www_foo_com:
'@type': type.googleapis.com/envoy.config.route.v3.FilterConfig
config: {}
- match:
pathSeparatedPrefix: /foo2
name: httproute/default/httproute-2/rule/0/match/0/www_foo_com
Expand All @@ -33,12 +29,6 @@
cluster: httproute/default/httproute-2/rule/0
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.ext_authz_securitypolicy/default/policy-for-http-route-2:
'@type': type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute
checkSettings:
contextExtensions:
route-name: httproute/default/httproute-2/rule/0/match/0/www_foo_com
- domains:
- www.bar.com
name: default/gateway-2/http/www_bar_com
Expand All @@ -50,7 +40,3 @@
cluster: httproute/default/httproute-3/rule/0
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.oauth2_httproute/default/httproute-3/rule/0/match/0/www_bar_com:
'@type': type.googleapis.com/envoy.config.route.v3.FilterConfig
config: {}
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,10 @@
cluster: first-route-dest
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.oauth2_first-route:
'@type': type.googleapis.com/envoy.config.route.v3.FilterConfig
config: {}
- match:
path: bar
name: second-route
route:
cluster: second-route-dest
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.oauth2_second-route:
'@type': type.googleapis.com/envoy.config.route.v3.FilterConfig
config: {}
4 changes: 2 additions & 2 deletions gateway/enforcer/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func main() {
client.InitiateEventingGRPCConnection()

// Create the XDS clients
xds.CreateXDSClients(cfg)
apiStore, _,_ := xds.CreateXDSClients(cfg)

// Start the external processing server
go extproc.StartExternalProcessingServer(cfg)
go extproc.StartExternalProcessingServer(cfg, apiStore)

// Wait forever
select {}
Expand Down
19 changes: 15 additions & 4 deletions gateway/enforcer/internal/datastore/api_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@ import (
"sync"

api "github.com/wso2/apk/adapter/pkg/discovery/api/wso2/discovery/api"
"github.com/wso2/apk/gateway/enforcer/internal/util"
)

// APIStore is a thread-safe store for APIs.
type APIStore struct {
apis []*api.Api
apis map[string]*api.Api
mu sync.RWMutex
}

// NewAPIStore creates a new instance of APIStore.
func NewAPIStore() *APIStore {
return &APIStore{
apis: make([]*api.Api, 0),
// apis: make(map[string]*api.Api, 0),
}
}

Expand All @@ -41,13 +42,23 @@ func NewAPIStore() *APIStore {
func (s *APIStore) AddAPIs(apis []*api.Api) {
s.mu.Lock()
defer s.mu.Unlock()
s.apis = apis
s.apis = make(map[string]*api.Api, len(apis))
for _, api := range apis {
s.apis[util.PrepareAPIKey(api.Vhost, api.BasePath, api.Version)] = api
}
}

// GetAPIs retrieves the list of APIs from the store.
// This method is thread-safe.
func (s *APIStore) GetAPIs() []*api.Api {
func (s *APIStore) GetAPIs() map[string]*api.Api {
s.mu.RLock()
defer s.mu.RUnlock()
return s.apis
}

// GetMatchedAPI retrieves the API that matches the given API key.
func (s *APIStore) GetMatchedAPI(apiKey string) *api.Api {
s.mu.RLock()
defer s.mu.RUnlock()
return s.apis[apiKey]
}
125 changes: 93 additions & 32 deletions gateway/enforcer/internal/extproc/ext_proc.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ import (

corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_service_proc_v3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
api "github.com/wso2/apk/adapter/pkg/discovery/api/wso2/discovery/api"
"github.com/wso2/apk/gateway/enforcer/internal/config"
"github.com/wso2/apk/gateway/enforcer/internal/datastore"
"github.com/wso2/apk/gateway/enforcer/internal/logging"
"github.com/wso2/apk/gateway/enforcer/internal/ratelimit"
"github.com/wso2/apk/gateway/enforcer/internal/util"

"net"
Expand All @@ -35,6 +38,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/prototext"
structpb "google.golang.org/protobuf/types/known/structpb"
)

Expand All @@ -43,6 +47,9 @@ import (
type ExternalProcessingServer struct {
log logging.Logger
externalProcessingEnvoyAttributes *ExternalProcessingEnvoyAttributes
matchedAPI *api.Api
apiStore *datastore.APIStore
ratelimitHelper *ratelimit.AIRatelimitHelper
}

// ExternalProcessingEnvoyAttributes represents the attributes extracted from the external processing request.
Expand Down Expand Up @@ -71,7 +78,7 @@ const (
)

// Define the regular expression as a constant
const keyValuePattern = `key: "(.*?)" value { string_value: "(.*?)" }`
const keyValuePattern = `key: "([^.]*)" value { string_value: "(.*?)" }`

// Pre-compile the regular expression
var re = regexp.MustCompile(keyValuePattern)
Expand All @@ -85,7 +92,7 @@ var re = regexp.MustCompile(keyValuePattern)
// public and private keys, and a logger instance.
//
// If there is an error during the creation of the gRPC server, the function will panic.
func StartExternalProcessingServer(cfg *config.Server) {
func StartExternalProcessingServer(cfg *config.Server, apiStore *datastore.APIStore) {
kaParams := keepalive.ServerParameters{
Time: time.Duration(cfg.ExternalProcessingKeepAliveTime) * time.Hour, // Ping the client if it is idle for 2 hours
Timeout: 20 * time.Second,
Expand All @@ -99,7 +106,8 @@ func StartExternalProcessingServer(cfg *config.Server) {
panic(err)
}

envoy_service_proc_v3.RegisterExternalProcessorServer(server, &ExternalProcessingServer{cfg.Logger, nil})
ratelimitHelper := ratelimit.NewAIRatelimitHelper(cfg)
envoy_service_proc_v3.RegisterExternalProcessorServer(server, &ExternalProcessingServer{cfg.Logger, nil, nil, apiStore, ratelimitHelper})
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", cfg.ExternalProcessingPort))
if err != nil {
cfg.Logger.Error(err, fmt.Sprintf("Failed to listen on port: %s", cfg.ExternalProcessingPort))
Expand Down Expand Up @@ -152,6 +160,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
s.log.Error(err, "failed to extract context attributes")
}
s.externalProcessingEnvoyAttributes = attributes
s.matchedAPI = s.apiStore.GetMatchedAPI(util.PrepareAPIKey(s.externalProcessingEnvoyAttributes.VHost, s.externalProcessingEnvoyAttributes.BasePath, s.externalProcessingEnvoyAttributes.APIVersion))

rhq := &envoy_service_proc_v3.HeadersResponse{
Response: &envoy_service_proc_v3.CommonResponse{
Expand Down Expand Up @@ -185,10 +194,25 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
ResponseHeaders: rhq,
},
}
// s.log.Info(fmt.Sprintf("Matched api: %s", s.matchedAPI))
if s.matchedAPI.Aiprovider != nil &&
s.matchedAPI.Aiprovider.CompletionToken != nil &&
s.externalProcessingEnvoyAttributes.EnableBackendBasedAIRatelimit == "true" &&
s.matchedAPI.Aiprovider.CompletionToken.In == "Header" {
s.log.Info("Backend based AI rate limit enabled using headers")
tokenCount, err := ratelimit.ExtractTokenCountFromExternalProcessingResponseHeaders(req.GetResponseHeaders().GetHeaders().GetHeaders(), s.matchedAPI.Aiprovider.PromptTokens.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.Model.Value)
if err != nil {
s.log.Error(err, "failed to extract token count from response headers")
} else {
s.ratelimitHelper.DoAIRatelimit(tokenCount, true, false, s.externalProcessingEnvoyAttributes.BackendBasedAIRatelimitDescriptorValue)
}
}

break
case *envoy_service_proc_v3.ProcessingRequest_ResponseBody:
// httpBody := req.GetResponseBody()
// s.log.Info(fmt.Sprint("response body \n"))
s.log.Info(fmt.Sprintf("attribute %+v\n", s.externalProcessingEnvoyAttributes))

rbq := &envoy_service_proc_v3.BodyResponse{
Response: &envoy_service_proc_v3.CommonResponse{},
}
Expand All @@ -198,6 +222,19 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
},
}

if s.matchedAPI.Aiprovider != nil &&
s.matchedAPI.Aiprovider.CompletionToken != nil &&
s.externalProcessingEnvoyAttributes.EnableBackendBasedAIRatelimit == "true" &&
s.matchedAPI.Aiprovider.CompletionToken.In == "Body" {
s.log.Info("Backend based AI rate limit enabled using body")
tokenCount, err := ratelimit.ExtractTokenCountFromExternalProcessingResponseBody(req.GetResponseBody().Body, s.matchedAPI.Aiprovider.PromptTokens.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.Model.Value)
if err != nil {
s.log.Error(err, "failed to extract token count from response body")
} else {
s.ratelimitHelper.DoAIRatelimit(tokenCount, true, false, s.externalProcessingEnvoyAttributes.BackendBasedAIRatelimitDescriptorValue)
}
}

case *envoy_service_proc_v3.ProcessingRequest_RequestBody:
// httpBody := req.GetRequestBody()
// s.log.Info(fmt.Sprint("request body"))
Expand All @@ -218,6 +255,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
}
}

// extractExternalProcessingAttributes extracts the external processing attributes from the given data.
func extractExternalProcessingAttributes(data map[string]*structpb.Struct) (*ExternalProcessingEnvoyAttributes, error) {

// Get the fields from the map
Expand All @@ -234,39 +272,62 @@ func extractExternalProcessingAttributes(data map[string]*structpb.Struct) (*Ext
if field, ok := fields["xds.route_metadata"]; ok {

filterMetadata := field.GetStringValue()
var structData corev3.Metadata
err := prototext.Unmarshal([]byte(filterMetadata), &structData)
if err != nil {
return nil, fmt.Errorf("failed to parse Protobuf text: %v", err)
}

matches := re.FindAllStringSubmatch(filterMetadata, -1)
// Extract values for predefined keys
extractedValues := make(map[string]string)

// Iterate over the matches and assign values to the struct
for _, match := range matches {
key, value := match[1], match[2]
keysToExtract := []string{
pathAttribute,
vHostAttribute,
basePathAttribute,
methodAttribute,
apiVersionAttribute,
apiNameAttribute,
clusterNameAttribute,
enableBackendBasedAIRatelimitAttribute,
backendBasedAIRatelimitDescriptorValueAttribute,
}

switch key {
case enableBackendBasedAIRatelimitAttribute:
attributes.EnableBackendBasedAIRatelimit = value
case backendBasedAIRatelimitDescriptorValueAttribute:
attributes.BackendBasedAIRatelimitDescriptorValue = value
case pathAttribute:
attributes.Path = value
case vHostAttribute:
attributes.VHost = value
case basePathAttribute:
attributes.BasePath = value
case methodAttribute:
attributes.Method = value
case apiNameAttribute:
attributes.APIName = value
case apiVersionAttribute:
attributes.APIVersion = value
case clusterNameAttribute:
attributes.ClusterName = value
default:
for _, key := range keysToExtract {
if field, exists := structData.FilterMetadata["envoy.filters.http.ext_proc"]; exists {
extractedValues[key] = field.Fields[key].GetStringValue()
// case condition to populate ExternalProcessingEnvoyAttributes
switch key {
case pathAttribute:
attributes.Path = extractedValues[key]
case vHostAttribute:
attributes.VHost = extractedValues[key]
case basePathAttribute:
attributes.BasePath = extractedValues[key]
case methodAttribute:
attributes.Method = extractedValues[key]
case apiVersionAttribute:
attributes.APIVersion = extractedValues[key]
case apiNameAttribute:
attributes.APIName = extractedValues[key]
case clusterNameAttribute:
attributes.ClusterName = extractedValues[key]
case enableBackendBasedAIRatelimitAttribute:
attributes.EnableBackendBasedAIRatelimit = extractedValues[key]
case backendBasedAIRatelimitDescriptorValueAttribute:
attributes.BackendBasedAIRatelimitDescriptorValue = extractedValues[key]
}
}
}
} else {
fmt.Println("Key 'xds.route_metadata' not found in fields")

// Print extracted values
for key, value := range extractedValues {
fmt.Printf("%s: %s\n", key, value)
}
// Return the populated struct
return attributes, nil
}

// Return the populated struct
return attributes, nil
// Key not found
return nil, fmt.Errorf("key xds.route_metadata not found")
}
Loading