diff --git a/go/test/endtoend/vtgate/misc_test.go b/go/test/endtoend/vtgate/misc_test.go index bbcb338fa50..ff90e0f88d8 100644 --- a/go/test/endtoend/vtgate/misc_test.go +++ b/go/test/endtoend/vtgate/misc_test.go @@ -41,6 +41,19 @@ func TestInsertOnDuplicateKey(t *testing.T) { } +func TestLastInsertID(t *testing.T) { + conn, err := mysql.Connect(context.Background(), &vtParams) + require.NoError(t, err) + + _, err = conn.ExecuteFetch("select last_insert_id(12)", 1000, true) + require.NoError(t, err) + + qr, err := conn.ExecuteFetch("select last_insert_id()", 1000, true) + require.NoError(t, err) + + require.Equal(t, `[[VARCHAR("ks")]]`, fmt.Sprintf("%v", qr.Rows[0][0].String())) +} + func TestInsertNeg(t *testing.T) { conn, closer := start(t) defer closer() diff --git a/go/test/endtoend/vtgate/vitess_tester/aggregation/aggregation.test b/go/test/endtoend/vtgate/vitess_tester/aggregation/aggregation.test index 8b0997eed1a..7a02c7e85fa 100644 --- a/go/test/endtoend/vtgate/vitess_tester/aggregation/aggregation.test +++ b/go/test/endtoend/vtgate/vitess_tester/aggregation/aggregation.test @@ -64,4 +64,10 @@ from (select id, count(*) as num_segments from t1 group by 1 order by 2 desc lim join t2 u on u.id = t.id; select name -from (select name from t1 group by name having count(t1.id) > 1) t1; \ No newline at end of file +from (select name from t1 group by name having count(t1.id) > 1) t1; + +# this query uses last_insert_id with a column argument to show that this works well +select id, last_insert_id(count(*)) as num_segments from t1 group by id; + +# checking that we stored the correct value in the last_insert_id +select last_insert_id(); diff --git a/go/test/endtoend/vtgate/vitess_tester/expressions/expressions.test b/go/test/endtoend/vtgate/vitess_tester/expressions/expressions.test index 60c1e641463..f68023c78b6 100644 --- a/go/test/endtoend/vtgate/vitess_tester/expressions/expressions.test +++ b/go/test/endtoend/vtgate/vitess_tester/expressions/expressions.test @@ -28,3 +28,7 @@ SELECT (~ (1 || 0)) IS NULL; SELECT 1 WHERE (~ (1 || 0)) IS NULL; + +select last_insert_id(12); + +select last_insert_id(); diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 0bb47361f55..ad7501a9068 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -678,7 +678,7 @@ func (e *Executor) executeSPInAllSessions(ctx context.Context, safeSession *econ }) queries = append(queries, &querypb.BoundQuery{Sql: sql}) } - qr, errs = e.ExecuteMultiShard(ctx, nil, rss, queries, safeSession, false /*autocommit*/, ignoreMaxMemoryRows, nullResultsObserver{}) + qr, errs = e.ExecuteMultiShard(ctx, nil, rss, queries, safeSession, false /*autocommit*/, ignoreMaxMemoryRows, false, nullResultsObserver{}) err := vterrors.Aggregate(errs) if err != nil { return nil, err @@ -1484,8 +1484,8 @@ func parseAndValidateQuery(query string, parser *sqlparser.Parser) (sqlparser.St } // ExecuteMultiShard implements the IExecutor interface -func (e *Executor) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *econtext.SafeSession, autocommit bool, ignoreMaxMemoryRows bool, resultsObserver econtext.ResultsObserver) (qr *sqltypes.Result, errs []error) { - return e.scatterConn.ExecuteMultiShard(ctx, primitive, rss, queries, session, autocommit, ignoreMaxMemoryRows, resultsObserver) +func (e *Executor) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *econtext.SafeSession, autocommit, ignoreMaxMemoryRows, fetchLastInsertID bool, resultsObserver econtext.ResultsObserver) (qr *sqltypes.Result, errs []error) { + return e.scatterConn.ExecuteMultiShard(ctx, primitive, rss, queries, session, autocommit, ignoreMaxMemoryRows, resultsObserver, fetchLastInsertID) } // StreamExecuteMulti implements the IExecutor interface diff --git a/go/vt/vtgate/executorcontext/vcursor_impl.go b/go/vt/vtgate/executorcontext/vcursor_impl.go index 08fb89366c2..5f291112b88 100644 --- a/go/vt/vtgate/executorcontext/vcursor_impl.go +++ b/go/vt/vtgate/executorcontext/vcursor_impl.go @@ -95,7 +95,7 @@ type ( // vcursor_impl needs these facilities to be able to be able to execute queries for vindexes iExecute interface { Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, method string, session *SafeSession, s string, vars map[string]*querypb.BindVariable) (*sqltypes.Result, error) - ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit bool, ignoreMaxMemoryRows bool, resultsObserver ResultsObserver) (qr *sqltypes.Result, errs []error) + ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit, ignoreMaxMemoryRows, fetchLastInsertID bool, resultsObserver ResultsObserver) (qr *sqltypes.Result, errs []error) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error, observer ResultsObserver) []error ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, session *SafeSession, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) Commit(ctx context.Context, safeSession *SafeSession) error @@ -761,7 +761,7 @@ func (vc *VCursorImpl) ExecuteMultiShard(ctx context.Context, primitive engine.P return nil, []error{err} } - qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, commentedShardQueries(queries, vc.marginComments), vc.SafeSession, canAutocommit, vc.ignoreMaxMemoryRows, vc.observer) + qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, commentedShardQueries(queries, vc.marginComments), vc.SafeSession, canAutocommit, fetchLastInsertID, vc.ignoreMaxMemoryRows, vc.observer) vc.setRollbackOnPartialExecIfRequired(len(errs) != len(rss), rollbackOnError) vc.logShardsQueried(primitive, len(rss)) return qr, errs @@ -801,7 +801,7 @@ func (vc *VCursorImpl) ExecuteStandalone(ctx context.Context, primitive engine.P } // The autocommit flag is always set to false because we currently don't // execute DMLs through ExecuteStandalone. - qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, bqs, NewAutocommitSession(vc.SafeSession.Session), false /* autocommit */, vc.ignoreMaxMemoryRows, vc.observer) + qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, bqs, NewAutocommitSession(vc.SafeSession.Session), false /* autocommit */, vc.ignoreMaxMemoryRows, false, vc.observer) vc.logShardsQueried(primitive, len(rss)) return qr, vterrors.Aggregate(errs) } diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index 6e2cf9ad8ba..9d897c2cf58 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -152,6 +152,7 @@ func (stc *ScatterConn) ExecuteMultiShard( autocommit bool, ignoreMaxMemoryRows bool, resultsObserver econtext.ResultsObserver, + fetchLastInsertID bool, ) (qr *sqltypes.Result, errs []error) { if len(rss) != len(queries) { @@ -186,6 +187,9 @@ func (stc *ScatterConn) ExecuteMultiShard( if session != nil && session.Session != nil { opts = session.Session.Options } + if fetchLastInsertID { + opts.FetchLastInsertId = true + } if autocommit { // As this is auto-commit, the transactionID is supposed to be zero.