Skip to content

Commit

Permalink
fix: wrap db calls in transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Nov 6, 2024
1 parent 5b6cc1f commit a5bb44b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 25 deletions.
3 changes: 3 additions & 0 deletions consent/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package consent
import (
"context"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"

"github.com/ory/hydra/v2/client"
Expand Down Expand Up @@ -65,6 +66,8 @@ type (
GetDeviceUserAuthRequest(ctx context.Context, challenge string) (*flow.DeviceUserAuthRequest, error)
HandleDeviceUserAuthRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledDeviceUserAuthRequest) (*flow.DeviceUserAuthRequest, error)
VerifyAndInvalidateDeviceUserAuthRequest(ctx context.Context, verifier string) (*flow.HandledDeviceUserAuthRequest, error)

Transaction(context.Context, func(ctx context.Context, c *pop.Connection) error) error
}

ManagerProvider interface {
Expand Down
54 changes: 31 additions & 23 deletions consent/strategy_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"
"time"

"github.com/gobuffalo/pop/v6"
"github.com/gorilla/sessions"
"github.com/hashicorp/go-retryablehttp"
"github.com/pborman/uuid"
Expand All @@ -39,8 +40,6 @@ import (
"github.com/ory/x/urlx"
)

type ctxKey int

const (
DeviceVerificationPath = "/oauth2/device/verify"
CookieAuthenticationSIDName = "sid"
Expand Down Expand Up @@ -1159,21 +1158,11 @@ func (s *DefaultStrategy) HandleOAuth2AuthorizationRequest(
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest")
defer otelx.End(span, &err)

return s.handleOAuth2AuthorizationRequest(ctx, w, r, req, nil)
}

func (s *DefaultStrategy) handleOAuth2AuthorizationRequest(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
req fosite.AuthorizeRequester,
f *flow.Flow,
) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) {
loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier"))
consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier"))
if loginVerifier == "" && consentVerifier == "" {
// ok, we need to process this request and redirect to the original endpoint
return nil, nil, s.requestAuthentication(ctx, w, r, req, f)
return nil, nil, s.requestAuthentication(ctx, w, r, req, nil)
} else if loginVerifier != "" {
f, err := s.verifyAuthentication(ctx, w, r, req, loginVerifier)
if err != nil {
Expand All @@ -1197,7 +1186,10 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) {
) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) {
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2DeviceAuthorizationRequest")
defer otelx.End(span, &err)

deviceVerifier := strings.TrimSpace(r.URL.Query().Get("device_verifier"))
loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier"))
consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier"))
Expand Down Expand Up @@ -1235,16 +1227,32 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest(
ar.RequestedAudience = fosite.Arguments(deviceFlow.RequestedAudience)
}

// TODO(nsklikas): wrap these 2 function calls in a transaction (one persists the flow and the other invalidates the user_code)
consentSession, f, err := s.handleOAuth2AuthorizationRequest(ctx, w, r, ar, deviceFlow)
if err != nil {
return nil, nil, err
}
err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(r.Context(), string(f.DeviceCodeRequestID), f.ID)
if err != nil {
return nil, nil, err
if loginVerifier == "" && consentVerifier == "" {
// ok, we need to process this request and redirect to the authentication endpoint
return nil, nil, s.requestAuthentication(ctx, w, r, ar, deviceFlow)
} else if loginVerifier != "" {
f, err := s.verifyAuthentication(ctx, w, r, ar, loginVerifier)
if err != nil {
return nil, nil, err
}

// ok, we need to process this request and redirect to consent endpoint
return nil, f, s.requestConsent(ctx, w, r, ar, f)
}

var consentSession *flow.AcceptOAuth2ConsentRequest
var f *flow.Flow

err = s.r.ConsentManager().Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
consentSession, f, err = s.verifyConsent(ctx, w, r, consentVerifier)
if err != nil {
return err
}
err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(ctx, string(f.DeviceCodeRequestID), f.ID)

return err
})

return consentSession, f, err
}

Expand Down Expand Up @@ -1325,7 +1333,7 @@ func (s *DefaultStrategy) forwardDeviceRequest(ctx context.Context, w http.Respo
}

func (s *DefaultStrategy) verifyDevice(ctx context.Context, _ http.ResponseWriter, r *http.Request, verifier string) (_ *flow.Flow, err error) {
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.verifyAuthentication")
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.verifyDevice")
defer otelx.End(span, &err)

// We decode the flow from the cookie again because VerifyAndInvalidateDeviceRequest does not return the flow
Expand Down
2 changes: 0 additions & 2 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -747,14 +747,12 @@ func (h *Handler) performOAuth2DeviceVerificationFlow(w http.ResponseWriter, r *
return
}

// TODO(nsklikas): We need to add a db transaction here
req, err := h.r.OAuth2Storage().GetDeviceCodeSessionByRequestID(ctx, f.DeviceCodeRequestID.String(), &Session{})
if err != nil {
x.LogError(r, err, h.r.Logger())
h.r.Writer().WriteError(w, r, err)
return
}
// TODO(nsklika): Can we refactor this so we don't have to pass in the session?
session, err := h.updateSessionWithRequest(ctx, consentSession, f, r, req, req.GetSession().(*Session))
if err != nil {
h.r.Writer().WriteError(w, r, err)
Expand Down

0 comments on commit a5bb44b

Please sign in to comment.