diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index c67a0951b35..6343789aef8 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "strings" + "sync/atomic" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -95,34 +96,36 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st // TryStreamExecute performs a streaming exec. func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - joinVars := make(map[string]*querypb.BindVariable) - err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error { + var fieldNeeded atomic.Bool + fieldNeeded.Store(wantfields) + err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, fieldNeeded.Load(), func(lresult *sqltypes.Result) error { + joinVars := make(map[string]*querypb.BindVariable) for _, lrow := range lresult.Rows { for k, col := range jn.Vars { joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) } - rowSent := false - err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), wantfields, func(rresult *sqltypes.Result) error { + var rowSent atomic.Bool + err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), fieldNeeded.Load(), func(rresult *sqltypes.Result) error { result := &sqltypes.Result{} - if wantfields { + if fieldNeeded.Load() { // This code is currently unreachable because the first result // will always be just the field info, which will cause the outer // wantfields code path to be executed. But this may change in the future. - wantfields = false + fieldNeeded.Store(false) result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols) } for _, rrow := range rresult.Rows { result.Rows = append(result.Rows, joinRows(lrow, rrow, jn.Cols)) } if len(rresult.Rows) != 0 { - rowSent = true + rowSent.Store(true) } return callback(result) }) if err != nil { return err } - if jn.Opcode == LeftJoin && !rowSent { + if jn.Opcode == LeftJoin && !rowSent.Load() { result := &sqltypes.Result{} result.Rows = [][]sqltypes.Value{joinRows( lrow, @@ -132,8 +135,8 @@ func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars return callback(result) } } - if wantfields { - wantfields = false + if fieldNeeded.Load() { + fieldNeeded.Store(false) for k := range jn.Vars { joinVars[k] = sqltypes.NullBindVariable } diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 4aa51368be2..d934c7805da 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -3984,3 +3984,35 @@ func TestMain(m *testing.M) { _flag.ParseFlagsForTest() os.Exit(m.Run()) } + +func TestStreamJoinQuery(t *testing.T) { + // Special setup: Don't use createExecutorEnv. + cell := "aa" + hc := discovery.NewFakeHealthCheck(nil) + u := createSandbox(KsTestUnsharded) + s := createSandbox(KsTestSharded) + s.VSchema = executorVSchema + u.VSchema = unshardedVSchema + serv := newSandboxForCells([]string{cell}) + resolver := newTestResolver(hc, serv, cell) + shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"} + for _, shard := range shards { + _ = hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_PRIMARY, true, 1, nil) + } + executor := createExecutor(serv, cell, resolver) + + sql := "select u.foo, u.apa, ue.bar, ue.apa from user u join user_extra ue on u.foo = ue.bar" + result, err := executorStream(executor, sql) + require.NoError(t, err) + wantResult := &sqltypes.Result{ + Fields: append(sandboxconn.SingleRowResult.Fields, sandboxconn.SingleRowResult.Fields...), + } + wantRow := append(sandboxconn.StreamRowResult.Rows[0], sandboxconn.StreamRowResult.Rows[0]...) + for i := 0; i < 64; i++ { + wantResult.Rows = append(wantResult.Rows, wantRow) + } + require.Equal(t, len(wantResult.Rows), len(result.Rows)) + for idx := 0; idx < 64; idx++ { + utils.MustMatch(t, wantResult.Rows[idx], result.Rows[idx], "mismatched on: ", strconv.Itoa(idx)) + } +}