From d34c5dac672b1a185ee6ef234fc1d4a2fb82887f Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Tue, 10 Sep 2024 15:23:42 -0700 Subject: [PATCH 1/4] recursive descent parser for simple INSERT/SELECT --- engine.go | 3 +- sql/rdparser/insert.go | 175 +++++++++++++++++++++++++++++++++++ sql/rdparser/parser.go | 58 ++++++++++++ sql/rdparser/parser_test.go | 46 +++++++++ sql/rdparser/select.go | 179 ++++++++++++++++++++++++++++++++++++ sql/rdparser/statement.go | 18 ++++ 6 files changed, 478 insertions(+), 1 deletion(-) create mode 100644 sql/rdparser/insert.go create mode 100644 sql/rdparser/parser.go create mode 100644 sql/rdparser/parser_test.go create mode 100644 sql/rdparser/select.go create mode 100644 sql/rdparser/statement.go diff --git a/engine.go b/engine.go index e19bb7342b..7a6902c3bb 100644 --- a/engine.go +++ b/engine.go @@ -16,6 +16,7 @@ package sqle import ( "fmt" + "github.com/dolthub/go-mysql-server/sql/rdparser" "os" "strconv" "strings" @@ -192,7 +193,7 @@ func New(a *analyzer.Analyzer, cfg *Config) *Engine { PreparedDataCache: NewPreparedDataCache(), mu: &sync.Mutex{}, EventScheduler: nil, - Parser: sql.NewMysqlParser(), + Parser: rdparser.NewParser(), } ret.ReadOnly.Store(cfg.IsReadOnly) return ret diff --git a/sql/rdparser/insert.go b/sql/rdparser/insert.go new file mode 100644 index 0000000000..9784b551aa --- /dev/null +++ b/sql/rdparser/insert.go @@ -0,0 +1,175 @@ +package rdparser + +import ( + "context" + ast "github.com/dolthub/vitess/go/vt/sqlparser" +) + +func (p *parser) insert(ctx context.Context) (ast.Statement, bool) { + id, cur := p.tok.Scan() + ins := new(ast.Insert) + ins.Action = ast.InsertStr + if id == ast.INTO { + id, cur = p.tok.Scan() + } else if id != ast.ID { + return nil, false + } + + p.push(id, cur) + + var ok bool + ins.Table, ok = p.tableIdent(ctx) + if !ok { + return nil, false + } + + // optional () + id, cur = p.pop() + if id == '(' { + ins.Columns, ok = p.columnList(ctx) + if !ok { + return nil, false + } + id, cur = p.pop() + if id != ')' { + return nil, false + } + } + + // VALUES or SELECT + if id != ast.VALUES { + return nil, false + } + + ins.Rows, ok = p.valueList(ctx) + if !ok { + return nil, false + } + return ins, true +} + +func (p *parser) push(id int, cur []byte) { + p.curId, p.cur, p.curOk = id, cur, true +} + +func (p *parser) pop() (int, []byte) { + if p.curOk { + p.curOk = false + return p.curId, p.cur + } else { + return p.tok.Scan() + } +} + +func (p *parser) tableIdent(ctx context.Context) (ast.TableName, bool) { + // schema.database.table + + id, firstTok := p.pop() + if id != ast.ID { + return ast.TableName{}, false + } + + id, tok := p.pop() + if id != '.' { + p.push(id, tok) + return ast.TableName{Name: ast.NewTableIdent(string(firstTok))}, true + } + + id, secondTok := p.tok.Scan() + if id != ast.ID { + p.push(id, tok) + return ast.TableName{ + Name: ast.NewTableIdent(string(firstTok)), + }, true + } + + id, tok = p.pop() + if id != '.' { + p.push(id, tok) + return ast.TableName{ + DbQualifier: ast.NewTableIdent(string(firstTok)), + Name: ast.NewTableIdent(string(secondTok)), + }, true + } + + id, thirdTok := p.tok.Scan() + if id != ast.ID { + p.push(id, tok) + return ast.TableName{ + SchemaQualifier: ast.NewTableIdent(string(firstTok)), + DbQualifier: ast.NewTableIdent(string(secondTok)), + }, true + } + + return ast.TableName{ + SchemaQualifier: ast.NewTableIdent(string(firstTok)), + DbQualifier: ast.NewTableIdent(string(secondTok)), + Name: ast.NewTableIdent(string(thirdTok)), + }, true +} + +func (p *parser) columnList(ctx context.Context) (ast.Columns, bool) { + // id, ... + var cols ast.Columns + id, tok := p.pop() + for { + if id != ast.ID { + break + } + cols = append(cols, ast.NewColIdent(string(tok))) + id, tok = p.tok.Scan() + if id != ',' { + break + } + } + p.push(id, tok) + return cols, true +} + +func (p *parser) valueList(ctx context.Context) (ast.InsertRows, bool) { + var rows ast.Values + id, tok := p.pop() + for { + if id != '(' { + break + } + var row ast.ValTuple + for { + id, tok = p.pop() + if id == ',' { + id, tok = p.pop() + } + if id == ')' { + break + } + value, ok := p.value(ctx, id, tok) + if !ok { + return nil, false + } + row = append(row, value) + } + rows = append(rows, row) + id, tok = p.tok.Scan() + if id != ',' { + break + } + id, tok = p.tok.Scan() + } + p.push(id, tok) + return &ast.AliasedValues{Values: rows}, true +} + +func (p *parser) value(ctx context.Context, id int, tok []byte) (ast.Expr, bool) { + switch id { + case ast.STRING: + return ast.NewStrVal(tok), true + case ast.INTEGRAL: + return ast.NewIntVal(tok), true + case ast.FLOAT: + return ast.NewFloatVal(tok), true + case ast.NULL: + return ast.NewStrVal(tok), true + default: + return nil, false + } +} diff --git a/sql/rdparser/parser.go b/sql/rdparser/parser.go new file mode 100644 index 0000000000..2d49e158b2 --- /dev/null +++ b/sql/rdparser/parser.go @@ -0,0 +1,58 @@ +package rdparser + +import ( + "context" + "github.com/dolthub/go-mysql-server/sql" + ast "github.com/dolthub/vitess/go/vt/sqlparser" +) + +type parser struct { + tok *ast.Tokenizer + curOk bool + curId int + cur []byte +} + +func NewParser() sql.Parser { + return &parser{} +} + +func (p *parser) parse(ctx context.Context, s string, options ast.ParserOptions) (ast.Statement, error) { + // get next token + p.tok = ast.NewStringTokenizer(s) + if options.AnsiQuotes { + p.tok = ast.NewStringTokenizerForAnsiQuotes(s) + } + + if prePlan, ok := p.statement(ctx); ok { + return prePlan, nil + } + + return ast.ParseWithOptions(ctx, s, options) +} + +var _ sql.Parser = (*parser)(nil) + +func (p *parser) ParseSimple(query string) (ast.Statement, error) { + return p.parse(context.Background(), query, ast.ParserOptions{}) +} + +func (p *parser) Parse(ctx *sql.Context, query string, multi bool) (ast.Statement, string, string, error) { + return p.ParseWithOptions(ctx, query, ';', multi, ast.ParserOptions{}) +} + +func (p *parser) ParseWithOptions(ctx context.Context, query string, delimiter rune, multi bool, options ast.ParserOptions) (ast.Statement, string, string, error) { + stmt, err := p.parse(context.Background(), query, options) + if err != nil { + return nil, "", "", nil + } + return stmt, "", "", nil +} + +func (p *parser) ParseOneWithOptions(ctx context.Context, s string, options ast.ParserOptions) (ast.Statement, int, error) { + ast, err := p.parse(ctx, s, options) + if err != nil { + return nil, 0, err + } + return ast, 0, nil +} diff --git a/sql/rdparser/parser_test.go b/sql/rdparser/parser_test.go new file mode 100644 index 0000000000..b681d206cd --- /dev/null +++ b/sql/rdparser/parser_test.go @@ -0,0 +1,46 @@ +package rdparser + +import ( + "context" + ast "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/stretchr/testify/require" + "testing" +) + +func TestParser(t *testing.T) { + tests := []struct { + q string + exp ast.Statement + ok bool + }{ + { + q: "insert into xy values (0,'0', .0), (1,'1', 1.0)", + ok: true, + }, + { + q: "insert into xy (x,y,z) values (0,'0', .0), (1,'1', 1.0)", + ok: true, + }, + { + q: "insert into db.xy values (0,'0', .0), (1,'1', 1.0)", + ok: true, + }, + { + q: "select * from xy where x = 1", + ok: true, + }, + { + q: "select id from sbtest1 where id = 1000", + ok: true, + }, + } + for _, tt := range tests { + t.Run(tt.q, func(t *testing.T) { + p := new(parser) + p.tok = ast.NewStringTokenizer(tt.q) + res, ok := p.statement(context.Background()) + require.Equal(t, tt.ok, ok) + require.Equal(t, tt.exp, res) + }) + } +} diff --git a/sql/rdparser/select.go b/sql/rdparser/select.go new file mode 100644 index 0000000000..6411cfdd1d --- /dev/null +++ b/sql/rdparser/select.go @@ -0,0 +1,179 @@ +package rdparser + +import ( + "context" + ast "github.com/dolthub/vitess/go/vt/sqlparser" +) + +func (p *parser) sel(ctx context.Context) (ast.Statement, bool) { + // SELECT FROM WHERE + sel := new(ast.Select) + var ok bool + sel.SelectExprs, ok = p.selExprs(ctx) + if !ok { + return nil, false + } + + tab, ok := p.tableIdent(ctx) + if !ok { + return nil, false + } + sel.From = []ast.TableExpr{&ast.AliasedTableExpr{Expr: tab}} + + sel.Where, ok = p.whereOpt(ctx) + if !ok { + return nil, false + } + return sel, true +} + +func (p *parser) selExprs(ctx context.Context) (ast.SelectExprs, bool) { + var exprs ast.SelectExprs + id, tok := p.tok.Scan() + for { + if id == ast.FROM { + break + } + // literal + var expr ast.SelectExpr + var toAlias ast.Expr + var ok bool + switch id { + case ast.ID: + p.push(id, tok) + toAlias, ok = p.colName(ctx) + if !ok { + return nil, false + } + case '*': + expr = &ast.StarExpr{} + case ast.STRING, ast.INTEGRAL, ast.FLOAT, ast.NULL: + toAlias, ok = p.value(ctx, id, tok) + if !ok { + return nil, false + } + default: + return nil, false + } + if toAlias != nil { + expr = &ast.AliasedExpr{Expr: toAlias} + id, tok = p.pop() + if id == ast.AS { + expr = &ast.AliasedExpr{As: ast.NewColIdent(string(tok)), Expr: toAlias} + } + } + exprs = append(exprs, expr) + } + return exprs, true +} + +func (p *parser) colName(ctx context.Context) (*ast.ColName, bool) { + id, firstTok := p.pop() + if id != ast.ID { + return nil, false + } + + id, tok := p.pop() + if id != '.' { + p.push(id, tok) + return &ast.ColName{Name: ast.NewColIdent(string(firstTok))}, true + } + + id, secondTok := p.tok.Scan() + if id != ast.ID { + return nil, false + } + + return &ast.ColName{ + Qualifier: ast.TableName{ + Name: ast.NewTableIdent(string(firstTok)), + }, + Name: ast.NewColIdent(string(secondTok)), + }, true + + //id, tok = p.pop() + //if id == ast.AS { + // return &ast.AliasedExpr{ + // As: ast.NewColIdent(string(tok)), + // Expr: &ast.ColName{ + // Qualifier: ast.TableName{ + // Name: ast.NewTableIdent(string(firstTok)), + // }, + // Name: ast.NewColIdent(string(secondTok)), + // }, + // }, true + //} + // + //p.push(id, tok) + //return &ast.AliasedExpr{ + // Expr: &ast.ColName{ + // Qualifier: ast.TableName{ + // Name: ast.NewTableIdent(string(firstTok)), + // }, + // Name: ast.NewColIdent(string(secondTok)), + // }, + //}, true + // + // db, schema, ... +} + +func (p *parser) whereOpt(ctx context.Context) (*ast.Where, bool) { + id, tok := p.pop() + if id != ast.WHERE { + p.push(id, tok) + return nil, true + } + ret := new(ast.Where) + ret.Type = ast.WhereStr + var ok bool + ret.Expr, ok = p.expr(ctx) + if !ok { + return nil, false + } + return ret, true +} + +func (p *parser) expr(ctx context.Context) (ast.Expr, bool) { + firstExpr, ok := p.exprHelper(ctx) + if !ok { + return nil, false + } + secondExpr, ok := p.exprHelper(ctx) + if !ok { + return firstExpr, true + } + switch e := secondExpr.(type) { + case *ast.ComparisonExpr: + thirdExpr, ok := p.exprHelper(ctx) + if !ok { + return nil, false + } + e.Left = firstExpr + e.Right = thirdExpr + return e, true + default: + return nil, false + } +} + +func (p *parser) exprHelper(ctx context.Context) (ast.Expr, bool) { + id, tok := p.pop() + var expr ast.Expr + var ok bool + switch id { + case ast.ID: + p.push(id, tok) + expr, ok = p.colName(ctx) + case ast.STRING, ast.INTEGRAL, ast.FLOAT, ast.NULL: + expr, ok = p.value(ctx, id, tok) + case '=': + expr = &ast.ComparisonExpr{Operator: ast.EqualStr} + ok = true + default: + return nil, false + } + if !ok { + return nil, false + } + return expr, true +} diff --git a/sql/rdparser/statement.go b/sql/rdparser/statement.go new file mode 100644 index 0000000000..21e6880fdb --- /dev/null +++ b/sql/rdparser/statement.go @@ -0,0 +1,18 @@ +package rdparser + +import ( + "context" + ast "github.com/dolthub/vitess/go/vt/sqlparser" +) + +func (p *parser) statement(ctx context.Context) (ast.Statement, bool) { + id, _ := p.tok.Scan() + switch id { + case ast.INSERT: + return p.insert(ctx) + case ast.SELECT: + return p.sel(ctx) + default: + return nil, false + } +} From 06c986379ad5939f8f8a78b876d100d5293b7abe Mon Sep 17 00:00:00 2001 From: max-hoffman Date: Tue, 10 Sep 2024 22:26:00 +0000 Subject: [PATCH 2/4] [ga-format-pr] Run ./format_repo.sh to fix formatting --- engine.go | 2 +- sql/rdparser/insert.go | 1 + sql/rdparser/parser.go | 4 +++- sql/rdparser/parser_test.go | 3 ++- sql/rdparser/select.go | 1 + sql/rdparser/statement.go | 1 + 6 files changed, 9 insertions(+), 3 deletions(-) diff --git a/engine.go b/engine.go index 7a6902c3bb..df8cfa258c 100644 --- a/engine.go +++ b/engine.go @@ -16,7 +16,6 @@ package sqle import ( "fmt" - "github.com/dolthub/go-mysql-server/sql/rdparser" "os" "strconv" "strings" @@ -35,6 +34,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression/function" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/planbuilder" + "github.com/dolthub/go-mysql-server/sql/rdparser" "github.com/dolthub/go-mysql-server/sql/rowexec" "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" diff --git a/sql/rdparser/insert.go b/sql/rdparser/insert.go index 9784b551aa..26092129aa 100644 --- a/sql/rdparser/insert.go +++ b/sql/rdparser/insert.go @@ -2,6 +2,7 @@ package rdparser import ( "context" + ast "github.com/dolthub/vitess/go/vt/sqlparser" ) diff --git a/sql/rdparser/parser.go b/sql/rdparser/parser.go index 2d49e158b2..fdec159f6b 100644 --- a/sql/rdparser/parser.go +++ b/sql/rdparser/parser.go @@ -2,8 +2,10 @@ package rdparser import ( "context" - "github.com/dolthub/go-mysql-server/sql" + ast "github.com/dolthub/vitess/go/vt/sqlparser" + + "github.com/dolthub/go-mysql-server/sql" ) type parser struct { diff --git a/sql/rdparser/parser_test.go b/sql/rdparser/parser_test.go index b681d206cd..9c234dd637 100644 --- a/sql/rdparser/parser_test.go +++ b/sql/rdparser/parser_test.go @@ -2,9 +2,10 @@ package rdparser import ( "context" + "testing" + ast "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/stretchr/testify/require" - "testing" ) func TestParser(t *testing.T) { diff --git a/sql/rdparser/select.go b/sql/rdparser/select.go index 6411cfdd1d..3ba87cd372 100644 --- a/sql/rdparser/select.go +++ b/sql/rdparser/select.go @@ -2,6 +2,7 @@ package rdparser import ( "context" + ast "github.com/dolthub/vitess/go/vt/sqlparser" ) diff --git a/sql/rdparser/statement.go b/sql/rdparser/statement.go index 21e6880fdb..1aa0f2d610 100644 --- a/sql/rdparser/statement.go +++ b/sql/rdparser/statement.go @@ -2,6 +2,7 @@ package rdparser import ( "context" + ast "github.com/dolthub/vitess/go/vt/sqlparser" ) From a78f8890cb1ad85aea8797a8840a595b223f25dd Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Wed, 11 Sep 2024 19:57:40 -0700 Subject: [PATCH 3/4] progress --- sql/rdparser/insert.go | 8 +- sql/rdparser/parser.go | 50 ++++- sql/rdparser/select.go | 373 ++++++++++++++++++++++++++++++++++---- sql/rdparser/statement.go | 2 +- sql/rdparser/subquery.go | 16 ++ 5 files changed, 403 insertions(+), 46 deletions(-) create mode 100644 sql/rdparser/subquery.go diff --git a/sql/rdparser/insert.go b/sql/rdparser/insert.go index 9784b551aa..c1cb24fdc2 100644 --- a/sql/rdparser/insert.go +++ b/sql/rdparser/insert.go @@ -18,7 +18,7 @@ func (p *parser) insert(ctx context.Context) (ast.Statement, bool) { p.push(id, cur) var ok bool - ins.Table, ok = p.tableIdent(ctx) + ins.Table, ok = p.tableIdent() if !ok { return nil, false } @@ -61,7 +61,7 @@ func (p *parser) pop() (int, []byte) { } } -func (p *parser) tableIdent(ctx context.Context) (ast.TableName, bool) { +func (p *parser) tableIdent() (ast.TableName, bool) { // schema.database.table id, firstTok := p.pop() @@ -142,7 +142,7 @@ func (p *parser) valueList(ctx context.Context) (ast.InsertRows, bool) { if id == ')' { break } - value, ok := p.value(ctx, id, tok) + value, ok := p.value(id, tok) if !ok { return nil, false } @@ -159,7 +159,7 @@ func (p *parser) valueList(ctx context.Context) (ast.InsertRows, bool) { return &ast.AliasedValues{Values: rows}, true } -func (p *parser) value(ctx context.Context, id int, tok []byte) (ast.Expr, bool) { +func (p *parser) value(id int, tok []byte) (ast.Expr, bool) { switch id { case ast.STRING: return ast.NewStrVal(tok), true diff --git a/sql/rdparser/parser.go b/sql/rdparser/parser.go index 2d49e158b2..272c588c3e 100644 --- a/sql/rdparser/parser.go +++ b/sql/rdparser/parser.go @@ -2,22 +2,36 @@ package rdparser import ( "context" + "fmt" "github.com/dolthub/go-mysql-server/sql" ast "github.com/dolthub/vitess/go/vt/sqlparser" ) type parser struct { - tok *ast.Tokenizer - curOk bool - curId int - cur []byte + tok *ast.Tokenizer + curOk bool + curId int + cur []byte + peekOk bool + peekId int + _peek []byte } func NewParser() sql.Parser { return &parser{} } -func (p *parser) parse(ctx context.Context, s string, options ast.ParserOptions) (ast.Statement, error) { +func (p *parser) parse(ctx context.Context, s string, options ast.ParserOptions) (ret ast.Statement, err error) { + defer func() { + if mes := recover(); mes != nil { + _, ok := mes.(parseErr) + if !ok { + err = fmt.Errorf("panic encountered while parsing: %s", mes) + return + } + ret, err = ast.ParseWithOptions(ctx, s, options) + } + }() // get next token p.tok = ast.NewStringTokenizer(s) if options.AnsiQuotes { @@ -31,6 +45,32 @@ func (p *parser) parse(ctx context.Context, s string, options ast.ParserOptions) return ast.ParseWithOptions(ctx, s, options) } +type parseErr struct { + str string +} + +func (p *parser) fail(s string) { + panic(parseErr{s}) +} + +func (p *parser) next() (int, []byte) { + if p.peekOk { + p.peekOk = false + p.curId, p.cur = p.peekId, p._peek + } + p.curOk = true + p.curId, p.cur = p.tok.Scan() + return p.curId, p.cur +} + +func (p *parser) peek() (int, []byte) { + if !p.peekOk { + p.peekOk = true + p.peekId, p._peek = p.tok.Scan() + } + return p.peekId, p._peek +} + var _ sql.Parser = (*parser)(nil) func (p *parser) ParseSimple(query string) (ast.Statement, error) { diff --git a/sql/rdparser/select.go b/sql/rdparser/select.go index 6411cfdd1d..3366a362ef 100644 --- a/sql/rdparser/select.go +++ b/sql/rdparser/select.go @@ -1,33 +1,32 @@ package rdparser import ( - "context" ast "github.com/dolthub/vitess/go/vt/sqlparser" ) -func (p *parser) sel(ctx context.Context) (ast.Statement, bool) { +func (p *parser) sel() (ast.Statement, bool) { // SELECT FROM
WHERE sel := new(ast.Select) var ok bool - sel.SelectExprs, ok = p.selExprs(ctx) + sel.SelectExprs, ok = p.selExprs() if !ok { return nil, false } - tab, ok := p.tableIdent(ctx) + tab, ok := p.tableIdent() if !ok { return nil, false } sel.From = []ast.TableExpr{&ast.AliasedTableExpr{Expr: tab}} - sel.Where, ok = p.whereOpt(ctx) + sel.Where, ok = p.whereOpt() if !ok { return nil, false } return sel, true } -func (p *parser) selExprs(ctx context.Context) (ast.SelectExprs, bool) { +func (p *parser) selExprs() (ast.SelectExprs, bool) { var exprs ast.SelectExprs id, tok := p.tok.Scan() for { @@ -41,14 +40,14 @@ func (p *parser) selExprs(ctx context.Context) (ast.SelectExprs, bool) { switch id { case ast.ID: p.push(id, tok) - toAlias, ok = p.colName(ctx) + toAlias, ok = p.columnName() if !ok { return nil, false } case '*': expr = &ast.StarExpr{} case ast.STRING, ast.INTEGRAL, ast.FLOAT, ast.NULL: - toAlias, ok = p.value(ctx, id, tok) + toAlias, ok = p.value(id, tok) if !ok { return nil, false } @@ -67,7 +66,7 @@ func (p *parser) selExprs(ctx context.Context) (ast.SelectExprs, bool) { return exprs, true } -func (p *parser) colName(ctx context.Context) (*ast.ColName, bool) { +func (p *parser) columnName() (*ast.ColName, bool) { id, firstTok := p.pop() if id != ast.ID { return nil, false @@ -117,7 +116,7 @@ func (p *parser) colName(ctx context.Context) (*ast.ColName, bool) { // db, schema, ... } -func (p *parser) whereOpt(ctx context.Context) (*ast.Where, bool) { +func (p *parser) whereOpt() (*ast.Where, bool) { id, tok := p.pop() if id != ast.WHERE { p.push(id, tok) @@ -126,54 +125,356 @@ func (p *parser) whereOpt(ctx context.Context) (*ast.Where, bool) { ret := new(ast.Where) ret.Type = ast.WhereStr var ok bool - ret.Expr, ok = p.expr(ctx) + ret.Expr, ok = p.expression() if !ok { return nil, false } return ret, true } -func (p *parser) expr(ctx context.Context) (ast.Expr, bool) { - firstExpr, ok := p.exprHelper(ctx) - if !ok { +func (p *parser) expression() (ret ast.Expr, ok bool) { + // condition + // NOT + // DEFAULT + // valueExpression + id, _ := p.peek() + if id == ast.NOT { + p.next() + c, ok := p.expression() + if !ok { + return nil, false + } + ret = &ast.NotExpr{Expr: c} + } else if id == ast.DEFAULT { + p.next() + var d string + id, _ = p.peek() + if id == '(' { + p.next() + if ident, ok := p.id(); ok { + id, _ = p.peek() + if id != ')' { + p.next() + ret = &ast.Default{ColName: ident} + } else { + p.fail("invalid default expression") + } + } + } + ret = &ast.Default{ColName: d} + } else if ret, ok = p.condition(); ok { + return ret, true + } else if ret, ok = p.valueExpression(); ok { + return ret, true + } else { return nil, false } - secondExpr, ok := p.exprHelper(ctx) + + id, _ = p.peek() + // AND OR XOR IS + switch id { + case ast.AND: + p.next() + right, ok := p.expression() + if ok { + ret = &ast.AndExpr{Left: ret, Right: right} + } + case ast.OR: + p.next() + if right, ok := p.expression(); ok { + ret = &ast.OrExpr{Left: ret, Right: right} + } + case ast.XOR: + p.next() + if right, ok := p.expression(); ok { + ret = &ast.XorExpr{Left: ret, Right: right} + } + case ast.IS: + p.next() + if is := p.isSuffix(); is != "" { + ret = &ast.IsExpr{Operator: is, Expr: ret} + } + } + return ret, true +} + +func (p *parser) id() (string, bool) { + id, tok := p.peek() + if id == ast.ID { + p.next() + return string(tok), true + } + return "", false +} + +func (p *parser) isSuffix() string { + // (NOT) NULL|TRUE|FALSE + id, _ := p.peek() + var not bool + if id == ast.NOT { + p.next() + not = true + } + id, _ = p.peek() + switch id { + case ast.TRUE: + if not { + return ast.IsNotTrueStr + } else { + return ast.IsTrueStr + } + case ast.FALSE: + if not { + return ast.IsNotFalseStr + } else { + return ast.IsFalseStr + } + case ast.NULL: + if not { + return ast.IsNotNullStr + } else { + return ast.IsNullStr + } + default: + return "" + } +} + +func (p *parser) condition() (ret ast.Expr, ok bool) { + id, _ := p.peek() + if id == ast.EXISTS { + if subq, ok := p.subquery(); ok { + return &ast.ExistsExpr{Subquery: subq}, true + } + p.fail("expected subquery after EXISTS") + } + + left, ok := p.valueExpression() if !ok { - return firstExpr, true + return nil, false } - switch e := secondExpr.(type) { - case *ast.ComparisonExpr: - thirdExpr, ok := p.exprHelper(ctx) + + id, _ = p.peek() + var not bool + if id == ast.NOT { + p.next() + id, _ = p.peek() + not = true + } + + if comp := p.compare(); comp != "" { + right, ok := p.valueExpression() if !ok { return nil, false } - e.Left = firstExpr - e.Right = thirdExpr - return e, true - default: + ret = &ast.ComparisonExpr{Operator: comp, Left: left, Right: right} + } else { + switch id { + case ast.IN: + right, ok := p.colTuple() + if !ok { + return nil, false + } + ret = &ast.ComparisonExpr{Operator: ast.InStr, Left: left, Right: right} + case ast.LIKE: + right, ok := p.valueExpression() + if !ok { + return nil, false + } + esc := p.likeEscapeOpt() + ret = &ast.ComparisonExpr{Operator: ast.LikeStr, Left: left, Right: right, Escape: esc} + case ast.REGEXP: + right, ok := p.valueExpression() + if !ok { + return nil, false + } + ret = &ast.ComparisonExpr{Operator: ast.RegexpStr, Left: left, Right: right} + case ast.BETWEEN: + from, ok := p.valueExpression() + if !ok { + return nil, false + } + id, _ = p.next() + if id != ast.AND { + p.fail("between expected AND") + } + to, ok := p.valueExpression() + if !ok { + return nil, false + } + ret = &ast.RangeCond{Operator: ast.BetweenStr, Left: left, From: from, To: to} + } + } + + if not { + ret = &ast.NotExpr{Expr: ret} + } + return ret, true +} + +func (p *parser) valueExpression() (ret ast.Expr, ok bool) { + // value + // ACCOUNT, FORMAT + // boolean_value + // column_name + // column_name_safe_keyword + // tuple_expression + // subquery + // BINARY + id, tok := p.peek() + + if id == ast.FORMAT || id == ast.ACCOUNT { + return &ast.ColName{Name: ast.NewColIdent(string(tok))}, true + } else if v, ok := p.value(id, tok); ok { + return v, true + } else if id == ast.TRUE || id == ast.FALSE { + return ast.BoolVal(id == ast.TRUE), true + } else if col, ok := p.columnName(); ok { + p.next() + id, _ = p.peek() + switch id { + case ast.JSON_EXTRACT_OP: + p.next() + id, tok := p.next() + val, ok := p.value(id, tok) + if !ok { + return nil, false + } + return &ast.BinaryExpr{Operator: ast.JSONExtractOp, Left: col, Right: val}, true + case ast.JSON_UNQUOTE_EXTRACT_OP: + p.next() + id, tok := p.next() + val, ok := p.value(id, tok) + if !ok { + return nil, false + } + return &ast.BinaryExpr{Operator: ast.JSONUnquoteExtractOp, Left: col, Right: val}, true + default: + return col, true + } + return col, true + } else if col, ok := p.columnNameSafeKeyWord(); ok { + return col, true + } else if tup, ok := p.tupleExpression(); ok { + return tup, true + } else if subq, ok := p.subquery(); ok { + return subq, true + } + + // underscore_charsets valueExpr UNARY + // + valueExpr UNARY + // - valueExpr UNARY + // ! valueExpr UNARY + // ~ valueExpr + // INTERVAL value_expression sql_id + + // function_call_generic + // function_call_keyword + // function_call_nonkeyword + // function_call_conflict + // function_call_window + // function_call_aggregate_with_window + + // valueExpr op valueExpr + left, ok := p.valueExpression() + if ok { return nil, false } + switch id { + case '+': + case '-': + case '*': + case '/': + case '^': + case '&': + case '|': + case ast.DIV: + case '%': + case ast.MOD: + case ast.SHIFT_LEFT: + case ast.SHIFT_RIGHT: + case ast.COLLATE: + } } -func (p *parser) exprHelper(ctx context.Context) (ast.Expr, bool) { - id, tok := p.pop() - var expr ast.Expr - var ok bool +func (p *parser) compare() (ret string) { + id, _ := p.peek() switch id { - case ast.ID: - p.push(id, tok) - expr, ok = p.colName(ctx) - case ast.STRING, ast.INTEGRAL, ast.FLOAT, ast.NULL: - expr, ok = p.value(ctx, id, tok) case '=': - expr = &ast.ComparisonExpr{Operator: ast.EqualStr} - ok = true + ret = ast.EqualStr + case '<': + ret = ast.LessThanStr + case '>': + ret = ast.GreaterThanStr + case ast.LE: + ret = ast.LessEqualStr + case ast.GE: + ret = ast.GreaterEqualStr + case ast.NE: + ret = ast.NotEqualStr + case ast.NULL_SAFE_EQUAL: + ret = ast.NullSafeEqualStr default: + return "" + } + p.next() + return ret +} + +func (p *parser) colTuple() (ast.Expr, bool) { + id, tok := p.peek() + if id == ast.LIST_ARG { + return ast.ListArg(tok), true + } else if id != '(' { return nil, false } + p.next() + id, _ = p.peek() + + if id == ast.SELECT { + selStmt, ok := p.selNoInto() + if !ok { + return nil, false + } + p.next() + id, _ = p.next() + if id != ')' { + p.fail("expected subquery to end with ')'") + } + p.next() + return &ast.Subquery{Select: selStmt}, true + } + + //expr list + var tup ast.ValTuple + for { + e, ok := p.expression() + if !ok { + return nil, false + } + tup = append(tup, e) + id, _ = p.next() + if id == ')' { + break + } else if id != ',' { + p.fail("invalid expression list") + } + } + return tup, true +} + +func (p *parser) selNoInto() (*ast.Select, bool) { +} + +func (p *parser) likeEscapeOpt() ast.Expr { + id, _ := p.peek() + if id != ast.ESCAPE { + return nil + } + p.next() + ret, ok := p.valueExpression() if !ok { - return nil, false + p.fail("expected value expression after ESCAPE") } - return expr, true + return ret } diff --git a/sql/rdparser/statement.go b/sql/rdparser/statement.go index 21e6880fdb..9393817b4f 100644 --- a/sql/rdparser/statement.go +++ b/sql/rdparser/statement.go @@ -11,7 +11,7 @@ func (p *parser) statement(ctx context.Context) (ast.Statement, bool) { case ast.INSERT: return p.insert(ctx) case ast.SELECT: - return p.sel(ctx) + return p.sel() default: return nil, false } diff --git a/sql/rdparser/subquery.go b/sql/rdparser/subquery.go new file mode 100644 index 0000000000..7ee71b4ddf --- /dev/null +++ b/sql/rdparser/subquery.go @@ -0,0 +1,16 @@ +package rdparser + +import ast "github.com/dolthub/vitess/go/vt/sqlparser" + +func (p *parser) subquery() (*ast.Subquery, bool) { + id, _ := p.peek() + if id != '(' { + return nil, false + } + p.next() + selNoInt, ok := p.sel() + if !ok { + return nil, false + } + return &ast.Subquery{Select: selNoInt}, true +} From 105d89652c41df43d987d09844d69ca53971e789 Mon Sep 17 00:00:00 2001 From: max-hoffman Date: Thu, 12 Sep 2024 03:00:07 +0000 Subject: [PATCH 4/4] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/rdparser/parser.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/rdparser/parser.go b/sql/rdparser/parser.go index 272c588c3e..2f4e62390a 100644 --- a/sql/rdparser/parser.go +++ b/sql/rdparser/parser.go @@ -3,8 +3,10 @@ package rdparser import ( "context" "fmt" - "github.com/dolthub/go-mysql-server/sql" + ast "github.com/dolthub/vitess/go/vt/sqlparser" + + "github.com/dolthub/go-mysql-server/sql" ) type parser struct {