Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Data race in semi-join #17417

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,24 @@ func TestTimeZones(t *testing.T) {
})
}
}

// TestSemiJoin tests that the semi join works as intended.
func TestSemiJoin(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 22, "vtgate")
mcmp, closer := start(t)
defer closer()

for i := 1; i <= 1000; i++ {
mcmp.Exec(fmt.Sprintf("insert into t1(id1, id2) values (%d, %d)", i, 2*i))
mcmp.Exec(fmt.Sprintf("insert into tbl(id, unq_col, nonunq_col) values (%d, %d, %d)", i, 2*i, 3*i))
}

// Test that the semi join works as intended
for _, mode := range []string{"oltp", "olap"} {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))

mcmp.Exec("select id1, id2 from t1 where exists (select id from tbl where nonunq_col = t1.id2) order by id1")
})
}
}
7 changes: 5 additions & 2 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ type fakePrimitive struct {
// sendErr is sent at the end of the stream if it's set.
sendErr error

log []string
noLog bool
log []string

allResultsInOneCall bool

Expand Down Expand Up @@ -85,7 +86,9 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar
}

func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
if !f.noLog {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
}
if f.results == nil {
return f.sendErr
}
Expand Down
13 changes: 8 additions & 5 deletions go/vt/vtgate/engine/semi_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync/atomic"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -62,24 +63,26 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma

// TryStreamExecute performs a streaming exec.
func (jn *SemiJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
joinVars := make(map[string]*querypb.BindVariable)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
joinVars := make(map[string]*querypb.BindVariable)
result := &sqltypes.Result{Fields: lresult.Fields}
for _, lrow := range lresult.Rows {
for k, col := range jn.Vars {
joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
}
rowAdded := false
var rowAdded atomic.Bool
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false, func(rresult *sqltypes.Result) error {
Comment on lines 66 to 74
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another issue I mentioned: during a transaction, we might end up opening two connections, leaving one of them in limbo.
Do you have another PR that fixes it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will open a separate issue for it, with separate tests and PR for it, after this one gets merged.

if len(rresult.Rows) > 0 && !rowAdded {
result.Rows = append(result.Rows, lrow)
rowAdded = true
if len(rresult.Rows) > 0 {
rowAdded.Store(true)
}
return nil
})
if err != nil {
return err
}
if rowAdded.Load() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this one atomic? If this needs an atomic since it otherwise races, I imagine that the append on the next line also races and needs to be guarded with a lock then?

If this doesn't race, it doesn't need to be atomic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i know, I made the changes such that i could avoid using this atomic thing, but in the callback right above we are writing to it. We are only setting it to true but golang still complains in a -race test saying there is a concurrent write. So I had to make it atomic to avoid that. That being said, for strictly correctness purposes, we didn't need to make it atomic, but I couldn't figure out how to make the test work without making it atomic.

result.Rows = append(result.Rows, lrow)
}
}
return callback(result)
})
Expand Down
57 changes: 57 additions & 0 deletions go/vt/vtgate/engine/semi_join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,60 @@ func TestSemiJoinStreamExecute(t *testing.T) {
"4|d|dd",
))
}

// TestSemiJoinStreamExecuteParallelExecution tests SemiJoin stream execution with parallel execution
// to ensure we have no data races.
func TestSemiJoinStreamExecuteParallelExecution(t *testing.T) {
leftPrim := &fakePrimitive{
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
), sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"3|c|cc",
"4|d|dd",
),
},
async: true,
}
rightFields := sqltypes.MakeTestFields(
"col4|col5|col6",
"int64|varchar|varchar",
)
rightPrim := &fakePrimitive{
// we'll return non-empty results for rows 2 and 4
results: sqltypes.MakeTestStreamingResults(rightFields,
"4|d|dd",
"---",
"---",
"5|e|ee",
"6|f|ff",
"7|g|gg",
),
async: true,
noLog: true,
}

jn := &SemiJoin{
Left: leftPrim,
Right: rightPrim,
Vars: map[string]int{
"bv": 1,
},
}
err := jn.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error {
return nil
})
require.NoError(t, err)
leftPrim.ExpectLog(t, []string{
`StreamExecute true`,
})
}
Loading