From f22f3e34a7abd8743a3b4e28b38a92551f90bd19 Mon Sep 17 00:00:00 2001 From: Kamil Samigullin Date: Sat, 24 Nov 2018 21:48:13 +0300 Subject: [PATCH] fix #72: use "upsert" for register license logic --- env/client/rest.http | 2 +- pkg/storage/internal/postgres/license.go | 2 +- pkg/storage/protected.go | 65 ++++++++++++++++++------ pkg/storage/public.go | 6 +-- 4 files changed, 55 insertions(+), 20 deletions(-) diff --git a/env/client/rest.http b/env/client/rest.http index 334f3e7..4fa66d5 100644 --- a/env/client/rest.http +++ b/env/client/rest.http @@ -128,7 +128,7 @@ X-Request-ID: 10000000-2000-4000-8000-160000000000 ### Maintenance -POST http://localhost:8093/install +POST http://localhost:8093/api/v1/install Content-Type: application/json X-Request-ID: 10000000-2000-4000-8000-160000000000 diff --git a/pkg/storage/internal/postgres/license.go b/pkg/storage/internal/postgres/license.go index 579d0ff..e21ad1b 100644 --- a/pkg/storage/internal/postgres/license.go +++ b/pkg/storage/internal/postgres/license.go @@ -109,7 +109,7 @@ func (scope licenseManager) Update(token *types.Token, data query.UpdateLicense) "user %q of account %q with token %q tried to update license %q with new contract %s", token.UserID, token.User.AccountID, token.ID, entity.ID, after) } - if prev == nil || !prev.Equal(*entity.UpdatedAt) { + if entity.UpdatedAt != nil && (prev == nil || !prev.Equal(*entity.UpdatedAt)) { audit := `INSERT INTO "license_audit" ("license_id", "contract", "what", "when", "who", "with") VALUES ($1, $2, $3, $4, $5, $6)` if _, execErr := scope.conn.ExecContext(scope.ctx, audit, entity.ID, before, diff --git a/pkg/storage/protected.go b/pkg/storage/protected.go index a4abebf..a1e63bb 100644 --- a/pkg/storage/protected.go +++ b/pkg/storage/protected.go @@ -8,11 +8,42 @@ import ( "github.com/kamilsk/guard/pkg/storage/query" "github.com/kamilsk/guard/pkg/storage/types" + "github.com/pkg/errors" ) // RegisterLicense TODO issue#docs func (storage *Storage) RegisterLicense(ctx context.Context, id domain.Token, data query.RegisterLicense) (types.License, error) { - return storage.CreateLicense(ctx, id, query.CreateLicense{ID: &data.ID, Contract: data.Contract}) + var license types.License + + conn, closer, connErr := storage.connection(ctx) + if connErr != nil { + return license, connErr + } + defer func() { _ = closer() }() + + token, authErr := storage.exec.UserManager(ctx, conn).AccessToken(id) + if authErr != nil { + return license, authErr + } + + tx, txErr := conn.BeginTx(ctx, &sql.TxOptions{}) + if txErr != nil { + return license, txErr + } + defer func() { _ = tx.Rollback() }() + + manager := storage.exec.LicenseManager(ctx, conn) + license, execErr := manager.Read(token, query.ReadLicense{ID: data.ID}) + if execErr == nil { + license, execErr = manager.Update(token, query.UpdateLicense{ID: data.ID, Contract: data.Contract}) + } else if errors.Cause(execErr) == sql.ErrNoRows { + license, execErr = manager.Create(token, query.CreateLicense{ID: &data.ID, Contract: data.Contract}) + } + if execErr != nil { + return license, execErr + } + + return license, tx.Commit() } // CreateLicense TODO issue#docs @@ -23,7 +54,7 @@ func (storage *Storage) CreateLicense(ctx context.Context, id domain.Token, data if connErr != nil { return license, connErr } - defer closer() + defer func() { _ = closer() }() token, authErr := storage.exec.UserManager(ctx, conn).AccessToken(id) if authErr != nil { @@ -34,9 +65,10 @@ func (storage *Storage) CreateLicense(ctx context.Context, id domain.Token, data if txErr != nil { return license, txErr } + defer func() { _ = tx.Rollback() }() + license, execErr := storage.exec.LicenseManager(ctx, conn).Create(token, data) if execErr != nil { - _ = tx.Rollback() // TODO issue#composite return license, execErr } return license, tx.Commit() @@ -50,7 +82,7 @@ func (storage *Storage) ReadLicense(ctx context.Context, id domain.Token, data q if connErr != nil { return license, connErr } - defer closer() + defer func() { _ = closer() }() token, authErr := storage.exec.UserManager(ctx, conn).AccessToken(id) if authErr != nil { @@ -68,7 +100,7 @@ func (storage *Storage) UpdateLicense(ctx context.Context, id domain.Token, data if connErr != nil { return license, connErr } - defer closer() + defer func() { _ = closer() }() token, authErr := storage.exec.UserManager(ctx, conn).AccessToken(id) if authErr != nil { @@ -79,9 +111,10 @@ func (storage *Storage) UpdateLicense(ctx context.Context, id domain.Token, data if txErr != nil { return license, txErr } + defer func() { _ = tx.Rollback() }() + license, execErr := storage.exec.LicenseManager(ctx, conn).Update(token, data) if execErr != nil { - _ = tx.Rollback() // TODO issue#composite return license, execErr } return license, tx.Commit() @@ -95,7 +128,7 @@ func (storage *Storage) DeleteLicense(ctx context.Context, id domain.Token, data if connErr != nil { return license, connErr } - defer closer() + defer func() { _ = closer() }() token, authErr := storage.exec.UserManager(ctx, conn).AccessToken(id) if authErr != nil { @@ -106,9 +139,10 @@ func (storage *Storage) DeleteLicense(ctx context.Context, id domain.Token, data if txErr != nil { return license, txErr } + defer func() { _ = tx.Rollback() }() + license, execErr := storage.exec.LicenseManager(ctx, conn).Delete(token, data) if execErr != nil { - _ = tx.Rollback() // TODO issue#composite return license, execErr } return license, tx.Commit() @@ -122,7 +156,7 @@ func (storage *Storage) RestoreLicense(ctx context.Context, id domain.Token, dat if connErr != nil { return license, connErr } - defer closer() + defer func() { _ = closer() }() token, authErr := storage.exec.UserManager(ctx, conn).AccessToken(id) if authErr != nil { @@ -133,9 +167,10 @@ func (storage *Storage) RestoreLicense(ctx context.Context, id domain.Token, dat if txErr != nil { return license, txErr } + defer func() { _ = tx.Rollback() }() + license, execErr := storage.exec.LicenseManager(ctx, conn).Restore(token, data) if execErr != nil { - _ = tx.Rollback() // TODO issue#composite return license, execErr } return license, tx.Commit() @@ -149,7 +184,7 @@ func (storage *Storage) AddEmployee(ctx context.Context, id domain.Token, data q if connErr != nil { return connErr } - defer closer() + defer func() { _ = closer() }() token, authErr := storage.exec.UserManager(ctx, conn).AccessToken(id) if authErr != nil { @@ -165,7 +200,7 @@ func (storage *Storage) DeleteEmployee(ctx context.Context, id domain.Token, dat if connErr != nil { return connErr } - defer closer() + defer func() { _ = closer() }() token, authErr := storage.exec.UserManager(ctx, conn).AccessToken(id) if authErr != nil { @@ -181,7 +216,7 @@ func (storage *Storage) AddWorkplace(ctx context.Context, id domain.Token, data if connErr != nil { return connErr } - defer closer() + defer func() { _ = closer() }() token, authErr := storage.exec.UserManager(ctx, conn).AccessToken(id) if authErr != nil { @@ -197,7 +232,7 @@ func (storage *Storage) DeleteWorkplace(ctx context.Context, id domain.Token, da if connErr != nil { return connErr } - defer closer() + defer func() { _ = closer() }() token, authErr := storage.exec.UserManager(ctx, conn).AccessToken(id) if authErr != nil { @@ -213,7 +248,7 @@ func (storage *Storage) PushWorkplace(ctx context.Context, id domain.Token, data if connErr != nil { return connErr } - defer closer() + defer func() { _ = closer() }() token, authErr := storage.exec.UserManager(ctx, conn).AccessToken(id) if authErr != nil { diff --git a/pkg/storage/public.go b/pkg/storage/public.go index b19da99..57d943e 100644 --- a/pkg/storage/public.go +++ b/pkg/storage/public.go @@ -18,7 +18,7 @@ func (storage *Storage) LicenseByID(ctx context.Context, id domain.ID) (types.Li if connErr != nil { return license, connErr } - defer closer() + defer func() { _ = closer() }() return storage.exec.LicenseReader(ctx, conn).GetByID(query.GetLicenseWithID{ID: id}) } @@ -31,7 +31,7 @@ func (storage *Storage) LicenseByEmployee(ctx context.Context, employee domain.I if connErr != nil { return license, connErr } - defer closer() + defer func() { _ = closer() }() return storage.exec.LicenseReader(ctx, conn).GetByEmployee(query.GetEmployeeLicense{Employee: employee}) } @@ -46,7 +46,7 @@ func (storage *Storage) RegisterAccount(ctx context.Context, data *query.Registe if connErr != nil { return nil, connErr } - defer closer() + defer func() { _ = closer() }() tx, txErr := conn.BeginTx(ctx, &sql.TxOptions{}) if txErr != nil {