Skip to content

Commit

Permalink
[wip] propagate decimal precision in the evalengine
Browse files Browse the repository at this point in the history
Signed-off-by: Vicent Marti <[email protected]>
  • Loading branch information
vmg authored and systay committed Feb 10, 2024
1 parent 5878c08 commit a5d4474
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 43 deletions.
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ func (ap *AggregateParams) String() string {

func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type {
if ap.OrigOpcode != AggregateUnassigned {
return ap.OrigOpcode.Type(inputType)
return ap.OrigOpcode.SQLType(inputType)
}
return ap.Opcode.Type(inputType)
return ap.Opcode.SQLType(inputType)
}

type aggregator interface {
Expand Down
26 changes: 25 additions & 1 deletion go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package opcode
import (
"fmt"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

// PulloutOpcode is a number representing the opcode
Expand Down Expand Up @@ -138,7 +140,7 @@ func (code AggregateOpcode) MarshalJSON() ([]byte, error) {
}

// Type returns the opcode return sql type, and a bool telling is we are sure about this type or not
func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type {
switch code {
case AggregateUnassigned:
return sqltypes.Null
Expand Down Expand Up @@ -169,6 +171,28 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
}
}

func (code AggregateOpcode) Nullable() bool {
switch code {
case AggregateCount, AggregateCountStar:
return false
default:
return true
}
}

func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Environment) evalengine.Type {
sqltype := code.SQLType(t.Type())
collation := collations.CollationForType(sqltype, env.DefaultConnectionCharset())
nullable := code.Nullable()
size := t.Size()

scale := t.Scale()
if code == AggregateAvg {
scale += 4
}
return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale)
}

func (code AggregateOpcode) NeedsComparableValues() bool {
switch code {
case AggregateCountDistinct, AggregateSumDistinct, AggregateMin, AggregateMax:
Expand Down
10 changes: 6 additions & 4 deletions go/vt/vtgate/engine/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,12 @@ func (p *Projection) evalFields(env *evalengine.ExpressionEnv, infields []*query
fl |= uint32(querypb.MySqlFlag_NOT_NULL_FLAG)
}
fields = append(fields, &querypb.Field{
Name: col,
Type: typ.Type(),
Charset: uint32(typ.Collation()),
Flags: fl,
Name: col,
Type: typ.Type(),
Charset: uint32(typ.Collation()),
ColumnLength: uint32(typ.Size()),
Decimals: uint32(typ.Scale()),
Flags: fl,
})
}
return fields, nil
Expand Down
64 changes: 35 additions & 29 deletions go/vt/vtgate/evalengine/expr_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,40 +94,41 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) {
rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true)
lt, rt, swap = c.compileNumericPriority(lt, rt)

var sumtype sqltypes.Type
ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric, Scale: max(lt.Scale, rt.Scale)}

switch lt.Type {
case sqltypes.Int64:
c.asm.Add_ii()
sumtype = sqltypes.Int64
ct.Type = sqltypes.Int64
case sqltypes.Uint64:
switch rt.Type {
case sqltypes.Int64:
c.asm.Add_ui(swap)
case sqltypes.Uint64:
c.asm.Add_uu()
}
sumtype = sqltypes.Uint64
ct.Type = sqltypes.Uint64
case sqltypes.Decimal:
if swap {
c.compileToDecimal(rt, 2)
} else {
c.compileToDecimal(rt, 1)
}
c.asm.Add_dd()
sumtype = sqltypes.Decimal
ct.Type = sqltypes.Decimal
ct.Scale = max(lt.Scale, rt.Scale)
case sqltypes.Float64:
if swap {
c.compileToFloat(rt, 2)
} else {
c.compileToFloat(rt, 1)
}
c.asm.Add_ff()
sumtype = sqltypes.Float64
ct.Type = sqltypes.Float64
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: sumtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
return ct, nil
}

func (op *opArithSub) eval(left, right eval) (eval, error) {
Expand All @@ -151,66 +152,68 @@ func (op *opArithSub) compile(c *compiler, left, right IR) (ctype, error) {
lt = c.compileToNumeric(lt, 2, sqltypes.Float64, true)
rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true)

var subtype sqltypes.Type

ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric, Scale: max(lt.Scale, rt.Scale)}
switch lt.Type {
case sqltypes.Int64:
switch rt.Type {
case sqltypes.Int64:
c.asm.Sub_ii()
subtype = sqltypes.Int64
ct.Type = sqltypes.Int64
case sqltypes.Uint64:
c.asm.Sub_iu()
subtype = sqltypes.Uint64
ct.Type = sqltypes.Uint64
case sqltypes.Float64:
c.compileToFloat(lt, 2)
c.asm.Sub_ff()
subtype = sqltypes.Float64
ct.Type = sqltypes.Float64
case sqltypes.Decimal:
c.compileToDecimal(lt, 2)
c.asm.Sub_dd()
subtype = sqltypes.Decimal
ct.Type = sqltypes.Decimal
ct.Scale = max(lt.Scale, rt.Scale)
}
case sqltypes.Uint64:
switch rt.Type {
case sqltypes.Int64:
c.asm.Sub_ui()
subtype = sqltypes.Uint64
ct.Type = sqltypes.Uint64
case sqltypes.Uint64:
c.asm.Sub_uu()
subtype = sqltypes.Uint64
ct.Type = sqltypes.Uint64
case sqltypes.Float64:
c.compileToFloat(lt, 2)
c.asm.Sub_ff()
subtype = sqltypes.Float64
ct.Type = sqltypes.Float64
case sqltypes.Decimal:
c.compileToDecimal(lt, 2)
c.asm.Sub_dd()
subtype = sqltypes.Decimal
ct.Type = sqltypes.Decimal
ct.Scale = max(lt.Scale, rt.Scale)
}
case sqltypes.Float64:
c.compileToFloat(rt, 1)
c.asm.Sub_ff()
subtype = sqltypes.Float64
ct.Type = sqltypes.Float64
case sqltypes.Decimal:
switch rt.Type {
case sqltypes.Float64:
c.compileToFloat(lt, 2)
c.asm.Sub_ff()
subtype = sqltypes.Float64
ct.Type = sqltypes.Float64
default:
c.compileToDecimal(rt, 1)
c.asm.Sub_dd()
subtype = sqltypes.Decimal
ct.Type = sqltypes.Decimal
ct.Scale = max(lt.Scale, rt.Scale)
}
}

