Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve provider name handling #213

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ Follow to next steps for configuring on the Apple side:

After completing the previous steps, you can proceed with configuring the Apple auth provider. Here are the parameters for AppleConfig:

- _ClientID_ (**required**) - Service ID identifier which is used for Sign with Apple
- _ClientID_ (**required**) - Service ID (or App ID) which is used for Sign with Apple
- _TeamID_ (**required**) - Identifier a developer account (use as prefix for all App ID)
- _KeyID_ (**required**) - Identifier a generated key for Sign with Apple
- _ResponseMode_ - Response Mode, please see [documentation](https://developer.apple.com/documentation/sign_in_with_apple/request_an_authorization_to_the_sign_in_with_apple_server?changes=_1_2#4066168) for reference, default is `form_post`
Expand All @@ -541,7 +541,7 @@ After completing the previous steps, you can proceed with configuring the Apple
// apple config parameters
appleCfg := provider.AppleConfig{
TeamID: os.Getenv("AEXMPL_APPLE_TID"), // developer account identifier
ClientID: os.Getenv("AEXMPL_APPLE_CID"), // service identifier
ClientID: os.Getenv("AEXMPL_APPLE_CID"), // Service ID (or App ID)
KeyID: os.Getenv("AEXMPL_APPLE_KEYID"), // private key identifier
}
```
Expand Down
34 changes: 34 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package auth
import (
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"time"

Expand Down Expand Up @@ -267,10 +269,42 @@ func (s *Service) addProviderByName(name string, p provider.Params) {
}

func (s *Service) addProvider(prov provider.Provider) {
if !s.isValidProviderName(prov.Name()) {
return
}
s.providers = append(s.providers, provider.NewService(prov))
s.authMiddleware.Providers = s.providers
}

func (s *Service) isValidProviderName(name string) bool {
if strings.TrimSpace(name) == "" {
s.logger.Logf("[ERROR] provider has been ignored because its name is empty")
return false
}

formatForbidden := func(name string) {
s.logger.Logf("[ERROR] provider has been ignored because its name contains forbidden characters: '%s'", name)
}

path, err := url.PathUnescape(name)
if err != nil || path != name {
formatForbidden(name)
return false
}
if name != url.PathEscape(name) {
formatForbidden(name)
return false
}
// net/url package does not escape everything (https://github.com/golang/go/issues/5684)
// It is better to reject all reserved characters from https://datatracker.ietf.org/doc/html/rfc3986#section-2.2
if regexp.MustCompile(`[:/?#\[\]@!$&'\(\)*+,;=]`).MatchString(name) {
formatForbidden(name)
return false
}

return true
}

// AddProvider adds provider for given name
func (s *Service) AddProvider(name, cid, csecret string) {
p := provider.Params{
Expand Down
36 changes: 32 additions & 4 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,34 @@ func TestIntegrationList(t *testing.T) {
assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b))
}

func TestIntegrationInvalidProviderNames(t *testing.T) {
invalidNames := []string{
"provider/with/slashes",
"provider with spaces",
" providerWithSpacesAround\t",
"providerWithReserved-$-Char",
"providerWithReserved-&-Char",
"providerWithReserved-+-Char",
"providerWithReserved-,-Char",
"providerWithReserved-:-Char",
"providerWithReserved-;-Char",
"providerWithReserved-=-Char",
"providerWithReserved-?-Char",
"providerWithReserved-@-Char",
"providerWith%2F-EscapedSequence",
"",
}
svc, teardown := prepService(t, func(svc *Service) {
for _, name := range invalidNames {
svc.AddCustomProvider(name, Client{"cid", "csecret"}, provider.CustomHandlerOpt{})
}
})
defer teardown()

require.Equal(t, 1, len(svc.Providers()))
require.Equal(t, "dev", svc.Providers()[0].Name())
}

