From de962ce76e0df6faccb3f3d44d2d5b308c5f819a Mon Sep 17 00:00:00 2001 From: zepatrik Date: Thu, 14 Nov 2024 12:08:23 +0100 Subject: [PATCH 1/4] fix: do not roll back transaction on partial insert error --- identity/handler_test.go | 42 ++++++++++++++----- identity/manager.go | 10 ++++- identity/test/pool.go | 42 +++++++++++++++++-- .../sql/identity/persister_identity.go | 13 ++++-- 4 files changed, 87 insertions(+), 20 deletions(-) diff --git a/identity/handler_test.go b/identity/handler_test.go index 7f1c5285a8cf..6ff1080ee256 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -754,7 +754,7 @@ func TestHandler(t *testing.T) { }) t.Run("suite=PATCH identities", func(t *testing.T) { - t.Run("case=fails on > 100 identities", func(t *testing.T) { + t.Run("case=fails with too many patches", func(t *testing.T) { tooMany := make([]*identity.BatchIdentityPatch, identity.BatchPatchIdentitiesLimit+1) for i := range tooMany { tooMany[i] = &identity.BatchIdentityPatch{Create: validCreateIdentityBody("too-many-patches", i)} @@ -767,8 +767,8 @@ func TestHandler(t *testing.T) { t.Run("case=fails some on a bad identity", func(t *testing.T) { // Test setup: we have a list of valid identitiy patches and a list of invalid ones. // Each run adds one invalid patch to the list and sends it to the server. - // --> we expect the server to fail all patches in the list. - // Finally, we send just the valid patches + // --> we expect the server to fail only the bad patches in the list. + // Finally, we send just valid patches // --> we expect the server to succeed all patches in the list. t.Run("case=invalid patches fail", func(t *testing.T) { @@ -782,24 +782,23 @@ func TestHandler(t *testing.T) { {Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits {Create: validCreateIdentityBody("valid", 4)}, } + expectedToPass := []*identity.BatchIdentityPatch{patches[0], patches[1], patches[3], patches[5], patches[7]} // Create unique IDs for each patch - var patchIDs []string + patchIDs := make([]string, len(patches)) for i, p := range patches { id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i)) p.ID = &id - patchIDs = append(patchIDs, id.String()) + patchIDs[i] = id.String() } req := &identity.BatchPatchIdentitiesBody{Identities: patches} body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) var actions []string - for _, a := range body.Get("identities.#.action").Array() { - actions = append(actions, a.String()) - } - assert.Equal(t, + require.NoErrorf(t, json.Unmarshal(([]byte)(body.Get("identities.#.action").Raw), &actions), "%s", body) + assert.Equalf(t, []string{"create", "create", "error", "create", "error", "create", "error", "create"}, - actions, body) + actions, "%s", body) // Check that all patch IDs are returned for i, gotPatchID := range body.Get("identities.#.patch_id").Array() { @@ -811,6 +810,27 @@ func TestHandler(t *testing.T) { assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String()) assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String()) + var identityIDs []uuid.UUID + require.NoErrorf(t, json.Unmarshal(([]byte)(body.Get("identities.#.identity").Raw), &identityIDs), "%s", body) + + actualIdentities, _, err := reg.Persister().ListIdentities(ctx, identity.ListIdentityParameters{IdsFilter: identityIDs}) + require.NoError(t, err) + actualIdentityIDs := make([]uuid.UUID, len(actualIdentities)) + for i, id := range actualIdentities { + actualIdentityIDs[i] = id.ID + } + assert.ElementsMatchf(t, identityIDs, actualIdentityIDs, "%s", body) + + expectedTraits := make(map[string]string, len(expectedToPass)) + for i, p := range expectedToPass { + expectedTraits[identityIDs[i].String()] = string(p.Create.Traits) + } + actualTraits := make(map[string]string, len(actualIdentities)) + for _, id := range actualIdentities { + actualTraits[id.ID.String()] = string(id.Traits) + } + + assert.Equal(t, expectedTraits, actualTraits) }) t.Run("valid patches succeed", func(t *testing.T) { @@ -1928,7 +1948,7 @@ func validCreateIdentityBody(prefix string, i int) *identity.CreateIdentityBody identity.VerifiableAddressStatusCompleted, } - for j := 0; j < 4; j++ { + for j := range 4 { email := fmt.Sprintf("%s-%d-%d@ory.sh", prefix, i, j) traits.Emails = append(traits.Emails, email) verifiableAddresses = append(verifiableAddresses, identity.VerifiableAddress{ diff --git a/identity/manager.go b/identity/manager.go index 89c0259e6658..a09a08a778cd 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -333,6 +333,12 @@ type CreateIdentitiesError struct { failedIdentities map[*Identity]*herodot.DefaultError } +func NewCreateIdentitiesError(capacity int) *CreateIdentitiesError { + return &CreateIdentitiesError{ + failedIdentities: make(map[*Identity]*herodot.DefaultError, capacity), + } +} + func (e *CreateIdentitiesError) Error() string { e.init() return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities)) @@ -370,7 +376,7 @@ func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity { return nil } func (e *CreateIdentitiesError) ErrOrNil() error { - if len(e.failedIdentities) == 0 { + if e == nil || len(e.failedIdentities) == 0 { return nil } return e @@ -385,7 +391,7 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities") defer otelx.End(span, &err) - createIdentitiesError := &CreateIdentitiesError{} + createIdentitiesError := NewCreateIdentitiesError(len(identities)) validIdentities := make([]*Identity, 0, len(identities)) for _, ident := range identities { if ident.SchemaID == "" { diff --git a/identity/test/pool.go b/identity/test/pool.go index 4f898917449f..bfde144f82e0 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -350,10 +350,46 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers) assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute) assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute) - assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second) - assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second) + assert.Equal(t, id.CreatedAt, idFromDB.CreatedAt) + assert.Equal(t, id.UpdatedAt, idFromDB.UpdatedAt) + } + }) + + t.Run("create exactly the non-conflicting ones", func(t *testing.T) { + identities := make([]*identity.Identity, 100) + for i := range identities { + identities[i] = NewTestIdentity(4, "persister-create-multiple-2", i%60) + } + err := p.CreateIdentities(ctx, identities...) + errWithCtx := new(identity.CreateIdentitiesError) + require.ErrorAsf(t, err, &errWithCtx, "%#v", err) + + for _, id := range identities[:60] { + require.NotZero(t, id.ID) + + idFromDB, err := p.GetIdentity(ctx, id.ID, identity.ExpandEverything) + require.NoError(t, err) + + credFromDB := idFromDB.Credentials[identity.CredentialsTypePassword] + assert.Equal(t, id.ID, idFromDB.ID) + assert.Equal(t, id.SchemaID, idFromDB.SchemaID) + assert.Equal(t, id.SchemaURL, idFromDB.SchemaURL) + assert.Equal(t, id.State, idFromDB.State) + + // We test that the values are plausible in the handler test already. + assert.Equal(t, len(id.VerifiableAddresses), len(idFromDB.VerifiableAddresses)) + assert.Equal(t, len(id.RecoveryAddresses), len(idFromDB.RecoveryAddresses)) + + assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers) + assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute) + assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute) + assert.Equal(t, id.CreatedAt, idFromDB.CreatedAt) + assert.Equal(t, id.UpdatedAt, idFromDB.UpdatedAt) + } - require.NoError(t, p.DeleteIdentity(ctx, id.ID)) + for _, id := range identities[60:] { + failed := errWithCtx.Find(id) + assert.NotNil(t, failed) } }) }) diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 489b1fbb4360..5b29017779ca 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -561,7 +561,8 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... } }() - return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + var partialErr *identity.CreateIdentitiesError + if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { conn := &batch.TracerConnection{ Tracer: p.r.Tracer(ctx), Connection: tx, @@ -569,6 +570,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... succeededIDs = make([]uuid.UUID, 0, len(identities)) failedIdentityIDs := make(map[uuid.UUID]struct{}) + partialErr = nil // Don't use batch.WithPartialInserts, because identities have no other // constraints other than the primary key that could cause conflicts. @@ -620,7 +622,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... // If any of the batch inserts failed on conflict, let's delete the corresponding // identities and return a list of failed identities in the error. if len(failedIdentityIDs) > 0 { - partialErr := &identity.CreateIdentitiesError{} + partialErr = identity.NewCreateIdentitiesError(len(failedIdentityIDs)) failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs)) for _, ident := range identities { @@ -637,7 +639,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... return sqlcon.HandleError(err) } - return partialErr + return nil } else { // No failures: report all identities as created. for _, ident := range identities { @@ -646,7 +648,10 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... } return nil - }) + }); err != nil { + return err + } + return partialErr.ErrOrNil() } func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) { From 3f45a989ffbcfb743824327be4846dea7cfa24ea Mon Sep 17 00:00:00 2001 From: zepatrik Date: Thu, 14 Nov 2024 13:28:26 +0100 Subject: [PATCH 2/4] revert: assertion update --- identity/test/pool.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/identity/test/pool.go b/identity/test/pool.go index bfde144f82e0..47204860e61f 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -350,8 +350,9 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers) assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute) assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute) - assert.Equal(t, id.CreatedAt, idFromDB.CreatedAt) - assert.Equal(t, id.UpdatedAt, idFromDB.UpdatedAt) + // because of mysql precision + assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second) + assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second) } }) @@ -383,8 +384,9 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers) assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute) assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute) - assert.Equal(t, id.CreatedAt, idFromDB.CreatedAt) - assert.Equal(t, id.UpdatedAt, idFromDB.UpdatedAt) + // because of mysql precision + assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second) + assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second) } for _, id := range identities[60:] { From 32ceb02f36551b248e83332204f39def71f16197 Mon Sep 17 00:00:00 2001 From: zepatrik Date: Thu, 14 Nov 2024 14:19:08 +0100 Subject: [PATCH 3/4] test: mysql specific assertion when parts of patch are invalid --- identity/test/pool.go | 6 ++++++ internal/client-go/go.sum | 1 + 2 files changed, 7 insertions(+) diff --git a/identity/test/pool.go b/identity/test/pool.go index 47204860e61f..87e02372cabf 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -362,6 +362,12 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, identities[i] = NewTestIdentity(4, "persister-create-multiple-2", i%60) } err := p.CreateIdentities(ctx, identities...) + if dbname == "mysql" { + // partial inserts are not supported on mysql + assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation) + return + } + errWithCtx := new(identity.CreateIdentitiesError) require.ErrorAsf(t, err, &errWithCtx, "%#v", err) diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= From 914f5723bdd3bbdec8782ad1a160eec273b8f96c Mon Sep 17 00:00:00 2001 From: zepatrik Date: Thu, 14 Nov 2024 15:34:13 +0100 Subject: [PATCH 4/4] test: delete identities to avoid side effects on other cases --- identity/test/pool.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/identity/test/pool.go b/identity/test/pool.go index 87e02372cabf..2e53fa2a53a2 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -353,6 +353,8 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, // because of mysql precision assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second) assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second) + + require.NoError(t, p.DeleteIdentity(ctx, id.ID)) } }) @@ -393,6 +395,8 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, // because of mysql precision assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second) assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second) + + require.NoError(t, p.DeleteIdentity(ctx, id.ID)) } for _, id := range identities[60:] {