diff --git a/pkg/keycloak/config/config.go b/pkg/keycloak/config/config.go index e7f5c71a..b12453a0 100644 --- a/pkg/keycloak/config/config.go +++ b/pkg/keycloak/config/config.go @@ -75,6 +75,8 @@ type Config struct { OpenIDProviderRetryCount int `env:"OPENID_PROVIDER_RETRY_COUNT" json:"openid-provider-retry-count" usage:"number of retries for retrieving openid configuration" yaml:"openid-provider-retry-count"` // OpenIDProviderHeaders OpenIDProviderHeaders map[string]string `json:"openid-provider-headers" usage:"http headers sent to idp provider" yaml:"openid-provider-headers"` + // UpstreamProxy proxy for upstream communication + UpstreamProxy string `env:"UPSTREAM_PROXY" json:"upstream-proxy" usage:"proxy for communication with upstream" yaml:"upstream-proxy"` // BaseURI is prepended to all the generated URIs BaseURI string `env:"BASE_URI" json:"base-uri" usage:"common prefix for all URIs" yaml:"base-uri"` // OAuthURI is the uri for the oauth endpoints for the proxy diff --git a/pkg/keycloak/proxy/server.go b/pkg/keycloak/proxy/server.go index 268df48c..859c8464 100644 --- a/pkg/keycloak/proxy/server.go +++ b/pkg/keycloak/proxy/server.go @@ -1259,7 +1259,6 @@ func (r *OauthProxy) createUpstreamProxy(upstream *url.URL) error { // and for refreshed cookies (htts://github.com/louketo/louketo-proxy/pulls/456]) proxy.KeepDestinationHeaders = true proxy.Logger = httplog.New(io.Discard, "", 0) - proxy.KeepDestinationHeaders = true r.Upstream = proxy // update the tls configuration of the reverse proxy @@ -1269,8 +1268,15 @@ func (r *OauthProxy) createUpstreamProxy(upstream *url.URL) error { return apperrors.ErrAssertionFailed } + var upstreamProxyFunc func(*http.Request) (*url.URL, error) + if r.Config.UpstreamProxy != "" { + upstreamProxyFunc = func(req *http.Request) (*url.URL, error) { + return url.Parse(r.Config.UpstreamProxy) + } + } upstreamProxy.Tr = &http.Transport{ Dial: dialer, + Proxy: upstreamProxyFunc, DisableKeepAlives: !r.Config.UpstreamKeepalives, ExpectContinueTimeout: r.Config.UpstreamExpectContinueTimeout, ResponseHeaderTimeout: r.Config.UpstreamResponseHeaderTimeout, diff --git a/pkg/testsuite/constant.go b/pkg/testsuite/constant.go index d29f458c..224dcccf 100644 --- a/pkg/testsuite/constant.go +++ b/pkg/testsuite/constant.go @@ -23,6 +23,10 @@ const ( FakeCertFilePrefix = "/gateadmin_crt_" FakePrivFilePrefix = "/gateadmin_priv_" FakeCaFilePrefix = "/gateadmin_ca_" + TestProxyHeaderKey = "X-GoProxy" + TestProxyHeaderVal = "yxorPoG-X" ) var ErrCreateFakeProxy = errors.New("failed to create fake proxy service") +var ErrRunHTTPServer = errors.New("failed to run http server") +var ErrShutHTTPServer = errors.New("failed to shutdown http server") diff --git a/pkg/testsuite/fake_upstream.go b/pkg/testsuite/fake_upstream.go index 6b707a0c..5613dc31 100644 --- a/pkg/testsuite/fake_upstream.go +++ b/pkg/testsuite/fake_upstream.go @@ -3,9 +3,11 @@ package testsuite import ( "encoding/json" "io" + "net" "net/http" "strings" + "github.com/elazarl/goproxy" "golang.org/x/net/websocket" ) @@ -69,3 +71,21 @@ func (f *FakeUpstreamService) ServeHTTP(wrt http.ResponseWriter, req *http.Reque _, _ = wrt.Write(content) } } + +func createTestProxy() (*http.Server, net.Listener, error) { + proxy := goproxy.NewProxyHttpServer() + proxy.OnRequest().DoFunc( + func(r *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { + r.Header.Set(TestProxyHeaderKey, TestProxyHeaderVal) + return r, nil + }, + ) + proxyHTTPServer := &http.Server{ + Handler: proxy, + } + ln, err := net.Listen("tcp", randomLocalHost) + if err != nil { + return nil, nil, err + } + return proxyHTTPServer, ln, nil +} diff --git a/pkg/testsuite/server_test.go b/pkg/testsuite/server_test.go index e8db132b..5cafaf87 100644 --- a/pkg/testsuite/server_test.go +++ b/pkg/testsuite/server_test.go @@ -19,7 +19,9 @@ limitations under the License. package testsuite import ( + "context" "crypto/tls" + "errors" "fmt" "math/rand" "net/http" @@ -211,7 +213,37 @@ func TestAuthTokenHeader(t *testing.T) { } func TestForwardingProxy(t *testing.T) { - server := httptest.NewServer(&FakeUpstreamService{}) + errChan := make(chan error) + upProxy, lstn, err := createTestProxy() + upstreamProxyURL := fmt.Sprintf("http://%s", lstn.Addr().String()) + if err != nil { + t.Fatal(err) + } + + go func() { + errChan <- upProxy.Serve(lstn) + }() + + fakeUpstream := httptest.NewServer(&FakeUpstreamService{}) + upstreamConfig := newFakeKeycloakConfig() + upstreamConfig.EnableUma = true + upstreamConfig.NoRedirects = true + upstreamConfig.EnableDefaultDeny = true + upstreamConfig.ClientID = ValidUsername + upstreamConfig.ClientSecret = ValidPassword + upstreamConfig.PatRetryCount = 5 + upstreamConfig.PatRetryInterval = 2 * time.Second + upstreamConfig.Upstream = fakeUpstream.URL + // in newFakeProxy we are creating fakeauth server so, we will + // have two different fakeauth servers for upstream and forwarding, + // so we need to skip issuer check, but responses will be same + // so it is ok for this testing + upstreamConfig.SkipAccessTokenIssuerCheck = true + + upstreamProxy := newFakeProxy( + upstreamConfig, + &fakeAuthConfig{Expiration: 900 * time.Millisecond}, + ) testCases := []struct { Name string @@ -232,7 +264,7 @@ func TestForwardingProxy(t *testing.T) { }, ExecutionSettings: []fakeRequest{ { - URL: server.URL + FakeTestURL, + URL: upstreamProxy.getServiceURL() + FakeTestURL, ProxyRequest: true, ExpectedProxy: true, ExpectedCode: http.StatusOK, @@ -253,7 +285,7 @@ func TestForwardingProxy(t *testing.T) { }, ExecutionSettings: []fakeRequest{ { - URL: server.URL + FakeTestURL, + URL: upstreamProxy.getServiceURL() + FakeTestURL, ProxyRequest: true, ExpectedProxy: true, ExpectedCode: http.StatusOK, @@ -261,7 +293,7 @@ func TestForwardingProxy(t *testing.T) { ExpectedContentContains: "Bearer ey", }, { - URL: server.URL + FakeTestURL, + URL: upstreamProxy.getServiceURL() + FakeTestURL, ProxyRequest: true, ExpectedProxy: true, ExpectedCode: http.StatusOK, @@ -282,11 +314,21 @@ func TestForwardingProxy(t *testing.T) { }, ExecutionSettings: []fakeRequest{ { - URL: server.URL + FakeTestURL, + URL: upstreamProxy.getServiceURL() + FakeTestURL, ProxyRequest: true, ExpectedProxy: true, ExpectedCode: http.StatusOK, ExpectedContentContains: "Bearer ey", + Method: "POST", + FormValues: map[string]string{ + "Name": "Whatever", + }, + ExpectedContent: func(body string, testNum int) { + assert.Contains(t, body, FakeTestURL) + assert.Contains(t, body, "method") + assert.Contains(t, body, "Whatever") + assert.NotContains(t, body, TestProxyHeaderVal) + }, }, }, }, @@ -303,7 +345,7 @@ func TestForwardingProxy(t *testing.T) { }, ExecutionSettings: []fakeRequest{ { - URL: server.URL + FakeTestURL, + URL: upstreamProxy.getServiceURL() + FakeTestURL, ProxyRequest: true, ExpectedProxy: true, ExpectedCode: http.StatusOK, @@ -311,7 +353,7 @@ func TestForwardingProxy(t *testing.T) { ExpectedContentContains: "Bearer ey", }, { - URL: server.URL + FakeTestURL, + URL: upstreamProxy.getServiceURL() + FakeTestURL, ProxyRequest: true, ExpectedProxy: true, ExpectedCode: http.StatusOK, @@ -319,6 +361,39 @@ func TestForwardingProxy(t *testing.T) { }, }, }, + { + // forwardingProxy -> middleProxy -> our backend upstreamProxy + Name: "TestClientCredentialsGrantWithMiddleProxy", + ProxySettings: func(conf *config.Config) { + conf.EnableForwarding = true + conf.ForwardingDomains = []string{} + conf.ClientID = ValidUsername + conf.ClientSecret = ValidPassword + conf.ForwardingGrantType = configcore.GrantTypeClientCreds + conf.PatRetryCount = 5 + conf.PatRetryInterval = 2 * time.Second + conf.UpstreamProxy = upstreamProxyURL + }, + ExecutionSettings: []fakeRequest{ + { + URL: upstreamProxy.getServiceURL() + FakeTestURL, + ProxyRequest: true, + ExpectedProxy: true, + ExpectedCode: http.StatusOK, + ExpectedContentContains: "Bearer ey", + Method: "POST", + FormValues: map[string]string{ + "Name": "Whatever", + }, + ExpectedContent: func(body string, testNum int) { + assert.Contains(t, body, FakeTestURL) + assert.Contains(t, body, "method") + assert.Contains(t, body, "Whatever") + assert.Contains(t, body, TestProxyHeaderVal) + }, + }, + }, + }, } for _, testCase := range testCases { @@ -326,15 +401,33 @@ func TestForwardingProxy(t *testing.T) { t.Run( testCase.Name, func(t *testing.T) { - c := newFakeKeycloakConfig() - c.Upstream = server.URL - testCase.ProxySettings(c) - p := newFakeProxy(c, &fakeAuthConfig{Expiration: 900 * time.Millisecond}) + forwardingConfig := newFakeKeycloakConfig() + + testCase.ProxySettings(forwardingConfig) + forwardingProxy := newFakeProxy( + forwardingConfig, + &fakeAuthConfig{}, + ) + <-time.After(time.Duration(100) * time.Millisecond) - p.RunTests(t, testCase.ExecutionSettings) + forwardingProxy.RunTests(t, testCase.ExecutionSettings) }, ) } + + select { + case err = <-errChan: + if err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Fatal(errors.Join(ErrRunHTTPServer, err)) + } + default: + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + err = upProxy.Shutdown(ctx) + if err != nil { + t.Fatal(errors.Join(ErrShutHTTPServer, err)) + } + } } func TestUmaForwardingProxy(t *testing.T) { @@ -447,7 +540,6 @@ func TestUmaForwardingProxy(t *testing.T) { testCase.Name, func(t *testing.T) { forwardingConfig := newFakeKeycloakConfig() - forwardingConfig.Upstream = upstreamProxy.getServiceURL() testCase.ProxySettings(forwardingConfig) forwardingProxy := newFakeProxy( @@ -2055,3 +2147,100 @@ func TestCustomHTTPMethod(t *testing.T) { ) } } + +func TestUpstreamProxy(t *testing.T) { + errChan := make(chan error) + upstream := httptest.NewServer(&FakeUpstreamService{}) + upstreamProxy, lstn, err := createTestProxy() + upstreamProxyURL := fmt.Sprintf("http://%s", lstn.Addr().String()) + if err != nil { + t.Fatal(err) + } + + go func() { + errChan <- upstreamProxy.Serve(lstn) + }() + + testCases := []struct { + Name string + ProxySettings func(c *config.Config) + ExecutionSettings []fakeRequest + }{ + { + Name: "TestUpstreamProxy", + ProxySettings: func(c *config.Config) { + c.UpstreamProxy = upstreamProxyURL + c.Upstream = upstream.URL + }, + ExecutionSettings: []fakeRequest{ + { + URI: "/test", + Method: "POST", + FormValues: map[string]string{ + "Name": "Whatever", + }, + ExpectedProxy: true, + ExpectedCode: http.StatusOK, + ExpectedContentContains: "gzip", + ExpectedContent: func(body string, testNum int) { + assert.Contains(t, body, FakeTestURL) + assert.Contains(t, body, "method") + assert.Contains(t, body, "Whatever") + assert.Contains(t, body, TestProxyHeaderVal) + }, + }, + }, + }, + { + Name: "TestNoUpstreamProxy", + ProxySettings: func(c *config.Config) { + c.Upstream = upstream.URL + }, + ExecutionSettings: []fakeRequest{ + { + URI: FakeTestURL, + Method: "POST", + FormValues: map[string]string{ + "Name": "Whatever", + }, + ExpectedProxy: true, + ExpectedCode: http.StatusOK, + ExpectedContentContains: "gzip", + ExpectedContent: func(body string, testNum int) { + assert.Contains(t, body, FakeTestURL) + assert.Contains(t, body, "method") + assert.Contains(t, body, "Whatever") + assert.NotContains(t, body, TestProxyHeaderVal) + }, + }, + }, + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run( + testCase.Name, + func(t *testing.T) { + c := newFakeKeycloakConfig() + testCase.ProxySettings(c) + p := newFakeProxy(c, &fakeAuthConfig{}) + p.RunTests(t, testCase.ExecutionSettings) + }, + ) + } + + select { + case err = <-errChan: + if err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Fatal(errors.Join(ErrRunHTTPServer, err)) + } + default: + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + err = upstreamProxy.Shutdown(ctx) + if err != nil { + t.Fatal(errors.Join(ErrShutHTTPServer, err)) + } + } +}