From c2e7b1bc1f69683165a48416e3f3a619af010f63 Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Tue, 19 Sep 2023 16:05:05 +0530 Subject: [PATCH] [release-17.0] fix data race in join engine primitive olap streaming mode execution (#14012) (#14016) Signed-off-by: Harshit Gangal Co-authored-by: vitess-bot[bot] <108069721+vitess-bot[bot]@users.noreply.github.com> Co-authored-by: Harshit Gangal --- go/vt/vtgate/engine/join.go | 23 +++++++++++--------- go/vt/vtgate/executor_select_test.go | 32 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index a4c7f66b174..6723165a70b 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "strings" + "sync/atomic" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -114,34 +115,36 @@ func bindvarForType(t querypb.Type) *querypb.BindVariable { // TryStreamExecute performs a streaming exec. func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - joinVars := make(map[string]*querypb.BindVariable) - err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error { + var fieldNeeded atomic.Bool + fieldNeeded.Store(wantfields) + err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, fieldNeeded.Load(), func(lresult *sqltypes.Result) error { + joinVars := make(map[string]*querypb.BindVariable) for _, lrow := range lresult.Rows { for k, col := range jn.Vars { joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) } - rowSent := false - err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), wantfields, func(rresult *sqltypes.Result) error { + var rowSent atomic.Bool + err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), fieldNeeded.Load(), func(rresult *sqltypes.Result) error { result := &sqltypes.Result{} - if wantfields { + if fieldNeeded.Load() { // This code is currently unreachable because the first result // will always be just the field info, which will cause the outer // wantfields code path to be executed. But this may change in the future. - wantfields = false + fieldNeeded.Store(false) result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols) } for _, rrow := range rresult.Rows { result.Rows = append(result.Rows, joinRows(lrow, rrow, jn.Cols)) } if len(rresult.Rows) != 0 { - rowSent = true + rowSent.Store(true) } return callback(result) }) if err != nil { return err } - if jn.Opcode == LeftJoin && !rowSent { + if jn.Opcode == LeftJoin && !rowSent.Load() { result := &sqltypes.Result{} result.Rows = [][]sqltypes.Value{joinRows( lrow, @@ -151,8 +154,8 @@ func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars return callback(result) } } - if wantfields { - wantfields = false + if fieldNeeded.Load() { + fieldNeeded.Store(false) for k := range jn.Vars { joinVars[k] = sqltypes.NullBindVariable } diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 28e1e157477..c7cd12a79e4 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -3987,3 +3987,35 @@ func TestMain(m *testing.M) { _flag.ParseFlagsForTest() os.Exit(m.Run()) } + +func TestStreamJoinQuery(t *testing.T) { + // Special setup: Don't use createExecutorEnv. + cell := "aa" + hc := discovery.NewFakeHealthCheck(nil) + u := createSandbox(KsTestUnsharded) + s := createSandbox(KsTestSharded) + s.VSchema = executorVSchema + u.VSchema = unshardedVSchema + serv := newSandboxForCells([]string{cell}) + resolver := newTestResolver(hc, serv, cell) + shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"} + for _, shard := range shards { + _ = hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_PRIMARY, true, 1, nil) + } + executor := createExecutor(serv, cell, resolver) + + sql := "select u.foo, u.apa, ue.bar, ue.apa from user u join user_extra ue on u.foo = ue.bar" + result, err := executorStream(executor, sql) + require.NoError(t, err) + wantResult := &sqltypes.Result{ + Fields: append(sandboxconn.SingleRowResult.Fields, sandboxconn.SingleRowResult.Fields...), + } + wantRow := append(sandboxconn.StreamRowResult.Rows[0], sandboxconn.StreamRowResult.Rows[0]...) + for i := 0; i < 64; i++ { + wantResult.Rows = append(wantResult.Rows, wantRow) + } + require.Equal(t, len(wantResult.Rows), len(result.Rows)) + for idx := 0; idx < 64; idx++ { + utils.MustMatch(t, wantResult.Rows[idx], result.Rows[idx], "mismatched on: ", strconv.Itoa(idx)) + } +}