From 67f6ebe702e7d13bc7ff0fd65a89a0c20da778cb Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Fri, 14 Jun 2024 16:08:03 +0200 Subject: [PATCH] Cherry-pick 5a6f3868c56fb6e5290b153a615882c31aedfa6f with conflicts --- .../endtoend/vtgate/queries/misc/misc_test.go | 171 ++++++ .../endtoend/vtgate/queries/misc/schema.sql | 43 +- go/vt/vtgate/evalengine/compiler.go | 128 +++++ .../planbuilder/operator_transformers.go | 109 ++++ .../vtgate/planbuilder/operators/distinct.go | 4 + go/vt/vtgate/planbuilder/operators/filter.go | 2 +- .../vtgate/planbuilder/operators/hash_join.go | 494 ++++++++++++++++++ go/vt/vtgate/planbuilder/operators/insert.go | 20 + go/vt/vtgate/planbuilder/operators/join.go | 12 + .../planbuilder/operators/projection.go | 8 + .../planbuilder/operators/queryprojection.go | 5 + .../planbuilder/operators/sharded_routing.go | 42 ++ .../planbuilder/operators/union_merging.go | 17 + go/vt/vtgate/planbuilder/operators/update.go | 113 ++++ .../plancontext/planning_context.go | 31 ++ .../plancontext/planning_context_test.go | 108 ++++ go/vt/vtgate/semantics/semantic_state.go | 5 + 17 files changed, 1310 insertions(+), 2 deletions(-) create mode 100644 go/vt/vtgate/planbuilder/operators/hash_join.go create mode 100644 go/vt/vtgate/planbuilder/plancontext/planning_context_test.go diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index 309da1c5941..7cf89ed4499 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -276,3 +276,174 @@ func TestAnalyze(t *testing.T) { }) } } +<<<<<<< HEAD +======= + +// TestTransactionModeVar executes SELECT on `transaction_mode` variable +func TestTransactionModeVar(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + + mcmp, closer := start(t) + defer closer() + + tcases := []struct { + setStmt string + expRes string + }{{ + expRes: `[[VARCHAR("MULTI")]]`, + }, { + setStmt: `set transaction_mode = single`, + expRes: `[[VARCHAR("SINGLE")]]`, + }, { + setStmt: `set transaction_mode = multi`, + expRes: `[[VARCHAR("MULTI")]]`, + }, { + setStmt: `set transaction_mode = twopc`, + expRes: `[[VARCHAR("TWOPC")]]`, + }} + + for _, tcase := range tcases { + mcmp.Run(tcase.setStmt, func(mcmp *utils.MySQLCompare) { + if tcase.setStmt != "" { + utils.Exec(t, mcmp.VtConn, tcase.setStmt) + } + utils.AssertMatches(t, mcmp.VtConn, "select @@transaction_mode", tcase.expRes) + }) + } +} + +// TestAliasesInOuterJoinQueries tests that aliases work in queries that have outer join clauses. +func TestAliasesInOuterJoinQueries(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + + // Insert data into the 2 tables + mcmp.Exec("insert into t1(id1, id2) values (1,2), (42,5), (5, 42)") + mcmp.Exec("insert into tbl(id, unq_col, nonunq_col) values (1,2,3), (2,5,3), (3, 42, 2)") + + // Check that the select query works as intended and verifying the column names as well. + mcmp.ExecWithColumnCompare("select t1.id1 as t0, t1.id1 as t1, tbl.unq_col as col from t1 left outer join tbl on t1.id2 = tbl.nonunq_col") + mcmp.ExecWithColumnCompare("select t1.id1 as t0, t1.id1 as t1, tbl.unq_col as col from t1 left outer join tbl on t1.id2 = tbl.nonunq_col order by t1.id2 limit 2") + mcmp.ExecWithColumnCompare("select t1.id1 as t0, t1.id1 as t1, tbl.unq_col as col from t1 left outer join tbl on t1.id2 = tbl.nonunq_col order by t1.id2 limit 2 offset 2") + mcmp.ExecWithColumnCompare("select t1.id1 as t0, t1.id1 as t1, count(*) as leCount from t1 left outer join tbl on t1.id2 = tbl.nonunq_col group by 1, t1") + mcmp.ExecWithColumnCompare("select t.id1, t.id2, derived.unq_col from t1 t join (select id, unq_col, nonunq_col from tbl) as derived on t.id2 = derived.nonunq_col") +} + +func TestAlterTableWithView(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + mcmp, closer := start(t) + defer closer() + + // Test that create/alter view works and the output is as expected + mcmp.Exec(`use ks_misc`) + mcmp.Exec(`create view v1 as select * from t1`) + var viewDef string + utils.WaitForVschemaCondition(t, clusterInstance.VtgateProcess, keyspaceName, func(t *testing.T, ksMap map[string]any) bool { + views, ok := ksMap["views"] + if !ok { + return false + } + viewsMap := views.(map[string]any) + view, ok := viewsMap["v1"] + if ok { + viewDef = view.(string) + } + return ok + }, "Waiting for view creation") + mcmp.Exec(`insert into t1(id1, id2) values (1, 1)`) + mcmp.AssertMatches("select * from v1", `[[INT64(1) INT64(1)]]`) + + // alter table add column + mcmp.Exec(`alter table t1 add column test bigint`) + time.Sleep(10 * time.Second) + mcmp.Exec(`alter view v1 as select * from t1`) + + waitForChange := func(t *testing.T, ksMap map[string]any) bool { + // wait for the view definition to change + views := ksMap["views"] + viewsMap := views.(map[string]any) + newView := viewsMap["v1"] + if newView.(string) == viewDef { + return false + } + viewDef = newView.(string) + return true + } + utils.WaitForVschemaCondition(t, clusterInstance.VtgateProcess, keyspaceName, waitForChange, "Waiting for alter view") + + mcmp.AssertMatches("select * from v1", `[[INT64(1) INT64(1) NULL]]`) + + // alter table remove column + mcmp.Exec(`alter table t1 drop column test`) + mcmp.Exec(`alter view v1 as select * from t1`) + + utils.WaitForVschemaCondition(t, clusterInstance.VtgateProcess, keyspaceName, waitForChange, "Waiting for alter view") + + mcmp.AssertMatches("select * from v1", `[[INT64(1) INT64(1)]]`) +} + +// TestStraightJoin tests that Vitess respects the ordering of join in a STRAIGHT JOIN query. +func TestStraightJoin(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into tbl(id, unq_col, nonunq_col) values (1,0,10), (2,10,10), (3,4,20), (4,30,20), (5,40,10)") + mcmp.Exec(`insert into t1(id1, id2) values (10, 11), (20, 13)`) + + mcmp.AssertMatchesNoOrder("select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 join tbl where t1.id1 = tbl.nonunq_col", + `[[INT64(0) INT64(10) INT64(11)] [INT64(10) INT64(10) INT64(11)] [INT64(4) INT64(20) INT64(13)] [INT64(40) INT64(10) INT64(11)] [INT64(30) INT64(20) INT64(13)]]`, + ) + // Verify that in a normal join query, vitess joins tbl with t1. + res, err := mcmp.VtConn.ExecuteFetch("vexplain plan select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 join tbl where t1.id1 = tbl.nonunq_col", 100, false) + require.NoError(t, err) + require.Contains(t, fmt.Sprintf("%v", res.Rows), "tbl_t1") + + // Test the same query with a straight join + mcmp.AssertMatchesNoOrder("select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 straight_join tbl where t1.id1 = tbl.nonunq_col", + `[[INT64(0) INT64(10) INT64(11)] [INT64(10) INT64(10) INT64(11)] [INT64(4) INT64(20) INT64(13)] [INT64(40) INT64(10) INT64(11)] [INT64(30) INT64(20) INT64(13)]]`, + ) + // Verify that in a straight join query, vitess joins t1 with tbl. + res, err = mcmp.VtConn.ExecuteFetch("vexplain plan select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 straight_join tbl where t1.id1 = tbl.nonunq_col", 100, false) + require.NoError(t, err) + require.Contains(t, fmt.Sprintf("%v", res.Rows), "t1_tbl") +} + +func TestColumnAliases(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into t1(id1, id2) values (0,0), (1,1)") + mcmp.ExecWithColumnCompare(`select a as k from (select count(*) as a from t1) t`) +} + +func TestHandleNullableColumn(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 21, "vtgate") + require.NoError(t, + utils.WaitForAuthoritative(t, keyspaceName, "tbl", clusterInstance.VtgateProcess.ReadVSchema)) + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into t1(id1, id2) values (0,0), (1,1), (2,2)") + mcmp.Exec("insert into tbl(id, unq_col, nonunq_col) values (0,0,0), (1,1,6)") + // This query tests that we handle nullable columns correctly + // tbl.nonunq_col is not nullable according to the schema, but because of the left join, it can be NULL + mcmp.ExecWithColumnCompare(`select * from t1 left join tbl on t1.id2 = tbl.id where t1.id1 = 6 or tbl.nonunq_col = 6`) +} + +func TestEnumSetVals(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "tbl_enum_set", clusterInstance.VtgateProcess.ReadVSchema)) + + mcmp.Exec("insert into tbl_enum_set(id, enum_col, set_col) values (1, 'medium', 'a,b,e'), (2, 'small', 'e,f,g'), (3, 'large', 'c'), (4, 'xsmall', 'a,b'), (5, 'medium', 'a,d')") + + mcmp.AssertMatches("select id, enum_col, cast(enum_col as signed) from tbl_enum_set order by enum_col, id", `[[INT64(4) ENUM("xsmall") INT64(1)] [INT64(2) ENUM("small") INT64(2)] [INT64(1) ENUM("medium") INT64(3)] [INT64(5) ENUM("medium") INT64(3)] [INT64(3) ENUM("large") INT64(4)]]`) + mcmp.AssertMatches("select id, set_col, cast(set_col as unsigned) from tbl_enum_set order by set_col, id", `[[INT64(4) SET("a,b") UINT64(3)] [INT64(3) SET("c") UINT64(4)] [INT64(5) SET("a,d") UINT64(9)] [INT64(1) SET("a,b,e") UINT64(19)] [INT64(2) SET("e,f,g") UINT64(112)]]`) +} +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) diff --git a/go/test/endtoend/vtgate/queries/misc/schema.sql b/go/test/endtoend/vtgate/queries/misc/schema.sql index ceac0c07e6d..eedfd607572 100644 --- a/go/test/endtoend/vtgate/queries/misc/schema.sql +++ b/go/test/endtoend/vtgate/queries/misc/schema.sql @@ -1,5 +1,46 @@ +<<<<<<< HEAD create table if not exists t1( id1 bigint, id2 bigint, primary key(id1) -) Engine=InnoDB; \ No newline at end of file +) Engine=InnoDB; +======= +create table t1 +( + id1 bigint, + id2 bigint, + primary key (id1) +) Engine=InnoDB; + +create table unq_idx +( + unq_col bigint, + keyspace_id varbinary(20), + primary key (unq_col) +) Engine = InnoDB; + +create table nonunq_idx +( + nonunq_col bigint, + id bigint, + keyspace_id varbinary(20), + primary key (nonunq_col, id) +) Engine = InnoDB; + +create table tbl +( + id bigint, + unq_col bigint, + nonunq_col bigint not null, + primary key (id), + unique (unq_col) +) Engine = InnoDB; + +create table tbl_enum_set +( + id bigint, + enum_col enum('xsmall', 'small', 'medium', 'large', 'xlarge'), + set_col set('a', 'b', 'c', 'd', 'e', 'f', 'g'), + primary key (id) +) Engine = InnoDB; +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index a065825166c..79d7876fc3d 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -50,7 +50,135 @@ type ctype struct { Col collations.TypedCollation } +<<<<<<< HEAD func (ct ctype) nullable() bool { +======= +type Type struct { + typ sqltypes.Type + collation collations.ID + nullable bool + init bool + size, scale int32 + values *EnumSetValues +} + +func (v *EnumSetValues) Equal(other *EnumSetValues) bool { + if v == nil && other == nil { + return true + } + if v == nil || other == nil { + return false + } + return slices.Equal(*v, *other) +} + +func NewType(t sqltypes.Type, collation collations.ID) Type { + // New types default to being nullable + return NewTypeEx(t, collation, true, 0, 0, nil) +} + +func NewTypeEx(t sqltypes.Type, collation collations.ID, nullable bool, size, scale int32, values *EnumSetValues) Type { + return Type{ + typ: t, + collation: collation, + nullable: nullable, + init: true, + size: size, + scale: scale, + values: values, + } +} + +func NewTypeFromField(f *querypb.Field) Type { + return Type{ + typ: f.Type, + collation: collations.ID(f.Charset), + nullable: f.Flags&uint32(querypb.MySqlFlag_NOT_NULL_FLAG) == 0, + init: true, + size: int32(f.ColumnLength), + scale: int32(f.Decimals), + } +} + +func (t *Type) ToField(name string) *querypb.Field { + // need to get the proper flags for the type; usually leaving flags + // to 0 is OK, because Vitess' MySQL client will generate the right + // ones for the column's type, but here we're also setting the NotNull + // flag, so it needs to be set with the full flags for the column + _, flags := sqltypes.TypeToMySQL(t.typ) + if !t.nullable { + flags |= int64(querypb.MySqlFlag_NOT_NULL_FLAG) + } + + f := &querypb.Field{ + Name: name, + Type: t.typ, + Charset: uint32(t.collation), + ColumnLength: uint32(t.size), + Decimals: uint32(t.scale), + Flags: uint32(flags), + } + return f +} + +func (t *Type) Type() sqltypes.Type { + if t.init { + return t.typ + } + return sqltypes.Unknown +} + +func (t *Type) Collation() collations.ID { + return t.collation +} + +func (t *Type) Size() int32 { + return t.size +} + +func (t *Type) Scale() int32 { + return t.scale +} + +func (t *Type) Nullable() bool { + if t.init { + return t.nullable + } + return true // nullable by default for unknown types +} + +func (t *Type) SetNullability(n bool) { + t.nullable = n +} + +func (t *Type) Values() *EnumSetValues { + return t.values +} + +func (t *Type) Valid() bool { + return t.init +} + +func (t *Type) Equal(other *Type) bool { + return t.typ == other.typ && + t.collation == other.collation && + t.nullable == other.nullable && + t.size == other.size && + t.scale == other.scale && + t.values.Equal(other.values) +} + +func (ct *ctype) equal(other ctype) bool { + return ct.Type == other.Type && + ct.Flag == other.Flag && + ct.Size == other.Size && + ct.Scale == other.Scale && + ct.Col == other.Col && + ct.Values.Equal(other.Values) +} + +func (ct *ctype) nullable() bool { +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) return ct.Flag&flagNullable != 0 } diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 127404dee9f..f30f9a1fb43 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -186,8 +186,13 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega oa.aggregates = append(oa.aggregates, aggrParam) } for _, groupBy := range op.Grouping { +<<<<<<< HEAD typ, col, _ := ctx.SemTable.TypeForExpr(groupBy.SimplifiedExpr) oa.groupByKeys = append(oa.groupByKeys, &engine.GroupByParams{ +======= + typ, _ := ctx.TypeForExpr(groupBy.Inner) + groupByKeys = append(groupByKeys, &engine.GroupByParams{ +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) KeyCol: groupBy.ColOffset, WeightStringCol: groupBy.WSOffset, Expr: groupBy.AsAliasedExpr().Expr, @@ -230,6 +235,7 @@ func createMemorySort(ctx *plancontext.PlanningContext, src logicalPlan, orderin } for idx, order := range ordering.Order { +<<<<<<< HEAD typ, collationID, _ := ctx.SemTable.TypeForExpr(order.SimplifiedExpr) ms.eMemorySort.OrderBy = append(ms.eMemorySort.OrderBy, engine.OrderByParams{ Col: ordering.Offset[idx], @@ -238,6 +244,15 @@ func createMemorySort(ctx *plancontext.PlanningContext, src logicalPlan, orderin StarColFixedIndex: ordering.Offset[idx], Type: typ, CollationID: collationID, +======= + typ, _ := ctx.TypeForExpr(order.SimplifiedExpr) + prim.OrderBy = append(prim.OrderBy, evalengine.OrderByParams{ + Col: ordering.Offset[idx], + WeightStringCol: ordering.WOffset[idx], + Desc: order.Inner.Direction == sqlparser.DescOrder, + Type: typ, + CollationEnv: ctx.VSchema.Environment().CollationEnv(), +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) }) } @@ -292,8 +307,13 @@ func getEvalEngingeExpr(ctx *plancontext.PlanningContext, pe *operators.ProjExpr case *operators.EvalEngine: return e.EExpr, nil case operators.Offset: +<<<<<<< HEAD typ, col, _ := ctx.SemTable.TypeForExpr(pe.EvalExpr) return evalengine.NewColumn(int(e), typ, col), nil +======= + typ, _ := ctx.TypeForExpr(pe.EvalExpr) + return evalengine.NewColumn(int(e), typ, pe.EvalExpr), nil +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) default: return nil, vterrors.VT13001("project not planned for: %s", pe.String()) } @@ -440,8 +460,13 @@ func buildRouteLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route condition := getVindexPredicate(op) eroute, err := routeToEngineRoute(ctx, op) for _, order := range op.Ordering { +<<<<<<< HEAD typ, collation, _ := ctx.SemTable.TypeForExpr(order.AST) eroute.OrderBy = append(eroute.OrderBy, engine.OrderByParams{ +======= + typ, _ := ctx.TypeForExpr(order.AST) + eroute.OrderBy = append(eroute.OrderBy, evalengine.OrderByParams{ +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) Col: order.Offset, WeightStringCol: order.WOffset, Desc: order.Direction == sqlparser.DescOrder, @@ -707,5 +732,89 @@ func transformLimit(ctx *plancontext.PlanningContext, op *operators.Limit) (logi return nil, err } +<<<<<<< HEAD return createLimit(plan, op.AST) +======= + if len(op.LHSKeys) != 1 { + return nil, vterrors.VT12001("hash joins must have exactly one join predicate") + } + + joinOp := engine.InnerJoin + if op.LeftJoin { + joinOp = engine.LeftJoin + } + + var missingTypes []string + + ltyp, found := ctx.TypeForExpr(op.JoinComparisons[0].LHS) + if !found { + missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].LHS)) + } + rtyp, found := ctx.TypeForExpr(op.JoinComparisons[0].RHS) + if !found { + missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].RHS)) + } + + if len(missingTypes) > 0 { + return nil, vterrors.VT12001( + fmt.Sprintf("missing type information for [%s]", strings.Join(missingTypes, ", "))) + } + + comparisonType, err := evalengine.CoerceTypes(ltyp, rtyp, ctx.VSchema.Environment().CollationEnv()) + if err != nil { + return nil, err + } + + return &engine.HashJoin{ + Left: lhs, + Right: rhs, + Opcode: joinOp, + Cols: op.ColumnOffsets, + LHSKey: op.LHSKeys[0], + RHSKey: op.RHSKeys[0], + ASTPred: op.JoinPredicate(), + Collation: comparisonType.Collation(), + ComparisonType: comparisonType.Type(), + CollationEnv: ctx.VSchema.Environment().CollationEnv(), + Values: comparisonType.Values(), + }, nil +} + +func transformVindexPlan(ctx *plancontext.PlanningContext, op *operators.Vindex) (engine.Primitive, error) { + single, ok := op.Vindex.(vindexes.SingleColumn) + if !ok { + return nil, vterrors.VT12001("multi-column vindexes not supported") + } + + expr, err := evalengine.Translate(op.Value, &evalengine.Config{ + Collation: ctx.SemTable.Collation, + ResolveType: ctx.TypeForExpr, + Environment: ctx.VSchema.Environment(), + }) + if err != nil { + return nil, err + } + prim := &engine.VindexFunc{ + Opcode: op.OpCode, + Vindex: single, + Value: expr, + } + + for _, col := range op.Columns { + err := SupplyProjection(prim, &sqlparser.AliasedExpr{ + Expr: col, + As: sqlparser.IdentifierCI{}, + }, false) + if err != nil { + return nil, err + } + } + return prim, nil +} + +func generateQuery(statement sqlparser.Statement) string { + buf := sqlparser.NewTrackedBuffer(dmlFormatter) + statement.Format(buf) + return buf.String() +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) } diff --git a/go/vt/vtgate/planbuilder/operators/distinct.go b/go/vt/vtgate/planbuilder/operators/distinct.go index a521f12c8db..67706dc035f 100644 --- a/go/vt/vtgate/planbuilder/operators/distinct.go +++ b/go/vt/vtgate/planbuilder/operators/distinct.go @@ -62,7 +62,11 @@ func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) error { } wsCol = &offset } +<<<<<<< HEAD +======= + typ, _ := ctx.TypeForExpr(e) +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) d.Columns = append(d.Columns, engine.CheckCol{ Col: idx, WsCol: wsCol, diff --git a/go/vt/vtgate/planbuilder/operators/filter.go b/go/vt/vtgate/planbuilder/operators/filter.go index 874e799cf43..b44a90af990 100644 --- a/go/vt/vtgate/planbuilder/operators/filter.go +++ b/go/vt/vtgate/planbuilder/operators/filter.go @@ -125,7 +125,7 @@ func (f *Filter) Compact(*plancontext.PlanningContext) (ops.Operator, *rewrite.A func (f *Filter) planOffsets(ctx *plancontext.PlanningContext) error { cfg := &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, } diff --git a/go/vt/vtgate/planbuilder/operators/hash_join.go b/go/vt/vtgate/planbuilder/operators/hash_join.go new file mode 100644 index 00000000000..1928f4dda9e --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/hash_join.go @@ -0,0 +1,494 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "fmt" + "slices" + "strings" + + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +type ( + HashJoin struct { + LHS, RHS Operator + + // LeftJoin will be true in the case of an outer join + LeftJoin bool + + // Before offset planning + JoinComparisons []Comparison + + // These columns are the output columns of the hash join. While in operator mode we keep track of complex expression, + // but once we move to the engine primitives, the hash join only passes through column from either left or right. + // anything more complex will be solved by a projection on top of the hash join + columns *hashJoinColumns + + // After offset planning + + // Columns stores the column indexes of the columns coming from the left and right side + // negative value comes from LHS and positive from RHS + ColumnOffsets []int + + // These are the values that will be hashed together + LHSKeys, RHSKeys []int + + offset bool + } + + Comparison struct { + LHS, RHS sqlparser.Expr + } + + hashJoinColumn struct { + side joinSide + expr sqlparser.Expr + } + + joinSide int +) + +const ( + Unknown joinSide = iota + Left + Right +) + +var _ Operator = (*HashJoin)(nil) +var _ JoinOp = (*HashJoin)(nil) + +func NewHashJoin(lhs, rhs Operator, outerJoin bool) *HashJoin { + hj := &HashJoin{ + LHS: lhs, + RHS: rhs, + LeftJoin: outerJoin, + columns: &hashJoinColumns{}, + } + return hj +} + +func (hj *HashJoin) Clone(inputs []Operator) Operator { + kopy := *hj + kopy.LHS, kopy.RHS = inputs[0], inputs[1] + kopy.columns = hj.columns.clone() + kopy.LHSKeys = slices.Clone(hj.LHSKeys) + kopy.RHSKeys = slices.Clone(hj.RHSKeys) + kopy.JoinComparisons = slices.Clone(hj.JoinComparisons) + return &kopy +} + +func (hj *HashJoin) Inputs() []Operator { + return []Operator{hj.LHS, hj.RHS} +} + +func (hj *HashJoin) SetInputs(operators []Operator) { + hj.LHS, hj.RHS = operators[0], operators[1] +} + +func (hj *HashJoin) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { + return AddPredicate(ctx, hj, expr, false, newFilterSinglePredicate) +} + +func (hj *HashJoin) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { + if reuseExisting { + offset := hj.FindCol(ctx, expr.Expr, false) + if offset >= 0 { + return offset + } + } + + hj.columns.add(expr.Expr) + return len(hj.columns.columns) - 1 +} + +func (hj *HashJoin) AddWSColumn(ctx *plancontext.PlanningContext, offset int, underRoute bool) int { + hj.planOffsets(ctx) + + if len(hj.ColumnOffsets) <= offset { + panic(vterrors.VT13001("offset out of range")) + } + + // check if it already exists + wsExpr := weightStringFor(hj.columns.columns[offset].expr) + if index := hj.FindCol(ctx, wsExpr, false); index != -1 { + return index + } + + i := hj.ColumnOffsets[offset] + out := 0 + if i < 0 { + out = hj.LHS.AddWSColumn(ctx, FromLeftOffset(i), underRoute) + out = ToLeftOffset(out) + } else { + out = hj.RHS.AddWSColumn(ctx, FromRightOffset(i), underRoute) + out = ToRightOffset(out) + } + hj.ColumnOffsets = append(hj.ColumnOffsets, out) + return len(hj.ColumnOffsets) - 1 +} + +func (hj *HashJoin) planOffsets(ctx *plancontext.PlanningContext) Operator { + if hj.offset { + return nil + } + hj.offset = true + for _, cmp := range hj.JoinComparisons { + lOffset := hj.LHS.AddColumn(ctx, true, false, aeWrap(cmp.LHS)) + hj.LHSKeys = append(hj.LHSKeys, lOffset) + rOffset := hj.RHS.AddColumn(ctx, true, false, aeWrap(cmp.RHS)) + hj.RHSKeys = append(hj.RHSKeys, rOffset) + } + + needsProj := false + lID := TableID(hj.LHS) + rID := TableID(hj.RHS) + eexprs := slice.Map(hj.columns.columns, func(in hashJoinColumn) *ProjExpr { + var column *ProjExpr + var pureOffset bool + + switch in.side { + case Unknown: + column, pureOffset = hj.addColumn(ctx, in.expr) + case Left: + column, pureOffset = hj.addSingleSidedColumn(ctx, in.expr, lID, hj.LHS, lhsOffset) + case Right: + column, pureOffset = hj.addSingleSidedColumn(ctx, in.expr, rID, hj.RHS, rhsOffset) + default: + panic("not expected") + } + if !pureOffset { + needsProj = true + } + return column + }) + + if !needsProj { + return nil + } + proj := newAliasedProjection(hj) + proj.addProjExpr(eexprs...) + return proj +} + +func (hj *HashJoin) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, _ bool) int { + for offset, col := range hj.columns.columns { + if ctx.SemTable.EqualsExprWithDeps(expr, col.expr) { + return offset + } + } + return -1 +} + +func (hj *HashJoin) GetColumns(*plancontext.PlanningContext) []*sqlparser.AliasedExpr { + return slice.Map(hj.columns.columns, func(from hashJoinColumn) *sqlparser.AliasedExpr { + return aeWrap(from.expr) + }) +} + +func (hj *HashJoin) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { + return transformColumnsToSelectExprs(ctx, hj) +} + +func (hj *HashJoin) ShortDescription() string { + comparisons := slice.Map(hj.JoinComparisons, func(from Comparison) string { + return from.String() + }) + cmp := strings.Join(comparisons, " AND ") + + if len(hj.columns.columns) > 0 { + cols := slice.Map(hj.columns.columns, func(from hashJoinColumn) (result string) { + switch from.side { + case Unknown: + result = "U" + case Left: + result = "L" + case Right: + result = "R" + } + result += fmt.Sprintf("(%s)", sqlparser.String(from.expr)) + return + }) + return fmt.Sprintf("%s columns [%v]", cmp, strings.Join(cols, ", ")) + } + + return cmp +} + +func (hj *HashJoin) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy { + return nil // hash joins will never promise an output order +} + +func (hj *HashJoin) GetLHS() Operator { + return hj.LHS +} + +func (hj *HashJoin) GetRHS() Operator { + return hj.RHS +} + +func (hj *HashJoin) SetLHS(op Operator) { + hj.LHS = op +} + +func (hj *HashJoin) SetRHS(op Operator) { + hj.RHS = op +} + +func (hj *HashJoin) MakeInner() { + hj.LeftJoin = false +} + +func (hj *HashJoin) IsInner() bool { + return !hj.LeftJoin +} + +func (hj *HashJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) { + cmp, ok := expr.(*sqlparser.ComparisonExpr) + if !ok || !canBeSolvedWithHashJoin(cmp.Operator) { + panic(vterrors.VT12001(fmt.Sprintf("can't use [%s] with hash joins", sqlparser.String(expr)))) + } + lExpr := cmp.Left + lDeps := ctx.SemTable.RecursiveDeps(lExpr) + rExpr := cmp.Right + rDeps := ctx.SemTable.RecursiveDeps(rExpr) + lID := TableID(hj.LHS) + rID := TableID(hj.RHS) + if !lDeps.IsSolvedBy(lID) || !rDeps.IsSolvedBy(rID) { + // we'll switch and see if things work out then + lExpr, rExpr = rExpr, lExpr + lDeps, rDeps = rDeps, lDeps + } + + if !lDeps.IsSolvedBy(lID) || !rDeps.IsSolvedBy(rID) { + panic(vterrors.VT12001(fmt.Sprintf("can't use [%s] with hash joins", sqlparser.String(expr)))) + } + + hj.JoinComparisons = append(hj.JoinComparisons, Comparison{ + LHS: lExpr, + RHS: rExpr, + }) +} + +func canBeSolvedWithHashJoin(op sqlparser.ComparisonExprOperator) bool { + switch op { + case sqlparser.EqualOp, sqlparser.NullSafeEqualOp: + return true + default: + return false + } +} + +func (c Comparison) String() string { + return sqlparser.String(c.LHS) + " = " + sqlparser.String(c.RHS) +} +func lhsOffset(i int) int { return (i * -1) - 1 } +func rhsOffset(i int) int { return i + 1 } +func (hj *HashJoin) addColumn(ctx *plancontext.PlanningContext, in sqlparser.Expr) (*ProjExpr, bool) { + lId, rId := TableID(hj.LHS), TableID(hj.RHS) + r := new(replacer) // this is the expression we will put in instead of whatever we find there + pre := func(node, parent sqlparser.SQLNode) bool { + expr, ok := node.(sqlparser.Expr) + if !ok { + return true + } + deps := ctx.SemTable.RecursiveDeps(expr) + check := func(id semantics.TableSet, op Operator, offsetter func(int) int) int { + if !deps.IsSolvedBy(id) { + return -1 + } + inOffset := op.FindCol(ctx, expr, false) + if inOffset == -1 { + if !mustFetchFromInput(ctx, expr) { + return -1 + } + + // aha! this is an expression that we have to get from the input. let's force it in there + inOffset = op.AddColumn(ctx, false, false, aeWrap(expr)) + } + + // we turn the + internalOffset := offsetter(inOffset) + + // ok, we have an offset from the input operator. Let's check if we already have it + // in our list of incoming columns + + for idx, offset := range hj.ColumnOffsets { + if internalOffset == offset { + return idx + } + } + + hj.ColumnOffsets = append(hj.ColumnOffsets, internalOffset) + + return len(hj.ColumnOffsets) - 1 + } + + if lOffset := check(lId, hj.LHS, lhsOffset); lOffset >= 0 { + r.replaceExpr = sqlparser.NewOffset(lOffset, expr) + return false // we want to stop going down the expression tree and start coming back up again + } + + if rOffset := check(rId, hj.RHS, rhsOffset); rOffset >= 0 { + r.replaceExpr = sqlparser.NewOffset(rOffset, expr) + return false + } + + return true + } + + rewrittenExpr := sqlparser.CopyOnRewrite(in, pre, r.post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) + cfg := &evalengine.Config{ + ResolveType: ctx.TypeForExpr, + Collation: ctx.SemTable.Collation, + Environment: ctx.VSchema.Environment(), + } + eexpr, err := evalengine.Translate(rewrittenExpr, cfg) + if err != nil { + panic(err) + } + + _, isPureOffset := rewrittenExpr.(*sqlparser.Offset) + + return &ProjExpr{ + Original: aeWrap(in), + EvalExpr: rewrittenExpr, + ColExpr: rewrittenExpr, + Info: &EvalEngine{EExpr: eexpr}, + }, isPureOffset +} + +// JoinPredicate produces an AST representation of the join condition this join has +func (hj *HashJoin) JoinPredicate() sqlparser.Expr { + exprs := slice.Map(hj.JoinComparisons, func(from Comparison) sqlparser.Expr { + return &sqlparser.ComparisonExpr{ + Left: from.LHS, + Right: from.RHS, + } + }) + return sqlparser.AndExpressions(exprs...) +} + +type replacer struct { + replaceExpr sqlparser.Expr +} + +func (r *replacer) post(cursor *sqlparser.CopyOnWriteCursor) { + if r.replaceExpr != nil { + node := cursor.Node() + _, ok := node.(sqlparser.Expr) + if !ok { + panic(fmt.Sprintf("can't replace this node with an expression: %s", sqlparser.String(node))) + } + cursor.Replace(r.replaceExpr) + r.replaceExpr = nil + } +} + +func (hj *HashJoin) addSingleSidedColumn( + ctx *plancontext.PlanningContext, + in sqlparser.Expr, + tableID semantics.TableSet, + op Operator, + offsetter func(int) int, +) (*ProjExpr, bool) { + r := new(replacer) + pre := func(node, parent sqlparser.SQLNode) bool { + expr, ok := node.(sqlparser.Expr) + if !ok { + return true + } + deps := ctx.SemTable.RecursiveDeps(expr) + check := func(op Operator) int { + if !deps.IsSolvedBy(tableID) { + return -1 + } + inOffset := op.FindCol(ctx, expr, false) + if inOffset == -1 { + if !mustFetchFromInput(ctx, expr) { + return -1 + } + + // aha! this is an expression that we have to get from the input. let's force it in there + inOffset = op.AddColumn(ctx, false, false, aeWrap(expr)) + } + + // we have to turn the incoming offset to an outgoing offset of the columns this operator is exposing + internalOffset := offsetter(inOffset) + + // ok, we have an offset from the input operator. Let's check if we already have it + // in our list of incoming columns + for idx, offset := range hj.ColumnOffsets { + if internalOffset == offset { + return idx + } + } + + hj.ColumnOffsets = append(hj.ColumnOffsets, internalOffset) + + return len(hj.ColumnOffsets) - 1 + } + + if offset := check(op); offset >= 0 { + r.replaceExpr = sqlparser.NewOffset(offset, expr) + return false // we want to stop going down the expression tree and start coming back up again + } + + return true + } + + rewrittenExpr := sqlparser.CopyOnRewrite(in, pre, r.post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) + cfg := &evalengine.Config{ + ResolveType: ctx.TypeForExpr, + Collation: ctx.SemTable.Collation, + Environment: ctx.VSchema.Environment(), + } + eexpr, err := evalengine.Translate(rewrittenExpr, cfg) + if err != nil { + panic(err) + } + + _, isPureOffset := rewrittenExpr.(*sqlparser.Offset) + + return &ProjExpr{ + Original: aeWrap(in), + EvalExpr: rewrittenExpr, + ColExpr: rewrittenExpr, + Info: &EvalEngine{EExpr: eexpr}, + }, isPureOffset +} + +func FromLeftOffset(i int) int { + return -i - 1 +} + +func ToLeftOffset(i int) int { + return -i - 1 +} + +func FromRightOffset(i int) int { + return i - 1 +} + +func ToRightOffset(i int) int { + return i + 1 +} diff --git a/go/vt/vtgate/planbuilder/operators/insert.go b/go/vt/vtgate/planbuilder/operators/insert.go index 1925a0d9518..1749870f7fe 100644 --- a/go/vt/vtgate/planbuilder/operators/insert.go +++ b/go/vt/vtgate/planbuilder/operators/insert.go @@ -317,7 +317,15 @@ func insertRowsPlan(insOp *Insert, ins *sqlparser.Insert, rows sqlparser.Values) routeValues[vIdx][colIdx] = make([]evalengine.Expr, len(rows)) colNum, _ := findOrAddColumn(ins, col) for rowNum, row := range rows { +<<<<<<< HEAD innerpv, err := evalengine.Translate(row[colNum], nil) +======= + innerpv, err := evalengine.Translate(row[colNum], &evalengine.Config{ + ResolveType: ctx.TypeForExpr, + Collation: ctx.SemTable.Collation, + Environment: ctx.VSchema.Environment(), + }) +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) if err != nil { return nil, err } @@ -445,7 +453,19 @@ func modifyForAutoinc(ins *sqlparser.Insert, vTable *vindexes.Table) (*Generate, autoIncValues = append(autoIncValues, expr) row[colNum] = sqlparser.NewArgument(engine.SeqVarName + strconv.Itoa(rowNum)) } +<<<<<<< HEAD gen.Values = evalengine.NewTupleExpr(autoIncValues...) +======= + var err error + gen.Values, err = evalengine.Translate(autoIncValues, &evalengine.Config{ + ResolveType: ctx.TypeForExpr, + Collation: ctx.SemTable.Collation, + Environment: ctx.VSchema.Environment(), + }) + if err != nil { + panic(err) + } +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) } return gen, nil } diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index 693b7a75d8e..c8153f55e59 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -89,7 +89,19 @@ func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs ops.Operator) if tableExpr.Join == sqlparser.RightJoinType { lhs, rhs = rhs, lhs } +<<<<<<< HEAD subq, _ := getSubQuery(tableExpr.Condition.On) +======= + + joinOp := &Join{LHS: lhs, RHS: rhs, JoinType: join.Join} + + // mark the RHS as outer tables so we know which columns are nullable + ctx.OuterTables = ctx.OuterTables.Merge(TableID(rhs)) + + // for outer joins we have to be careful with the predicates we use + var op Operator + subq, _ := getSubQuery(join.Condition.On) +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) if subq != nil { return nil, vterrors.VT12001("subquery in outer join predicate") } diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index 686950ba56d..963467eef6a 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -600,7 +600,15 @@ func (p *Projection) planOffsets(ctx *plancontext.PlanningContext) error { } // for everything else, we'll turn to the evalengine +<<<<<<< HEAD eexpr, err := evalengine.Translate(rewritten, nil) +======= + eexpr, err := evalengine.Translate(rewritten, &evalengine.Config{ + ResolveType: ctx.TypeForExpr, + Collation: ctx.SemTable.Collation, + Environment: ctx.VSchema.Environment(), + }) +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) if err != nil { return err } diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 0630b09d459..3a0a2d1058e 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -121,8 +121,13 @@ func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) (sqltypes.Ty } switch aggr.OpCode { case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateSumDistinct, opcode.AggregateCountDistinct: +<<<<<<< HEAD typ, col, _ := ctx.SemTable.TypeForExpr(aggr.Func.GetArg()) return typ, col +======= + typ, _ := ctx.TypeForExpr(aggr.Func.GetArg()) + return typ +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) } return sqltypes.Unknown, collations.Unknown diff --git a/go/vt/vtgate/planbuilder/operators/sharded_routing.go b/go/vt/vtgate/planbuilder/operators/sharded_routing.go index 4965c5d18b5..19ae8ca515f 100644 --- a/go/vt/vtgate/planbuilder/operators/sharded_routing.go +++ b/go/vt/vtgate/planbuilder/operators/sharded_routing.go @@ -532,6 +532,43 @@ func (tr *ShardedRouting) planCompositeInOpRecursive( return foundVindex } +<<<<<<< HEAD +======= +func (tr *ShardedRouting) planCompositeInOpArg( + ctx *plancontext.PlanningContext, + cmp *sqlparser.ComparisonExpr, + left sqlparser.ValTuple, + right sqlparser.ListArg, +) bool { + foundVindex := false + for idx, expr := range left { + col, ok := expr.(*sqlparser.ColName) + if !ok { + continue + } + + // check if left col is a vindex + if !tr.hasVindex(col) { + continue + } + + value := &evalengine.TupleBindVariable{ + Key: right.String(), + Index: idx, + } + if typ, found := ctx.TypeForExpr(col); found { + value.Type = typ.Type() + value.Collation = typ.Collation() + } + + opcode := func(*vindexes.ColumnVindex) engine.Opcode { return engine.MultiEqual } + newVindex := tr.haveMatchingVindex(ctx, cmp, nil, col, value, opcode, justTheVindex) + foundVindex = newVindex || foundVindex + } + return foundVindex +} + +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) func (tr *ShardedRouting) hasVindex(column *sqlparser.ColName) bool { for _, v := range tr.VindexPreds { for _, col := range v.ColVindex.Columns { @@ -626,7 +663,12 @@ func makeEvalEngineExpr(ctx *plancontext.PlanningContext, n sqlparser.Expr) eval for _, expr := range ctx.SemTable.GetExprAndEqualities(n) { ee, _ := evalengine.Translate(expr, &evalengine.Config{ Collation: ctx.SemTable.Collation, +<<<<<<< HEAD ResolveType: ctx.SemTable.TypeForExpr, +======= + ResolveType: ctx.TypeForExpr, + Environment: ctx.VSchema.Environment(), +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) }) if ee != nil { return ee diff --git a/go/vt/vtgate/planbuilder/operators/union_merging.go b/go/vt/vtgate/planbuilder/operators/union_merging.go index 36fa7ec87e7..c47124c4033 100644 --- a/go/vt/vtgate/planbuilder/operators/union_merging.go +++ b/go/vt/vtgate/planbuilder/operators/union_merging.go @@ -207,11 +207,28 @@ func createMergedUnion( continue } deps = deps.Merge(ctx.SemTable.RecursiveDeps(rae.Expr)) +<<<<<<< HEAD rt, _, foundR := ctx.SemTable.TypeForExpr(rae.Expr) lt, _, foundL := ctx.SemTable.TypeForExpr(lae.Expr) if foundR && foundL && rt == lt { ctx.SemTable.CopySemanticInfo(rae.Expr, col) ctx.SemTable.CopySemanticInfo(lae.Expr, col) +======= + rt, foundR := ctx.TypeForExpr(rae.Expr) + lt, foundL := ctx.TypeForExpr(lae.Expr) + if foundR && foundL { + collations := ctx.VSchema.Environment().CollationEnv() + var typer evalengine.TypeAggregator + + if err := typer.Add(rt, collations); err != nil { + panic(err) + } + if err := typer.Add(lt, collations); err != nil { + panic(err) + } + + ctx.SemTable.ExprTypes[col] = typer.Type() +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) } ctx.SemTable.Recursive[col] = deps } diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 5a7716bdbeb..28c7746ee79 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -689,3 +689,116 @@ func nullSafeNotInComparison(updateExprs sqlparser.UpdateExprs, cFk vindexes.Chi return finalExpr } +<<<<<<< HEAD +======= + +func buildChangedVindexesValues( + ctx *plancontext.PlanningContext, + update *sqlparser.Update, + table *vindexes.Table, + ksidCols []sqlparser.IdentifierCI, + assignments []SetExpr, +) (changedVindexes map[string]*engine.VindexValues, ovq *sqlparser.Select, subQueriesArgOnChangedVindex []string) { + changedVindexes = make(map[string]*engine.VindexValues) + selExprs, offset := initialQuery(ksidCols, table) + for i, vindex := range table.ColumnVindexes { + vindexValueMap := make(map[string]evalengine.Expr) + var compExprs []sqlparser.Expr + for _, vcol := range vindex.Columns { + subQueriesArgOnChangedVindex, compExprs = + createAssignmentExpressions(ctx, assignments, vcol, subQueriesArgOnChangedVindex, vindexValueMap, compExprs) + } + if len(vindexValueMap) == 0 { + // Vindex not changing, continue + continue + } + if i == 0 { + panic(vterrors.VT12001(fmt.Sprintf("you cannot UPDATE primary vindex columns; invalid update on vindex: %v", vindex.Name))) + } + if _, ok := vindex.Vindex.(vindexes.Lookup); !ok { + panic(vterrors.VT12001(fmt.Sprintf("you can only UPDATE lookup vindexes; invalid update on vindex: %v", vindex.Name))) + } + + // Checks done, let's actually add the expressions and the vindex map + selExprs = append(selExprs, aeWrap(sqlparser.AndExpressions(compExprs...))) + changedVindexes[vindex.Name] = &engine.VindexValues{ + EvalExprMap: vindexValueMap, + Offset: offset, + } + offset++ + } + if len(changedVindexes) == 0 { + return nil, nil, nil + } + // generate rest of the owned vindex query. + ovq = &sqlparser.Select{ + SelectExprs: selExprs, + OrderBy: update.OrderBy, + Limit: update.Limit, + Lock: sqlparser.ForUpdateLock, + } + return changedVindexes, ovq, subQueriesArgOnChangedVindex +} + +func initialQuery(ksidCols []sqlparser.IdentifierCI, table *vindexes.Table) (sqlparser.SelectExprs, int) { + var selExprs sqlparser.SelectExprs + offset := 0 + for _, col := range ksidCols { + selExprs = append(selExprs, aeWrap(sqlparser.NewColName(col.String()))) + offset++ + } + for _, cv := range table.Owned { + for _, column := range cv.Columns { + selExprs = append(selExprs, aeWrap(sqlparser.NewColName(column.String()))) + offset++ + } + } + return selExprs, offset +} + +func createAssignmentExpressions( + ctx *plancontext.PlanningContext, + assignments []SetExpr, + vcol sqlparser.IdentifierCI, + subQueriesArgOnChangedVindex []string, + vindexValueMap map[string]evalengine.Expr, + compExprs []sqlparser.Expr, +) ([]string, []sqlparser.Expr) { + // Searching in order of columns in colvindex. + found := false + for _, assignment := range assignments { + if !vcol.Equal(assignment.Name.Name) { + continue + } + if found { + panic(vterrors.VT03015(assignment.Name.Name)) + } + found = true + pv, err := evalengine.Translate(assignment.Expr.EvalExpr, &evalengine.Config{ + ResolveType: ctx.TypeForExpr, + Collation: ctx.SemTable.Collation, + Environment: ctx.VSchema.Environment(), + }) + if err != nil { + panic(invalidUpdateExpr(assignment.Name.Name.String(), assignment.Expr.EvalExpr)) + } + + if assignment.Expr.Info != nil { + sqe, ok := assignment.Expr.Info.(SubQueryExpression) + if ok { + for _, sq := range sqe { + subQueriesArgOnChangedVindex = append(subQueriesArgOnChangedVindex, sq.ArgName) + } + } + } + + vindexValueMap[vcol.String()] = pv + compExprs = append(compExprs, sqlparser.NewComparisonExpr(sqlparser.EqualOp, assignment.Name, assignment.Expr.EvalExpr, nil)) + } + return subQueriesArgOnChangedVindex, compExprs +} + +func invalidUpdateExpr(upd string, expr sqlparser.Expr) error { + return vterrors.VT12001(fmt.Sprintf("only values are supported; invalid update on column: `%s` with expr: [%s]", upd, sqlparser.String(expr))) +} +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index d090a593a39..f8a94ad3655 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -19,6 +19,11 @@ package plancontext import ( querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" +<<<<<<< HEAD +======= + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -54,6 +59,16 @@ type PlanningContext struct { // CurrentPhase keeps track of how far we've gone in the planning process // The type should be operators.Phase, but depending on that would lead to circular dependencies CurrentPhase int +<<<<<<< HEAD +======= + + // Statement contains the originally parsed statement + Statement sqlparser.Statement + + // OuterTables contains the tables that are outer to the current query + // Used to set the nullable flag on the columns + OuterTables semantics.TableSet +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) } func CreatePlanningContext(stmt sqlparser.Statement, @@ -116,3 +131,19 @@ func (ctx *PlanningContext) GetArgumentFor(expr sqlparser.Expr, f func() string) ctx.ReservedArguments[expr] = bvName return bvName } + +// TypeForExpr returns the type of the given expression, with nullable set if the expression is from an outer table. +func (ctx *PlanningContext) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { + t, found := ctx.SemTable.TypeForExpr(e) + if !found { + return t, found + } + deps := ctx.SemTable.RecursiveDeps(e) + // If the expression is from an outer table, it should be nullable + // There are some exceptions to this, where an expression depending on the outer side + // will never return NULL, but it's better to be conservative here. + if deps.IsOverlapping(ctx.OuterTables) { + t.SetNullability(true) + } + return t, true +} diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context_test.go b/go/vt/vtgate/planbuilder/plancontext/planning_context_test.go new file mode 100644 index 00000000000..b47286abdb2 --- /dev/null +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context_test.go @@ -0,0 +1,108 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plancontext + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vtgate/evalengine" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +func TestOuterTableNullability(t *testing.T) { + // Tests that columns from outer tables are nullable, + // even though the semantic state says that they are not nullable. + // This is because the outer table may not have a matching row. + // All columns are marked as NOT NULL in the schema. + query := "select * from t1 left join t2 on t1.a = t2.a where t1.a+t2.a/abs(t2.boing)" + ctx, columns := prepareContextAndFindColumns(t, query) + + // Check if the columns are correctly marked as nullable. + for _, col := range columns { + colName := "column: " + sqlparser.String(col) + t.Run(colName, func(t *testing.T) { + // Extract the column type from the context and the semantic state. + // The context should mark the column as nullable. + ctxType, found := ctx.TypeForExpr(col) + require.True(t, found, colName) + stType, found := ctx.SemTable.TypeForExpr(col) + require.True(t, found, colName) + ctxNullable := ctxType.Nullable() + stNullable := stType.Nullable() + + switch col.Qualifier.Name.String() { + case "t1": + assert.False(t, ctxNullable, colName) + assert.False(t, stNullable, colName) + case "t2": + assert.True(t, ctxNullable, colName) + + // The semantic state says that the column is not nullable. Don't trust it. + assert.False(t, stNullable, colName) + } + }) + } +} + +func prepareContextAndFindColumns(t *testing.T, query string) (ctx *PlanningContext, columns []*sqlparser.ColName) { + parser := sqlparser.NewTestParser() + ast, err := parser.Parse(query) + require.NoError(t, err) + semTable := semantics.EmptySemTable() + t1 := semTable.NewTableId() + t2 := semTable.NewTableId() + stmt := ast.(*sqlparser.Select) + expr := stmt.Where.Expr + + // Instead of using the semantic analysis, we manually set the types for the columns. + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + col, ok := node.(*sqlparser.ColName) + if !ok { + return true, nil + } + + switch col.Qualifier.Name.String() { + case "t1": + semTable.Recursive[col] = t1 + case "t2": + semTable.Recursive[col] = t2 + } + + intNotNull := evalengine.NewType(sqltypes.Int64, collations.Unknown) + intNotNull.SetNullability(false) + semTable.ExprTypes[col] = intNotNull + columns = append(columns, col) + return false, nil + }, nil, expr) + + ctx = &PlanningContext{ + SemTable: semTable, + joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, + skipPredicates: map[sqlparser.Expr]any{}, + ReservedArguments: map[sqlparser.Expr]string{}, + Statement: stmt, + OuterTables: t2, // t2 is the outer table. + } + return +} diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 1ede4731edd..459f3d0a71b 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -325,7 +325,12 @@ func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.Sel } // TypeForExpr returns the type of expressions in the query +<<<<<<< HEAD func (st *SemTable) TypeForExpr(e sqlparser.Expr) (sqltypes.Type, collations.ID, bool) { +======= +// Note that PlanningContext has the same method, and you should use that if you have a PlanningContext +func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { +>>>>>>> 5a6f3868c5 (Handle Nullability for Columns from Outer Tables (#16174)) if typ, found := st.ExprTypes[e]; found { return typ.Type, typ.Collation, true }