Skip to content

Commit

Permalink
Merge pull request #141 from launchdarkly/eb/ch83491/relay-base
Browse files Browse the repository at this point in the history
(v6 - #9) create RelayCore and move most of the core logic into it
  • Loading branch information
eli-darkly authored Aug 4, 2020
2 parents a487574 + 8fd3f2f commit dd1fa96
Show file tree
Hide file tree
Showing 17 changed files with 872 additions and 414 deletions.
29 changes: 6 additions & 23 deletions client-side.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import (
"net/http"
"net/http/httptest"
"strconv"
"strings"

"net/http/httputil"

"github.com/launchdarkly/ld-relay/v6/config"
"github.com/launchdarkly/ld-relay/v6/internal/cors"
"github.com/launchdarkly/ld-relay/v6/internal/events"
"github.com/launchdarkly/ld-relay/v6/internal/relayenv"
"github.com/launchdarkly/ld-relay/v6/internal/util"
Expand Down Expand Up @@ -57,7 +57,11 @@ func (m clientSideMux) selectClientByUrlParam(next http.Handler) http.Handler {
return
}

req = req.WithContext(context.WithValue(req.Context(), contextKey, clientCtx))
reqContext := context.WithValue(req.Context(), contextKey, clientCtx)
// Even though the clientCtx also serves as a CORSContext, we attach it separately here just to keep
// the CORS implementation less reliant on other unrelated implementation details
reqContext = cors.WithCORSContext(reqContext, clientCtx)
req = req.WithContext(reqContext)
next.ServeHTTP(w, req)
})
}
Expand All @@ -67,27 +71,6 @@ func (m clientSideMux) getGoals(w http.ResponseWriter, req *http.Request) {
clientCtx.proxy.ServeHTTP(w, req)
}

var allowedHeadersList = []string{
"Cache-Control",
"Content-Type",
"Content-Length",
"Accept-Encoding",
"X-LaunchDarkly-User-Agent",
"X-LaunchDarkly-Payload-ID",
"X-LaunchDarkly-Wrapper",
events.EventSchemaHeader,
}

var allowedHeaders = strings.Join(allowedHeadersList, ",")

func setCorsHeaders(w http.ResponseWriter, origin string) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "false")
w.Header().Set("Access-Control-Max-Age", "300")
w.Header().Set("Access-Control-Allow-Headers", allowedHeaders)
w.Header().Set("Access-Control-Expose-Headers", "Date")
}

const transparent1PixelImgBase64 = "R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7="

var transparent1PixelImg []byte
Expand Down
61 changes: 61 additions & 0 deletions internal/cors/cors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,67 @@
package cors

import (
"context"
"net/http"
"strings"

"github.com/launchdarkly/ld-relay/v6/internal/events"
)

const (
// The default origin string to use in CORS response headers.
DefaultAllowedOrigin = "*"
)

type contextKeyType string

const (
contextKey contextKeyType = "context"
maxAge string = "300"
)

var allowedHeadersList = []string{
"Cache-Control",
"Content-Type",
"Content-Length",
"Accept-Encoding",
"X-LaunchDarkly-User-Agent",
"X-LaunchDarkly-Payload-ID",
"X-LaunchDarkly-Wrapper",
events.EventSchemaHeader,
}

var allowedHeaders = strings.Join(allowedHeadersList, ",")

// RequestContext represents a scope that has a specific set of allowed origins for CORS requests. This
// can be attached to a request context with WithCORSContext().
type RequestContext interface {
AllowedOrigins() []string
}

// GetCORSContext returns the CORSContext that has been attached to this Context with WithCORSContext(),
// or nil if none.
func GetCORSContext(ctx context.Context) RequestContext {
if cc, ok := ctx.Value(contextKey).(RequestContext); ok {
return cc
}
return nil
}

// WithCORSContext returns a copy of the parent context with the specified CORSContext attached.
func WithCORSContext(parent context.Context, cc RequestContext) context.Context {
if cc == nil {
return parent
}
return context.WithValue(parent, contextKey, cc)
}

// SetCORSHeaders sets a standard set of CORS headers on an HTTP response. This is meant to be the same
// behavior that the LaunchDarkly service endpoints uses for client-side JS requests.
func SetCORSHeaders(w http.ResponseWriter, origin string) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "false")
w.Header().Set("Access-Control-Max-Age", maxAge)
w.Header().Set("Access-Control-Allow-Headers", allowedHeaders)
w.Header().Set("Access-Control-Expose-Headers", "Date")
}
43 changes: 43 additions & 0 deletions internal/cors/cors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package cors

