diff --git a/go/test/endtoend/vtgate/queries/union/union_test.go b/go/test/endtoend/vtgate/queries/union/union_test.go index d91ea3c4073..03f98950f44 100644 --- a/go/test/endtoend/vtgate/queries/union/union_test.go +++ b/go/test/endtoend/vtgate/queries/union/union_test.go @@ -20,7 +20,6 @@ import ( "testing" "vitess.io/vitess/go/test/endtoend/cluster" - "vitess.io/vitess/go/test/endtoend/utils" "github.com/stretchr/testify/assert" @@ -111,9 +110,15 @@ func TestUnionAll(t *testing.T) { mcmp.AssertMatchesNoOrder("select tbl2.id1 FROM ((select id1 from t1 order by id1 limit 5) union all (select id1 from t1 order by id1 desc limit 5)) as tbl1 INNER JOIN t1 as tbl2 ON tbl1.id1 = tbl2.id1", "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]") - // union all between two select unique in tables - mcmp.AssertMatchesNoOrder("select id1 from t1 where id1 in (1, 2, 3, 4, 5, 6, 7, 8) union all select id1 from t1 where id1 in (1, 2, 3, 4, 5, 6, 7, 8)", - "[[INT64(1)] [INT64(2)] [INT64(1)] [INT64(2)]]") + // this test is quite good at uncovering races in the Concatenate engine primitive. make it run many times + // see: https://github.com/vitessio/vitess/issues/15434 + if utils.BinaryIsAtLeastAtVersion(20, "vtgate") { + for i := 0; i < 100; i++ { + // union all between two select unique in tables + mcmp.AssertMatchesNoOrder("select id1 from t1 where id1 in (1, 2, 3, 4, 5, 6, 7, 8) union all select id1 from t1 where id1 in (1, 2, 3, 4, 5, 6, 7, 8)", + "[[INT64(1)] [INT64(2)] [INT64(1)] [INT64(2)]]") + } + } // 4 tables union all mcmp.AssertMatchesNoOrder("select id1, id2 from t1 where id1 = 1 union all select id3,id4 from t2 where id3 = 3 union all select id1, id2 from t1 where id1 = 2 union all select id3,id4 from t2 where id3 = 4", diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 352b190fb1d..13727124e78 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -285,36 +285,35 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, currIndex, currSource := i, source wg.Go(func() error { err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error { - // Process fields when they arrive; coordinate field agreement across sources. - if resultChunk.Fields != nil { - muFields.Lock() + muFields.Lock() + // Process fields when they arrive; coordinate field agreement across sources. + if resultChunk.Fields != nil && rest[currIndex] == nil { // Capture the initial result chunk to determine field types later. - if rest[currIndex] == nil { - rest[currIndex] = resultChunk - - // If this was the last source to report its fields, derive the final output fields. - if !slices.Contains(rest, nil) { + rest[currIndex] = resultChunk + + // If this was the last source to report its fields, derive the final output fields. + if !slices.Contains(rest, nil) { + // We have received fields from all sources. We can now calculate the output types + var err error + resultChunk.Fields, fieldTypes, err = c.getFieldTypes(vcursor, rest) + if err != nil { muFields.Unlock() - - // We have received fields from all sources. We can now calculate the output types - var err error - resultChunk.Fields, fieldTypes, err = c.getFieldTypes(vcursor, rest) - if err != nil { - return err - } - - defer condFields.Broadcast() - return callback(resultChunk, currIndex) + return err } + + muFields.Unlock() + defer condFields.Broadcast() + return callback(resultChunk, currIndex) } - // Wait for fields from all sources. - for slices.Contains(rest, nil) { - condFields.Wait() - } - muFields.Unlock() } + // Wait for fields from all sources. + for slices.Contains(rest, nil) { + condFields.Wait() + } + muFields.Unlock() + // Context check to avoid extra work. if ctx.Err() != nil { return nil