Skip to content

Commit

Permalink
Handle encoding binary data separately (#16988)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink authored Oct 23, 2024
1 parent 17607fa commit b0b7981
Show file tree
Hide file tree
Showing 31 changed files with 5,210 additions and 5,229 deletions.
75 changes: 49 additions & 26 deletions go/mysql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,11 @@ func (c *Conn) parseStmtArgs(data []byte, typ querypb.Type, pos int) (sqltypes.V
}
switch size {
case 0x00:
return sqltypes.NewVarChar(" "), pos, ok
out := []byte("0000-00-00")
if typ != sqltypes.Date {
out = append(out, []byte(" 00:00:00")...)
}
return sqltypes.MakeTrusted(typ, out), pos, ok
case 0x0b:
year, pos, ok := readUint16(data, pos)
if !ok {
Expand Down Expand Up @@ -743,15 +747,22 @@ func (c *Conn) parseStmtArgs(data []byte, typ querypb.Type, pos int) (sqltypes.V
if !ok {
return sqltypes.NULL, 0, false
}
val := strconv.Itoa(int(year)) + "-" +
strconv.Itoa(int(month)) + "-" +
strconv.Itoa(int(day)) + " " +
strconv.Itoa(int(hour)) + ":" +
strconv.Itoa(int(minute)) + ":" +
strconv.Itoa(int(second)) + "." +
fmt.Sprintf("%06d", microSecond)

return sqltypes.NewVarChar(val), pos, ok
val := strconv.AppendInt(nil, int64(year), 10)
val = append(val, '-')
val = strconv.AppendInt(val, int64(month), 10)
val = append(val, '-')
val = strconv.AppendInt(val, int64(day), 10)
if typ != sqltypes.Date {
val = append(val, ' ')
val = strconv.AppendInt(val, int64(hour), 10)
val = append(val, ':')
val = strconv.AppendInt(val, int64(minute), 10)
val = append(val, ':')
val = strconv.AppendInt(val, int64(second), 10)
val = append(val, '.')
val = append(val, fmt.Sprintf("%06d", microSecond)...)
}
return sqltypes.MakeTrusted(typ, val), pos, ok
case 0x07:
year, pos, ok := readUint16(data, pos)
if !ok {
Expand All @@ -777,14 +788,21 @@ func (c *Conn) parseStmtArgs(data []byte, typ querypb.Type, pos int) (sqltypes.V
if !ok {
return sqltypes.NULL, 0, false
}
val := strconv.Itoa(int(year)) + "-" +
strconv.Itoa(int(month)) + "-" +
strconv.Itoa(int(day)) + " " +
strconv.Itoa(int(hour)) + ":" +
strconv.Itoa(int(minute)) + ":" +
strconv.Itoa(int(second))

return sqltypes.NewVarChar(val), pos, ok
val := strconv.AppendInt(nil, int64(year), 10)
val = append(val, '-')
val = strconv.AppendInt(val, int64(month), 10)
val = append(val, '-')
val = strconv.AppendInt(val, int64(day), 10)
if typ != sqltypes.Date {
val = append(val, ' ')
val = strconv.AppendInt(val, int64(hour), 10)
val = append(val, ':')
val = strconv.AppendInt(val, int64(minute), 10)
val = append(val, ':')
val = strconv.AppendInt(val, int64(second), 10)
}

return sqltypes.MakeTrusted(typ, val), pos, ok
case 0x04:
year, pos, ok := readUint16(data, pos)
if !ok {
Expand All @@ -798,11 +816,16 @@ func (c *Conn) parseStmtArgs(data []byte, typ querypb.Type, pos int) (sqltypes.V
if !ok {
return sqltypes.NULL, 0, false
}
val := strconv.Itoa(int(year)) + "-" +
strconv.Itoa(int(month)) + "-" +
strconv.Itoa(int(day))
val := strconv.AppendInt(nil, int64(year), 10)
val = append(val, '-')
val = strconv.AppendInt(val, int64(month), 10)
val = append(val, '-')
val = strconv.AppendInt(val, int64(day), 10)
if typ != sqltypes.Date {
val = append(val, []byte(" 00:00:00")...)
}

return sqltypes.NewVarChar(val), pos, ok
return sqltypes.MakeTrusted(typ, val), pos, ok
default:
return sqltypes.NULL, 0, false
}
Expand All @@ -813,7 +836,7 @@ func (c *Conn) parseStmtArgs(data []byte, typ querypb.Type, pos int) (sqltypes.V
}
switch size {
case 0x00:
return sqltypes.NewVarChar("00:00:00"), pos, ok
return sqltypes.NewTime("00:00:00"), pos, ok
case 0x0c:
isNegative, pos, ok := readByte(data, pos)
if !ok {
Expand Down Expand Up @@ -852,7 +875,7 @@ func (c *Conn) parseStmtArgs(data []byte, typ querypb.Type, pos int) (sqltypes.V
strconv.Itoa(int(second)) + "." +
fmt.Sprintf("%06d", microSecond)

return sqltypes.NewVarChar(val), pos, ok
return sqltypes.NewTime(val), pos, ok
case 0x08:
isNegative, pos, ok := readByte(data, pos)
if !ok {
Expand Down Expand Up @@ -886,14 +909,14 @@ func (c *Conn) parseStmtArgs(data []byte, typ querypb.Type, pos int) (sqltypes.V
strconv.Itoa(int(minute)) + ":" +
strconv.Itoa(int(second))

return sqltypes.NewVarChar(val), pos, ok
return sqltypes.NewTime(val), pos, ok
default:
return sqltypes.NULL, 0, false
}
case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar, sqltypes.VarBinary, sqltypes.Year, sqltypes.Char,
sqltypes.Bit, sqltypes.Enum, sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON, sqltypes.Vector:
val, pos, ok := readLenEncStringAsBytesCopy(data, pos)
return sqltypes.MakeTrusted(sqltypes.VarBinary, val), pos, ok
return sqltypes.MakeTrusted(typ, val), pos, ok
default:
return sqltypes.NULL, pos, false
}
Expand Down
45 changes: 22 additions & 23 deletions go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package sqltypes

import (
"bytes"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
Expand Down Expand Up @@ -441,6 +440,8 @@ func (v Value) EncodeSQL(b BinWriter) {
switch {
case v.Type() == Null:
b.Write(NullBytes)
case v.IsBinary():
encodeBinarySQL(v.val, b)
case v.IsQuoted():
encodeBytesSQL(v.val, b)
case v.Type() == Bit:
Expand All @@ -456,6 +457,8 @@ func (v Value) EncodeSQLStringBuilder(b *strings.Builder) {
switch {
case v.Type() == Null:
b.Write(NullBytes)
case v.IsBinary():
encodeBinarySQLStringBuilder(v.val, b)
case v.IsQuoted():
encodeBytesSQLStringBuilder(v.val, b)
case v.Type() == Bit:
Expand All @@ -482,6 +485,8 @@ func (v Value) EncodeSQLBytes2(b *bytes2.Buffer) {
switch {
case v.Type() == Null:
b.Write(NullBytes)
case v.IsBinary():
encodeBinarySQLBytes2(v.val, b)
case v.IsQuoted():
encodeBytesSQLBytes2(v.val, b)
case v.Type() == Bit:
Expand All @@ -491,18 +496,6 @@ func (v Value) EncodeSQLBytes2(b *bytes2.Buffer) {
}
}

// EncodeASCII encodes the value using 7-bit clean ascii bytes.
func (v Value) EncodeASCII(b BinWriter) {
switch {
case v.Type() == Null:
b.Write(NullBytes)
case v.IsQuoted() || v.Type() == Bit:
encodeBytesASCII(v.val, b)
default:
b.Write(v.val)
}
}

// IsNull returns true if Value is null.
func (v Value) IsNull() bool {
return v.Type() == Null
Expand Down Expand Up @@ -758,6 +751,22 @@ func (v Value) TinyWeight() uint32 {
return v.tinyweight
}

func encodeBinarySQL(val []byte, b BinWriter) {
buf := &bytes2.Buffer{}
encodeBinarySQLBytes2(val, buf)
b.Write(buf.Bytes())
}

func encodeBinarySQLBytes2(val []byte, buf *bytes2.Buffer) {
buf.Write([]byte("_binary"))
encodeBytesSQLBytes2(val, buf)
}

func encodeBinarySQLStringBuilder(val []byte, buf *strings.Builder) {
buf.Write([]byte("_binary"))
encodeBytesSQLStringBuilder(val, buf)
}

func encodeBytesSQL(val []byte, b BinWriter) {
buf := &bytes2.Buffer{}
encodeBytesSQLBytes2(val, buf)
Expand Down Expand Up @@ -838,16 +847,6 @@ func encodeBytesSQLBits(val []byte, b BinWriter) {
fmt.Fprint(b, "'")
}

func encodeBytesASCII(val []byte, b BinWriter) {
buf := &bytes2.Buffer{}
buf.WriteByte('\'')
encoder := base64.NewEncoder(base64.StdEncoding, buf)
encoder.Write(val)
encoder.Close()
buf.WriteByte('\'')
b.Write(buf.Bytes())
}

// SQLEncodeMap specifies how to escape binary data with '\'.
// Complies to https://dev.mysql.com/doc/refman/5.7/en/string-literals.html
// Handling escaping of % and _ is different than other characters.
Expand Down
35 changes: 16 additions & 19 deletions go/sqltypes/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,34 +367,28 @@ func TestEncode(t *testing.T) {
outSQL string
outASCII string
}{{
in: NULL,
outSQL: "null",
outASCII: "null",
in: NULL,
outSQL: "null",
}, {
in: TestValue(Int64, "1"),
outSQL: "1",
outASCII: "1",
in: TestValue(Int64, "1"),
outSQL: "1",
}, {
in: TestValue(VarChar, "foo"),
outSQL: "'foo'",
outASCII: "'Zm9v'",
in: TestValue(VarChar, "foo"),
outSQL: "'foo'",
}, {
in: TestValue(VarChar, "\x00'\"\b\n\r\t\x1A\\"),
outSQL: "'\\0\\'\"\\b\\n\\r\\t\\Z\\\\'",
outASCII: "'ACciCAoNCRpc'",
in: TestValue(VarChar, "\x00'\"\b\n\r\t\x1A\\"),
outSQL: "'\\0\\'\"\\b\\n\\r\\t\\Z\\\\'",
}, {
in: TestValue(VarBinary, "\x00'\"\b\n\r\t\x1A\\"),
outSQL: "_binary'\\0\\'\"\\b\\n\\r\\t\\Z\\\\'",
}, {
in: TestValue(Bit, "a"),
outSQL: "b'01100001'",
outASCII: "'YQ=='",
in: TestValue(Bit, "a"),
outSQL: "b'01100001'",
}}
for _, tcase := range testcases {
var buf strings.Builder
tcase.in.EncodeSQL(&buf)
assert.Equal(t, tcase.outSQL, buf.String())

buf.Reset()
tcase.in.EncodeASCII(&buf)
assert.Equal(t, tcase.outASCII, buf.String())
}
}

Expand Down Expand Up @@ -639,6 +633,9 @@ func TestEncodeSQLStringBuilder(t *testing.T) {
}, {
in: TestTuple(TestValue(Int64, "1"), TestValue(VarChar, "foo")),
outSQL: "(1, 'foo')",
}, {
in: TestValue(VarBinary, "foo"),
outSQL: "_binary'foo'",
}}
for _, tcase := range testcases {
var buf strings.Builder
Expand Down
3 changes: 1 addition & 2 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -200,7 +199,7 @@ func TestHighNumberOfParams(t *testing.T) {
var vals []any
var params []string
for i := 0; i < paramCount; i++ {
vals = append(vals, strconv.Itoa(i))
vals = append(vals, i)
params = append(params, "?")
}

Expand Down
40 changes: 23 additions & 17 deletions go/test/endtoend/vtgate/queries/vexplain/vexplain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package vexplain

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -75,24 +76,29 @@ func TestVtGateVExplain(t *testing.T) {
`vexplain queries insert into user (id,lookup,lookup_unique) values (4,'apa','foo'),(5,'apa','bar'),(6,'monkey','nobar')`,
"vexplain queries/all will actually run queries")

expected := `[
binaryPrefix := ""
if utils.BinaryIsAtLeastAtVersion(22, "vtgate") {
binaryPrefix = "_binary"
}

expected := fmt.Sprintf(`[
[VARCHAR("ks") VARCHAR("-40") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("insert into lookup(lookup, id, keyspace_id) values ('apa', 1, '\x16k@\xb4J\xbaK\xd6') on duplicate key update lookup = values(lookup), id = values(id), keyspace_id = values(keyspace_id)")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("insert into lookup(lookup, id, keyspace_id) values ('apa', 1, %s'\x16k@\xb4J\xbaK\xd6') on duplicate key update lookup = values(lookup), id = values(id), keyspace_id = values(keyspace_id)")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("insert into lookup(lookup, id, keyspace_id) values ('monkey', 3, 'N\xb1\x90ɢ\xfa\x16\x9c') on duplicate key update lookup = values(lookup), id = values(id), keyspace_id = values(keyspace_id)")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("insert into lookup(lookup, id, keyspace_id) values ('monkey', 3, %s'N\xb1\x90ɢ\xfa\x16\x9c') on duplicate key update lookup = values(lookup), id = values(id), keyspace_id = values(keyspace_id)")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("commit")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("commit")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("insert into lookup_unique(lookup_unique, keyspace_id) values ('monkey', 'N\xb1\x90ɢ\xfa\x16\x9c')")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("insert into lookup_unique(lookup_unique, keyspace_id) values ('monkey', %s'N\xb1\x90ɢ\xfa\x16\x9c')")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("insert into lookup_unique(lookup_unique, keyspace_id) values ('apa', '\x16k@\xb4J\xbaK\xd6')")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("insert into lookup_unique(lookup_unique, keyspace_id) values ('apa', %s'\x16k@\xb4J\xbaK\xd6')")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("commit")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("commit")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("insert into ` + "`user`" + `(id, lookup, lookup_unique) values (3, 'monkey', 'monkey')")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("insert into `+"`user`"+`(id, lookup, lookup_unique) values (3, 'monkey', 'monkey')")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("insert into ` + "`user`" + `(id, lookup, lookup_unique) values (1, 'apa', 'apa')")]
]`
[VARCHAR("ks") VARCHAR("-40") VARCHAR("insert into `+"`user`"+`(id, lookup, lookup_unique) values (1, 'apa', 'apa')")]
]`, binaryPrefix, binaryPrefix, binaryPrefix, binaryPrefix)
assertVExplainEquals(t, conn, `vexplain /*vt+ EXECUTE_DML_QUERIES */ queries insert into user (id,lookup,lookup_unique) values (1,'apa','apa'),(3,'monkey','monkey')`, expected)

// Assert that the output of vexplain all doesn't have begin queries because they aren't explainable
Expand All @@ -109,27 +115,27 @@ func TestVtGateVExplain(t *testing.T) {

// transaction explicitly started to no commit in the end.
utils.Exec(t, conn, "begin")
expected = `[
expected = fmt.Sprintf(`[
[VARCHAR("ks") VARCHAR("-40") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("insert into lookup(lookup, id, keyspace_id) values ('apa', 4, '\xd2\xfd\x88g\xd5\\r-\xfe'), ('apa', 5, 'p\xbb\x02<\x81\f\xa8z') on duplicate key update lookup = values(lookup), id = values(id), keyspace_id = values(keyspace_id)")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("insert into lookup(lookup, id, keyspace_id) values ('apa', 4, %s'\xd2\xfd\x88g\xd5\\r-\xfe'), ('apa', 5, %s'p\xbb\x02<\x81\f\xa8z') on duplicate key update lookup = values(lookup), id = values(id), keyspace_id = values(keyspace_id)")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("insert into lookup(lookup, id, keyspace_id) values ('monkey', 6, '\xf0\x98H\\n\xc4ľq') on duplicate key update lookup = values(lookup), id = values(id), keyspace_id = values(keyspace_id)")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("insert into lookup(lookup, id, keyspace_id) values ('monkey', 6, %s'\xf0\x98H\\n\xc4ľq') on duplicate key update lookup = values(lookup), id = values(id), keyspace_id = values(keyspace_id)")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("commit")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("commit")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("insert into lookup_unique(lookup_unique, keyspace_id) values ('foo', '\xd2\xfd\x88g\xd5\\r-\xfe')")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("insert into lookup_unique(lookup_unique, keyspace_id) values ('foo', %s'\xd2\xfd\x88g\xd5\\r-\xfe')")]
[VARCHAR("ks") VARCHAR("80-c0") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("80-c0") VARCHAR("insert into lookup_unique(lookup_unique, keyspace_id) values ('bar', 'p\xbb\x02<\x81\f\xa8z')")]
[VARCHAR("ks") VARCHAR("80-c0") VARCHAR("insert into lookup_unique(lookup_unique, keyspace_id) values ('bar', %s'p\xbb\x02<\x81\f\xa8z')")]
[VARCHAR("ks") VARCHAR("c0-") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("c0-") VARCHAR("insert into lookup_unique(lookup_unique, keyspace_id) values ('nobar', '\xf0\x98H\\n\xc4ľq')")]
[VARCHAR("ks") VARCHAR("c0-") VARCHAR("insert into lookup_unique(lookup_unique, keyspace_id) values ('nobar', %s'\xf0\x98H\\n\xc4ľq')")]
[VARCHAR("ks") VARCHAR("-40") VARCHAR("commit")]
[VARCHAR("ks") VARCHAR("80-c0") VARCHAR("commit")]
[VARCHAR("ks") VARCHAR("c0-") VARCHAR("commit")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("insert into ` + "`user`" + `(id, lookup, lookup_unique) values (5, 'apa', 'bar')")]
[VARCHAR("ks") VARCHAR("40-80") VARCHAR("insert into `+"`user`"+`(id, lookup, lookup_unique) values (5, 'apa', 'bar')")]
[VARCHAR("ks") VARCHAR("c0-") VARCHAR("begin")]
[VARCHAR("ks") VARCHAR("c0-") VARCHAR("insert into ` + "`user`" + `(id, lookup, lookup_unique) values (4, 'apa', 'foo'), (6, 'monkey', 'nobar')")]
]`
[VARCHAR("ks") VARCHAR("c0-") VARCHAR("insert into `+"`user`"+`(id, lookup, lookup_unique) values (4, 'apa', 'foo'), (6, 'monkey', 'nobar')")]
]`, binaryPrefix, binaryPrefix, binaryPrefix, binaryPrefix, binaryPrefix, binaryPrefix)
assertVExplainEquals(t, conn, `vexplain /*vt+ EXECUTE_DML_QUERIES */ queries insert into user (id,lookup,lookup_unique) values (4,'apa','foo'),(5,'apa','bar'),(6,'monkey','nobar')`, expected)

utils.Exec(t, conn, "rollback")
Expand Down
2 changes: 1 addition & 1 deletion go/vt/sqlparser/ast_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,7 @@ func (node *AssignmentExpr) Format(buf *TrackedBuffer) {
func (node *Literal) Format(buf *TrackedBuffer) {
switch node.Type {
case StrVal:
sqltypes.MakeTrusted(sqltypes.VarBinary, node.Bytes()).EncodeSQL(buf)
sqltypes.MakeTrusted(sqltypes.VarChar, node.Bytes()).EncodeSQL(buf)
case IntVal, FloatVal, DecimalVal, HexNum, BitNum:
buf.astPrintf(node, "%#s", node.Val)
case HexVal:
Expand Down
Loading

0 comments on commit b0b7981

Please sign in to comment.