Skip to content

Commit

Permalink
Merge pull request #2687 from Tharsanan1/airl
Browse files Browse the repository at this point in the history
Backend based AI Ratelimit
  • Loading branch information
Tharsanan1 authored Jan 17, 2025
2 parents dba5fa4 + c689293 commit 1fd214e
Show file tree
Hide file tree
Showing 14 changed files with 530 additions and 109 deletions.
4 changes: 2 additions & 2 deletions adapter/internal/oasparser/envoyconf/http_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,12 @@ func getExtProcessHTTPFilter() *hcmv3.HttpFilter {
},
FailureModeAllow: true,
ProcessingMode: &ext_process.ProcessingMode{
// ResponseBodyMode: ext_process.ProcessingMode_BUFFERED,
ResponseBodyMode: ext_process.ProcessingMode_BUFFERED,
RequestHeaderMode: ext_process.ProcessingMode_SEND,
ResponseHeaderMode: ext_process.ProcessingMode_SEND,
// RequestHeaderMode: ext_process.ProcessingMode_SKIP,
// ResponseHeaderMode: ext_process.ProcessingMode_SKIP,
// RequestBodyMode: ext_process.ProcessingMode_BUFFERED,
RequestBodyMode: ext_process.ProcessingMode_BUFFERED,
},
MetadataOptions: &ext_process.MetadataOptions{
ForwardingNamespaces: &ext_process.MetadataOptions_MetadataNamespaces{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,3 @@
cluster: first-route-dest
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.basic_auth_first-route:
'@type': type.googleapis.com/envoy.config.route.v3.FilterConfig
config: {}
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,10 @@
cluster: httproute/default/httproute-1/rule/0
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.ext_authz_:
'@type': type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute
checkSettings:
contextExtensions:
route-name: httproute/default/httproute-1/rule/0/match/0/www_example_com
- match:
path: bar
name: httproute/default/httproute-2/rule/0/match/0/www_example_com
route:
cluster: httproute/default/httproute-2/rule/0
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.ext_authz_:
'@type': type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute
checkSettings:
contextExtensions:
route-name: httproute/default/httproute-2/rule/0/match/0/www_example_com
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
cluster: httproute/default/httproute-1/rule/0
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.basic_auth_httproute/default/httproute-1/rule/0/match/0/www_foo_com:
'@type': type.googleapis.com/envoy.config.route.v3.FilterConfig
config: {}
- match:
pathSeparatedPrefix: /foo2
name: httproute/default/httproute-2/rule/0/match/0/www_foo_com
Expand All @@ -33,12 +29,6 @@
cluster: httproute/default/httproute-2/rule/0
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.ext_authz_securitypolicy/default/policy-for-http-route-2:
'@type': type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute
checkSettings:
contextExtensions:
route-name: httproute/default/httproute-2/rule/0/match/0/www_foo_com
- domains:
- www.bar.com
name: default/gateway-2/http/www_bar_com
Expand All @@ -50,7 +40,3 @@
cluster: httproute/default/httproute-3/rule/0
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.oauth2_httproute/default/httproute-3/rule/0/match/0/www_bar_com:
'@type': type.googleapis.com/envoy.config.route.v3.FilterConfig
config: {}
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,10 @@
cluster: first-route-dest
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.oauth2_first-route:
'@type': type.googleapis.com/envoy.config.route.v3.FilterConfig
config: {}
- match:
path: bar
name: second-route
route:
cluster: second-route-dest
upgradeConfigs:
- upgradeType: websocket
typedPerFilterConfig:
envoy.filters.http.oauth2_second-route:
'@type': type.googleapis.com/envoy.config.route.v3.FilterConfig
config: {}
4 changes: 2 additions & 2 deletions gateway/enforcer/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func main() {
client.InitiateEventingGRPCConnection()

// Create the XDS clients
xds.CreateXDSClients(cfg)
apiStore, _,_ := xds.CreateXDSClients(cfg)

// Start the external processing server
go extproc.StartExternalProcessingServer(cfg)
go extproc.StartExternalProcessingServer(cfg, apiStore)

// Wait forever
select {}
Expand Down
19 changes: 15 additions & 4 deletions gateway/enforcer/internal/datastore/api_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@ import (
"sync"

api "github.com/wso2/apk/adapter/pkg/discovery/api/wso2/discovery/api"
"github.com/wso2/apk/gateway/enforcer/internal/util"
)

// APIStore is a thread-safe store for APIs.
type APIStore struct {
apis []*api.Api
apis map[string]*api.Api
mu sync.RWMutex
}

// NewAPIStore creates a new instance of APIStore.
func NewAPIStore() *APIStore {
return &APIStore{
apis: make([]*api.Api, 0),
// apis: make(map[string]*api.Api, 0),
}
}

Expand All @@ -41,13 +42,23 @@ func NewAPIStore() *APIStore {
func (s *APIStore) AddAPIs(apis []*api.Api) {
s.mu.Lock()
defer s.mu.Unlock()
s.apis = apis
s.apis = make(map[string]*api.Api, len(apis))
for _, api := range apis {
s.apis[util.PrepareAPIKey(api.Vhost, api.BasePath, api.Version)] = api
}
}

// GetAPIs retrieves the list of APIs from the store.
// This method is thread-safe.
func (s *APIStore) GetAPIs() []*api.Api {
func (s *APIStore) GetAPIs() map[string]*api.Api {
s.mu.RLock()
defer s.mu.RUnlock()
return s.apis
}

// GetMatchedAPI retrieves the API that matches the given API key.
func (s *APIStore) GetMatchedAPI(apiKey string) *api.Api {
s.mu.RLock()
defer s.mu.RUnlock()
return s.apis[apiKey]
}
125 changes: 93 additions & 32 deletions gateway/enforcer/internal/extproc/ext_proc.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ import (

corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_service_proc_v3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
api "github.com/wso2/apk/adapter/pkg/discovery/api/wso2/discovery/api"
"github.com/wso2/apk/gateway/enforcer/internal/config"
"github.com/wso2/apk/gateway/enforcer/internal/datastore"
"github.com/wso2/apk/gateway/enforcer/internal/logging"
"github.com/wso2/apk/gateway/enforcer/internal/ratelimit"
"github.com/wso2/apk/gateway/enforcer/internal/util"

"net"
Expand All @@ -35,6 +38,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/prototext"
structpb "google.golang.org/protobuf/types/known/structpb"
)

Expand All @@ -43,6 +47,9 @@ import (
type ExternalProcessingServer struct {
log logging.Logger
externalProcessingEnvoyAttributes *ExternalProcessingEnvoyAttributes
matchedAPI *api.Api
apiStore *datastore.APIStore
ratelimitHelper *ratelimit.AIRatelimitHelper
}

// ExternalProcessingEnvoyAttributes represents the attributes extracted from the external processing request.
Expand Down Expand Up @@ -71,7 +78,7 @@ const (
)

// Define the regular expression as a constant
const keyValuePattern = `key: "(.*?)" value { string_value: "(.*?)" }`
const keyValuePattern = `key: "([^.]*)" value { string_value: "(.*?)" }`

// Pre-compile the regular expression
var re = regexp.MustCompile(keyValuePattern)
Expand All @@ -85,7 +92,7 @@ var re = regexp.MustCompile(keyValuePattern)
// public and private keys, and a logger instance.
//
// If there is an error during the creation of the gRPC server, the function will panic.
func StartExternalProcessingServer(cfg *config.Server) {
func StartExternalProcessingServer(cfg *config.Server, apiStore *datastore.APIStore) {
kaParams := keepalive.ServerParameters{
Time: time.Duration(cfg.ExternalProcessingKeepAliveTime) * time.Hour, // Ping the client if it is idle for 2 hours
Timeout: 20 * time.Second,
Expand All @@ -99,7 +106,8 @@ func StartExternalProcessingServer(cfg *config.Server) {
panic(err)
}

envoy_service_proc_v3.RegisterExternalProcessorServer(server, &ExternalProcessingServer{cfg.Logger, nil})
ratelimitHelper := ratelimit.NewAIRatelimitHelper(cfg)
envoy_service_proc_v3.RegisterExternalProcessorServer(server, &ExternalProcessingServer{cfg.Logger, nil, nil, apiStore, ratelimitHelper})
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", cfg.ExternalProcessingPort))
if err != nil {
cfg.Logger.Error(err, fmt.Sprintf("Failed to listen on port: %s", cfg.ExternalProcessingPort))
Expand Down Expand Up @@ -152,6 +160,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
s.log.Error(err, "failed to extract context attributes")
}
s.externalProcessingEnvoyAttributes = attributes
s.matchedAPI = s.apiStore.GetMatchedAPI(util.PrepareAPIKey(s.externalProcessingEnvoyAttributes.VHost, s.externalProcessingEnvoyAttributes.BasePath, s.externalProcessingEnvoyAttributes.APIVersion))

rhq := &envoy_service_proc_v3.HeadersResponse{
Response: &envoy_service_proc_v3.CommonResponse{
Expand Down Expand Up @@ -185,10 +194,25 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
ResponseHeaders: rhq,
},
}
// s.log.Info(fmt.Sprintf("Matched api: %s", s.matchedAPI))
if s.matchedAPI.Aiprovider != nil &&
s.matchedAPI.Aiprovider.CompletionToken != nil &&
s.externalProcessingEnvoyAttributes.EnableBackendBasedAIRatelimit == "true" &&
s.matchedAPI.Aiprovider.CompletionToken.In == "Header" {
s.log.Info("Backend based AI rate limit enabled using headers")
tokenCount, err := ratelimit.ExtractTokenCountFromExternalProcessingResponseHeaders(req.GetResponseHeaders().GetHeaders().GetHeaders(), s.matchedAPI.Aiprovider.PromptTokens.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.Model.Value)
if err != nil {
s.log.Error(err, "failed to extract token count from response headers")
} else {
s.ratelimitHelper.DoAIRatelimit(tokenCount, true, false, s.externalProcessingEnvoyAttributes.BackendBasedAIRatelimitDescriptorValue)
}
}

break
case *envoy_service_proc_v3.ProcessingRequest_ResponseBody:
// httpBody := req.GetResponseBody()
// s.log.Info(fmt.Sprint("response body \n"))
s.log.Info(fmt.Sprintf("attribute %+v\n", s.externalProcessingEnvoyAttributes))

rbq := &envoy_service_proc_v3.BodyResponse{
Response: &envoy_service_proc_v3.CommonResponse{},
}
Expand All @@ -198,6 +222,19 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
},
}

if s.matchedAPI.Aiprovider != nil &&
s.matchedAPI.Aiprovider.CompletionToken != nil &&
s.externalProcessingEnvoyAttributes.EnableBackendBasedAIRatelimit == "true" &&
s.matchedAPI.Aiprovider.CompletionToken.In == "Body" {
s.log.Info("Backend based AI rate limit enabled using body")
tokenCount, err := ratelimit.ExtractTokenCountFromExternalProcessingResponseBody(req.GetResponseBody().Body, s.matchedAPI.Aiprovider.PromptTokens.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.CompletionToken.Value, s.matchedAPI.Aiprovider.Model.Value)
if err != nil {
s.log.Error(err, "failed to extract token count from response body")
} else {
s.ratelimitHelper.DoAIRatelimit(tokenCount, true, false, s.externalProcessingEnvoyAttributes.BackendBasedAIRatelimitDescriptorValue)
}
}

case *envoy_service_proc_v3.ProcessingRequest_RequestBody:
// httpBody := req.GetRequestBody()
// s.log.Info(fmt.Sprint("request body"))
Expand All @@ -218,6 +255,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
}
}

// extractExternalProcessingAttributes extracts the external processing attributes from the given data.
func extractExternalProcessingAttributes(data map[string]*structpb.Struct) (*ExternalProcessingEnvoyAttributes, error) {

// Get the fields from the map
Expand All @@ -234,39 +272,62 @@ func extractExternalProcessingAttributes(data map[string]*structpb.Struct) (*Ext
if field, ok := fields["xds.route_metadata"]; ok {

filterMetadata := field.GetStringValue()
var structData corev3.Metadata
err := prototext.Unmarshal([]byte(filterMetadata), &structData)
if err != nil {
return nil, fmt.Errorf("failed to parse Protobuf text: %v", err)
}

matches := re.FindAllStringSubmatch(filterMetadata, -1)
// Extract values for predefined keys
extractedValues := make(map[string]string)

// Iterate over the matches and assign values to the struct
for _, match := range matches {
key, value := match[1], match[2]
keysToExtract := []string{
pathAttribute,
vHostAttribute,
basePathAttribute,
methodAttribute,
apiVersionAttribute,
apiNameAttribute,
clusterNameAttribute,
enableBackendBasedAIRatelimitAttribute,
backendBasedAIRatelimitDescriptorValueAttribute,
}

switch key {
case enableBackendBasedAIRatelimitAttribute:
attributes.EnableBackendBasedAIRatelimit = value
case backendBasedAIRatelimitDescriptorValueAttribute:
attributes.BackendBasedAIRatelimitDescriptorValue = value
case pathAttribute:
attributes.Path = value
case vHostAttribute:
attributes.VHost = value
case basePathAttribute:
attributes.BasePath = value
case methodAttribute:
attributes.Method = value
case apiNameAttribute:
attributes.APIName = value
case apiVersionAttribute:
attributes.APIVersion = value
case clusterNameAttribute:
attributes.ClusterName = value
default:
for _, key := range keysToExtract {
if field, exists := structData.FilterMetadata["envoy.filters.http.ext_proc"]; exists {
extractedValues[key] = field.Fields[key].GetStringValue()
// case condition to populate ExternalProcessingEnvoyAttributes
switch key {
case pathAttribute:
attributes.Path = extractedValues[key]
case vHostAttribute:
attributes.VHost = extractedValues[key]
case basePathAttribute:
attributes.BasePath = extractedValues[key]
case methodAttribute:
attributes.Method = extractedValues[key]
case apiVersionAttribute:
attributes.APIVersion = extractedValues[key]
case apiNameAttribute:
attributes.APIName = extractedValues[key]
case clusterNameAttribute:
attributes.ClusterName = extractedValues[key]
case enableBackendBasedAIRatelimitAttribute:
attributes.EnableBackendBasedAIRatelimit = extractedValues[key]
case backendBasedAIRatelimitDescriptorValueAttribute:
attributes.BackendBasedAIRatelimitDescriptorValue = extractedValues[key]
}
}
}
} else {
fmt.Println("Key 'xds.route_metadata' not found in fields")

// Print extracted values
for key, value := range extractedValues {
fmt.Printf("%s: %s\n", key, value)
}
// Return the populated struct
return attributes, nil
}

// Return the populated struct
return attributes, nil
// Key not found
return nil, fmt.Errorf("key xds.route_metadata not found")
}
Loading

0 comments on commit 1fd214e

Please sign in to comment.