Skip to content

Commit

Permalink
feat: set request headers
Browse files Browse the repository at this point in the history
Signed-off-by: Utkarsh Saxena <[email protected]>
  • Loading branch information
utk-spartan committed May 27, 2024
1 parent d15beea commit 3070218
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package migration

import (
"database/sql"

"github.com/pressly/goose/v3"
)

func init() {
goose.AddMigration(Up20240525205304, Down20240525205304)
}

func Up20240525205304(tx *sql.Tx) error {
var err error

_, err = tx.Exec("ALTER TABLE `policies` ADD COLUMN `set_request_source` VARCHAR(255) DEFAULT '';")
if err != nil {
return err
}
return err
}

func Down20240525205304(tx *sql.Tx) error {
var err error

_, err = tx.Exec("ALTER TABLE `policies` DROP COLUMN `set_request_source`;")
if err != nil {
return err
}
return err
}
13 changes: 7 additions & 6 deletions internal/gatewayserver/models/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import "github.com/razorpay/trino-gateway/pkg/spine"
// policy model struct definition
type Policy struct {
spine.Model
RuleType string `json:"rule_type"`
RuleValue string `json:"rule_value"`
GroupId string `json:"group_id"`
FallbackGroupId *string `json:"fallback_group_id"`
IsEnabled *bool `json:"is_enabled" sql:"DEFAULT:true"`
IsAuthDelegated *bool `json:"is_auth_delegated" sql:"DEFAULT:false"`
RuleType string `json:"rule_type"`
RuleValue string `json:"rule_value"`
GroupId string `json:"group_id"`
FallbackGroupId *string `json:"fallback_group_id"`
IsEnabled *bool `json:"is_enabled" sql:"DEFAULT:true"`
IsAuthDelegated *bool `json:"is_auth_delegated" sql:"DEFAULT:false"`
SetRequestSource *string `json:"set_request_source"`
}

