diff --git a/go/test/endtoend/utils/mysql.go b/go/test/endtoend/utils/mysql.go index de8ce40f992..571c4519da4 100644 --- a/go/test/endtoend/utils/mysql.go +++ b/go/test/endtoend/utils/mysql.go @@ -26,13 +26,21 @@ import ( "github.com/stretchr/testify/assert" +<<<<<<< HEAD "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/sqlparser" +======= +>>>>>>> cd61d85130 (bugfix: wrong field type returned for SUM (#15192)) "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/endtoend/cluster" + "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/mysqlctl" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/sqlparser" ) // NewMySQL creates a new MySQL server using the local mysqld binary. The name of the database @@ -155,7 +163,9 @@ func prepareMySQLWithSchema(params mysql.ConnParams, sql string) error { return nil } -func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumns bool) error { +func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumnNames bool) error { + t.Helper() + if vtQr == nil && mysqlQr == nil { return nil } @@ -168,29 +178,34 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn return errors.New("MySQL result is 'nil' while Vitess' is not.\n") } - var errStr string - if compareColumns { - vtColCount := len(vtQr.Fields) - myColCount := len(mysqlQr.Fields) - if vtColCount > 0 && myColCount > 0 { - if vtColCount != myColCount { - t.Errorf("column count does not match: %d vs %d", vtColCount, myColCount) - errStr += fmt.Sprintf("column count does not match: %d vs %d\n", vtColCount, myColCount) - } - - var vtCols []string - var myCols []string - for i, vtField := range vtQr.Fields { - vtCols = append(vtCols, vtField.Name) - myCols = append(myCols, mysqlQr.Fields[i].Name) - } - if !assert.Equal(t, myCols, vtCols, "column names do not match - the expected values are what mysql produced") { - errStr += "column names do not match - the expected values are what mysql produced\n" - errStr += fmt.Sprintf("Not equal: \nexpected: %v\nactual: %v\n", myCols, vtCols) - } + vtColCount := len(vtQr.Fields) + myColCount := len(mysqlQr.Fields) + + if vtColCount != myColCount { + t.Errorf("column count does not match: %d vs %d", vtColCount, myColCount) + } + + if vtColCount > 0 { + var vtCols []string + var myCols []string + for i, vtField := range vtQr.Fields { + myField := mysqlQr.Fields[i] + checkFields(t, myField.Name, vtField, myField) + + vtCols = append(vtCols, vtField.Name) + myCols = append(myCols, myField.Name) + } + + if compareColumnNames && !assert.Equal(t, myCols, vtCols, "column names do not match - the expected values are what mysql produced") { + t.Errorf("column names do not match - the expected values are what mysql produced\nNot equal: \nexpected: %v\nactual: %v\n", myCols, vtCols) } } +<<<<<<< HEAD stmt, err := sqlparser.Parse(query) +======= + + stmt, err := sqlparser.NewTestParser().Parse(query) +>>>>>>> cd61d85130 (bugfix: wrong field type returned for SUM (#15192)) if err != nil { t.Error(err) return err @@ -204,7 +219,7 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn return nil } - errStr += "Query (" + query + ") results mismatched.\nVitess Results:\n" + errStr := "Query (" + query + ") results mismatched.\nVitess Results:\n" for _, row := range vtQr.Rows { errStr += fmt.Sprintf("%s\n", row) } @@ -224,6 +239,20 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn return errors.New(errStr) } +func checkFields(t *testing.T, columnName string, vtField, myField *querypb.Field) { + t.Helper() + if vtField.Type != myField.Type { + t.Errorf("for column %s field types do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Type.String(), vtField.Type.String()) + } + + // starting in Vitess 20, decimal types are properly sized in their field information + if BinaryIsAtLeastAtVersion(20, "vtgate") && vtField.Type == sqltypes.Decimal { + if vtField.Decimals != myField.Decimals { + t.Errorf("for column %s field decimals count do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Decimals, vtField.Decimals) + } + } +} + func compareVitessAndMySQLErrors(t *testing.T, vtErr, mysqlErr error) { if vtErr != nil && mysqlErr != nil || vtErr == nil && mysqlErr == nil { return diff --git a/go/vt/vtgate/engine/aggregations.go b/go/vt/vtgate/engine/aggregations.go index 63634adb87c..b9f504f7412 100644 --- a/go/vt/vtgate/engine/aggregations.go +++ b/go/vt/vtgate/engine/aggregations.go @@ -89,9 +89,9 @@ func (ap *AggregateParams) String() string { func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type { if ap.OrigOpcode != AggregateUnassigned { - return ap.OrigOpcode.Type(inputType) + return ap.OrigOpcode.SQLType(inputType) } - return ap.Opcode.Type(inputType) + return ap.Opcode.SQLType(inputType) } type aggregator interface { diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index 07a39020f8b..d30c39e2c97 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -19,8 +19,10 @@ package opcode import ( "fmt" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) // PulloutOpcode is a number representing the opcode @@ -134,7 +136,7 @@ func (code AggregateOpcode) MarshalJSON() ([]byte, error) { } // Type returns the opcode return sql type, and a bool telling is we are sure about this type or not -func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type { +func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type { switch code { case AggregateUnassigned: return sqltypes.Null @@ -159,6 +161,28 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type { } } +func (code AggregateOpcode) Nullable() bool { + switch code { + case AggregateCount, AggregateCountStar: + return false + default: + return true + } +} + +func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Environment) evalengine.Type { + sqltype := code.SQLType(t.Type()) + collation := collations.CollationForType(sqltype, env.DefaultConnectionCharset()) + nullable := code.Nullable() + size := t.Size() + + scale := t.Scale() + if code == AggregateAvg { + scale += 4 + } + return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale) +} + func (code AggregateOpcode) NeedsComparableValues() bool { switch code { case AggregateCountDistinct, AggregateSumDistinct, AggregateMin, AggregateMax: diff --git a/go/vt/vtgate/engine/opcode/constants_test.go b/go/vt/vtgate/engine/opcode/constants_test.go index 50cfc49a71c..3299317ee9c 100644 --- a/go/vt/vtgate/engine/opcode/constants_test.go +++ b/go/vt/vtgate/engine/opcode/constants_test.go @@ -25,6 +25,137 @@ import ( func TestCheckAllAggrOpCodes(t *testing.T) { // This test is just checking that we never reach the panic when using Type() on valid opcodes for i := AggregateOpcode(0); i < _NumOfOpCodes; i++ { - i.Type(sqltypes.Null) + i.SQLType(sqltypes.Null) } } +<<<<<<< HEAD +======= + +func TestType(t *testing.T) { + tt := []struct { + opcode AggregateOpcode + typ querypb.Type + out querypb.Type + }{ + {AggregateUnassigned, sqltypes.VarChar, sqltypes.Null}, + {AggregateGroupConcat, sqltypes.VarChar, sqltypes.Text}, + {AggregateGroupConcat, sqltypes.Blob, sqltypes.Blob}, + {AggregateGroupConcat, sqltypes.Unknown, sqltypes.Unknown}, + {AggregateMax, sqltypes.Int64, sqltypes.Int64}, + {AggregateMax, sqltypes.Float64, sqltypes.Float64}, + {AggregateSumDistinct, sqltypes.Unknown, sqltypes.Unknown}, + {AggregateSumDistinct, sqltypes.Int64, sqltypes.Decimal}, + {AggregateSumDistinct, sqltypes.Decimal, sqltypes.Decimal}, + {AggregateCount, sqltypes.Int32, sqltypes.Int64}, + {AggregateCountStar, sqltypes.Int64, sqltypes.Int64}, + {AggregateGtid, sqltypes.VarChar, sqltypes.VarChar}, + } + + for _, tc := range tt { + t.Run(tc.opcode.String()+"_"+tc.typ.String(), func(t *testing.T) { + out := tc.opcode.SQLType(tc.typ) + assert.Equal(t, tc.out, out) + }) + } +} + +func TestType_Panic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + errMsg, ok := r.(string) + assert.True(t, ok, "Expected a string panic message") + assert.Contains(t, errMsg, "ERROR", "Expected panic message containing 'ERROR'") + } + }() + AggregateOpcode(999).SQLType(sqltypes.VarChar) +} + +func TestNeedsListArg(t *testing.T) { + tt := []struct { + opcode PulloutOpcode + out bool + }{ + {PulloutValue, false}, + {PulloutIn, true}, + {PulloutNotIn, true}, + {PulloutExists, false}, + {PulloutNotExists, false}, + } + + for _, tc := range tt { + t.Run(tc.opcode.String(), func(t *testing.T) { + out := tc.opcode.NeedsListArg() + assert.Equal(t, tc.out, out) + }) + } +} + +func TestPulloutOpcode_MarshalJSON(t *testing.T) { + tt := []struct { + opcode PulloutOpcode + out string + }{ + {PulloutValue, "\"PulloutValue\""}, + {PulloutIn, "\"PulloutIn\""}, + {PulloutNotIn, "\"PulloutNotIn\""}, + {PulloutExists, "\"PulloutExists\""}, + {PulloutNotExists, "\"PulloutNotExists\""}, + } + + for _, tc := range tt { + t.Run(tc.opcode.String(), func(t *testing.T) { + out, err := json.Marshal(tc.opcode) + require.NoError(t, err, "Unexpected error") + assert.Equal(t, tc.out, string(out)) + }) + } +} + +func TestAggregateOpcode_MarshalJSON(t *testing.T) { + tt := []struct { + opcode AggregateOpcode + out string + }{ + {AggregateCount, "\"count\""}, + {AggregateSum, "\"sum\""}, + {AggregateMin, "\"min\""}, + {AggregateMax, "\"max\""}, + {AggregateCountDistinct, "\"count_distinct\""}, + {AggregateSumDistinct, "\"sum_distinct\""}, + {AggregateGtid, "\"vgtid\""}, + {AggregateCountStar, "\"count_star\""}, + {AggregateGroupConcat, "\"group_concat\""}, + {AggregateAnyValue, "\"any_value\""}, + {AggregateAvg, "\"avg\""}, + {999, "\"ERROR\""}, + } + + for _, tc := range tt { + t.Run(tc.opcode.String(), func(t *testing.T) { + out, err := json.Marshal(tc.opcode) + require.NoError(t, err, "Unexpected error") + assert.Equal(t, tc.out, string(out)) + }) + } +} + +func TestNeedsComparableValues(t *testing.T) { + for i := AggregateOpcode(0); i < _NumOfOpCodes; i++ { + if i == AggregateCountDistinct || i == AggregateSumDistinct || i == AggregateMin || i == AggregateMax { + assert.True(t, i.NeedsComparableValues()) + } else { + assert.False(t, i.NeedsComparableValues()) + } + } +} + +func TestIsDistinct(t *testing.T) { + for i := AggregateOpcode(0); i < _NumOfOpCodes; i++ { + if i == AggregateCountDistinct || i == AggregateSumDistinct { + assert.True(t, i.IsDistinct()) + } else { + assert.False(t, i.IsDistinct()) + } + } +} +>>>>>>> cd61d85130 (bugfix: wrong field type returned for SUM (#15192)) diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index 85800187d1f..fb7f2ba87f5 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -158,10 +158,19 @@ func (p *Projection) evalFields(env *evalengine.ExpressionEnv, infields []*query fl |= uint32(querypb.MySqlFlag_NOT_NULL_FLAG) } fields = append(fields, &querypb.Field{ +<<<<<<< HEAD Name: col, Type: q, Charset: uint32(cs), Flags: fl, +======= + Name: col, + Type: typ.Type(), + Charset: uint32(typ.Collation()), + ColumnLength: uint32(typ.Size()), + Decimals: uint32(typ.Scale()), + Flags: fl, +>>>>>>> cd61d85130 (bugfix: wrong field type returned for SUM (#15192)) }) } return fields, nil diff --git a/go/vt/vtgate/evalengine/expr_arithmetic.go b/go/vt/vtgate/evalengine/expr_arithmetic.go index 3622a270e9a..dfa360830e1 100644 --- a/go/vt/vtgate/evalengine/expr_arithmetic.go +++ b/go/vt/vtgate/evalengine/expr_arithmetic.go @@ -158,12 +158,12 @@ func (op *opArithAdd) compile(c *compiler, left, right Expr) (ctype, error) { rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) lt, rt, swap = c.compileNumericPriority(lt, rt) - var sumtype sqltypes.Type + ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric} switch lt.Type { case sqltypes.Int64: c.asm.Add_ii() - sumtype = sqltypes.Int64 + ct.Type = sqltypes.Int64 case sqltypes.Uint64: switch rt.Type { case sqltypes.Int64: @@ -171,7 +171,7 @@ func (op *opArithAdd) compile(c *compiler, left, right Expr) (ctype, error) { case sqltypes.Uint64: c.asm.Add_uu() } - sumtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Decimal: if swap { c.compileToDecimal(rt, 2) @@ -179,7 +179,8 @@ func (op *opArithAdd) compile(c *compiler, left, right Expr) (ctype, error) { c.compileToDecimal(rt, 1) } c.asm.Add_dd() - sumtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) case sqltypes.Float64: if swap { c.compileToFloat(rt, 2) @@ -187,11 +188,11 @@ func (op *opArithAdd) compile(c *compiler, left, right Expr) (ctype, error) { c.compileToFloat(rt, 1) } c.asm.Add_ff() - sumtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sumtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil + return ct, nil } func (op *opArithSub) eval(left, right eval) (eval, error) { @@ -215,66 +216,68 @@ func (op *opArithSub) compile(c *compiler, left, right Expr) (ctype, error) { lt = c.compileToNumeric(lt, 2, sqltypes.Float64, true) rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) - var subtype sqltypes.Type - + ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric} switch lt.Type { case sqltypes.Int64: switch rt.Type { case sqltypes.Int64: c.asm.Sub_ii() - subtype = sqltypes.Int64 + ct.Type = sqltypes.Int64 case sqltypes.Uint64: c.asm.Sub_iu() - subtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Float64: c.compileToFloat(lt, 2) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: c.compileToDecimal(lt, 2) c.asm.Sub_dd() - subtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) } case sqltypes.Uint64: switch rt.Type { case sqltypes.Int64: c.asm.Sub_ui() - subtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Uint64: c.asm.Sub_uu() - subtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Float64: c.compileToFloat(lt, 2) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: c.compileToDecimal(lt, 2) c.asm.Sub_dd() - subtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) } case sqltypes.Float64: c.compileToFloat(rt, 1) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: switch rt.Type { case sqltypes.Float64: c.compileToFloat(lt, 2) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 default: c.compileToDecimal(rt, 1) c.asm.Sub_dd() - subtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) } } - if subtype == 0 { + if ct.Type == 0 { panic("did not compile?") } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: subtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil + return ct, nil } func (op *opArithMul) eval(left, right eval) (eval, error) { @@ -301,12 +304,11 @@ func (op *opArithMul) compile(c *compiler, left, right Expr) (ctype, error) { rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) lt, rt, swap = c.compileNumericPriority(lt, rt) - var multype sqltypes.Type - + ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric} switch lt.Type { case sqltypes.Int64: c.asm.Mul_ii() - multype = sqltypes.Int64 + ct.Type = sqltypes.Int64 case sqltypes.Uint64: switch rt.Type { case sqltypes.Int64: @@ -314,7 +316,7 @@ func (op *opArithMul) compile(c *compiler, left, right Expr) (ctype, error) { case sqltypes.Uint64: c.asm.Mul_uu() } - multype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Float64: if swap { c.compileToFloat(rt, 2) @@ -322,7 +324,7 @@ func (op *opArithMul) compile(c *compiler, left, right Expr) (ctype, error) { c.compileToFloat(rt, 1) } c.asm.Mul_ff() - multype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: if swap { c.compileToDecimal(rt, 2) @@ -330,11 +332,12 @@ func (op *opArithMul) compile(c *compiler, left, right Expr) (ctype, error) { c.compileToDecimal(rt, 1) } c.asm.Mul_dd() - multype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = lt.Scale + rt.Scale } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: multype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil + return ct, nil } func (op *opArithDiv) eval(left, right eval) (eval, error) { @@ -370,6 +373,7 @@ func (op *opArithDiv) compile(c *compiler, left, right Expr) (ctype, error) { c.compileToDecimal(lt, 2) c.compileToDecimal(rt, 1) c.asm.Div_dd() + ct.Scale = lt.Scale + divPrecisionIncrement } c.asm.jumpDestination(skip1, skip2) return ct, nil @@ -483,7 +487,7 @@ func (op *opArithMod) compile(c *compiler, left, right Expr) (ctype, error) { lt = c.compileToNumeric(lt, 2, sqltypes.Float64, true) rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) - ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable} + ct := ctype{Col: collationNumeric, Flag: flagNullable} switch lt.Type { case sqltypes.Int64: ct.Type = sqltypes.Int64 @@ -498,6 +502,7 @@ func (op *opArithMod) compile(c *compiler, left, right Expr) (ctype, error) { c.asm.Mod_ff() case sqltypes.Decimal: ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) c.asm.Convert_xd(2, 0, 0) c.asm.Mod_dd() } @@ -514,6 +519,7 @@ func (op *opArithMod) compile(c *compiler, left, right Expr) (ctype, error) { c.asm.Mod_ff() case sqltypes.Decimal: ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) c.asm.Convert_xd(2, 0, 0) c.asm.Mod_dd() } diff --git a/go/vt/vtgate/semantics/typer.go b/go/vt/vtgate/semantics/typer.go index 6652f1a476b..53f9df23af9 100644 --- a/go/vt/vtgate/semantics/typer.go +++ b/go/vt/vtgate/semantics/typer.go @@ -18,8 +18,11 @@ package semantics import ( "vitess.io/vitess/go/mysql/collations" +<<<<<<< HEAD "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" +======= +>>>>>>> cd61d85130 (bugfix: wrong field type returned for SUM (#15192)) "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine/opcode" ) @@ -55,6 +58,7 @@ func (t *typer) up(cursor *sqlparser.Cursor) error { if !ok { return nil } +<<<<<<< HEAD var inputType sqltypes.Type if arg := node.GetArg(); arg != nil { t, ok := t.exprTypes[arg] @@ -64,6 +68,15 @@ func (t *typer) up(cursor *sqlparser.Cursor) error { } type_ := code.Type(inputType) t.exprTypes[node] = Type{Type: type_, Collation: collations.DefaultCollationForType(type_)} +======= + var inputType evalengine.Type + if arg := node.GetArg(); arg != nil { + if tt, ok := t.m[arg]; ok { + inputType = tt + } + } + t.m[node] = code.ResolveType(inputType, t.collationEnv) +>>>>>>> cd61d85130 (bugfix: wrong field type returned for SUM (#15192)) } return nil }