Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse enum/set values with sqlparser #17133

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 41 additions & 74 deletions go/vt/schema/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ limitations under the License.
package schema

import (
"fmt"
"regexp"
"strings"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtenv"
"vitess.io/vitess/go/vt/vterrors"
)

// NormalizedDDLQuery contains a query which is online-ddl -normalized
Expand Down Expand Up @@ -49,9 +53,6 @@ var (
// ALTER TABLE tbl something
regexp.MustCompile(alterTableBasicPattern + `([\S]+)\s+(.*$)`),
}

enumValuesRegexp = regexp.MustCompile("(?i)^enum[(](.*)[)]$")
setValuesRegexp = regexp.MustCompile("(?i)^set[(](.*)[)]$")
)

// ParseAlterTableOptions parses a ALTER ... TABLE... statement into:
Expand All @@ -77,91 +78,57 @@ func ParseAlterTableOptions(alterStatement string) (explicitSchema, explicitTabl
return explicitSchema, explicitTable, alterOptions
}

// ParseEnumValues parses the comma delimited part of an enum column definition
func ParseEnumValues(enumColumnType string) string {
if submatch := enumValuesRegexp.FindStringSubmatch(enumColumnType); len(submatch) > 0 {
return submatch[1]
}
return enumColumnType
}

// ParseSetValues parses the comma delimited part of a set column definition
func ParseSetValues(setColumnType string) string {
if submatch := setValuesRegexp.FindStringSubmatch(setColumnType); len(submatch) > 0 {
return submatch[1]
}
return setColumnType
}

