Skip to content

Commit

Permalink
feat: make json aggregation function implement aggfunc interface and …
Browse files Browse the repository at this point in the history
…add tests

Signed-off-by: Manan Gupta <[email protected]>
  • Loading branch information
GuptaManan100 committed Jun 24, 2024
1 parent 7f4f7dd commit 7c81d62
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 0 deletions.
12 changes: 12 additions & 0 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,15 @@ func TestHavingQueries(t *testing.T) {
})
}
}

// TestJsonAggregation tests that json aggregation works for single sharded queries.
func TestJsonAggregation(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 21, "vtgate")
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into t3(id5, id6, id7) values(1,2,1), (2,2,4), (3,2,4), (4,1,2), (5,2,1), (6,2,6), (7,1,7)")

mcmp.Exec("select count(1) from t3 where id6 = 2 group by id7 having json_arrayagg(id5+1) = json_array(2, 6)")
mcmp.Exec(`select count(1) from t3 where id6 = 2 group by id7 having json_objectagg(id5+1, id7) = json_object("2",1,"6",1)`)
}
20 changes: 20 additions & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package sqlparser
import (
"vitess.io/vitess/go/mysql/datetime"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/vterrors"
)

/*
Expand Down Expand Up @@ -3404,6 +3405,8 @@ func (varP *VarPop) GetArg() Expr { return varP.Arg }
func (varS *VarSamp) GetArg() Expr { return varS.Arg }
func (variance *Variance) GetArg() Expr { return variance.Arg }
func (av *AnyValue) GetArg() Expr { return av.Arg }
func (jaa *JSONArrayAgg) GetArg() Expr { return jaa.Expr }
func (joa *JSONObjectAgg) GetArg() Expr { return joa.Key }

func (sum *Sum) GetArgs() Exprs { return Exprs{sum.Arg} }
func (min *Min) GetArgs() Exprs { return Exprs{min.Arg} }
Expand All @@ -3423,6 +3426,8 @@ func (varP *VarPop) GetArgs() Exprs { return Exprs{varP.Arg} }
func (varS *VarSamp) GetArgs() Exprs { return Exprs{varS.Arg} }
func (variance *Variance) GetArgs() Exprs { return Exprs{variance.Arg} }
func (av *AnyValue) GetArgs() Exprs { return Exprs{av.Arg} }
func (jaa *JSONArrayAgg) GetArgs() Exprs { return Exprs{jaa.Expr} }
func (joa *JSONObjectAgg) GetArgs() Exprs { return Exprs{joa.Key, joa.Value} }

func (min *Min) SetArg(expr Expr) { min.Arg = expr }
func (sum *Sum) SetArg(expr Expr) { sum.Arg = expr }
Expand All @@ -3442,6 +3447,10 @@ func (varP *VarPop) SetArg(expr Expr) { varP.Arg = expr }
func (varS *VarSamp) SetArg(expr Expr) { varS.Arg = expr }
func (variance *Variance) SetArg(expr Expr) { variance.Arg = expr }
func (av *AnyValue) SetArg(expr Expr) { av.Arg = expr }
func (jaa *JSONArrayAgg) SetArg(expr Expr) { jaa.Expr = expr }
func (joa *JSONObjectAgg) SetArg(expr Expr) {
joa.Key = getColNameForExpression(expr, "JSONObjectAgg")
}

func (min *Min) SetArgs(exprs Exprs) error { return setFuncArgs(min, exprs, "MIN") }
func (sum *Sum) SetArgs(exprs Exprs) error { return setFuncArgs(sum, exprs, "SUM") }
Expand All @@ -3459,6 +3468,15 @@ func (varP *VarPop) SetArgs(exprs Exprs) error { return setFuncArgs(varP,
func (varS *VarSamp) SetArgs(exprs Exprs) error { return setFuncArgs(varS, exprs, "VAR_SAMP") }
func (variance *Variance) SetArgs(exprs Exprs) error { return setFuncArgs(variance, exprs, "VARIANCE") }
func (av *AnyValue) SetArgs(exprs Exprs) error { return setFuncArgs(av, exprs, "ANY_VALUE") }
func (jaa *JSONArrayAgg) SetArgs(exprs Exprs) error { return setFuncArgs(jaa, exprs, "JSON_ARRAYARG") }
func (joa *JSONObjectAgg) SetArgs(exprs Exprs) error {
if len(exprs) != 2 {
return vterrors.VT13001("JSONObjectAgg takes in 2 expressions")
}
joa.Key = getColNameForExpression(exprs[0], "JSONObjectAgg")
joa.Value = getColNameForExpression(exprs[1], "JSONObjectAgg")
return nil
}

func (count *Count) SetArgs(exprs Exprs) error {
count.Args = exprs
Expand Down Expand Up @@ -3501,6 +3519,8 @@ func (*VarPop) AggrName() string { return "var_pop" }
func (*VarSamp) AggrName() string { return "var_samp" }
func (*Variance) AggrName() string { return "variance" }
func (*AnyValue) AggrName() string { return "any_value" }
func (*JSONArrayAgg) AggrName() string { return "json_arrayagg" }
func (*JSONObjectAgg) AggrName() string { return "json_objectagg" }

// Exprs represents a list of value expressions.
// It's not a valid expression because it's not parenthesized.
Expand Down
4 changes: 4 additions & 0 deletions go/vt/sqlparser/ast_clone.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions go/vt/sqlparser/ast_copy_on_rewrite.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions go/vt/sqlparser/ast_equals.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2184,6 +2184,15 @@ func ContainsAggregation(e SQLNode) bool {
return hasAggregates
}

// getColNameForExpression gets the column name for the given expression.
func getColNameForExpression(expr Expr, name string) *ColName {
colName, isColName := expr.(*ColName)
if !isColName {
panic(vterrors.VT13001(fmt.Sprintf("Column name required in %v", name)))
}
return colName
}

// setFuncArgs sets the arguments for the aggregation function, while checking that there is only one argument
func setFuncArgs(aggr AggrFunc, exprs Exprs, name string) error {
if len(exprs) != 1 {
Expand Down
4 changes: 4 additions & 0 deletions go/vt/sqlparser/ast_rewrite.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions go/vt/sqlparser/ast_visit.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/select_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,57 @@
]
}
},
{
"comment": "json_arrayagg in single sharded query",
"query": "select count(1) from user where id = 'abc' group by n_id having json_arrayagg(a_id) = '[]'",
"plan": {
"QueryType": "SELECT",
"Original": "select count(1) from user where id = 'abc' group by n_id having json_arrayagg(a_id) = '[]'",
"Instructions": {
"OperatorType": "Route",
"Variant": "EqualUnique",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(1) from `user` where 1 != 1 group by n_id",
"Query": "select count(1) from `user` where id = 'abc' group by n_id having json_arrayagg(a_id) = '[]'",
"Table": "`user`",
"Values": [
"'abc'"
],
"Vindex": "user_index"
},
"TablesUsed": [
"user.user"
]
}
},{
"comment": "json_objectagg in single sharded query",
"query": "select count(1) from user where id = 'abc' group by n_id having json_objectagg(a_id, b_id) = '[]'",
"plan": {
"QueryType": "SELECT",
"Original": "select count(1) from user where id = 'abc' group by n_id having json_objectagg(a_id, b_id) = '[]'",
"Instructions": {
"OperatorType": "Route",
"Variant": "EqualUnique",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(1) from `user` where 1 != 1 group by n_id",
"Query": "select count(1) from `user` where id = 'abc' group by n_id having json_objectagg(a_id, b_id) = '[]'",
"Table": "`user`",
"Values": [
"'abc'"
],
"Vindex": "user_index"
},
"TablesUsed": [
"user.user"
]
}
},
{
"comment": "Cannot auto-resolve for cross-shard joins",
"query": "select col from user join user_extra",
Expand Down

0 comments on commit 7c81d62

Please sign in to comment.