diff --git a/common.go b/common.go index ee75542..8cc6d71 100644 --- a/common.go +++ b/common.go @@ -168,7 +168,10 @@ func (a Auth) GetVerifier(options ...jwt.ValidateOption) (*jwtauth.JWTAuth, erro // findProjectClaim looks for the project_id/project claim in the JWT func findProjectClaim(r *http.Request) (uint64, error) { - raw := cmp.Or(jwtauth.TokenFromHeader(r)) + raw := jwtauth.TokenFromHeader(r) + if raw == "" { + return 0, nil + } token, err := jwt.ParseString(raw, jwt.WithVerify(false)) if err != nil { @@ -179,7 +182,7 @@ func findProjectClaim(r *http.Request) (uint64, error) { claim := cmp.Or(claims["project_id"], claims["project"]) if claim == nil { - return 0, fmt.Errorf("missing project claim") + return 0, nil } switch val := claim.(type) { diff --git a/middleware.go b/middleware.go index 5149243..e8e6452 100644 --- a/middleware.go +++ b/middleware.go @@ -65,23 +65,26 @@ func VerifyToken(cfg Options) func(next http.Handler) http.Handler { if cfg.ProjectStore != nil { projectID, err := findProjectClaim(r) if err != nil { - cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get project claim: %w", err)) + cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("find project claim: %w", err)) return } - project, _auth, err := cfg.ProjectStore.GetProject(ctx, projectID) - if err != nil { - cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get project: %w", err)) - return - } - if project == nil { - cfg.ErrHandler(r, w, proto.ErrProjectNotFound) - return - } - if _auth != nil { - auth = _auth + if projectID != 0 { + project, _auth, err := cfg.ProjectStore.GetProject(ctx, projectID) + if err != nil { + cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get project: %w", err)) + return + } + if project == nil { + cfg.ErrHandler(r, w, proto.ErrProjectNotFound) + return + } + if _auth != nil { + auth = _auth + } + ctx = WithProject(ctx, project) } - ctx = WithProject(ctx, project) + } jwtAuth, err := auth.GetVerifier(jwtOptions...)