Skip to content

Commit

Permalink
move safesession to executorcontext package
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal committed Dec 2, 2024
1 parent 6be7923 commit 27d8dc9
Show file tree
Hide file tree
Showing 25 changed files with 455 additions and 367 deletions.
3 changes: 2 additions & 1 deletion go/vt/vtexplain/vtexplain_vtgate.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate"
"vitess.io/vitess/go/vt/vtgate/engine"
econtext "vitess.io/vitess/go/vt/vtgate/executorcontext"
"vitess.io/vitess/go/vt/vtgate/logstats"
"vitess.io/vitess/go/vt/vtgate/vindexes"
"vitess.io/vitess/go/vt/vttablet/queryservice"
Expand Down Expand Up @@ -235,7 +236,7 @@ func (vte *VTExplain) vtgateExecute(sql string) ([]*engine.Plan, map[string]*Tab
// This will ensure that the commit/rollback order is predictable.
vte.sortShardSession()

_, err := vte.vtgateExecutor.Execute(context.Background(), nil, "VtexplainExecute", vtgate.NewSafeSession(vte.vtgateSession), sql, nil)
_, err := vte.vtgateExecutor.Execute(context.Background(), nil, "VtexplainExecute", econtext.NewSafeSession(vte.vtgateSession), sql, nil)
if err != nil {
for _, tc := range vte.explainTopo.TabletConns {
tc.tabletQueries = nil
Expand Down
12 changes: 6 additions & 6 deletions go/vt/vtgate/autocommit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import (
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"

querypb "vitess.io/vitess/go/vt/proto/query"
vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
econtext "vitess.io/vitess/go/vt/vtgate/executorcontext"
)

// This file contains tests for all the autocommit code paths
Expand Down Expand Up @@ -382,7 +382,7 @@ func TestAutocommitTransactionStarted(t *testing.T) {

// single shard query - no savepoint needed
sql := "update `user` set a = 2 where id = 1"
_, err := executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
_, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
require.NoError(t, err)
require.Len(t, sbc1.Queries, 1)
require.Equal(t, sql, sbc1.Queries[0].Sql)
Expand All @@ -394,7 +394,7 @@ func TestAutocommitTransactionStarted(t *testing.T) {
// multi shard query - savepoint needed
sql = "update `user` set a = 2 where id in (1, 4)"
expectedSql := "update `user` set a = 2 where id in ::__vals"
_, err = executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
_, err = executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
require.NoError(t, err)
require.Len(t, sbc1.Queries, 2)
require.Contains(t, sbc1.Queries[0].Sql, "savepoint")
Expand All @@ -413,7 +413,7 @@ func TestAutocommitDirectTarget(t *testing.T) {
}
sql := "insert into `simple`(val) values ('val')"

_, err := executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
_, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
require.NoError(t, err)

assertQueries(t, sbclookup, []*querypb.BoundQuery{{
Expand All @@ -434,7 +434,7 @@ func TestAutocommitDirectRangeTarget(t *testing.T) {
}
sql := "delete from sharded_user_msgs limit 1000"

_, err := executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
_, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
require.NoError(t, err)

assertQueries(t, sbc1, []*querypb.BoundQuery{{
Expand All @@ -451,5 +451,5 @@ func autocommitExec(executor *Executor, sql string) (*sqltypes.Result, error) {
TransactionMode: vtgatepb.TransactionMode_MULTI,
}

return executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
return executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
}
60 changes: 28 additions & 32 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import (
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/engine"
"vitess.io/vitess/go/vt/vtgate/evalengine"
econtext "vitess.io/vitess/go/vt/vtgate/executorcontext"
"vitess.io/vitess/go/vt/vtgate/logstats"
"vitess.io/vitess/go/vt/vtgate/planbuilder"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
Expand Down Expand Up @@ -223,7 +224,7 @@ func NewExecutor(
}

// Execute executes a non-streaming query.
func (e *Executor) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (result *sqltypes.Result, err error) {
func (e *Executor) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, method string, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (result *sqltypes.Result, err error) {
span, ctx := trace.NewSpan(ctx, "executor.Execute")
span.Annotate("method", method)
trace.AnnotateSQL(span, sqlparser.Preview(sql))
Expand Down Expand Up @@ -286,7 +287,7 @@ func (e *Executor) StreamExecute(
ctx context.Context,
mysqlCtx vtgateservice.MySQLConnection,
method string,
safeSession *SafeSession,
safeSession *econtext.SafeSession,
sql string,
bindVars map[string]*querypb.BindVariable,
callback func(*sqltypes.Result) error,
Expand Down Expand Up @@ -411,12 +412,12 @@ func canReturnRows(stmtType sqlparser.StatementType) bool {
}
}

func saveSessionStats(safeSession *SafeSession, stmtType sqlparser.StatementType, rowsAffected, insertID uint64, rowsReturned int, err error) {
func saveSessionStats(safeSession *econtext.SafeSession, stmtType sqlparser.StatementType, rowsAffected, insertID uint64, rowsReturned int, err error) {
safeSession.RowCount = -1
if err != nil {
return
}
if !safeSession.foundRowsHandled {
if !safeSession.IsFoundRowsHandled() {
safeSession.FoundRows = uint64(rowsReturned)
}
if insertID > 0 {
Expand All @@ -430,7 +431,7 @@ func saveSessionStats(safeSession *SafeSession, stmtType sqlparser.StatementType
}
}

func (e *Executor) execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) (sqlparser.StatementType, *sqltypes.Result, error) {
func (e *Executor) execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) (sqlparser.StatementType, *sqltypes.Result, error) {
var err error
var qr *sqltypes.Result
var stmtType sqlparser.StatementType
Expand All @@ -448,7 +449,7 @@ func (e *Executor) execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConn
}

// addNeededBindVars adds bind vars that are needed by the plan
func (e *Executor) addNeededBindVars(vcursor *vcursorImpl, bindVarNeeds *sqlparser.BindVarNeeds, bindVars map[string]*querypb.BindVariable, session *SafeSession) error {
func (e *Executor) addNeededBindVars(vcursor *vcursorImpl, bindVarNeeds *sqlparser.BindVarNeeds, bindVars map[string]*querypb.BindVariable, session *econtext.SafeSession) error {
for _, funcName := range bindVarNeeds.NeedFunctionResult {
switch funcName {
case sqlparser.DBVarName:
Expand Down Expand Up @@ -572,21 +573,21 @@ func (e *Executor) addNeededBindVars(vcursor *vcursorImpl, bindVarNeeds *sqlpars
return nil
}

func ifOptionsExist(session *SafeSession, f func(*querypb.ExecuteOptions)) {
func ifOptionsExist(session *econtext.SafeSession, f func(*querypb.ExecuteOptions)) {
options := session.GetOptions()
if options != nil {
f(options)
}
}

func ifReadAfterWriteExist(session *SafeSession, f func(*vtgatepb.ReadAfterWrite)) {
func ifReadAfterWriteExist(session *econtext.SafeSession, f func(*vtgatepb.ReadAfterWrite)) {
raw := session.ReadAfterWrite
if raw != nil {
f(raw)
}
}

func (e *Executor) handleBegin(ctx context.Context, safeSession *SafeSession, logStats *logstats.LogStats, stmt sqlparser.Statement) (*sqltypes.Result, error) {
func (e *Executor) handleBegin(ctx context.Context, safeSession *econtext.SafeSession, logStats *logstats.LogStats, stmt sqlparser.Statement) (*sqltypes.Result, error) {
execStart := time.Now()
logStats.PlanTime = execStart.Sub(logStats.StartTime)

Expand All @@ -599,7 +600,7 @@ func (e *Executor) handleBegin(ctx context.Context, safeSession *SafeSession, lo
return &sqltypes.Result{}, err
}

func (e *Executor) handleCommit(ctx context.Context, safeSession *SafeSession, logStats *logstats.LogStats) (*sqltypes.Result, error) {
func (e *Executor) handleCommit(ctx context.Context, safeSession *econtext.SafeSession, logStats *logstats.LogStats) (*sqltypes.Result, error) {
execStart := time.Now()
logStats.PlanTime = execStart.Sub(logStats.StartTime)
logStats.ShardQueries = uint64(len(safeSession.ShardSessions))
Expand All @@ -611,11 +612,11 @@ func (e *Executor) handleCommit(ctx context.Context, safeSession *SafeSession, l
}

// Commit commits the existing transactions
func (e *Executor) Commit(ctx context.Context, safeSession *SafeSession) error {
func (e *Executor) Commit(ctx context.Context, safeSession *econtext.SafeSession) error {
return e.txConn.Commit(ctx, safeSession)
}

func (e *Executor) handleRollback(ctx context.Context, safeSession *SafeSession, logStats *logstats.LogStats) (*sqltypes.Result, error) {
func (e *Executor) handleRollback(ctx context.Context, safeSession *econtext.SafeSession, logStats *logstats.LogStats) (*sqltypes.Result, error) {
execStart := time.Now()
logStats.PlanTime = execStart.Sub(logStats.StartTime)
logStats.ShardQueries = uint64(len(safeSession.ShardSessions))
Expand All @@ -625,7 +626,7 @@ func (e *Executor) handleRollback(ctx context.Context, safeSession *SafeSession,
return &sqltypes.Result{}, err
}

func (e *Executor) handleSavepoint(ctx context.Context, safeSession *SafeSession, sql string, planType string, logStats *logstats.LogStats, nonTxResponse func(query string) (*sqltypes.Result, error), ignoreMaxMemoryRows bool) (*sqltypes.Result, error) {
func (e *Executor) handleSavepoint(ctx context.Context, safeSession *econtext.SafeSession, sql string, planType string, logStats *logstats.LogStats, nonTxResponse func(query string) (*sqltypes.Result, error), ignoreMaxMemoryRows bool) (*sqltypes.Result, error) {
execStart := time.Now()
logStats.PlanTime = execStart.Sub(logStats.StartTime)
logStats.ShardQueries = uint64(len(safeSession.ShardSessions))
Expand All @@ -637,15 +638,15 @@ func (e *Executor) handleSavepoint(ctx context.Context, safeSession *SafeSession
// If no transaction exists on any of the shard sessions,
// then savepoint does not need to be executed, it will be only stored in the session
// and later will be executed when a transaction is started.
if !safeSession.isTxOpen() {
if !safeSession.IsTxOpen() {
if safeSession.InTransaction() {
// Storing, as this needs to be executed just after starting transaction on the shard.
safeSession.StoreSavepoint(sql)
return &sqltypes.Result{}, nil
}
return nonTxResponse(sql)
}
orig := safeSession.commitOrder
orig := safeSession.GetCommitOrder()
qr, err := e.executeSPInAllSessions(ctx, safeSession, sql, ignoreMaxMemoryRows)
safeSession.SetCommitOrder(orig)
if err != nil {
Expand All @@ -657,15 +658,15 @@ func (e *Executor) handleSavepoint(ctx context.Context, safeSession *SafeSession

// executeSPInAllSessions function executes the savepoint query in all open shard sessions (pre, normal and post)
// which has non-zero transaction id (i.e. an open transaction on the shard connection).
func (e *Executor) executeSPInAllSessions(ctx context.Context, safeSession *SafeSession, sql string, ignoreMaxMemoryRows bool) (*sqltypes.Result, error) {
func (e *Executor) executeSPInAllSessions(ctx context.Context, safeSession *econtext.SafeSession, sql string, ignoreMaxMemoryRows bool) (*sqltypes.Result, error) {
var qr *sqltypes.Result
var errs []error
for _, co := range []vtgatepb.CommitOrder{vtgatepb.CommitOrder_PRE, vtgatepb.CommitOrder_NORMAL, vtgatepb.CommitOrder_POST} {
safeSession.SetCommitOrder(co)

var rss []*srvtopo.ResolvedShard
var queries []*querypb.BoundQuery
for _, shardSession := range safeSession.getSessions() {
for _, shardSession := range safeSession.GetSessions() {
// This will avoid executing savepoint on reserved connections
// which has no open transaction.
if shardSession.TransactionId == 0 {
Expand Down Expand Up @@ -718,7 +719,7 @@ func (e *Executor) handleKill(ctx context.Context, mysqlCtx vtgateservice.MySQLC

// CloseSession releases the current connection, which rollbacks open transactions and closes reserved connections.
// It is called then the MySQL servers closes the connection to its client.
func (e *Executor) CloseSession(ctx context.Context, safeSession *SafeSession) error {
func (e *Executor) CloseSession(ctx context.Context, safeSession *econtext.SafeSession) error {
return e.txConn.ReleaseAll(ctx, safeSession)
}

Expand Down Expand Up @@ -1088,11 +1089,6 @@ func (e *Executor) ParseDestinationTarget(targetString string) (string, topodata
return destKeyspace, destTabletType, dest, err
}

type iQueryOption interface {
cachePlan() bool
getSelectLimit() int
}

// getPlan computes the plan for the given query. If one is in
// the cache, it reuses it.
func (e *Executor) getPlan(
Expand Down Expand Up @@ -1136,7 +1132,7 @@ func (e *Executor) getPlan(
bindVars,
parameterize,
vcursor.keyspace,
vcursor.safeSession.getSelectLimit(),
vcursor.safeSession.GetSelectLimit(),
setVarComment,
vcursor.safeSession.SystemVariables,
vcursor.GetForeignKeyChecksState(),
Expand Down Expand Up @@ -1195,7 +1191,7 @@ func (e *Executor) cacheAndBuildStatement(
bindVarNeeds *sqlparser.BindVarNeeds,
logStats *logstats.LogStats,
) (*engine.Plan, error) {
planCachable := sqlparser.CachePlan(stmt) && vcursor.safeSession.cachePlan()
planCachable := sqlparser.CachePlan(stmt) && vcursor.safeSession.CachePlan()
if planCachable {
planKey := e.hashPlan(ctx, vcursor, query)

Expand Down Expand Up @@ -1354,7 +1350,7 @@ func isValidPayloadSize(query string) bool {
}

// Prepare executes a prepare statements.
func (e *Executor) Prepare(ctx context.Context, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (fld []*querypb.Field, err error) {
func (e *Executor) Prepare(ctx context.Context, method string, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (fld []*querypb.Field, err error) {
logStats := logstats.NewLogStats(ctx, method, sql, safeSession.GetSessionUUID(), bindVars)
fld, err = e.prepare(ctx, safeSession, sql, bindVars, logStats)
logStats.Error = err
Expand All @@ -1373,7 +1369,7 @@ func (e *Executor) Prepare(ctx context.Context, method string, safeSession *Safe
return fld, err
}

func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) ([]*querypb.Field, error) {
func (e *Executor) prepare(ctx context.Context, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) ([]*querypb.Field, error) {
// Start an implicit transaction if necessary.
if !safeSession.Autocommit && !safeSession.InTransaction() {
if err := e.txConn.Begin(ctx, safeSession, nil); err != nil {
Expand Down Expand Up @@ -1409,7 +1405,7 @@ func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql st
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unrecognized prepare statement: %s", sql)
}

func (e *Executor) handlePrepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) ([]*querypb.Field, error) {
func (e *Executor) handlePrepare(ctx context.Context, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) ([]*querypb.Field, error) {
query, comments := sqlparser.SplitMarginComments(sql)
vcursor, _ := newVCursorImpl(safeSession, comments, e, logStats, e.vm, e.VSchema(), e.resolver.resolver, e.serv, e.warnShardedOnly, e.pv)

Expand Down Expand Up @@ -1460,17 +1456,17 @@ func parseAndValidateQuery(query string, parser *sqlparser.Parser) (sqlparser.St
}

// ExecuteMultiShard implements the IExecutor interface
func (e *Executor) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit bool, ignoreMaxMemoryRows bool, resultsObserver resultsObserver) (qr *sqltypes.Result, errs []error) {
func (e *Executor) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *econtext.SafeSession, autocommit bool, ignoreMaxMemoryRows bool, resultsObserver resultsObserver) (qr *sqltypes.Result, errs []error) {
return e.scatterConn.ExecuteMultiShard(ctx, primitive, rss, queries, session, autocommit, ignoreMaxMemoryRows, resultsObserver)
}

// StreamExecuteMulti implements the IExecutor interface
func (e *Executor) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error, resultsObserver resultsObserver) []error {
func (e *Executor) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *econtext.SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error, resultsObserver resultsObserver) []error {
return e.scatterConn.StreamExecuteMulti(ctx, primitive, query, rss, vars, session, autocommit, callback, resultsObserver)
}

// ExecuteLock implements the IExecutor interface
func (e *Executor) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, session *SafeSession, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) {
func (e *Executor) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, session *econtext.SafeSession, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) {
return e.scatterConn.ExecuteLock(ctx, rs, query, session, lockFuncType)
}

Expand Down Expand Up @@ -1581,7 +1577,7 @@ func getTabletThrottlerStatus(tabletHostPort string) (string, error) {
}

// ReleaseLock implements the IExecutor interface
func (e *Executor) ReleaseLock(ctx context.Context, session *SafeSession) error {
func (e *Executor) ReleaseLock(ctx context.Context, session *econtext.SafeSession) error {
return e.txConn.ReleaseLock(ctx, session)
}

Expand Down
3 changes: 2 additions & 1 deletion go/vt/vtgate/executor_ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"testing"

vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
econtext "vitess.io/vitess/go/vt/vtgate/executorcontext"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -56,7 +57,7 @@ func TestDDLFlags(t *testing.T) {
for _, testcase := range testcases {
t.Run(fmt.Sprintf("%s-%v-%v", testcase.sql, testcase.enableDirectDDL, testcase.enableOnlineDDL), func(t *testing.T) {
executor, _, _, _, ctx := createExecutorEnv(t)
session := NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded})
session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded})
enableDirectDDL = testcase.enableDirectDDL
enableOnlineDDL = testcase.enableOnlineDDL
_, err := executor.Execute(ctx, nil, "TestDDLFlags", session, testcase.sql, nil)
Expand Down
Loading

0 comments on commit 27d8dc9

Please sign in to comment.