Skip to content

Commit

Permalink
feat: allow setting HTTP headers in hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik committed Sep 19, 2023
1 parent 9e9be2d commit 3f4e4f5
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 115 deletions.
46 changes: 40 additions & 6 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ const (
KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional"
KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional"
KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl"
KeyRefreshTokenHookURL = "oauth2.refresh_token_hook" // #nosec G101
KeyTokenHookURL = "oauth2.token_hook" // #nosec G101
KeyRefreshTokenHook = "oauth2.refresh_token_hook" // #nosec G101
KeyTokenHook = "oauth2.token_hook" // #nosec G101
KeyDevelopmentMode = "dev"
)

Expand Down Expand Up @@ -467,12 +467,46 @@ func (p *DefaultProvider) AccessTokenStrategy(ctx context.Context, additionalSou
return s
}

func (p *DefaultProvider) TokenHookURL(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyTokenHookURL, nil)
type HookConfig struct {
URL string `json:"url"`
Headers map[string]string `json:"headers"`
}

func (p *DefaultProvider) TokenRefreshHookURL(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyRefreshTokenHookURL, nil)
func (p *DefaultProvider) getHookConfig(ctx context.Context, key string) *HookConfig {
if hookURL := p.getProvider(ctx).RequestURIF(key, nil); hookURL != nil {
return &HookConfig{
URL: hookURL.String(),
}
}

var hookConfig *HookConfig
if err := p.getProvider(ctx).Unmarshal(key, &hookConfig); err != nil {
p.l.WithError(errors.WithStack(err)).
Errorf("Configuration value from key %s could not be decoded.", key)
return nil
}
if hookConfig == nil {
return nil
}

// validate URL by parsing it
u, err := url.ParseRequestURI(hookConfig.URL)
if err != nil {
p.l.WithError(errors.WithStack(err)).
Errorf("Configuration value from key %s could not be decoded.", key)
return nil
}
hookConfig.URL = u.String()

return hookConfig
}

func (p *DefaultProvider) TokenHookConfig(ctx context.Context) *HookConfig {
return p.getHookConfig(ctx, KeyTokenHook)
}

func (p *DefaultProvider) TokenRefreshHookConfig(ctx context.Context) *HookConfig {
return p.getHookConfig(ctx, KeyRefreshTokenHook)
}

func (p *DefaultProvider) DbIgnoreUnknownTableColumns() bool {
Expand Down
31 changes: 24 additions & 7 deletions driver/config/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -425,17 +424,35 @@ func TestCookieSecure(t *testing.T) {
assert.True(t, c.CookieSecure(ctx))
}

func TestTokenRefreshHookURL(t *testing.T) {
func TestHookConfigs(t *testing.T) {
ctx := context.Background()
l := logrusx.New("", "")
l.Logrus().SetOutput(io.Discard)
c := MustNew(context.Background(), l, configx.SkipValidation())

assert.EqualValues(t, (*url.URL)(nil), c.TokenRefreshHookURL(ctx))
c.MustSet(ctx, KeyRefreshTokenHookURL, "")
assert.EqualValues(t, (*url.URL)(nil), c.TokenRefreshHookURL(ctx))
c.MustSet(ctx, KeyRefreshTokenHookURL, "http://localhost:8080/oauth/token_refresh")
assert.EqualValues(t, "http://localhost:8080/oauth/token_refresh", c.TokenRefreshHookURL(ctx).String())
for key, getFunc := range map[string]func(context.Context) *HookConfig{
KeyRefreshTokenHook: c.TokenRefreshHookConfig,
KeyTokenHook: c.TokenHookConfig,
} {
assert.Nil(t, getFunc(ctx))
c.MustSet(ctx, key, "")
assert.Nil(t, getFunc(ctx))
c.MustSet(ctx, key, "http://localhost:8080/hook")
hc := getFunc(ctx)
require.NotNil(t, hc)
assert.EqualValues(t, "http://localhost:8080/hook", hc.URL)

c.MustSet(ctx, key, map[string]any{
"url": "http://localhost:8080/hook2",
"header": map[string]any{
"My-Headers": "my-value",
},
})
hc = getFunc(ctx)
require.NotNil(t, hc)
assert.EqualValues(t, "http://localhost:8080/hook2", hc.URL)
assert.EqualValues(t, "my-value", hc.Headers["My-Headers"])
}
}

func TestJWTBearer(t *testing.T) {
Expand Down
37 changes: 0 additions & 37 deletions go.sum

Large diffs are not rendered by default.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion oauth2/oauth2_auth_code_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func BenchmarkAuthCode(b *testing.B) {
reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer)
reg.Config().MustSet(ctx, config.KeyLogLevel, "error")
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHookURL, "")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "")
oauth2Keys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OAuth2JWTKeyName, "sig")
require.NoError(b, err)
oidcKeys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OpenIDConnectKeyName, "sig")
Expand Down
62 changes: 33 additions & 29 deletions oauth2/oauth2_auth_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
ctx := context.Background()
reg := internal.NewMockedRegistry(t, &contextx.Default{})
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHookURL, "")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "")
publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg)

