Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add savepoint support to atomic distributed transaction #16863

Merged
merged 8 commits into from
Oct 10, 2024
11 changes: 11 additions & 0 deletions go/test/endtoend/cluster/cluster_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,17 @@ func (shard *Shard) PrimaryTablet() *Vttablet {
return shard.Vttablets[0]
}

// FindPrimaryTablet finds the primary tablet in the shard.
func (shard *Shard) FindPrimaryTablet() *Vttablet {
for _, vttablet := range shard.Vttablets {
tabletType := vttablet.VttabletProcess.GetTabletType()
if tabletType == "primary" {
return vttablet
}
}
return nil
}

// Rdonly get the last tablet which is rdonly
func (shard *Shard) Rdonly() *Vttablet {
for idx, tablet := range shard.Vttablets {
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/cluster/reshard.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (rw *ReshardWorkflow) WaitForVreplCatchup(timeToWait time.Duration) {
if !slices.Contains(targetShards, shard.Name) {
continue
}
vttablet := shard.PrimaryTablet().VttabletProcess
vttablet := shard.FindPrimaryTablet().VttabletProcess
vttablet.WaitForVReplicationToCatchup(rw.t, rw.workflowName, fmt.Sprintf("vt_%s", vttablet.Keyspace), "", timeToWait)
}
}
Expand Down
57 changes: 55 additions & 2 deletions go/test/endtoend/transaction/twopc/fuzz/fuzzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"os"
"path"
"slices"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -126,7 +127,18 @@ func TestTwoPCFuzzTest(t *testing.T) {
fz.start(t)

// Wait for the timeForTesting so that the threads continue to run.
time.Sleep(tt.timeForTesting)
timeout := time.After(tt.timeForTesting)
loop := true
for loop {
select {
case <-timeout:
loop = false
case <-time.After(1 * time.Second):
if t.Failed() {
loop = false
}
}
}

// Signal the fuzzer to stop.
fz.stop()
Expand Down Expand Up @@ -302,9 +314,11 @@ func (fz *fuzzer) generateAndExecuteTransaction(threadId int) {
// for each update set ordered by the auto increment column will not be true.
// That assertion depends on all the transactions running updates first to ensure that for any given update set,
// no two transactions are running the insert queries.
queries := []string{"begin"}
var queries []string
queries = append(queries, fz.generateUpdateQueries(updateSetVal, incrementVal)...)
queries = append(queries, fz.generateInsertQueries(updateSetVal, threadId)...)
queries = fz.addRandomSavePoints(queries)
queries = append([]string{"begin"}, queries...)
finalCommand := "commit"
for _, query := range queries {
_, err := conn.ExecuteFetch(query, 0, false)
Expand Down Expand Up @@ -377,6 +391,45 @@ func (fz *fuzzer) runClusterDisruption(t *testing.T) {
}
}

// addRandomSavePoints will add random savepoints and queries to the list of queries.
// It still ensures that all the new queries added are rolledback so that the assertions of queries
// don't change.
func (fz *fuzzer) addRandomSavePoints(queries []string) []string {
savePointCount := 1
for {
shouldAddSavePoint := rand.Intn(2)
if shouldAddSavePoint == 0 {
return queries
}

savePointQueries := []string{"SAVEPOINT sp" + strconv.Itoa(savePointCount)}
randomDmlCount := rand.Intn(2) + 1
for i := 0; i < randomDmlCount; i++ {
savePointQueries = append(savePointQueries, fz.randomDML())
}
savePointQueries = append(savePointQueries, "ROLLBACK TO sp"+strconv.Itoa(savePointCount))
savePointCount++

savePointPosition := rand.Intn(len(queries))
newQueries := slices.Clone(queries[:savePointPosition])
newQueries = append(newQueries, savePointQueries...)
newQueries = append(newQueries, queries[savePointPosition:]...)
queries = newQueries
}
}

// randomDML generates a random DML to be used.
func (fz *fuzzer) randomDML() string {
queryType := rand.Intn(2)
if queryType == 0 {
// Generate INSERT
return fmt.Sprintf(insertIntoFuzzInsert, updateRowBaseVals[rand.Intn(len(updateRowBaseVals))], rand.Intn(fz.updateSets), rand.Intn(fz.threads))
}
// Generate UPDATE
updateId := fz.updateRowsVals[rand.Intn(len(fz.updateRowsVals))][rand.Intn(len(updateRowBaseVals))]
return fmt.Sprintf(updateFuzzUpdate, rand.Intn(100000), updateId)
}

/*
Cluster Level Disruptions for the fuzzer
*/
Expand Down
77 changes: 73 additions & 4 deletions go/test/endtoend/transaction/twopc/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@ import (
"fmt"
"io"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/endtoend/utils"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/transaction/twopc/utils"
twopcutil "vitess.io/vitess/go/test/endtoend/transaction/twopc/utils"
binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
querypb "vitess.io/vitess/go/vt/proto/query"
topodatapb "vitess.io/vitess/go/vt/proto/topodata"
Expand All @@ -42,6 +45,7 @@ import (

var (
clusterInstance *cluster.LocalProcessCluster
mysqlParams mysql.ConnParams
vtParams mysql.ConnParams
vtgateGrpcAddress string
keyspaceName = "ks"
Expand Down Expand Up @@ -81,6 +85,8 @@ func TestMain(m *testing.M) {
"--twopc_enable",
"--twopc_abandon_age", "1",
"--queryserver-config-transaction-cap", "3",
"--queryserver-config-transaction-timeout", "400s",
"--queryserver-config-query-timeout", "9000s",
)

// Start keyspace
Expand All @@ -102,6 +108,15 @@ func TestMain(m *testing.M) {
vtParams = clusterInstance.GetVTParams(keyspaceName)
vtgateGrpcAddress = fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateGrpcPort)

// create mysql instance and connection parameters
conn, closer, err := utils.NewMySQL(clusterInstance, keyspaceName, SchemaSQL)
if err != nil {
fmt.Println(err)
return 1
}
defer closer()
mysqlParams = conn

return m.Run()
}()
os.Exit(exitcode)
Expand All @@ -121,8 +136,29 @@ func start(t *testing.T) (*mysql.Conn, func()) {

func cleanup(t *testing.T) {
cluster.PanicHandler(t)
utils.ClearOutTable(t, vtParams, "twopc_user")
utils.ClearOutTable(t, vtParams, "twopc_t1")
twopcutil.ClearOutTable(t, vtParams, "twopc_user")
twopcutil.ClearOutTable(t, vtParams, "twopc_t1")
sm.reset()
}

func startWithMySQL(t *testing.T) (utils.MySQLCompare, func()) {
mcmp, err := utils.NewMySQLCompare(t, vtParams, mysqlParams)
require.NoError(t, err)

deleteAll := func() {
tables := []string{"twopc_user"}
for _, table := range tables {
_, _ = mcmp.ExecAndIgnore("delete from " + table)
}
}

deleteAll()

return mcmp, func() {
deleteAll()
mcmp.Close()
cluster.PanicHandler(t)
}
}

type extractInterestingValues func(dtidMap map[string]string, vals []sqltypes.Value) []sqltypes.Value
Expand All @@ -147,7 +183,8 @@ var tables = map[string]extractInterestingValues{
},
"ks.redo_statement": func(dtidMap map[string]string, vals []sqltypes.Value) (out []sqltypes.Value) {
dtid := getDTID(dtidMap, vals[0].ToString())
out = append([]sqltypes.Value{sqltypes.NewVarChar(dtid)}, vals[1:]...)
stmt := getStatement(vals[2].ToString())
out = append([]sqltypes.Value{sqltypes.NewVarChar(dtid)}, vals[1], sqltypes.TestValue(sqltypes.Blob, stmt))
return
},
"ks.twopc_user": func(_ map[string]string, vals []sqltypes.Value) []sqltypes.Value { return vals },
Expand All @@ -167,6 +204,28 @@ func getDTID(dtidMap map[string]string, dtKey string) string {
return dtid
}

func getStatement(stmt string) string {
var sKey string
var prefix string
switch {
case strings.HasPrefix(stmt, "savepoint"):
prefix = "savepoint-"
sKey = stmt[9:]
case strings.HasPrefix(stmt, "rollback to"):
prefix = "rollback-"
sKey = stmt[11:]
default:
return stmt
}

sid, exists := sm.stmt[sKey]
if !exists {
sid = fmt.Sprintf("%d", len(sm.stmt)+1)
sm.stmt[sKey] = sid
}
return prefix + sid
}

func runVStream(t *testing.T, ctx context.Context, ch chan *binlogdatapb.VEvent, vtgateConn *vtgateconn.VTGateConn) {
vgtid := &binlogdatapb.VGtid{
ShardGtids: []*binlogdatapb.ShardGtid{
Expand Down Expand Up @@ -272,3 +331,13 @@ func prettyPrint(v interface{}) string {
}
return string(b)
}

type stmtMapper struct {
stmt map[string]string
}

var sm = &stmtMapper{stmt: make(map[string]string)}

func (sm *stmtMapper) reset() {
sm.stmt = make(map[string]string)
}
Loading
Loading