Skip to content

Commit

Permalink
feat: support xa protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
dk-lockdown committed Aug 24, 2022
1 parent 2c40f80 commit a0d86fd
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 8 deletions.
23 changes: 23 additions & 0 deletions pkg/constant/constant.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* This file is part of the hptx distribution (https://github.com/cectc/htpx).
* Copyright 2022 CECTC, Inc.
*
* This program is free software: you can redistribute it and/or modify it under the terms
* of the GNU General Public License as published by the Free Software Foundation, either
* version 3 of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
* PARTICULAR PURPOSE. See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License along with this
* program. If not, see <https://www.gnu.org/licenses/>.
*/

package constant

const XID = keyXID("XID")

type (
keyXID string
)
9 changes: 2 additions & 7 deletions pkg/contrib/grpc/global_transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,10 @@ import (
"github.com/cectc/dbpack/pkg/log"
"google.golang.org/grpc"

"github.com/cectc/hptx/pkg/constant"
"github.com/cectc/hptx/pkg/core"
)

const XID = keyXID("XID")

type (
keyXID string
)

type GlobalTransactionInfo struct {
FullMethod string
Timeout int32
Expand All @@ -47,7 +42,7 @@ func GlobalTransactionInterceptor(globalTransactionInfos []*GlobalTransactionInf
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, XID, xid)
ctx = context.WithValue(ctx, constant.XID, xid)
resp, err = handler(ctx, req)
if err == nil {
_, commitErr := core.GetDistributedTransactionManager().Commit(ctx, xid)
Expand Down
6 changes: 5 additions & 1 deletion pkg/core/distributed_transaction_manger.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ func (manager *DistributedTransactionManager) branchCommit(ctx context.Context,
status, err = resource.GetTCCBranchResource().Commit(ctx, bs)
case api.AT:
status, err = resource.GetATBranchResource().Commit(ctx, bs)
case api.XA:
status, err = resource.GetXABranchResource().Commit(ctx, bs)
default:
return bs.Status, errors.New("should never happen!")
}
Expand All @@ -182,6 +184,8 @@ func (manager *DistributedTransactionManager) branchRollback(ctx context.Context
status, err = resource.GetTCCBranchResource().Rollback(ctx, bs)
case api.AT:
status, err = resource.GetATBranchResource().Rollback(ctx, bs)
case api.XA:
status, err = resource.GetXABranchResource().Rollback(ctx, bs)
default:
return bs.Status, errors.New("should never happen!")
}
Expand Down Expand Up @@ -350,7 +354,7 @@ func (manager *DistributedTransactionManager) processNextBranchSession(ctx conte
}
}
if bs.Status == api.PhaseTwoRollbacking {
if manager.IsRollingBackDead(bs) {
if bs.Type == api.AT && manager.IsRollingBackDead(bs) {
if manager.rollbackRetryTimeoutUnlockEnable {
if _, err := manager.storageDriver.ReleaseLockKeys(context.Background(), bs.ResourceID, []string{bs.LockKey}); err != nil {
log.Error(err)
Expand Down
8 changes: 8 additions & 0 deletions pkg/resource/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,18 @@ func GetATBranchResource() proto.BranchResource {
return branches[api.AT]
}

func GetXABranchResource() proto.BranchResource {
return branches[api.XA]
}

func InitTCCBranchResource(resource proto.BranchResource) {
branches[api.TCC] = resource
}

func InitATBranchResource(resource proto.BranchResource) {
branches[api.AT] = resource
}

func InitXABranchResource(resource proto.BranchResource) {
branches[api.XA] = resource
}
80 changes: 80 additions & 0 deletions pkg/xa/xa.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* This file is part of the hptx distribution (https://github.com/cectc/htpx).
* Copyright 2022 CECTC, Inc.
*
* This program is free software: you can redistribute it and/or modify it under the terms
* of the GNU General Public License as published by the Free Software Foundation, either
* version 3 of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
* PARTICULAR PURPOSE. See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License along with this
* program. If not, see <https://www.gnu.org/licenses/>.
*/

package xa

import (
"context"
"database/sql"
"fmt"

"github.com/cectc/dbpack/pkg/dt/api"
"github.com/cectc/dbpack/pkg/log"
"github.com/pkg/errors"

"github.com/cectc/hptx/pkg/constant"
"github.com/cectc/hptx/pkg/core"
"github.com/cectc/hptx/pkg/proto"
)

func HandleWithXA(ctx context.Context, db *sql.DB, appid string, businessFn func(conn *sql.Conn) error) (err error) {
xid := ctx.Value(constant.XID)
if xid == nil {
return errors.New("ctx must with value xid")
}

var branchID string
branchID, _, err = core.GetDistributedTransactionManager().BranchRegister(ctx, &proto.BranchRegister{
XID: xid.(string),
ResourceID: appid,
LockKey: "",
BranchType: api.XA,
ApplicationData: nil,
})
if err != nil {
log.Errorf("XA branch Register error, xid: %s", xid.(string))
return errors.WithStack(err)
}
defer func() {
if err != nil {
if reportErr := core.GetDistributedTransactionManager().BranchReport(ctx, branchID, api.PhaseOneFailed); reportErr != nil {
log.Error(reportErr)
}
}
}()

var conn *sql.Conn
conn, err = db.Conn(ctx)
if err != nil {
return err
}
defer conn.Close()
if _, err = conn.ExecContext(ctx, fmt.Sprintf("XA START '%s'", branchID)); err != nil {
return err
}
defer func() {
if err == nil {
_, err = conn.ExecContext(ctx, fmt.Sprintf("XA PREPARE '%s'", branchID))
}
}()
defer func() {
_, err = conn.ExecContext(ctx, fmt.Sprintf("XA END '%s'", branchID))
}()
if err = businessFn(conn); err != nil {
return err
}
return nil
}

0 comments on commit a0d86fd

Please sign in to comment.