// parseEnumOrSetTokens parses the comma delimited part of an enum/set column definition and
// returns the (unquoted) text values
// Expected input: `'x-small','small','medium','large','x-large'`
// Unexpected input: `enum('x-small','small','medium','large','x-large')`
func parseEnumOrSetTokens(enumOrSetValues string) []string {
// We need to track both the start of the current value and current
// position, since there might be quoted quotes inside the value
// which we need to handle.
start := 0
pos := 1
var tokens []string
for {
// If the input does not start with a quote, it's not a valid enum/set definition
if enumOrSetValues[start] != '\'' {
return nil
}
i := strings.IndexByte(enumOrSetValues[pos:], '\'')
// If there's no closing quote, we have invalid input
if i < 0 {
return nil
}
// We're at the end here of the last quoted value,
// so we add the last token and return them.
if i == len(enumOrSetValues[pos:])-1 {
tok, err := sqltypes.DecodeStringSQL(enumOrSetValues[start:])
if err != nil {
return nil
}
tokens = append(tokens, tok)
return tokens
}
// MySQL double quotes things as escape value, so if we see another
// single quote, we skip the character and remove it from the input.
if enumOrSetValues[pos+i+1] == '\'' {
pos = pos + i + 2
continue
}
// Next value needs to be a comma as a separator, otherwise
// the data is invalid so we return nil.
if enumOrSetValues[pos+i+1] != ',' {
return nil
}
// If we're at the end of the input here, it's invalid
// since we have a trailing comma which is not what MySQL
// returns.
if pos+i+1 == len(enumOrSetValues) {
return nil
}

tok, err := sqltypes.DecodeStringSQL(enumOrSetValues[start : pos+i+1])
func parseEnumOrSetTokens(env *vtenv.Environment, enumOrSetValues string) ([]string, error) {
// sqlparser cannot directly parse enum/set values, so we create a dummy query to parse it.
dummyQuery := fmt.Sprintf("alter table t add column e enum(%s)", enumOrSetValues)
ddlStmt, err := env.Parser().ParseStrictDDL(dummyQuery)
if err != nil {
return nil, err
}
unexpectedError := func() error {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected error parsing enum values: %v", enumOrSetValues)
}
alterTable, ok := ddlStmt.(*sqlparser.AlterTable)
if !ok {
return nil, unexpectedError()
}
if len(alterTable.AlterOptions) != 1 {
return nil, unexpectedError()
}
addColumn, ok := alterTable.AlterOptions[0].(*sqlparser.AddColumns)
if !ok {
return nil, unexpectedError()
}
if len(addColumn.Columns) != 1 {
return nil, unexpectedError()
}
enumValues := addColumn.Columns[0].Type.EnumValues
decodedEnumValues := make([]string, len(enumValues))
for i := range enumValues {
val, err := sqltypes.DecodeStringSQL(enumValues[i])
if err != nil {
return nil
return nil, err
}

tokens = append(tokens, tok)
// We add 2 to the position to skip the closing quote & comma
start = pos + i + 2
pos = start + 1
decodedEnumValues[i] = val
}
return decodedEnumValues, nil
}

// ParseEnumOrSetTokensMap parses the comma delimited part of an enum column definition
// and returns a map where [1] is the first token, and [<n>] is the last.
func ParseEnumOrSetTokensMap(enumOrSetValues string) map[int]string {
tokens := parseEnumOrSetTokens(enumOrSetValues)
func ParseEnumOrSetTokensMap(env *vtenv.Environment, enumOrSetValues string) (map[int]string, error) {
tokens, err := parseEnumOrSetTokens(env, enumOrSetValues)
if err != nil {
return nil, err
}
tokensMap := map[int]string{}
for i, token := range tokens {
// SET and ENUM values are 1 indexed.
tokensMap[i+1] = token
}
return tokensMap
return tokensMap, nil
}
109 changes: 34 additions & 75 deletions go/vt/schema/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/vt/vtenv"
)

func TestParseAlterTableOptions(t *testing.T) {
Expand Down Expand Up @@ -48,108 +51,50 @@ func TestParseAlterTableOptions(t *testing.T) {
}
}

func TestParseEnumValues(t *testing.T) {
{
inputs := []string{
`enum('x-small','small','medium','large','x-large')`,
`ENUM('x-small','small','medium','large','x-large')`,
`'x-small','small','medium','large','x-large'`,
}
for _, input := range inputs {
enumValues := ParseEnumValues(input)
assert.Equal(t, `'x-small','small','medium','large','x-large'`, enumValues)
}
}
{
inputs := []string{
``,
`abc`,
`func('x-small','small','medium','large','x-large')`,
`set('x-small','small','medium','large','x-large')`,
}
for _, input := range inputs {
enumValues := ParseEnumValues(input)
assert.Equal(t, input, enumValues)
}
}

{
inputs := []string{
``,
`abc`,
`func('x small','small','medium','large','x large')`,
`set('x small','small','medium','large','x large')`,
}
for _, input := range inputs {
enumValues := ParseEnumValues(input)
assert.Equal(t, input, enumValues)
}
}
}

func TestParseSetValues(t *testing.T) {
{
inputs := []string{
`set('x-small','small','medium','large','x-large')`,
`SET('x-small','small','medium','large','x-large')`,
`'x-small','small','medium','large','x-large'`,
}
for _, input := range inputs {
setValues := ParseSetValues(input)
assert.Equal(t, `'x-small','small','medium','large','x-large'`, setValues)
}
}
{
inputs := []string{
``,
`abc`,
`func('x-small','small','medium','large','x-large')`,
`enum('x-small','small','medium','large','x-large')`,
`ENUM('x-small','small','medium','large','x-large')`,
}
for _, input := range inputs {
setValues := ParseSetValues(input)
assert.Equal(t, input, setValues)
}
}
}

func TestParseEnumTokens(t *testing.T) {
env := vtenv.NewTestEnv()
{
input := `'x-small','small','medium','large','x-large'`
enumTokens := parseEnumOrSetTokens(input)
enumTokens, err := parseEnumOrSetTokens(env, input)
require.NoError(t, err)
expect := []string{"x-small", "small", "medium", "large", "x-large"}
assert.Equal(t, expect, enumTokens)
}
{
input := `'x small','small','medium','large','x large'`
enumTokens := parseEnumOrSetTokens(input)
enumTokens, err := parseEnumOrSetTokens(env, input)
require.NoError(t, err)
expect := []string{"x small", "small", "medium", "large", "x large"}
assert.Equal(t, expect, enumTokens)
}
{
input := `'with '' quote','and \n newline'`
enumTokens := parseEnumOrSetTokens(input)
enumTokens, err := parseEnumOrSetTokens(env, input)
require.NoError(t, err)
expect := []string{"with ' quote", "and \n newline"}
assert.Equal(t, expect, enumTokens)
}
{
input := `enum('x-small','small','medium','large','x-large')`
enumTokens := parseEnumOrSetTokens(input)
enumTokens, err := parseEnumOrSetTokens(env, input)
assert.Error(t, err)
assert.Nil(t, enumTokens)
}
{
input := `set('x-small','small','medium','large','x-large')`
enumTokens := parseEnumOrSetTokens(input)
enumTokens, err := parseEnumOrSetTokens(env, input)
assert.Error(t, err)
assert.Nil(t, enumTokens)
}
}

func TestParseEnumTokensMap(t *testing.T) {
env := vtenv.NewTestEnv()
{
input := `'x-small','small','medium','large','x-large'`

enumTokensMap := ParseEnumOrSetTokensMap(input)
enumTokensMap, err := ParseEnumOrSetTokensMap(env, input)
require.NoError(t, err)
expect := map[int]string{
1: "x-small",
2: "small",
Expand All @@ -165,9 +110,23 @@ func TestParseEnumTokensMap(t *testing.T) {
`set('x-small','small','medium','large','x-large')`,
}
for _, input := range inputs {
enumTokensMap := ParseEnumOrSetTokensMap(input)
expect := map[int]string{}
assert.Equal(t, expect, enumTokensMap)
enumTokensMap, err := ParseEnumOrSetTokensMap(env, input)
assert.Error(t, err)
assert.Nil(t, enumTokensMap)
}
}
{
input := `'x-small','small','med''ium','large','x-large'`

enumTokensMap, err := ParseEnumOrSetTokensMap(env, input)
require.NoError(t, err)
expect := map[int]string{
1: "x-small",
2: "small",
3: "med'ium",
4: "large",
5: "x-large",
}
assert.Equal(t, expect, enumTokensMap)
}
}
Loading
Loading