Skip to content

Commit

Permalink
Fix panic in aggregation (#15728)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <[email protected]>
  • Loading branch information
GuptaManan100 authored Apr 17, 2024
1 parent 1536314 commit f11de06
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
8 changes: 4 additions & 4 deletions go/test/endtoend/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,12 @@ func AssertMatchesWithTimeout(t *testing.T, conn *mysql.Conn, query, expected st
}

// WaitForAuthoritative waits for a table to become authoritative
func WaitForAuthoritative(t *testing.T, ks, tbl string, readVSchema func() (*interface{}, error)) error {
func WaitForAuthoritative(t TestingT, ks, tbl string, readVSchema func() (*interface{}, error)) error {
timeout := time.After(60 * time.Second)
for {
select {
case <-timeout:
return fmt.Errorf("schema tracking didn't mark table t2 as authoritative until timeout")
return fmt.Errorf("schema tracking didn't mark table %v.%v as authoritative until timeout", ks, tbl)
default:
res, err := readVSchema()
require.NoError(t, err, res)
Expand Down Expand Up @@ -320,7 +320,7 @@ func WaitForTableDeletions(t *testing.T, vtgateProcess cluster.VtgateProcess, ks
}

// WaitForColumn waits for a table's column to be present
func WaitForColumn(t testing.TB, vtgateProcess cluster.VtgateProcess, ks, tbl, col string) error {
func WaitForColumn(t TestingT, vtgateProcess cluster.VtgateProcess, ks, tbl, col string) error {
timeout := time.After(60 * time.Second)
for {
select {
Expand Down Expand Up @@ -355,7 +355,7 @@ func WaitForColumn(t testing.TB, vtgateProcess cluster.VtgateProcess, ks, tbl, c
if !isMap {
break
}
if colName, exists := colDef["name"]; exists && colName == col {
if colName, exists := colDef["name"]; exists && strings.EqualFold(colName.(string), col) {
return nil
}
}
Expand Down
30 changes: 30 additions & 0 deletions go/test/endtoend/vtgate/queries/tpch/tpch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,36 @@ order by
l_returnflag,
l_linestatus;`,
},
{
name: "Q11",
query: `select
ps_partkey,
sum(ps_supplycost * ps_availqty) as value
from
partsupp,
supplier,
nation
where
ps_suppkey = s_suppkey
and s_nationkey = n_nationkey
and n_name = 'MOZAMBIQUE'
group by
ps_partkey having
sum(ps_supplycost * ps_availqty) > (
select
sum(ps_supplycost * ps_availqty) * 0.0001000000
from
partsupp,
supplier,
nation
where
ps_suppkey = s_suppkey
and s_nationkey = n_nationkey
and n_name = 'MOZAMBIQUE'
)
order by
value desc;`,
},
}

for _, testcase := range testcases {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/scalar_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (sa *ScalarAggregate) NeedsTransaction() bool {

// TryExecute implements the Primitive interface
func (sa *ScalarAggregate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
result, err := vcursor.ExecutePrimitive(ctx, sa.Input, bindVars, wantfields)
result, err := vcursor.ExecutePrimitive(ctx, sa.Input, bindVars, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor
var fields []*querypb.Field
fieldsSent := !wantfields

err := vcursor.StreamExecutePrimitive(ctx, sa.Input, bindVars, wantfields, func(result *sqltypes.Result) error {
err := vcursor.StreamExecutePrimitive(ctx, sa.Input, bindVars, true, func(result *sqltypes.Result) error {
// as the underlying primitive call is not sync
// and here scalar aggregate is using shared variables we have to sync the callback
// for correct aggregation.
Expand Down

0 comments on commit f11de06

Please sign in to comment.