diff --git a/engine.go b/engine.go index e19bb7342b..df8cfa258c 100644 --- a/engine.go +++ b/engine.go @@ -34,6 +34,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression/function" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/planbuilder" + "github.com/dolthub/go-mysql-server/sql/rdparser" "github.com/dolthub/go-mysql-server/sql/rowexec" "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" @@ -192,7 +193,7 @@ func New(a *analyzer.Analyzer, cfg *Config) *Engine { PreparedDataCache: NewPreparedDataCache(), mu: &sync.Mutex{}, EventScheduler: nil, - Parser: sql.NewMysqlParser(), + Parser: rdparser.NewParser(), } ret.ReadOnly.Store(cfg.IsReadOnly) return ret diff --git a/sql/rdparser/insert.go b/sql/rdparser/insert.go new file mode 100644 index 0000000000..7f0c35502d --- /dev/null +++ b/sql/rdparser/insert.go @@ -0,0 +1,176 @@ +package rdparser + +import ( + "context" + + ast "github.com/dolthub/vitess/go/vt/sqlparser" +) + +func (p *parser) insert(ctx context.Context) (ast.Statement, bool) { + id, cur := p.tok.Scan() + ins := new(ast.Insert) + ins.Action = ast.InsertStr + if id == ast.INTO { + id, cur = p.tok.Scan() + } else if id != ast.ID { + return nil, false + } + + p.push(id, cur) + + var ok bool + ins.Table, ok = p.tableIdent() + if !ok { + return nil, false + } + + // optional () + id, cur = p.pop() + if id == '(' { + ins.Columns, ok = p.columnList(ctx) + if !ok { + return nil, false + } + id, cur = p.pop() + if id != ')' { + return nil, false + } + } + + // VALUES or SELECT + if id != ast.VALUES { + return nil, false + } + + ins.Rows, ok = p.valueList(ctx) + if !ok { + return nil, false + } + return ins, true +} + +func (p *parser) push(id int, cur []byte) { + p.curId, p.cur, p.curOk = id, cur, true +} + +func (p *parser) pop() (int, []byte) { + if p.curOk { + p.curOk = false + return p.curId, p.cur + } else { + return p.tok.Scan() + } +} + +func (p *parser) tableIdent() (ast.TableName, bool) { + // schema.database.table + + id, firstTok := p.pop() + if id != ast.ID { + return ast.TableName{}, false + } + + id, tok := p.pop() + if id != '.' { + p.push(id, tok) + return ast.TableName{Name: ast.NewTableIdent(string(firstTok))}, true + } + + id, secondTok := p.tok.Scan() + if id != ast.ID { + p.push(id, tok) + return ast.TableName{ + Name: ast.NewTableIdent(string(firstTok)), + }, true + } + + id, tok = p.pop() + if id != '.' { + p.push(id, tok) + return ast.TableName{ + DbQualifier: ast.NewTableIdent(string(firstTok)), + Name: ast.NewTableIdent(string(secondTok)), + }, true + } + + id, thirdTok := p.tok.Scan() + if id != ast.ID { + p.push(id, tok) + return ast.TableName{ + SchemaQualifier: ast.NewTableIdent(string(firstTok)), + DbQualifier: ast.NewTableIdent(string(secondTok)), + }, true + } + + return ast.TableName{ + SchemaQualifier: ast.NewTableIdent(string(firstTok)), + DbQualifier: ast.NewTableIdent(string(secondTok)), + Name: ast.NewTableIdent(string(thirdTok)), + }, true +} + +func (p *parser) columnList(ctx context.Context) (ast.Columns, bool) { + // id, ... + var cols ast.Columns + id, tok := p.pop() + for { + if id != ast.ID { + break + } + cols = append(cols, ast.NewColIdent(string(tok))) + id, tok = p.tok.Scan() + if id != ',' { + break + } + } + p.push(id, tok) + return cols, true +} + +func (p *parser) valueList(ctx context.Context) (ast.InsertRows, bool) { + var rows ast.Values + id, tok := p.pop() + for { + if id != '(' { + break + } + var row ast.ValTuple + for { + id, tok = p.pop() + if id == ',' { + id, tok = p.pop() + } + if id == ')' { + break + } + value, ok := p.value(id, tok) + if !ok { + return nil, false + } + row = append(row, value) + } + rows = append(rows, row) + id, tok = p.tok.Scan() + if id != ',' { + break + } + id, tok = p.tok.Scan() + } + p.push(id, tok) + return &ast.AliasedValues{Values: rows}, true +} + +func (p *parser) value(id int, tok []byte) (ast.Expr, bool) { + switch id { + case ast.STRING: + return ast.NewStrVal(tok), true + case ast.INTEGRAL: + return ast.NewIntVal(tok), true + case ast.FLOAT: + return ast.NewFloatVal(tok), true + case ast.NULL: + return ast.NewStrVal(tok), true + default: + return nil, false + } +} diff --git a/sql/rdparser/parser.go b/sql/rdparser/parser.go new file mode 100644 index 0000000000..2f4e62390a --- /dev/null +++ b/sql/rdparser/parser.go @@ -0,0 +1,100 @@ +package rdparser + +import ( + "context" + "fmt" + + ast "github.com/dolthub/vitess/go/vt/sqlparser" + + "github.com/dolthub/go-mysql-server/sql" +) + +type parser struct { + tok *ast.Tokenizer + curOk bool + curId int + cur []byte + peekOk bool + peekId int + _peek []byte +} + +func NewParser() sql.Parser { + return &parser{} +} + +func (p *parser) parse(ctx context.Context, s string, options ast.ParserOptions) (ret ast.Statement, err error) { + defer func() { + if mes := recover(); mes != nil { + _, ok := mes.(parseErr) + if !ok { + err = fmt.Errorf("panic encountered while parsing: %s", mes) + return + } + ret, err = ast.ParseWithOptions(ctx, s, options) + } + }() + // get next token + p.tok = ast.NewStringTokenizer(s) + if options.AnsiQuotes { + p.tok = ast.NewStringTokenizerForAnsiQuotes(s) + } + + if prePlan, ok := p.statement(ctx); ok { + return prePlan, nil + } + + return ast.ParseWithOptions(ctx, s, options) +} + +type parseErr struct { + str string +} + +func (p *parser) fail(s string) { + panic(parseErr{s}) +} + +func (p *parser) next() (int, []byte) { + if p.peekOk { + p.peekOk = false + p.curId, p.cur = p.peekId, p._peek + } + p.curOk = true + p.curId, p.cur = p.tok.Scan() + return p.curId, p.cur +} + +func (p *parser) peek() (int, []byte) { + if !p.peekOk { + p.peekOk = true + p.peekId, p._peek = p.tok.Scan() + } + return p.peekId, p._peek +} + +var _ sql.Parser = (*parser)(nil) + +func (p *parser) ParseSimple(query string) (ast.Statement, error) { + return p.parse(context.Background(), query, ast.ParserOptions{}) +} + +func (p *parser) Parse(ctx *sql.Context, query string, multi bool) (ast.Statement, string, string, error) { + return p.ParseWithOptions(ctx, query, ';', multi, ast.ParserOptions{}) +} + +func (p *parser) ParseWithOptions(ctx context.Context, query string, delimiter rune, multi bool, options ast.ParserOptions) (ast.Statement, string, string, error) { + stmt, err := p.parse(context.Background(), query, options) + if err != nil { + return nil, "", "", nil + } + return stmt, "", "", nil +} + +func (p *parser) ParseOneWithOptions(ctx context.Context, s string, options ast.ParserOptions) (ast.Statement, int, error) { + ast, err := p.parse(ctx, s, options) + if err != nil { + return nil, 0, err + } + return ast, 0, nil +} diff --git a/sql/rdparser/parser_test.go b/sql/rdparser/parser_test.go new file mode 100644 index 0000000000..9c234dd637 --- /dev/null +++ b/sql/rdparser/parser_test.go @@ -0,0 +1,47 @@ +package rdparser + +import ( + "context" + "testing" + + ast "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/stretchr/testify/require" +) + +func TestParser(t *testing.T) { + tests := []struct { + q string + exp ast.Statement + ok bool + }{ + { + q: "insert into xy values (0,'0', .0), (1,'1', 1.0)", + ok: true, + }, + { + q: "insert into xy (x,y,z) values (0,'0', .0), (1,'1', 1.0)", + ok: true, + }, + { + q: "insert into db.xy values (0,'0', .0), (1,'1', 1.0)", + ok: true, + }, + { + q: "select * from xy where x = 1", + ok: true, + }, + { + q: "select id from sbtest1 where id = 1000", + ok: true, + }, + } + for _, tt := range tests { + t.Run(tt.q, func(t *testing.T) { + p := new(parser) + p.tok = ast.NewStringTokenizer(tt.q) + res, ok := p.statement(context.Background()) + require.Equal(t, tt.ok, ok) + require.Equal(t, tt.exp, res) + }) + } +} diff --git a/sql/rdparser/select.go b/sql/rdparser/select.go new file mode 100644 index 0000000000..3366a362ef --- /dev/null +++ b/sql/rdparser/select.go @@ -0,0 +1,480 @@ +package rdparser + +import ( + ast "github.com/dolthub/vitess/go/vt/sqlparser" +) + +func (p *parser) sel() (ast.Statement, bool) { + // SELECT FROM WHERE + sel := new(ast.Select) + var ok bool + sel.SelectExprs, ok = p.selExprs() + if !ok { + return nil, false + } + + tab, ok := p.tableIdent() + if !ok { + return nil, false + } + sel.From = []ast.TableExpr{&ast.AliasedTableExpr{Expr: tab}} + + sel.Where, ok = p.whereOpt() + if !ok { + return nil, false + } + return sel, true +} + +func (p *parser) selExprs() (ast.SelectExprs, bool) { + var exprs ast.SelectExprs + id, tok := p.tok.Scan() + for { + if id == ast.FROM { + break + } + // literal + var expr ast.SelectExpr + var toAlias ast.Expr + var ok bool + switch id { + case ast.ID: + p.push(id, tok) + toAlias, ok = p.columnName() + if !ok { + return nil, false + } + case '*': + expr = &ast.StarExpr{} + case ast.STRING, ast.INTEGRAL, ast.FLOAT, ast.NULL: + toAlias, ok = p.value(id, tok) + if !ok { + return nil, false + } + default: + return nil, false + } + if toAlias != nil { + expr = &ast.AliasedExpr{Expr: toAlias} + id, tok = p.pop() + if id == ast.AS { + expr = &ast.AliasedExpr{As: ast.NewColIdent(string(tok)), Expr: toAlias} + } + } + exprs = append(exprs, expr) + } + return exprs, true +} + +func (p *parser) columnName() (*ast.ColName, bool) { + id, firstTok := p.pop() + if id != ast.ID { + return nil, false + } + + id, tok := p.pop() + if id != '.' { + p.push(id, tok) + return &ast.ColName{Name: ast.NewColIdent(string(firstTok))}, true + } + + id, secondTok := p.tok.Scan() + if id != ast.ID { + return nil, false + } + + return &ast.ColName{ + Qualifier: ast.TableName{ + Name: ast.NewTableIdent(string(firstTok)), + }, + Name: ast.NewColIdent(string(secondTok)), + }, true + + //id, tok = p.pop() + //if id == ast.AS { + // return &ast.AliasedExpr{ + // As: ast.NewColIdent(string(tok)), + // Expr: &ast.ColName{ + // Qualifier: ast.TableName{ + // Name: ast.NewTableIdent(string(firstTok)), + // }, + // Name: ast.NewColIdent(string(secondTok)), + // }, + // }, true + //} + // + //p.push(id, tok) + //return &ast.AliasedExpr{ + // Expr: &ast.ColName{ + // Qualifier: ast.TableName{ + // Name: ast.NewTableIdent(string(firstTok)), + // }, + // Name: ast.NewColIdent(string(secondTok)), + // }, + //}, true + // + // db, schema, ... +} + +func (p *parser) whereOpt() (*ast.Where, bool) { + id, tok := p.pop() + if id != ast.WHERE { + p.push(id, tok) + return nil, true + } + ret := new(ast.Where) + ret.Type = ast.WhereStr + var ok bool + ret.Expr, ok = p.expression() + if !ok { + return nil, false + } + return ret, true +} + +func (p *parser) expression() (ret ast.Expr, ok bool) { + // condition + // NOT + // DEFAULT + // valueExpression + id, _ := p.peek() + if id == ast.NOT { + p.next() + c, ok := p.expression() + if !ok { + return nil, false + } + ret = &ast.NotExpr{Expr: c} + } else if id == ast.DEFAULT { + p.next() + var d string + id, _ = p.peek() + if id == '(' { + p.next() + if ident, ok := p.id(); ok { + id, _ = p.peek() + if id != ')' { + p.next() + ret = &ast.Default{ColName: ident} + } else { + p.fail("invalid default expression") + } + } + } + ret = &ast.Default{ColName: d} + } else if ret, ok = p.condition(); ok { + return ret, true + } else if ret, ok = p.valueExpression(); ok { + return ret, true + } else { + return nil, false + } + + id, _ = p.peek() + // AND OR XOR IS + switch id { + case ast.AND: + p.next() + right, ok := p.expression() + if ok { + ret = &ast.AndExpr{Left: ret, Right: right} + } + case ast.OR: + p.next() + if right, ok := p.expression(); ok { + ret = &ast.OrExpr{Left: ret, Right: right} + } + case ast.XOR: + p.next() + if right, ok := p.expression(); ok { + ret = &ast.XorExpr{Left: ret, Right: right} + } + case ast.IS: + p.next() + if is := p.isSuffix(); is != "" { + ret = &ast.IsExpr{Operator: is, Expr: ret} + } + } + return ret, true +} + +func (p *parser) id() (string, bool) { + id, tok := p.peek() + if id == ast.ID { + p.next() + return string(tok), true + } + return "", false +} + +func (p *parser) isSuffix() string { + // (NOT) NULL|TRUE|FALSE + id, _ := p.peek() + var not bool + if id == ast.NOT { + p.next() + not = true + } + id, _ = p.peek() + switch id { + case ast.TRUE: + if not { + return ast.IsNotTrueStr + } else { + return ast.IsTrueStr + } + case ast.FALSE: + if not { + return ast.IsNotFalseStr + } else { + return ast.IsFalseStr + } + case ast.NULL: + if not { + return ast.IsNotNullStr + } else { + return ast.IsNullStr + } + default: + return "" + } +} + +func (p *parser) condition() (ret ast.Expr, ok bool) { + id, _ := p.peek() + if id == ast.EXISTS { + if subq, ok := p.subquery(); ok { + return &ast.ExistsExpr{Subquery: subq}, true + } + p.fail("expected subquery after EXISTS") + } + + left, ok := p.valueExpression() + if !ok { + return nil, false + } + + id, _ = p.peek() + var not bool + if id == ast.NOT { + p.next() + id, _ = p.peek() + not = true + } + + if comp := p.compare(); comp != "" { + right, ok := p.valueExpression() + if !ok { + return nil, false + } + ret = &ast.ComparisonExpr{Operator: comp, Left: left, Right: right} + } else { + switch id { + case ast.IN: + right, ok := p.colTuple() + if !ok { + return nil, false + } + ret = &ast.ComparisonExpr{Operator: ast.InStr, Left: left, Right: right} + case ast.LIKE: + right, ok := p.valueExpression() + if !ok { + return nil, false + } + esc := p.likeEscapeOpt() + ret = &ast.ComparisonExpr{Operator: ast.LikeStr, Left: left, Right: right, Escape: esc} + case ast.REGEXP: + right, ok := p.valueExpression() + if !ok { + return nil, false + } + ret = &ast.ComparisonExpr{Operator: ast.RegexpStr, Left: left, Right: right} + case ast.BETWEEN: + from, ok := p.valueExpression() + if !ok { + return nil, false + } + id, _ = p.next() + if id != ast.AND { + p.fail("between expected AND") + } + to, ok := p.valueExpression() + if !ok { + return nil, false + } + ret = &ast.RangeCond{Operator: ast.BetweenStr, Left: left, From: from, To: to} + } + } + + if not { + ret = &ast.NotExpr{Expr: ret} + } + return ret, true +} + +func (p *parser) valueExpression() (ret ast.Expr, ok bool) { + // value + // ACCOUNT, FORMAT + // boolean_value + // column_name + // column_name_safe_keyword + // tuple_expression + // subquery + // BINARY + id, tok := p.peek() + + if id == ast.FORMAT || id == ast.ACCOUNT { + return &ast.ColName{Name: ast.NewColIdent(string(tok))}, true + } else if v, ok := p.value(id, tok); ok { + return v, true + } else if id == ast.TRUE || id == ast.FALSE { + return ast.BoolVal(id == ast.TRUE), true + } else if col, ok := p.columnName(); ok { + p.next() + id, _ = p.peek() + switch id { + case ast.JSON_EXTRACT_OP: + p.next() + id, tok := p.next() + val, ok := p.value(id, tok) + if !ok { + return nil, false + } + return &ast.BinaryExpr{Operator: ast.JSONExtractOp, Left: col, Right: val}, true + case ast.JSON_UNQUOTE_EXTRACT_OP: + p.next() + id, tok := p.next() + val, ok := p.value(id, tok) + if !ok { + return nil, false + } + return &ast.BinaryExpr{Operator: ast.JSONUnquoteExtractOp, Left: col, Right: val}, true + default: + return col, true + } + return col, true + } else if col, ok := p.columnNameSafeKeyWord(); ok { + return col, true + } else if tup, ok := p.tupleExpression(); ok { + return tup, true + } else if subq, ok := p.subquery(); ok { + return subq, true + } + + // underscore_charsets valueExpr UNARY + // + valueExpr UNARY + // - valueExpr UNARY + // ! valueExpr UNARY + // ~ valueExpr + // INTERVAL value_expression sql_id + + // function_call_generic + // function_call_keyword + // function_call_nonkeyword + // function_call_conflict + // function_call_window + // function_call_aggregate_with_window + + // valueExpr op valueExpr + left, ok := p.valueExpression() + if ok { + return nil, false + } + switch id { + case '+': + case '-': + case '*': + case '/': + case '^': + case '&': + case '|': + case ast.DIV: + case '%': + case ast.MOD: + case ast.SHIFT_LEFT: + case ast.SHIFT_RIGHT: + case ast.COLLATE: + } +} + +func (p *parser) compare() (ret string) { + id, _ := p.peek() + switch id { + case '=': + ret = ast.EqualStr + case '<': + ret = ast.LessThanStr + case '>': + ret = ast.GreaterThanStr + case ast.LE: + ret = ast.LessEqualStr + case ast.GE: + ret = ast.GreaterEqualStr + case ast.NE: + ret = ast.NotEqualStr + case ast.NULL_SAFE_EQUAL: + ret = ast.NullSafeEqualStr + default: + return "" + } + p.next() + return ret +} + +func (p *parser) colTuple() (ast.Expr, bool) { + id, tok := p.peek() + if id == ast.LIST_ARG { + return ast.ListArg(tok), true + } else if id != '(' { + return nil, false + } + p.next() + id, _ = p.peek() + + if id == ast.SELECT { + selStmt, ok := p.selNoInto() + if !ok { + return nil, false + } + p.next() + id, _ = p.next() + if id != ')' { + p.fail("expected subquery to end with ')'") + } + p.next() + return &ast.Subquery{Select: selStmt}, true + } + + //expr list + var tup ast.ValTuple + for { + e, ok := p.expression() + if !ok { + return nil, false + } + tup = append(tup, e) + id, _ = p.next() + if id == ')' { + break + } else if id != ',' { + p.fail("invalid expression list") + } + } + return tup, true +} + +func (p *parser) selNoInto() (*ast.Select, bool) { +} + +func (p *parser) likeEscapeOpt() ast.Expr { + id, _ := p.peek() + if id != ast.ESCAPE { + return nil + } + p.next() + ret, ok := p.valueExpression() + if !ok { + p.fail("expected value expression after ESCAPE") + } + return ret +} diff --git a/sql/rdparser/statement.go b/sql/rdparser/statement.go new file mode 100644 index 0000000000..5774b0b257 --- /dev/null +++ b/sql/rdparser/statement.go @@ -0,0 +1,19 @@ +package rdparser + +import ( + "context" + + ast "github.com/dolthub/vitess/go/vt/sqlparser" +) + +func (p *parser) statement(ctx context.Context) (ast.Statement, bool) { + id, _ := p.tok.Scan() + switch id { + case ast.INSERT: + return p.insert(ctx) + case ast.SELECT: + return p.sel() + default: + return nil, false + } +} diff --git a/sql/rdparser/subquery.go b/sql/rdparser/subquery.go new file mode 100644 index 0000000000..7ee71b4ddf --- /dev/null +++ b/sql/rdparser/subquery.go @@ -0,0 +1,16 @@ +package rdparser + +import ast "github.com/dolthub/vitess/go/vt/sqlparser" + +func (p *parser) subquery() (*ast.Subquery, bool) { + id, _ := p.peek() + if id != '(' { + return nil, false + } + p.next() + selNoInt, ok := p.sel() + if !ok { + return nil, false + } + return &ast.Subquery{Select: selNoInt}, true +}