import (
"context"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

type mockCORSContext struct{}

func (m mockCORSContext) AllowedOrigins() []string {
return nil
}

func TestCORSContext(t *testing.T) {
t.Run("GetCORSContext when there is no RequestContext returns nil", func(t *testing.T) {
assert.Nil(t, GetCORSContext(context.Background()))
})

t.Run("WithCORSContext adds RequestContext to context", func(t *testing.T) {
m := mockCORSContext{}
ctx := WithCORSContext(context.Background(), m)
assert.Equal(t, m, GetCORSContext(ctx))
})

t.Run("WithCORSContext has no effect with nil parameter", func(t *testing.T) {
ctx := WithCORSContext(context.Background(), nil)
assert.Equal(t, context.Background(), ctx)
})

t.Run("SetCORSHeaders", func(t *testing.T) {
origin := "http://good.cat"
rr := httptest.ResponseRecorder{}
SetCORSHeaders(&rr, origin)
assert.Equal(t, origin, rr.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "false", rr.Header().Get("Access-Control-Allow-Credentials"))
assert.Equal(t, maxAge, rr.Header().Get("Access-Control-Max-Age"))
assert.Equal(t, allowedHeaders, rr.Header().Get("Access-Control-Allow-Headers"))
assert.Equal(t, "Date", rr.Header().Get("Access-Control-Expose-Headers"))
})
}
3 changes: 3 additions & 0 deletions internal/relayenv/env_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package relayenv

import (
"context"
"io"
"net/http"
"time"

Expand All @@ -19,6 +20,8 @@ import (
// connection may take a while, so it is possible for the client and store references to be nil if initialization
// is not yet complete.
type EnvContext interface {
io.Closer

// GetName returns the configured name of the environment.
GetName() string

Expand Down
5 changes: 5 additions & 0 deletions internal/relayenv/env_context_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,8 @@ func (c *envContextImpl) GetInitError() error {
func (c *envContextImpl) IsSecureMode() bool {
return c.secureMode
}

func (c *envContextImpl) Close() error {
// This currently isn't used, but will be used in the future when we can dynamically change environments
return nil
}
75 changes: 7 additions & 68 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,15 @@ import (
"encoding/json"
"errors"
"net/http"
"regexp"

"github.com/gorilla/mux"

"github.com/launchdarkly/ld-relay/v6/config"
"github.com/launchdarkly/ld-relay/v6/internal/cors"
"github.com/launchdarkly/ld-relay/v6/internal/metrics"
"github.com/launchdarkly/ld-relay/v6/internal/relayenv"
"github.com/launchdarkly/ld-relay/v6/internal/version"
"gopkg.in/launchdarkly/go-sdk-common.v2/lduser"
ld "gopkg.in/launchdarkly/go-server-sdk.v5"
)

var (
uuidHeaderPattern = regexp.MustCompile(`^(?:api_key )?((?:[a-z]{3}-)?[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89aAbB][a-f0-9]{3}-[a-f0-9]{12})$`)
)

type corsContext interface {
AllowedOrigins() []string
}

func chainMiddleware(middlewares ...mux.MiddlewareFunc) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
handler := next
Expand All @@ -37,56 +25,7 @@ func chainMiddleware(middlewares ...mux.MiddlewareFunc) mux.MiddlewareFunc {
}
}

type clientMux struct {
clientContextByKey map[config.SDKCredential]relayenv.EnvContext
}

func (m clientMux) getStatus(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
envs := make(map[string]environmentStatus)

healthy := true
for _, clientCtx := range m.clientContextByKey {
var status environmentStatus
creds := clientCtx.GetCredentials()
status.SdkKey = obscureKey(creds.SDKKey)
if mobileKey, ok := creds.MobileKey.Get(); ok {
status.MobileKey = obscureKey(mobileKey)
}
status.EnvId = creds.EnvironmentID.StringValue()
client := clientCtx.GetClient()
if client == nil || !client.Initialized() {
status.Status = "disconnected"
healthy = false
} else {
status.Status = "connected"
}
envs[clientCtx.GetName()] = status
}

resp := struct {
Environments map[string]environmentStatus `json:"environments"`
Status string `json:"status"`
Version string `json:"version"`
ClientVersion string `json:"clientVersion"`
}{
Environments: envs,
Version: version.Version,
ClientVersion: ld.Version,
}

if healthy {
resp.Status = "healthy"
} else {
resp.Status = "degraded"
}

data, _ := json.Marshal(resp)

w.Write(data)
}

func (m clientMux) selectClientByAuthorizationKey(sdkKind sdkKind) func(http.Handler) http.Handler {
func selectEnvironmentByAuthorizationKey(sdkKind sdkKind, envs RelayEnvironments) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
credential, err := sdkKind.getSDKCredential(req)
Expand All @@ -95,7 +34,7 @@ func (m clientMux) selectClientByAuthorizationKey(sdkKind sdkKind) func(http.Han
return
}

clientCtx := m.clientContextByKey[credential]
clientCtx := envs.GetEnvironment(credential)

if clientCtx == nil {
w.WriteHeader(http.StatusUnauthorized)
Expand Down Expand Up @@ -168,24 +107,24 @@ func withGauge(handler http.Handler, measure metrics.Measure) http.Handler {
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var domains []string
if context, ok := r.Context().Value(contextKey).(corsContext); ok {
domains = context.AllowedOrigins()
if corsContext := cors.GetCORSContext(r.Context()); corsContext != nil {
domains = corsContext.AllowedOrigins()
}
if len(domains) > 0 {
for _, d := range domains {
if r.Header.Get("Origin") == d {
setCorsHeaders(w, d)
cors.SetCORSHeaders(w, d)
return
}
}
// Not a valid origin, set allowed origin to any allowed origin
setCorsHeaders(w, domains[0])
cors.SetCORSHeaders(w, domains[0])
} else {
origin := cors.DefaultAllowedOrigin
if r.Header.Get("Origin") != "" {
origin = r.Header.Get("Origin")
}
setCorsHeaders(w, origin)
cors.SetCORSHeaders(w, origin)
}
next.ServeHTTP(w, r)
})
Expand Down
Loading

0 comments on commit dd1fa96

Please sign in to comment.