diff --git a/identity/handler_test.go b/identity/handler_test.go index 7f1c5285a8cf..6ff1080ee256 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -754,7 +754,7 @@ func TestHandler(t *testing.T) { }) t.Run("suite=PATCH identities", func(t *testing.T) { - t.Run("case=fails on > 100 identities", func(t *testing.T) { + t.Run("case=fails with too many patches", func(t *testing.T) { tooMany := make([]*identity.BatchIdentityPatch, identity.BatchPatchIdentitiesLimit+1) for i := range tooMany { tooMany[i] = &identity.BatchIdentityPatch{Create: validCreateIdentityBody("too-many-patches", i)} @@ -767,8 +767,8 @@ func TestHandler(t *testing.T) { t.Run("case=fails some on a bad identity", func(t *testing.T) { // Test setup: we have a list of valid identitiy patches and a list of invalid ones. // Each run adds one invalid patch to the list and sends it to the server. - // --> we expect the server to fail all patches in the list. - // Finally, we send just the valid patches + // --> we expect the server to fail only the bad patches in the list. + // Finally, we send just valid patches // --> we expect the server to succeed all patches in the list. t.Run("case=invalid patches fail", func(t *testing.T) { @@ -782,24 +782,23 @@ func TestHandler(t *testing.T) { {Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits {Create: validCreateIdentityBody("valid", 4)}, } + expectedToPass := []*identity.BatchIdentityPatch{patches[0], patches[1], patches[3], patches[5], patches[7]} // Create unique IDs for each patch - var patchIDs []string + patchIDs := make([]string, len(patches)) for i, p := range patches { id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i)) p.ID = &id - patchIDs = append(patchIDs, id.String()) + patchIDs[i] = id.String() } req := &identity.BatchPatchIdentitiesBody{Identities: patches} body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) var actions []string - for _, a := range body.Get("identities.#.action").Array() { - actions = append(actions, a.String()) - } - assert.Equal(t, + require.NoErrorf(t, json.Unmarshal(([]byte)(body.Get("identities.#.action").Raw), &actions), "%s", body) + assert.Equalf(t, []string{"create", "create", "error", "create", "error", "create", "error", "create"}, - actions, body) + actions, "%s", body) // Check that all patch IDs are returned for i, gotPatchID := range body.Get("identities.#.patch_id").Array() { @@ -811,6 +810,27 @@ func TestHandler(t *testing.T) { assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String()) assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String()) + var identityIDs []uuid.UUID + require.NoErrorf(t, json.Unmarshal(([]byte)(body.Get("identities.#.identity").Raw), &identityIDs), "%s", body) + + actualIdentities, _, err := reg.Persister().ListIdentities(ctx, identity.ListIdentityParameters{IdsFilter: identityIDs}) + require.NoError(t, err) + actualIdentityIDs := make([]uuid.UUID, len(actualIdentities)) + for i, id := range actualIdentities { + actualIdentityIDs[i] = id.ID + } + assert.ElementsMatchf(t, identityIDs, actualIdentityIDs, "%s", body) + + expectedTraits := make(map[string]string, len(expectedToPass)) + for i, p := range expectedToPass { + expectedTraits[identityIDs[i].String()] = string(p.Create.Traits) + } + actualTraits := make(map[string]string, len(actualIdentities)) + for _, id := range actualIdentities { + actualTraits[id.ID.String()] = string(id.Traits) + } + + assert.Equal(t, expectedTraits, actualTraits) }) t.Run("valid patches succeed", func(t *testing.T) { @@ -1928,7 +1948,7 @@ func validCreateIdentityBody(prefix string, i int) *identity.CreateIdentityBody identity.VerifiableAddressStatusCompleted, } - for j := 0; j < 4; j++ { + for j := range 4 { email := fmt.Sprintf("%s-%d-%d@ory.sh", prefix, i, j) traits.Emails = append(traits.Emails, email) verifiableAddresses = append(verifiableAddresses, identity.VerifiableAddress{ diff --git a/identity/manager.go b/identity/manager.go index 89c0259e6658..a09a08a778cd 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -333,6 +333,12 @@ type CreateIdentitiesError struct { failedIdentities map[*Identity]*herodot.DefaultError } +func NewCreateIdentitiesError(capacity int) *CreateIdentitiesError { + return &CreateIdentitiesError{ + failedIdentities: make(map[*Identity]*herodot.DefaultError, capacity), + } +} + func (e *CreateIdentitiesError) Error() string { e.init() return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities)) @@ -370,7 +376,7 @@ func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity { return nil } func (e *CreateIdentitiesError) ErrOrNil() error { - if len(e.failedIdentities) == 0 { + if e == nil || len(e.failedIdentities) == 0 { return nil } return e @@ -385,7 +391,7 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities") defer otelx.End(span, &err) - createIdentitiesError := &CreateIdentitiesError{} + createIdentitiesError := NewCreateIdentitiesError(len(identities)) validIdentities := make([]*Identity, 0, len(identities)) for _, ident := range identities { if ident.SchemaID == "" { diff --git a/identity/test/pool.go b/identity/test/pool.go index 4f898917449f..2e53fa2a53a2 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -350,12 +350,60 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers) assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute) assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute) + // because of mysql precision assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second) assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second) require.NoError(t, p.DeleteIdentity(ctx, id.ID)) } }) + + t.Run("create exactly the non-conflicting ones", func(t *testing.T) { + identities := make([]*identity.Identity, 100) + for i := range identities { + identities[i] = NewTestIdentity(4, "persister-create-multiple-2", i%60) + } + err := p.CreateIdentities(ctx, identities...) + if dbname == "mysql" { + // partial inserts are not supported on mysql + assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation) + return + } + + errWithCtx := new(identity.CreateIdentitiesError) + require.ErrorAsf(t, err, &errWithCtx, "%#v", err) + + for _, id := range identities[:60] { + require.NotZero(t, id.ID) + + idFromDB, err := p.GetIdentity(ctx, id.ID, identity.ExpandEverything) + require.NoError(t, err) + + credFromDB := idFromDB.Credentials[identity.CredentialsTypePassword] + assert.Equal(t, id.ID, idFromDB.ID) + assert.Equal(t, id.SchemaID, idFromDB.SchemaID) + assert.Equal(t, id.SchemaURL, idFromDB.SchemaURL) + assert.Equal(t, id.State, idFromDB.State) + + // We test that the values are plausible in the handler test already. + assert.Equal(t, len(id.VerifiableAddresses), len(idFromDB.VerifiableAddresses)) + assert.Equal(t, len(id.RecoveryAddresses), len(idFromDB.RecoveryAddresses)) + + assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers) + assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute) + assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute) + // because of mysql precision + assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second) + assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second) + + require.NoError(t, p.DeleteIdentity(ctx, id.ID)) + } + + for _, id := range identities[60:] { + failed := errWithCtx.Find(id) + assert.NotNil(t, failed) + } + }) }) t.Run("case=should error when the identity ID does not exist", func(t *testing.T) { diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 489b1fbb4360..5b29017779ca 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -561,7 +561,8 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... } }() - return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + var partialErr *identity.CreateIdentitiesError + if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { conn := &batch.TracerConnection{ Tracer: p.r.Tracer(ctx), Connection: tx, @@ -569,6 +570,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... succeededIDs = make([]uuid.UUID, 0, len(identities)) failedIdentityIDs := make(map[uuid.UUID]struct{}) + partialErr = nil // Don't use batch.WithPartialInserts, because identities have no other // constraints other than the primary key that could cause conflicts. @@ -620,7 +622,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... // If any of the batch inserts failed on conflict, let's delete the corresponding // identities and return a list of failed identities in the error. if len(failedIdentityIDs) > 0 { - partialErr := &identity.CreateIdentitiesError{} + partialErr = identity.NewCreateIdentitiesError(len(failedIdentityIDs)) failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs)) for _, ident := range identities { @@ -637,7 +639,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... return sqlcon.HandleError(err) } - return partialErr + return nil } else { // No failures: report all identities as created. for _, ident := range identities { @@ -646,7 +648,10 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... } return nil - }) + }); err != nil { + return err + } + return partialErr.ErrOrNil() } func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) {