Skip to content

Commit

Permalink
[release-14.0] Fix scalar aggregation engine primitive for column tru…
Browse files Browse the repository at this point in the history
…ncation (vitessio#12468) (vitessio#12471)

* Fix scalar aggregation engine primitive for column truncation (vitessio#12468)

* fix: scalar aggregation truncation

Signed-off-by: Harshit Gangal <[email protected]>

* test: added scalar aggr engine unit test

Signed-off-by: Harshit Gangal <[email protected]>

* remove onecase change

Signed-off-by: Harshit Gangal <[email protected]>

---------

Signed-off-by: Harshit Gangal <[email protected]>

* feat: fix tests

Signed-off-by: Manan Gupta <[email protected]>

---------

Signed-off-by: Harshit Gangal <[email protected]>
Signed-off-by: Manan Gupta <[email protected]>
Co-authored-by: Harshit Gangal <[email protected]>
  • Loading branch information
GuptaManan100 and harshit-gangal authored Feb 27, 2023
1 parent bf128d2 commit 77afd51
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 8 deletions.
34 changes: 34 additions & 0 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"
)
Expand Down Expand Up @@ -394,3 +395,36 @@ func TestAggregateLeftJoin(t *testing.T) {
mcmp.AssertMatches("SELECT count(*) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[INT64(2)]]`)
mcmp.AssertMatches("SELECT sum(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`)
}

// TestScalarAggregate tests validates that only count is returned and no additional field is returned.gst
func TestScalarAggregate(t *testing.T) {
// disable schema tracking to have weight_string column added to query send down to mysql.
clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, "--schema_change_signal=false")
require.NoError(t,
clusterInstance.RestartVtgate())

// update vtgate params
vtParams = mysql.ConnParams{
Host: clusterInstance.Hostname,
Port: clusterInstance.VtgateMySQLPort,
}

defer func() {
// roll it back
clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, "--schema_change_signal")
require.NoError(t,
clusterInstance.RestartVtgate())
// update vtgate params
vtParams = mysql.ConnParams{
Host: clusterInstance.Hostname,
Port: clusterInstance.VtgateMySQLPort,
}

}()

mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)")
mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ count(distinct val1) from aggr_test", `[[INT64(3)]]`)
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/scalar_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (sa *ScalarAggregate) TryExecute(vcursor VCursor, bindVars map[string]*quer
}

out.Rows = [][]sqltypes.Value{resultRow}
return out, nil
return out.Truncate(sa.TruncateColumnCount), nil
}

// TryStreamExecute implements the Primitive interface
Expand Down
39 changes: 35 additions & 4 deletions go/vt/vtgate/engine/scalar_aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,16 @@ func TestEmptyRows(outer *testing.T) {
func TestScalarAggregateStreamExecute(t *testing.T) {
assert := assert.New(t)
fields := sqltypes.MakeTestFields(
"count(*)",
"uint64",
"col|weight_string(col)",
"uint64|varbinary",
)
fp := &fakePrimitive{
allResultsInOneCall: true,
results: []*sqltypes.Result{
sqltypes.MakeTestResult(fields,
"1",
"1|null",
), sqltypes.MakeTestResult(fields,
"3",
"3|null",
)},
}

Expand All @@ -140,3 +140,34 @@ func TestScalarAggregateStreamExecute(t *testing.T) {
got := fmt.Sprintf("%v", results[1].Rows)
assert.Equal("[[UINT64(4)]]", got)
}

// TestScalarAggregateExecuteTruncate checks if truncate works
func TestScalarAggregateExecuteTruncate(t *testing.T) {
assert := assert.New(t)
fields := sqltypes.MakeTestFields(
"col|weight_string(col)",
"uint64|varbinary",
)

fp := &fakePrimitive{
allResultsInOneCall: true,
results: []*sqltypes.Result{
sqltypes.MakeTestResult(fields,
"1|null", "3|null",
)},
}

oa := &ScalarAggregate{
Aggregates: []*AggregateParams{{
Opcode: AggregateSum,
Col: 0,
}},
Input: fp,
TruncateColumnCount: 1,
PreProcess: true,
}

qr, err := oa.TryExecute(&noopVCursor{}, nil, true)
assert.NoError(err)
assert.Equal("[[UINT64(4)]]", fmt.Sprintf("%v", qr.Rows))
}
6 changes: 3 additions & 3 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3594,7 +3594,7 @@ func TestSelectAggregationData(t *testing.T) {
}{
{
sql: `select count(distinct col) from user`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col", "int64"), "1", "2", "2", "3"),
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|weight_string(col)", "int64|varbinary"), "1|NULL", "2|NULL", "2|NULL", "3|NULL"),
expSandboxQ: "select col, weight_string(col) from `user` group by col, weight_string(col) order by col asc",
expField: `[name:"count(distinct col)" type:INT64]`,
expRow: `[[INT64(3)]]`,
Expand All @@ -3608,14 +3608,14 @@ func TestSelectAggregationData(t *testing.T) {
},
{
sql: `select col, count(*) from user group by col`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|count(*)", "int64|int64"), "1|3"),
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|count(*)|weight_string(col)", "int64|int64|varbinary"), "1|3|NULL"),
expSandboxQ: "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc",
expField: `[name:"col" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[[INT64(1) INT64(24)]]`,
},
{
sql: `select col, count(*) from user group by col limit 2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|count(*)", "int64|int64"), "1|2", "2|1", "3|4"),
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|count(*)|weight_string(col)", "int64|int64|varbinary"), "1|2|NULL", "2|1|NULL", "3|4|NULL"),
expSandboxQ: "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc limit :__upper_limit",
expField: `[name:"col" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[[INT64(1) INT64(16)] [INT64(2) INT64(8)]]`,
Expand Down

0 comments on commit 77afd51

Please sign in to comment.