Skip to content

Commit

Permalink
fix: wrap token handler in transaction (#3730)
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl authored Mar 5, 2024
1 parent b47942c commit 67a85cc
Show file tree
Hide file tree
Showing 11 changed files with 495 additions and 77 deletions.
16 changes: 5 additions & 11 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@ go 1.21

toolchain go1.21.0

replace (
github.com/jackc/pcmock => github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65
github.com/jackc/pgconn => github.com/jackc/pgconn v1.14.1
github.com/mattn/go-sqlite3 => github.com/mattn/go-sqlite3 v1.14.16
)

replace github.com/ory/hydra-client-go/v2 => ./internal/httpclient

require (
Expand Down Expand Up @@ -44,7 +38,7 @@ require (
github.com/ory/hydra-client-go/v2 v2.1.1
github.com/ory/jsonschema/v3 v3.0.8
github.com/ory/kratos-client-go v0.13.1
github.com/ory/x v0.0.612-0.20240130132700-6275e3f1ad0d
github.com/ory/x v0.0.616
github.com/pborman/uuid v1.2.1
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.16.0
Expand All @@ -66,7 +60,7 @@ require (
go.opentelemetry.io/otel/sdk v1.21.0
go.opentelemetry.io/otel/trace v1.21.0
go.uber.org/automaxprocs v1.5.3
golang.org/x/crypto v0.17.0
golang.org/x/crypto v0.21.0
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
golang.org/x/oauth2 v0.14.0
golang.org/x/sync v0.5.0
Expand Down Expand Up @@ -181,7 +175,7 @@ require (
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mattn/go-sqlite3 v1.14.16 // indirect
github.com/mattn/goveralls v0.0.12 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/microcosm-cc/bluemonday v1.0.26 // indirect
Expand Down Expand Up @@ -234,8 +228,8 @@ require (
go.opentelemetry.io/otel/metric v1.21.0 // indirect
go.opentelemetry.io/proto/otlp v1.0.0 // indirect
golang.org/x/mod v0.14.0 // indirect
golang.org/x/net v0.18.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect
google.golang.org/appengine v1.6.8 // indirect
Expand Down
73 changes: 65 additions & 8 deletions go.sum

Large diffs are not rendered by default.

347 changes: 347 additions & 0 deletions internal/httpclient/go.sum

Large diffs are not rendered by default.

95 changes: 50 additions & 45 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package oauth2

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
Expand All @@ -13,6 +14,7 @@ import (
"strings"
"time"

"github.com/gobuffalo/pop/v6"
"github.com/tidwall/gjson"

"github.com/pborman/uuid"
Expand Down Expand Up @@ -958,68 +960,71 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) {
return
}

if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeClientCredentials)) ||
accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeJWTBearer)) {
var accessTokenKeyID string
if h.c.AccessTokenStrategy(ctx, client.AccessTokenStrategySource(accessRequest.GetClient())) == "jwt" {
accessTokenKeyID, err = h.r.AccessTokenJWTStrategy().GetPublicKeyID(ctx)
if err != nil {
x.LogError(r, err, h.r.Logger())
h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err)
events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest))
return
err = h.r.Persister().Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error {
var err error

if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeClientCredentials)) ||
accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeJWTBearer)) {
var accessTokenKeyID string
if h.c.AccessTokenStrategy(ctx, client.AccessTokenStrategySource(accessRequest.GetClient())) == "jwt" {
accessTokenKeyID, err = h.r.AccessTokenJWTStrategy().GetPublicKeyID(ctx)
if err != nil {
return err
}
}
}

// only for client_credentials, otherwise Authentication is included in session
if accessRequest.GetGrantTypes().ExactOne("client_credentials") {
session.Subject = accessRequest.GetClient().GetID()
}
session.ClientID = accessRequest.GetClient().GetID()
session.KID = accessTokenKeyID
session.DefaultSession.Claims.Issuer = h.c.IssuerURL(r.Context()).String()
session.DefaultSession.Claims.IssuedAt = time.Now().UTC()

scopes := accessRequest.GetRequestedScopes()
// only for client_credentials, otherwise Authentication is included in session
if accessRequest.GetGrantTypes().ExactOne("client_credentials") {
session.Subject = accessRequest.GetClient().GetID()
}
session.ClientID = accessRequest.GetClient().GetID()
session.KID = accessTokenKeyID
session.DefaultSession.Claims.Issuer = h.c.IssuerURL(r.Context()).String()
session.DefaultSession.Claims.IssuedAt = time.Now().UTC()

