From 2604f9244e174d01cda9d9fbd67a9cd41afd67d4 Mon Sep 17 00:00:00 2001 From: thedae Date: Mon, 13 Jan 2025 17:16:10 +0100 Subject: [PATCH] Global param propagation for sequential merger Signed-off-by: thedae --- proxy/merging.go | 210 ++++++++++++++++++++++++++++-------------- proxy/merging_test.go | 27 +++++- 2 files changed, 166 insertions(+), 71 deletions(-) diff --git a/proxy/merging.go b/proxy/merging.go index 5d4aaf4c1..a130db9d7 100644 --- a/proxy/merging.go +++ b/proxy/merging.go @@ -5,6 +5,7 @@ package proxy import ( "context" "fmt" + "io" "net/http" "regexp" "strconv" @@ -27,7 +28,7 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi } serviceTimeout := time.Duration(85*endpointConfig.Timeout.Nanoseconds()/100) * time.Nanosecond combiner := getResponseCombiner(endpointConfig.ExtraConfig) - isSequential := shouldRunSequentialMerger(endpointConfig) + isSequential, propagatedParams := sequentialMergerConfig(endpointConfig) logger.Debug( fmt.Sprintf( @@ -57,24 +58,86 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi return parallelMerge(reqClone, serviceTimeout, combiner, next...) } - patterns := make([]string, len(endpointConfig.Backend)) + sequentialReplacements := make([][]sequentialBackendReplacement, totalBackends) + + var rePropagatedParams = regexp.MustCompile(`[Rr]esp(\d+)_?([\w-\.]+)?`) + var reUrlPatterns = regexp.MustCompile(`\{\{\.Resp(\d+)_([\w-\.]+)\}\}`) + destKeyGenerator := func(i string, t string) string { + key := "Resp" + i + if t != "" { + key += "_" + t + } + return key + } + for i, b := range endpointConfig.Backend { - patterns[i] = b.URLPattern + for _, match := range reUrlPatterns.FindAllStringSubmatch(b.URLPattern, -1) { + if len(match) > 1 { + backendIndex, err := strconv.Atoi(match[1]) + if err != nil { + continue + } + + sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{ + backendIndex: backendIndex, + destination: destKeyGenerator(match[1], match[2]), + source: strings.Split(match[2], "."), + fullResponse: len(match[2]) == 0, + }) + } + } + + if i > 0 { + for _, p := range propagatedParams { + for _, match := range rePropagatedParams.FindAllStringSubmatch(p, -1) { + if len(match) > 1 { + backendIndex, err := strconv.Atoi(match[1]) + if err != nil || backendIndex >= totalBackends { + continue + } + + sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{ + backendIndex: backendIndex, + destination: destKeyGenerator(match[1], match[2]), + source: strings.Split(match[2], "."), + fullResponse: len(match[2]) == 0, + }) + } + } + } + } } - return sequentialMerge(reqClone, patterns, serviceTimeout, combiner, next...) + + return sequentialMerge(reqClone, sequentialReplacements, serviceTimeout, combiner, next...) } } -func shouldRunSequentialMerger(cfg *config.EndpointConfig) bool { +type sequentialBackendReplacement struct { + backendIndex int + destination string + source []string + fullResponse bool +} + +func sequentialMergerConfig(cfg *config.EndpointConfig) (bool, []string) { + enabled := false + propagatedParams := []string{} if v, ok := cfg.ExtraConfig[Namespace]; ok { if e, ok := v.(map[string]interface{}); ok { if v, ok := e[isSequentialKey]; ok { c, ok := v.(bool) - return ok && c + enabled = ok && c + } + if v, ok := e[sequentialPropagateKey]; ok { + if a, ok := v.([]interface{}); ok { + for _, p := range a { + propagatedParams = append(propagatedParams, p.(string)) + } + } } } } - return false + return enabled, propagatedParams } func hasUnsafeBackends(cfg *config.EndpointConfig) bool { @@ -118,75 +181,92 @@ func parallelMerge(reqCloner func(*Request) *Request, timeout time.Duration, rc } } -var reMergeKey = regexp.MustCompile(`\{\{\.Resp(\d+)_([\w-\.]+)\}\}`) - -func sequentialMerge(reqCloner func(*Request) *Request, patterns []string, timeout time.Duration, rc ResponseCombiner, next ...Proxy) Proxy { +func sequentialMerge(reqCloner func(*Request) *Request, sequentialReplacements [][]sequentialBackendReplacement, timeout time.Duration, rc ResponseCombiner, next ...Proxy) Proxy { return func(ctx context.Context, request *Request) (*Response, error) { localCtx, cancel := context.WithTimeout(ctx, timeout) parts := make([]*Response, len(next)) out := make(chan *Response, 1) errCh := make(chan error, 1) + sequentialMergeRegistry := map[string]string{} acc := newIncrementalMergeAccumulator(len(next), rc) TxLoop: for i, n := range next { if i > 0 { - for _, match := range reMergeKey.FindAllStringSubmatch(patterns[i], -1) { - if len(match) > 1 { - rNum, err := strconv.Atoi(match[1]) - if err != nil || rNum >= i || parts[rNum] == nil { - continue - } - key := "Resp" + match[1] + "_" + match[2] - - var v interface{} - var ok bool - - data := parts[rNum].Data - keys := strings.Split(match[2], ".") - if len(keys) > 1 { - for _, k := range keys[:len(keys)-1] { - v, ok = data[k] - if !ok { - break - } - clean, ok := v.(map[string]interface{}) - if !ok { - break - } - data = clean + for _, r := range sequentialReplacements[i] { + if r.backendIndex >= i || parts[r.backendIndex] == nil { + continue + } + + var v interface{} + var ok bool + + data := parts[r.backendIndex].Data + if len(r.source) > 1 { + for _, k := range r.source[:len(r.source)-1] { + v, ok = data[k] + if !ok { + break } + clean, ok := v.(map[string]interface{}) + if !ok { + break + } + data = clean } + } - v, ok = data[keys[len(keys)-1]] - if !ok { + if found := sequentialMergeRegistry[r.destination]; found != "" { + request.Params[r.destination] = found + continue + } + + if r.fullResponse { + if parts[r.backendIndex].Io == nil { continue } - switch clean := v.(type) { - case []interface{}: - if len(clean) == 0 { - request.Params[key] = "" - continue - } - var b strings.Builder - for i := 0; i < len(clean)-1; i++ { - fmt.Fprintf(&b, "%v,", clean[i]) - } - fmt.Fprintf(&b, "%v", clean[len(clean)-1]) - request.Params[key] = b.String() - case string: - request.Params[key] = clean - case int: - request.Params[key] = strconv.Itoa(clean) - case float64: - request.Params[key] = strconv.FormatFloat(clean, 'E', -1, 32) - case bool: - request.Params[key] = strconv.FormatBool(clean) - default: - request.Params[key] = fmt.Sprintf("%v", v) + buf, err := io.ReadAll(parts[r.backendIndex].Io) + + if err == nil { + request.Params[r.destination] = string(buf) + sequentialMergeRegistry[r.destination] = string(buf) } + continue } + + v, ok = data[r.source[len(r.source)-1]] + if !ok { + continue + } + + var param string + + switch clean := v.(type) { + case []interface{}: + if len(clean) == 0 { + request.Params[r.destination] = "" + break + } + var b strings.Builder + for i := 0; i < len(clean)-1; i++ { + fmt.Fprintf(&b, "%v,", clean[i]) + } + fmt.Fprintf(&b, "%v", clean[len(clean)-1]) + param = b.String() + case string: + param = clean + case int: + param = strconv.Itoa(clean) + case float64: + param = strconv.FormatFloat(clean, 'E', -1, 32) + case bool: + param = strconv.FormatBool(clean) + default: + param = fmt.Sprintf("%v", v) + } + request.Params[r.destination] = param + sequentialMergeRegistry[r.destination] = param } } @@ -284,22 +364,18 @@ func requestPart(ctx context.Context, next Proxy, request *Request, out chan<- * } func sequentialRequestPart(ctx context.Context, next Proxy, request *Request, out chan<- *Response, failed chan<- error) { - localCtx, cancel := context.WithCancel(ctx) - copyRequest := CloneRequest(request) - in, err := next(localCtx, request) + in, err := next(ctx, request) *request = *copyRequest if err != nil { failed <- err - cancel() return } if in == nil { failed <- errNullResult - cancel() return } select { @@ -307,7 +383,6 @@ func sequentialRequestPart(ctx context.Context, next Proxy, request *Request, ou case <-ctx.Done(): failed <- ctx.Err() } - cancel() } func newMergeError(errs []error) error { @@ -342,9 +417,10 @@ func RegisterResponseCombiner(name string, f ResponseCombiner) { } const ( - mergeKey = "combiner" - isSequentialKey = "sequential" - defaultCombinerName = "default" + mergeKey = "combiner" + isSequentialKey = "sequential" + sequentialPropagateKey = "sequential_propagated_params" + defaultCombinerName = "default" ) var responseCombiners = initResponseCombiners() diff --git a/proxy/merging_test.go b/proxy/merging_test.go index 956465a7b..551b4e985 100644 --- a/proxy/merging_test.go +++ b/proxy/merging_test.go @@ -98,13 +98,16 @@ func TestNewMergeDataMiddleware_sequential(t *testing.T) { {URLPattern: "/"}, {URLPattern: "/aaa/{{.Resp0_array}}"}, {URLPattern: "/aaa/{{.Resp0_int}}/{{.Resp0_string}}/{{.Resp0_bool}}/{{.Resp0_float}}/{{.Resp0_struct.foo}}"}, - {URLPattern: "/aaa/{{.Resp0_int}}/{{.Resp0_string}}/{{.Resp0_bool}}/{{.Resp0_float}}/{{.Resp0_struct.foo}}?x={{.Resp1_tupu}}"}, + {URLPattern: "/aaa/{{.Resp0_int}}/{{.Resp0_string}}/{{.Resp0_bool}}/{{.Resp0_float}}/{{.Resp0_struct.foo}}?x={{.Resp1_tupu}}", Encoding: "noop"}, {URLPattern: "/aaa/{{.Resp0_struct.foo}}/{{.Resp0_struct.struct.foo}}/{{.Resp0_struct.struct.struct.foo}}"}, + {URLPattern: "/zzz", Encoding: "noop"}, + {URLPattern: "/hit-me"}, }, Timeout: time.Duration(timeout) * time.Millisecond, ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ - isSequentialKey: true, + isSequentialKey: true, + sequentialPropagateKey: []interface{}{"resp0_propagated", "resp5"}, }, }, } @@ -144,11 +147,13 @@ func TestNewMergeDataMiddleware_sequential(t *testing.T) { }, }, }, - "array": []interface{}{"1", "2"}, + "array": []interface{}{"1", "2"}, + "propagated": "everywhere", }, IsComplete: true}), func(ctx context.Context, r *Request) (*Response, error) { checkBody(t, r) checkRequestParam(t, r, "Resp0_array", "1,2") + checkRequestParam(t, r, "Resp0_propagated", "everywhere") return &Response{Data: map[string]interface{}{"tupu": "foo"}, IsComplete: true}, nil }, func(ctx context.Context, r *Request) (*Response, error) { @@ -158,6 +163,7 @@ func TestNewMergeDataMiddleware_sequential(t *testing.T) { checkRequestParam(t, r, "Resp0_float", "3.14E+00") checkRequestParam(t, r, "Resp0_bool", "true") checkRequestParam(t, r, "Resp0_struct.foo", "bar") + checkRequestParam(t, r, "Resp0_propagated", "everywhere") return &Response{Data: map[string]interface{}{"tupu": "foo"}, IsComplete: true}, nil }, func(ctx context.Context, r *Request) (*Response, error) { @@ -168,6 +174,7 @@ func TestNewMergeDataMiddleware_sequential(t *testing.T) { checkRequestParam(t, r, "Resp0_bool", "true") checkRequestParam(t, r, "Resp0_struct.foo", "bar") checkRequestParam(t, r, "Resp1_tupu", "foo") + checkRequestParam(t, r, "Resp0_propagated", "everywhere") return &Response{Data: map[string]interface{}{"aaaa": []int{1, 2, 3}}, IsComplete: true}, nil }, func(ctx context.Context, r *Request) (*Response, error) { @@ -175,8 +182,20 @@ func TestNewMergeDataMiddleware_sequential(t *testing.T) { checkRequestParam(t, r, "Resp0_struct.foo", "bar") checkRequestParam(t, r, "Resp0_struct.struct.foo", "bar") checkRequestParam(t, r, "Resp0_struct.struct.struct.foo", "bar") + checkRequestParam(t, r, "Resp0_propagated", "everywhere") return &Response{Data: map[string]interface{}{"bbbb": []bool{true, false}}, IsComplete: true}, nil }, + func(ctx context.Context, r *Request) (*Response, error) { + checkBody(t, r) + checkRequestParam(t, r, "Resp0_propagated", "everywhere") + return &Response{Data: map[string]interface{}{}, Io: io.NopCloser(strings.NewReader("hello")), IsComplete: true}, nil + }, + func(ctx context.Context, r *Request) (*Response, error) { + checkBody(t, r) + checkRequestParam(t, r, "Resp0_propagated", "everywhere") + checkRequestParam(t, r, "Resp5", "hello") + return &Response{Data: map[string]interface{}{}, IsComplete: true}, nil + }, ) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{ @@ -194,7 +213,7 @@ func TestNewMergeDataMiddleware_sequential(t *testing.T) { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: - if len(out.Data) != 9 { + if len(out.Data) != 10 { t.Errorf("We weren't expecting a partial response but we got %v!\n", out) } if !out.IsComplete {