Skip to content

Commit

Permalink
fix: restore fallback querying by raw signature for access tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Aug 7, 2023
1 parent 8381da3 commit 3d10f57
Showing 1 changed file with 66 additions and 26 deletions.
92 changes: 66 additions & 26 deletions persistence/sql/persister_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.Bl
return sqlcon.HandleError(p.CreateWithNetwork(ctx, jti))
}

func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) (err error) {
func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) error {
req, err := p.sqlSchemaFromRequest(ctx, signature, requester, table)
if err != nil {
return err
Expand All @@ -228,21 +228,21 @@ func (p *Persister) createSession(ctx context.Context, signature string, request
return nil
}

func (p *Persister) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findSessionBySignature")
defer otelx.End(span, &err)

func (p *Persister) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table tableName) (fosite.Requester, error) {
r := OAuth2RequestSQL{Table: table}
err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r)
err := p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r)
if errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(fosite.ErrNotFound)
} else if err != nil {
}
if err != nil {
return nil, sqlcon.HandleError(err)
} else if !r.Active {
}
if !r.Active {
fr, err := r.toRequest(ctx, session, p)
if err != nil {
return nil, err
} else if table == sqlTableCode {
}
if table == sqlTableCode {
return fr, errorsx.WithStack(fosite.ErrInvalidatedAuthorizeCode)
}
return fr, errorsx.WithStack(fosite.ErrInactiveToken)
Expand All @@ -251,38 +251,35 @@ func (p *Persister) findSessionBySignature(ctx context.Context, signature string
return r.toRequest(ctx, session, p)
}

func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionBySignature")
defer otelx.End(span, &err)

err = sqlcon.HandleError(
func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) error {
err := sqlcon.HandleError(
p.QueryWithNetwork(ctx).
Where("signature = ?", signature).
Delete(&OAuth2RequestSQL{Table: table}))

if errors.Is(err, sqlcon.ErrNoRows) {
return errorsx.WithStack(fosite.ErrNotFound)
} else if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
}
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
} else if err != nil {
return err
}
return nil
return err
}

func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, table tableName) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionByRequestID")
defer otelx.End(span, &err)

/* #nosec G201 table is static */
if err := p.QueryWithNetwork(ctx).
err = p.QueryWithNetwork(ctx).
Where("request_id=?", id).
Delete(&OAuth2RequestSQL{Table: table}); errors.Is(err, sql.ErrNoRows) {
Delete(&OAuth2RequestSQL{Table: table})
if errors.Is(err, sql.ErrNoRows) {
return errorsx.WithStack(fosite.ErrNotFound)
} else if err := sqlcon.HandleError(err); err != nil {
}
if err := sqlcon.HandleError(err); err != nil {
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
} else if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock?
}
if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock?
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
}
return err
Expand Down Expand Up @@ -355,13 +352,56 @@ func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature stri
func (p *Persister) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAccessTokenSession")
defer otelx.End(span, &err)
return p.findSessionBySignature(ctx, SignatureHash(signature), session, sqlTableAccess)

r := OAuth2RequestSQL{Table: sqlTableAccess}
err = p.QueryWithNetwork(ctx).Where("signature = ?", SignatureHash(signature)).First(&r)
if errors.Is(err, sql.ErrNoRows) {
// Backwards compatibility: we previously did not always hash the
// signature before inserting. In case there are still very old (but
// valid) access tokens in the database, this should get them.
err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r)
if errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(fosite.ErrNotFound)
}
}
if err != nil {
return nil, sqlcon.HandleError(err)
}
if !r.Active {
fr, err := r.toRequest(ctx, session, p)
if err != nil {
return nil, err
}
return fr, errorsx.WithStack(fosite.ErrInactiveToken)
}

return r.toRequest(ctx, session, p)
}

func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokenSession")
defer otelx.End(span, &err)
return p.deleteSessionBySignature(ctx, SignatureHash(signature), sqlTableAccess)

err = sqlcon.HandleError(
p.QueryWithNetwork(ctx).
Where("signature = ?", SignatureHash(signature)).
Delete(&OAuth2RequestSQL{Table: sqlTableAccess}))
if errors.Is(err, sqlcon.ErrNoRows) {
// Backwards compatibility: we previously did not always hash the
// signature before inserting. In case there are still very old (but
// valid) access tokens in the database, this should get them.
err = sqlcon.HandleError(
p.QueryWithNetwork(ctx).
Where("signature = ?", signature).
Delete(&OAuth2RequestSQL{Table: sqlTableAccess}))
if errors.Is(err, sqlcon.ErrNoRows) {
return errorsx.WithStack(fosite.ErrNotFound)
}
}
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
}
return err
}

func toEventOptions(requester fosite.Requester) []trace.EventOption {
Expand Down

0 comments on commit 3d10f57

Please sign in to comment.