scopes := accessRequest.GetRequestedScopes()

// Added for compatibility with MITREid
if h.c.GrantAllClientCredentialsScopesPerDefault(r.Context()) && len(scopes) == 0 {
for _, scope := range accessRequest.GetClient().GetScopes() {
accessRequest.GrantScope(scope)
}
}

// Added for compatibility with MITREid
if h.c.GrantAllClientCredentialsScopesPerDefault(r.Context()) && len(scopes) == 0 {
for _, scope := range accessRequest.GetClient().GetScopes() {
accessRequest.GrantScope(scope)
for _, scope := range scopes {
if h.r.Config().GetScopeStrategy(ctx)(accessRequest.GetClient().GetScopes(), scope) {
accessRequest.GrantScope(scope)
}
}
}

for _, scope := range scopes {
if h.r.Config().GetScopeStrategy(ctx)(accessRequest.GetClient().GetScopes(), scope) {
accessRequest.GrantScope(scope)
for _, audience := range accessRequest.GetRequestedAudience() {
if h.r.AudienceStrategy()(accessRequest.GetClient().GetAudience(), []string{audience}) == nil {
accessRequest.GrantAudience(audience)
}
}
}

for _, audience := range accessRequest.GetRequestedAudience() {
if h.r.AudienceStrategy()(accessRequest.GetClient().GetAudience(), []string{audience}) == nil {
accessRequest.GrantAudience(audience)
for _, hook := range h.r.AccessRequestHooks() {
if err = hook(ctx, accessRequest); err != nil {
return err
}
}
}

for _, hook := range h.r.AccessRequestHooks() {
if err := hook(ctx, accessRequest); err != nil {
h.logOrAudit(err, r)
h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err)
events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest))
return
accessResponse, err := h.r.OAuth2Provider().NewAccessResponse(ctx, accessRequest)
if err != nil {
return err
}
}

accessResponse, err := h.r.OAuth2Provider().NewAccessResponse(ctx, accessRequest)
h.r.OAuth2Provider().WriteAccessResponse(ctx, w, accessRequest, accessResponse)

return nil
})

if err != nil {
h.logOrAudit(err, r)
h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err)
events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest))
return
events.Trace(ctx, events.TokenExchangeError)
}

h.r.OAuth2Provider().WriteAccessResponse(ctx, w, accessRequest, accessResponse)
}

// swagger:route GET /oauth2/auth oAuth2 oAuth2Authorize
Expand Down
2 changes: 2 additions & 0 deletions oauth2/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/oauth2/trust"
"github.com/ory/hydra/v2/persistence"
"github.com/ory/hydra/v2/x"
)

Expand All @@ -21,6 +22,7 @@ type InternalRegistry interface {
x.RegistryWriter
x.RegistryLogger
consent.Registry
persistence.Provider
Registry
FlowCipher() *aead.XChaCha20Poly1305
}
Expand Down
1 change: 1 addition & 0 deletions persistence/definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type (
MigrateUp(context.Context) error
PrepareMigration(context.Context) error
Connection(context.Context) *pop.Connection
Transaction(context.Context, func(ctx context.Context, c *pop.Connection) error) error
Ping() error
Networker
}
Expand Down
20 changes: 16 additions & 4 deletions persistence/sql/persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ var _ persistence.Persister = new(Persister)
var _ storage.Transactional = new(Persister)

var (
ErrTransactionOpen = errors.New("There is already a transaction in this context.")
ErrNoTransactionOpen = errors.New("There is no transaction in this context.")
ErrTransactionOpen = errors.New("There is already a Transaction in this context.")
ErrNoTransactionOpen = errors.New("There is no Transaction in this context.")
)

type skipCommitContextKey int

const skipCommitKey skipCommitContextKey = 0

