diff --git a/go/sqltypes/proto3.go b/go/sqltypes/proto3.go index 0ca03b153cf..5c4934f47a6 100644 --- a/go/sqltypes/proto3.go +++ b/go/sqltypes/proto3.go @@ -18,6 +18,7 @@ package sqltypes import ( "google.golang.org/protobuf/proto" + "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/vterrors" @@ -131,9 +132,11 @@ func Proto3ToResult(qr *querypb.QueryResult) *Result { // takes a separate fields input because not all QueryResults contain the field info. // In particular, only the first packet of streaming queries contain the field info. func CustomProto3ToResult(fields []*querypb.Field, qr *querypb.QueryResult) *Result { + log.Info("Building Proto to Result") if qr == nil { return nil } + log.Info("Result: %v, %v", qr.InsertId, qr.InsertIdChanged) return &Result{ Fields: qr.Fields, RowsAffected: qr.RowsAffected, diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index 05b96c33604..eaa2e21892f 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -135,7 +135,7 @@ func TestCast(t *testing.T) { func TestSetAndGetLastInsertID(t *testing.T) { notZero := 1 checkQuery := func(i string, workload string, tx bool, mcmp utils.MySQLCompare) { - for _, val := range []int{0, notZero} { + for _, val := range []int{notZero, 0, notZero + 99} { query := fmt.Sprintf(i, val) name := fmt.Sprintf("%s - %s", workload, query) if tx { @@ -154,18 +154,18 @@ func TestSetAndGetLastInsertID(t *testing.T) { } queries := []string{ - "select last_insert_id(%d)", + // "select last_insert_id(%d)", "select last_insert_id(%d), id1, id2 from t1 limit 1", - "select last_insert_id(%d), id1, id2 from t1 where 1 = 2", - "select 12 from t1 where last_insert_id(%d)", - "update t1 set id2 = last_insert_id(%d) where id1 = 1", - "update t1 set id2 = last_insert_id(%d) where id1 = 2", - "update t1 set id2 = 88 where id1 = last_insert_id(%d)", - "delete from t1 where id1 = last_insert_id(%d)", + // "select last_insert_id(%d), id1, id2 from t1 where 1 = 2", + // "select 12 from t1 where last_insert_id(%d)", + // "update t1 set id2 = last_insert_id(%d) where id1 = 1", + // "update t1 set id2 = last_insert_id(%d) where id1 = 2", + // "update t1 set id2 = 88 where id1 = last_insert_id(%d)", + // "delete from t1 where id1 = last_insert_id(%d)", } - for _, workload := range []string{"olap", "oltp"} { - for _, tx := range []bool{true, false} { + for _, workload := range []string{"olap"} { + for _, tx := range []bool{false} { mcmp, closer := start(t) _, err := mcmp.VtConn.ExecuteFetch(fmt.Sprintf("set workload = %s", workload), 1000, false) require.NoError(t, err) diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index b2f3ad8790d..1be26bf9dc9 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -387,6 +387,7 @@ func (stc *ScatterConn) StreamExecuteMulti( if reply != nil { resultsObserver.Observe(reply) } + log.Infof("result received: %v, %v", reply.InsertID, reply.InsertIDChanged) return callback(reply) } allErrors := stc.multiGoTransaction( diff --git a/go/vt/vttablet/grpcqueryservice/server.go b/go/vt/vttablet/grpcqueryservice/server.go index e3c179ce856..3fb839ab452 100644 --- a/go/vt/vttablet/grpcqueryservice/server.go +++ b/go/vt/vttablet/grpcqueryservice/server.go @@ -20,6 +20,7 @@ import ( "context" "google.golang.org/grpc" + "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/callerid" @@ -65,6 +66,7 @@ func (q *query) StreamExecute(request *querypb.StreamExecuteRequest, stream quer request.ImmediateCallerId, ) err = q.server.StreamExecute(ctx, request.Target, request.Query.Sql, request.Query.BindVariables, request.TransactionId, request.ReservedId, request.Options, func(reply *sqltypes.Result) error { + log.Infof("StreamExecute: (%v, %v)", reply.InsertID, reply.InsertIDChanged) return stream.Send(&querypb.StreamExecuteResponse{ Result: sqltypes.ResultToProto3(reply), }) diff --git a/go/vt/vttablet/grpctabletconn/conn.go b/go/vt/vttablet/grpctabletconn/conn.go index d2d5604d808..6bbbefeda27 100644 --- a/go/vt/vttablet/grpctabletconn/conn.go +++ b/go/vt/vttablet/grpctabletconn/conn.go @@ -23,6 +23,7 @@ import ( "github.com/spf13/pflag" "google.golang.org/grpc" + "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/netutil" "vitess.io/vitess/go/sqltypes" @@ -187,6 +188,7 @@ func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, target *querypb. if fields == nil { fields = ser.Result.Fields } + log.Infof("StreamExecute result: %v, %v, %v", query, ser.Result.InsertId, ser.Result.InsertIdChanged) if err := callback(sqltypes.CustomProto3ToResult(fields, ser.Result)); err != nil { if err == io.EOF { return nil diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index 2253d22ec6e..3d035982d68 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -1202,16 +1202,8 @@ func (qre *QueryExecutor) fetchLastInsertID(ctx context.Context, conn *connpool. func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction bool, sql string, callback func(*sqltypes.Result) error) error { span, ctx := trace.NewSpan(qre.ctx, "QueryExecutor.execStreamSQL") + defer span.Finish() trace.AnnotateSQL(span, sqlparser.Preview(sql)) - callBackClosingSpan := func(result *sqltypes.Result) error { - defer span.Finish() - - // if err := qre.fetchLastInsertID(ctx, conn.Conn, result); err != nil { - // return err - // } - - return callback(result) - } start := time.Now() defer qre.logStats.AddRewrittenSQL(sql, start) @@ -1222,9 +1214,19 @@ func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction // This change will ensure that long-running streaming stateful queries get gracefully shutdown during ServingTypeChange // once their grace period is over. qd := NewQueryDetail(qre.logStats.Ctx, conn.Conn) - // if err := qre.resetLastInsertIDIfNeeded(ctx, conn.Conn); err != nil { - // return err - // } + log.Infof("Fetch Last Insert ID: %v, sql: %v", qre.options.GetFetchLastInsertId(), sql) + + if err := qre.resetLastInsertIDIfNeeded(ctx, conn.Conn); err != nil { + return err + } + + lastInsertIDSet := false + cb := func(result *sqltypes.Result) error { + if result != nil && result.InsertID != 0 { + lastInsertIDSet = true + } + return callback(result) + } if isTransaction { err := qre.tsv.statefulql.Add(qd) @@ -1232,7 +1234,7 @@ func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction return err } defer qre.tsv.statefulql.Remove(qd) - err = conn.Conn.StreamOnce(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) + err = conn.Conn.StreamOnce(ctx, sql, cb, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) if err != nil { return err } @@ -1243,7 +1245,22 @@ func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction return err } defer qre.tsv.olapql.Remove(qd) - return conn.Conn.Stream(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) + err = conn.Conn.Stream(ctx, sql, cb, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) + log.Infof("streaming complete") + if err != nil || lastInsertIDSet || !qre.options.GetFetchLastInsertId() { + return err + } + log.Infof("checking for last insert id change") + res := &sqltypes.Result{} + if err = qre.fetchLastInsertID(ctx, conn.Conn, res); err != nil { + return err + } + if res.InsertIDChanged { + log.Infof("done callback with values: (%v, %v)", res.InsertID, res.InsertIDChanged) + return callback(res) + } + log.Info("no changes for last insert ID") + return nil } func (qre *QueryExecutor) recordUserQuery(queryType string, duration int64) {