func (u *Policy) TableName() string {
Expand Down
50 changes: 37 additions & 13 deletions internal/gatewayserver/policyApi/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type ICore interface {

EvaluateGroupsForClient(ctx context.Context, c *EvaluateClientParams) ([]string, error)
EvaluateAuthDelegation(ctx context.Context, p int32) (bool, error)
EvaluateRequestSource(ctx context.Context, p int32) (string, error)
// EvaluatePolicy(ctx context.Context, group string) (string, error)
// FindPolicyForQuery(ctx context.Context, q string) (string, error)
}
Expand All @@ -36,23 +37,25 @@ func NewCore(policy repo.IPolicyRepo) *Core {

// CreateParams has attributes that are required for policy.Create()
type PolicyCreateParams struct {
ID string
RuleType string
RuleValue string
Group string
FallbackGroup string
IsEnabled bool
IsAuthDelegated bool
ID string
RuleType string
RuleValue string
Group string
FallbackGroup string
IsEnabled bool
IsAuthDelegated bool
SetRequestSource string
}

func (c *Core) CreateOrUpdatePolicy(ctx context.Context, params *PolicyCreateParams) error {
policy := models.Policy{
RuleType: params.RuleType,
RuleValue: params.RuleValue,
GroupId: params.Group,
FallbackGroupId: &params.FallbackGroup,
IsEnabled: &params.IsEnabled,
IsAuthDelegated: &params.IsAuthDelegated,
RuleType: params.RuleType,
RuleValue: params.RuleValue,
GroupId: params.Group,
FallbackGroupId: &params.FallbackGroup,
IsEnabled: &params.IsEnabled,
IsAuthDelegated: &params.IsAuthDelegated,
SetRequestSource: &params.SetRequestSource,
}
policy.ID = params.ID

Expand Down Expand Up @@ -240,6 +243,27 @@ func (c *Core) EvaluateAuthDelegation(ctx context.Context, port int32) (bool, er
return false, nil
}

func (c *Core) EvaluateRequestSource(ctx context.Context, port int32) (string, error) {
res, err := c.FindMany(
ctx,
&FindManyParams{
IsEnabled: true,
RuleType: "listening_port",
RuleValue: strconv.Itoa(int(port)),
})
if err != nil {
return "", err
}
provider.Logger(ctx).Debugw("Evaluate Request Source For Port", map[string]interface{}{
"listeningPort": port,
"matchingRules": res,
})
if len(res) > 0 {
return *res[0].SetRequestSource, nil
}
return "", nil
}

// Implementing "set" collection methods here, :)
func setIntersection(s1 map[string]struct{}, s2 map[string]struct{}) map[string]struct{} {
s_intersection := map[string]struct{}{}
Expand Down
49 changes: 36 additions & 13 deletions internal/gatewayserver/policyApi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ func (s *Server) CreateOrUpdatePolicy(ctx context.Context, req *gatewayv1.Policy
})

createParams := PolicyCreateParams{
ID: req.GetId(),
RuleType: req.GetRule().GetType().Enum().String(),
RuleValue: req.GetRule().GetValue(),
Group: req.GetGroup(),
FallbackGroup: req.GetFallbackGroup(),
IsEnabled: req.GetIsEnabled(),
IsAuthDelegated: req.GetIsAuthDelegated(),
ID: req.GetId(),
RuleType: req.GetRule().GetType().Enum().String(),
RuleValue: req.GetRule().GetValue(),
Group: req.GetGroup(),
FallbackGroup: req.GetFallbackGroup(),
IsEnabled: req.GetIsEnabled(),
IsAuthDelegated: req.GetIsAuthDelegated(),
SetRequestSource: req.GetSetRequestSource(),
}

err := s.core.CreateOrUpdatePolicy(ctx, &createParams)
Expand Down Expand Up @@ -143,12 +144,13 @@ func toPolicyResponseProto(policy *models.Policy) (*gatewayv1.Policy, error) {
Value: policy.RuleValue,
}
response := gatewayv1.Policy{
Id: policy.ID,
Rule: &rule,
Group: policy.GroupId,
FallbackGroup: *policy.FallbackGroupId,
IsEnabled: *policy.IsEnabled,
IsAuthDelegated: *policy.IsAuthDelegated,
Id: policy.ID,
Rule: &rule,
Group: policy.GroupId,
FallbackGroup: *policy.FallbackGroupId,
IsEnabled: *policy.IsEnabled,
IsAuthDelegated: *policy.IsAuthDelegated,
SetRequestSource: *policy.SetRequestSource,
}

return &response, nil
Expand Down Expand Up @@ -198,3 +200,24 @@ func (s *Server) EvaluateAuthDelegationForClient(ctx context.Context, req *gatew
}
return &gatewayv1.EvaluateAuthDelegationResponse{IsAuthDelegated: result}, nil
}

func (s *Server) EvaluateRequestSourceForClient(ctx context.Context, req *gatewayv1.EvaluateRequestSourceRequest) (*gatewayv1.EvaluateRequestSourceResponse, error) {
provider.Logger(ctx).Debugw("EvaluateRequestSource", map[string]interface{}{
"request": req.String(),
})

if req.GetIncomingPort() == 0 {
err := errors.New("Invalid port defined in `incoming_port`.")
provider.Logger(ctx).WithError(err).Error(err.Error())
return &gatewayv1.EvaluateRequestSourceResponse{SetRequestSource: ""}, nil
}

result, err := s.core.EvaluateRequestSource(
ctx,
req.GetIncomingPort(),
)
if err != nil {
return nil, err
}
return &gatewayv1.EvaluateRequestSourceResponse{SetRequestSource: result}, nil
}
9 changes: 9 additions & 0 deletions internal/router/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,15 @@ func (r *RouterServer) prepareReqForRouting(ctx *context.Context, req *http.Requ
req.URL.Host = host
req.URL.Scheme = scheme
req.Host = host
sourceHeader, err := r.gatewayApiClient.Policy.EvaluateRequestSourceForClient(*ctx, &gatewayv1.EvaluateRequestSourceRequest{
IncomingPort: int32(r.port),
})
if err != nil {
return err
}
if s := sourceHeader.GetSetRequestSource(); s != "" {
req.Header.Set("X-Trino-Source", s)
}
// TODO - validate and refine parsing of X-Forwarded headers
req.Header.Set("X-Forwarded-Host", host)
provider.Logger(*ctx).Infow(
Expand Down
1 change: 1 addition & 0 deletions internal/router/trinoheaders/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const (
ConnectionProperties = "Connection-Properties"
TransactionId = "Transaction-Id"
Password = "Password"
Source = "Source"
)

var allowedPrefixes = [...]string{"Presto", "Trino"}
Expand Down
11 changes: 11 additions & 0 deletions rpc/gateway/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ service PolicyApi {
};

rpc EvaluateAuthDelegationForClient(EvaluateAuthDelegationRequest) returns (EvaluateAuthDelegationResponse);

rpc EvaluateRequestSourceForClient(EvaluateRequestSourceRequest) returns (EvaluateRequestSourceResponse);
}

message Policy {
Expand All @@ -264,6 +266,7 @@ message Policy {
string fallback_group = 4;
bool is_enabled = 5;
bool is_auth_delegated = 6;
string set_request_source = 7;
}

message PolicyGetRequest {
Expand Down Expand Up @@ -309,6 +312,14 @@ message EvaluateAuthDelegationResponse {
bool is_auth_delegated = 1; // required
}

message EvaluateRequestSourceRequest {
int32 incoming_port = 1; // required
}

message EvaluateRequestSourceResponse {
string set_request_source = 1;
}

service QueryApi {
rpc CreateOrUpdateQuery (Query) returns (Empty);
rpc GetQuery (QueryGetRequest) returns (QueryGetResponse){
Expand Down

0 comments on commit 3070218

Please sign in to comment.