Skip to content

Commit

Permalink
Improve typing during query planning (#16310)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay authored Jul 4, 2024
1 parent cb2d0df commit d303601
Show file tree
Hide file tree
Showing 13 changed files with 476 additions and 87 deletions.
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) Operator {
a.Grouping[idx].ColOffset = offset
gb.ColOffset = offset
}
if gb.WSOffset != -1 || !ctx.SemTable.NeedsWeightString(gb.Inner) {
if gb.WSOffset != -1 || !ctx.NeedsWeightString(gb.Inner) {
continue
}

Expand Down Expand Up @@ -516,7 +516,7 @@ func (a *Aggregator) pushRemainingGroupingColumnsAndWeightStrings(ctx *planconte
a.Grouping[idx].ColOffset = offset
}

if gb.WSOffset != -1 || !ctx.SemTable.NeedsWeightString(gb.Inner) {
if gb.WSOffset != -1 || !ctx.NeedsWeightString(gb.Inner) {
continue
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) Operator {
for idx, col := range columns {
e := col.Expr
var wsCol *int
if ctx.SemTable.NeedsWeightString(e) {
if ctx.NeedsWeightString(e) {
offset := d.Source.AddWSColumn(ctx, idx, false)
wsCol = &offset
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/ordering.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (o *Ordering) planOffsets(ctx *plancontext.PlanningContext) Operator {
offset := o.Source.AddColumn(ctx, true, false, aeWrap(order.SimplifiedExpr))
o.Offset = append(o.Offset, offset)

if !ctx.SemTable.NeedsWeightString(order.SimplifiedExpr) {
if !ctx.NeedsWeightString(order.SimplifiedExpr) {
o.WOffset = append(o.WOffset, -1)
continue
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ type (
)

func (aggr Aggr) NeedsWeightString(ctx *plancontext.PlanningContext) bool {
return aggr.OpCode.NeedsComparableValues() && ctx.SemTable.NeedsWeightString(aggr.Func.GetArg())
return aggr.OpCode.NeedsComparableValues() && ctx.NeedsWeightString(aggr.Func.GetArg())
}

func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.Type {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ func (r *Route) planOffsets(ctx *plancontext.PlanningContext) Operator {
WOffset: -1,
Direction: order.Inner.Direction,
}
if ctx.SemTable.NeedsWeightString(order.SimplifiedExpr) {
if ctx.NeedsWeightString(order.SimplifiedExpr) {
ws := weightStringFor(order.SimplifiedExpr)
offset := r.AddColumn(ctx, true, false, aeWrap(ws))
o.WOffset = offset
Expand Down
109 changes: 108 additions & 1 deletion go/vt/vtgate/planbuilder/plancontext/planning_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/engine/opcode"
"vitess.io/vitess/go/vt/vtgate/evalengine"
"vitess.io/vitess/go/vt/vtgate/semantics"
)
Expand Down Expand Up @@ -214,7 +215,12 @@ func (ctx *PlanningContext) RewriteDerivedTableExpression(expr sqlparser.Expr, t
func (ctx *PlanningContext) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) {
t, found := ctx.SemTable.TypeForExpr(e)
if !found {
return t, found
typ := ctx.calculateTypeFor(e)
if typ.Valid() {
ctx.SemTable.ExprTypes[e] = typ
return typ, true
}
return evalengine.NewUnknownType(), false
}
deps := ctx.SemTable.RecursiveDeps(e)
// If the expression is from an outer table, it should be nullable
Expand All @@ -226,6 +232,89 @@ func (ctx *PlanningContext) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool
return t, true
}

func (ctx *PlanningContext) calculateTypeFor(e sqlparser.Expr) evalengine.Type {
cfg := &evalengine.Config{
ResolveType: func(expr sqlparser.Expr) (evalengine.Type, bool) {
col, isCol := expr.(*sqlparser.ColName)
if !isCol {
return evalengine.NewUnknownType(), false
}
return ctx.SemTable.TypeForExpr(col)
},
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
ResolveColumn: func(name *sqlparser.ColName) (int, error) {
// We don't need to resolve the column for type calculation
return 0, nil
},
}
env := evalengine.EmptyExpressionEnv(ctx.VSchema.Environment())

// We need to rewrite the aggregate functions to their corresponding types
// The evaluation engine compiler doesn't handle them, so we replace them with Arguments before
// asking the compiler for the type

// TODO: put this back in when we can calculate the aggregation types correctly
// expr, unknown := ctx.replaceAggrWithArg(e, cfg, env)
// if unknown {
// return evalengine.NewUnknownType()
// }

translatedExpr, err := evalengine.Translate(e, cfg)
if err != nil {
return evalengine.NewUnknownType()
}

typ, err := env.TypeOf(translatedExpr)
if err != nil {
return evalengine.NewUnknownType()
}
return typ
}

// replaceAggrWithArg replaces aggregate functions with Arguments in the given expression.
// this is to prepare for sending the expression to the evalengine compiler to figure out the type
func (ctx *PlanningContext) replaceAggrWithArg(e sqlparser.Expr, cfg *evalengine.Config, env *evalengine.ExpressionEnv) (expr sqlparser.Expr, unknown bool) {
expr = sqlparser.CopyOnRewrite(e, nil, func(cursor *sqlparser.CopyOnWriteCursor) {
agg, ok := cursor.Node().(sqlparser.AggrFunc)
if !ok {
return
}
code, ok := opcode.SupportedAggregates[agg.AggrName()]
if !ok {
// We don't know the type of this aggregate function
// The type calculation will be set to unknown
unknown = true
cursor.StopTreeWalk()
return
}
var inputType evalengine.Type
if arg := agg.GetArg(); arg != nil {
translatedExpr, err := evalengine.Translate(arg, cfg)
if err != nil {
unknown = true
cursor.StopTreeWalk()
return
}

inputType, err = env.TypeOf(translatedExpr)
if err != nil {
unknown = true
cursor.StopTreeWalk()
return
}
}
typ := code.ResolveType(inputType, ctx.VSchema.Environment().CollationEnv())
cursor.Replace(&sqlparser.Argument{
Name: "arg",
Type: typ.Type(),
Size: typ.Size(),
Scale: typ.Scale(),
})
}, nil).(sqlparser.Expr)
return expr, unknown
}

// SQLTypeForExpr returns the sql type of the given expression, with nullable set if the expression is from an outer table.
func (ctx *PlanningContext) SQLTypeForExpr(e sqlparser.Expr) sqltypes.Type {
t, found := ctx.TypeForExpr(e)
Expand All @@ -235,6 +324,24 @@ func (ctx *PlanningContext) SQLTypeForExpr(e sqlparser.Expr) sqltypes.Type {
return t.Type()
}

func (ctx *PlanningContext) NeedsWeightString(e sqlparser.Expr) bool {
switch e := e.(type) {
case *sqlparser.WeightStringFuncExpr, *sqlparser.Literal:
return false
default:
typ, found := ctx.TypeForExpr(e)
if !found {
return true
}

if !sqltypes.IsText(typ.Type()) {
return false
}

return !ctx.VSchema.Environment().CollationEnv().IsSupported(typ.Collation())
}
}

func (ctx *PlanningContext) IsAggr(e sqlparser.SQLNode) bool {
switch node := e.(type) {
case sqlparser.AggrFunc:
Expand Down
Loading

0 comments on commit d303601

Please sign in to comment.