From a2ade2281a40ca9a745ee9fccfd099c0f3fa077e Mon Sep 17 00:00:00 2001 From: yb huang Date: Fri, 6 Sep 2024 15:46:49 +0800 Subject: [PATCH] support locking clause in select statement --- ast.go | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++ ast_test.go | 36 ++++++++++++++++++++++++++++++ parser.go | 50 +++++++++++++++++++++++++++++++++++++++++ parser_test.go | 38 ++++++++++++++++++++++++++++++++ token.go | 8 +++++++ 5 files changed, 192 insertions(+) diff --git a/ast.go b/ast.go index bf788cf..1f892f4 100644 --- a/ast.go +++ b/ast.go @@ -611,6 +611,60 @@ func (t *OrderingTerm) String() string { return buf.String() } +type LockStrength int + +const ( + Update LockStrength = iota + NoKeyUpdate + Share + KeyShare +) + +type LockOption int + +func (l LockOption) ToPtr() *LockOption { + return &l +} + +const ( + Nowait LockOption = iota + SkipLocked +) + +type LockingClause struct { + Strength LockStrength + + Option *LockOption +} + +func (c *LockingClause) String() string { + var buf bytes.Buffer + buf.Grow(30) + buf.WriteString("FOR") + + switch c.Strength { + case Update: + buf.WriteString(" UPDATE") + case NoKeyUpdate: + buf.WriteString(" NO KEY UPDATE") + case Share: + buf.WriteString(" SHARE") + case KeyShare: + buf.WriteString(" KEY SHARE") + } + + if c.Option != nil { + switch *c.Option { + case Nowait: + buf.WriteString(" NOWAIT") + case SkipLocked: + buf.WriteString(" SKIP LOCKED") + } + } + + return buf.String() +} + type ColumnArg interface { Node columnArg() @@ -928,6 +982,8 @@ type SelectStatement struct { Limit Expr Offset Expr // offset expression + Locking *LockingClause + Hint *Hint } @@ -1004,6 +1060,10 @@ func (s *SelectStatement) String() string { } } + if s.Locking != nil { + fmt.Fprintf(&buf, " %s", s.Locking.String()) + } + return buf.String() } diff --git a/ast_test.go b/ast_test.go index 572db50..1f42611 100644 --- a/ast_test.go +++ b/ast_test.go @@ -312,6 +312,42 @@ func TestSelectStatement_String(t *testing.T) { Y: &sqlparser.TableName{Name: &sqlparser.Ident{Name: "y"}}, }, }, `SELECT * FROM x CROSS JOIN y`) + + AssertStatementStringer(t, &sqlparser.SelectStatement{ + Distinct: true, + Columns: &sqlparser.OutputNames{&sqlparser.ResultColumn{ + Star: true, + }}, + Condition: &sqlparser.BinaryExpr{ + X: &sqlparser.Ident{Name: "ID"}, + Op: sqlparser.EQ, + Y: &sqlparser.NumberLit{Value: "1"}, + }, + FromItems: &sqlparser.TableName{ + Name: &sqlparser.Ident{Name: "tbl"}, + }, + Locking: &sqlparser.LockingClause{ + Strength: sqlparser.Update, + Option: sqlparser.Nowait.ToPtr(), + }}, `SELECT DISTINCT * FROM tbl WHERE ID = 1 FOR UPDATE NOWAIT`) + + AssertStatementStringer(t, &sqlparser.SelectStatement{ + Distinct: true, + Columns: &sqlparser.OutputNames{&sqlparser.ResultColumn{ + Star: true, + }}, + Condition: &sqlparser.BinaryExpr{ + X: &sqlparser.Ident{Name: "ID"}, + Op: sqlparser.EQ, + Y: &sqlparser.NumberLit{Value: "1"}, + }, + FromItems: &sqlparser.TableName{ + Name: &sqlparser.Ident{Name: "tbl"}, + }, + Locking: &sqlparser.LockingClause{ + Strength: sqlparser.NoKeyUpdate, + Option: sqlparser.SkipLocked.ToPtr(), + }}, `SELECT DISTINCT * FROM tbl WHERE ID = 1 FOR NO KEY UPDATE SKIP LOCKED`) } func TestUpdateStatement_String(t *testing.T) { diff --git a/parser.go b/parser.go index e308815..ea93260 100644 --- a/parser.go +++ b/parser.go @@ -616,6 +616,56 @@ func (p *Parser) parseSelectStatement(compounded bool) (_ *SelectStatement, err } } + if !compounded && p.peek() == FOR { + locking := &LockingClause{} + p.lex() + switch p.peek() { + case UPDATE: + locking.Strength = Update + p.lex() + case NO: + p.lex() + if p.peek() != KEY { + return &stmt, p.errorExpected(p.pos, p.tok, "KEY") + } + p.lex() + if p.peek() != UPDATE { + return &stmt, p.errorExpected(p.pos, p.tok, "UPDATE") + } + locking.Strength = NoKeyUpdate + p.lex() + case SHARE: + locking.Strength = Share + p.lex() + case KEY: + p.lex() + if p.peek() != SHARE { + return &stmt, p.errorExpected(p.pos, p.tok, "SHARE") + } + locking.Strength = KeyShare + p.lex() + default: + return &stmt, p.errorExpected(p.pos, p.tok, "UPDATE | NO | SHARE | KEY") + } + + switch p.peek() { + case NOWAIT: + locking.Option = Nowait.ToPtr() + p.lex() + case SKIP: + p.lex() + if p.peek() != LOCKED { + return &stmt, p.errorExpected(p.pos, p.tok, "LOCKED") + } + locking.Option = SkipLocked.ToPtr() + p.lex() + default: + + } + + stmt.Locking = locking + } + return &stmt, nil } diff --git a/parser_test.go b/parser_test.go index 73da06c..14c7290 100644 --- a/parser_test.go +++ b/parser_test.go @@ -343,6 +343,44 @@ func TestParser_ParseStatement(t *testing.T) { }, }) + AssertParseStatement(t, `SELECT DISTINCT * FROM tbl WHERE ID = 1 FOR UPDATE NOWAIT`, &sqlparser.SelectStatement{ + Distinct: true, + Columns: &sqlparser.OutputNames{&sqlparser.ResultColumn{ + Star: true, + }}, + Condition: &sqlparser.BinaryExpr{ + X: &sqlparser.Ident{Name: "ID"}, + Op: sqlparser.EQ, + Y: &sqlparser.NumberLit{Value: "1"}, + }, + FromItems: &sqlparser.TableName{ + Name: &sqlparser.Ident{Name: "tbl"}, + }, + Locking: &sqlparser.LockingClause{ + Strength: sqlparser.Update, + Option: sqlparser.Nowait.ToPtr(), + }, + }) + + AssertParseStatement(t, `SELECT DISTINCT * FROM tbl WHERE ID = 1 FOR NO KEY UPDATE SKIP LOCKED`, &sqlparser.SelectStatement{ + Distinct: true, + Columns: &sqlparser.OutputNames{&sqlparser.ResultColumn{ + Star: true, + }}, + Condition: &sqlparser.BinaryExpr{ + X: &sqlparser.Ident{Name: "ID"}, + Op: sqlparser.EQ, + Y: &sqlparser.NumberLit{Value: "1"}, + }, + FromItems: &sqlparser.TableName{ + Name: &sqlparser.Ident{Name: "tbl"}, + }, + Locking: &sqlparser.LockingClause{ + Strength: sqlparser.NoKeyUpdate, + Option: sqlparser.SkipLocked.ToPtr(), + }, + }) + AssertParseStatementError(t, `SELECT `, `1:7: expected expression, found 'EOF'`) AssertParseStatementError(t, `SELECT 1+`, `1:9: expected expression, found 'EOF'`) AssertParseStatementError(t, `SELECT foo,`, `1:11: expected expression, found 'EOF'`) diff --git a/token.go b/token.go index 89bfb57..093740e 100644 --- a/token.go +++ b/token.go @@ -234,6 +234,10 @@ const ( WITH WITHOUT DUPLICATE + SHARE + NOWAIT + SKIP + LOCKED keyword_end ANY // ??? @@ -446,6 +450,10 @@ var tokens = [...]string{ WITH: "WITH", WITHOUT: "WITHOUT", DUPLICATE: "DUPLICATE", + SHARE: "SHARE", + NOWAIT: "NOWAIT", + SKIP: "SKIP", + LOCKED: "LOCKED", } func (tok Token) String() string {