From a5bb44b3e3138c9d607f99c52f7fd2efbfbd99a3 Mon Sep 17 00:00:00 2001 From: Nikos Date: Tue, 30 Jul 2024 14:56:56 +0300 Subject: [PATCH] fix: wrap db calls in transaction --- consent/manager.go | 3 +++ consent/strategy_default.go | 54 +++++++++++++++++++++---------------- oauth2/handler.go | 2 -- 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/consent/manager.go b/consent/manager.go index f09c803c06b..577fffa27f1 100644 --- a/consent/manager.go +++ b/consent/manager.go @@ -6,6 +6,7 @@ package consent import ( "context" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/ory/hydra/v2/client" @@ -65,6 +66,8 @@ type ( GetDeviceUserAuthRequest(ctx context.Context, challenge string) (*flow.DeviceUserAuthRequest, error) HandleDeviceUserAuthRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledDeviceUserAuthRequest) (*flow.DeviceUserAuthRequest, error) VerifyAndInvalidateDeviceUserAuthRequest(ctx context.Context, verifier string) (*flow.HandledDeviceUserAuthRequest, error) + + Transaction(context.Context, func(ctx context.Context, c *pop.Connection) error) error } ManagerProvider interface { diff --git a/consent/strategy_default.go b/consent/strategy_default.go index a4a11ac38c5..a17a6412b1e 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/gobuffalo/pop/v6" "github.com/gorilla/sessions" "github.com/hashicorp/go-retryablehttp" "github.com/pborman/uuid" @@ -39,8 +40,6 @@ import ( "github.com/ory/x/urlx" ) -type ctxKey int - const ( DeviceVerificationPath = "/oauth2/device/verify" CookieAuthenticationSIDName = "sid" @@ -1159,21 +1158,11 @@ func (s *DefaultStrategy) HandleOAuth2AuthorizationRequest( ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest") defer otelx.End(span, &err) - return s.handleOAuth2AuthorizationRequest(ctx, w, r, req, nil) -} - -func (s *DefaultStrategy) handleOAuth2AuthorizationRequest( - ctx context.Context, - w http.ResponseWriter, - r *http.Request, - req fosite.AuthorizeRequester, - f *flow.Flow, -) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) { loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier")) consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier")) if loginVerifier == "" && consentVerifier == "" { // ok, we need to process this request and redirect to the original endpoint - return nil, nil, s.requestAuthentication(ctx, w, r, req, f) + return nil, nil, s.requestAuthentication(ctx, w, r, req, nil) } else if loginVerifier != "" { f, err := s.verifyAuthentication(ctx, w, r, req, loginVerifier) if err != nil { @@ -1197,7 +1186,10 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest( ctx context.Context, w http.ResponseWriter, r *http.Request, -) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) { +) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) { + ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2DeviceAuthorizationRequest") + defer otelx.End(span, &err) + deviceVerifier := strings.TrimSpace(r.URL.Query().Get("device_verifier")) loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier")) consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier")) @@ -1235,16 +1227,32 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest( ar.RequestedAudience = fosite.Arguments(deviceFlow.RequestedAudience) } - // TODO(nsklikas): wrap these 2 function calls in a transaction (one persists the flow and the other invalidates the user_code) - consentSession, f, err := s.handleOAuth2AuthorizationRequest(ctx, w, r, ar, deviceFlow) - if err != nil { - return nil, nil, err - } - err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(r.Context(), string(f.DeviceCodeRequestID), f.ID) - if err != nil { - return nil, nil, err + if loginVerifier == "" && consentVerifier == "" { + // ok, we need to process this request and redirect to the authentication endpoint + return nil, nil, s.requestAuthentication(ctx, w, r, ar, deviceFlow) + } else if loginVerifier != "" { + f, err := s.verifyAuthentication(ctx, w, r, ar, loginVerifier) + if err != nil { + return nil, nil, err + } + + // ok, we need to process this request and redirect to consent endpoint + return nil, f, s.requestConsent(ctx, w, r, ar, f) } + var consentSession *flow.AcceptOAuth2ConsentRequest + var f *flow.Flow + + err = s.r.ConsentManager().Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + consentSession, f, err = s.verifyConsent(ctx, w, r, consentVerifier) + if err != nil { + return err + } + err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(ctx, string(f.DeviceCodeRequestID), f.ID) + + return err + }) + return consentSession, f, err } @@ -1325,7 +1333,7 @@ func (s *DefaultStrategy) forwardDeviceRequest(ctx context.Context, w http.Respo } func (s *DefaultStrategy) verifyDevice(ctx context.Context, _ http.ResponseWriter, r *http.Request, verifier string) (_ *flow.Flow, err error) { - ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.verifyAuthentication") + ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.verifyDevice") defer otelx.End(span, &err) // We decode the flow from the cookie again because VerifyAndInvalidateDeviceRequest does not return the flow diff --git a/oauth2/handler.go b/oauth2/handler.go index 4822e81653b..b7250035e8c 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -747,14 +747,12 @@ func (h *Handler) performOAuth2DeviceVerificationFlow(w http.ResponseWriter, r * return } - // TODO(nsklikas): We need to add a db transaction here req, err := h.r.OAuth2Storage().GetDeviceCodeSessionByRequestID(ctx, f.DeviceCodeRequestID.String(), &Session{}) if err != nil { x.LogError(r, err, h.r.Logger()) h.r.Writer().WriteError(w, r, err) return } - // TODO(nsklika): Can we refactor this so we don't have to pass in the session? session, err := h.updateSessionWithRequest(ctx, consentSession, f, r, req, req.GetSession().(*Session)) if err != nil { h.r.Writer().WriteError(w, r, err)