Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evalengine: fix numeric coercibility #14473

Merged
merged 5 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/sqltypes/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func IsNull(t querypb.Type) bool {
// switch statements for those who want to cover types
// by their category.
const (
Unknown = -1
Unknown = querypb.Type(-1)
Null = querypb.Type_NULL_TYPE
Int8 = querypb.Type_INT8
Uint8 = querypb.Type_UINT8
Expand Down
1 change: 1 addition & 0 deletions go/test/endtoend/vtgate/queries/random/random_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,4 +339,5 @@ func TestBuggyQueries(t *testing.T) {
mcmp.Exec("select count(tbl1.dname) as caggr1 from dept as tbl0 left join dept as tbl1 on tbl1.dname > tbl1.loc where tbl1.loc <=> tbl1.dname group by tbl1.dname order by tbl1.dname asc")
mcmp.Exec("select count(*) from (select count(*) from dept as tbl0) as tbl0")
mcmp.Exec("select count(*), count(*) from (select count(*) from dept as tbl0) as tbl0, dept as tbl1")
mcmp.Exec(`select distinct case max(tbl0.ename) when min(tbl0.job) then 'sole' else count(case when false then -27 when 'gazelle' then tbl0.deptno end) end as caggr0 from emp as tbl0`)
}
6 changes: 6 additions & 0 deletions go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,19 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
case AggregateUnassigned:
return sqltypes.Null
case AggregateGroupConcat:
if typ == sqltypes.Unknown {
return sqltypes.Unknown
}
if sqltypes.IsBinary(typ) {
return sqltypes.Blob
}
return sqltypes.Text
case AggregateMax, AggregateMin, AggregateAnyValue:
return typ
case AggregateSumDistinct, AggregateSum:
if typ == sqltypes.Unknown {
return sqltypes.Unknown
}
if sqltypes.IsIntegral(typ) || sqltypes.IsDecimal(typ) {
return sqltypes.Decimal
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/api_literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func NewColumn(offset int, typ Type, original sqlparser.Expr) *Column {
return &Column{
Offset: offset,
Type: typ.Type,
Collation: defaultCoercionCollation(typ.Coll),
Collation: typedCoercionCollation(typ.Type, typ.Coll),
Original: original,
dynamicTypeOffset: -1,
}
Expand Down
115 changes: 109 additions & 6 deletions go/vt/vtgate/evalengine/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,115 @@ limitations under the License.

package evalengine

import "vitess.io/vitess/go/mysql/collations"
import (
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/collations/colldata"
"vitess.io/vitess/go/sqltypes"
)

func defaultCoercionCollation(id collations.ID) collations.TypedCollation {
return collations.TypedCollation{
Collation: id,
Coercibility: collations.CoerceCoercible,
Repertoire: collations.RepertoireUnicode,
func typedCoercionCollation(typ sqltypes.Type, id collations.ID) collations.TypedCollation {
switch {
case sqltypes.IsNull(typ):
return collationNull
case sqltypes.IsNumber(typ) || sqltypes.IsDateOrTime(typ):
return collationNumeric
case typ == sqltypes.TypeJSON:
return collationJSON
default:
return collations.TypedCollation{
Collation: id,
Coercibility: collations.CoerceCoercible,
Repertoire: collations.RepertoireUnicode,
}
}
}

func evalCollation(e eval) collations.TypedCollation {
switch e := e.(type) {
case nil:
return collationNull
case evalNumeric, *evalTemporal:
return collationNumeric
case *evalJSON:
return collationJSON
case *evalBytes:
return e.col
default:
return collationBinary
}
}

func mergeCollations(c1, c2 collations.TypedCollation, t1, t2 sqltypes.Type) (collations.TypedCollation, colldata.Coercion, colldata.Coercion, error) {
if c1.Collation == c2.Collation {
return c1, nil, nil, nil
}

lt := sqltypes.IsText(t1) || sqltypes.IsBinary(t1)
rt := sqltypes.IsText(t2) || sqltypes.IsBinary(t2)
if !lt || !rt {
if lt {
return c1, nil, nil, nil
}
if rt {
return c2, nil, nil, nil
}
return collationBinary, nil, nil, nil
}

env := collations.Local()
return colldata.Merge(env, c1, c2, colldata.CoercionOptions{
ConvertToSuperset: true,
ConvertWithCoercion: true,
})
}

func mergeAndCoerceCollations(left, right eval) (eval, eval, collations.TypedCollation, error) {
lt := left.SQLType()
rt := right.SQLType()

mc, coerceLeft, coerceRight, err := mergeCollations(evalCollation(left), evalCollation(right), lt, rt)
if err != nil {
return nil, nil, collations.TypedCollation{}, err
}
if coerceLeft == nil && coerceRight == nil {
return left, right, mc, nil
}

left1 := newEvalRaw(lt, left.(*evalBytes).bytes, mc)
right1 := newEvalRaw(rt, right.(*evalBytes).bytes, mc)

if coerceLeft != nil {
left1.bytes, err = coerceLeft(nil, left1.bytes)
if err != nil {
return nil, nil, collations.TypedCollation{}, err
}
}
if coerceRight != nil {
right1.bytes, err = coerceRight(nil, right1.bytes)
if err != nil {
return nil, nil, collations.TypedCollation{}, err
}
}
return left1, right1, mc, nil
}

type collationAggregation struct {
cur collations.TypedCollation
}

func (ca *collationAggregation) add(env *collations.Environment, tc collations.TypedCollation) error {
if ca.cur.Collation == collations.Unknown {
ca.cur = tc
} else {
var err error
ca.cur, _, _, err = colldata.Merge(env, ca.cur, tc, colldata.CoercionOptions{ConvertToSuperset: true, ConvertWithCoercion: true})
if err != nil {
return err
}
}
return nil
}

func (ca *collationAggregation) result() collations.TypedCollation {
return ca.cur
}
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4165,13 +4165,13 @@ func (asm *assembler) Fn_DATEADD_D(unit datetime.IntervalType, sub bool) {
}

tmp := env.vm.stack[env.vm.sp-2].(*evalTemporal)
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{}, env.now)
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.Unknown, env.now)
env.vm.sp--
return 1
}, "FN DATEADD TEMPORAL(SP-2), INTERVAL(SP-1)")
}

func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col collations.TypedCollation) {
func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col collations.ID) {
asm.adjustStack(-1)
asm.emit(func(env *ExpressionEnv) int {
var interval *datetime.Interval
Expand Down
13 changes: 13 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ func TestCompilerSingle(t *testing.T) {
expression string
values []sqltypes.Value
result string
collation collations.ID
}{
{
expression: "1 + column0",
Expand Down Expand Up @@ -489,6 +490,12 @@ func TestCompilerSingle(t *testing.T) {
expression: `'2020-01-01' + interval month(date_sub(FROM_UNIXTIME(1234), interval 1 month))-1 month`,
result: `CHAR("2020-12-01")`,
},
{
expression: `case column0 when 1 then column1 else column2 end`,
values: []sqltypes.Value{sqltypes.NewInt64(42), sqltypes.NewVarChar("sole"), sqltypes.NewInt64(0)},
result: `VARCHAR("0")`,
collation: collations.CollationUtf8mb4ID,
},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down Expand Up @@ -524,6 +531,9 @@ func TestCompilerSingle(t *testing.T) {
if expected.String() != tc.result {
t.Fatalf("bad evaluation from eval engine: got %s, want %s", expected.String(), tc.result)
}
if tc.collation != collations.Unknown && tc.collation != expected.Collation() {
t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation)
}

// re-run the same evaluation multiple times to ensure results are always consistent
for i := 0; i < 8; i++ {
Expand All @@ -535,6 +545,9 @@ func TestCompilerSingle(t *testing.T) {
if res.String() != tc.result {
t.Errorf("bad evaluation from compiler: got %s, want %s (iteration %d)", res, tc.result, i)
}
if tc.collation != collations.Unknown && tc.collation != res.Collation() {
t.Fatalf("bad collation evaluation from compiler: got %d, want %d", res.Collation(), tc.collation)
}
}
})
}
Expand Down
20 changes: 10 additions & 10 deletions go/vt/vtgate/evalengine/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
fval, _ := fastparse.ParseFloat64(v.RawStr())
return newEvalFloat(fval), nil
default:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -265,7 +265,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
fval, _ := fastparse.ParseFloat64(v.RawStr())
dec = decimal.NewFromFloat(fval)
default:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -285,7 +285,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
i, err := fastparse.ParseInt64(v.RawStr(), 10)
return newEvalInt64(i), err
default:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -304,7 +304,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
u, err := fastparse.ParseUint64(v.RawStr(), 10)
return newEvalUint64(u), err
default:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -315,15 +315,15 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
case sqltypes.IsText(typ) || sqltypes.IsBinary(typ):
switch {
case v.IsText() || v.IsBinary():
return newEvalRaw(v.Type(), v.Raw(), defaultCoercionCollation(collation)), nil
return newEvalRaw(v.Type(), v.Raw(), typedCoercionCollation(v.Type(), collation)), nil
case sqltypes.IsText(typ):
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
return evalToVarchar(e, collation, true)
default:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -333,7 +333,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
case typ == sqltypes.TypeJSON:
return json.NewFromSQL(v)
case typ == sqltypes.Date:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -344,7 +344,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
}
return d, nil
case typ == sqltypes.Datetime || typ == sqltypes.Timestamp:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -355,7 +355,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
}
return dt, nil
case typ == sqltypes.Time:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtgate/evalengine/eval_temporal.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (e *evalTemporal) isZero() bool {
return e.dt.IsZero()
}

func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation, now time.Time) eval {
func (e *evalTemporal) addInterval(interval *datetime.Interval, coll collations.ID, now time.Time) eval {
var tmp *evalTemporal
var ok bool

Expand All @@ -150,16 +150,16 @@ func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collatio
tmp.dt.Date, ok = e.dt.Date.AddInterval(interval)
case tt == sqltypes.Time && !interval.Unit().HasDateParts():
tmp = &evalTemporal{t: e.t}
tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, strcoll.Valid())
tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, coll != collations.Unknown)
case tt == sqltypes.Datetime || tt == sqltypes.Timestamp || (tt == sqltypes.Date && interval.Unit().HasTimeParts()) || (tt == sqltypes.Time && interval.Unit().HasDateParts()):
tmp = e.toDateTime(int(e.prec), now)
tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, strcoll.Valid())
tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, coll != collations.Unknown)
}
if !ok {
return nil
}
if strcoll.Valid() {
return newEvalRaw(sqltypes.Char, tmp.ToRawBytes(), strcoll)
if coll != collations.Unknown {
return newEvalRaw(sqltypes.Char, tmp.ToRawBytes(), typedCoercionCollation(sqltypes.Char, coll))
}
return tmp
}
Expand Down
8 changes: 4 additions & 4 deletions go/vt/vtgate/evalengine/expr_bvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) {

tuple := make([]eval, 0, len(bvar.Values))
for _, value := range bvar.Values {
e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), defaultCoercionCollation(collations.CollationForType(value.Type, bv.Collation)))
e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), typedCoercionCollation(value.Type, collations.CollationForType(value.Type, bv.Collation)))
if err != nil {
return nil, err
}
Expand All @@ -86,7 +86,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) {
if bv.typed() {
typ = bv.Type
}
return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), defaultCoercionCollation(collations.CollationForType(typ, bv.Collation)))
return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation)))
}
}

Expand All @@ -110,7 +110,7 @@ func (bv *BindVariable) typeof(env *ExpressionEnv) (ctype, error) {
case sqltypes.BitNum:
return ctype{Type: sqltypes.VarBinary, Flag: flagBit, Col: collationNumeric}, nil
default:
return ctype{Type: tt, Flag: 0, Col: defaultCoercionCollation(collations.CollationForType(tt, bv.Collation))}, nil
return ctype{Type: tt, Flag: 0, Col: typedCoercionCollation(tt, collations.CollationForType(tt, bv.Collation))}, nil
}
}

Expand All @@ -119,7 +119,7 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) {

if bvar.typed() {
typ.Type = bvar.Type
typ.Col = defaultCoercionCollation(collations.CollationForType(bvar.Type, bvar.Collation))
typ.Col = typedCoercionCollation(bvar.Type, collations.CollationForType(bvar.Type, bvar.Collation))
} else if c.dynamicTypes != nil {
typ = c.dynamicTypes[bvar.dynamicTypeOffset]
} else {
Expand Down
Loading
Loading