Skip to content

Commit

Permalink
Merge pull request go-gorm#5455 from longbridgeapp/feat-support-trans…
Browse files Browse the repository at this point in the history
…action-calllback
  • Loading branch information
huacnlee authored Jul 1, 2022
2 parents c74bc57 + 2cb4088 commit 5c4016d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 21 deletions.
37 changes: 31 additions & 6 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gorm

import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
Expand All @@ -15,12 +16,13 @@ import (
func initializeCallbacks(db *DB) *callbacks {
return &callbacks{
processors: map[string]*processor{
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
"transaction": {db: db},
},
}
}
Expand Down Expand Up @@ -72,6 +74,29 @@ func (cs *callbacks) Raw() *processor {
return cs.processors["raw"]
}

func (cs *callbacks) Transaction() *processor {
return cs.processors["transaction"]
}

func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB {
var err error

switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = ErrInvalidTransaction
}

if err != nil {
_ = tx.AddError(err)
}

return tx
}

func (p *processor) Execute(db *DB) *DB {
// call scopes
for len(db.Statement.scopes) > 0 {
Expand Down
16 changes: 1 addition & 15 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,27 +619,13 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
// clone statement
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
opt *sql.TxOptions
err error
)

if len(opts) > 0 {
opt = opts[0]
}

switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = ErrInvalidTransaction
}

if err != nil {
tx.AddError(err)
}

return tx
return tx.callbacks.Transaction().Begin(tx, opt)
}

// Commit commit a transaction
Expand Down

0 comments on commit 5c4016d

Please sign in to comment.