diff --git a/identity/handler.go b/identity/handler.go index cf85dc792c43..c4f90cec7619 100644 --- a/identity/handler.go +++ b/identity/handler.go @@ -617,13 +617,22 @@ func (h *Handler) batchPatchIdentities(w http.ResponseWriter, r *http.Request, _ } } - if err := h.r.IdentityManager().CreateIdentities(r.Context(), identities); err != nil { + err := h.r.IdentityManager().CreateIdentities(r.Context(), identities) + partialErr := new(CreateIdentitiesError) + if err != nil && !errors.As(err, &partialErr) { h.r.Writer().WriteError(w, r, err) return } for resIdx, identitiesIdx := range indexInIdentities { if identitiesIdx != nil { - res.Identities[resIdx].IdentityID = &identities[*identitiesIdx].ID + ident := identities[*identitiesIdx] + // Check if the identity was created successfully. + if failed := partialErr.Find(ident); failed != nil { + res.Identities[resIdx].Action = ActionError + res.Identities[resIdx].Error = failed.Error + } else { + res.Identities[resIdx].IdentityID = &ident.ID + } } } diff --git a/identity/handler_test.go b/identity/handler_test.go index 2f6714137411..33d6a46adf87 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -759,53 +759,63 @@ func TestHandler(t *testing.T) { assert.Contains(t, res.Get("error.reason").String(), strconv.Itoa(identity.BatchPatchIdentitiesLimit), "the error reason should contain the limit") }) - t.Run("case=fails all on a bad identity", func(t *testing.T) { + t.Run("case=fails some on a bad identity", func(t *testing.T) { // Test setup: we have a list of valid identitiy patches and a list of invalid ones. // Each run adds one invalid patch to the list and sends it to the server. - // --> we expectedIdentifiers the server to fail all patches in the list. + // --> we expect the server to fail all patches in the list. // Finally, we send just the valid patches - // --> we expectedIdentifiers the server to succeed all patches in the list. - validPatches := []*identity.BatchIdentityPatch{ - {Create: validCreateIdentityBody("valid-patch", 0)}, - {Create: validCreateIdentityBody("valid-patch", 1)}, - {Create: validCreateIdentityBody("valid-patch", 2)}, - {Create: validCreateIdentityBody("valid-patch", 3)}, - {Create: validCreateIdentityBody("valid-patch", 4)}, - } + // --> we expect the server to succeed all patches in the list. + + t.Run("case=invalid patches fail", func(t *testing.T) { + patches := []*identity.BatchIdentityPatch{ + {Create: validCreateIdentityBody("valid", 0)}, + {Create: validCreateIdentityBody("valid", 1)}, + {Create: &identity.CreateIdentityBody{}}, // <-- invalid: missing all fields + {Create: validCreateIdentityBody("valid", 2)}, + {Create: validCreateIdentityBody("valid", 0)}, // <-- duplicate + {Create: validCreateIdentityBody("valid", 3)}, + {Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits + {Create: validCreateIdentityBody("valid", 4)}, + } - for _, tt := range []struct { - name string - body *identity.CreateIdentityBody - expectStatus int - }{ - { - name: "missing all fields", - body: &identity.CreateIdentityBody{}, - expectStatus: http.StatusBadRequest, - }, - { - name: "duplicate identity", - body: validCreateIdentityBody("valid-patch", 0), - expectStatus: http.StatusConflict, - }, - { - name: "invalid traits", - body: &identity.CreateIdentityBody{ - Traits: json.RawMessage(`"invalid traits"`), - }, - expectStatus: http.StatusBadRequest, - }, - } { - t.Run("invalid because "+tt.name, func(t *testing.T) { - patches := append([]*identity.BatchIdentityPatch{}, validPatches...) - patches = append(patches, &identity.BatchIdentityPatch{Create: tt.body}) + // Create unique IDs for each patch + var patchIDs []string + for i, p := range patches { + id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i)) + p.ID = &id + patchIDs = append(patchIDs, id.String()) + } - req := &identity.BatchPatchIdentitiesBody{Identities: patches} - send(t, adminTS, "PATCH", "/identities", tt.expectStatus, req) - }) - } + req := &identity.BatchPatchIdentitiesBody{Identities: patches} + body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) + var actions []string + for _, a := range body.Get("identities.#.action").Array() { + actions = append(actions, a.String()) + } + assert.Equal(t, + []string{"create", "create", "error", "create", "error", "create", "error", "create"}, + actions, body) + + // Check that all patch IDs are returned + for i, gotPatchID := range body.Get("identities.#.patch_id").Array() { + assert.Equal(t, patchIDs[i], gotPatchID.String()) + } + + // Check specific errors + assert.Equal(t, "Bad Request", body.Get("identities.2.error.status").String()) + assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String()) + assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String()) + + }) t.Run("valid patches succeed", func(t *testing.T) { + validPatches := []*identity.BatchIdentityPatch{ + {Create: validCreateIdentityBody("valid-patch", 0)}, + {Create: validCreateIdentityBody("valid-patch", 1)}, + {Create: validCreateIdentityBody("valid-patch", 2)}, + {Create: validCreateIdentityBody("valid-patch", 3)}, + {Create: validCreateIdentityBody("valid-patch", 4)}, + } req := &identity.BatchPatchIdentitiesBody{Identities: validPatches} send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) }) diff --git a/identity/identity.go b/identity/identity.go index 0277772c4708..d21cadb36ab3 100644 --- a/identity/identity.go +++ b/identity/identity.go @@ -11,22 +11,17 @@ import ( "sync" "time" + "github.com/gofrs/uuid" + "github.com/pkg/errors" "github.com/samber/lo" - - "github.com/tidwall/sjson" - "github.com/tidwall/gjson" - - "github.com/ory/kratos/cipher" + "github.com/tidwall/sjson" "github.com/ory/herodot" + "github.com/ory/kratos/cipher" + "github.com/ory/kratos/driver/config" "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/sqlxx" - - "github.com/ory/kratos/driver/config" - - "github.com/gofrs/uuid" - "github.com/pkg/errors" ) // An Identity's State @@ -645,6 +640,9 @@ const ( // Create this identity. ActionCreate BatchPatchAction = "create" + // Error indicates that the patch failed. + ActionError BatchPatchAction = "error" + // Future actions: // // Delete this identity. @@ -677,4 +675,7 @@ type BatchIdentityPatchResponse struct { // The ID of this patch response, if an ID was specified in the patch. PatchID *uuid.UUID `json:"patch_id,omitempty"` + + // The error message, if the action was "error". + Error *herodot.DefaultError `json:"error,omitempty"` } diff --git a/identity/manager.go b/identity/manager.go index f35fd1468710..3bc5b08e0158 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -6,6 +6,7 @@ package identity import ( "context" "encoding/json" + "fmt" "reflect" "slices" "sort" @@ -323,26 +324,91 @@ func (e *ErrDuplicateCredentials) HasHints() bool { return len(e.availableCredentials) > 0 || len(e.availableOIDCProviders) > 0 || len(e.identifierHint) > 0 } +type FailedIdentity struct { + Identity *Identity + Error *herodot.DefaultError +} + +type CreateIdentitiesError struct { + failedIdentities map[*Identity]*herodot.DefaultError +} + +func (e *CreateIdentitiesError) Error() string { + e.init() + return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities)) +} +func (e *CreateIdentitiesError) Unwrap() []error { + e.init() + var errs []error + for _, err := range e.failedIdentities { + errs = append(errs, err) + } + return errs +} + +func (e *CreateIdentitiesError) AddFailedIdentity(ident *Identity, err *herodot.DefaultError) { + e.init() + e.failedIdentities[ident] = err +} +func (e *CreateIdentitiesError) Merge(other *CreateIdentitiesError) { + e.init() + for k, v := range other.failedIdentities { + e.failedIdentities[k] = v + } +} +func (e *CreateIdentitiesError) Contains(ident *Identity) bool { + e.init() + _, found := e.failedIdentities[ident] + return found +} +func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity { + e.init() + if err, found := e.failedIdentities[ident]; found { + return &FailedIdentity{Identity: ident, Error: err} + } + + return nil +} +func (e *CreateIdentitiesError) ErrOrNil() error { + if e.failedIdentities == nil || len(e.failedIdentities) == 0 { + return nil + } + return e +} +func (e *CreateIdentitiesError) init() { + if e.failedIdentities == nil { + e.failedIdentities = map[*Identity]*herodot.DefaultError{} + } +} + func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, opts ...ManagerOption) (err error) { ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities") defer otelx.End(span, &err) - for _, i := range identities { - if i.SchemaID == "" { - i.SchemaID = m.r.Config().DefaultIdentityTraitsSchemaID(ctx) + createIdentitiesError := &CreateIdentitiesError{} + validIdentities := make([]*Identity, 0, len(identities)) + for _, ident := range identities { + if ident.SchemaID == "" { + ident.SchemaID = m.r.Config().DefaultIdentityTraitsSchemaID(ctx) } o := newManagerOptions(opts) - if err := m.ValidateIdentity(ctx, i, o); err != nil { - return err + if err := m.ValidateIdentity(ctx, ident, o); err != nil { + createIdentitiesError.AddFailedIdentity(ident, herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err)) + continue } + validIdentities = append(validIdentities, ident) } - if err := m.r.PrivilegedIdentityPool().CreateIdentities(ctx, identities...); err != nil { - return err + if err := m.r.PrivilegedIdentityPool().CreateIdentities(ctx, validIdentities...); err != nil { + if partialErr := new(CreateIdentitiesError); errors.As(err, &partialErr) { + createIdentitiesError.Merge(partialErr) + } else { + return err + } } - return nil + return createIdentitiesError.ErrOrNil() } func (m *Manager) requiresPrivilegedAccess(ctx context.Context, original, updated *Identity, o *ManagerOptions) (err error) { diff --git a/identity/pool.go b/identity/pool.go index 8a94aad3e075..86559f0a8a3f 100644 --- a/identity/pool.go +++ b/identity/pool.go @@ -61,9 +61,13 @@ type ( FindByCredentialsIdentifier(ctx context.Context, ct CredentialsType, match string) (*Identity, *Credentials, error) // DeleteIdentity removes an identity by its id. Will return an error - // if identity exists, backend connectivity is broken, or trait validation fails. + // if identity does not exists, or backend connectivity is broken. DeleteIdentity(context.Context, uuid.UUID) error + // DeleteIdentities removes identities by its id. Will return an error + // if any identity does not exists, or backend connectivity is broken. + DeleteIdentities(context.Context, []uuid.UUID) error + // UpdateVerifiableAddress updates an identity's verifiable address. UpdateVerifiableAddress(ctx context.Context, address *VerifiableAddress) error diff --git a/internal/client-go/model_identity_patch_response.go b/internal/client-go/model_identity_patch_response.go index 2ee305f7da81..f67224edad01 100644 --- a/internal/client-go/model_identity_patch_response.go +++ b/internal/client-go/model_identity_patch_response.go @@ -17,8 +17,9 @@ import ( // IdentityPatchResponse Response for a single identity patch type IdentityPatchResponse struct { - // The action for this specific patch create ActionCreate Create this identity. - Action *string `json:"action,omitempty"` + // The action for this specific patch create ActionCreate Create this identity. error ActionError Error indicates that the patch failed. + Action *string `json:"action,omitempty"` + Error interface{} `json:"error,omitempty"` // The identity ID payload of this patch Identity *string `json:"identity,omitempty"` // The ID of this patch response, if an ID was specified in the patch. @@ -74,6 +75,39 @@ func (o *IdentityPatchResponse) SetAction(v string) { o.Action = &v } +// GetError returns the Error field value if set, zero value otherwise (both if not set or set to explicit null). +func (o *IdentityPatchResponse) GetError() interface{} { + if o == nil { + var ret interface{} + return ret + } + return o.Error +} + +// GetErrorOk returns a tuple with the Error field value if set, nil otherwise +// and a boolean to check if the value has been set. +// NOTE: If the value is an explicit nil, `nil, true` will be returned +func (o *IdentityPatchResponse) GetErrorOk() (*interface{}, bool) { + if o == nil || o.Error == nil { + return nil, false + } + return &o.Error, true +} + +// HasError returns a boolean if a field has been set. +func (o *IdentityPatchResponse) HasError() bool { + if o != nil && o.Error != nil { + return true + } + + return false +} + +// SetError gets a reference to the given interface{} and assigns it to the Error field. +func (o *IdentityPatchResponse) SetError(v interface{}) { + o.Error = v +} + // GetIdentity returns the Identity field value if set, zero value otherwise. func (o *IdentityPatchResponse) GetIdentity() string { if o == nil || o.Identity == nil { @@ -143,6 +177,9 @@ func (o IdentityPatchResponse) MarshalJSON() ([]byte, error) { if o.Action != nil { toSerialize["action"] = o.Action } + if o.Error != nil { + toSerialize["error"] = o.Error + } if o.Identity != nil { toSerialize["identity"] = o.Identity } diff --git a/internal/httpclient/model_identity_patch_response.go b/internal/httpclient/model_identity_patch_response.go index 2ee305f7da81..f67224edad01 100644 --- a/internal/httpclient/model_identity_patch_response.go +++ b/internal/httpclient/model_identity_patch_response.go @@ -17,8 +17,9 @@ import ( // IdentityPatchResponse Response for a single identity patch type IdentityPatchResponse struct { - // The action for this specific patch create ActionCreate Create this identity. - Action *string `json:"action,omitempty"` + // The action for this specific patch create ActionCreate Create this identity. error ActionError Error indicates that the patch failed. + Action *string `json:"action,omitempty"` + Error interface{} `json:"error,omitempty"` // The identity ID payload of this patch Identity *string `json:"identity,omitempty"` // The ID of this patch response, if an ID was specified in the patch. @@ -74,6 +75,39 @@ func (o *IdentityPatchResponse) SetAction(v string) { o.Action = &v } +// GetError returns the Error field value if set, zero value otherwise (both if not set or set to explicit null). +func (o *IdentityPatchResponse) GetError() interface{} { + if o == nil { + var ret interface{} + return ret + } + return o.Error +} + +// GetErrorOk returns a tuple with the Error field value if set, nil otherwise +// and a boolean to check if the value has been set. +// NOTE: If the value is an explicit nil, `nil, true` will be returned +func (o *IdentityPatchResponse) GetErrorOk() (*interface{}, bool) { + if o == nil || o.Error == nil { + return nil, false + } + return &o.Error, true +} + +// HasError returns a boolean if a field has been set. +func (o *IdentityPatchResponse) HasError() bool { + if o != nil && o.Error != nil { + return true + } + + return false +} + +// SetError gets a reference to the given interface{} and assigns it to the Error field. +func (o *IdentityPatchResponse) SetError(v interface{}) { + o.Error = v +} + // GetIdentity returns the Identity field value if set, zero value otherwise. func (o *IdentityPatchResponse) GetIdentity() string { if o == nil || o.Identity == nil { @@ -143,6 +177,9 @@ func (o IdentityPatchResponse) MarshalJSON() ([]byte, error) { if o.Action != nil { toSerialize["action"] = o.Action } + if o.Error != nil { + toSerialize["error"] = o.Error + } if o.Identity != nil { toSerialize["identity"] = o.Identity } diff --git a/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json b/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json index 9c8e755cacd5..04c9db394e80 100644 --- a/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json +++ b/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json @@ -1,6 +1,6 @@ [ - "0001-01-01T00:00:00Z", - "0001-01-01T00:00:00Z", + "2023-01-01T00:00:00Z", + "2023-01-01T00:00:00Z", "string", 42, null, diff --git a/persistence/sql/batch/create.go b/persistence/sql/batch/create.go index 38254a3b2a80..30b3c9768a7a 100644 --- a/persistence/sql/batch/create.go +++ b/persistence/sql/batch/create.go @@ -8,23 +8,21 @@ import ( "database/sql" "fmt" "reflect" + "slices" "sort" "strings" "time" + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" "github.com/jmoiron/sqlx/reflectx" + "github.com/pkg/errors" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "github.com/ory/x/dbal" - - "github.com/gobuffalo/pop/v6" - "github.com/gofrs/uuid" - "github.com/pkg/errors" - "github.com/ory/x/otelx" "github.com/ory/x/sqlcon" - "github.com/ory/x/sqlxx" ) @@ -42,9 +40,32 @@ type ( Tracer *otelx.Tracer Connection *pop.Connection } + + // PartialConflictError represents a partial conflict during [Create]. It always + // wraps a [sqlcon.ErrUniqueViolation], so that the caller can either abort the + // whole transaction, or handle the partial success. + PartialConflictError[T any] struct { + Failed []*T + } ) -func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *reflectx.Mapper, quoter quoter, models []*T) insertQueryArgs { +func (p *PartialConflictError[T]) Error() string { + return fmt.Sprintf("partial conflict error: %d models failed to insert", len(p.Failed)) +} +func (p *PartialConflictError[T]) ErrOrNil() error { + if len(p.Failed) == 0 { + return nil + } + return p +} +func (p *PartialConflictError[T]) Unwrap() error { + if len(p.Failed) == 0 { + return nil + } + return sqlcon.ErrUniqueViolation +} + +func buildInsertQueryArgs[T any](ctx context.Context, models []*T, opts *createOpts) insertQueryArgs { var ( v T model = pop.NewModel(v, ctx) @@ -64,7 +85,7 @@ func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *re sort.Strings(columns) for _, col := range columns { - quotedColumns = append(quotedColumns, quoter.Quote(col)) + quotedColumns = append(quotedColumns, opts.quoter.Quote(col)) } // We generate a list (for every row one) of VALUE statements here that @@ -88,37 +109,36 @@ func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *re continue } - field := mapper.FieldByName(m, columns[k]) + field := opts.mapper.FieldByName(m, columns[k]) val, ok := field.Interface().(uuid.UUID) if !ok { continue } - if val == uuid.Nil && dialect == dbal.DriverCockroachDB { + if val == uuid.Nil && opts.dialect == dbal.DriverCockroachDB && !opts.partialInserts { pl[k] = "gen_random_uuid()" break } } - placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(pl, ", "))) } return insertQueryArgs{ - TableName: quoter.Quote(model.TableName()), + TableName: opts.quoter.Quote(model.TableName()), ColumnsDecl: strings.Join(quotedColumns, ", "), Columns: columns, Placeholders: strings.Join(placeholders, ",\n"), } } -func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, columns []string, models []*T, nowFunc func() time.Time) (values []any, err error) { +func buildInsertQueryValues[T any](columns []string, models []*T, opts *createOpts) (values []any, err error) { for _, m := range models { m := reflect.ValueOf(m) - now := nowFunc() + now := opts.now() // Append model fields to args for _, c := range columns { - field := mapper.FieldByName(m, c) + field := opts.mapper.FieldByName(m, c) switch c { case "created_at": @@ -130,7 +150,7 @@ func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, colu case "id": if field.Interface().(uuid.UUID) != uuid.Nil { break // breaks switch, not for - } else if dialect == dbal.DriverCockroachDB { + } else if opts.dialect == dbal.DriverCockroachDB && !opts.partialInserts { // This is a special case: // 1. We're using cockroach // 2. It's the primary key field ("ID") @@ -167,9 +187,51 @@ func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, colu return values, nil } -// Create batch-inserts the given models into the database using a single INSERT statement. -// The models are either all created or none. -func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err error) { +type createOpts struct { + partialInserts bool + dialect string + mapper *reflectx.Mapper + quoter quoter + now func() time.Time +} + +type CreateOpts func(*createOpts) + +// WithPartialInserts allows to insert only the models that do not conflict with +// an existing record. WithPartialInserts will also generate the IDs for the +// models before inserting them, so that the successful inserts can be correlated +// with the input models. +// +// In particular, WithPartialInserts does not work with MySQL, because it does +// not support the "RETURNING" clause. +// +// WithPartialInserts does not work with CockroachDB and gen_random_uuid(), +// because then the successful inserts cannot be correlated with the input +// models. Note: gen_random_uuid() will skip the UNIQUE constraint check, which +// needs to hit all regions in a distributed setup. Therefore, WithPartialInserts +// should not be used to insert models for only a single identity. +var WithPartialInserts CreateOpts = func(o *createOpts) { + o.partialInserts = true +} + +func newCreateOpts(conn *pop.Connection, opts ...CreateOpts) *createOpts { + o := new(createOpts) + o.dialect = conn.Dialect.Name() + o.mapper = conn.TX.Mapper + o.quoter = conn.Dialect.(quoter) + o.now = func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) } + for _, f := range opts { + f(o) + } + return o +} + +// Create batch-inserts the given models into the database using a single INSERT +// statement. By default, the models are either all created or none. If +// [WithPartialInserts] is passed as an option, partial inserts are supported, +// and the models that could not be inserted are returned in an +// [PartialConflictError]. +func Create[T any](ctx context.Context, p *TracerConnection, models []*T, opts ...CreateOpts) (err error) { ctx, span := p.Tracer.Tracer().Start(ctx, "persistence.sql.batch.Create", trace.WithAttributes(attribute.Int("count", len(models)))) defer otelx.End(span, &err) @@ -182,13 +244,10 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e model := pop.NewModel(v, ctx) conn := p.Connection - quoter, ok := conn.Dialect.(quoter) - if !ok { - return errors.Errorf("store is not a quoter: %T", conn.Store) - } + options := newCreateOpts(conn, opts...) - queryArgs := buildInsertQueryArgs(ctx, conn.Dialect.Name(), conn.TX.Mapper, quoter, models) - values, err := buildInsertQueryValues(conn.Dialect.Name(), conn.TX.Mapper, queryArgs.Columns, models, func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) }) + queryArgs := buildInsertQueryArgs(ctx, models, options) + values, err := buildInsertQueryValues(queryArgs.Columns, models, options) if err != nil { return err } @@ -196,7 +255,11 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e var returningClause string if conn.Dialect.Name() != dbal.DriverMySQL { // PostgreSQL, CockroachDB, SQLite support RETURNING. - returningClause = fmt.Sprintf("RETURNING %s", model.IDField()) + if options.partialInserts { + returningClause = fmt.Sprintf("ON CONFLICT DO NOTHING RETURNING %s", model.IDField()) + } else { + returningClause = fmt.Sprintf("RETURNING %s", model.IDField()) + } } query := conn.Dialect.TranslateSQL(fmt.Sprintf( @@ -213,66 +276,84 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e } defer rows.Close() + // MySQL, which does not support RETURNING, also does not have ON CONFLICT DO + // NOTHING, meaning that MySQL will always fail the whole transaction on a single + // record conflict. + if conn.Dialect.Name() == dbal.DriverMySQL { + return nil + } + + if options.partialInserts { + return handlePartialInserts(queryArgs, values, models, rows) + } else { + return handleFullInserts(models, rows) + } + +} + +func handleFullInserts[T any](models []*T, rows *sql.Rows) error { // Hydrate the models from the RETURNING clause. - // - // Databases not supporting RETURNING will just return 0 rows. - count := 0 - for rows.Next() { - if err := rows.Err(); err != nil { - return sqlcon.HandleError(err) + for i := 0; rows.Next(); i++ { + var id uuid.UUID + if err := rows.Scan(&id); err != nil { + return errors.WithStack(err) } - - if err := setModelID(rows, pop.NewModel(models[count], ctx)); err != nil { + if err := setModelID(id, models[i]); err != nil { return err } - count++ } - if err := rows.Err(); err != nil { return sqlcon.HandleError(err) } - if err := rows.Close(); err != nil { - return sqlcon.HandleError(err) - } - - return sqlcon.HandleError(err) + return nil } -// setModelID was copy & pasted from pop. It basically sets -// the primary key to the given value read from the SQL row. -func setModelID(row *sql.Rows, model *pop.Model) error { - el := reflect.ValueOf(model.Value).Elem() - fbn := el.FieldByName("ID") - if !fbn.IsValid() { - return errors.New("model does not have a field named id") +func handlePartialInserts[T any](queryArgs insertQueryArgs, values []any, models []*T, rows *sql.Rows) error { + // Hydrate the models from the RETURNING clause. + idsInDB := make(map[uuid.UUID]struct{}) + for rows.Next() { + var id uuid.UUID + if err := rows.Scan(&id); err != nil { + return errors.WithStack(err) + } + idsInDB[id] = struct{}{} + } + if err := rows.Err(); err != nil { + return sqlcon.HandleError(err) } - pkt, err := model.PrimaryKeyType() - if err != nil { - return errors.WithStack(err) + idIdx := slices.Index(queryArgs.Columns, "id") + if idIdx == -1 { + return errors.New("id column not found") + } + var idValues []uuid.UUID + for i := idIdx; i < len(values); i += len(queryArgs.Columns) { + idValues = append(idValues, values[i].(uuid.UUID)) } - switch pkt { - case "UUID": - var id uuid.UUID - if err := row.Scan(&id); err != nil { - return errors.WithStack(err) - } - fbn.Set(reflect.ValueOf(id)) - default: - var id interface{} - if err := row.Scan(&id); err != nil { - return errors.WithStack(err) - } - v := reflect.ValueOf(id) - switch fbn.Kind() { - case reflect.Int, reflect.Int64: - fbn.SetInt(v.Int()) - default: - fbn.Set(reflect.ValueOf(id)) + var partialConflictError PartialConflictError[T] + for i, id := range idValues { + if _, ok := idsInDB[id]; !ok { + partialConflictError.Failed = append(partialConflictError.Failed, models[i]) + } else { + if err := setModelID(id, models[i]); err != nil { + return err + } } } + return partialConflictError.ErrOrNil() +} + +// setModelID sets the id field of the model to the id. +func setModelID(id uuid.UUID, model any) error { + el := reflect.ValueOf(model).Elem() + idField := el.FieldByName("ID") + if !idField.IsValid() { + return errors.New("model does not have a field named id") + } + idField.Set(reflect.ValueOf(id)) + return nil } diff --git a/persistence/sql/batch/create_test.go b/persistence/sql/batch/create_test.go index f5c81664a486..131589478e6e 100644 --- a/persistence/sql/batch/create_test.go +++ b/persistence/sql/batch/create_test.go @@ -9,14 +9,13 @@ import ( "testing" "time" - "github.com/ory/x/dbal" - "github.com/gofrs/uuid" "github.com/jmoiron/sqlx/reflectx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ory/kratos/identity" + "github.com/ory/x/dbal" "github.com/ory/x/snapshotx" "github.com/ory/x/sqlxx" ) @@ -53,8 +52,11 @@ func Test_buildInsertQueryArgs(t *testing.T) { ctx := context.Background() t.Run("case=testModel", func(t *testing.T) { models := makeModels[testModel]() - mapper := reflectx.NewMapper("db") - args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) + opts := &createOpts{ + dialect: "other", + quoter: testQuoter{}, + mapper: reflectx.NewMapper("db")} + args := buildInsertQueryArgs(ctx, models, opts) snapshotx.SnapshotT(t, args) query := fmt.Sprintf("INSERT INTO %s (%s) VALUES\n%s", args.TableName, args.ColumnsDecl, args.Placeholders) @@ -73,22 +75,31 @@ func Test_buildInsertQueryArgs(t *testing.T) { t.Run("case=Identities", func(t *testing.T) { models := makeModels[identity.Identity]() - mapper := reflectx.NewMapper("db") - args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) + opts := &createOpts{ + dialect: "other", + quoter: testQuoter{}, + mapper: reflectx.NewMapper("db")} + args := buildInsertQueryArgs(ctx, models, opts) snapshotx.SnapshotT(t, args) }) t.Run("case=RecoveryAddress", func(t *testing.T) { models := makeModels[identity.RecoveryAddress]() - mapper := reflectx.NewMapper("db") - args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) + opts := &createOpts{ + dialect: "other", + quoter: testQuoter{}, + mapper: reflectx.NewMapper("db")} + args := buildInsertQueryArgs(ctx, models, opts) snapshotx.SnapshotT(t, args) }) t.Run("case=RecoveryAddress", func(t *testing.T) { models := makeModels[identity.RecoveryAddress]() - mapper := reflectx.NewMapper("db") - args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) + opts := &createOpts{ + dialect: "other", + quoter: testQuoter{}, + mapper: reflectx.NewMapper("db")} + args := buildInsertQueryArgs(ctx, models, opts) snapshotx.SnapshotT(t, args) }) @@ -99,8 +110,11 @@ func Test_buildInsertQueryArgs(t *testing.T) { models[k].ID = uuid.FromStringOrNil(fmt.Sprintf("ae0125a9-2786-4ada-82d2-d169cf75047%d", k)) } } - mapper := reflectx.NewMapper("db") - args := buildInsertQueryArgs(ctx, "cockroach", mapper, testQuoter{}, models) + opts := &createOpts{ + dialect: dbal.DriverCockroachDB, + quoter: testQuoter{}, + mapper: reflectx.NewMapper("db")} + args := buildInsertQueryArgs(ctx, models, opts) snapshotx.SnapshotT(t, args) }) } @@ -112,25 +126,38 @@ func Test_buildInsertQueryValues(t *testing.T) { Int: 42, Traits: []byte(`{"foo": "bar"}`), } - mapper := reflectx.NewMapper("db") - nowFunc := func() time.Time { - return time.Time{} + frozenTime := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + opts := &createOpts{ + mapper: reflectx.NewMapper("db"), + quoter: testQuoter{}, + now: func() time.Time { return frozenTime }, } + t.Run("case=cockroach", func(t *testing.T) { - values, err := buildInsertQueryValues(dbal.DriverCockroachDB, mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) + opts.dialect = dbal.DriverCockroachDB + values, err := buildInsertQueryValues( + []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, + []*testModel{model}, + opts, + ) require.NoError(t, err) snapshotx.SnapshotT(t, values) }) t.Run("case=others", func(t *testing.T) { - values, err := buildInsertQueryValues("other", mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) + opts.dialect = "other" + values, err := buildInsertQueryValues( + []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, + []*testModel{model}, + opts, + ) require.NoError(t, err) - assert.NotNil(t, model.CreatedAt) + assert.Equal(t, frozenTime, model.CreatedAt) assert.Equal(t, model.CreatedAt, values[0]) - assert.NotNil(t, model.UpdatedAt) + assert.Equal(t, frozenTime, model.UpdatedAt) assert.Equal(t, model.UpdatedAt, values[1]) assert.NotZero(t, model.ID) diff --git a/persistence/sql/batch/test_persister.go b/persistence/sql/batch/test_persister.go new file mode 100644 index 000000000000..be0e9ac7a4c5 --- /dev/null +++ b/persistence/sql/batch/test_persister.go @@ -0,0 +1,112 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package batch + +import ( + "context" + "errors" + "testing" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/identity" + "github.com/ory/kratos/persistence" + "github.com/ory/x/dbal" + "github.com/ory/x/otelx" + "github.com/ory/x/sqlcon" +) + +func TestPersister(ctx context.Context, tracer *otelx.Tracer, p persistence.Persister) func(t *testing.T) { + return func(t *testing.T) { + t.Run("method=batch.Create", func(t *testing.T) { + + ident1 := identity.NewIdentity("") + ident1.NID = p.NetworkID(ctx) + ident2 := identity.NewIdentity("") + ident2.NID = p.NetworkID(ctx) + + // Create two identities + _ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + conn := &TracerConnection{ + Tracer: tracer, + Connection: tx, + } + + err := Create(ctx, conn, []*identity.Identity{ident1, ident2}) + require.NoError(t, err) + + return nil + }) + + require.NotEqual(t, uuid.Nil, ident1.ID) + require.NotEqual(t, uuid.Nil, ident2.ID) + + // Create conflicting verifiable addresses + addresses := []*identity.VerifiableAddress{{ + Value: "foo.1@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.2@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "conflict@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.3@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "conflict@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.4@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }} + + t.Run("case=fails all without partial inserts", func(t *testing.T) { + _ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + conn := &TracerConnection{ + Tracer: tracer, + Connection: tx, + } + err := Create(ctx, conn, addresses) + assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation) + if partial := new(PartialConflictError[identity.VerifiableAddress]); errors.As(err, &partial) { + require.NoError(t, partial, "expected no partial error") + } + return err + }) + }) + + t.Run("case=return partial error with partial inserts", func(t *testing.T) { + _ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + conn := &TracerConnection{ + Tracer: tracer, + Connection: tx, + } + + err := Create(ctx, conn, addresses, WithPartialInserts) + assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation) + + if conn.Connection.Dialect.Name() != dbal.DriverMySQL { + // MySQL does not support partial errors. + partialErr := new(PartialConflictError[identity.VerifiableAddress]) + require.ErrorAs(t, err, &partialErr) + assert.Len(t, partialErr.Failed, 1) + } + + return nil + }) + }) + }) + } +} diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 127167613b98..afa525184f03 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -326,6 +326,11 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn identifiers []*identity.CredentialIdentifier ) + var opts []batch.CreateOpts + if len(identities) > 1 { + opts = append(opts, batch.WithPartialInserts) + } + for _, ident := range identities { for k := range ident.Credentials { cred := ident.Credentials[k] @@ -351,7 +356,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn ident.Credentials[k] = cred } } - if err = batch.Create(ctx, traceConn, credentials); err != nil { + if err = batch.Create(ctx, traceConn, credentials, opts...); err != nil { return err } @@ -379,7 +384,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn } } - if err = batch.Create(ctx, traceConn, identifiers); err != nil { + if err = batch.Create(ctx, traceConn, identifiers, opts...); err != nil { return err } @@ -399,8 +404,12 @@ func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, conn work = append(work, &id.VerifiableAddresses[i]) } } + var opts []batch.CreateOpts + if len(identities) > 1 { + opts = append(opts, batch.WithPartialInserts) + } - return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work) + return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work, opts...) } func updateAssociation[T interface { @@ -511,7 +520,12 @@ func (p *IdentityPersister) createRecoveryAddresses(ctx context.Context, conn *p } } - return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work) + var opts []batch.CreateOpts + if len(identities) > 1 { + opts = append(opts, batch.WithPartialInserts) + } + + return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work, opts...) } func (p *IdentityPersister) CountIdentities(ctx context.Context) (n int64, err error) { @@ -576,21 +590,77 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... Connection: tx, } + // Don't use batch.WithPartialInserts, because identities have no other + // constraints other than the primary key that could cause conflicts. if err := batch.Create(ctx, conn, identities); err != nil { return sqlcon.HandleError(err) } p.normalizeAllAddressess(ctx, identities...) + failedIdentityIDs := make(map[uuid.UUID]struct{}) + if err = p.createVerifiableAddresses(ctx, tx, identities...); err != nil { - return sqlcon.HandleError(err) + if paritalErr := new(batch.PartialConflictError[identity.VerifiableAddress]); errors.As(err, &paritalErr) { + for _, k := range paritalErr.Failed { + failedIdentityIDs[k.IdentityID] = struct{}{} + } + } else { + return sqlcon.HandleError(err) + } } if err = p.createRecoveryAddresses(ctx, tx, identities...); err != nil { - return sqlcon.HandleError(err) + if paritalErr := new(batch.PartialConflictError[identity.RecoveryAddress]); errors.As(err, &paritalErr) { + for _, k := range paritalErr.Failed { + failedIdentityIDs[k.IdentityID] = struct{}{} + } + } else { + return sqlcon.HandleError(err) + } } if err = p.createIdentityCredentials(ctx, tx, identities...); err != nil { - return sqlcon.HandleError(err) + if paritalErr := new(batch.PartialConflictError[identity.Credentials]); errors.As(err, &paritalErr) { + for _, k := range paritalErr.Failed { + failedIdentityIDs[k.IdentityID] = struct{}{} + } + + } else if paritalErr := new(batch.PartialConflictError[identity.CredentialIdentifier]); errors.As(err, &paritalErr) { + for _, k := range paritalErr.Failed { + credID := k.IdentityCredentialsID + for _, ident := range identities { + for _, cred := range ident.Credentials { + if cred.ID == credID { + failedIdentityIDs[ident.ID] = struct{}{} + } + } + } + } + } else { + return sqlcon.HandleError(err) + } + } + + // If any of the batch inserts failed on conflict, let's delete the corresponding + // identities and return a list of failed identities in the error. + if len(failedIdentityIDs) > 0 { + partialErr := &identity.CreateIdentitiesError{} + failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs)) + for _, ident := range identities { + if _, ok := failedIdentityIDs[ident.ID]; ok { + partialErr.AddFailedIdentity(ident, sqlcon.ErrUniqueViolation) + failedIDs = append(failedIDs, ident.ID) + } + } + // Manually roll back by deleting the identities that were inserted before the + // error occurred. + if err := p.DeleteIdentities(ctx, failedIDs); err != nil { + return sqlcon.HandleError(err) + } + // Wrap the partial error with the first error that occurred, so that the caller + // can continue to handle the error either as a partial error or a full error. + return partialErr } + return nil }); err != nil { return err @@ -1052,6 +1122,44 @@ func (p *IdentityPersister) DeleteIdentity(ctx context.Context, id uuid.UUID) (e return nil } +func (p *IdentityPersister) DeleteIdentities(ctx context.Context, ids []uuid.UUID) (err error) { + stringIDs := make([]string, len(ids)) + for k, id := range ids { + stringIDs[k] = id.String() + } + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteIdentites", + trace.WithAttributes( + attribute.StringSlice("identity.ids", stringIDs), + attribute.Stringer("network.id", p.NetworkID(ctx)))) + defer otelx.End(span, &err) + + placeholders := strings.TrimSuffix(strings.Repeat("?, ", len(ids)), ", ") + args := make([]any, 0, len(ids)+1) + for _, id := range ids { + args = append(args, id) + } + args = append(args, p.NetworkID(ctx)) + + tableName := new(identity.Identity).TableName(ctx) + if p.c.Dialect.Name() == "cockroach" { + tableName += "@primary" + } + count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( + "DELETE FROM %s WHERE id IN (%s) AND nid = ?", + tableName, + placeholders, + ), + args..., + ).ExecWithCount() + if err != nil { + return sqlcon.HandleError(err) + } + if count != len(ids) { + return errors.WithStack(sqlcon.ErrNoRows) + } + return nil +} + func (p *IdentityPersister) GetIdentity(ctx context.Context, id uuid.UUID, expand identity.Expandables) (_ *identity.Identity, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetIdentity", trace.WithAttributes( diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index 57593577c8c7..3029cdc51ef0 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -29,6 +29,7 @@ import ( "github.com/ory/kratos/internal" "github.com/ory/kratos/internal/testhelpers" "github.com/ory/kratos/persistence/sql" + "github.com/ory/kratos/persistence/sql/batch" sqltesthelpers "github.com/ory/kratos/persistence/sql/testhelpers" "github.com/ory/kratos/schema" errorx "github.com/ory/kratos/selfservice/errorx/test" @@ -264,6 +265,10 @@ func TestPersister(t *testing.T) { t.Parallel() continuity.TestPersister(ctx, p)(t) }) + t.Run("contract=batch.TestPersister", func(t *testing.T) { + t.Parallel() + batch.TestPersister(ctx, reg.Tracer(ctx), p)(t) + }) }) } } diff --git a/spec/api.json b/spec/api.json index c8f296913fa0..f0331bee92e5 100644 --- a/spec/api.json +++ b/spec/api.json @@ -1163,12 +1163,16 @@ "description": "Response for a single identity patch", "properties": { "action": { - "description": "The action for this specific patch\ncreate ActionCreate Create this identity.", + "description": "The action for this specific patch\ncreate ActionCreate Create this identity.\nerror ActionError Error indicates that the patch failed.", "enum": [ - "create" + "create", + "error" ], "type": "string", - "x-go-enum-desc": "create ActionCreate Create this identity." + "x-go-enum-desc": "create ActionCreate Create this identity.\nerror ActionError Error indicates that the patch failed." + }, + "error": { + "$ref": "#/components/schemas/DefaultError" }, "identity": { "description": "The identity ID payload of this patch", diff --git a/spec/swagger.json b/spec/swagger.json index 2952837d4c09..12ba5f44a9e6 100755 --- a/spec/swagger.json +++ b/spec/swagger.json @@ -4275,12 +4275,16 @@ "type": "object", "properties": { "action": { - "description": "The action for this specific patch\ncreate ActionCreate Create this identity.", + "description": "The action for this specific patch\ncreate ActionCreate Create this identity.\nerror ActionError Error indicates that the patch failed.", "type": "string", "enum": [ - "create" + "create", + "error" ], - "x-go-enum-desc": "create ActionCreate Create this identity." + "x-go-enum-desc": "create ActionCreate Create this identity.\nerror ActionError Error indicates that the patch failed." + }, + "error": { + "$ref": "#/definitions/DefaultError" }, "identity": { "description": "The identity ID payload of this patch", diff --git a/test/e2e/cypress/integration/profiles/email/logout/success.spec.ts b/test/e2e/cypress/integration/profiles/email/logout/success.spec.ts index f13358166afb..3926f21012b1 100644 --- a/test/e2e/cypress/integration/profiles/email/logout/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/email/logout/success.spec.ts @@ -75,12 +75,12 @@ context("Testing logout flows", () => { cy.visit(settings, { qs: { - return_to: "https://www.ory.sh", + return_to: "https://www.example.org", }, }) cy.get("a[href*='logout']").click() - cy.location("host").should("eq", "www.ory.sh") + cy.location("host").should("eq", "www.example.org") }) it("should be able to sign out on welcome page", () => { @@ -94,12 +94,12 @@ context("Testing logout flows", () => { cy.visit(welcome, { qs: { - return_to: "https://www.ory.sh", + return_to: "https://www.example.org", }, }) cy.get("a[href*='logout']").click() - cy.location("host").should("eq", "www.ory.sh") + cy.location("host").should("eq", "www.example.org") }) it("should be able to sign out at 2fa page", () => { @@ -122,7 +122,7 @@ context("Testing logout flows", () => { cy.logout() cy.visit(route, { qs: { - return_to: "https://www.ory.sh", + return_to: "https://www.example.org", }, }) @@ -135,7 +135,7 @@ context("Testing logout flows", () => { cy.get("a[href*='logout']").click() - cy.location("host").should("eq", "www.ory.sh") + cy.location("host").should("eq", "www.example.org") cy.useLookupSecrets(false) }) }) diff --git a/test/e2e/cypress/integration/profiles/mfa/lookup.spec.ts b/test/e2e/cypress/integration/profiles/mfa/lookup.spec.ts index d3a0e8720cb8..1bb4e2f94898 100644 --- a/test/e2e/cypress/integration/profiles/mfa/lookup.spec.ts +++ b/test/e2e/cypress/integration/profiles/mfa/lookup.spec.ts @@ -192,9 +192,9 @@ context("2FA lookup secrets", () => { cy.visit(settings) cy.get('button[name="lookup_secret_reveal"]').click() cy.getLookupSecrets().should((c) => { - expect(c[0]).not.to.equal(codes[0]) - expect(c[1]).not.to.equal(codes[1]) expect(c.slice(2)).to.eql(codes.slice(2)) + expect(c[0]).to.match(/(Secret was used at )|(Used)/g) + expect(c[1]).to.match(/(Secret was used at )|(Used)/g) }) // Regenerating the codes means the old one become invalid @@ -234,8 +234,8 @@ context("2FA lookup secrets", () => { cy.visit(settings) cy.get('button[name="lookup_secret_reveal"]').click() cy.getLookupSecrets().should((c) => { - expect(c[0]).not.to.equal(regenCodes[0]) expect(c.slice(1)).to.eql(regenCodes.slice(1)) + expect(c[0]).to.match(/(Secret was used at )|(Used)/g) }) }) diff --git a/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts b/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts index 54adac4bdd16..132845623f31 100644 --- a/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts @@ -194,9 +194,9 @@ context("Social Sign Up Successes", () => { app, email, website, - route: registration + "?return_to=https://www.ory.sh/", + route: registration + "?return_to=https://www.example.org/", }) - cy.location("href").should("eq", "https://www.ory.sh/") + cy.location("href").should("eq", "https://www.example.org/") cy.logout() }) diff --git a/test/e2e/cypress/integration/profiles/recovery/code/success.spec.ts b/test/e2e/cypress/integration/profiles/recovery/code/success.spec.ts index a27bca61d967..baafe5a389c7 100644 --- a/test/e2e/cypress/integration/profiles/recovery/code/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/recovery/code/success.spec.ts @@ -180,7 +180,7 @@ context("Account Recovery With Code Success", () => { const identity = gen.identityWithWebsite() cy.registerApi(identity) - cy.visit(express.recovery + "?return_to=https://www.ory.sh/") + cy.visit(express.recovery + "?return_to=https://www.example.org/") cy.get("input[name='email']").type(identity.email) cy.get("button[value='code']").click() cy.get('[data-testid="ui/message/1060003"]').should( @@ -196,6 +196,6 @@ context("Account Recovery With Code Success", () => { cy.get('input[name="password"]').clear().type(gen.password()) cy.get('button[value="password"]').click() - cy.url().should("eq", "https://www.ory.sh/") + cy.url().should("eq", "https://www.example.org/") }) }) diff --git a/test/e2e/cypress/integration/profiles/recovery/link/success.spec.ts b/test/e2e/cypress/integration/profiles/recovery/link/success.spec.ts index fc4137200b59..abfa089375b8 100644 --- a/test/e2e/cypress/integration/profiles/recovery/link/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/recovery/link/success.spec.ts @@ -109,7 +109,10 @@ context("Account Recovery Success", () => { const identity = gen.identityWithWebsite() cy.registerApi(identity) - cy.recoverApi({ email: identity.email, returnTo: "https://www.ory.sh/" }) + cy.recoverApi({ + email: identity.email, + returnTo: "https://www.example.org/", + }) cy.recoverEmail({ expect: identity }) @@ -120,7 +123,7 @@ context("Account Recovery Success", () => { .clear() .type(gen.password()) cy.get('button[value="password"]').click() - cy.url().should("eq", "https://www.ory.sh/") + cy.url().should("eq", "https://www.example.org/") }) it("should recover even if already logged into another account", () => { diff --git a/test/e2e/cypress/integration/profiles/recovery/return-to/success.spec.ts b/test/e2e/cypress/integration/profiles/recovery/return-to/success.spec.ts index 0fa3f12a9524..af5ffa0acd1c 100644 --- a/test/e2e/cypress/integration/profiles/recovery/return-to/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/recovery/return-to/success.spec.ts @@ -63,7 +63,7 @@ context("Recovery with `return_to`", () => { } it("should return to the `return_to` url after successful account recovery and settings update", () => { - cy.visit(recovery + "?return_to=https://www.ory.sh/") + cy.visit(recovery + "?return_to=https://www.example.org/") doRecovery() cy.get('[data-testid="ui/message/1060001"]', { timeout: 30000 }).should( @@ -80,7 +80,7 @@ context("Recovery with `return_to`", () => { .type(newPassword) cy.get('button[value="password"]').click() - cy.location("hostname").should("eq", "www.ory.sh") + cy.location("hostname").should("eq", "www.example.org") }) it("should return to the `return_to` url even with mfa enabled after successful account recovery and settings update", () => { @@ -108,7 +108,7 @@ context("Recovery with `return_to`", () => { cy.logout() cy.clearAllCookies() - cy.visit(recovery + "?return_to=https://www.ory.sh/") + cy.visit(recovery + "?return_to=https://www.example.org/") doRecovery() cy.shouldShow2FAScreen() @@ -122,7 +122,7 @@ context("Recovery with `return_to`", () => { .clear() .type(newPassword) cy.get('button[value="password"]').click() - cy.location("hostname").should("eq", "www.ory.sh") + cy.location("hostname").should("eq", "www.example.org") }) }) }) diff --git a/test/e2e/cypress/integration/profiles/two-steps/registration/oidc.spec.ts b/test/e2e/cypress/integration/profiles/two-steps/registration/oidc.spec.ts index cd4cf069c556..ca2d5e57078a 100644 --- a/test/e2e/cypress/integration/profiles/two-steps/registration/oidc.spec.ts +++ b/test/e2e/cypress/integration/profiles/two-steps/registration/oidc.spec.ts @@ -194,9 +194,9 @@ context("Social Sign Up Successes", () => { app, email, website, - route: registration + "?return_to=https://www.ory.sh/", + route: registration + "?return_to=https://www.example.org/", }) - cy.location("href").should("eq", "https://www.ory.sh/") + cy.location("href").should("eq", "https://www.example.org/") cy.logout() })