Skip to content

Commit

Permalink
feat: make sure Insert updates the result struct correctly
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Dec 20, 2024
1 parent 53e3fcc commit a32c012
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 36 deletions.
2 changes: 1 addition & 1 deletion go/mysql/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result *
got = &sqltypes.Result{}
got.RowsAffected = result.RowsAffected
got.InsertID = result.InsertID
got.InsertIDChanged = result.InsertIDChanged
got.InsertIDChanged = result.InsertIDUpdated()
got.Fields, err = cConn.Fields()
if err != nil {
fatalError = fmt.Sprintf("Fields(%v) failed: %v", query, err)
Expand Down
2 changes: 1 addition & 1 deletion go/sqltypes/proto3.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func ResultToProto3(qr *Result) *querypb.QueryResult {
Fields: qr.Fields,
RowsAffected: qr.RowsAffected,
InsertId: qr.InsertID,
InsertIdChanged: qr.InsertIDChanged,
InsertIdChanged: qr.InsertIDUpdated(),
Rows: RowsToProto3(qr.Rows),
Info: qr.Info,
SessionStateChanges: qr.SessionStateChanges,
Expand Down
14 changes: 9 additions & 5 deletions go/sqltypes/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (result *Result) Copy() *Result {
out := &Result{
RowsAffected: result.RowsAffected,
InsertID: result.InsertID,
InsertIDChanged: result.InsertIDChanged,
InsertIDChanged: result.InsertIDUpdated(),
SessionStateChanges: result.SessionStateChanges,
StatusFlags: result.StatusFlags,
Info: result.Info,
Expand Down Expand Up @@ -132,7 +132,7 @@ func (result *Result) Metadata() *Result {
return &Result{
Fields: result.Fields,
InsertID: result.InsertID,
InsertIDChanged: result.InsertIDChanged,
InsertIDChanged: result.InsertIDUpdated(),
RowsAffected: result.RowsAffected,
Info: result.Info,
SessionStateChanges: result.SessionStateChanges,
Expand All @@ -157,7 +157,7 @@ func (result *Result) Truncate(l int) *Result {

out := &Result{
InsertID: result.InsertID,
InsertIDChanged: result.InsertIDChanged,
InsertIDChanged: result.InsertIDUpdated(),
RowsAffected: result.RowsAffected,
Info: result.Info,
SessionStateChanges: result.SessionStateChanges,
Expand Down Expand Up @@ -331,10 +331,10 @@ func (result *Result) StripMetadata(incl querypb.ExecuteOptions_IncludedFields)
// if two results have different fields.We will enhance this function.
func (result *Result) AppendResult(src *Result) {
result.RowsAffected += src.RowsAffected
if src.InsertID != 0 || src.InsertIDChanged {
if src.InsertIDUpdated() {
result.InsertID = src.InsertID
}
result.InsertIDChanged = result.InsertIDChanged || src.InsertIDChanged
result.InsertIDChanged = result.InsertIDUpdated() || src.InsertIDUpdated()
if result.Fields == nil {
result.Fields = src.Fields
}
Expand All @@ -355,3 +355,7 @@ func (result *Result) IsMoreResultsExists() bool {
func (result *Result) IsInTransaction() bool {
return result.StatusFlags&ServerStatusInTrans == ServerStatusInTrans
}

func (result *Result) InsertIDUpdated() bool {
return result.InsertIDChanged || result.InsertID > 0
}
1 change: 1 addition & 0 deletions go/vt/vtgate/engine/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ func (ins *Insert) executeInsertQueries(

if insertID != 0 {
result.InsertID = insertID
result.InsertIDChanged = true
}
return result, nil
}
Expand Down
1 change: 1 addition & 0 deletions go/vt/vtgate/engine/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func (ins *InsertCommon) executeUnshardedTableQuery(ctx context.Context, vcursor
// values, we don't return an error because this behavior
// is required to support migration.
if insertID != 0 {
qr.InsertIDChanged = true
qr.InsertID = insertID
}
return qr, nil
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,10 @@ func (s *streaminResultReceiver) storeResultStats(typ sqlparser.StatementType, q
defer s.mu.Unlock()
s.rowsAffected += qr.RowsAffected
s.rowsReturned += len(qr.Rows)
if qr.InsertID != 0 || qr.InsertIDChanged {
if qr.InsertIDUpdated() {
s.insertID = qr.InsertID
}
s.insertIDChanged = s.insertIDChanged || qr.InsertIDChanged
s.insertIDChanged = s.insertIDChanged || qr.InsertIDUpdated()
s.stmtType = typ
return s.callback(qr)
}
Expand Down
55 changes: 33 additions & 22 deletions go/vt/vtgate/executor_dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1812,8 +1812,9 @@ func TestInsertGeneratorSharded(t *testing.T) {
Rows: [][]sqltypes.Value{{
sqltypes.NewInt64(1),
}},
RowsAffected: 1,
InsertID: 1,
RowsAffected: 1,
InsertIDChanged: true,
InsertID: 1,
}})
session := &vtgatepb.Session{
TargetString: "@primary",
Expand All @@ -1840,8 +1841,9 @@ func TestInsertGeneratorSharded(t *testing.T) {
}}
assertQueries(t, sbclookup, wantQueries)
wantResult := &sqltypes.Result{
InsertID: 1,
RowsAffected: 1,
InsertID: 1,
RowsAffected: 1,
InsertIDChanged: true,
}
utils.MustMatch(t, wantResult, result)
}
Expand All @@ -1854,8 +1856,9 @@ func TestInsertAutoincSharded(t *testing.T) {
Rows: [][]sqltypes.Value{{
sqltypes.NewInt64(1),
}},
RowsAffected: 1,
InsertID: 2,
RowsAffected: 1,
InsertID: 2,
InsertIDChanged: true,
}
sbc.SetResults([]*sqltypes.Result{wantResult})
session := &vtgatepb.Session{
Expand Down Expand Up @@ -1894,8 +1897,9 @@ func TestInsertGeneratorUnsharded(t *testing.T) {
}}
assertQueries(t, sbclookup, wantQueries)
wantResult := &sqltypes.Result{
InsertID: 1,
RowsAffected: 1,
InsertID: 1,
InsertIDChanged: true,
RowsAffected: 1,
}
utils.MustMatch(t, wantResult, result)
}
Expand All @@ -1912,8 +1916,9 @@ func TestInsertAutoincUnsharded(t *testing.T) {
Rows: [][]sqltypes.Value{{
sqltypes.NewInt64(1),
}},
RowsAffected: 1,
InsertID: 2,
RowsAffected: 1,
InsertID: 2,
InsertIDChanged: true,
}
sbclookup.SetResults([]*sqltypes.Result{wantResult})

Expand Down Expand Up @@ -1965,8 +1970,9 @@ func TestInsertLookupOwnedGenerator(t *testing.T) {
Rows: [][]sqltypes.Value{{
sqltypes.NewInt64(4),
}},
RowsAffected: 1,
InsertID: 1,
RowsAffected: 1,
InsertID: 1,
InsertIDChanged: true,
}})
session := &vtgatepb.Session{
TargetString: "@primary",
Expand All @@ -1993,8 +1999,9 @@ func TestInsertLookupOwnedGenerator(t *testing.T) {
}}
assertQueries(t, sbclookup, wantQueries)
wantResult := &sqltypes.Result{
InsertID: 4,
RowsAffected: 1,
InsertID: 4,
InsertIDChanged: true,
RowsAffected: 1,
}
utils.MustMatch(t, wantResult, result)
}
Expand Down Expand Up @@ -2226,8 +2233,9 @@ func TestMultiInsertGenerator(t *testing.T) {
Rows: [][]sqltypes.Value{{
sqltypes.NewInt64(1),
}},
RowsAffected: 1,
InsertID: 1,
RowsAffected: 1,
InsertIDChanged: true,
InsertID: 1,
}})
session := &vtgatepb.Session{
TargetString: "@primary",
Expand Down Expand Up @@ -2258,8 +2266,9 @@ func TestMultiInsertGenerator(t *testing.T) {
}}
assertQueries(t, sbclookup, wantQueries)
wantResult := &sqltypes.Result{
InsertID: 1,
RowsAffected: 1,
InsertIDChanged: true,
InsertID: 1,
RowsAffected: 1,
}
utils.MustMatch(t, wantResult, result)
}
Expand All @@ -2271,8 +2280,9 @@ func TestMultiInsertGeneratorSparse(t *testing.T) {
Rows: [][]sqltypes.Value{{
sqltypes.NewInt64(1),
}},
RowsAffected: 1,
InsertID: 1,
RowsAffected: 1,
InsertIDChanged: true,
InsertID: 1,
}})
session := &vtgatepb.Session{
TargetString: "@primary",
Expand Down Expand Up @@ -2307,8 +2317,9 @@ func TestMultiInsertGeneratorSparse(t *testing.T) {
}}
assertQueries(t, sbclookup, wantQueries)
wantResult := &sqltypes.Result{
InsertID: 1,
RowsAffected: 1,
InsertIDChanged: true,
InsertID: 1,
RowsAffected: 1,
}
utils.MustMatch(t, wantResult, result)
}
Expand Down
8 changes: 4 additions & 4 deletions go/vt/vtgate/executorcontext/vcursor_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -676,15 +676,15 @@ func (vc *VCursorImpl) ExecutePrimitiveStandalone(ctx context.Context, primitive
func (vc *VCursorImpl) wrapCallback(callback func(*sqltypes.Result) error, primitive engine.Primitive) func(*sqltypes.Result) error {
if vc.interOpStats == nil {
return func(r *sqltypes.Result) error {
if r.InsertIDChanged {
if r.InsertIDUpdated() {
vc.SafeSession.LastInsertId = r.InsertID
}
return callback(r)
}
}

return func(r *sqltypes.Result) error {
if r.InsertIDChanged {
if r.InsertIDUpdated() {
vc.SafeSession.LastInsertId = r.InsertID
}
vc.logOpTraffic(primitive, r)
Expand Down Expand Up @@ -772,7 +772,7 @@ func (vc *VCursorImpl) ExecuteMultiShard(ctx context.Context, primitive engine.P
qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, commentedShardQueries(queries, vc.marginComments), vc.SafeSession, canAutocommit, vc.ignoreMaxMemoryRows, vc.observer, fetchLastInsertID)
vc.setRollbackOnPartialExecIfRequired(len(errs) != len(rss), rollbackOnError)
vc.logShardsQueried(primitive, len(rss))
if qr.InsertIDChanged {
if qr.InsertIDUpdated() {
vc.SafeSession.LastInsertId = qr.InsertID
}
return qr, errs
Expand Down Expand Up @@ -814,7 +814,7 @@ func (vc *VCursorImpl) ExecuteStandalone(ctx context.Context, primitive engine.P
// execute DMLs through ExecuteStandalone.
qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, bqs, NewAutocommitSession(vc.SafeSession.Session), false /* autocommit */, vc.ignoreMaxMemoryRows, vc.observer, fetchLastInsertID)
vc.logShardsQueried(primitive, len(rss))
if qr.InsertIDChanged {
if qr.InsertIDUpdated() {
vc.SafeSession.LastInsertId = qr.InsertID
}
return qr, vterrors.Aggregate(errs)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vttablet/tabletserver/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,7 @@ func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction
if err = qre.fetchLastInsertID(ctx, conn.Conn, res); err != nil {
return err
}
if res.InsertIDChanged {
if res.InsertIDUpdated() {
return callback(res)
}
return nil
Expand Down

0 comments on commit a32c012

Please sign in to comment.