diff --git a/go.mod b/go.mod index a907ba2..51f55e8 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22 require ( github.com/coreos/go-oidc/v3 v3.11.0 github.com/go-ldap/ldap/v3 v3.4.8 + github.com/go-pkgz/expirable-cache/v3 v3.0.0 github.com/negasus/haproxy-spoe-go v1.0.5 github.com/sirupsen/logrus v1.9.3 github.com/spf13/viper v1.19.0 diff --git a/go.sum b/go.sum index fbe2bd1..bb3a5bf 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0 github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= github.com/go-ldap/ldap/v3 v3.4.8 h1:loKJyspcRezt2Q3ZRMq2p/0v8iOurlmeXDPw6fikSvQ= github.com/go-ldap/ldap/v3 v3.4.8/go.mod h1:qS3Sjlu76eHfHGpUdWkAXQTw4beih+cHsco2jXlIXrk= +github.com/go-pkgz/expirable-cache/v3 v3.0.0 h1:u3/gcu3sabLYiTCevoRKv+WzjIn5oo7P8XtiXBeRDLw= +github.com/go-pkgz/expirable-cache/v3 v3.0.0/go.mod h1:2OQiDyEGQalYecLWmXprm3maPXeVb5/6/X7yRPYTzec= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -61,6 +63,9 @@ github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/C github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= diff --git a/internal/auth/authenticator_oidc.go b/internal/auth/authenticator_oidc.go index dbdf303..1a63b6b 100644 --- a/internal/auth/authenticator_oidc.go +++ b/internal/auth/authenticator_oidc.go @@ -15,6 +15,7 @@ import ( "github.com/vmihailenco/msgpack/v5" "github.com/coreos/go-oidc/v3/oidc" + cache "github.com/go-pkgz/expirable-cache/v3" action "github.com/negasus/haproxy-spoe-go/action" message "github.com/negasus/haproxy-spoe-go/message" @@ -74,7 +75,7 @@ type OIDCAuthenticator struct { signatureComputer *HmacSha256Computer encryptor *AESEncryptor - pkceVerifier string + pkceVerifierCache cache.Cache[string, string] options OIDCAuthenticatorOptions } @@ -120,7 +121,7 @@ func NewOIDCAuthenticator(options OIDCAuthenticatorOptions) *OIDCAuthenticator { options: options, signatureComputer: NewHmacSha256Computer(options.SignatureSecret), encryptor: NewAESEncryptor(options.EncryptionSecret), - pkceVerifier: oauth2.GenerateVerifier(), + pkceVerifierCache: cache.NewCache[string, string](), } go func() { @@ -396,8 +397,15 @@ func (oa *OIDCAuthenticator) buildAuthorizationURL(domain string, oauthArgs OAut } var authorizationURL string + pkceVerifier := oauth2.GenerateVerifier() + stateStr := base64.StdEncoding.EncodeToString(stateBytes) + cacheTTL := time.Second * 3600 + if oa.options.CookieTTL != 0 { + cacheTTL = oa.options.CookieTTL + } + oa.pkceVerifierCache.Set(stateStr, pkceVerifier, cacheTTL) err = oa.withOAuth2Config(domain, func(config oauth2.Config) error { - authorizationURL = config.AuthCodeURL(base64.StdEncoding.EncodeToString(stateBytes), oauth2.S256ChallengeOption(oa.pkceVerifier)) + authorizationURL = config.AuthCodeURL(stateStr, oauth2.S256ChallengeOption(pkceVerifier)) return nil }) if err != nil { @@ -435,9 +443,15 @@ func (oa *OIDCAuthenticator) handleOAuth2Callback(tmpl *template.Template, error domain := extractDomainFromHost(r.Host) + pkceVerifier, ok := oa.pkceVerifierCache.Get(stateB64Payload) + if !ok { + logrus.Error("cannot retrieve pkce verifier") + http.Error(w, "Bad request", http.StatusBadRequest) + return + } var oauth2Token *oauth2.Token err := oa.withOAuth2Config(domain, func(config oauth2.Config) error { - token, err := config.Exchange(r.Context(), r.URL.Query().Get("code"), oauth2.VerifierOption(oa.pkceVerifier)) + token, err := config.Exchange(r.Context(), r.URL.Query().Get("code"), oauth2.VerifierOption(pkceVerifier)) oauth2Token = token return err })