Skip to content

Commit

Permalink
fix: cors allow headers option and env (#2016)
Browse files Browse the repository at this point in the history
Fixes #1986 

- Added `--allow-headers` option eg `--allow-headers
x-forwarded-capabilities`
- Works with env `FTL_CONTROLLER_ALLOW_HEADERS`
  • Loading branch information
gak authored Jul 9, 2024
1 parent aafdc2e commit a122f55
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 25 deletions.
14 changes: 13 additions & 1 deletion backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,20 @@ import (
// CommonConfig between the production controller and development server.
type CommonConfig struct {
AllowOrigins []*url.URL `help:"Allow CORS requests to ingress endpoints from these origins." env:"FTL_CONTROLLER_ALLOW_ORIGIN"`
AllowHeaders []string `help:"Allow these headers in CORS requests. (Requires AllowOrigins)" env:"FTL_CONTROLLER_ALLOW_HEADERS"`
NoConsole bool `help:"Disable the console."`
IdleRunners int `help:"Number of idle runners to keep around (not supported in production)." default:"3"`
WaitFor []string `help:"Wait for these modules to be deployed before becoming ready." placeholder:"MODULE"`
CronJobTimeout time.Duration `help:"Timeout for cron jobs." default:"5m"`
}

func (c *CommonConfig) Validate() error {
if len(c.AllowHeaders) > 0 && len(c.AllowOrigins) == 0 {
return fmt.Errorf("AllowOrigins must be set when AllowHeaders is used")
}
return nil
}

type Config struct {
Bind *url.URL `help:"Socket to bind to." default:"http://localhost:8892" env:"FTL_CONTROLLER_BIND"`
IngressBind *url.URL `help:"Socket to bind to for ingress." default:"http://localhost:8891" env:"FTL_CONTROLLER_INGRESS_BIND"`
Expand Down Expand Up @@ -139,7 +147,11 @@ func Start(ctx context.Context, config Config, runnerScaling scaling.RunnerScali

ingressHandler := http.Handler(svc)
if len(config.AllowOrigins) > 0 {
ingressHandler = cors.Middleware(slices.Map(config.AllowOrigins, func(u *url.URL) string { return u.String() }), ingressHandler)
ingressHandler = cors.Middleware(
slices.Map(config.AllowOrigins, func(u *url.URL) string { return u.String() }),
config.AllowHeaders,
ingressHandler,
)
}

g, ctx := errgroup.WithContext(ctx)
Expand Down
89 changes: 71 additions & 18 deletions backend/controller/ingress/ingress_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package ingress_test

import (
"net/http"
"os"
"testing"

"github.com/alecthomas/assert/v2"
Expand All @@ -16,7 +17,7 @@ func TestHttpIngress(t *testing.T) {
in.Run(t, "",
in.CopyModule("httpingress"),
in.Deploy("httpingress"),
in.HttpCall(http.MethodGet, "/users/123/posts/456", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/users/123/posts/456", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Get"])
assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"])
Expand All @@ -31,96 +32,148 @@ func TestHttpIngress(t *testing.T) {
assert.True(t, ok, "good_stuff is not a string: %s", repr.String(resp.JsonBody))
assert.Equal(t, "This is good stuff", goodStuff)
}),
in.HttpCall(http.MethodPost, "/users", in.JsonData(t, in.Obj{"userId": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodPost, "/users", nil, in.JsonData(t, in.Obj{"userId": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 201, resp.Status)
assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Post"])
success, ok := resp.JsonBody["success"].(bool)
assert.True(t, ok, "success is not a bool: %s", repr.String(resp.JsonBody))
assert.True(t, success)
}),
// contains aliased field
in.HttpCall(http.MethodPost, "/users", in.JsonData(t, in.Obj{"user_id": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodPost, "/users", nil, in.JsonData(t, in.Obj{"user_id": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 201, resp.Status)
}),
in.HttpCall(http.MethodPut, "/users/123", in.JsonData(t, in.Obj{"postId": "346"}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodPut, "/users/123", nil, in.JsonData(t, in.Obj{"postId": "346"}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Put"])
assert.Equal(t, map[string]any{}, resp.JsonBody)
}),
in.HttpCall(http.MethodDelete, "/users/123", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodDelete, "/users/123", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Delete"])
assert.Equal(t, map[string]any{}, resp.JsonBody)
}),

in.HttpCall(http.MethodGet, "/queryparams?foo=bar", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/queryparams?foo=bar", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, "bar", string(resp.BodyBytes))
}),

in.HttpCall(http.MethodGet, "/queryparams", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/queryparams", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, "No value", string(resp.BodyBytes))
}),

in.HttpCall(http.MethodGet, "/html", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/html", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"text/html; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, "<html><body><h1>HTML Page From FTL 🚀!</h1></body></html>", string(resp.BodyBytes))
}),

in.HttpCall(http.MethodPost, "/bytes", []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodPost, "/bytes", nil, []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"application/octet-stream"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("Hello, World!"), resp.BodyBytes)
}),

in.HttpCall(http.MethodGet, "/empty", nil, func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/empty", nil, nil, func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, nil, resp.Headers["Content-Type"])
assert.Equal(t, nil, resp.BodyBytes)
}),

in.HttpCall(http.MethodGet, "/string", []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/string", nil, []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("Hello, World!"), resp.BodyBytes)
}),

in.HttpCall(http.MethodGet, "/int", []byte("1234"), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/int", nil, []byte("1234"), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("1234"), resp.BodyBytes)
}),
in.HttpCall(http.MethodGet, "/float", []byte("1234.56789"), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/float", nil, []byte("1234.56789"), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("1234.56789"), resp.BodyBytes)
}),
in.HttpCall(http.MethodGet, "/bool", []byte("true"), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/bool", nil, []byte("true"), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("true"), resp.BodyBytes)
}),
in.HttpCall(http.MethodGet, "/error", nil, func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/error", nil, nil, func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 500, resp.Status)
assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("Error from FTL"), resp.BodyBytes)
}),
in.HttpCall(http.MethodGet, "/array/string", in.JsonData(t, []string{"hello", "world"}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/array/string", nil, in.JsonData(t, []string{"hello", "world"}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, in.JsonData(t, []string{"hello", "world"}), resp.BodyBytes)
}),
in.HttpCall(http.MethodPost, "/array/data", in.JsonData(t, []in.Obj{{"item": "a"}, {"item": "b"}}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodPost, "/array/data", nil, in.JsonData(t, []in.Obj{{"item": "a"}, {"item": "b"}}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, in.JsonData(t, []in.Obj{{"item": "a"}, {"item": "b"}}), resp.BodyBytes)
}),
in.HttpCall(http.MethodGet, "/typeenum", in.JsonData(t, in.Obj{"name": "A", "value": "hello"}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/typeenum", nil, in.JsonData(t, in.Obj{"name": "A", "value": "hello"}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, in.JsonData(t, in.Obj{"name": "A", "value": "hello"}), resp.BodyBytes)
}),
// CORS preflight request without CORS middleware enabled
in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{
"Origin": {"http://localhost:8892"},
"Access-Control-Request-Method": {"GET"},
"Access-Control-Request-Headers": {"x-forwarded-capabilities"},
}, nil, func(t testing.TB, resp *in.HTTPResponse) {
// should not return access control headers because we have not set up cors in this controller
assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Origin"])
assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Methods"])
assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Headers"])
}),
)
}

