diff --git a/go.mod b/go.mod index a6adac3..b5e33d5 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway v1.14.3 github.com/jinzhu/gorm v1.9.12 github.com/mailru/easyjson v0.7.1 // indirect - github.com/nilorg/oauth2 v0.2.3 + github.com/nilorg/oauth2 v0.2.4 github.com/nilorg/pkg v0.0.0-20200517083116-9b88f6e458df github.com/nilorg/protobuf v0.0.0-20200503084506-2b10c53bd0f9 github.com/nilorg/sdk v0.0.0-20200517085820-1f0160ffff7a diff --git a/go.sum b/go.sum index ad3a6a9..cd526d6 100644 --- a/go.sum +++ b/go.sum @@ -379,8 +379,8 @@ github.com/nats-io/nats.go v1.8.1/go.mod h1:BrFz9vVn0fU3AcH9Vn4Kd7W0NpJ651tD5omQ github.com/nats-io/nkeys v0.0.2/go.mod h1:dab7URMsZm6Z/jp9Z5UGa87Uutgc2mVpXLC4B7TDb/4= github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= -github.com/nilorg/oauth2 v0.2.3 h1:ASHq7e9vLNscp8A4ycUAOvE74xKS9ACX91GRqWCvoUs= -github.com/nilorg/oauth2 v0.2.3/go.mod h1:a3NN/02vXZmw054tbt8mGHmvp/lyF2w8H/hYqVgppOg= +github.com/nilorg/oauth2 v0.2.4 h1:IYztO/+ILiTXnByqC3dK3A3gY32VudPDHogTrXJ2z9o= +github.com/nilorg/oauth2 v0.2.4/go.mod h1:a3NN/02vXZmw054tbt8mGHmvp/lyF2w8H/hYqVgppOg= github.com/nilorg/pkg v0.0.0-20200517083116-9b88f6e458df h1:rz+qPqZNzEYrAe/VqlX+DEI2Mj+t/zu6e5AkOUSn/pE= github.com/nilorg/pkg v0.0.0-20200517083116-9b88f6e458df/go.mod h1:UjETEjErg9a2LREGwU34bJTfprTw9gkqqHq2Q+3URHI= github.com/nilorg/protobuf v0.0.0-20200503084506-2b10c53bd0f9 h1:GcX22XZzRMR1LGjiuQ/aSH+iDte5o3fYEAClVKXw8b4= diff --git a/internal/module/global/global.go b/internal/module/global/global.go index 70c8b41..0449f4a 100644 --- a/internal/module/global/global.go +++ b/internal/module/global/global.go @@ -45,6 +45,7 @@ func initPrivate() { logger.Fatalf("x509.ParsePKCS1PrivateKey Error: %s", err) return } + JwtPublicKey = &JwtPrivateKey.PublicKey } func initCert() { @@ -62,7 +63,6 @@ func initCert() { logger.Fatalln("failed to parse certificate: %s", err) return } - JwtPublicKey = JwtCertificates[0].PublicKey.(*rsa.PublicKey) } func initJwk() { diff --git a/internal/pkg/token/jwt.go b/internal/pkg/token/jwt.go index f1ac355..89a7b52 100644 --- a/internal/pkg/token/jwt.go +++ b/internal/pkg/token/jwt.go @@ -1,16 +1,26 @@ package token import ( + "strings" "time" "github.com/nilorg/oauth2" - "github.com/nilorg/sdk/strings" + sdkStrings "github.com/nilorg/sdk/strings" ) // NewGenerateAccessToken 创建默认生成AccessToken方法 func NewGenerateAccessToken(key interface{}, idTokenEnabled bool) oauth2.GenerateAccessTokenFunc { return func(issuer, clientID, scope, openID string) (token *oauth2.TokenResponse, err error) { - accessJwtClaims := oauth2.NewJwtClaims(issuer, clientID, scope, openID) + idTokenFlag := false + var newScopes []string + for _, s := range sdkStrings.Split(scope, " ") { + if s == "openid" { + idTokenFlag = true + } else { + newScopes = append(newScopes, s) + } + } + accessJwtClaims := oauth2.NewJwtClaims(issuer, clientID, strings.Join(newScopes, " "), openID) var tokenStr string tokenStr, err = oauth2.NewJwtToken(accessJwtClaims, "RS256", key) if err != nil { @@ -34,13 +44,6 @@ func NewGenerateAccessToken(key interface{}, idTokenEnabled bool) oauth2.Generat RefreshToken: refreshTokenStr, Scope: scope, } - idTokenFlag := false - for _, s := range strings.Split(scope, " ") { - if s == "openid" { - idTokenFlag = true - break - } - } if idTokenFlag && idTokenEnabled { idTokenJwtClaims := oauth2.JwtClaims{ JwtStandardClaims: oauth2.JwtStandardClaims{ @@ -50,7 +53,7 @@ func NewGenerateAccessToken(key interface{}, idTokenEnabled bool) oauth2.Generat ExpiresAt: currTime.Add(oauth2.AccessTokenExpire).Unix(), Audience: []string{clientID}, }, - Scope: scope, + Scope: "openid", } var idToken string idToken, err = oauth2.NewJwtClaimsToken(&idTokenJwtClaims, "RS256", key) diff --git a/internal/server/middleware/middleware.go b/internal/server/middleware/middleware.go index 47fb51b..c1838fe 100644 --- a/internal/server/middleware/middleware.go +++ b/internal/server/middleware/middleware.go @@ -51,7 +51,7 @@ func parseAuth(auth string) (token string, ok bool) { // AuthUserinfoRequired 身份验证 func AuthUserinfoRequired(key interface{}) gin.HandlerFunc { return func(ctx *gin.Context) { - accessToken, ok := parseAuth(ctx.GetHeader("Authorization")) + tok, ok := parseAuth(ctx.GetHeader("Authorization")) if !ok { ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "Authorization Is Empty", @@ -62,7 +62,7 @@ func AuthUserinfoRequired(key interface{}) gin.HandlerFunc { idTokenClaims *oauth2.JwtClaims err error ) - idTokenClaims, err = oauth2.ParseJwtClaimsToken(accessToken, key) + idTokenClaims, err = oauth2.ParseJwtClaimsToken(tok, key) if err != nil { logger.Errorf("oauth2.ParseJwtClaimsToken: %s", err) ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ @@ -83,7 +83,7 @@ func AuthUserinfoRequired(key interface{}) gin.HandlerFunc { }) return } - if idTokenClaims.VerifyScope("openid", false) { + if !idTokenClaims.VerifyScope("openid", false) { ctx.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": oauth2.ErrInvalidScope.Error(), })