From a122f5584656aa2fa37f53688aef5c1047b0b06a Mon Sep 17 00:00:00 2001 From: gak Date: Wed, 10 Jul 2024 06:50:26 +1000 Subject: [PATCH] fix: cors allow headers option and env (#2016) Fixes #1986 - Added `--allow-headers` option eg `--allow-headers x-forwarded-capabilities` - Works with env `FTL_CONTROLLER_ALLOW_HEADERS` --- backend/controller/controller.go | 14 ++- .../ingress/ingress_integration_test.go | 89 +++++++++++++++---- frontend/local.go | 2 +- frontend/release.go | 2 +- .../encoding/encoding_integration_test.go | 2 +- integration/actions.go | 7 +- internal/cors/cors.go | 7 +- 7 files changed, 98 insertions(+), 25 deletions(-) diff --git a/backend/controller/controller.go b/backend/controller/controller.go index 8ddd017a2c..db3d5bf5a1 100644 --- a/backend/controller/controller.go +++ b/backend/controller/controller.go @@ -66,12 +66,20 @@ import ( // CommonConfig between the production controller and development server. type CommonConfig struct { AllowOrigins []*url.URL `help:"Allow CORS requests to ingress endpoints from these origins." env:"FTL_CONTROLLER_ALLOW_ORIGIN"` + AllowHeaders []string `help:"Allow these headers in CORS requests. (Requires AllowOrigins)" env:"FTL_CONTROLLER_ALLOW_HEADERS"` NoConsole bool `help:"Disable the console."` IdleRunners int `help:"Number of idle runners to keep around (not supported in production)." default:"3"` WaitFor []string `help:"Wait for these modules to be deployed before becoming ready." placeholder:"MODULE"` CronJobTimeout time.Duration `help:"Timeout for cron jobs." default:"5m"` } +func (c *CommonConfig) Validate() error { + if len(c.AllowHeaders) > 0 && len(c.AllowOrigins) == 0 { + return fmt.Errorf("AllowOrigins must be set when AllowHeaders is used") + } + return nil +} + type Config struct { Bind *url.URL `help:"Socket to bind to." default:"http://localhost:8892" env:"FTL_CONTROLLER_BIND"` IngressBind *url.URL `help:"Socket to bind to for ingress." default:"http://localhost:8891" env:"FTL_CONTROLLER_INGRESS_BIND"` @@ -139,7 +147,11 @@ func Start(ctx context.Context, config Config, runnerScaling scaling.RunnerScali ingressHandler := http.Handler(svc) if len(config.AllowOrigins) > 0 { - ingressHandler = cors.Middleware(slices.Map(config.AllowOrigins, func(u *url.URL) string { return u.String() }), ingressHandler) + ingressHandler = cors.Middleware( + slices.Map(config.AllowOrigins, func(u *url.URL) string { return u.String() }), + config.AllowHeaders, + ingressHandler, + ) } g, ctx := errgroup.WithContext(ctx) diff --git a/backend/controller/ingress/ingress_integration_test.go b/backend/controller/ingress/ingress_integration_test.go index 2db1cd0d47..8ca60f4339 100644 --- a/backend/controller/ingress/ingress_integration_test.go +++ b/backend/controller/ingress/ingress_integration_test.go @@ -4,6 +4,7 @@ package ingress_test import ( "net/http" + "os" "testing" "github.com/alecthomas/assert/v2" @@ -16,7 +17,7 @@ func TestHttpIngress(t *testing.T) { in.Run(t, "", in.CopyModule("httpingress"), in.Deploy("httpingress"), - in.HttpCall(http.MethodGet, "/users/123/posts/456", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/users/123/posts/456", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Get"]) assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"]) @@ -31,7 +32,7 @@ func TestHttpIngress(t *testing.T) { assert.True(t, ok, "good_stuff is not a string: %s", repr.String(resp.JsonBody)) assert.Equal(t, "This is good stuff", goodStuff) }), - in.HttpCall(http.MethodPost, "/users", in.JsonData(t, in.Obj{"userId": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodPost, "/users", nil, in.JsonData(t, in.Obj{"userId": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 201, resp.Status) assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Post"]) success, ok := resp.JsonBody["success"].(bool) @@ -39,88 +40,140 @@ func TestHttpIngress(t *testing.T) { assert.True(t, success) }), // contains aliased field - in.HttpCall(http.MethodPost, "/users", in.JsonData(t, in.Obj{"user_id": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodPost, "/users", nil, in.JsonData(t, in.Obj{"user_id": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 201, resp.Status) }), - in.HttpCall(http.MethodPut, "/users/123", in.JsonData(t, in.Obj{"postId": "346"}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodPut, "/users/123", nil, in.JsonData(t, in.Obj{"postId": "346"}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Put"]) assert.Equal(t, map[string]any{}, resp.JsonBody) }), - in.HttpCall(http.MethodDelete, "/users/123", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodDelete, "/users/123", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Delete"]) assert.Equal(t, map[string]any{}, resp.JsonBody) }), - in.HttpCall(http.MethodGet, "/queryparams?foo=bar", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/queryparams?foo=bar", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, "bar", string(resp.BodyBytes)) }), - in.HttpCall(http.MethodGet, "/queryparams", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/queryparams", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, "No value", string(resp.BodyBytes)) }), - in.HttpCall(http.MethodGet, "/html", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/html", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"text/html; charset=utf-8"}, resp.Headers["Content-Type"]) assert.Equal(t, "

HTML Page From FTL 🚀!

", string(resp.BodyBytes)) }), - in.HttpCall(http.MethodPost, "/bytes", []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodPost, "/bytes", nil, []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"application/octet-stream"}, resp.Headers["Content-Type"]) assert.Equal(t, []byte("Hello, World!"), resp.BodyBytes) }), - in.HttpCall(http.MethodGet, "/empty", nil, func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/empty", nil, nil, func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, nil, resp.Headers["Content-Type"]) assert.Equal(t, nil, resp.BodyBytes) }), - in.HttpCall(http.MethodGet, "/string", []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/string", nil, []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"]) assert.Equal(t, []byte("Hello, World!"), resp.BodyBytes) }), - in.HttpCall(http.MethodGet, "/int", []byte("1234"), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/int", nil, []byte("1234"), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"]) assert.Equal(t, []byte("1234"), resp.BodyBytes) }), - in.HttpCall(http.MethodGet, "/float", []byte("1234.56789"), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/float", nil, []byte("1234.56789"), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"]) assert.Equal(t, []byte("1234.56789"), resp.BodyBytes) }), - in.HttpCall(http.MethodGet, "/bool", []byte("true"), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/bool", nil, []byte("true"), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"]) assert.Equal(t, []byte("true"), resp.BodyBytes) }), - in.HttpCall(http.MethodGet, "/error", nil, func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/error", nil, nil, func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 500, resp.Status) assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"]) assert.Equal(t, []byte("Error from FTL"), resp.BodyBytes) }), - in.HttpCall(http.MethodGet, "/array/string", in.JsonData(t, []string{"hello", "world"}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/array/string", nil, in.JsonData(t, []string{"hello", "world"}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"]) assert.Equal(t, in.JsonData(t, []string{"hello", "world"}), resp.BodyBytes) }), - in.HttpCall(http.MethodPost, "/array/data", in.JsonData(t, []in.Obj{{"item": "a"}, {"item": "b"}}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodPost, "/array/data", nil, in.JsonData(t, []in.Obj{{"item": "a"}, {"item": "b"}}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"]) assert.Equal(t, in.JsonData(t, []in.Obj{{"item": "a"}, {"item": "b"}}), resp.BodyBytes) }), - in.HttpCall(http.MethodGet, "/typeenum", in.JsonData(t, in.Obj{"name": "A", "value": "hello"}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/typeenum", nil, in.JsonData(t, in.Obj{"name": "A", "value": "hello"}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"]) assert.Equal(t, in.JsonData(t, in.Obj{"name": "A", "value": "hello"}), resp.BodyBytes) }), + // CORS preflight request without CORS middleware enabled + in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{ + "Origin": {"http://localhost:8892"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-Headers": {"x-forwarded-capabilities"}, + }, nil, func(t testing.TB, resp *in.HTTPResponse) { + // should not return access control headers because we have not set up cors in this controller + assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Origin"]) + assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Methods"]) + assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Headers"]) + }), + ) +} + +// Run with CORS enabled via FTL_CONTROLLER_ALLOW_ORIGIN and FTL_CONTROLLER_ALLOW_HEADERS +// This test is similar to TestHttpIngress above with the addition of CORS enabled in the controller. +func TestHttpIngressWithCors(t *testing.T) { + os.Setenv("FTL_CONTROLLER_ALLOW_ORIGIN", "http://localhost:8892") + os.Setenv("FTL_CONTROLLER_ALLOW_HEADERS", "x-forwarded-capabilities") + in.Run(t, "", + in.CopyModule("httpingress"), + in.Deploy("httpingress"), + // A correct CORS preflight request + in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{ + "Origin": {"http://localhost:8892"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-Headers": {"x-forwarded-capabilities"}, + }, nil, func(t testing.TB, resp *in.HTTPResponse) { + assert.Equal(t, []string{"http://localhost:8892"}, resp.Headers["Access-Control-Allow-Origin"]) + assert.Equal(t, []string{"GET"}, resp.Headers["Access-Control-Allow-Methods"]) + assert.Equal(t, []string{"x-forwarded-capabilities"}, resp.Headers["Access-Control-Allow-Headers"]) + }), + // Not allowed headers + in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{ + "Origin": {"http://localhost:8892"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-Headers": {"moo"}, + }, nil, func(t testing.TB, resp *in.HTTPResponse) { + assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Origin"]) + assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Methods"]) + assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Headers"]) + }), + // Not allowed origin + in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{ + "Origin": {"http://localhost:4444"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-Headers": {"x-forwarded-capabilities"}, + }, nil, func(t testing.TB, resp *in.HTTPResponse) { + assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Origin"]) + assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Methods"]) + assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Headers"]) + }), ) } diff --git a/frontend/local.go b/frontend/local.go index 6ea258410a..f7961d662a 100644 --- a/frontend/local.go +++ b/frontend/local.go @@ -43,5 +43,5 @@ func Server(ctx context.Context, timestamp time.Time, publicURL *url.URL, allowO return proxy, nil } - return cors.Middleware([]string{allowOrigin.String()}, proxy), nil + return cors.Middleware([]string{allowOrigin.String()}, nil, proxy), nil } diff --git a/frontend/release.go b/frontend/release.go index d9115df8f8..913119ab0a 100644 --- a/frontend/release.go +++ b/frontend/release.go @@ -47,7 +47,7 @@ func Server(ctx context.Context, timestamp time.Time, publicURL *url.URL, allowO http.ServeContent(w, r, filePath, timestamp, f.(io.ReadSeeker)) }) if allowOrigin != nil { - handler = cors.Middleware([]string{allowOrigin.String()}, handler) + handler = cors.Middleware([]string{allowOrigin.String()}, nil, handler) } return handler, nil } diff --git a/go-runtime/encoding/encoding_integration_test.go b/go-runtime/encoding/encoding_integration_test.go index b3dea04d91..d8157ba364 100644 --- a/go-runtime/encoding/encoding_integration_test.go +++ b/go-runtime/encoding/encoding_integration_test.go @@ -14,7 +14,7 @@ func TestHttpEncodeOmitempty(t *testing.T) { in.Run(t, "", in.CopyModule("omitempty"), in.Deploy("omitempty"), - in.HttpCall(http.MethodGet, "/get", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { + in.HttpCall(http.MethodGet, "/get", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) { assert.Equal(t, 200, resp.Status) _, ok := resp.JsonBody["mustset"] assert.True(t, ok) diff --git a/integration/actions.go b/integration/actions.go index 119593e04a..693a5760ca 100644 --- a/integration/actions.go +++ b/integration/actions.go @@ -403,7 +403,7 @@ func JsonData(t testing.TB, body interface{}) []byte { } // HttpCall makes an HTTP call to the running FTL ingress endpoint. -func HttpCall(method string, path string, body []byte, onResponse func(t testing.TB, resp *HTTPResponse)) Action { +func HttpCall(method string, path string, headers map[string][]string, body []byte, onResponse func(t testing.TB, resp *HTTPResponse)) Action { return func(t testing.TB, ic TestContext) { Infof("HTTP %s %s", method, path) baseURL, err := url.Parse(fmt.Sprintf("http://localhost:8891")) @@ -415,6 +415,11 @@ func HttpCall(method string, path string, body []byte, onResponse func(t testing assert.NoError(t, err) r.Header.Add("Content-Type", "application/json") + for k, vs := range headers { + for _, v := range vs { + r.Header.Add(k, v) + } + } client := http.Client{} resp, err := client.Do(r) diff --git a/internal/cors/cors.go b/internal/cors/cors.go index 7bb0672f88..a3c7cd485f 100644 --- a/internal/cors/cors.go +++ b/internal/cors/cors.go @@ -6,7 +6,10 @@ import ( "github.com/rs/cors" ) -func Middleware(allowOrigins []string, next http.Handler) http.Handler { - c := cors.New(cors.Options{AllowedOrigins: allowOrigins}) +func Middleware(allowOrigins []string, allowHeaders []string, next http.Handler) http.Handler { + c := cors.New(cors.Options{ + AllowedOrigins: allowOrigins, + AllowedHeaders: allowHeaders, + }) return c.Handler(next) }