diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index 857339605f8..e43171b6701 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -462,6 +462,20 @@ func TestColumnAliases(t *testing.T) { mcmp.ExecWithColumnCompare(`select a as k from (select count(*) as a from t1) t`) } +func TestHandleNullableColumn(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 21, "vtgate") + require.NoError(t, + utils.WaitForAuthoritative(t, keyspaceName, "tbl", clusterInstance.VtgateProcess.ReadVSchema)) + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into t1(id1, id2) values (0,0), (1,1), (2,2)") + mcmp.Exec("insert into tbl(id, unq_col, nonunq_col) values (0,0,0), (1,1,6)") + // This query tests that we handle nullable columns correctly + // tbl.nonunq_col is not nullable according to the schema, but because of the left join, it can be NULL + mcmp.ExecWithColumnCompare(`select * from t1 left join tbl on t1.id2 = tbl.id where t1.id1 = 6 or tbl.nonunq_col = 6`) +} + func TestEnumSetVals(t *testing.T) { utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") diff --git a/go/test/endtoend/vtgate/queries/misc/schema.sql b/go/test/endtoend/vtgate/queries/misc/schema.sql index e0c0d1a36a7..6b860ae77ae 100644 --- a/go/test/endtoend/vtgate/queries/misc/schema.sql +++ b/go/test/endtoend/vtgate/queries/misc/schema.sql @@ -24,7 +24,7 @@ create table tbl ( id bigint, unq_col bigint, - nonunq_col bigint, + nonunq_col bigint not null, primary key (id), unique (unq_col) ) Engine = InnoDB; diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index d9de15aa571..bcb2281f1a6 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -156,6 +156,10 @@ func (t *Type) Nullable() bool { return true // nullable by default for unknown types } +func (t *Type) SetNullability(n bool) { + t.nullable = n +} + func (t *Type) Values() *EnumSetValues { return t.values } diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 76c4ddd476c..bec5cd28bb5 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -317,7 +317,7 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega } for _, groupBy := range op.Grouping { - typ, _ := ctx.SemTable.TypeForExpr(groupBy.Inner) + typ, _ := ctx.TypeForExpr(groupBy.Inner) groupByKeys = append(groupByKeys, &engine.GroupByParams{ KeyCol: groupBy.ColOffset, WeightStringCol: groupBy.WSOffset, @@ -372,7 +372,7 @@ func createMemorySort(ctx *plancontext.PlanningContext, src engine.Primitive, or } for idx, order := range ordering.Order { - typ, _ := ctx.SemTable.TypeForExpr(order.SimplifiedExpr) + typ, _ := ctx.TypeForExpr(order.SimplifiedExpr) prim.OrderBy = append(prim.OrderBy, evalengine.OrderByParams{ Col: ordering.Offset[idx], WeightStringCol: ordering.WOffset[idx], @@ -438,7 +438,7 @@ func getEvalEngineExpr(ctx *plancontext.PlanningContext, pe *operators.ProjExpr) case *operators.EvalEngine: return e.EExpr, nil case operators.Offset: - typ, _ := ctx.SemTable.TypeForExpr(pe.EvalExpr) + typ, _ := ctx.TypeForExpr(pe.EvalExpr) return evalengine.NewColumn(int(e), typ, pe.EvalExpr), nil default: return nil, vterrors.VT13001("project not planned for: %s", pe.String()) @@ -590,7 +590,7 @@ func buildRoutePrimitive(ctx *plancontext.PlanningContext, op *operators.Route, } for _, order := range op.Ordering { - typ, _ := ctx.SemTable.TypeForExpr(order.AST) + typ, _ := ctx.TypeForExpr(order.AST) eroute.OrderBy = append(eroute.OrderBy, evalengine.OrderByParams{ Col: order.Offset, WeightStringCol: order.WOffset, @@ -907,11 +907,11 @@ func transformHashJoin(ctx *plancontext.PlanningContext, op *operators.HashJoin) var missingTypes []string - ltyp, found := ctx.SemTable.TypeForExpr(op.JoinComparisons[0].LHS) + ltyp, found := ctx.TypeForExpr(op.JoinComparisons[0].LHS) if !found { missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].LHS)) } - rtyp, found := ctx.SemTable.TypeForExpr(op.JoinComparisons[0].RHS) + rtyp, found := ctx.TypeForExpr(op.JoinComparisons[0].RHS) if !found { missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].RHS)) } @@ -949,7 +949,7 @@ func transformVindexPlan(ctx *plancontext.PlanningContext, op *operators.Vindex) expr, err := evalengine.Translate(op.Value, &evalengine.Config{ Collation: ctx.SemTable.Collation, - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Environment: ctx.VSchema.Environment(), }) if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/distinct.go b/go/vt/vtgate/planbuilder/operators/distinct.go index 9c893a878cd..f24d5b4978b 100644 --- a/go/vt/vtgate/planbuilder/operators/distinct.go +++ b/go/vt/vtgate/planbuilder/operators/distinct.go @@ -54,7 +54,7 @@ func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) Operator { offset := d.Source.AddWSColumn(ctx, idx, false) wsCol = &offset } - typ, _ := ctx.SemTable.TypeForExpr(e) + typ, _ := ctx.TypeForExpr(e) d.Columns = append(d.Columns, engine.CheckCol{ Col: idx, WsCol: wsCol, diff --git a/go/vt/vtgate/planbuilder/operators/filter.go b/go/vt/vtgate/planbuilder/operators/filter.go index babc309db72..19d864c0ada 100644 --- a/go/vt/vtgate/planbuilder/operators/filter.go +++ b/go/vt/vtgate/planbuilder/operators/filter.go @@ -127,7 +127,7 @@ func (f *Filter) Compact(*plancontext.PlanningContext) (Operator, *ApplyResult) func (f *Filter) planOffsets(ctx *plancontext.PlanningContext) Operator { cfg := &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), } diff --git a/go/vt/vtgate/planbuilder/operators/hash_join.go b/go/vt/vtgate/planbuilder/operators/hash_join.go index d2ba6522691..1928f4dda9e 100644 --- a/go/vt/vtgate/planbuilder/operators/hash_join.go +++ b/go/vt/vtgate/planbuilder/operators/hash_join.go @@ -358,7 +358,7 @@ func (hj *HashJoin) addColumn(ctx *plancontext.PlanningContext, in sqlparser.Exp rewrittenExpr := sqlparser.CopyOnRewrite(in, pre, r.post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) cfg := &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), } @@ -458,7 +458,7 @@ func (hj *HashJoin) addSingleSidedColumn( rewrittenExpr := sqlparser.CopyOnRewrite(in, pre, r.post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) cfg := &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), } diff --git a/go/vt/vtgate/planbuilder/operators/insert.go b/go/vt/vtgate/planbuilder/operators/insert.go index 7c6e242ae9c..75466500fe6 100644 --- a/go/vt/vtgate/planbuilder/operators/insert.go +++ b/go/vt/vtgate/planbuilder/operators/insert.go @@ -506,7 +506,7 @@ func insertRowsPlan(ctx *plancontext.PlanningContext, insOp *Insert, ins *sqlpar colNum, _ := findOrAddColumn(ins, col) for rowNum, row := range rows { innerpv, err := evalengine.Translate(row[colNum], &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), }) @@ -637,7 +637,7 @@ func modifyForAutoinc(ctx *plancontext.PlanningContext, ins *sqlparser.Insert, v } var err error gen.Values, err = evalengine.Translate(autoIncValues, &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), }) diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index d13f79e010f..71d2e5a8048 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -105,6 +105,9 @@ func createLeftOuterJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinT joinOp := &Join{LHS: lhs, RHS: rhs, JoinType: join.Join} + // mark the RHS as outer tables so we know which columns are nullable + ctx.OuterTables = ctx.OuterTables.Merge(TableID(rhs)) + // for outer joins we have to be careful with the predicates we use var op Operator subq, _ := getSubQuery(join.Condition.On) diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index 229dacc4ba2..1f380ba736b 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -631,7 +631,7 @@ func (p *Projection) planOffsets(ctx *plancontext.PlanningContext) Operator { // for everything else, we'll turn to the evalengine eexpr, err := evalengine.Translate(rewritten, &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), }) diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 56e0fe8d623..5729dbd0c2e 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -101,7 +101,7 @@ func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.T } switch aggr.OpCode { case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateSumDistinct, opcode.AggregateCountDistinct: - typ, _ := ctx.SemTable.TypeForExpr(aggr.Func.GetArg()) + typ, _ := ctx.TypeForExpr(aggr.Func.GetArg()) return typ } diff --git a/go/vt/vtgate/planbuilder/operators/sharded_routing.go b/go/vt/vtgate/planbuilder/operators/sharded_routing.go index 29ee88787b5..1319b76f040 100644 --- a/go/vt/vtgate/planbuilder/operators/sharded_routing.go +++ b/go/vt/vtgate/planbuilder/operators/sharded_routing.go @@ -573,7 +573,7 @@ func (tr *ShardedRouting) planCompositeInOpArg( Key: right.String(), Index: idx, } - if typ, found := ctx.SemTable.TypeForExpr(col); found { + if typ, found := ctx.TypeForExpr(col); found { value.Type = typ.Type() value.Collation = typ.Collation() } @@ -687,7 +687,7 @@ func makeEvalEngineExpr(ctx *plancontext.PlanningContext, n sqlparser.Expr) eval for _, expr := range ctx.SemTable.GetExprAndEqualities(n) { ee, _ := evalengine.Translate(expr, &evalengine.Config{ Collation: ctx.SemTable.Collation, - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Environment: ctx.VSchema.Environment(), }) if ee != nil { diff --git a/go/vt/vtgate/planbuilder/operators/union_merging.go b/go/vt/vtgate/planbuilder/operators/union_merging.go index 81ca2f5623e..20c20673665 100644 --- a/go/vt/vtgate/planbuilder/operators/union_merging.go +++ b/go/vt/vtgate/planbuilder/operators/union_merging.go @@ -200,8 +200,8 @@ func createMergedUnion( continue } deps = deps.Merge(ctx.SemTable.RecursiveDeps(rae.Expr)) - rt, foundR := ctx.SemTable.TypeForExpr(rae.Expr) - lt, foundL := ctx.SemTable.TypeForExpr(lae.Expr) + rt, foundR := ctx.TypeForExpr(rae.Expr) + lt, foundL := ctx.TypeForExpr(lae.Expr) if foundR && foundL { collations := ctx.VSchema.Environment().CollationEnv() var typer evalengine.TypeAggregator diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 4abf319ad08..f6dfa2280b2 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -1123,7 +1123,7 @@ func createAssignmentExpressions( } found = true pv, err := evalengine.Translate(assignment.Expr.EvalExpr, &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), }) diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 3c2a1c97434..90a6bdac6f8 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -20,6 +20,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -57,6 +58,10 @@ type PlanningContext struct { // Statement contains the originally parsed statement Statement sqlparser.Statement + + // OuterTables contains the tables that are outer to the current query + // Used to set the nullable flag on the columns + OuterTables semantics.TableSet } // CreatePlanningContext initializes a new PlanningContext with the given parameters. @@ -201,3 +206,19 @@ func (ctx *PlanningContext) RewriteDerivedTableExpression(expr sqlparser.Expr, t } return modifiedExpr } + +// TypeForExpr returns the type of the given expression, with nullable set if the expression is from an outer table. +func (ctx *PlanningContext) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { + t, found := ctx.SemTable.TypeForExpr(e) + if !found { + return t, found + } + deps := ctx.SemTable.RecursiveDeps(e) + // If the expression is from an outer table, it should be nullable + // There are some exceptions to this, where an expression depending on the outer side + // will never return NULL, but it's better to be conservative here. + if deps.IsOverlapping(ctx.OuterTables) { + t.SetNullability(true) + } + return t, true +} diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context_test.go b/go/vt/vtgate/planbuilder/plancontext/planning_context_test.go new file mode 100644 index 00000000000..b47286abdb2 --- /dev/null +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context_test.go @@ -0,0 +1,108 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plancontext + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vtgate/evalengine" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +func TestOuterTableNullability(t *testing.T) { + // Tests that columns from outer tables are nullable, + // even though the semantic state says that they are not nullable. + // This is because the outer table may not have a matching row. + // All columns are marked as NOT NULL in the schema. + query := "select * from t1 left join t2 on t1.a = t2.a where t1.a+t2.a/abs(t2.boing)" + ctx, columns := prepareContextAndFindColumns(t, query) + + // Check if the columns are correctly marked as nullable. + for _, col := range columns { + colName := "column: " + sqlparser.String(col) + t.Run(colName, func(t *testing.T) { + // Extract the column type from the context and the semantic state. + // The context should mark the column as nullable. + ctxType, found := ctx.TypeForExpr(col) + require.True(t, found, colName) + stType, found := ctx.SemTable.TypeForExpr(col) + require.True(t, found, colName) + ctxNullable := ctxType.Nullable() + stNullable := stType.Nullable() + + switch col.Qualifier.Name.String() { + case "t1": + assert.False(t, ctxNullable, colName) + assert.False(t, stNullable, colName) + case "t2": + assert.True(t, ctxNullable, colName) + + // The semantic state says that the column is not nullable. Don't trust it. + assert.False(t, stNullable, colName) + } + }) + } +} + +func prepareContextAndFindColumns(t *testing.T, query string) (ctx *PlanningContext, columns []*sqlparser.ColName) { + parser := sqlparser.NewTestParser() + ast, err := parser.Parse(query) + require.NoError(t, err) + semTable := semantics.EmptySemTable() + t1 := semTable.NewTableId() + t2 := semTable.NewTableId() + stmt := ast.(*sqlparser.Select) + expr := stmt.Where.Expr + + // Instead of using the semantic analysis, we manually set the types for the columns. + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + col, ok := node.(*sqlparser.ColName) + if !ok { + return true, nil + } + + switch col.Qualifier.Name.String() { + case "t1": + semTable.Recursive[col] = t1 + case "t2": + semTable.Recursive[col] = t2 + } + + intNotNull := evalengine.NewType(sqltypes.Int64, collations.Unknown) + intNotNull.SetNullability(false) + semTable.ExprTypes[col] = intNotNull + columns = append(columns, col) + return false, nil + }, nil, expr) + + ctx = &PlanningContext{ + SemTable: semTable, + joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, + skipPredicates: map[sqlparser.Expr]any{}, + ReservedArguments: map[sqlparser.Expr]string{}, + Statement: stmt, + OuterTables: t2, // t2 is the outer table. + } + return +} diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index f6f62a3eba5..47b8fa4595f 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -656,6 +656,7 @@ func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.Sel } // TypeForExpr returns the type of expressions in the query +// Note that PlanningContext has the same method, and you should use that if you have a PlanningContext func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { if typ, found := st.ExprTypes[e]; found { return typ, true