diff --git a/ast.go b/ast.go index 441e7c4..bf788cf 100644 --- a/ast.go +++ b/ast.go @@ -689,6 +689,7 @@ type UpsertClause struct { DoNothing bool // position of NOTHING keyword after DO DoUpdate bool // position of UPDATE keyword after DO + DuplicateKey bool // position of ON DUPLICATE KEY UPDATE keyword Assignments []*Assignment // list of column assignments UpdateWhereExpr Expr // optional conditional expression for DO UPDATE SET } @@ -696,40 +697,49 @@ type UpsertClause struct { // String returns the string representation of the clause. func (c *UpsertClause) String() string { var buf bytes.Buffer - buf.WriteString("ON CONFLICT") - - if len(c.Columns) != 0 { - buf.WriteString(" (") - for i, col := range c.Columns { + if c.DuplicateKey { + buf.WriteString("ON DUPLICATE KEY UPDATE ") + for i := range c.Assignments { if i != 0 { buf.WriteString(", ") } - buf.WriteString(col.String()) + buf.WriteString(c.Assignments[i].String()) } - buf.WriteString(")") + } else { + buf.WriteString("ON CONFLICT") - if c.WhereExpr != nil { - fmt.Fprintf(&buf, " WHERE %s", c.WhereExpr.String()) - } - } + if len(c.Columns) != 0 { + buf.WriteString(" (") + for i, col := range c.Columns { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(col.String()) + } + buf.WriteString(")") - buf.WriteString(" DO") - if c.DoNothing { - buf.WriteString(" NOTHING") - } else { - buf.WriteString(" UPDATE SET ") - for i := range c.Assignments { - if i != 0 { - buf.WriteString(", ") + if c.WhereExpr != nil { + fmt.Fprintf(&buf, " WHERE %s", c.WhereExpr.String()) } - buf.WriteString(c.Assignments[i].String()) } - if c.UpdateWhereExpr != nil { - fmt.Fprintf(&buf, " WHERE %s", c.UpdateWhereExpr.String()) + buf.WriteString(" DO") + if c.DoNothing { + buf.WriteString(" NOTHING") + } else { + buf.WriteString(" UPDATE SET ") + for i := range c.Assignments { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(c.Assignments[i].String()) + } + + if c.UpdateWhereExpr != nil { + fmt.Fprintf(&buf, " WHERE %s", c.UpdateWhereExpr.String()) + } } } - return buf.String() } diff --git a/parser.go b/parser.go index d11fef3..e308815 100644 --- a/parser.go +++ b/parser.go @@ -212,8 +212,12 @@ func (p *Parser) parseUpsertClause() (_ *UpsertClause, err error) { // Parse "ON CONFLICT" p.lex() - if p.peek() != CONFLICT { - return &clause, p.errorExpected(p.pos, p.tok, "CONFLICT") + switch p.peek() { + case CONFLICT: + case DUPLICATE: + clause.DuplicateKey = true + default: + return &clause, p.errorExpected(p.pos, p.tok, "CONFLICT or DUPLICATE") } p.lex() @@ -244,9 +248,11 @@ func (p *Parser) parseUpsertClause() (_ *UpsertClause, err error) { } } - // Parse "DO NOTHING" or "DO UPDATE SET". - if p.peek() != DO { + // Parse "DO NOTHING" or "DO UPDATE SET" or "ON DUPLICATE KEY". + if !clause.DuplicateKey && p.peek() != DO { return &clause, p.errorExpected(p.pos, p.tok, "DO") + } else if clause.DuplicateKey && p.peek() != KEY { + return &clause, p.errorExpected(p.pos, p.tok, "KEY") } p.lex() @@ -262,10 +268,12 @@ func (p *Parser) parseUpsertClause() (_ *UpsertClause, err error) { // Otherwise parse "UPDATE SET" p.lex() clause.DoUpdate = true - if p.peek() != SET { - return &clause, p.errorExpected(p.pos, p.tok, "SET") + if !clause.DuplicateKey { + if p.peek() != SET { + return &clause, p.errorExpected(p.pos, p.tok, "SET") + } + p.lex() } - p.lex() // Parse list of assignments. for { diff --git a/token.go b/token.go index c58b20a..89bfb57 100644 --- a/token.go +++ b/token.go @@ -233,6 +233,7 @@ const ( WINDOW WITH WITHOUT + DUPLICATE keyword_end ANY // ??? @@ -444,6 +445,7 @@ var tokens = [...]string{ WINDOW: "WINDOW", WITH: "WITH", WITHOUT: "WITHOUT", + DUPLICATE: "DUPLICATE", } func (tok Token) String() string {