From f13649ec06a910b5449f04a1bdb8470efb7ecec1 Mon Sep 17 00:00:00 2001 From: cyb3r4nt <104218001+cyb3r4nt@users.noreply.github.com> Date: Fri, 30 Aug 2024 00:07:23 +0300 Subject: [PATCH] improve provider name handling Add provider name into JWT token claims to allow provider names with multiple underscore "_" symbols. Forbid provider names containing URL reserved symbols. --- auth.go | 34 +++++++++++++++++++++++ auth_test.go | 36 ++++++++++++++++++++++--- middleware/auth.go | 23 +++++++++++----- middleware/auth_test.go | 55 ++++++++++++++++++-------------------- provider/apple.go | 6 +++++ provider/direct.go | 3 +++ provider/oauth1.go | 6 +++++ provider/oauth2.go | 6 +++++ provider/telegram.go | 3 +++ provider/verify.go | 6 +++++ token/jwt.go | 14 +++++++--- v2/auth.go | 34 +++++++++++++++++++++++ v2/auth_test.go | 36 ++++++++++++++++++++++--- v2/middleware/auth.go | 23 +++++++++++----- v2/middleware/auth_test.go | 55 ++++++++++++++++++-------------------- v2/provider/apple.go | 6 +++++ v2/provider/direct.go | 3 +++ v2/provider/oauth1.go | 6 +++++ v2/provider/oauth2.go | 6 +++++ v2/provider/telegram.go | 3 +++ v2/provider/verify.go | 6 +++++ v2/token/jwt.go | 14 +++++++--- 22 files changed, 298 insertions(+), 86 deletions(-) diff --git a/auth.go b/auth.go index 5bd9d879..7a6f2072 100644 --- a/auth.go +++ b/auth.go @@ -4,6 +4,8 @@ package auth import ( "fmt" "net/http" + "net/url" + "regexp" "strings" "time" @@ -267,10 +269,42 @@ func (s *Service) addProviderByName(name string, p provider.Params) { } func (s *Service) addProvider(prov provider.Provider) { + if !s.isValidProviderName(prov.Name()) { + return + } s.providers = append(s.providers, provider.NewService(prov)) s.authMiddleware.Providers = s.providers } +func (s *Service) isValidProviderName(name string) bool { + if strings.TrimSpace(name) == "" { + s.logger.Logf("[ERROR] provider has been ignored because its name is empty") + return false + } + + formatForbidden := func(name string) { + s.logger.Logf("[ERROR] provider has been ignored because its name contains forbidden characters: '%s'", name) + } + + path, err := url.PathUnescape(name) + if err != nil || path != name { + formatForbidden(name) + return false + } + if name != url.PathEscape(name) { + formatForbidden(name) + return false + } + // net/url package does not escape everything (https://github.com/golang/go/issues/5684) + // It is better to reject all reserved characters from https://datatracker.ietf.org/doc/html/rfc3986#section-2.2 + if regexp.MustCompile(`[:/?#\[\]@!$&'\(\)*+,;=]`).MatchString(name) { + formatForbidden(name) + return false + } + + return true +} + // AddProvider adds provider for given name func (s *Service) AddProvider(name, cid, csecret string) { p := provider.Params{ diff --git a/auth_test.go b/auth_test.go index 06a3b8aa..b7945b4f 100644 --- a/auth_test.go +++ b/auth_test.go @@ -246,6 +246,34 @@ func TestIntegrationList(t *testing.T) { assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b)) } +func TestIntegrationInvalidProviderNames(t *testing.T) { + invalidNames := []string{ + "provider/with/slashes", + "provider with spaces", + " providerWithSpacesAround\t", + "providerWithReserved-$-Char", + "providerWithReserved-&-Char", + "providerWithReserved-+-Char", + "providerWithReserved-,-Char", + "providerWithReserved-:-Char", + "providerWithReserved-;-Char", + "providerWithReserved-=-Char", + "providerWithReserved-?-Char", + "providerWithReserved-@-Char", + "providerWith%2F-EscapedSequence", + "", + } + svc, teardown := prepService(t, func(svc *Service) { + for _, name := range invalidNames { + svc.AddCustomProvider(name, Client{"cid", "csecret"}, provider.CustomHandlerOpt{}) + } + }) + defer teardown() + + require.Equal(t, 1, len(svc.Providers())) + require.Equal(t, "dev", svc.Providers()[0].Name()) +} + func TestIntegrationUserInfo(t *testing.T) { _, teardown := prepService(t) defer teardown() @@ -386,7 +414,7 @@ func TestDirectProvider(t *testing.T) { func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { _, teardown := prepService(t, func(svc *Service) { - svc.AddDirectProviderWithUserIDFunc("directCustom", + svc.AddDirectProviderWithUserIDFunc("direct_custom", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { return user == "dev_direct" && password == "password", nil }), @@ -401,12 +429,12 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { jar, err := cookiejar.New(nil) require.Nil(t, err) client := &http.Client{Jar: jar, Timeout: 5 * time.Second} - resp, err := client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=bad") + resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 403, resp.StatusCode) - resp, err = client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=password") + resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 200, resp.StatusCode) @@ -416,7 +444,7 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { t.Logf("resp %s", string(body)) t.Logf("headers: %+v", resp.Header) - assert.Contains(t, string(body), `"name":"dev_direct","id":"directCustom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) + assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) require.Equal(t, 2, len(resp.Cookies())) assert.Equal(t, "JWT", resp.Cookies()[0].Name) diff --git a/middleware/auth.go b/middleware/auth.go index 64d6c7ce..b44b90ef 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -129,7 +129,7 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { } // check if user provider is allowed - if !a.isProviderAllowed(claims.User.ID) { + if !a.isProviderAllowed(&claims) { onError(h, w, r, fmt.Errorf("user %s/%s provider is not allowed", claims.User.Name, claims.User.ID)) a.JWTService.Reset(w) return @@ -153,13 +153,24 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { return f } -// isProviderAllowed checks if user provider is allowed, user id looks like "provider_1234567890" -// this check is needed to reject users from providers what are used to be allowed but not anymore. +// isProviderAllowed checks if user provider is allowed. +// If provider name is explicitly set in the token claims, then that provider is checked. +// +// If user id looks like "provider_1234567890", +// then there is an attempt to extract provider name from that user ID. +// Note that such read can fail if user id has multiple "_" separator symbols. +// +// This check is needed to reject users from providers what are used to be allowed but not anymore. // Such users made token before the provider was disabled and should not be allowed to login anymore. -func (a *Authenticator) isProviderAllowed(userID string) bool { - userProvider := strings.Split(userID, "_")[0] +func (a *Authenticator) isProviderAllowed(claims *token.Claims) bool { + // TODO: remove this read when old tokens expire and all new tokens have a provider name in them + userIDProvider := strings.Split(claims.User.ID, "_")[0] for _, p := range a.Providers { - if p.Name() == userProvider { + name := p.Name() + if claims.AuthProvider != nil && claims.AuthProvider.Name == name { + return true + } + if name == userIDProvider { return true } } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 18208aee..302510d4 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -32,6 +32,8 @@ var testJwtNoUser = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI3ODkxOTE4Mj var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9LCJyb2xlIjoiZW1wbG95ZWUifX0.o95raB0aNl2TWUs43Tu6xyX5Y3Fa5wv6_6RFJuN-d6g" +var testJwtValidWithAuthProvider = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fSwiYXV0aF9wcm92aWRlciI6eyJuYW1lIjoicHJvdmlkZXIxIn19.iBKM9-lgejJNjcs-crj6gkEejnIJpavmaq8alenf0JA" + func TestAuthJWTCookie(t *testing.T) { a := makeTestAuth(t) @@ -51,56 +53,51 @@ func TestAuthJWTCookie(t *testing.T) { client := &http.Client{Timeout: 5 * time.Second} expiration := int(365 * 24 * time.Hour.Seconds()) //nolint - t.Run("valid token", func(t *testing.T) { + makeRequest := func(jwtCookie string, xsrfToken string) *http.Response { req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") + req.AddCookie(&http.Cookie{ + Name: "JWT", + Value: jwtCookie, + HttpOnly: true, + Path: "/", + MaxAge: expiration, + Secure: false, + }) + req.Header.Add("X-XSRF-TOKEN", xsrfToken) resp, err := client.Do(req) require.NoError(t, err) + return resp + } + + t.Run("valid token", func(t *testing.T) { + resp := makeRequest(testJwtValid, "random id") assert.Equal(t, 201, resp.StatusCode, "valid token user") }) - t.Run("valid token, wrong provider", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValidWrongProvider, HttpOnly: true, Path: "/", - MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") + t.Run("valid token with auth_provider", func(t *testing.T) { + resp := makeRequest(testJwtValidWithAuthProvider, "random id") + assert.Equal(t, 201, resp.StatusCode, "valid token user") + }) - resp, err := client.Do(req) - require.NoError(t, err) + t.Run("valid token, wrong provider", func(t *testing.T) { + resp := makeRequest(testJwtValidWrongProvider, "random id") assert.Equal(t, 401, resp.StatusCode, "user name1/provider3_id1 provider is not allowed") }) t.Run("xsrf mismatch", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "wrong id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtValid, "wrong id") assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch") }) t.Run("token expired and refreshed", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtExpired, "random id") assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed") }) t.Run("no user info in the token", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtNoUser, "random id") assert.Equal(t, 401, resp.StatusCode, "no user info in the token") }) } diff --git a/provider/apple.go b/provider/apple.go index 2663cf11..ed73bd44 100644 --- a/provider/apple.go +++ b/provider/apple.go @@ -267,6 +267,9 @@ func (ah *AppleHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), NotBefore: time.Now().Add(-1 * time.Minute).Unix(), }, + AuthProvider: &token.AuthProvider{ + Name: ah.name, + }, } if _, err = ah.JwtService.Set(w, claims); err != nil { @@ -376,6 +379,9 @@ func (ah AppleHandler) AuthHandler(w http.ResponseWriter, r *http.Request) { Audience: oauthClaims.Audience, }, SessionOnly: false, + AuthProvider: &token.AuthProvider{ + Name: ah.name, + }, } if _, err = ah.JwtService.Set(w, claims); err != nil { diff --git a/provider/direct.go b/provider/direct.go index 742ebd5a..e1cef6c8 100644 --- a/provider/direct.go +++ b/provider/direct.go @@ -126,6 +126,9 @@ func (p DirectHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { Audience: creds.Audience, }, SessionOnly: sessOnly, + AuthProvider: &token.AuthProvider{ + Name: p.ProviderName, + }, } if _, err = p.TokenService.Set(w, claims); err != nil { diff --git a/provider/oauth1.go b/provider/oauth1.go index 4aec7f56..7af51e8a 100644 --- a/provider/oauth1.go +++ b/provider/oauth1.go @@ -61,6 +61,9 @@ func (h Oauth1Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), NotBefore: time.Now().Add(-1 * time.Minute).Unix(), }, + AuthProvider: &token.AuthProvider{ + Name: h.name, + }, } if _, err = h.JwtService.Set(w, claims); err != nil { @@ -146,6 +149,9 @@ func (h Oauth1Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { Audience: oauthClaims.Audience, }, SessionOnly: oauthClaims.SessionOnly, + AuthProvider: &token.AuthProvider{ + Name: h.name, + }, } if _, err = h.JwtService.Set(w, claims); err != nil { diff --git a/provider/oauth2.go b/provider/oauth2.go index 6161357e..11fbd3e1 100644 --- a/provider/oauth2.go +++ b/provider/oauth2.go @@ -118,6 +118,9 @@ func (p Oauth2Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { NotBefore: time.Now().Add(-1 * time.Minute).Unix(), }, NoAva: r.URL.Query().Get("noava") == "1", + AuthProvider: &token.AuthProvider{ + Name: p.name, + }, } if _, err := p.JwtService.Set(w, claims); err != nil { @@ -215,6 +218,9 @@ func (p Oauth2Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { }, SessionOnly: oauthClaims.SessionOnly, NoAva: oauthClaims.NoAva, + AuthProvider: &token.AuthProvider{ + Name: p.name, + }, } if _, err = p.JwtService.Set(w, claims); err != nil { diff --git a/provider/telegram.go b/provider/telegram.go index 177c1c02..5868048e 100644 --- a/provider/telegram.go +++ b/provider/telegram.go @@ -310,6 +310,9 @@ func (th *TelegramHandler) LoginHandler(w http.ResponseWriter, r *http.Request) NotBefore: time.Now().Add(-1 * time.Minute).Unix(), }, SessionOnly: false, // TODO review? + AuthProvider: &authtoken.AuthProvider{ + Name: th.Name(), + }, } if _, err := th.TokenService.Set(w, claims); err != nil { diff --git a/provider/verify.go b/provider/verify.go index 8b0a03dd..4575055e 100644 --- a/provider/verify.go +++ b/provider/verify.go @@ -117,6 +117,9 @@ func (e VerifyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { Audience: confClaims.Audience, }, SessionOnly: sessOnly, + AuthProvider: &token.AuthProvider{ + Name: e.ProviderName, + }, } if _, err = e.TokenService.Set(w, claims); err != nil { @@ -152,6 +155,9 @@ func (e VerifyHandler) sendConfirmation(w http.ResponseWriter, r *http.Request) NotBefore: time.Now().Add(-1 * time.Minute).Unix(), Issuer: e.Issuer, }, + AuthProvider: &token.AuthProvider{ + Name: e.ProviderName, + }, } tkn, err := e.TokenService.Token(claims) diff --git a/token/jwt.go b/token/jwt.go index 73cc1c2c..e0396beb 100644 --- a/token/jwt.go +++ b/token/jwt.go @@ -21,10 +21,11 @@ type Service struct { // Claims stores user info for token and state & from from login type Claims struct { jwt.StandardClaims - User *User `json:"user,omitempty"` // user info - SessionOnly bool `json:"sess_only,omitempty"` - Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake - NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon + User *User `json:"user,omitempty"` // user info + SessionOnly bool `json:"sess_only,omitempty"` + Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake + NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon + AuthProvider *AuthProvider `json:"auth_provider,omitempty"` // auth provider info } // Handshake used for oauth handshake @@ -34,6 +35,11 @@ type Handshake struct { ID string `json:"id,omitempty"` } +// AuthProvider stores attributes of provider which has created a JWT token +type AuthProvider struct { + Name string `json:"name,omitempty"` +} + const ( // default names for cookies and headers defaultJWTCookieName = "JWT" diff --git a/v2/auth.go b/v2/auth.go index 30d4a613..23ff69d8 100644 --- a/v2/auth.go +++ b/v2/auth.go @@ -4,6 +4,8 @@ package auth import ( "fmt" "net/http" + "net/url" + "regexp" "strings" "time" @@ -267,10 +269,42 @@ func (s *Service) addProviderByName(name string, p provider.Params) { } func (s *Service) addProvider(prov provider.Provider) { + if !s.isValidProviderName(prov.Name()) { + return + } s.providers = append(s.providers, provider.NewService(prov)) s.authMiddleware.Providers = s.providers } +func (s *Service) isValidProviderName(name string) bool { + if strings.TrimSpace(name) == "" { + s.logger.Logf("[ERROR] provider has been ignored because its name is empty") + return false + } + + formatForbidden := func(name string) { + s.logger.Logf("[ERROR] provider has been ignored because its name contains forbidden characters: '%s'", name) + } + + path, err := url.PathUnescape(name) + if err != nil || path != name { + formatForbidden(name) + return false + } + if name != url.PathEscape(name) { + formatForbidden(name) + return false + } + // net/url package does not escape everything (https://github.com/golang/go/issues/5684) + // It is better to reject all reserved characters from https://datatracker.ietf.org/doc/html/rfc3986#section-2.2 + if regexp.MustCompile(`[:/?#\[\]@!$&'\(\)*+,;=]`).MatchString(name) { + formatForbidden(name) + return false + } + + return true +} + // AddProvider adds provider for given name func (s *Service) AddProvider(name, cid, csecret string) { p := provider.Params{ diff --git a/v2/auth_test.go b/v2/auth_test.go index 81655bb6..b8dbf724 100644 --- a/v2/auth_test.go +++ b/v2/auth_test.go @@ -246,6 +246,34 @@ func TestIntegrationList(t *testing.T) { assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b)) } +func TestIntegrationInvalidProviderNames(t *testing.T) { + invalidNames := []string{ + "provider/with/slashes", + "provider with spaces", + " providerWithSpacesAround\t", + "providerWithReserved-$-Char", + "providerWithReserved-&-Char", + "providerWithReserved-+-Char", + "providerWithReserved-,-Char", + "providerWithReserved-:-Char", + "providerWithReserved-;-Char", + "providerWithReserved-=-Char", + "providerWithReserved-?-Char", + "providerWithReserved-@-Char", + "providerWith%2F-EscapedSequence", + "", + } + svc, teardown := prepService(t, func(svc *Service) { + for _, name := range invalidNames { + svc.AddCustomProvider(name, Client{"cid", "csecret"}, provider.CustomHandlerOpt{}) + } + }) + defer teardown() + + require.Equal(t, 1, len(svc.Providers())) + require.Equal(t, "dev", svc.Providers()[0].Name()) +} + func TestIntegrationUserInfo(t *testing.T) { _, teardown := prepService(t) defer teardown() @@ -386,7 +414,7 @@ func TestDirectProvider(t *testing.T) { func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { _, teardown := prepService(t, func(svc *Service) { - svc.AddDirectProviderWithUserIDFunc("directCustom", + svc.AddDirectProviderWithUserIDFunc("direct_custom", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { return user == "dev_direct" && password == "password", nil }), @@ -401,12 +429,12 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { jar, err := cookiejar.New(nil) require.Nil(t, err) client := &http.Client{Jar: jar, Timeout: 5 * time.Second} - resp, err := client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=bad") + resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 403, resp.StatusCode) - resp, err = client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=password") + resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 200, resp.StatusCode) @@ -416,7 +444,7 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { t.Logf("resp %s", string(body)) t.Logf("headers: %+v", resp.Header) - assert.Contains(t, string(body), `"name":"dev_direct","id":"directCustom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) + assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) require.Equal(t, 2, len(resp.Cookies())) assert.Equal(t, "JWT", resp.Cookies()[0].Name) diff --git a/v2/middleware/auth.go b/v2/middleware/auth.go index 6e8c5ed8..6aafa033 100644 --- a/v2/middleware/auth.go +++ b/v2/middleware/auth.go @@ -129,7 +129,7 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { } // check if user provider is allowed - if !a.isProviderAllowed(claims.User.ID) { + if !a.isProviderAllowed(&claims) { onError(h, w, r, fmt.Errorf("user %s/%s provider is not allowed", claims.User.Name, claims.User.ID)) a.JWTService.Reset(w) return @@ -153,13 +153,24 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { return f } -// isProviderAllowed checks if user provider is allowed, user id looks like "provider_1234567890" -// this check is needed to reject users from providers what are used to be allowed but not anymore. +// isProviderAllowed checks if user provider is allowed. +// If provider name is explicitly set in the token claims, then that provider is checked. +// +// If user id looks like "provider_1234567890", +// then there is an attempt to extract provider name from that user ID. +// Note that such read can fail if user id has multiple "_" separator symbols. +// +// This check is needed to reject users from providers what are used to be allowed but not anymore. // Such users made token before the provider was disabled and should not be allowed to login anymore. -func (a *Authenticator) isProviderAllowed(userID string) bool { - userProvider := strings.Split(userID, "_")[0] +func (a *Authenticator) isProviderAllowed(claims *token.Claims) bool { + // TODO: remove this read when old tokens expire and all new tokens have a provider name in them + userIDProvider := strings.Split(claims.User.ID, "_")[0] for _, p := range a.Providers { - if p.Name() == userProvider { + name := p.Name() + if claims.AuthProvider != nil && claims.AuthProvider.Name == name { + return true + } + if name == userIDProvider { return true } } diff --git a/v2/middleware/auth_test.go b/v2/middleware/auth_test.go index d71e0c9f..9ce53229 100644 --- a/v2/middleware/auth_test.go +++ b/v2/middleware/auth_test.go @@ -32,6 +32,8 @@ var testJwtNoUser = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI3ODkxOTE4Mj var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9LCJyb2xlIjoiZW1wbG95ZWUifX0.o95raB0aNl2TWUs43Tu6xyX5Y3Fa5wv6_6RFJuN-d6g" +var testJwtValidWithAuthProvider = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fSwiYXV0aF9wcm92aWRlciI6eyJuYW1lIjoicHJvdmlkZXIxIn19.iBKM9-lgejJNjcs-crj6gkEejnIJpavmaq8alenf0JA" + func TestAuthJWTCookie(t *testing.T) { a := makeTestAuth(t) @@ -51,56 +53,51 @@ func TestAuthJWTCookie(t *testing.T) { client := &http.Client{Timeout: 5 * time.Second} expiration := int(365 * 24 * time.Hour.Seconds()) //nolint - t.Run("valid token", func(t *testing.T) { + makeRequest := func(jwtCookie string, xsrfToken string) *http.Response { req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") + req.AddCookie(&http.Cookie{ + Name: "JWT", + Value: jwtCookie, + HttpOnly: true, + Path: "/", + MaxAge: expiration, + Secure: false, + }) + req.Header.Add("X-XSRF-TOKEN", xsrfToken) resp, err := client.Do(req) require.NoError(t, err) + return resp + } + + t.Run("valid token", func(t *testing.T) { + resp := makeRequest(testJwtValid, "random id") assert.Equal(t, 201, resp.StatusCode, "valid token user") }) - t.Run("valid token, wrong provider", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValidWrongProvider, HttpOnly: true, Path: "/", - MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") + t.Run("valid token with auth_provider", func(t *testing.T) { + resp := makeRequest(testJwtValidWithAuthProvider, "random id") + assert.Equal(t, 201, resp.StatusCode, "valid token user") + }) - resp, err := client.Do(req) - require.NoError(t, err) + t.Run("valid token, wrong provider", func(t *testing.T) { + resp := makeRequest(testJwtValidWrongProvider, "random id") assert.Equal(t, 401, resp.StatusCode, "user name1/provider3_id1 provider is not allowed") }) t.Run("xsrf mismatch", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "wrong id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtValid, "wrong id") assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch") }) t.Run("token expired and refreshed", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtExpired, "random id") assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed") }) t.Run("no user info in the token", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtNoUser, "random id") assert.Equal(t, 401, resp.StatusCode, "no user info in the token") }) } diff --git a/v2/provider/apple.go b/v2/provider/apple.go index 4a2fbcd0..97005833 100644 --- a/v2/provider/apple.go +++ b/v2/provider/apple.go @@ -267,6 +267,9 @@ func (ah *AppleHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * time.Minute)), NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, + AuthProvider: &token.AuthProvider{ + Name: ah.name, + }, } if _, err = ah.JwtService.Set(w, claims); err != nil { @@ -376,6 +379,9 @@ func (ah AppleHandler) AuthHandler(w http.ResponseWriter, r *http.Request) { Audience: oauthClaims.Audience, }, SessionOnly: false, + AuthProvider: &token.AuthProvider{ + Name: ah.name, + }, } if _, err = ah.JwtService.Set(w, claims); err != nil { diff --git a/v2/provider/direct.go b/v2/provider/direct.go index 7d940177..9844989c 100644 --- a/v2/provider/direct.go +++ b/v2/provider/direct.go @@ -126,6 +126,9 @@ func (p DirectHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { Audience: []string{creds.Audience}, }, SessionOnly: sessOnly, + AuthProvider: &token.AuthProvider{ + Name: p.ProviderName, + }, } if _, err = p.TokenService.Set(w, claims); err != nil { diff --git a/v2/provider/oauth1.go b/v2/provider/oauth1.go index 6a51b0da..8484a4f0 100644 --- a/v2/provider/oauth1.go +++ b/v2/provider/oauth1.go @@ -61,6 +61,9 @@ func (h Oauth1Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * time.Minute)), NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, + AuthProvider: &token.AuthProvider{ + Name: h.name, + }, } if _, err = h.JwtService.Set(w, claims); err != nil { @@ -146,6 +149,9 @@ func (h Oauth1Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { Audience: oauthClaims.Audience, }, SessionOnly: oauthClaims.SessionOnly, + AuthProvider: &token.AuthProvider{ + Name: h.name, + }, } if _, err = h.JwtService.Set(w, claims); err != nil { diff --git a/v2/provider/oauth2.go b/v2/provider/oauth2.go index a985cd98..273a1906 100644 --- a/v2/provider/oauth2.go +++ b/v2/provider/oauth2.go @@ -118,6 +118,9 @@ func (p Oauth2Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, NoAva: r.URL.Query().Get("noava") == "1", + AuthProvider: &token.AuthProvider{ + Name: p.name, + }, } if _, err := p.JwtService.Set(w, claims); err != nil { @@ -215,6 +218,9 @@ func (p Oauth2Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { }, SessionOnly: oauthClaims.SessionOnly, NoAva: oauthClaims.NoAva, + AuthProvider: &token.AuthProvider{ + Name: p.name, + }, } if _, err = p.JwtService.Set(w, claims); err != nil { diff --git a/v2/provider/telegram.go b/v2/provider/telegram.go index 906fe59f..5b884eb9 100644 --- a/v2/provider/telegram.go +++ b/v2/provider/telegram.go @@ -310,6 +310,9 @@ func (th *TelegramHandler) LoginHandler(w http.ResponseWriter, r *http.Request) NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, SessionOnly: false, // TODO review? + AuthProvider: &authtoken.AuthProvider{ + Name: th.ProviderName, + }, } if _, err := th.TokenService.Set(w, claims); err != nil { diff --git a/v2/provider/verify.go b/v2/provider/verify.go index 0fc9ca6e..1c459532 100644 --- a/v2/provider/verify.go +++ b/v2/provider/verify.go @@ -117,6 +117,9 @@ func (e VerifyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { Audience: confClaims.Audience, }, SessionOnly: sessOnly, + AuthProvider: &token.AuthProvider{ + Name: e.ProviderName, + }, } if _, err = e.TokenService.Set(w, claims); err != nil { @@ -152,6 +155,9 @@ func (e VerifyHandler) sendConfirmation(w http.ResponseWriter, r *http.Request) NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), Issuer: e.Issuer, }, + AuthProvider: &token.AuthProvider{ + Name: e.ProviderName, + }, } tkn, err := e.TokenService.Token(claims) diff --git a/v2/token/jwt.go b/v2/token/jwt.go index 8350ae85..352fcb3a 100644 --- a/v2/token/jwt.go +++ b/v2/token/jwt.go @@ -23,10 +23,11 @@ type Service struct { // Claims stores user info for token and state & from from login type Claims struct { jwt.RegisteredClaims - User *User `json:"user,omitempty"` // user info - SessionOnly bool `json:"sess_only,omitempty"` - Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake - NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon + User *User `json:"user,omitempty"` // user info + SessionOnly bool `json:"sess_only,omitempty"` + Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake + NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon + AuthProvider *AuthProvider `json:"auth_provider,omitempty"` // auth provider info } // Handshake used for oauth handshake @@ -36,6 +37,11 @@ type Handshake struct { ID string `json:"id,omitempty"` } +// AuthProvider stores attributes of provider which has created a JWT token +type AuthProvider struct { + Name string `json:"name,omitempty"` +} + const ( // default names for cookies and headers defaultJWTCookieName = "JWT"