diff --git a/consent/csrf.go b/consent/csrf.go index 42588390a52..1d9b202f984 100644 --- a/consent/csrf.go +++ b/consent/csrf.go @@ -45,7 +45,7 @@ func createCsrfSession(w http.ResponseWriter, r *http.Request, conf x.CookieConf return nil } -func validateCsrfSession(r *http.Request, conf x.CookieConfigProvider, store sessions.Store, name, expectedCSRF string) error { +func validateCsrfSession(r *http.Request, conf x.CookieConfigProvider, store sessions.Store, name, expectedCSRF string, _ []byte) error { if cookie, err := getCsrfSession(r, store, conf, name); err != nil { return errorsx.WithStack(fosite.ErrRequestForbidden.WithHint("CSRF session cookie could not be decoded.")) } else if csrf, err := mapx.GetString(cookie.Values, "csrf"); err != nil { diff --git a/consent/strategy_default.go b/consent/strategy_default.go index 0de9ac2b168..4a82ec1b2da 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -336,7 +336,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re } clientSpecificCookieNameLoginCSRF := fmt.Sprintf("%s_%d", s.r.Config().CookieNameLoginCSRF(ctx), murmur3.Sum32(session.LoginRequest.Client.ID.Bytes())) - if err := validateCsrfSession(r, s.r.Config(), store, clientSpecificCookieNameLoginCSRF, session.LoginRequest.CSRF); err != nil { + if err := validateCsrfSession(r, s.r.Config(), store, clientSpecificCookieNameLoginCSRF, session.LoginRequest.CSRF, session.Context); err != nil { return nil, err } @@ -598,7 +598,7 @@ func (s *DefaultStrategy) verifyConsent(ctx context.Context, w http.ResponseWrit } clientSpecificCookieNameConsentCSRF := fmt.Sprintf("%s_%d", s.r.Config().CookieNameConsentCSRF(ctx), murmur3.Sum32(session.ConsentRequest.Client.ID.Bytes())) - if err := validateCsrfSession(r, s.r.Config(), store, clientSpecificCookieNameConsentCSRF, session.ConsentRequest.CSRF); err != nil { + if err := validateCsrfSession(r, s.r.Config(), store, clientSpecificCookieNameConsentCSRF, session.ConsentRequest.CSRF, session.Context); err != nil { return nil, err } diff --git a/consent/types.go b/consent/types.go index 6a389e9d8bb..7a139d4da37 100644 --- a/consent/types.go +++ b/consent/types.go @@ -172,6 +172,11 @@ type AcceptOAuth2ConsentRequest struct { // the flow. WasHandled bool `json:"-"` + // Context is an optional object which can hold arbitrary data. The data will be made available when fetching the + // consent request under the "context" field. This is useful in scenarios where login and consent endpoints share + // data. + Context sqlxx.JSONRawMessage `json:"context"` + ConsentRequest *OAuth2ConsentRequest `json:"-"` Error *RequestDeniedError `json:"-"` RequestedAt time.Time `json:"-"` @@ -236,6 +241,11 @@ type OAuth2ConsentSession struct { // the flow. WasHandled bool `json:"-" db:"was_used"` + // Context is an optional object which can hold arbitrary data. The data will be made available when fetching the + // consent request under the "context" field. This is useful in scenarios where login and consent endpoints share + // data. + Context sqlxx.JSONRawMessage `json:"context"` + // Consent Request // // The consent request that lead to this consent session. diff --git a/flow/flow.go b/flow/flow.go index bbf2e36fec9..999a06ca1cd 100644 --- a/flow/flow.go +++ b/flow/flow.go @@ -295,7 +295,9 @@ func (f *Flow) HandleLoginRequest(h *consent.HandledLoginRequest) error { f.LoginExtendSessionLifespan = h.ExtendSessionLifespan f.ACR = h.ACR f.AMR = h.AMR - f.Context = h.Context + if h.Context == nil { + f.Context = h.Context + } f.LoginWasUsed = h.WasHandled f.LoginAuthenticatedAt = h.AuthenticatedAt return nil @@ -388,6 +390,10 @@ func (f *Flow) HandleConsentRequest(r *consent.AcceptOAuth2ConsentRequest) error f.ConsentWasHandled = r.WasHandled f.ConsentError = r.Error + if r.Context != nil { + f.Context = r.Context + } + if r.Session != nil { f.SessionIDToken = r.Session.IDToken f.SessionAccessToken = r.Session.AccessToken @@ -453,6 +459,7 @@ func (f *Flow) GetHandledConsentRequest() *consent.AcceptOAuth2ConsentRequest { AuthenticatedAt: f.LoginAuthenticatedAt, SessionIDToken: f.SessionIDToken, SessionAccessToken: f.SessionAccessToken, + Context: f.Context, } } diff --git a/flow/flow_test.go b/flow/flow_test.go index c00e7524b2e..97cee4c5fc9 100644 --- a/flow/flow_test.go +++ b/flow/flow_test.go @@ -88,6 +88,7 @@ func (f *Flow) setHandledConsentRequest(r consent.AcceptOAuth2ConsentRequest) { f.LoginAuthenticatedAt = r.AuthenticatedAt f.SessionIDToken = r.SessionIDToken f.SessionAccessToken = r.SessionAccessToken + f.Context = r.Context } func TestFlow_GetLoginRequest(t *testing.T) {