diff --git a/go/mysql/json/helpers.go b/go/mysql/json/helpers.go index 1df38b2d769..760d59c5624 100644 --- a/go/mysql/json/helpers.go +++ b/go/mysql/json/helpers.go @@ -106,6 +106,10 @@ func NewFromSQL(v sqltypes.Value) (*Value, error) { return NewDate(v.RawStr()), nil case v.IsTime(): return NewTime(v.RawStr()), nil + case v.IsEnum(): + return NewString(v.RawStr()), nil + case v.IsSet(): + return NewString(v.RawStr()), nil default: return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot coerce %v as a JSON type", v) } diff --git a/go/sqltypes/testing.go b/go/sqltypes/testing.go index 2fd9ee9c2be..f67cd1c6deb 100644 --- a/go/sqltypes/testing.go +++ b/go/sqltypes/testing.go @@ -279,6 +279,12 @@ var RandomGenerators = map[Type]RandomGenerator{ } return v }, + Enum: func() Value { + return MakeTrusted(Enum, randEnum()) + }, + Set: func() Value { + return MakeTrusted(Set, randSet()) + }, } func randTime() time.Time { @@ -289,3 +295,33 @@ func randTime() time.Time { sec := rand.Int64N(delta) + min return time.Unix(sec, 0) } + +func randEnum() []byte { + enums := []string{ + "xxsmall", + "xsmall", + "small", + "medium", + "large", + "xlarge", + "xxlarge", + } + return []byte(enums[rand.IntN(len(enums))]) +} + +func randSet() []byte { + set := []string{ + "a", + "b", + "c", + "d", + "e", + "f", + "g", + } + rand.Shuffle(len(set), func(i, j int) { + set[i], set[j] = set[j], set[i] + }) + set = set[:rand.IntN(len(set))] + return []byte(strings.Join(set, ",")) +} diff --git a/go/sqltypes/type.go b/go/sqltypes/type.go index 964dd6b5d83..4090dd0107a 100644 --- a/go/sqltypes/type.go +++ b/go/sqltypes/type.go @@ -119,6 +119,16 @@ func IsNull(t querypb.Type) bool { return t == Null } +// IsEnum returns true if the type is Enum type +func IsEnum(t querypb.Type) bool { + return t == Enum +} + +// IsSet returns true if the type is Set type +func IsSet(t querypb.Type) bool { + return t == Set +} + // Vitess data types. These are idiomatically named synonyms for the querypb.Type values. // Although these constants are interchangeable, they should be treated as different from querypb.Type. // Use the synonyms only to refer to the type in Value. For proto variables, use the querypb.Type constants instead. diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index b8f05e02db3..bb4e26d15e3 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -568,6 +568,16 @@ func (v Value) IsDecimal() bool { return IsDecimal(v.Type()) } +// IsEnum returns true if Value is enum. +func (v Value) IsEnum() bool { + return v.Type() == querypb.Type_ENUM +} + +// IsSet returns true if Value is set. +func (v Value) IsSet() bool { + return v.Type() == querypb.Type_SET +} + // IsComparable returns true if the Value is null safe comparable without collation information. func (v *Value) IsComparable() bool { if v.Type() == Null || IsNumber(v.Type()) || IsBinary(v.Type()) { diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index c10cb4c9b71..d0c610084cd 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -37,7 +37,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) { require.NoError(t, err) deleteAll := func() { - tables := []string{"t1", "tbl", "unq_idx", "nonunq_idx", "uks.unsharded"} + tables := []string{"t1", "tbl", "unq_idx", "nonunq_idx", "tbl_enum_set", "uks.unsharded"} for _, table := range tables { _, _ = mcmp.ExecAndIgnore("delete from " + table) } @@ -452,3 +452,16 @@ func TestStraightJoin(t *testing.T) { require.NoError(t, err) require.Contains(t, fmt.Sprintf("%v", res.Rows), "t1_tbl") } + +func TestEnumSetVals(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "tbl_enum_set", clusterInstance.VtgateProcess.ReadVSchema)) + + mcmp.Exec("insert into tbl_enum_set(id, enum_col, set_col) values (1, 'medium', 'a,b,e'), (2, 'small', 'e,f,g'), (3, 'large', 'c'), (4, 'xsmall', 'a,b'), (5, 'medium', 'a,d')") + + mcmp.AssertMatches("select id, enum_col, cast(enum_col as signed) from tbl_enum_set order by enum_col, id", `[[INT64(4) ENUM("xsmall") INT64(1)] [INT64(2) ENUM("small") INT64(2)] [INT64(1) ENUM("medium") INT64(3)] [INT64(5) ENUM("medium") INT64(3)] [INT64(3) ENUM("large") INT64(4)]]`) + mcmp.AssertMatches("select id, set_col, cast(set_col as unsigned) from tbl_enum_set order by set_col, id", `[[INT64(4) SET("a,b") UINT64(3)] [INT64(3) SET("c") UINT64(4)] [INT64(5) SET("a,d") UINT64(9)] [INT64(1) SET("a,b,e") UINT64(19)] [INT64(2) SET("e,f,g") UINT64(112)]]`) +} diff --git a/go/test/endtoend/vtgate/queries/misc/schema.sql b/go/test/endtoend/vtgate/queries/misc/schema.sql index 6fd57b9183d..685500ec809 100644 --- a/go/test/endtoend/vtgate/queries/misc/schema.sql +++ b/go/test/endtoend/vtgate/queries/misc/schema.sql @@ -27,3 +27,11 @@ create table tbl primary key (id), unique (unq_col) ) Engine = InnoDB; + +create table tbl_enum_set +( + id bigint, + enum_col enum('xsmall', 'small', 'medium', 'large', 'xlarge'), + set_col set('a', 'b', 'c', 'd', 'e', 'f', 'g'), + primary key (id) +) Engine = InnoDB; diff --git a/go/test/endtoend/vtgate/queries/misc/vschema.json b/go/test/endtoend/vtgate/queries/misc/vschema.json index f56b1fc1b36..d3d7c3b7935 100644 --- a/go/test/endtoend/vtgate/queries/misc/vschema.json +++ b/go/test/endtoend/vtgate/queries/misc/vschema.json @@ -53,6 +53,14 @@ } ] }, + "tbl_enum_set": { + "column_vindexes": [ + { + "column": "id", + "name": "hash" + } + ] + }, "unq_idx": { "column_vindexes": [ { diff --git a/go/vt/vtexplain/vtexplain_vttablet.go b/go/vt/vtexplain/vtexplain_vttablet.go index 53e09445c17..6f28cd99ec0 100644 --- a/go/vt/vtexplain/vtexplain_vttablet.go +++ b/go/vt/vtexplain/vtexplain_vttablet.go @@ -755,7 +755,7 @@ func (t *explainTablet) analyzeWhere(selStmt *sqlparser.Select, tableColumnMap m // Check if we have a duplicate value isNewValue := true for _, v := range inVal { - result, err := evalengine.NullsafeCompare(v, value, t.collationEnv, t.collationEnv.DefaultConnectionCharset()) + result, err := evalengine.NullsafeCompare(v, value, t.collationEnv, t.collationEnv.DefaultConnectionCharset(), nil) if err != nil { return "", nil, 0, nil, err } diff --git a/go/vt/vtgate/engine/aggregations.go b/go/vt/vtgate/engine/aggregations.go index ea10267a7e6..4673a2717e5 100644 --- a/go/vt/vtgate/engine/aggregations.go +++ b/go/vt/vtgate/engine/aggregations.go @@ -107,6 +107,7 @@ type aggregatorDistinct struct { last sqltypes.Value coll collations.ID collationEnv *collations.Environment + values *evalengine.EnumSetValues } func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) { @@ -115,7 +116,7 @@ func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) { next := row[a.column] if !last.IsNull() { if last.TinyWeightCmp(next) == 0 { - cmp, err := evalengine.NullsafeCompare(last, next, a.collationEnv, a.coll) + cmp, err := evalengine.NullsafeCompare(last, next, a.collationEnv, a.coll, a.values) if err != nil { return true, err } @@ -386,6 +387,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg column: distinct, coll: aggr.Type.Collation(), collationEnv: aggr.CollationEnv, + values: aggr.Type.Values(), }, } @@ -405,6 +407,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg column: distinct, coll: aggr.Type.Collation(), collationEnv: aggr.CollationEnv, + values: aggr.Type.Values(), }, } @@ -412,7 +415,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg ag = &aggregatorMin{ aggregatorMinMax{ from: aggr.Col, - minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation()), + minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation(), aggr.Type.Values()), }, } @@ -420,7 +423,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg ag = &aggregatorMax{ aggregatorMinMax{ from: aggr.Col, - minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation()), + minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation(), aggr.Type.Values()), }, } diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 5ff7a7c96ce..22b3a38a990 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -37,6 +37,8 @@ func (cached *AggregateParams) CachedSize(alloc bool) int64 { if alloc { size += int64(112) } + // field Type vitess.io/vitess/go/vt/vtgate/evalengine.Type + size += cached.Type.CachedSize(false) // field Alias string size += hack.RuntimeAllocSize(int64(len(cached.Alias))) // field Expr vitess.io/vitess/go/vt/sqlparser.Expr @@ -73,6 +75,8 @@ func (cached *CheckCol) CachedSize(alloc bool) int64 { } // field WsCol *int size += hack.RuntimeAllocSize(int64(8)) + // field Type vitess.io/vitess/go/vt/vtgate/evalengine.Type + size += cached.Type.CachedSize(false) // field CollationEnv *vitess.io/vitess/go/mysql/collations.Environment size += cached.CollationEnv.CachedSize(true) return size @@ -235,7 +239,7 @@ func (cached *Distinct) CachedSize(alloc bool) int64 { } // field CheckCols []vitess.io/vitess/go/vt/vtgate/engine.CheckCol { - size += hack.RuntimeAllocSize(int64(cap(cached.CheckCols)) * int64(40)) + size += hack.RuntimeAllocSize(int64(cap(cached.CheckCols)) * int64(48)) for _, elem := range cached.CheckCols { size += elem.CachedSize(false) } @@ -382,12 +386,14 @@ func (cached *GroupByParams) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(64) + size += int64(80) } // field Expr vitess.io/vitess/go/vt/sqlparser.Expr if cc, ok := cached.Expr.(cachedObject); ok { size += cc.CachedSize(true) } + // field Type vitess.io/vitess/go/vt/vtgate/evalengine.Type + size += cached.Type.CachedSize(false) // field CollationEnv *vitess.io/vitess/go/mysql/collations.Environment size += cached.CollationEnv.CachedSize(true) return size @@ -398,7 +404,7 @@ func (cached *HashJoin) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(112) + size += int64(128) } // field Left vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Left.(cachedObject); ok { @@ -418,6 +424,14 @@ func (cached *HashJoin) CachedSize(alloc bool) int64 { } // field CollationEnv *vitess.io/vitess/go/mysql/collations.Environment size += cached.CollationEnv.CachedSize(true) + // field Values *vitess.io/vitess/go/vt/vtgate/evalengine.EnumSetValues + if cached.Values != nil { + size += int64(24) + size += hack.RuntimeAllocSize(int64(cap(*cached.Values)) * int64(16)) + for _, elem := range *cached.Values { + size += hack.RuntimeAllocSize(int64(len(elem))) + } + } return size } func (cached *Insert) CachedSize(alloc bool) int64 { @@ -657,7 +671,7 @@ func (cached *MemorySort) CachedSize(alloc bool) int64 { } // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(48)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(56)) for _, elem := range cached.OrderBy { size += elem.CachedSize(false) } @@ -687,7 +701,7 @@ func (cached *MergeSort) CachedSize(alloc bool) int64 { } // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(48)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(56)) for _, elem := range cached.OrderBy { size += elem.CachedSize(false) } @@ -897,7 +911,7 @@ func (cached *Route) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.FieldQuery))) // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(48)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(56)) for _, elem := range cached.OrderBy { size += elem.CachedSize(false) } diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index c47cf6be8d1..189440611c3 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -74,14 +74,14 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (vthash.Hash, error) return vthash.Hash{}, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code") } col := inputRow[checkCol.Col] - err := evalengine.NullsafeHashcode128(&hasher, col, checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode) + err := evalengine.NullsafeHashcode128(&hasher, col, checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode, checkCol.Type.Values()) if err != nil { if err != evalengine.UnsupportedCollationHashError || checkCol.WsCol == nil { return vthash.Hash{}, err } checkCol = checkCol.SwitchToWeightString() pt.checkCols[i] = checkCol - err = evalengine.NullsafeHashcode128(&hasher, inputRow[checkCol.Col], checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode) + err = evalengine.NullsafeHashcode128(&hasher, inputRow[checkCol.Col], checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode, checkCol.Type.Values()) if err != nil { return vthash.Hash{}, err } diff --git a/go/vt/vtgate/engine/distinct_test.go b/go/vt/vtgate/engine/distinct_test.go index cb414d8de28..d7fe8786158 100644 --- a/go/vt/vtgate/engine/distinct_test.go +++ b/go/vt/vtgate/engine/distinct_test.go @@ -90,7 +90,7 @@ func TestDistinct(t *testing.T) { } checkCols = append(checkCols, CheckCol{ Col: i, - Type: evalengine.NewTypeEx(tc.inputs.Fields[i].Type, collID, false, 0, 0), + Type: evalengine.NewTypeEx(tc.inputs.Fields[i].Type, collID, false, 0, 0, nil), CollationEnv: collations.MySQL8(), }) } diff --git a/go/vt/vtgate/engine/hash_join.go b/go/vt/vtgate/engine/hash_join.go index f7c9d87e1fb..6ac34e1ab79 100644 --- a/go/vt/vtgate/engine/hash_join.go +++ b/go/vt/vtgate/engine/hash_join.go @@ -67,6 +67,9 @@ type ( ComparisonType querypb.Type CollationEnv *collations.Environment + + // Values for enum and set types + Values *evalengine.EnumSetValues } hashJoinProbeTable struct { @@ -78,6 +81,7 @@ type ( cols []int hasher vthash.Hasher sqlmode evalengine.SQLMode + values *evalengine.EnumSetValues } probeTableEntry struct { @@ -94,7 +98,7 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma return nil, err } - pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols) + pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols, hj.Values) // build the probe table from the LHS result for _, row := range lresult.Rows { err := pt.addLeftRow(row) @@ -130,7 +134,7 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma // TryStreamExecute implements the Primitive interface func (hj *HashJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { // build the probe table from the LHS result - pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols) + pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols, hj.Values) var lfields []*querypb.Field var mu sync.Mutex err := vcursor.StreamExecutePrimitive(ctx, hj.Left, bindVars, wantfields, func(result *sqltypes.Result) error { @@ -260,7 +264,7 @@ func (hj *HashJoin) description() PrimitiveDescription { } } -func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int) *hashJoinProbeTable { +func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int, values *evalengine.EnumSetValues) *hashJoinProbeTable { return &hashJoinProbeTable{ innerMap: map[vthash.Hash]*probeTableEntry{}, coll: coll, @@ -269,6 +273,7 @@ func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey rhsKey: rhsKey, cols: cols, hasher: vthash.New(), + values: values, } } @@ -286,7 +291,7 @@ func (pt *hashJoinProbeTable) addLeftRow(r sqltypes.Row) error { } func (pt *hashJoinProbeTable) hash(val sqltypes.Value) (vthash.Hash, error) { - err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode) + err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode, pt.values) if err != nil { return vthash.Hash{}, err } diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index 1bdbe61fd65..28c09de0fd6 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -180,7 +180,7 @@ func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Envir if code == AggregateAvg { scale += 4 } - return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale) + return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale, t.Values()) } func (code AggregateOpcode) NeedsComparableValues() bool { diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index ade8cd00299..5a72bdf4501 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -344,14 +344,14 @@ func (oa *OrderedAggregate) nextGroupBy(currentKey, nextRow []sqltypes.Value) (n return nextRow, true, nil } - cmp, err := evalengine.NullsafeCompare(v1, v2, oa.CollationEnv, gb.Type.Collation()) + cmp, err := evalengine.NullsafeCompare(v1, v2, oa.CollationEnv, gb.Type.Collation(), gb.Type.Values()) if err != nil { _, isCollationErr := err.(evalengine.UnsupportedCollationError) if !isCollationErr || gb.WeightStringCol == -1 { return nil, false, err } gb.KeyCol = gb.WeightStringCol - cmp, err = evalengine.NullsafeCompare(currentKey[gb.WeightStringCol], nextRow[gb.WeightStringCol], oa.CollationEnv, gb.Type.Collation()) + cmp, err = evalengine.NullsafeCompare(currentKey[gb.WeightStringCol], nextRow[gb.WeightStringCol], oa.CollationEnv, gb.Type.Collation(), gb.Type.Values()) if err != nil { return nil, false, err } diff --git a/go/vt/vtgate/evalengine/api_aggregation.go b/go/vt/vtgate/evalengine/api_aggregation.go index 0566f477a3c..78ab8335d6d 100644 --- a/go/vt/vtgate/evalengine/api_aggregation.go +++ b/go/vt/vtgate/evalengine/api_aggregation.go @@ -448,6 +448,7 @@ type aggregationMinMax struct { current sqltypes.Value collation collations.ID collationEnv *collations.Environment + values *EnumSetValues } func (a *aggregationMinMax) minmax(value sqltypes.Value, max bool) (err error) { @@ -458,7 +459,7 @@ func (a *aggregationMinMax) minmax(value sqltypes.Value, max bool) (err error) { a.current = value return nil } - n, err := compare(a.current, value, a.collationEnv, a.collation) + n, err := compare(a.current, value, a.collationEnv, a.collation, a.values) if err != nil { return err } @@ -484,7 +485,7 @@ func (a *aggregationMinMax) Reset() { a.current = sqltypes.NULL } -func NewAggregationMinMax(typ sqltypes.Type, collationEnv *collations.Environment, collation collations.ID) MinMax { +func NewAggregationMinMax(typ sqltypes.Type, collationEnv *collations.Environment, collation collations.ID, values *EnumSetValues) MinMax { switch { case sqltypes.IsSigned(typ): return &aggregationInt{t: typ} @@ -495,6 +496,6 @@ func NewAggregationMinMax(typ sqltypes.Type, collationEnv *collations.Environmen case sqltypes.IsDecimal(typ): return &aggregationDecimal{} default: - return &aggregationMinMax{collation: collation, collationEnv: collationEnv} + return &aggregationMinMax{collation: collation, collationEnv: collationEnv, values: values} } } diff --git a/go/vt/vtgate/evalengine/api_aggregation_test.go b/go/vt/vtgate/evalengine/api_aggregation_test.go index e5dae47017e..05884b4bb4b 100644 --- a/go/vt/vtgate/evalengine/api_aggregation_test.go +++ b/go/vt/vtgate/evalengine/api_aggregation_test.go @@ -137,7 +137,7 @@ func TestMinMax(t *testing.T) { for i, tcase := range tcases { t.Run(strconv.Itoa(i), func(t *testing.T) { t.Run("Min", func(t *testing.T) { - agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll) + agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll, nil) for _, v := range tcase.values { err := agg.Min(v) @@ -153,7 +153,7 @@ func TestMinMax(t *testing.T) { }) t.Run("Max", func(t *testing.T) { - agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll) + agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll, nil) for _, v := range tcase.values { err := agg.Max(v) diff --git a/go/vt/vtgate/evalengine/api_coerce.go b/go/vt/vtgate/evalengine/api_coerce.go index 907c578df8a..eef83c58422 100644 --- a/go/vt/vtgate/evalengine/api_coerce.go +++ b/go/vt/vtgate/evalengine/api_coerce.go @@ -24,7 +24,7 @@ import ( ) func CoerceTo(value sqltypes.Value, typ Type, sqlmode SQLMode) (sqltypes.Value, error) { - cast, err := valueToEvalCast(value, value.Type(), collations.Unknown, sqlmode) + cast, err := valueToEvalCast(value, value.Type(), collations.Unknown, typ.values, sqlmode) if err != nil { return sqltypes.Value{}, err } @@ -33,7 +33,7 @@ func CoerceTo(value sqltypes.Value, typ Type, sqlmode SQLMode) (sqltypes.Value, // CoerceTypes takes two input types, and decides how they should be coerced before compared func CoerceTypes(v1, v2 Type, collationEnv *collations.Environment) (out Type, err error) { - if v1 == v2 { + if v1.Equal(&v2) { return v1, nil } if sqltypes.IsNull(v1.Type()) || sqltypes.IsNull(v2.Type()) { diff --git a/go/vt/vtgate/evalengine/api_compare.go b/go/vt/vtgate/evalengine/api_compare.go index c6278264a47..6873ad40143 100644 --- a/go/vt/vtgate/evalengine/api_compare.go +++ b/go/vt/vtgate/evalengine/api_compare.go @@ -43,7 +43,7 @@ func (err UnsupportedCollationError) Error() string { // UnsupportedCollationHashError is returned when we try to get the hash value and are missing the collation to use var UnsupportedCollationHashError = vterrors.Errorf(vtrpcpb.Code_INTERNAL, "text type with an unknown/unsupported collation cannot be hashed") -func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID) (int, error) { +func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID, values *EnumSetValues) (int, error) { v1t := v1.Type() // We have a fast path here for the case where both values are @@ -115,7 +115,7 @@ func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collat Collation: collationID, Coercibility: collations.CoerceImplicit, Repertoire: collations.RepertoireUnicode, - }) + }, values) if err != nil { return 0, err } @@ -124,7 +124,7 @@ func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collat Collation: collationID, Coercibility: collations.CoerceImplicit, Repertoire: collations.RepertoireUnicode, - }) + }, values) if err != nil { return 0, err } @@ -147,7 +147,7 @@ func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collat // numeric, then a numeric comparison is performed after // necessary conversions. If none are numeric, then it's // a simple binary comparison. Uncomparable values return an error. -func NullsafeCompare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID) (int, error) { +func NullsafeCompare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID, values *EnumSetValues) (int, error) { // Based on the categorization defined for the types, // we're going to allow comparison of the following: // Null, isNumber, IsBinary. This will exclude IsQuoted @@ -161,7 +161,7 @@ func NullsafeCompare(v1, v2 sqltypes.Value, collationEnv *collations.Environment if v2.IsNull() { return 1, nil } - return compare(v1, v2, collationEnv, collationID) + return compare(v1, v2, collationEnv, collationID, values) } // OrderByParams specifies the parameters for ordering. @@ -213,7 +213,7 @@ func (obp *OrderByParams) Compare(r1, r2 []sqltypes.Value) int { if cmp == 0 { var err error - cmp, err = NullsafeCompare(v1, v2, obp.CollationEnv, obp.Type.Collation()) + cmp, err = NullsafeCompare(v1, v2, obp.CollationEnv, obp.Type.Collation(), obp.Type.values) if err != nil { _, isCollationErr := err.(UnsupportedCollationError) if !isCollationErr || obp.WeightStringCol == -1 { @@ -222,7 +222,7 @@ func (obp *OrderByParams) Compare(r1, r2 []sqltypes.Value) int { // in case of a comparison or collation error switch to using the weight string column for ordering obp.Col = obp.WeightStringCol obp.WeightStringCol = -1 - cmp, err = NullsafeCompare(r1[obp.Col], r2[obp.Col], obp.CollationEnv, obp.Type.Collation()) + cmp, err = NullsafeCompare(r1[obp.Col], r2[obp.Col], obp.CollationEnv, obp.Type.Collation(), obp.Type.values) if err != nil { panic(err) } diff --git a/go/vt/vtgate/evalengine/api_compare_test.go b/go/vt/vtgate/evalengine/api_compare_test.go index aa039537240..106b111cafc 100644 --- a/go/vt/vtgate/evalengine/api_compare_test.go +++ b/go/vt/vtgate/evalengine/api_compare_test.go @@ -30,14 +30,12 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vterrors" - - "vitess.io/vitess/go/sqltypes" - - querypb "vitess.io/vitess/go/vt/proto/query" ) type testCase struct { @@ -1109,11 +1107,12 @@ func TestNullComparisons(t *testing.T) { } func TestNullsafeCompare(t *testing.T) { - collation := collationEnv.LookupByName("utf8mb4_general_ci") + collation := collations.ID(collations.CollationUtf8mb4ID) tcases := []struct { v1, v2 sqltypes.Value out int err error + values *EnumSetValues }{ { v1: NULL, @@ -1140,23 +1139,60 @@ func TestNullsafeCompare(t *testing.T) { v2: TestValue(sqltypes.VarChar, " 6736380880502626304.000000 aa"), out: -1, }, + { + v1: TestValue(sqltypes.Enum, "foo"), + v2: TestValue(sqltypes.Enum, "bar"), + out: -1, + values: &EnumSetValues{"'foo'", "'bar'"}, + }, { v1: TestValue(sqltypes.Enum, "foo"), v2: TestValue(sqltypes.Enum, "bar"), out: 1, }, + { + v1: TestValue(sqltypes.Enum, "foo"), + v2: TestValue(sqltypes.VarChar, "bar"), + out: 1, + values: &EnumSetValues{"'foo'", "'bar'"}, + }, + { + v1: TestValue(sqltypes.VarChar, "foo"), + v2: TestValue(sqltypes.Enum, "bar"), + out: 1, + }, + { + v1: TestValue(sqltypes.Set, "bar"), + v2: TestValue(sqltypes.Set, "foo,bar"), + out: -1, + values: &EnumSetValues{"'foo'", "'bar'"}, + }, + { + v1: TestValue(sqltypes.Set, "bar"), + v2: TestValue(sqltypes.Set, "foo,bar"), + out: -1, + }, + { + v1: TestValue(sqltypes.VarChar, "bar"), + v2: TestValue(sqltypes.Set, "foo,bar"), + out: -1, + values: &EnumSetValues{"'foo'", "'bar'"}, + }, + { + v1: TestValue(sqltypes.Set, "bar"), + v2: TestValue(sqltypes.VarChar, "foo,bar"), + out: -1, + }, } for _, tcase := range tcases { t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { - got, err := NullsafeCompare(tcase.v1, tcase.v2, collations.MySQL8(), collation) + got, err := NullsafeCompare(tcase.v1, tcase.v2, collations.MySQL8(), collation, tcase.values) if tcase.err != nil { require.EqualError(t, err, tcase.err.Error()) return } require.NoError(t, err) - if got != tcase.out { - t.Errorf("NullsafeCompare(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), got, tcase.out) - } + assert.Equal(t, tcase.out, got) }) } } @@ -1237,7 +1273,7 @@ func TestNullsafeCompareCollate(t *testing.T) { } for _, tcase := range tcases { t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { - got, err := NullsafeCompare(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), collations.MySQL8(), tcase.collation) + got, err := NullsafeCompare(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), collations.MySQL8(), tcase.collation, nil) if tcase.err == nil { require.NoError(t, err) } else { @@ -1288,7 +1324,7 @@ func BenchmarkNullSafeComparison(b *testing.B) { for i := 0; i < b.N; i++ { for _, lhs := range inputs { for _, rhs := range inputs { - _, _ = NullsafeCompare(lhs, rhs, collations.MySQL8(), collid) + _, _ = NullsafeCompare(lhs, rhs, collations.MySQL8(), collid, nil) } } } @@ -1318,7 +1354,7 @@ func BenchmarkNullSafeComparison(b *testing.B) { for i := 0; i < b.N; i++ { for _, lhs := range inputs { for _, rhs := range inputs { - _, _ = NullsafeCompare(lhs, rhs, collations.MySQL8(), collations.CollationUtf8mb4ID) + _, _ = NullsafeCompare(lhs, rhs, collations.MySQL8(), collations.CollationUtf8mb4ID, nil) } } } diff --git a/go/vt/vtgate/evalengine/api_hash.go b/go/vt/vtgate/evalengine/api_hash.go index 2d3bc2d3b56..a5e5d1778dd 100644 --- a/go/vt/vtgate/evalengine/api_hash.go +++ b/go/vt/vtgate/evalengine/api_hash.go @@ -34,8 +34,8 @@ type HashCode = uint64 // NullsafeHashcode returns an int64 hashcode that is guaranteed to be the same // for two values that are considered equal by `NullsafeCompare`. -func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type, sqlmode SQLMode) (HashCode, error) { - e, err := valueToEvalCast(v, coerceType, collation, sqlmode) +func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type, sqlmode SQLMode, values *EnumSetValues) (HashCode, error) { + e, err := valueToEvalCast(v, coerceType, collation, values, sqlmode) if err != nil { return 0, err } @@ -75,7 +75,7 @@ var ErrHashCoercionIsNotExact = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, " // for two values that are considered equal by `NullsafeCompare`. // This can be used to avoid having to do comparison checks after a hash, // since we consider the 128 bits of entropy enough to guarantee uniqueness. -func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode) error { +func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode, values *EnumSetValues) error { switch { case v.IsNull(), sqltypes.IsNull(coerceTo): hash.Write16(hashPrefixNil) @@ -97,7 +97,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat case v.IsText(), v.IsBinary(): f, _ = fastparse.ParseFloat64(v.RawStr()) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode, values) } if err != nil { return err @@ -137,7 +137,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat } neg = i < 0 default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode, values) } if err != nil { return err @@ -180,7 +180,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat u, err = uint64(fval), nil } default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode, values) } if err != nil { return err @@ -223,20 +223,20 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat fval, _ := fastparse.ParseFloat64(v.RawStr()) dec = decimal.NewFromFloat(fval) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode, values) } hash.Write16(hashPrefixDecimal) dec.Hash(hash) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode, values) } return nil } -func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode) error { +func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode, values *EnumSetValues) error { // Slow path to handle all other types. This uses the generic // logic for value casting to ensure we match MySQL here. - e, err := valueToEvalCast(v, coerceTo, collation, sqlmode) + e, err := valueToEvalCast(v, coerceTo, collation, values, sqlmode) if err != nil { return err } diff --git a/go/vt/vtgate/evalengine/api_hash_test.go b/go/vt/vtgate/evalengine/api_hash_test.go index 7a680892712..bb2652ec6f2 100644 --- a/go/vt/vtgate/evalengine/api_hash_test.go +++ b/go/vt/vtgate/evalengine/api_hash_test.go @@ -52,14 +52,14 @@ func TestHashCodes(t *testing.T) { for _, tc := range cases { t.Run(fmt.Sprintf("%v %s %v", tc.static, equality(tc.equal).Operator(), tc.dynamic), func(t *testing.T) { - cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.MySQL8(), collations.CollationUtf8mb4ID) + cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.MySQL8(), collations.CollationUtf8mb4ID, nil) require.NoError(t, err) require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) - h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0) + h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0, nil) require.NoError(t, err) - h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0) + h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0, nil) require.ErrorIs(t, err, tc.err) assert.Equalf(t, tc.equal, h1 == h2, "HASH(%v) %s HASH(%v) (expected %s)", tc.static, equality(h1 == h2).Operator(), tc.dynamic, equality(tc.equal)) @@ -77,14 +77,14 @@ func TestHashCodesRandom(t *testing.T) { for time.Now().Before(endTime) { tested++ v1, v2 := sqltypes.TestRandomValues() - cmp, err := NullsafeCompare(v1, v2, collations.MySQL8(), collation) + cmp, err := NullsafeCompare(v1, v2, collations.MySQL8(), collation, nil) require.NoErrorf(t, err, "%s compared with %s", v1.String(), v2.String()) typ, err := coerceTo(v1.Type(), v2.Type()) require.NoError(t, err) - hash1, err := NullsafeHashcode(v1, collation, typ, 0) + hash1, err := NullsafeHashcode(v1, collation, typ, 0, nil) require.NoError(t, err) - hash2, err := NullsafeHashcode(v2, collation, typ, 0) + hash2, err := NullsafeHashcode(v2, collation, typ, 0, nil) require.NoError(t, err) if cmp == 0 { equal++ @@ -137,16 +137,16 @@ func TestHashCodes128(t *testing.T) { for _, tc := range cases { t.Run(fmt.Sprintf("%v %s %v", tc.static, equality(tc.equal).Operator(), tc.dynamic), func(t *testing.T) { - cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.MySQL8(), collations.CollationUtf8mb4ID) + cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.MySQL8(), collations.CollationUtf8mb4ID, nil) require.NoError(t, err) require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) hasher1 := vthash.New() - err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0) + err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0, nil) require.NoError(t, err) hasher2 := vthash.New() - err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0) + err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0, nil) require.ErrorIs(t, err, tc.err) h1 := hasher1.Sum128() @@ -166,16 +166,16 @@ func TestHashCodesRandom128(t *testing.T) { for time.Now().Before(endTime) { tested++ v1, v2 := sqltypes.TestRandomValues() - cmp, err := NullsafeCompare(v1, v2, collations.MySQL8(), collation) + cmp, err := NullsafeCompare(v1, v2, collations.MySQL8(), collation, nil) require.NoErrorf(t, err, "%s compared with %s", v1.String(), v2.String()) typ, err := coerceTo(v1.Type(), v2.Type()) require.NoError(t, err) hasher1 := vthash.New() - err = NullsafeHashcode128(&hasher1, v1, collation, typ, 0) + err = NullsafeHashcode128(&hasher1, v1, collation, typ, 0, nil) require.NoError(t, err) hasher2 := vthash.New() - err = NullsafeHashcode128(&hasher2, v2, collation, typ, 0) + err = NullsafeHashcode128(&hasher2, v2, collation, typ, 0, nil) require.NoError(t, err) if cmp == 0 { equal++ diff --git a/go/vt/vtgate/evalengine/api_literal.go b/go/vt/vtgate/evalengine/api_literal.go index 64d0cf5c1c3..16897650362 100644 --- a/go/vt/vtgate/evalengine/api_literal.go +++ b/go/vt/vtgate/evalengine/api_literal.go @@ -228,6 +228,7 @@ func NewColumn(offset int, typ Type, original sqlparser.Expr) *Column { Collation: typedCoercionCollation(typ.Type(), typ.Collation()), Original: original, Nullable: typ.nullable, + Values: typ.values, dynamicTypeOffset: -1, } } diff --git a/go/vt/vtgate/evalengine/api_type_aggregation.go b/go/vt/vtgate/evalengine/api_type_aggregation.go index 326f1397369..04622e5a212 100644 --- a/go/vt/vtgate/evalengine/api_type_aggregation.go +++ b/go/vt/vtgate/evalengine/api_type_aggregation.go @@ -80,7 +80,7 @@ func (ta *TypeAggregator) Type() Type { if ta.invalid > 0 || ta.types.empty() { return Type{} } - return NewTypeEx(ta.types.result(), ta.collations.result().Collation, ta.types.nullable, ta.size, ta.scale) + return NewTypeEx(ta.types.result(), ta.collations.result().Collation, ta.types.nullable, ta.size, ta.scale, nil) } func (ta *TypeAggregator) Field(name string) *query.Field { diff --git a/go/vt/vtgate/evalengine/arena.go b/go/vt/vtgate/evalengine/arena.go index 590dc3b02c7..ccfe63f514f 100644 --- a/go/vt/vtgate/evalengine/arena.go +++ b/go/vt/vtgate/evalengine/arena.go @@ -32,6 +32,8 @@ type Arena struct { aFloat64 []evalFloat aDecimal []evalDecimal aBytes []evalBytes + aEnum []evalEnum + aSet []evalSet } func (a *Arena) reset() { @@ -40,6 +42,8 @@ func (a *Arena) reset() { a.aFloat64 = a.aFloat64[:0] a.aDecimal = a.aDecimal[:0] a.aBytes = a.aBytes[:0] + a.aEnum = a.aEnum[:0] + a.aSet = a.aSet[:0] } func (a *Arena) newEvalDecimalWithPrec(dec decimal.Decimal, prec int32) *evalDecimal { @@ -61,6 +65,32 @@ func (a *Arena) newEvalDecimal(dec decimal.Decimal, m, d int32) *evalDecimal { return a.newEvalDecimalWithPrec(dec.Clamp(m-d, d), d) } +func (a *Arena) newEvalEnum(raw []byte, values *EnumSetValues) *evalEnum { + if cap(a.aEnum) > len(a.aEnum) { + a.aEnum = a.aEnum[:len(a.aEnum)+1] + } else { + a.aEnum = append(a.aEnum, evalEnum{}) + } + val := &a.aEnum[len(a.aInt64)-1] + s := string(raw) + val.string = s + val.value = valueIdx(values, s) + return val +} + +func (a *Arena) newEvalSet(raw []byte, values *EnumSetValues) *evalSet { + if cap(a.aSet) > len(a.aSet) { + a.aSet = a.aSet[:len(a.aSet)+1] + } else { + a.aSet = append(a.aSet, evalSet{}) + } + val := &a.aSet[len(a.aInt64)-1] + s := string(raw) + val.string = s + val.set = evalSetBits(values, s) + return val +} + func (a *Arena) newEvalBool(b bool) *evalInt64 { if b { return a.newEvalInt64(1) diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index f7aca0509bd..65f0bd37d12 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -165,6 +165,14 @@ func (cached *Column) CachedSize(alloc bool) int64 { if cc, ok := cached.Original.(cachedObject); ok { size += cc.CachedSize(true) } + // field Values *vitess.io/vitess/go/vt/vtgate/evalengine.EnumSetValues + if cached.Values != nil { + size += int64(24) + size += hack.RuntimeAllocSize(int64(cap(*cached.Values)) * int64(16)) + for _, elem := range *cached.Values { + size += hack.RuntimeAllocSize(int64(len(elem))) + } + } return size } func (cached *ComparisonExpr) CachedSize(alloc bool) int64 { @@ -195,6 +203,8 @@ func (cached *CompiledExpr) CachedSize(alloc bool) int64 { { size += hack.RuntimeAllocSize(int64(cap(cached.code)) * int64(8)) } + // field typed vitess.io/vitess/go/vt/vtgate/evalengine.ctype + size += cached.typed.CachedSize(false) // field ir vitess.io/vitess/go/vt/vtgate/evalengine.IR if cc, ok := cached.ir.(cachedObject); ok { size += cc.CachedSize(true) @@ -361,8 +371,10 @@ func (cached *OrderByParams) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(64) } + // field Type vitess.io/vitess/go/vt/vtgate/evalengine.Type + size += cached.Type.CachedSize(false) // field CollationEnv *vitess.io/vitess/go/mysql/collations.Environment size += cached.CollationEnv.CachedSize(true) return size @@ -379,6 +391,24 @@ func (cached *TupleBindVariable) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.Key))) return size } +func (cached *Type) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(24) + } + // field values *vitess.io/vitess/go/vt/vtgate/evalengine.EnumSetValues + if cached.values != nil { + size += int64(24) + size += hack.RuntimeAllocSize(int64(cap(*cached.values)) * int64(16)) + for _, elem := range *cached.values { + size += hack.RuntimeAllocSize(int64(len(elem))) + } + } + return size +} func (cached *UnaryExpr) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -1911,6 +1941,24 @@ func (cached *builtinYearWeek) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *ctype) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(32) + } + // field Values *vitess.io/vitess/go/vt/vtgate/evalengine.EnumSetValues + if cached.Values != nil { + size += int64(24) + size += hack.RuntimeAllocSize(int64(cap(*cached.Values)) * int64(16)) + for _, elem := range *cached.Values { + size += hack.RuntimeAllocSize(int64(len(elem))) + } + } + return size +} func (cached *evalBytes) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -1937,6 +1985,18 @@ func (cached *evalDecimal) CachedSize(alloc bool) int64 { size += cached.dec.CachedSize(false) return size } +func (cached *evalEnum) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(24) + } + // field string string + size += hack.RuntimeAllocSize(int64(len(cached.string))) + return size +} func (cached *evalFloat) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -1957,6 +2017,18 @@ func (cached *evalInt64) CachedSize(alloc bool) int64 { } return size } +func (cached *evalSet) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(24) + } + // field string string + size += hack.RuntimeAllocSize(int64(len(cached.string))) + return size +} func (cached *evalTemporal) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -2006,7 +2078,10 @@ func (cached *typedExpr) CachedSize(alloc bool) int64 { } // field types []vitess.io/vitess/go/vt/vtgate/evalengine.ctype { - size += hack.RuntimeAllocSize(int64(cap(cached.types)) * int64(20)) + size += hack.RuntimeAllocSize(int64(cap(cached.types)) * int64(32)) + for _, elem := range cached.types { + size += elem.CachedSize(false) + } } // field compiled *vitess.io/vitess/go/vt/vtgate/evalengine.CompiledExpr size += cached.compiled.CachedSize(true) diff --git a/go/vt/vtgate/evalengine/compare.go b/go/vt/vtgate/evalengine/compare.go index 102d6142321..836ca7c5043 100644 --- a/go/vt/vtgate/evalengine/compare.go +++ b/go/vt/vtgate/evalengine/compare.go @@ -18,6 +18,7 @@ package evalengine import ( "bytes" + "strings" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/colldata" @@ -122,6 +123,42 @@ func compareDates(l, r *evalTemporal) int { return l.dt.Compare(r.dt) } +func compareEnums(l, r *evalEnum) int { + if l.value == -1 || r.value == -1 { + // If the values are equal normally the strings + // are equal too. In case we didn't find the proper + // value in the enum we return the string comparison. + // This is not always correct, but a best effort and still + // works for the cases where we only care about + // equality. + return strings.Compare(l.string, r.string) + } + if l.value == r.value { + return 0 + } + if l.value < r.value { + return -1 + } + return 1 +} + +func compareSets(l, r *evalSet) int { + if l.set == r.set { + if l.set == 0 && (len(l.string) != 0 || len(r.string) != 0) { + // In this case we didn't have the proper values passed + // in when creating the evalSet. We can't compare the set + // values then, but fall back to string comparison to at + // least compare something and to handle equality checks. + return strings.Compare(l.string, r.string) + } + return 0 + } + if l.set < r.set { + return -1 + } + return 1 +} + func compareDateAndString(l, r eval) int { if tt, ok := l.(*evalTemporal); ok { return tt.dt.Compare(r.(*evalBytes).toDateBestEffort()) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 21d13119804..d9de15aa571 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -17,6 +17,8 @@ limitations under the License. package evalengine import ( + "slices" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/charset" "vitess.io/vitess/go/mysql/collations/colldata" @@ -50,11 +52,14 @@ type compiledCoercion struct { right colldata.Coercion } +type EnumSetValues []string + type ctype struct { Type sqltypes.Type Flag typeFlag Size, Scale int32 Col collations.TypedCollation + Values *EnumSetValues } type Type struct { @@ -63,14 +68,25 @@ type Type struct { nullable bool init bool size, scale int32 + values *EnumSetValues +} + +func (v *EnumSetValues) Equal(other *EnumSetValues) bool { + if v == nil && other == nil { + return true + } + if v == nil || other == nil { + return false + } + return slices.Equal(*v, *other) } func NewType(t sqltypes.Type, collation collations.ID) Type { // New types default to being nullable - return NewTypeEx(t, collation, true, 0, 0) + return NewTypeEx(t, collation, true, 0, 0, nil) } -func NewTypeEx(t sqltypes.Type, collation collations.ID, nullable bool, size, scale int32) Type { +func NewTypeEx(t sqltypes.Type, collation collations.ID, nullable bool, size, scale int32, values *EnumSetValues) Type { return Type{ typ: t, collation: collation, @@ -78,6 +94,7 @@ func NewTypeEx(t sqltypes.Type, collation collations.ID, nullable bool, size, sc init: true, size: size, scale: scale, + values: values, } } @@ -139,19 +156,41 @@ func (t *Type) Nullable() bool { return true // nullable by default for unknown types } +func (t *Type) Values() *EnumSetValues { + return t.values +} + func (t *Type) Valid() bool { return t.init } -func (ct ctype) nullable() bool { +func (t *Type) Equal(other *Type) bool { + return t.typ == other.typ && + t.collation == other.collation && + t.nullable == other.nullable && + t.size == other.size && + t.scale == other.scale && + t.values.Equal(other.values) +} + +func (ct *ctype) equal(other ctype) bool { + return ct.Type == other.Type && + ct.Flag == other.Flag && + ct.Size == other.Size && + ct.Scale == other.Scale && + ct.Col == other.Col && + ct.Values.Equal(other.Values) +} + +func (ct *ctype) nullable() bool { return ct.Flag&flagNullable != 0 } -func (ct ctype) isTextual() bool { +func (ct *ctype) isTextual() bool { return sqltypes.IsTextOrBinary(ct.Type) } -func (ct ctype) isHexOrBitLiteral() bool { +func (ct *ctype) isHexOrBitLiteral() bool { return ct.Flag&flagBit != 0 || ct.Flag&flagHex != 0 } diff --git a/go/vt/vtgate/evalengine/compiler_asm_push.go b/go/vt/vtgate/evalengine/compiler_asm_push.go index ab1371f1e11..87d2ee9af9b 100644 --- a/go/vt/vtgate/evalengine/compiler_asm_push.go +++ b/go/vt/vtgate/evalengine/compiler_asm_push.go @@ -105,6 +105,18 @@ func push_d(env *ExpressionEnv, raw []byte) int { return 1 } +func push_enum(env *ExpressionEnv, raw []byte, values *EnumSetValues) int { + env.vm.stack[env.vm.sp] = env.vm.arena.newEvalEnum(raw, values) + env.vm.sp++ + return 1 +} + +func push_set(env *ExpressionEnv, raw []byte, values *EnumSetValues) int { + env.vm.stack[env.vm.sp] = env.vm.arena.newEvalSet(raw, values) + env.vm.sp++ + return 1 +} + func (asm *assembler) PushColumn_d(offset int) { asm.adjustStack(1) @@ -117,6 +129,30 @@ func (asm *assembler) PushColumn_d(offset int) { }, "PUSH DECIMAL(:%d)", offset) } +func (asm *assembler) PushColumn_enum(offset int, values *EnumSetValues) { + asm.adjustStack(1) + + asm.emit(func(env *ExpressionEnv) int { + col := env.Row[offset] + if col.IsNull() { + return push_null(env) + } + return push_enum(env, col.Raw(), values) + }, "PUSH ENUM(:%d)", offset) +} + +func (asm *assembler) PushColumn_set(offset int, values *EnumSetValues) { + asm.adjustStack(1) + + asm.emit(func(env *ExpressionEnv) int { + col := env.Row[offset] + if col.IsNull() { + return push_null(env) + } + return push_set(env, col.Raw(), values) + }, "PUSH SET(:%d)", offset) +} + func (asm *assembler) PushBVar_d(key string) { asm.adjustStack(1) diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index 36ce482d967..90b1add541a 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -212,7 +212,7 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, all } } -func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.ID, sqlmode SQLMode) (eval, error) { +func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.ID, values *EnumSetValues, sqlmode SQLMode) (eval, error) { switch { case typ == sqltypes.Null: return nil, nil @@ -232,7 +232,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I fval, _ := fastparse.ParseFloat64(v.RawStr()) return newEvalFloat(fval), nil default: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -259,7 +259,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I fval, _ := fastparse.ParseFloat64(v.RawStr()) dec = decimal.NewFromFloat(fval) default: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -279,7 +279,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I i, err := fastparse.ParseInt64(v.RawStr(), 10) return newEvalInt64(i), err default: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -298,7 +298,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I u, err := fastparse.ParseUint64(v.RawStr(), 10) return newEvalUint64(u), err default: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -311,13 +311,13 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I case v.IsText() || v.IsBinary(): return newEvalRaw(v.Type(), v.Raw(), typedCoercionCollation(v.Type(), collation)), nil case sqltypes.IsText(typ): - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } return evalToVarchar(e, collation, true) default: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -327,7 +327,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I case typ == sqltypes.TypeJSON: return json.NewFromSQL(v) case typ == sqltypes.Date: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -338,7 +338,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I } return d, nil case typ == sqltypes.Datetime || typ == sqltypes.Timestamp: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -349,7 +349,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I } return dt, nil case typ == sqltypes.Time: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -359,11 +359,15 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I return nil, nil } return t, nil + case typ == sqltypes.Enum: + return newEvalEnum(v.Raw(), values), nil + case typ == sqltypes.Set: + return newEvalSet(v.Raw(), values), nil } return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value: %v", v) } -func valueToEval(value sqltypes.Value, collation collations.TypedCollation) (eval, error) { +func valueToEval(value sqltypes.Value, collation collations.TypedCollation, values *EnumSetValues) (eval, error) { wrap := func(err error) error { if err == nil { return nil @@ -384,6 +388,10 @@ func valueToEval(value sqltypes.Value, collation collations.TypedCollation) (eva case tt == sqltypes.Decimal: dec, err := decimal.NewFromMySQL(value.Raw()) return newEvalDecimal(dec, 0, 0), wrap(err) + case tt == sqltypes.Enum: + return newEvalEnum(value.Raw(), values), nil + case tt == sqltypes.Set: + return newEvalSet(value.Raw(), values), nil case sqltypes.IsText(tt): if tt == sqltypes.HexNum { raw, err := parseHexNumber(value.Raw()) diff --git a/go/vt/vtgate/evalengine/eval_enum.go b/go/vt/vtgate/evalengine/eval_enum.go new file mode 100644 index 00000000000..a0d349314da --- /dev/null +++ b/go/vt/vtgate/evalengine/eval_enum.go @@ -0,0 +1,40 @@ +package evalengine + +import ( + "vitess.io/vitess/go/hack" + "vitess.io/vitess/go/sqltypes" +) + +type evalEnum struct { + value int + string string +} + +func newEvalEnum(val []byte, values *EnumSetValues) *evalEnum { + s := string(val) + return &evalEnum{ + value: valueIdx(values, s), + string: s, + } +} + +func (e *evalEnum) ToRawBytes() []byte { + return hack.StringBytes(e.string) +} + +func (e *evalEnum) SQLType() sqltypes.Type { + return sqltypes.Enum +} + +func valueIdx(values *EnumSetValues, value string) int { + if values == nil { + return -1 + } + for i, v := range *values { + v, _ = sqltypes.DecodeStringSQL(v) + if v == value { + return i + } + } + return -1 +} diff --git a/go/vt/vtgate/evalengine/eval_numeric.go b/go/vt/vtgate/evalengine/eval_numeric.go index fb34caab85d..64f5477a3fc 100644 --- a/go/vt/vtgate/evalengine/eval_numeric.go +++ b/go/vt/vtgate/evalengine/eval_numeric.go @@ -149,6 +149,10 @@ func evalToNumeric(e eval, preciseDatetime bool) evalNumeric { return newEvalDecimalWithPrec(e.toDecimal(), int32(e.prec)) } return &evalFloat{f: e.toFloat()} + case *evalEnum: + return &evalFloat{f: float64(e.value)} + case *evalSet: + return &evalFloat{f: float64(e.set)} default: panic("unsupported") } @@ -205,6 +209,10 @@ func evalToFloat(e eval) (*evalFloat, bool) { } case *evalTemporal: return &evalFloat{f: e.toFloat()}, true + case *evalEnum: + return &evalFloat{f: float64(e.value)}, e.value != -1 + case *evalSet: + return &evalFloat{f: float64(e.set)}, true default: panic(fmt.Sprintf("unsupported type %T", e)) } @@ -269,6 +277,10 @@ func evalToDecimal(e eval, m, d int32) *evalDecimal { } case *evalTemporal: return newEvalDecimal(e.toDecimal(), m, d) + case *evalEnum: + return newEvalDecimal(decimal.NewFromInt(int64(e.value)), m, d) + case *evalSet: + return newEvalDecimal(decimal.NewFromUint(e.set), m, d) default: panic("unsupported") } @@ -332,6 +344,10 @@ func evalToInt64(e eval) *evalInt64 { } case *evalTemporal: return newEvalInt64(e.toInt64()) + case *evalEnum: + return newEvalInt64(int64(e.value)) + case *evalSet: + return newEvalInt64(int64(e.set)) default: panic(fmt.Sprintf("unsupported type: %T", e)) } diff --git a/go/vt/vtgate/evalengine/eval_set.go b/go/vt/vtgate/evalengine/eval_set.go new file mode 100644 index 00000000000..6a9de2eff14 --- /dev/null +++ b/go/vt/vtgate/evalengine/eval_set.go @@ -0,0 +1,49 @@ +package evalengine + +import ( + "strings" + + "vitess.io/vitess/go/hack" + "vitess.io/vitess/go/sqltypes" +) + +type evalSet struct { + set uint64 + string string +} + +func newEvalSet(val []byte, values *EnumSetValues) *evalSet { + value := string(val) + + return &evalSet{ + set: evalSetBits(values, value), + string: value, + } +} + +func (e *evalSet) ToRawBytes() []byte { + return hack.StringBytes(e.string) +} + +func (e *evalSet) SQLType() sqltypes.Type { + return sqltypes.Set +} + +func evalSetBits(values *EnumSetValues, value string) uint64 { + if values != nil && len(*values) > 64 { + // This never would happen as MySQL limits SET + // to 64 elements. Safeguard here just in case though. + panic("too many values for set") + } + + set := uint64(0) + for _, val := range strings.Split(value, ",") { + idx := valueIdx(values, val) + if idx == -1 { + continue + } + set |= 1 << idx + } + + return set +} diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index b21ded90189..0fffe3140a2 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -70,7 +70,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { tuple := make([]eval, 0, len(bvar.Values)) for _, value := range bvar.Values { - e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), typedCoercionCollation(value.Type, collations.CollationForType(value.Type, bv.Collation))) + e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), typedCoercionCollation(value.Type, collations.CollationForType(value.Type, bv.Collation)), nil) if err != nil { return nil, err } @@ -86,7 +86,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { if bv.typed() { typ = bv.Type } - return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation))) + return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation)), nil) } } diff --git a/go/vt/vtgate/evalengine/expr_column.go b/go/vt/vtgate/evalengine/expr_column.go index 8663370f819..d53585ceb8b 100644 --- a/go/vt/vtgate/evalengine/expr_column.go +++ b/go/vt/vtgate/evalengine/expr_column.go @@ -34,6 +34,7 @@ type ( Collation collations.TypedCollation Original sqlparser.Expr Nullable bool + Values *EnumSetValues // For ENUM and SET types // dynamicTypeOffset is set when the type of this column cannot be calculated // at translation time. Since expressions with dynamic types cannot be compiled ahead of time, @@ -54,7 +55,7 @@ func (c *Column) IsExpr() {} // eval implements the expression interface func (c *Column) eval(env *ExpressionEnv) (eval, error) { - return valueToEval(env.Row[c.Offset], c.Collation) + return valueToEval(env.Row[c.Offset], c.Collation, c.Values) } func (c *Column) typeof(env *ExpressionEnv) (ctype, error) { @@ -63,7 +64,7 @@ func (c *Column) typeof(env *ExpressionEnv) (ctype, error) { if c.Nullable { nullable = flagNullable } - return ctype{Type: c.Type, Size: c.Size, Scale: c.Scale, Flag: nullable, Col: c.Collation}, nil + return ctype{Type: c.Type, Size: c.Size, Scale: c.Scale, Flag: nullable, Col: c.Collation, Values: c.Values}, nil } if c.Offset < len(env.Fields) { field := env.Fields[c.Offset] @@ -83,7 +84,7 @@ func (c *Column) typeof(env *ExpressionEnv) (ctype, error) { } if c.Offset < len(env.Row) { value := env.Row[c.Offset] - return ctype{Type: value.Type(), Flag: 0, Col: c.Collation}, nil + return ctype{Type: value.Type(), Flag: 0, Col: c.Collation, Values: c.Values}, nil } return ctype{}, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "no column at offset %d", c.Offset) } @@ -99,6 +100,7 @@ func (column *Column) compile(c *compiler) (ctype, error) { } typ.Size = column.Size typ.Scale = column.Scale + typ.Values = column.Values } else if c.dynamicTypes != nil { typ = c.dynamicTypes[column.dynamicTypeOffset] } else { @@ -121,6 +123,10 @@ func (column *Column) compile(c *compiler) (ctype, error) { typ.Type = sqltypes.Float64 case sqltypes.IsDecimal(tt): c.asm.PushColumn_d(column.Offset) + case tt == sqltypes.Enum: + c.asm.PushColumn_enum(column.Offset, column.Values) + case tt == sqltypes.Set: + c.asm.PushColumn_set(column.Offset, column.Values) case sqltypes.IsText(tt): if tt == sqltypes.HexNum { c.asm.PushColumn_hexnum(column.Offset) diff --git a/go/vt/vtgate/evalengine/expr_compare.go b/go/vt/vtgate/evalengine/expr_compare.go index ca4cdd75f74..f3bd44588ee 100644 --- a/go/vt/vtgate/evalengine/expr_compare.go +++ b/go/vt/vtgate/evalengine/expr_compare.go @@ -114,7 +114,7 @@ func (compareNullSafeEQ) compare(collationEnv *collations.Environment, left, rig } func typeIsTextual(tt sqltypes.Type) bool { - return sqltypes.IsTextOrBinary(tt) || tt == sqltypes.Time + return sqltypes.IsTextOrBinary(tt) || tt == sqltypes.Time || tt == sqltypes.Enum || tt == sqltypes.Set } func compareAsStrings(l, r sqltypes.Type) bool { @@ -143,6 +143,14 @@ func compareAsDates(l, r sqltypes.Type) bool { return sqltypes.IsDateOrTime(l) && sqltypes.IsDateOrTime(r) } +func compareAsEnums(l, r sqltypes.Type) bool { + return sqltypes.IsEnum(l) && sqltypes.IsEnum(r) +} + +func compareAsSets(l, r sqltypes.Type) bool { + return sqltypes.IsSet(l) && sqltypes.IsSet(r) +} + func compareAsDateAndString(l, r sqltypes.Type) bool { return (sqltypes.IsDate(l) && typeIsTextual(r)) || (typeIsTextual(l) && sqltypes.IsDate(r)) } @@ -223,6 +231,10 @@ func evalCompare(left, right eval, collationEnv *collations.Environment) (comp i switch { case compareAsDates(lt, rt): return compareDates(left.(*evalTemporal), right.(*evalTemporal)), nil + case compareAsEnums(lt, rt): + return compareEnums(left.(*evalEnum), right.(*evalEnum)), nil + case compareAsSets(lt, rt): + return compareSets(left.(*evalSet), right.(*evalSet)), nil case compareAsStrings(lt, rt): return compareStrings(left, right, collationEnv) case compareAsSameNumericType(lt, rt) || compareAsDecimal(lt, rt): diff --git a/go/vt/vtgate/evalengine/expr_env.go b/go/vt/vtgate/evalengine/expr_env.go index 6e09b03cffb..38a65f9b4e0 100644 --- a/go/vt/vtgate/evalengine/expr_env.go +++ b/go/vt/vtgate/evalengine/expr_env.go @@ -104,7 +104,7 @@ func (env *ExpressionEnv) TypeOf(expr Expr) (Type, error) { if err != nil { return Type{}, err } - return NewTypeEx(ty.Type, ty.Col.Collation, ty.Flag&flagNullable != 0, ty.Size, ty.Scale), nil + return NewTypeEx(ty.Type, ty.Col.Collation, ty.Flag&flagNullable != 0, ty.Size, ty.Scale, ty.Values), nil } func (env *ExpressionEnv) SetTime(now time.Time) { diff --git a/go/vt/vtgate/evalengine/expr_tuple_bvar.go b/go/vt/vtgate/evalengine/expr_tuple_bvar.go index 3b2553f25ba..14cfbd95a8b 100644 --- a/go/vt/vtgate/evalengine/expr_tuple_bvar.go +++ b/go/vt/vtgate/evalengine/expr_tuple_bvar.go @@ -71,7 +71,7 @@ func (bv *TupleBindVariable) eval(env *ExpressionEnv) (eval, error) { return } found = true - e, err := valueToEval(val, typedCoercionCollation(val.Type(), collations.CollationForType(val.Type(), bv.Collation))) + e, err := valueToEval(val, typedCoercionCollation(val.Type(), collations.CollationForType(val.Type(), bv.Collation)), nil) if err != nil { evalErr = err return diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index d1c32b113c2..99ffd956513 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -686,7 +686,9 @@ func (u *UntypedExpr) loadTypedExpression(env *ExpressionEnv) (*typedExpr, error defer u.mu.Unlock() for _, typed := range u.typed { - if slices.Equal(typed.types, dynamicTypes) { + if slices.EqualFunc(typed.types, dynamicTypes, func(a, b ctype) bool { + return a.equal(b) + }) { return typed, nil } } diff --git a/go/vt/vtgate/evalengine/weights.go b/go/vt/vtgate/evalengine/weights.go index 2a9d6c9f93e..3eb9aa290c5 100644 --- a/go/vt/vtgate/evalengine/weights.go +++ b/go/vt/vtgate/evalengine/weights.go @@ -41,11 +41,11 @@ import ( // externally communicates with the `WEIGHT_STRING` function, so that we // can also use this to order / sort other types like Float and Decimal // as well. -func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, sqlmode SQLMode) ([]byte, bool, error) { +func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, values *EnumSetValues, sqlmode SQLMode) ([]byte, bool, error) { // We optimize here for the case where we already have the desired type. // Otherwise, we fall back to the general evalengine conversion logic. if v.Type() != coerceTo { - return fallbackWeightString(dst, v, coerceTo, col, length, precision, sqlmode) + return fallbackWeightString(dst, v, coerceTo, col, length, precision, values, sqlmode) } switch { @@ -116,13 +116,17 @@ func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col coll return dst, false, err } return j.WeightString(dst), false, nil + case coerceTo == sqltypes.Enum: + return evalWeightString(dst, newEvalEnum(v.Raw(), values), length, precision) + case coerceTo == sqltypes.Set: + return evalWeightString(dst, newEvalSet(v.Raw(), values), length, precision) default: - return fallbackWeightString(dst, v, coerceTo, col, length, precision, sqlmode) + return fallbackWeightString(dst, v, coerceTo, col, length, precision, values, sqlmode) } } -func fallbackWeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, sqlmode SQLMode) ([]byte, bool, error) { - e, err := valueToEvalCast(v, coerceTo, col, sqlmode) +func fallbackWeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, values *EnumSetValues, sqlmode SQLMode) ([]byte, bool, error) { + e, err := valueToEvalCast(v, coerceTo, col, values, sqlmode) if err != nil { return dst, false, err } @@ -174,6 +178,14 @@ func evalWeightString(dst []byte, e eval, length, precision int) ([]byte, bool, return e.dt.WeightString(dst), true, nil case *evalJSON: return e.WeightString(dst), false, nil + case *evalEnum: + raw := uint64(e.value) + raw = raw ^ (1 << 63) + return binary.BigEndian.AppendUint64(dst, raw), true, nil + case *evalSet: + raw := e.set + raw = raw ^ (1 << 63) + return binary.BigEndian.AppendUint64(dst, raw), true, nil } return dst, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected type %v", e.SQLType()) @@ -192,7 +204,7 @@ func TinyWeighter(f *querypb.Field, collation collations.ID) func(v *sqltypes.Va case sqltypes.IsNull(f.Type): return nil - case sqltypes.IsSigned(f.Type): + case sqltypes.IsSigned(f.Type), f.Type == sqltypes.Enum, f.Type == sqltypes.Set: return func(v *sqltypes.Value) { i, err := v.ToInt64() if err != nil { @@ -301,7 +313,6 @@ func TinyWeighter(f *querypb.Field, collation collations.ID) func(v *sqltypes.Va copy(w32[:4], j.WeightString(nil)) v.SetTinyWeight(binary.BigEndian.Uint32(w32[:4])) } - default: return nil } diff --git a/go/vt/vtgate/evalengine/weights_test.go b/go/vt/vtgate/evalengine/weights_test.go index 9a34e6e9e81..95764d3c3a4 100644 --- a/go/vt/vtgate/evalengine/weights_test.go +++ b/go/vt/vtgate/evalengine/weights_test.go @@ -32,11 +32,12 @@ func TestTinyWeightStrings(t *testing.T) { const Length = 10000 var cases = []struct { - typ sqltypes.Type - gen func() sqltypes.Value - col collations.ID - len int - prec int + typ sqltypes.Type + gen func() sqltypes.Value + col collations.ID + len int + prec int + values *EnumSetValues }{ {typ: sqltypes.Int32, gen: sqltypes.RandomGenerators[sqltypes.Int32], col: collations.CollationBinaryID}, {typ: sqltypes.Int64, gen: sqltypes.RandomGenerators[sqltypes.Int64], col: collations.CollationBinaryID}, @@ -47,6 +48,8 @@ func TestTinyWeightStrings(t *testing.T) { {typ: sqltypes.VarBinary, gen: sqltypes.RandomGenerators[sqltypes.VarBinary], col: collations.CollationBinaryID}, {typ: sqltypes.Decimal, gen: sqltypes.RandomGenerators[sqltypes.Decimal], col: collations.CollationBinaryID, len: 20, prec: 10}, {typ: sqltypes.TypeJSON, gen: sqltypes.RandomGenerators[sqltypes.TypeJSON], col: collations.CollationBinaryID}, + {typ: sqltypes.Enum, gen: sqltypes.RandomGenerators[sqltypes.Enum], col: collations.CollationBinaryID, values: &EnumSetValues{"'xxsmall'", "'xsmall'", "'small'", "'medium'", "'large'", "'xlarge'", "'xxlarge'"}}, + {typ: sqltypes.Set, gen: sqltypes.RandomGenerators[sqltypes.Set], col: collations.CollationBinaryID, values: &EnumSetValues{"'a'", "'b'", "'c'", "'d'", "'e'", "'f'", "'g'"}}, } for _, tc := range cases { @@ -77,7 +80,7 @@ func TestTinyWeightStrings(t *testing.T) { return cmp } - cmp, err := NullsafeCompare(a, b, collations.MySQL8(), tc.col) + cmp, err := NullsafeCompare(a, b, collations.MySQL8(), tc.col, tc.values) require.NoError(t, err) fullComparisons++ @@ -88,7 +91,7 @@ func TestTinyWeightStrings(t *testing.T) { a := items[i] b := items[i+1] - cmp, err := NullsafeCompare(a, b, collations.MySQL8(), tc.col) + cmp, err := NullsafeCompare(a, b, collations.MySQL8(), tc.col, tc.values) require.NoError(t, err) if cmp > 0 { @@ -110,12 +113,13 @@ func TestWeightStrings(t *testing.T) { } var cases = []struct { - name string - gen func() sqltypes.Value - types []sqltypes.Type - col collations.ID - len int - prec int + name string + gen func() sqltypes.Value + types []sqltypes.Type + col collations.ID + len int + prec int + values *EnumSetValues }{ {name: "int64", gen: sqltypes.RandomGenerators[sqltypes.Int64], types: []sqltypes.Type{sqltypes.Int64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, {name: "uint64", gen: sqltypes.RandomGenerators[sqltypes.Uint64], types: []sqltypes.Type{sqltypes.Uint64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, @@ -128,6 +132,8 @@ func TestWeightStrings(t *testing.T) { {name: "datetime", gen: sqltypes.RandomGenerators[sqltypes.Datetime], types: []sqltypes.Type{sqltypes.Datetime, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, {name: "timestamp", gen: sqltypes.RandomGenerators[sqltypes.Timestamp], types: []sqltypes.Type{sqltypes.Timestamp, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, {name: "time", gen: sqltypes.RandomGenerators[sqltypes.Time], types: []sqltypes.Type{sqltypes.Time, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, + {name: "enum", gen: sqltypes.RandomGenerators[sqltypes.Enum], types: []sqltypes.Type{sqltypes.Enum, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID, values: &EnumSetValues{"'xxsmall'", "'xsmall'", "'small'", "'medium'", "'large'", "'xlarge'", "'xxlarge'"}}, + {name: "set", gen: sqltypes.RandomGenerators[sqltypes.Set], types: []sqltypes.Type{sqltypes.Set, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID, values: &EnumSetValues{"'a'", "'b'", "'c'", "'d'", "'e'", "'f'", "'g'"}}, } for _, tc := range cases { @@ -136,7 +142,7 @@ func TestWeightStrings(t *testing.T) { items := make([]item, 0, Length) for i := 0; i < Length; i++ { v := tc.gen() - w, _, err := WeightString(nil, v, typ, tc.col, tc.len, tc.prec, 0) + w, _, err := WeightString(nil, v, typ, tc.col, tc.len, tc.prec, tc.values, 0) require.NoError(t, err) items = append(items, item{value: v, weight: string(w)}) @@ -156,9 +162,9 @@ func TestWeightStrings(t *testing.T) { a := items[i] b := items[i+1] - v1, err := valueToEvalCast(a.value, typ, tc.col, 0) + v1, err := valueToEvalCast(a.value, typ, tc.col, tc.values, 0) require.NoError(t, err) - v2, err := valueToEvalCast(b.value, typ, tc.col, 0) + v2, err := valueToEvalCast(b.value, typ, tc.col, tc.values, 0) require.NoError(t, err) cmp, err := evalCompareNullSafe(v1, v2, collations.MySQL8()) diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 572afa42f72..2a7f37a258f 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -918,6 +918,7 @@ func transformHashJoin(ctx *plancontext.PlanningContext, op *operators.HashJoin) Collation: comparisonType.Collation(), ComparisonType: comparisonType.Type(), CollationEnv: ctx.VSchema.Environment().CollationEnv(), + Values: comparisonType.Values(), }, }, nil } diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 6c6e495b33d..6c89b2bb999 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -667,7 +667,7 @@ func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { ws, isWS := e.(*sqlparser.WeightStringFuncExpr) if isWS { wt, _ := st.TypeForExpr(ws.Expr) - return evalengine.NewTypeEx(sqltypes.VarBinary, collations.CollationBinaryID, wt.Nullable(), 0, 0), true + return evalengine.NewTypeEx(sqltypes.VarBinary, collations.CollationBinaryID, wt.Nullable(), 0, 0, nil), true } return evalengine.Type{}, false diff --git a/go/vt/vtgate/vindexes/consistent_lookup.go b/go/vt/vtgate/vindexes/consistent_lookup.go index f32adc0f772..d231f358a37 100644 --- a/go/vt/vtgate/vindexes/consistent_lookup.go +++ b/go/vt/vtgate/vindexes/consistent_lookup.go @@ -412,7 +412,7 @@ func (lu *clCommon) Delete(ctx context.Context, vcursor VCursor, rowsColValues [ func (lu *clCommon) Update(ctx context.Context, vcursor VCursor, oldValues []sqltypes.Value, ksid []byte, newValues []sqltypes.Value) error { equal := true for i := range oldValues { - result, err := evalengine.NullsafeCompare(oldValues[i], newValues[i], vcursor.Environment().CollationEnv(), vcursor.ConnCollation()) + result, err := evalengine.NullsafeCompare(oldValues[i], newValues[i], vcursor.Environment().CollationEnv(), vcursor.ConnCollation(), nil) // errors from NullsafeCompare can be ignored. if they are real problems, we'll see them in the Create/Update if err != nil || result != 0 { equal = false diff --git a/go/vt/vtgate/vindexes/vschema.go b/go/vt/vtgate/vindexes/vschema.go index 8dc889fc848..8e5e8b547a6 100644 --- a/go/vt/vtgate/vindexes/vschema.go +++ b/go/vt/vtgate/vindexes/vschema.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "vitess.io/vitess/go/ptr" "vitess.io/vitess/go/vt/topotools" "vitess.io/vitess/go/json2" @@ -233,7 +234,7 @@ func (col *Column) ToEvalengineType(collationEnv *collations.Environment) evalen } else { collation = collations.CollationForType(col.Type, collationEnv.DefaultConnectionCharset()) } - return evalengine.NewTypeEx(col.Type, collation, col.Nullable, col.Size, col.Scale) + return evalengine.NewTypeEx(col.Type, collation, col.Nullable, col.Size, col.Scale, ptr.Of(evalengine.EnumSetValues(col.Values))) } // KeyspaceSchema contains the schema(table) for a keyspace. diff --git a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go index 142e79c40d0..12ab13b0e42 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go +++ b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go @@ -701,7 +701,7 @@ func (td *tableDiffer) compare(sourceRow, targetRow []sqltypes.Value, cols []com if collationID == collations.Unknown { collationID = collations.CollationBinaryID } - c, err = evalengine.NullsafeCompare(sourceRow[compareIndex], targetRow[compareIndex], td.wd.collationEnv, collationID) + c, err = evalengine.NullsafeCompare(sourceRow[compareIndex], targetRow[compareIndex], td.wd.collationEnv, collationID, nil) if err != nil { return 0, err } diff --git a/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go b/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go index 424daad4871..d4b733b4c0b 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go +++ b/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go @@ -303,7 +303,7 @@ func (tp *TablePlan) isOutsidePKRange(bindvars map[string]*querypb.BindVariable, rowVal, _ := sqltypes.BindVariableToValue(bindvar) // TODO(king-11) make collation aware - result, err := evalengine.NullsafeCompare(rowVal, tp.Lastpk.Rows[0][0], tp.CollationEnv, collations.Unknown) + result, err := evalengine.NullsafeCompare(rowVal, tp.Lastpk.Rows[0][0], tp.CollationEnv, collations.Unknown, nil) // If rowVal is > last pk, transaction will be a noop, so don't apply this statement if err == nil && result > 0 { tp.Stats.NoopQueryCount.Add(stmtType, 1) diff --git a/go/vt/vttablet/tabletserver/schema/load_table.go b/go/vt/vttablet/tabletserver/schema/load_table.go index e4e464f3fce..6022f8724eb 100644 --- a/go/vt/vttablet/tabletserver/schema/load_table.go +++ b/go/vt/vttablet/tabletserver/schema/load_table.go @@ -215,7 +215,7 @@ func getSpecifiedMessageFields(tableFields []*querypb.Field, specifiedCols []str fields := make([]*querypb.Field, 0, len(specifiedCols)) for _, col := range specifiedCols { for _, field := range tableFields { - if res, _ := evalengine.NullsafeCompare(sqltypes.NewVarChar(field.Name), sqltypes.NewVarChar(strings.TrimSpace(col)), collationEnv, collationEnv.DefaultConnectionCharset()); res == 0 { + if res, _ := evalengine.NullsafeCompare(sqltypes.NewVarChar(field.Name), sqltypes.NewVarChar(strings.TrimSpace(col)), collationEnv, collationEnv.DefaultConnectionCharset(), nil); res == 0 { fields = append(fields, field) break } diff --git a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go index c3e1975c0a1..ad2f218f8d1 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go +++ b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go @@ -172,7 +172,7 @@ func compare(comparison Opcode, columnValue, filterValue sqltypes.Value, collati } // at this point neither values can be null // NullsafeCompare returns 0 if values match, -1 if columnValue < filterValue, 1 if columnValue > filterValue - result, err := evalengine.NullsafeCompare(columnValue, filterValue, collationEnv, charset) + result, err := evalengine.NullsafeCompare(columnValue, filterValue, collationEnv, charset, nil) if err != nil { return false, err } diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 8145d1c9e51..2e9529070f6 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -31,6 +31,7 @@ import ( "vitess.io/vitess/go/mysql/replication" "vitess.io/vitess/go/mysql/sqlerror" + "vitess.io/vitess/go/ptr" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/mysql/collations" @@ -116,9 +117,10 @@ type vdiff struct { // compareColInfo contains the metadata for a column of the table being diffed type compareColInfo struct { - colIndex int // index of the column in the filter's select - collation collations.ID // is the collation of the column, if any - isPK bool // is this column part of the primary key + colIndex int // index of the column in the filter's select + collation collations.ID // is the collation of the column, if any + values *evalengine.EnumSetValues // is the list of enum or set values for the column, if any + isPK bool // is this column part of the primary key } // tableDiffer performs a diff for one table in the workflow. @@ -492,7 +494,7 @@ func (df *vdiff) buildVDiffPlan(filter *binlogdatapb.Filter, schm *tabletmanager // findPKs identifies PKs, determines any collations to be used for // them, and removes them from the columns used for data comparison. func findPKs(env *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition, targetSelect *sqlparser.Select, td *tableDiffer) (sqlparser.OrderBy, error) { - columnCollations, err := getColumnCollations(env, table) + columnCollations, columnValues, err := getColumnCollations(env, table) if err != nil { return nil, err } @@ -513,6 +515,7 @@ func findPKs(env *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition, if strings.EqualFold(pk, colname) { td.compareCols[i].isPK = true td.compareCols[i].collation = columnCollations[strings.ToLower(colname)] + td.compareCols[i].values = columnValues[strings.ToLower(colname)] td.comparePKs = append(td.comparePKs, td.compareCols[i]) td.selectPks = append(td.selectPks, i) // We'll be comparing pks separately. So, remove them from compareCols. @@ -536,19 +539,19 @@ func findPKs(env *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition, // getColumnCollations determines the proper collation to use for each // column in the table definition leveraging MySQL's collation inheritance // rules. -func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition) (map[string]collations.ID, error) { +func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition) (map[string]collations.ID, map[string]*evalengine.EnumSetValues, error) { createstmt, err := venv.Parser().Parse(table.Schema) if err != nil { - return nil, err + return nil, nil, err } createtable, ok := createstmt.(*sqlparser.CreateTable) if !ok { - return nil, vterrors.Wrapf(err, "invalid table schema %s for table %s", table.Schema, table.Name) + return nil, nil, vterrors.Wrapf(err, "invalid table schema %s for table %s", table.Schema, table.Name) } env := schemadiff.NewEnv(venv, venv.CollationEnv().DefaultConnectionCharset()) tableschema, err := schemadiff.NewCreateTableEntity(env, createtable) if err != nil { - return nil, vterrors.Wrapf(err, "invalid table schema %s for table %s", table.Schema, table.Name) + return nil, nil, vterrors.Wrapf(err, "invalid table schema %s for table %s", table.Schema, table.Name) } tableCharset := tableschema.GetCharset() tableCollation := tableschema.GetCollation() @@ -579,6 +582,7 @@ func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.Tab } columnCollations := make(map[string]collations.ID) + columnValues := make(map[string]*evalengine.EnumSetValues) for _, column := range tableschema.TableSpec.Columns { // If it's not a character based type then no collation is used. if !sqltypes.IsQuoted(column.Type.SQLType()) { @@ -586,8 +590,12 @@ func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.Tab continue } columnCollations[column.Name.Lowered()] = getColumnCollation(column) + if len(column.Type.EnumValues) == 0 { + continue + } + columnValues[column.Name.Lowered()] = ptr.Of(evalengine.EnumSetValues(column.Type.EnumValues)) } - return columnCollations, nil + return columnCollations, columnValues, nil } // If SourceTimeZone is defined in the BinlogSource, the VReplication workflow would have converted the datetime @@ -1318,7 +1326,7 @@ func (td *tableDiffer) compare(sourceRow, targetRow []sqltypes.Value, cols []com if col.collation == collations.Unknown { collationID = collations.CollationBinaryID } - c, err = evalengine.NullsafeCompare(sourceRow[compareIndex], targetRow[compareIndex], td.collationEnv, collationID) + c, err = evalengine.NullsafeCompare(sourceRow[compareIndex], targetRow[compareIndex], td.collationEnv, collationID, col.values) if err != nil { return 0, err } diff --git a/go/vt/wrangler/vdiff_test.go b/go/vt/wrangler/vdiff_test.go index 1b0071ebed7..87988c5fd7e 100644 --- a/go/vt/wrangler/vdiff_test.go +++ b/go/vt/wrangler/vdiff_test.go @@ -18,7 +18,6 @@ package wrangler import ( "context" - "reflect" "strings" "testing" "time" @@ -35,6 +34,7 @@ import ( "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/engine/opcode" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) func TestVDiffPlanSuccess(t *testing.T) { @@ -94,12 +94,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -113,12 +113,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -132,12 +132,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -151,12 +151,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c2, c1 from t1 order by c1 asc", targetExpression: "select c2, c1 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collations.Unknown, true}}, - comparePKs: []compareColInfo{{1, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collations.Unknown, nil, true}}, + comparePKs: []compareColInfo{{1, collations.Unknown, nil, true}}, pkCols: []int{1}, selectPks: []int{1}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -170,12 +170,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c0 as c1, c2 from t2 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -190,12 +190,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "nonpktext", sourceExpression: "select c1, textcol from nonpktext order by c1 asc", targetExpression: "select c1, textcol from nonpktext order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -210,12 +210,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "nonpktext", sourceExpression: "select textcol, c1 from nonpktext order by c1 asc", targetExpression: "select textcol, c1 from nonpktext order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collations.Unknown, true}}, - comparePKs: []compareColInfo{{1, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collations.Unknown, nil, true}}, + comparePKs: []compareColInfo{{1, collations.Unknown, nil, true}}, pkCols: []int{1}, selectPks: []int{1}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -230,12 +230,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "pktext", sourceExpression: "select textcol, c2 from pktext order by textcol asc", targetExpression: "select textcol, c2 from pktext order by textcol asc", - compareCols: []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), true}}, + compareCols: []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -250,12 +250,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "pktext", sourceExpression: "select c2, textcol from pktext order by textcol asc", targetExpression: "select c2, textcol from pktext order by textcol asc", - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collationEnv.DefaultConnectionCharset(), true}}, - comparePKs: []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collationEnv.DefaultConnectionCharset(), nil, true}}, + comparePKs: []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, true}}, pkCols: []int{1}, selectPks: []int{1}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -270,12 +270,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "pktext", sourceExpression: "select c2, a + b as textcol from pktext order by textcol asc", targetExpression: "select c2, textcol from pktext order by textcol asc", - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collationEnv.DefaultConnectionCharset(), true}}, - comparePKs: []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collationEnv.DefaultConnectionCharset(), nil, true}}, + comparePKs: []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, true}}, pkCols: []int{1}, selectPks: []int{1}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -288,12 +288,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "multipk", sourceExpression: "select c1, c2 from multipk order by c1 asc, c2 asc", targetExpression: "select c1, c2 from multipk order by c1 asc, c2 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, true}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, true}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, true}}, pkCols: []int{0, 1}, selectPks: []int{0, 1}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -308,12 +308,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -329,12 +329,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 where c2 = 2 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -350,12 +350,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 where c2 = 2 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -371,12 +371,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 where c2 = 2 and c1 = 1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -392,12 +392,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 where c2 = 2 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -412,12 +412,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 group by c1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -432,8 +432,8 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "aggr", sourceExpression: "select c1, c2, count(*) as c3, sum(c4) as c4 from t1 group by c1 order by c1 asc", targetExpression: "select c1, c2, c3, c4 from aggr order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}, {2, collations.Unknown, false}, {3, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}, {2, collations.Unknown, nil, false}, {3, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, sourcePrimitive: &engine.OrderedAggregate{ @@ -442,10 +442,10 @@ func TestVDiffPlanSuccess(t *testing.T) { engine.NewAggregateParam(opcode.AggregateSum, 3, "", collationEnv), }, GroupByKeys: []*engine.GroupByParams{{KeyCol: 0, WeightStringCol: -1, CollationEnv: collations.MySQL8()}}, - Input: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + Input: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), CollationEnv: collationEnv, }, - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -459,12 +459,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "datze", sourceExpression: "select id, dt from datze order by id asc", targetExpression: "select id, convert_tz(dt, 'UTC', 'US/Pacific') as dt from datze order by id asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -1078,13 +1078,13 @@ func TestVDiffFindPKs(t *testing.T) { }, }, tdIn: &tableDiffer{ - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collations.Unknown, false}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collations.Unknown, nil, false}}, comparePKs: []compareColInfo{}, pkCols: []int{}, }, tdOut: &tableDiffer{ - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, }, @@ -1106,13 +1106,13 @@ func TestVDiffFindPKs(t *testing.T) { }, }, tdIn: &tableDiffer{ - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collations.Unknown, false}, {2, collations.Unknown, false}, {3, collations.Unknown, false}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collations.Unknown, nil, false}, {2, collations.Unknown, nil, false}, {3, collations.Unknown, nil, false}}, comparePKs: []compareColInfo{}, pkCols: []int{}, }, tdOut: &tableDiffer{ - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}, {2, collations.Unknown, false}, {3, collations.Unknown, true}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}, {3, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}, {2, collations.Unknown, nil, false}, {3, collations.Unknown, nil, true}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}, {3, collations.Unknown, nil, true}}, pkCols: []int{0, 3}, selectPks: []int{0, 3}, }, @@ -1184,10 +1184,11 @@ func TestVDiffPlanInclude(t *testing.T) { func TestGetColumnCollations(t *testing.T) { collationEnv := collations.MySQL8() tests := []struct { - name string - table *tabletmanagerdatapb.TableDefinition - want map[string]collations.ID - wantErr bool + name string + table *tabletmanagerdatapb.TableDefinition + wantCols map[string]collations.ID + wantValues map[string]*evalengine.EnumSetValues + wantErr bool }{ { name: "invalid schema", @@ -1201,94 +1202,128 @@ func TestGetColumnCollations(t *testing.T) { table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 int, name varchar(10), primary key(c1))", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collations.Unknown, "name": collationEnv.DefaultConnectionCharset(), }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "char pk with global default collation", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10), name varchar(10), primary key(c1))", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.DefaultConnectionCharset(), "name": collationEnv.DefaultConnectionCharset(), }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "compound char int pk with global default collation", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 int, name varchar(10), primary key(c1, name))", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collations.Unknown, "name": collationEnv.DefaultConnectionCharset(), }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "char pk with table default charset", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10), name varchar(10), primary key(c1)) default character set ucs2", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.DefaultCollationForCharset("ucs2"), "name": collationEnv.DefaultCollationForCharset("ucs2"), }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "char pk with table default collation", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10), name varchar(10), primary key(c1)) charset=utf32 collate=utf32_icelandic_ci", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.LookupByName("utf32_icelandic_ci"), "name": collationEnv.LookupByName("utf32_icelandic_ci"), }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "char pk with column charset override", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10) charset sjis, name varchar(10), primary key(c1)) character set=utf8", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.DefaultCollationForCharset("sjis"), "name": collationEnv.DefaultCollationForCharset("utf8mb3"), }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "char pk with column collation override", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10) collate hebrew_bin, name varchar(10), primary key(c1)) charset=hebrew", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.LookupByName("hebrew_bin"), "name": collationEnv.DefaultCollationForCharset("hebrew"), }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "compound char int pk with column collation override", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10) collate utf16_turkish_ci, c2 int, name varchar(10), primary key(c1, c2)) charset=utf16 collate=utf16_icelandic_ci", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.LookupByName("utf16_turkish_ci"), "c2": collations.Unknown, "name": collationEnv.LookupByName("utf16_icelandic_ci"), }, + wantValues: map[string]*evalengine.EnumSetValues{}, + }, + { + name: "col with enum values", + table: &tabletmanagerdatapb.TableDefinition{ + Schema: "create table t1 (c1 varchar(10), size enum('small', 'medium', 'large'), primary key(c1))", + }, + wantCols: map[string]collations.ID{ + "c1": collationEnv.DefaultConnectionCharset(), + "size": collationEnv.DefaultConnectionCharset(), + }, + wantValues: map[string]*evalengine.EnumSetValues{ + "size": {"'small'", "'medium'", "'large'"}, + }, + }, + { + name: "col with set values", + table: &tabletmanagerdatapb.TableDefinition{ + Schema: "create table t1 (c1 varchar(10), size set('small', 'medium', 'large'), primary key(c1))", + }, + wantCols: map[string]collations.ID{ + "c1": collationEnv.DefaultConnectionCharset(), + "size": collationEnv.DefaultConnectionCharset(), + }, + wantValues: map[string]*evalengine.EnumSetValues{ + "size": {"'small'", "'medium'", "'large'"}, + }, }, } env := vtenv.NewTestEnv() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := getColumnCollations(env, tt.table) - if (err != nil) != tt.wantErr { - t.Errorf("getColumnCollations() error = %v, wantErr = %t", err, tt.wantErr) + gotCols, gotValues, err := getColumnCollations(env, tt.table) + if tt.wantErr { + require.Error(t, err) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("getColumnCollations() = %+v, want %+v", got, tt.want) - } + require.NoError(t, err) + require.Equal(t, tt.wantCols, gotCols) + require.Equal(t, tt.wantValues, gotValues) }) } }