Skip to content

Commit

Permalink
rollback wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
NodudeWasTaken committed Oct 22, 2024
1 parent 45f2757 commit 9654813
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 38 deletions.
14 changes: 6 additions & 8 deletions pkg/sqlite/blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,17 +308,12 @@ func (qb *BlobStore) readFromDatabase(ctx context.Context, checksum string) (sql
// Delete marks a checksum as no longer in use by a single reference.
// If no references remain, the blob is deleted from the database and filesystem.
func (qb *BlobStore) Delete(ctx context.Context, checksum string) error {
rollid, err := savepoint(ctx)
if err != nil {
return fmt.Errorf("savepoint %s: %w", rollid, err)
}

// try to delete the blob from the database
if err := qb.delete(ctx, checksum); err != nil {
if qb.isConstraintError(err) {
// blob is still referenced - do not delete
logger.Debugf("Blob %s is still referenced - not deleting", checksum)
return rollbackToSavepoint(ctx, rollid)
return nil
}

// unexpected error
Expand Down Expand Up @@ -358,11 +353,14 @@ func (qb *BlobStore) delete(ctx context.Context, checksum string) error {

q := dialect.Delete(table).Where(goqu.C(blobChecksumColumn).Eq(checksum))

_, err := exec(ctx, q)
err := withSavepoint(ctx, func(ctx context.Context) error {
_, err := exec(ctx, q)
return err
})

if err != nil {
return fmt.Errorf("deleting from %s: %w", table, err)
}

return nil
}

Expand Down
9 changes: 6 additions & 3 deletions pkg/sqlite/performer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1155,9 +1155,12 @@ func TestPerformerQueryForAutoTag(t *testing.T) {
t.Errorf("Error finding performers: %s", err.Error())
}

assert.Len(t, performers, 2)
assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[0].Name))
assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[1].Name))
if assert.Len(t, performers, 2) {
assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[0].Name))
assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[1].Name))
} else {
t.Errorf("Skipping performer comparison as atleast 1 is missing")
}

return nil
})
Expand Down
27 changes: 0 additions & 27 deletions pkg/sqlite/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/jmoiron/sqlx"
"gopkg.in/guregu/null.v4"

"github.com/stashapp/stash/pkg/hash"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil"
Expand Down Expand Up @@ -1155,32 +1154,6 @@ func execID(ctx context.Context, stmt sqler) (*int64, error) {
return &id, nil
}

func savepoint(ctx context.Context) (string, error) {
tx, err := getTx(ctx)
if err != nil {
return "", err
}

// Generate savepoint
rnd, err := hash.GenerateRandomKey(64)
if err != nil {
return "", err
}

_, err = tx.QueryxContext(ctx, "SAVEPOINT "+rnd)
return rnd, err
}

func rollbackToSavepoint(ctx context.Context, id string) error {
tx, err := getTx(ctx)
if err != nil {
return err
}

_, err = tx.QueryxContext(ctx, "ROLLBACK TO SAVEPOINT "+id)
return err
}

func count(ctx context.Context, q *goqu.SelectDataset) (int, error) {
var count int
if err := querySimple(ctx, q, &count); err != nil {
Expand Down
41 changes: 41 additions & 0 deletions pkg/sqlite/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/hash"
"github.com/stashapp/stash/pkg/logger"
)

Expand Down Expand Up @@ -174,3 +175,43 @@ func (db *dbWrapperType) ExecStmt(ctx context.Context, stmt *stmt, args ...inter

return ret, sqlError(err, stmt.query, args...)
}

type SavepointAction func(ctx context.Context) error

func withSavepoint(ctx context.Context, action SavepointAction) error {
tx, err := getTx(ctx)
if err != nil {
return err
}

// Generate savepoint
rnd, err := hash.GenerateRandomKey(64)
if err != nil {
return err
}
rnd = "savepoint_" + rnd

// Create a savepoint
_, err = tx.Exec("SAVEPOINT " + rnd)
if err != nil {
return fmt.Errorf("failed to create savepoint: %w", err)
}

// Execute the action
err = action(ctx)
if err != nil {
// Rollback to savepoint on error
if _, rbErr := tx.Exec("ROLLBACK TO SAVEPOINT " + rnd); rbErr != nil {
return fmt.Errorf("action failed and rollback to savepoint failed: %w", rbErr)
}
return fmt.Errorf("action failed: %w", err)
}

// Release the savepoint on success
_, err = tx.Exec("RELEASE SAVEPOINT " + rnd)
if err != nil {
return fmt.Errorf("failed to release savepoint: %w", err)
}

return nil
}

0 comments on commit 9654813

Please sign in to comment.