publicClient := hydra.NewAPIClient(hydra.NewConfiguration())
Expand Down Expand Up @@ -955,6 +955,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
return func(t *testing.T) {
hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8")
assert.Equal(t, r.Header.Get("Bearer"), "secret value")

var hookReq hydraoauth2.TokenHookRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq))
Expand All @@ -981,9 +982,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{
URL: hs.URL,
Headers: map[string]string{"Bearer": "secret value"},
})

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1030,9 +1034,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1070,9 +1074,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1110,9 +1114,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1657,11 +1661,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand Down Expand Up @@ -1699,11 +1703,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts)
Expand Down Expand Up @@ -1734,11 +1738,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand All @@ -1764,11 +1768,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand All @@ -1794,11 +1798,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand Down
20 changes: 12 additions & 8 deletions oauth2/oauth2_client_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ func TestClientCredentials(t *testing.T) {

hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8")
assert.Equal(t, r.Header.Get("Bearer"), "secret value")

expectedGrantedScopes := []string{"foobar"}
expectedGrantedAudience := []string{"https://api.ory.sh/"}
Expand Down Expand Up @@ -286,9 +287,12 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{
URL: hs.URL,
Headers: map[string]string{"Bearer": "secret value"},
})

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

secret := uuid.New().String()
cl, conf := newCustomClient(t, &hc.Client{
Expand Down Expand Up @@ -316,9 +320,9 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

_, conf := newClient(t)

Expand All @@ -340,9 +344,9 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

_, conf := newClient(t)

Expand All @@ -364,9 +368,9 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

_, conf := newClient(t)

Expand Down
20 changes: 10 additions & 10 deletions oauth2/oauth2_jwt_bearer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,9 @@ func TestJWTBearer(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

conf := newConf(client)
conf.EndpointParams = url.Values{"grant_type": {grantType}, "assertion": {token}}
Expand Down Expand Up @@ -429,9 +429,9 @@ func TestJWTBearer(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

conf := newConf(client)
conf.AuthStyle = goauth2.AuthStyleInParams
Expand All @@ -457,9 +457,9 @@ func TestJWTBearer(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

token, _, err := signer.Generate(ctx, jwt.MapClaims{
"jti": uuid.NewString(),
Expand Down Expand Up @@ -492,9 +492,9 @@ func TestJWTBearer(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

token, _, err := signer.Generate(ctx, jwt.MapClaims{
"jti": uuid.NewString(),
Expand Down Expand Up @@ -527,9 +527,9 @@ func TestJWTBearer(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

token, _, err := signer.Generate(ctx, jwt.MapClaims{
"jti": uuid.NewString(),
Expand Down
Loading

0 comments on commit 3f4e4f5

Please sign in to comment.