type (
Persister struct {
conn *pop.Connection
Expand Down Expand Up @@ -65,7 +69,7 @@ func (p *Persister) BeginTX(ctx context.Context) (_ context.Context, err error)

fallback := &pop.Connection{TX: &pop.Tx{}}
if popx.GetConnection(ctx, fallback).TX != fallback.TX {
return ctx, errorsx.WithStack(ErrTransactionOpen)
return context.WithValue(ctx, skipCommitKey, true), nil // no-op
}

tx, err := p.conn.Store.TransactionContextOptions(ctx, &sql.TxOptions{
Expand All @@ -85,6 +89,10 @@ func (p *Persister) Commit(ctx context.Context) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Commit")
defer otelx.End(span, &err)

if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip {
return nil // we skipped BeginTX, so we also skip Commit
}

fallback := &pop.Connection{TX: &pop.Tx{}}
tx := popx.GetConnection(ctx, fallback)
if tx.TX == fallback.TX || tx.TX == nil {
Expand All @@ -98,6 +106,10 @@ func (p *Persister) Rollback(ctx context.Context) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Rollback")
defer otelx.End(span, &err)

if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip {
return nil // we skipped BeginTX, so we also skip Rollback
}

fallback := &pop.Connection{TX: &pop.Tx{}}
tx := popx.GetConnection(ctx, fallback)
if tx.TX == fallback.TX || tx.TX == nil {
Expand Down Expand Up @@ -184,6 +196,6 @@ func (p *Persister) mustSetNetwork(nid uuid.UUID, v interface{}) interface{} {
return v
}

func (p *Persister) transaction(ctx context.Context, f func(ctx context.Context, c *pop.Connection) error) error {
func (p *Persister) Transaction(ctx context.Context, f func(ctx context.Context, c *pop.Connection) error) error {
return popx.Transaction(ctx, p.conn, f)
}
2 changes: 1 addition & 1 deletion persistence/sql/persister_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) (err er
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateClient")
defer otelx.End(span, &err)

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
o, err := p.GetConcreteClient(ctx, cl.GetID())
if err != nil {
return err
Expand Down
6 changes: 3 additions & 3 deletions persistence/sql/persister_consent.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ func (p *Persister) RevokeSubjectConsentSession(ctx context.Context, user string
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectConsentSession")
defer span.End()

return p.transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ?", user))
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ?", user))
}

func (p *Persister) RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectClientConsentSession")
defer span.End()

return p.transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ? AND client_id = ?", user, client))
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ? AND client_id = ?", user, client))
}

func (p *Persister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error {
Expand Down Expand Up @@ -117,7 +117,7 @@ func (p *Persister) CreateForcedObfuscatedLoginSession(ctx context.Context, sess
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateForcedObfuscatedLoginSession")
defer span.End()

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
nid := p.NetworkID(ctx)
if err := c.RawQuery(
"DELETE FROM hydra_oauth2_obfuscated_authentication_session WHERE nid = ? AND client_id = ? AND subject = ?",
Expand Down
4 changes: 2 additions & 2 deletions persistence/sql/persister_grant_jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jo
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateGrant")
defer otelx.End(span, &err)

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
// add key, if it doesn't exist
if _, err := p.GetKey(ctx, g.PublicKey.Set, g.PublicKey.KeyID); err != nil {
if !errors.Is(err, sqlcon.ErrNoRows) {
Expand Down Expand Up @@ -59,7 +59,7 @@ func (p *Persister) DeleteGrant(ctx context.Context, id string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteGrant")
defer otelx.End(span, &err)

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
grant, err := p.GetConcreteGrant(ctx, id)
if err != nil {
return sqlcon.HandleError(err)
Expand Down
6 changes: 3 additions & 3 deletions persistence/sql/persister_jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (p *Persister) AddKeySet(ctx context.Context, set string, keys *jose.JSONWe
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AddKey")
defer span.End()

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
for _, key := range keys.Keys {
out, err := json.Marshal(key)
if err != nil {
Expand Down Expand Up @@ -94,7 +94,7 @@ func (p *Persister) UpdateKey(ctx context.Context, set string, key *jose.JSONWeb
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateKey")
defer span.End()

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := p.DeleteKey(ctx, set, key.KeyID); err != nil {
return err
}
Expand All @@ -110,7 +110,7 @@ func (p *Persister) UpdateKeySet(ctx context.Context, set string, keySet *jose.J
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateKeySet")
defer span.End()

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := p.DeleteKeySet(ctx, set); err != nil {
return err
}
Expand Down

0 comments on commit 67a85cc

Please sign in to comment.