func TestIntegrationUserInfo(t *testing.T) {
_, teardown := prepService(t)
defer teardown()
Expand Down Expand Up @@ -386,7 +414,7 @@ func TestDirectProvider(t *testing.T) {

func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
_, teardown := prepService(t, func(svc *Service) {
svc.AddDirectProviderWithUserIDFunc("directCustom",
svc.AddDirectProviderWithUserIDFunc("direct_custom",
provider.CredCheckerFunc(func(user, password string) (ok bool, err error) {
return user == "dev_direct" && password == "password", nil
}),
Expand All @@ -401,12 +429,12 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
jar, err := cookiejar.New(nil)
require.Nil(t, err)
client := &http.Client{Jar: jar, Timeout: 5 * time.Second}
resp, err := client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=bad")
resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad")
require.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, 403, resp.StatusCode)

resp, err = client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=password")
resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password")
require.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
Expand All @@ -416,7 +444,7 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
t.Logf("resp %s", string(body))
t.Logf("headers: %+v", resp.Header)

assert.Contains(t, string(body), `"name":"dev_direct","id":"directCustom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`)
assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`)

require.Equal(t, 2, len(resp.Cookies()))
assert.Equal(t, "JWT", resp.Cookies()[0].Name)
Expand Down
23 changes: 17 additions & 6 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
}

// check if user provider is allowed
if !a.isProviderAllowed(claims.User.ID) {
if !a.isProviderAllowed(&claims) {
onError(h, w, r, fmt.Errorf("user %s/%s provider is not allowed", claims.User.Name, claims.User.ID))
a.JWTService.Reset(w)
return
Expand All @@ -153,13 +153,24 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
return f
}

// isProviderAllowed checks if user provider is allowed, user id looks like "provider_1234567890"
// this check is needed to reject users from providers what are used to be allowed but not anymore.
// isProviderAllowed checks if user provider is allowed.
// If provider name is explicitly set in the token claims, then that provider is checked.
//
// If user id looks like "provider_1234567890",
// then there is an attempt to extract provider name from that user ID.
// Note that such read can fail if user id has multiple "_" separator symbols.
//
// This check is needed to reject users from providers what are used to be allowed but not anymore.
// Such users made token before the provider was disabled and should not be allowed to login anymore.
func (a *Authenticator) isProviderAllowed(userID string) bool {
userProvider := strings.Split(userID, "_")[0]
func (a *Authenticator) isProviderAllowed(claims *token.Claims) bool {
// TODO: remove this read when old tokens expire and all new tokens have a provider name in them
userIDProvider := strings.Split(claims.User.ID, "_")[0]
for _, p := range a.Providers {
if p.Name() == userProvider {
name := p.Name()
if claims.AuthProvider != nil && claims.AuthProvider.Name == name {
return true
}
if name == userIDProvider {
return true
}
}
Expand Down
55 changes: 26 additions & 29 deletions middleware/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ var testJwtNoUser = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI3ODkxOTE4Mj

var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9LCJyb2xlIjoiZW1wbG95ZWUifX0.o95raB0aNl2TWUs43Tu6xyX5Y3Fa5wv6_6RFJuN-d6g"

var testJwtValidWithAuthProvider = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fSwiYXV0aF9wcm92aWRlciI6eyJuYW1lIjoicHJvdmlkZXIxIn19.iBKM9-lgejJNjcs-crj6gkEejnIJpavmaq8alenf0JA"

func TestAuthJWTCookie(t *testing.T) {
a := makeTestAuth(t)

Expand All @@ -51,56 +53,51 @@ func TestAuthJWTCookie(t *testing.T) {
client := &http.Client{Timeout: 5 * time.Second}
expiration := int(365 * 24 * time.Hour.Seconds()) //nolint

t.Run("valid token", func(t *testing.T) {
makeRequest := func(jwtCookie string, xsrfToken string) *http.Response {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
req.AddCookie(&http.Cookie{
Name: "JWT",
Value: jwtCookie,
HttpOnly: true,
Path: "/",
MaxAge: expiration,
Secure: false,
})
req.Header.Add("X-XSRF-TOKEN", xsrfToken)

resp, err := client.Do(req)
require.NoError(t, err)
return resp
}

t.Run("valid token", func(t *testing.T) {
resp := makeRequest(testJwtValid, "random id")
assert.Equal(t, 201, resp.StatusCode, "valid token user")
})

t.Run("valid token, wrong provider", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValidWrongProvider, HttpOnly: true, Path: "/",
MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
t.Run("valid token with auth_provider", func(t *testing.T) {
resp := makeRequest(testJwtValidWithAuthProvider, "random id")
assert.Equal(t, 201, resp.StatusCode, "valid token user")
})

resp, err := client.Do(req)
require.NoError(t, err)
t.Run("valid token, wrong provider", func(t *testing.T) {
resp := makeRequest(testJwtValidWrongProvider, "random id")
assert.Equal(t, 401, resp.StatusCode, "user name1/provider3_id1 provider is not allowed")
})

t.Run("xsrf mismatch", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "wrong id")
resp, err := client.Do(req)
require.NoError(t, err)
resp := makeRequest(testJwtValid, "wrong id")
assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch")
})

t.Run("token expired and refreshed", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
resp, err := client.Do(req)
require.NoError(t, err)
resp := makeRequest(testJwtExpired, "random id")
assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed")
})

t.Run("no user info in the token", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
resp, err := client.Do(req)
require.NoError(t, err)
resp := makeRequest(testJwtNoUser, "random id")
assert.Equal(t, 401, resp.StatusCode, "no user info in the token")
})
}
Expand Down
6 changes: 6 additions & 0 deletions provider/apple.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ func (ah *AppleHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
ExpiresAt: time.Now().Add(30 * time.Minute).Unix(),
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
AuthProvider: &token.AuthProvider{
Name: ah.name,
},
}

if _, err = ah.JwtService.Set(w, claims); err != nil {
Expand Down Expand Up @@ -376,6 +379,9 @@ func (ah AppleHandler) AuthHandler(w http.ResponseWriter, r *http.Request) {
Audience: oauthClaims.Audience,
},
SessionOnly: false,
AuthProvider: &token.AuthProvider{
Name: ah.name,
},
}

if _, err = ah.JwtService.Set(w, claims); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion provider/apple_pubkeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ type appleKeySet struct {

// get return Apple public key with specific KeyID (kid)
func (aks *appleKeySet) get(kid string) (keys *applePublicKey, err error) {
if aks.keys == nil || len(aks.keys) == 0 {
if len(aks.keys) == 0 {
return nil, fmt.Errorf("failed to get key in appleKeySet, key set is nil or empty")
}

Expand Down
8 changes: 4 additions & 4 deletions provider/apple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import (
"crypto/rsa"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"log"
"math/big"
"net/http"
"net/http/cookiejar"
"net/url"
Expand Down Expand Up @@ -659,8 +659,8 @@ ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy
n := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(publicKey.N.Bytes())

// convert exponent
eBuff := make([]byte, 4)
binary.LittleEndian.PutUint32(eBuff, uint32(publicKey.E))
require.Positive(t, publicKey.E, "RSA exponent must be positive")
eBuff := big.NewInt(int64(publicKey.E)).Bytes()
e := base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(eBuff)

JWK := struct {
Expand All @@ -670,7 +670,7 @@ ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy
Kid string `json:"kid"`
E string `json:"e"`
N string `json:"n"`
}{Alg: "RS256", Kty: "RSA", Use: "sig", Kid: "112233", N: n, E: e[:4]}
}{Alg: "RS256", Kty: "RSA", Use: "sig", Kid: "112233", N: n, E: e}

var buffJwk []byte
buffJwk, err = json.Marshal(JWK)
Expand Down
3 changes: 3 additions & 0 deletions provider/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ func (p DirectHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
Audience: creds.Audience,
},
SessionOnly: sessOnly,
AuthProvider: &token.AuthProvider{
Name: p.ProviderName,
},
}

if _, err = p.TokenService.Set(w, claims); err != nil {
Expand Down
6 changes: 6 additions & 0 deletions provider/oauth1.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ func (h Oauth1Handler) LoginHandler(w http.ResponseWriter, r *http.Request) {
ExpiresAt: time.Now().Add(30 * time.Minute).Unix(),
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
AuthProvider: &token.AuthProvider{
Name: h.name,
},
}

if _, err = h.JwtService.Set(w, claims); err != nil {
Expand Down Expand Up @@ -146,6 +149,9 @@ func (h Oauth1Handler) AuthHandler(w http.ResponseWriter, r *http.Request) {
Audience: oauthClaims.Audience,
},
SessionOnly: oauthClaims.SessionOnly,
AuthProvider: &token.AuthProvider{
Name: h.name,
},
}

if _, err = h.JwtService.Set(w, claims); err != nil {
Expand Down
6 changes: 6 additions & 0 deletions provider/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ func (p Oauth2Handler) LoginHandler(w http.ResponseWriter, r *http.Request) {
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
NoAva: r.URL.Query().Get("noava") == "1",
AuthProvider: &token.AuthProvider{
Name: p.name,
},
}

if _, err := p.JwtService.Set(w, claims); err != nil {
Expand Down Expand Up @@ -215,6 +218,9 @@ func (p Oauth2Handler) AuthHandler(w http.ResponseWriter, r *http.Request) {
},
SessionOnly: oauthClaims.SessionOnly,
NoAva: oauthClaims.NoAva,
AuthProvider: &token.AuthProvider{
Name: p.name,
},
}

if _, err = p.JwtService.Set(w, claims); err != nil {
Expand Down
3 changes: 3 additions & 0 deletions provider/telegram.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ func (th *TelegramHandler) LoginHandler(w http.ResponseWriter, r *http.Request)
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
SessionOnly: false, // TODO review?
AuthProvider: &authtoken.AuthProvider{
Name: th.ProviderName,
},
}

if _, err := th.TokenService.Set(w, claims); err != nil {
Expand Down
Loading
Loading