// Run with CORS enabled via FTL_CONTROLLER_ALLOW_ORIGIN and FTL_CONTROLLER_ALLOW_HEADERS
// This test is similar to TestHttpIngress above with the addition of CORS enabled in the controller.
func TestHttpIngressWithCors(t *testing.T) {
os.Setenv("FTL_CONTROLLER_ALLOW_ORIGIN", "http://localhost:8892")
os.Setenv("FTL_CONTROLLER_ALLOW_HEADERS", "x-forwarded-capabilities")
in.Run(t, "",
in.CopyModule("httpingress"),
in.Deploy("httpingress"),
// A correct CORS preflight request
in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{
"Origin": {"http://localhost:8892"},
"Access-Control-Request-Method": {"GET"},
"Access-Control-Request-Headers": {"x-forwarded-capabilities"},
}, nil, func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, []string{"http://localhost:8892"}, resp.Headers["Access-Control-Allow-Origin"])
assert.Equal(t, []string{"GET"}, resp.Headers["Access-Control-Allow-Methods"])
assert.Equal(t, []string{"x-forwarded-capabilities"}, resp.Headers["Access-Control-Allow-Headers"])
}),
// Not allowed headers
in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{
"Origin": {"http://localhost:8892"},
"Access-Control-Request-Method": {"GET"},
"Access-Control-Request-Headers": {"moo"},
}, nil, func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Origin"])
assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Methods"])
assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Headers"])
}),
// Not allowed origin
in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{
"Origin": {"http://localhost:4444"},
"Access-Control-Request-Method": {"GET"},
"Access-Control-Request-Headers": {"x-forwarded-capabilities"},
}, nil, func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Origin"])
assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Methods"])
assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Headers"])
}),
)
}
2 changes: 1 addition & 1 deletion frontend/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ func Server(ctx context.Context, timestamp time.Time, publicURL *url.URL, allowO
return proxy, nil
}

