diff --git a/go/vt/vtgate/engine/aggregations.go b/go/vt/vtgate/engine/aggregations.go index 3413234c84f..ea10267a7e6 100644 --- a/go/vt/vtgate/engine/aggregations.go +++ b/go/vt/vtgate/engine/aggregations.go @@ -91,9 +91,9 @@ func (ap *AggregateParams) String() string { func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type { if ap.OrigOpcode != AggregateUnassigned { - return ap.OrigOpcode.Type(inputType) + return ap.OrigOpcode.SQLType(inputType) } - return ap.Opcode.Type(inputType) + return ap.Opcode.SQLType(inputType) } type aggregator interface { diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index 8a70df79d0c..2fa0e9446a4 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -19,8 +19,10 @@ package opcode import ( "fmt" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) // PulloutOpcode is a number representing the opcode @@ -138,7 +140,7 @@ func (code AggregateOpcode) MarshalJSON() ([]byte, error) { } // Type returns the opcode return sql type, and a bool telling is we are sure about this type or not -func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type { +func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type { switch code { case AggregateUnassigned: return sqltypes.Null @@ -169,6 +171,28 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type { } } +func (code AggregateOpcode) Nullable() bool { + switch code { + case AggregateCount, AggregateCountStar: + return false + default: + return true + } +} + +func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Environment) evalengine.Type { + sqltype := code.SQLType(t.Type()) + collation := collations.CollationForType(sqltype, env.DefaultConnectionCharset()) + nullable := code.Nullable() + size := t.Size() + + scale := t.Scale() + if code == AggregateAvg { + scale += 4 + } + return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale) +} + func (code AggregateOpcode) NeedsComparableValues() bool { switch code { case AggregateCountDistinct, AggregateSumDistinct, AggregateMin, AggregateMax: diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index e9714b6a8cb..77e07203476 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -158,10 +158,12 @@ func (p *Projection) evalFields(env *evalengine.ExpressionEnv, infields []*query fl |= uint32(querypb.MySqlFlag_NOT_NULL_FLAG) } fields = append(fields, &querypb.Field{ - Name: col, - Type: typ.Type(), - Charset: uint32(typ.Collation()), - Flags: fl, + Name: col, + Type: typ.Type(), + Charset: uint32(typ.Collation()), + ColumnLength: uint32(typ.Size()), + Decimals: uint32(typ.Scale()), + Flags: fl, }) } return fields, nil diff --git a/go/vt/vtgate/evalengine/expr_arithmetic.go b/go/vt/vtgate/evalengine/expr_arithmetic.go index 938803910cb..cf70a36e733 100644 --- a/go/vt/vtgate/evalengine/expr_arithmetic.go +++ b/go/vt/vtgate/evalengine/expr_arithmetic.go @@ -94,12 +94,12 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) { rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) lt, rt, swap = c.compileNumericPriority(lt, rt) - var sumtype sqltypes.Type + ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric, Scale: max(lt.Scale, rt.Scale)} switch lt.Type { case sqltypes.Int64: c.asm.Add_ii() - sumtype = sqltypes.Int64 + ct.Type = sqltypes.Int64 case sqltypes.Uint64: switch rt.Type { case sqltypes.Int64: @@ -107,7 +107,7 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) { case sqltypes.Uint64: c.asm.Add_uu() } - sumtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Decimal: if swap { c.compileToDecimal(rt, 2) @@ -115,7 +115,8 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) { c.compileToDecimal(rt, 1) } c.asm.Add_dd() - sumtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) case sqltypes.Float64: if swap { c.compileToFloat(rt, 2) @@ -123,11 +124,11 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) { c.compileToFloat(rt, 1) } c.asm.Add_ff() - sumtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sumtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil + return ct, nil } func (op *opArithSub) eval(left, right eval) (eval, error) { @@ -151,66 +152,68 @@ func (op *opArithSub) compile(c *compiler, left, right IR) (ctype, error) { lt = c.compileToNumeric(lt, 2, sqltypes.Float64, true) rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) - var subtype sqltypes.Type - + ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric, Scale: max(lt.Scale, rt.Scale)} switch lt.Type { case sqltypes.Int64: switch rt.Type { case sqltypes.Int64: c.asm.Sub_ii() - subtype = sqltypes.Int64 + ct.Type = sqltypes.Int64 case sqltypes.Uint64: c.asm.Sub_iu() - subtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Float64: c.compileToFloat(lt, 2) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: c.compileToDecimal(lt, 2) c.asm.Sub_dd() - subtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) } case sqltypes.Uint64: switch rt.Type { case sqltypes.Int64: c.asm.Sub_ui() - subtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Uint64: c.asm.Sub_uu() - subtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Float64: c.compileToFloat(lt, 2) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: c.compileToDecimal(lt, 2) c.asm.Sub_dd() - subtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) } case sqltypes.Float64: c.compileToFloat(rt, 1) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: switch rt.Type { case sqltypes.Float64: c.compileToFloat(lt, 2) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 default: c.compileToDecimal(rt, 1) c.asm.Sub_dd() - subtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) } } - if subtype == 0 { + if ct.Type == 0 { panic("did not compile?") } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: subtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil + return ct, nil } func (op *opArithMul) eval(left, right eval) (eval, error) { @@ -237,12 +240,11 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) { rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) lt, rt, swap = c.compileNumericPriority(lt, rt) - var multype sqltypes.Type - + ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric, Scale: lt.Scale + rt.Scale} switch lt.Type { case sqltypes.Int64: c.asm.Mul_ii() - multype = sqltypes.Int64 + ct.Type = sqltypes.Int64 case sqltypes.Uint64: switch rt.Type { case sqltypes.Int64: @@ -250,7 +252,7 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) { case sqltypes.Uint64: c.asm.Mul_uu() } - multype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Float64: if swap { c.compileToFloat(rt, 2) @@ -258,7 +260,7 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) { c.compileToFloat(rt, 1) } c.asm.Mul_ff() - multype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: if swap { c.compileToDecimal(rt, 2) @@ -266,11 +268,12 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) { c.compileToDecimal(rt, 1) } c.asm.Mul_dd() - multype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = lt.Scale + rt.Scale } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: multype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil + return ct, nil } func (op *opArithDiv) eval(left, right eval) (eval, error) { @@ -306,6 +309,7 @@ func (op *opArithDiv) compile(c *compiler, left, right IR) (ctype, error) { c.compileToDecimal(lt, 2) c.compileToDecimal(rt, 1) c.asm.Div_dd() + ct.Scale = lt.Scale + divPrecisionIncrement } c.asm.jumpDestination(skip1, skip2) return ct, nil @@ -419,7 +423,7 @@ func (op *opArithMod) compile(c *compiler, left, right IR) (ctype, error) { lt = c.compileToNumeric(lt, 2, sqltypes.Float64, true) rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) - ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable} + ct := ctype{Col: collationNumeric, Flag: flagNullable} switch lt.Type { case sqltypes.Int64: ct.Type = sqltypes.Int64 @@ -434,6 +438,7 @@ func (op *opArithMod) compile(c *compiler, left, right IR) (ctype, error) { c.asm.Mod_ff() case sqltypes.Decimal: ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) c.asm.Convert_xd(2, 0, 0) c.asm.Mod_dd() } @@ -450,6 +455,7 @@ func (op *opArithMod) compile(c *compiler, left, right IR) (ctype, error) { c.asm.Mod_ff() case sqltypes.Decimal: ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) c.asm.Convert_xd(2, 0, 0) c.asm.Mod_dd() } diff --git a/go/vt/vtgate/semantics/typer.go b/go/vt/vtgate/semantics/typer.go index a9c783cee18..54261339114 100644 --- a/go/vt/vtgate/semantics/typer.go +++ b/go/vt/vtgate/semantics/typer.go @@ -18,7 +18,6 @@ package semantics import ( "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine/opcode" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -55,16 +54,13 @@ func (t *typer) up(cursor *sqlparser.Cursor) error { if !ok { return nil } - inputType := sqltypes.Unknown + var inputType evalengine.Type if arg := node.GetArg(); arg != nil { if tt, ok := t.m[arg]; ok { - inputType = tt.Type() + inputType = tt } } - type_ := code.Type(inputType) - _, isCount := node.(*sqlparser.Count) - _, isCountStart := node.(*sqlparser.CountStar) - t.m[node] = evalengine.NewTypeEx(type_, collations.CollationForType(type_, t.collationEnv.DefaultConnectionCharset()), !(isCount || isCountStart), 0, 0) + t.m[node] = code.ResolveType(inputType, t.collationEnv) } return nil }