if subtype == 0 {
if ct.Type == 0 {
panic("did not compile?")
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: subtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
return ct, nil
}

func (op *opArithMul) eval(left, right eval) (eval, error) {
Expand All @@ -237,40 +240,40 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) {
rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true)
lt, rt, swap = c.compileNumericPriority(lt, rt)

var multype sqltypes.Type

ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric, Scale: lt.Scale + rt.Scale}
switch lt.Type {
case sqltypes.Int64:
c.asm.Mul_ii()
multype = sqltypes.Int64
ct.Type = sqltypes.Int64
case sqltypes.Uint64:
switch rt.Type {
case sqltypes.Int64:
c.asm.Mul_ui(swap)
case sqltypes.Uint64:
c.asm.Mul_uu()
}
multype = sqltypes.Uint64
ct.Type = sqltypes.Uint64
case sqltypes.Float64:
if swap {
c.compileToFloat(rt, 2)
} else {
c.compileToFloat(rt, 1)
}
c.asm.Mul_ff()
multype = sqltypes.Float64
ct.Type = sqltypes.Float64
case sqltypes.Decimal:
if swap {
c.compileToDecimal(rt, 2)
} else {
c.compileToDecimal(rt, 1)
}
c.asm.Mul_dd()
multype = sqltypes.Decimal
ct.Type = sqltypes.Decimal
ct.Scale = lt.Scale + rt.Scale
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: multype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
return ct, nil
}

func (op *opArithDiv) eval(left, right eval) (eval, error) {
Expand Down Expand Up @@ -306,6 +309,7 @@ func (op *opArithDiv) compile(c *compiler, left, right IR) (ctype, error) {
c.compileToDecimal(lt, 2)
c.compileToDecimal(rt, 1)
c.asm.Div_dd()
ct.Scale = lt.Scale + divPrecisionIncrement
}
c.asm.jumpDestination(skip1, skip2)
return ct, nil
Expand Down Expand Up @@ -419,7 +423,7 @@ func (op *opArithMod) compile(c *compiler, left, right IR) (ctype, error) {
lt = c.compileToNumeric(lt, 2, sqltypes.Float64, true)
rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true)

ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}
ct := ctype{Col: collationNumeric, Flag: flagNullable}
switch lt.Type {
case sqltypes.Int64:
ct.Type = sqltypes.Int64
Expand All @@ -434,6 +438,7 @@ func (op *opArithMod) compile(c *compiler, left, right IR) (ctype, error) {
c.asm.Mod_ff()
case sqltypes.Decimal:
ct.Type = sqltypes.Decimal
ct.Scale = max(lt.Scale, rt.Scale)
c.asm.Convert_xd(2, 0, 0)
c.asm.Mod_dd()
}
Expand All @@ -450,6 +455,7 @@ func (op *opArithMod) compile(c *compiler, left, right IR) (ctype, error) {
c.asm.Mod_ff()
case sqltypes.Decimal:
ct.Type = sqltypes.Decimal
ct.Scale = max(lt.Scale, rt.Scale)
c.asm.Convert_xd(2, 0, 0)
c.asm.Mod_dd()
}
Expand Down
10 changes: 3 additions & 7 deletions go/vt/vtgate/semantics/typer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package semantics

import (
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/engine/opcode"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -55,16 +54,13 @@ func (t *typer) up(cursor *sqlparser.Cursor) error {
if !ok {
return nil
}
inputType := sqltypes.Unknown
var inputType evalengine.Type
if arg := node.GetArg(); arg != nil {
if tt, ok := t.m[arg]; ok {
inputType = tt.Type()
inputType = tt
}
}
type_ := code.Type(inputType)
_, isCount := node.(*sqlparser.Count)
_, isCountStart := node.(*sqlparser.CountStar)
t.m[node] = evalengine.NewTypeEx(type_, collations.CollationForType(type_, t.collationEnv.DefaultConnectionCharset()), !(isCount || isCountStart), 0, 0)
t.m[node] = code.ResolveType(inputType, t.collationEnv)
}
return nil
}
Expand Down

0 comments on commit a5d4474

Please sign in to comment.