return cors.Middleware([]string{allowOrigin.String()}, proxy), nil
return cors.Middleware([]string{allowOrigin.String()}, nil, proxy), nil
}
2 changes: 1 addition & 1 deletion frontend/release.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func Server(ctx context.Context, timestamp time.Time, publicURL *url.URL, allowO
http.ServeContent(w, r, filePath, timestamp, f.(io.ReadSeeker))
})
if allowOrigin != nil {
handler = cors.Middleware([]string{allowOrigin.String()}, handler)
handler = cors.Middleware([]string{allowOrigin.String()}, nil, handler)
}
return handler, nil
}
2 changes: 1 addition & 1 deletion go-runtime/encoding/encoding_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func TestHttpEncodeOmitempty(t *testing.T) {
in.Run(t, "",
in.CopyModule("omitempty"),
in.Deploy("omitempty"),
in.HttpCall(http.MethodGet, "/get", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
in.HttpCall(http.MethodGet, "/get", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
_, ok := resp.JsonBody["mustset"]
assert.True(t, ok)
Expand Down
7 changes: 6 additions & 1 deletion integration/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ func JsonData(t testing.TB, body interface{}) []byte {
}

// HttpCall makes an HTTP call to the running FTL ingress endpoint.
func HttpCall(method string, path string, body []byte, onResponse func(t testing.TB, resp *HTTPResponse)) Action {
func HttpCall(method string, path string, headers map[string][]string, body []byte, onResponse func(t testing.TB, resp *HTTPResponse)) Action {
return func(t testing.TB, ic TestContext) {
Infof("HTTP %s %s", method, path)
baseURL, err := url.Parse(fmt.Sprintf("http://localhost:8891"))
Expand All @@ -415,6 +415,11 @@ func HttpCall(method string, path string, body []byte, onResponse func(t testing
assert.NoError(t, err)

r.Header.Add("Content-Type", "application/json")
for k, vs := range headers {
for _, v := range vs {
r.Header.Add(k, v)
}
}

client := http.Client{}
resp, err := client.Do(r)
Expand Down
7 changes: 5 additions & 2 deletions internal/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import (
"github.com/rs/cors"
)

func Middleware(allowOrigins []string, next http.Handler) http.Handler {
c := cors.New(cors.Options{AllowedOrigins: allowOrigins})
func Middleware(allowOrigins []string, allowHeaders []string, next http.Handler) http.Handler {
c := cors.New(cors.Options{
AllowedOrigins: allowOrigins,
AllowedHeaders: allowHeaders,
})
return c.Handler(next)
}

0 comments on commit a122f55

Please sign in to comment.