diff --git a/README.md b/README.md index 9d07ff2..2ddb757 100644 --- a/README.md +++ b/README.md @@ -532,7 +532,7 @@ Follow to next steps for configuring on the Apple side: After completing the previous steps, you can proceed with configuring the Apple auth provider. Here are the parameters for AppleConfig: -- _ClientID_ (**required**) - Service ID identifier which is used for Sign with Apple +- _ClientID_ (**required**) - Service ID (or App ID) which is used for Sign with Apple - _TeamID_ (**required**) - Identifier a developer account (use as prefix for all App ID) - _KeyID_ (**required**) - Identifier a generated key for Sign with Apple - _ResponseMode_ - Response Mode, please see [documentation](https://developer.apple.com/documentation/sign_in_with_apple/request_an_authorization_to_the_sign_in_with_apple_server?changes=_1_2#4066168) for reference, default is `form_post` @@ -541,7 +541,7 @@ After completing the previous steps, you can proceed with configuring the Apple // apple config parameters appleCfg := provider.AppleConfig{ TeamID: os.Getenv("AEXMPL_APPLE_TID"), // developer account identifier - ClientID: os.Getenv("AEXMPL_APPLE_CID"), // service identifier + ClientID: os.Getenv("AEXMPL_APPLE_CID"), // Service ID (or App ID) KeyID: os.Getenv("AEXMPL_APPLE_KEYID"), // private key identifier } ``` diff --git a/auth.go b/auth.go index 5bd9d87..7a6f207 100644 --- a/auth.go +++ b/auth.go @@ -4,6 +4,8 @@ package auth import ( "fmt" "net/http" + "net/url" + "regexp" "strings" "time" @@ -267,10 +269,42 @@ func (s *Service) addProviderByName(name string, p provider.Params) { } func (s *Service) addProvider(prov provider.Provider) { + if !s.isValidProviderName(prov.Name()) { + return + } s.providers = append(s.providers, provider.NewService(prov)) s.authMiddleware.Providers = s.providers } +func (s *Service) isValidProviderName(name string) bool { + if strings.TrimSpace(name) == "" { + s.logger.Logf("[ERROR] provider has been ignored because its name is empty") + return false + } + + formatForbidden := func(name string) { + s.logger.Logf("[ERROR] provider has been ignored because its name contains forbidden characters: '%s'", name) + } + + path, err := url.PathUnescape(name) + if err != nil || path != name { + formatForbidden(name) + return false + } + if name != url.PathEscape(name) { + formatForbidden(name) + return false + } + // net/url package does not escape everything (https://github.com/golang/go/issues/5684) + // It is better to reject all reserved characters from https://datatracker.ietf.org/doc/html/rfc3986#section-2.2 + if regexp.MustCompile(`[:/?#\[\]@!$&'\(\)*+,;=]`).MatchString(name) { + formatForbidden(name) + return false + } + + return true +} + // AddProvider adds provider for given name func (s *Service) AddProvider(name, cid, csecret string) { p := provider.Params{ diff --git a/auth_test.go b/auth_test.go index 06a3b8a..b7945b4 100644 --- a/auth_test.go +++ b/auth_test.go @@ -246,6 +246,34 @@ func TestIntegrationList(t *testing.T) { assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b)) } +func TestIntegrationInvalidProviderNames(t *testing.T) { + invalidNames := []string{ + "provider/with/slashes", + "provider with spaces", + " providerWithSpacesAround\t", + "providerWithReserved-$-Char", + "providerWithReserved-&-Char", + "providerWithReserved-+-Char", + "providerWithReserved-,-Char", + "providerWithReserved-:-Char", + "providerWithReserved-;-Char", + "providerWithReserved-=-Char", + "providerWithReserved-?-Char", + "providerWithReserved-@-Char", + "providerWith%2F-EscapedSequence", + "", + } + svc, teardown := prepService(t, func(svc *Service) { + for _, name := range invalidNames { + svc.AddCustomProvider(name, Client{"cid", "csecret"}, provider.CustomHandlerOpt{}) + } + }) + defer teardown() + + require.Equal(t, 1, len(svc.Providers())) + require.Equal(t, "dev", svc.Providers()[0].Name()) +} + func TestIntegrationUserInfo(t *testing.T) { _, teardown := prepService(t) defer teardown() @@ -386,7 +414,7 @@ func TestDirectProvider(t *testing.T) { func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { _, teardown := prepService(t, func(svc *Service) { - svc.AddDirectProviderWithUserIDFunc("directCustom", + svc.AddDirectProviderWithUserIDFunc("direct_custom", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { return user == "dev_direct" && password == "password", nil }), @@ -401,12 +429,12 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { jar, err := cookiejar.New(nil) require.Nil(t, err) client := &http.Client{Jar: jar, Timeout: 5 * time.Second} - resp, err := client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=bad") + resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 403, resp.StatusCode) - resp, err = client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=password") + resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 200, resp.StatusCode) @@ -416,7 +444,7 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { t.Logf("resp %s", string(body)) t.Logf("headers: %+v", resp.Header) - assert.Contains(t, string(body), `"name":"dev_direct","id":"directCustom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) + assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) require.Equal(t, 2, len(resp.Cookies())) assert.Equal(t, "JWT", resp.Cookies()[0].Name) diff --git a/middleware/auth.go b/middleware/auth.go index 64d6c7c..b44b90e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -129,7 +129,7 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { } // check if user provider is allowed - if !a.isProviderAllowed(claims.User.ID) { + if !a.isProviderAllowed(&claims) { onError(h, w, r, fmt.Errorf("user %s/%s provider is not allowed", claims.User.Name, claims.User.ID)) a.JWTService.Reset(w) return @@ -153,13 +153,24 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { return f } -// isProviderAllowed checks if user provider is allowed, user id looks like "provider_1234567890" -// this check is needed to reject users from providers what are used to be allowed but not anymore. +// isProviderAllowed checks if user provider is allowed. +// If provider name is explicitly set in the token claims, then that provider is checked. +// +// If user id looks like "provider_1234567890", +// then there is an attempt to extract provider name from that user ID. +// Note that such read can fail if user id has multiple "_" separator symbols. +// +// This check is needed to reject users from providers what are used to be allowed but not anymore. // Such users made token before the provider was disabled and should not be allowed to login anymore. -func (a *Authenticator) isProviderAllowed(userID string) bool { - userProvider := strings.Split(userID, "_")[0] +func (a *Authenticator) isProviderAllowed(claims *token.Claims) bool { + // TODO: remove this read when old tokens expire and all new tokens have a provider name in them + userIDProvider := strings.Split(claims.User.ID, "_")[0] for _, p := range a.Providers { - if p.Name() == userProvider { + name := p.Name() + if claims.AuthProvider != nil && claims.AuthProvider.Name == name { + return true + } + if name == userIDProvider { return true } } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 18208ae..302510d 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -32,6 +32,8 @@ var testJwtNoUser = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI3ODkxOTE4Mj var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9LCJyb2xlIjoiZW1wbG95ZWUifX0.o95raB0aNl2TWUs43Tu6xyX5Y3Fa5wv6_6RFJuN-d6g" +var testJwtValidWithAuthProvider = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fSwiYXV0aF9wcm92aWRlciI6eyJuYW1lIjoicHJvdmlkZXIxIn19.iBKM9-lgejJNjcs-crj6gkEejnIJpavmaq8alenf0JA" + func TestAuthJWTCookie(t *testing.T) { a := makeTestAuth(t) @@ -51,56 +53,51 @@ func TestAuthJWTCookie(t *testing.T) { client := &http.Client{Timeout: 5 * time.Second} expiration := int(365 * 24 * time.Hour.Seconds()) //nolint - t.Run("valid token", func(t *testing.T) { + makeRequest := func(jwtCookie string, xsrfToken string) *http.Response { req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") + req.AddCookie(&http.Cookie{ + Name: "JWT", + Value: jwtCookie, + HttpOnly: true, + Path: "/", + MaxAge: expiration, + Secure: false, + }) + req.Header.Add("X-XSRF-TOKEN", xsrfToken) resp, err := client.Do(req) require.NoError(t, err) + return resp + } + + t.Run("valid token", func(t *testing.T) { + resp := makeRequest(testJwtValid, "random id") assert.Equal(t, 201, resp.StatusCode, "valid token user") }) - t.Run("valid token, wrong provider", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValidWrongProvider, HttpOnly: true, Path: "/", - MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") + t.Run("valid token with auth_provider", func(t *testing.T) { + resp := makeRequest(testJwtValidWithAuthProvider, "random id") + assert.Equal(t, 201, resp.StatusCode, "valid token user") + }) - resp, err := client.Do(req) - require.NoError(t, err) + t.Run("valid token, wrong provider", func(t *testing.T) { + resp := makeRequest(testJwtValidWrongProvider, "random id") assert.Equal(t, 401, resp.StatusCode, "user name1/provider3_id1 provider is not allowed") }) t.Run("xsrf mismatch", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "wrong id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtValid, "wrong id") assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch") }) t.Run("token expired and refreshed", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtExpired, "random id") assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed") }) t.Run("no user info in the token", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtNoUser, "random id") assert.Equal(t, 401, resp.StatusCode, "no user info in the token") }) } diff --git a/provider/apple.go b/provider/apple.go index 2663cf1..ed73bd4 100644 --- a/provider/apple.go +++ b/provider/apple.go @@ -267,6 +267,9 @@ func (ah *AppleHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), NotBefore: time.Now().Add(-1 * time.Minute).Unix(), }, + AuthProvider: &token.AuthProvider{ + Name: ah.name, + }, } if _, err = ah.JwtService.Set(w, claims); err != nil { @@ -376,6 +379,9 @@ func (ah AppleHandler) AuthHandler(w http.ResponseWriter, r *http.Request) { Audience: oauthClaims.Audience, }, SessionOnly: false, + AuthProvider: &token.AuthProvider{ + Name: ah.name, + }, } if _, err = ah.JwtService.Set(w, claims); err != nil { diff --git a/provider/apple_pubkeys.go b/provider/apple_pubkeys.go index ce0ccde..5c3563e 100644 --- a/provider/apple_pubkeys.go +++ b/provider/apple_pubkeys.go @@ -151,7 +151,7 @@ type appleKeySet struct { // get return Apple public key with specific KeyID (kid) func (aks *appleKeySet) get(kid string) (keys *applePublicKey, err error) { - if aks.keys == nil || len(aks.keys) == 0 { + if len(aks.keys) == 0 { return nil, fmt.Errorf("failed to get key in appleKeySet, key set is nil or empty") } diff --git a/provider/apple_test.go b/provider/apple_test.go index f448ae7..51994a8 100644 --- a/provider/apple_test.go +++ b/provider/apple_test.go @@ -5,11 +5,11 @@ import ( "crypto/rsa" "crypto/sha1" "encoding/base64" - "encoding/binary" "encoding/json" "fmt" "io" "log" + "math/big" "net/http" "net/http/cookiejar" "net/url" @@ -659,8 +659,8 @@ ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy n := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(publicKey.N.Bytes()) // convert exponent - eBuff := make([]byte, 4) - binary.LittleEndian.PutUint32(eBuff, uint32(publicKey.E)) + require.Positive(t, publicKey.E, "RSA exponent must be positive") + eBuff := big.NewInt(int64(publicKey.E)).Bytes() e := base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(eBuff) JWK := struct { @@ -670,7 +670,7 @@ ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy Kid string `json:"kid"` E string `json:"e"` N string `json:"n"` - }{Alg: "RS256", Kty: "RSA", Use: "sig", Kid: "112233", N: n, E: e[:4]} + }{Alg: "RS256", Kty: "RSA", Use: "sig", Kid: "112233", N: n, E: e} var buffJwk []byte buffJwk, err = json.Marshal(JWK) diff --git a/provider/direct.go b/provider/direct.go index 742ebd5..e1cef6c 100644 --- a/provider/direct.go +++ b/provider/direct.go @@ -126,6 +126,9 @@ func (p DirectHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { Audience: creds.Audience, }, SessionOnly: sessOnly, + AuthProvider: &token.AuthProvider{ + Name: p.ProviderName, + }, } if _, err = p.TokenService.Set(w, claims); err != nil { diff --git a/provider/oauth1.go b/provider/oauth1.go index 4aec7f5..7af51e8 100644 --- a/provider/oauth1.go +++ b/provider/oauth1.go @@ -61,6 +61,9 @@ func (h Oauth1Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), NotBefore: time.Now().Add(-1 * time.Minute).Unix(), }, + AuthProvider: &token.AuthProvider{ + Name: h.name, + }, } if _, err = h.JwtService.Set(w, claims); err != nil { @@ -146,6 +149,9 @@ func (h Oauth1Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { Audience: oauthClaims.Audience, }, SessionOnly: oauthClaims.SessionOnly, + AuthProvider: &token.AuthProvider{ + Name: h.name, + }, } if _, err = h.JwtService.Set(w, claims); err != nil { diff --git a/provider/oauth2.go b/provider/oauth2.go index 6161357..11fbd3e 100644 --- a/provider/oauth2.go +++ b/provider/oauth2.go @@ -118,6 +118,9 @@ func (p Oauth2Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { NotBefore: time.Now().Add(-1 * time.Minute).Unix(), }, NoAva: r.URL.Query().Get("noava") == "1", + AuthProvider: &token.AuthProvider{ + Name: p.name, + }, } if _, err := p.JwtService.Set(w, claims); err != nil { @@ -215,6 +218,9 @@ func (p Oauth2Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { }, SessionOnly: oauthClaims.SessionOnly, NoAva: oauthClaims.NoAva, + AuthProvider: &token.AuthProvider{ + Name: p.name, + }, } if _, err = p.JwtService.Set(w, claims); err != nil { diff --git a/provider/telegram.go b/provider/telegram.go index 177c1c0..80031db 100644 --- a/provider/telegram.go +++ b/provider/telegram.go @@ -310,6 +310,9 @@ func (th *TelegramHandler) LoginHandler(w http.ResponseWriter, r *http.Request) NotBefore: time.Now().Add(-1 * time.Minute).Unix(), }, SessionOnly: false, // TODO review? + AuthProvider: &authtoken.AuthProvider{ + Name: th.ProviderName, + }, } if _, err := th.TokenService.Set(w, claims); err != nil { diff --git a/provider/telegram_test.go b/provider/telegram_test.go index 12afe07..f56839b 100644 --- a/provider/telegram_test.go +++ b/provider/telegram_test.go @@ -89,22 +89,37 @@ func TestTelegramUnconfirmedRequest(t *testing.T) { func TestTelegramConfirmedRequest(t *testing.T) { var servedToken string - var mu sync.Mutex + + // is set when token becomes used, + // no sync is required because only a single goroutine in TelegramHandler.Run() reads and writes it + var tokenAlreadyUsed bool + + var wgToken sync.WaitGroup + wgToken.Add(1) + defer func() { + if t.Failed() && servedToken == "" { + wgToken.Done() // for the case when test fails before token is generated + } + }() m := &TelegramAPIMock{ GetUpdatesFunc: func(ctx context.Context) (*telegramUpdate, error) { - var upd telegramUpdate + wgToken.Wait() - mu.Lock() - defer mu.Unlock() - if servedToken != "" { - resp := fmt.Sprintf(getUpdatesResp, servedToken) + if tokenAlreadyUsed || t.Failed() { + return nil, fmt.Errorf("token %s has been already used", servedToken) + } - err := json.Unmarshal([]byte(resp), &upd) - if err != nil { - t.Fatal(err) - } + var upd telegramUpdate + resp := fmt.Sprintf(getUpdatesResp, servedToken) + err := json.Unmarshal([]byte(resp), &upd) + if err != nil { + t.Fatal(err) } + + // token is served only once + tokenAlreadyUsed = true + return &upd, nil }, AvatarFunc: func(ctx context.Context, userID int) (string, error) { @@ -147,10 +162,10 @@ func TestTelegramConfirmedRequest(t *testing.T) { err := json.Unmarshal(w.Body.Bytes(), &resp) assert.NoError(t, err) assert.Equal(t, "my_auth_bot", resp.Bot) + assert.NotEmpty(t, resp.Token) - mu.Lock() servedToken = resp.Token - mu.Unlock() + wgToken.Done() // Check the token confirmation assert.Eventually(t, func() bool { diff --git a/provider/verify.go b/provider/verify.go index 8b0a03d..4575055 100644 --- a/provider/verify.go +++ b/provider/verify.go @@ -117,6 +117,9 @@ func (e VerifyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { Audience: confClaims.Audience, }, SessionOnly: sessOnly, + AuthProvider: &token.AuthProvider{ + Name: e.ProviderName, + }, } if _, err = e.TokenService.Set(w, claims); err != nil { @@ -152,6 +155,9 @@ func (e VerifyHandler) sendConfirmation(w http.ResponseWriter, r *http.Request) NotBefore: time.Now().Add(-1 * time.Minute).Unix(), Issuer: e.Issuer, }, + AuthProvider: &token.AuthProvider{ + Name: e.ProviderName, + }, } tkn, err := e.TokenService.Token(claims) diff --git a/token/jwt.go b/token/jwt.go index 73cc1c2..e0396be 100644 --- a/token/jwt.go +++ b/token/jwt.go @@ -21,10 +21,11 @@ type Service struct { // Claims stores user info for token and state & from from login type Claims struct { jwt.StandardClaims - User *User `json:"user,omitempty"` // user info - SessionOnly bool `json:"sess_only,omitempty"` - Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake - NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon + User *User `json:"user,omitempty"` // user info + SessionOnly bool `json:"sess_only,omitempty"` + Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake + NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon + AuthProvider *AuthProvider `json:"auth_provider,omitempty"` // auth provider info } // Handshake used for oauth handshake @@ -34,6 +35,11 @@ type Handshake struct { ID string `json:"id,omitempty"` } +// AuthProvider stores attributes of provider which has created a JWT token +type AuthProvider struct { + Name string `json:"name,omitempty"` +} + const ( // default names for cookies and headers defaultJWTCookieName = "JWT" diff --git a/v2/auth.go b/v2/auth.go index 30d4a61..23ff69d 100644 --- a/v2/auth.go +++ b/v2/auth.go @@ -4,6 +4,8 @@ package auth import ( "fmt" "net/http" + "net/url" + "regexp" "strings" "time" @@ -267,10 +269,42 @@ func (s *Service) addProviderByName(name string, p provider.Params) { } func (s *Service) addProvider(prov provider.Provider) { + if !s.isValidProviderName(prov.Name()) { + return + } s.providers = append(s.providers, provider.NewService(prov)) s.authMiddleware.Providers = s.providers } +func (s *Service) isValidProviderName(name string) bool { + if strings.TrimSpace(name) == "" { + s.logger.Logf("[ERROR] provider has been ignored because its name is empty") + return false + } + + formatForbidden := func(name string) { + s.logger.Logf("[ERROR] provider has been ignored because its name contains forbidden characters: '%s'", name) + } + + path, err := url.PathUnescape(name) + if err != nil || path != name { + formatForbidden(name) + return false + } + if name != url.PathEscape(name) { + formatForbidden(name) + return false + } + // net/url package does not escape everything (https://github.com/golang/go/issues/5684) + // It is better to reject all reserved characters from https://datatracker.ietf.org/doc/html/rfc3986#section-2.2 + if regexp.MustCompile(`[:/?#\[\]@!$&'\(\)*+,;=]`).MatchString(name) { + formatForbidden(name) + return false + } + + return true +} + // AddProvider adds provider for given name func (s *Service) AddProvider(name, cid, csecret string) { p := provider.Params{ diff --git a/v2/auth_test.go b/v2/auth_test.go index 81655bb..b8dbf72 100644 --- a/v2/auth_test.go +++ b/v2/auth_test.go @@ -246,6 +246,34 @@ func TestIntegrationList(t *testing.T) { assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b)) } +func TestIntegrationInvalidProviderNames(t *testing.T) { + invalidNames := []string{ + "provider/with/slashes", + "provider with spaces", + " providerWithSpacesAround\t", + "providerWithReserved-$-Char", + "providerWithReserved-&-Char", + "providerWithReserved-+-Char", + "providerWithReserved-,-Char", + "providerWithReserved-:-Char", + "providerWithReserved-;-Char", + "providerWithReserved-=-Char", + "providerWithReserved-?-Char", + "providerWithReserved-@-Char", + "providerWith%2F-EscapedSequence", + "", + } + svc, teardown := prepService(t, func(svc *Service) { + for _, name := range invalidNames { + svc.AddCustomProvider(name, Client{"cid", "csecret"}, provider.CustomHandlerOpt{}) + } + }) + defer teardown() + + require.Equal(t, 1, len(svc.Providers())) + require.Equal(t, "dev", svc.Providers()[0].Name()) +} + func TestIntegrationUserInfo(t *testing.T) { _, teardown := prepService(t) defer teardown() @@ -386,7 +414,7 @@ func TestDirectProvider(t *testing.T) { func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { _, teardown := prepService(t, func(svc *Service) { - svc.AddDirectProviderWithUserIDFunc("directCustom", + svc.AddDirectProviderWithUserIDFunc("direct_custom", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { return user == "dev_direct" && password == "password", nil }), @@ -401,12 +429,12 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { jar, err := cookiejar.New(nil) require.Nil(t, err) client := &http.Client{Jar: jar, Timeout: 5 * time.Second} - resp, err := client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=bad") + resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 403, resp.StatusCode) - resp, err = client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=password") + resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 200, resp.StatusCode) @@ -416,7 +444,7 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { t.Logf("resp %s", string(body)) t.Logf("headers: %+v", resp.Header) - assert.Contains(t, string(body), `"name":"dev_direct","id":"directCustom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) + assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) require.Equal(t, 2, len(resp.Cookies())) assert.Equal(t, "JWT", resp.Cookies()[0].Name) diff --git a/v2/middleware/auth.go b/v2/middleware/auth.go index 6e8c5ed..6aafa03 100644 --- a/v2/middleware/auth.go +++ b/v2/middleware/auth.go @@ -129,7 +129,7 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { } // check if user provider is allowed - if !a.isProviderAllowed(claims.User.ID) { + if !a.isProviderAllowed(&claims) { onError(h, w, r, fmt.Errorf("user %s/%s provider is not allowed", claims.User.Name, claims.User.ID)) a.JWTService.Reset(w) return @@ -153,13 +153,24 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { return f } -// isProviderAllowed checks if user provider is allowed, user id looks like "provider_1234567890" -// this check is needed to reject users from providers what are used to be allowed but not anymore. +// isProviderAllowed checks if user provider is allowed. +// If provider name is explicitly set in the token claims, then that provider is checked. +// +// If user id looks like "provider_1234567890", +// then there is an attempt to extract provider name from that user ID. +// Note that such read can fail if user id has multiple "_" separator symbols. +// +// This check is needed to reject users from providers what are used to be allowed but not anymore. // Such users made token before the provider was disabled and should not be allowed to login anymore. -func (a *Authenticator) isProviderAllowed(userID string) bool { - userProvider := strings.Split(userID, "_")[0] +func (a *Authenticator) isProviderAllowed(claims *token.Claims) bool { + // TODO: remove this read when old tokens expire and all new tokens have a provider name in them + userIDProvider := strings.Split(claims.User.ID, "_")[0] for _, p := range a.Providers { - if p.Name() == userProvider { + name := p.Name() + if claims.AuthProvider != nil && claims.AuthProvider.Name == name { + return true + } + if name == userIDProvider { return true } } diff --git a/v2/middleware/auth_test.go b/v2/middleware/auth_test.go index d71e0c9..9ce5322 100644 --- a/v2/middleware/auth_test.go +++ b/v2/middleware/auth_test.go @@ -32,6 +32,8 @@ var testJwtNoUser = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI3ODkxOTE4Mj var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9LCJyb2xlIjoiZW1wbG95ZWUifX0.o95raB0aNl2TWUs43Tu6xyX5Y3Fa5wv6_6RFJuN-d6g" +var testJwtValidWithAuthProvider = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fSwiYXV0aF9wcm92aWRlciI6eyJuYW1lIjoicHJvdmlkZXIxIn19.iBKM9-lgejJNjcs-crj6gkEejnIJpavmaq8alenf0JA" + func TestAuthJWTCookie(t *testing.T) { a := makeTestAuth(t) @@ -51,56 +53,51 @@ func TestAuthJWTCookie(t *testing.T) { client := &http.Client{Timeout: 5 * time.Second} expiration := int(365 * 24 * time.Hour.Seconds()) //nolint - t.Run("valid token", func(t *testing.T) { + makeRequest := func(jwtCookie string, xsrfToken string) *http.Response { req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") + req.AddCookie(&http.Cookie{ + Name: "JWT", + Value: jwtCookie, + HttpOnly: true, + Path: "/", + MaxAge: expiration, + Secure: false, + }) + req.Header.Add("X-XSRF-TOKEN", xsrfToken) resp, err := client.Do(req) require.NoError(t, err) + return resp + } + + t.Run("valid token", func(t *testing.T) { + resp := makeRequest(testJwtValid, "random id") assert.Equal(t, 201, resp.StatusCode, "valid token user") }) - t.Run("valid token, wrong provider", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValidWrongProvider, HttpOnly: true, Path: "/", - MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") + t.Run("valid token with auth_provider", func(t *testing.T) { + resp := makeRequest(testJwtValidWithAuthProvider, "random id") + assert.Equal(t, 201, resp.StatusCode, "valid token user") + }) - resp, err := client.Do(req) - require.NoError(t, err) + t.Run("valid token, wrong provider", func(t *testing.T) { + resp := makeRequest(testJwtValidWrongProvider, "random id") assert.Equal(t, 401, resp.StatusCode, "user name1/provider3_id1 provider is not allowed") }) t.Run("xsrf mismatch", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "wrong id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtValid, "wrong id") assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch") }) t.Run("token expired and refreshed", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtExpired, "random id") assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed") }) t.Run("no user info in the token", func(t *testing.T) { - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") - resp, err := client.Do(req) - require.NoError(t, err) + resp := makeRequest(testJwtNoUser, "random id") assert.Equal(t, 401, resp.StatusCode, "no user info in the token") }) } diff --git a/v2/provider/apple.go b/v2/provider/apple.go index 4a2fbcd..9700583 100644 --- a/v2/provider/apple.go +++ b/v2/provider/apple.go @@ -267,6 +267,9 @@ func (ah *AppleHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * time.Minute)), NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, + AuthProvider: &token.AuthProvider{ + Name: ah.name, + }, } if _, err = ah.JwtService.Set(w, claims); err != nil { @@ -376,6 +379,9 @@ func (ah AppleHandler) AuthHandler(w http.ResponseWriter, r *http.Request) { Audience: oauthClaims.Audience, }, SessionOnly: false, + AuthProvider: &token.AuthProvider{ + Name: ah.name, + }, } if _, err = ah.JwtService.Set(w, claims); err != nil { diff --git a/v2/provider/apple_pubkeys.go b/v2/provider/apple_pubkeys.go index c1975a2..364503d 100644 --- a/v2/provider/apple_pubkeys.go +++ b/v2/provider/apple_pubkeys.go @@ -151,7 +151,7 @@ type appleKeySet struct { // get return Apple public key with specific KeyID (kid) func (aks *appleKeySet) get(kid string) (keys *applePublicKey, err error) { - if aks.keys == nil || len(aks.keys) == 0 { + if len(aks.keys) == 0 { return nil, fmt.Errorf("failed to get key in appleKeySet, key set is nil or empty") } diff --git a/v2/provider/apple_test.go b/v2/provider/apple_test.go index 19d1089..4584777 100644 --- a/v2/provider/apple_test.go +++ b/v2/provider/apple_test.go @@ -5,11 +5,11 @@ import ( "crypto/rsa" "crypto/sha1" "encoding/base64" - "encoding/binary" "encoding/json" "fmt" "io" "log" + "math/big" "net/http" "net/http/cookiejar" "net/url" @@ -659,8 +659,8 @@ ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy n := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(publicKey.N.Bytes()) // convert exponent - eBuff := make([]byte, 4) - binary.LittleEndian.PutUint32(eBuff, uint32(publicKey.E)) + require.Positive(t, publicKey.E, "RSA exponent must be positive") + eBuff := big.NewInt(int64(publicKey.E)).Bytes() e := base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(eBuff) JWK := struct { @@ -670,7 +670,7 @@ ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy Kid string `json:"kid"` E string `json:"e"` N string `json:"n"` - }{Alg: "RS256", Kty: "RSA", Use: "sig", Kid: "112233", N: n, E: e[:4]} + }{Alg: "RS256", Kty: "RSA", Use: "sig", Kid: "112233", N: n, E: e} var buffJwk []byte buffJwk, err = json.Marshal(JWK) diff --git a/v2/provider/direct.go b/v2/provider/direct.go index 7d94017..9844989 100644 --- a/v2/provider/direct.go +++ b/v2/provider/direct.go @@ -126,6 +126,9 @@ func (p DirectHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { Audience: []string{creds.Audience}, }, SessionOnly: sessOnly, + AuthProvider: &token.AuthProvider{ + Name: p.ProviderName, + }, } if _, err = p.TokenService.Set(w, claims); err != nil { diff --git a/v2/provider/oauth1.go b/v2/provider/oauth1.go index 6a51b0d..8484a4f 100644 --- a/v2/provider/oauth1.go +++ b/v2/provider/oauth1.go @@ -61,6 +61,9 @@ func (h Oauth1Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * time.Minute)), NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, + AuthProvider: &token.AuthProvider{ + Name: h.name, + }, } if _, err = h.JwtService.Set(w, claims); err != nil { @@ -146,6 +149,9 @@ func (h Oauth1Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { Audience: oauthClaims.Audience, }, SessionOnly: oauthClaims.SessionOnly, + AuthProvider: &token.AuthProvider{ + Name: h.name, + }, } if _, err = h.JwtService.Set(w, claims); err != nil { diff --git a/v2/provider/oauth2.go b/v2/provider/oauth2.go index a985cd9..273a190 100644 --- a/v2/provider/oauth2.go +++ b/v2/provider/oauth2.go @@ -118,6 +118,9 @@ func (p Oauth2Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, NoAva: r.URL.Query().Get("noava") == "1", + AuthProvider: &token.AuthProvider{ + Name: p.name, + }, } if _, err := p.JwtService.Set(w, claims); err != nil { @@ -215,6 +218,9 @@ func (p Oauth2Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { }, SessionOnly: oauthClaims.SessionOnly, NoAva: oauthClaims.NoAva, + AuthProvider: &token.AuthProvider{ + Name: p.name, + }, } if _, err = p.JwtService.Set(w, claims); err != nil { diff --git a/v2/provider/telegram.go b/v2/provider/telegram.go index 906fe59..5b884eb 100644 --- a/v2/provider/telegram.go +++ b/v2/provider/telegram.go @@ -310,6 +310,9 @@ func (th *TelegramHandler) LoginHandler(w http.ResponseWriter, r *http.Request) NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), }, SessionOnly: false, // TODO review? + AuthProvider: &authtoken.AuthProvider{ + Name: th.ProviderName, + }, } if _, err := th.TokenService.Set(w, claims); err != nil { diff --git a/v2/provider/telegram_test.go b/v2/provider/telegram_test.go index 228ae23..7e0cbf0 100644 --- a/v2/provider/telegram_test.go +++ b/v2/provider/telegram_test.go @@ -89,22 +89,36 @@ func TestTelegramUnconfirmedRequest(t *testing.T) { func TestTelegramConfirmedRequest(t *testing.T) { var servedToken string - var mu sync.Mutex + // is set when token becomes used, + // no sync is required because only a single goroutine in TelegramHandler.Run() reads and writes it + var tokenAlreadyUsed bool + + var wgToken sync.WaitGroup + wgToken.Add(1) + defer func() { + if t.Failed() && servedToken == "" { + wgToken.Done() // for the case when test fails before token is generated + } + }() m := &TelegramAPIMock{ GetUpdatesFunc: func(ctx context.Context) (*telegramUpdate, error) { - var upd telegramUpdate + wgToken.Wait() - mu.Lock() - defer mu.Unlock() - if servedToken != "" { - resp := fmt.Sprintf(getUpdatesResp, servedToken) + if tokenAlreadyUsed || t.Failed() { + return nil, fmt.Errorf("token %s has been already used", servedToken) + } - err := json.Unmarshal([]byte(resp), &upd) - if err != nil { - t.Fatal(err) - } + var upd telegramUpdate + resp := fmt.Sprintf(getUpdatesResp, servedToken) + err := json.Unmarshal([]byte(resp), &upd) + if err != nil { + t.Fatal(err) } + + // token is served only once + tokenAlreadyUsed = true + return &upd, nil }, AvatarFunc: func(ctx context.Context, userID int) (string, error) { @@ -147,10 +161,10 @@ func TestTelegramConfirmedRequest(t *testing.T) { err := json.Unmarshal(w.Body.Bytes(), &resp) assert.NoError(t, err) assert.Equal(t, "my_auth_bot", resp.Bot) + assert.NotEmpty(t, resp.Token) - mu.Lock() servedToken = resp.Token - mu.Unlock() + wgToken.Done() // Check the token confirmation assert.Eventually(t, func() bool { diff --git a/v2/provider/verify.go b/v2/provider/verify.go index 0fc9ca6..1c45953 100644 --- a/v2/provider/verify.go +++ b/v2/provider/verify.go @@ -117,6 +117,9 @@ func (e VerifyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { Audience: confClaims.Audience, }, SessionOnly: sessOnly, + AuthProvider: &token.AuthProvider{ + Name: e.ProviderName, + }, } if _, err = e.TokenService.Set(w, claims); err != nil { @@ -152,6 +155,9 @@ func (e VerifyHandler) sendConfirmation(w http.ResponseWriter, r *http.Request) NotBefore: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), Issuer: e.Issuer, }, + AuthProvider: &token.AuthProvider{ + Name: e.ProviderName, + }, } tkn, err := e.TokenService.Token(claims) diff --git a/v2/token/jwt.go b/v2/token/jwt.go index 8350ae8..352fcb3 100644 --- a/v2/token/jwt.go +++ b/v2/token/jwt.go @@ -23,10 +23,11 @@ type Service struct { // Claims stores user info for token and state & from from login type Claims struct { jwt.RegisteredClaims - User *User `json:"user,omitempty"` // user info - SessionOnly bool `json:"sess_only,omitempty"` - Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake - NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon + User *User `json:"user,omitempty"` // user info + SessionOnly bool `json:"sess_only,omitempty"` + Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake + NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon + AuthProvider *AuthProvider `json:"auth_provider,omitempty"` // auth provider info } // Handshake used for oauth handshake @@ -36,6 +37,11 @@ type Handshake struct { ID string `json:"id,omitempty"` } +// AuthProvider stores attributes of provider which has created a JWT token +type AuthProvider struct { + Name string `json:"name,omitempty"` +} + const ( // default names for cookies and headers defaultJWTCookieName = "JWT"