Skip to content

Commit

Permalink
fix: aggreate should throw error when args is not right in mysql
Browse files Browse the repository at this point in the history
  • Loading branch information
taozhi8833998 committed Nov 28, 2023
1 parent 120bce0 commit 53d0b39
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 6 deletions.
20 changes: 19 additions & 1 deletion pegjs/mariadb.pegjs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,24 @@
'PERSIST_ONLY': true,
};

const reservedFunctionName = {
avg: true,
sum: true,
count: true,
max: true,
min: true,
group_concat: true,
std: true,
variance: true,
current_date: true,
current_time: true,
current_timestamp: true,
current_user: true,
user: true,
session_user: true,
system_user: true
}

function createUnaryExpr(op, e) {
return {
type: 'unary_expr',
Expand Down Expand Up @@ -2931,7 +2949,7 @@ func_call
over: up
}
}
/ name:proc_func_name &{ return name.toLowerCase() !== 'convert' } __ LPAREN __ l:or_and_where_expr? __ RPAREN __ bc:over_partition? {
/ name:proc_func_name &{ return name.toLowerCase() !== 'convert' && !reservedFunctionName[name.toLowerCase()] } __ LPAREN __ l:or_and_where_expr? __ RPAREN __ bc:over_partition? {
if (l && l.type !== 'expr_list') l = { type: 'expr_list', value: [l] }
if ((name.toUpperCase() === 'TIMESTAMPDIFF' || name.toUpperCase() === 'TIMESTAMPADD') && l.value && l.value[0]) l.value[0] = { type: 'origin', value: l.value[0].column }
return {
Expand Down
24 changes: 21 additions & 3 deletions pegjs/mysql.pegjs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,24 @@
'ZEROFILL': true,
};

const reservedFunctionName = {
avg: true,
sum: true,
count: true,
max: true,
min: true,
group_concat: true,
std: true,
variance: true,
current_date: true,
current_time: true,
current_timestamp: true,
current_user: true,
user: true,
session_user: true,
system_user: true
}

function createUnaryExpr(op, e) {
return {
type: 'unary_expr',
Expand Down Expand Up @@ -1564,8 +1582,7 @@ lock_stmt
}

call_stmt
= KW_CALL __
e: proc_func_call {
= KW_CALL __ e:proc_func_call {
return {
tableList: Array.from(tableList),
columnList: columnListTableAlias(columnList),
Expand Down Expand Up @@ -3191,6 +3208,7 @@ trim_func_clause
args,
};
}

func_call
= extract_func / trim_func_clause
/ 'convert'i __ LPAREN __ l:convert_args __ RPAREN __ ca:collate_expr? {
Expand All @@ -3216,7 +3234,7 @@ func_call
over: up
}
}
/ name:proc_func_name &{ return name.toLowerCase() !== 'convert' } __ LPAREN __ l:or_and_where_expr? __ RPAREN __ bc:over_partition? {
/ name:proc_func_name &{ return name.toLowerCase() !== 'convert' && !reservedFunctionName[name.toLowerCase()] } __ LPAREN __ l:or_and_where_expr? __ RPAREN __ bc:over_partition? {
if (l && l.type !== 'expr_list') l = { type: 'expr_list', value: [l] }
if ((name.toUpperCase() === 'TIMESTAMPDIFF' || name.toUpperCase() === 'TIMESTAMPADD') && l.value && l.value[0]) l.value[0] = { type: 'origin', value: l.value[0].column }
return {
Expand Down
6 changes: 4 additions & 2 deletions test/mysql-mariadb.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -897,10 +897,12 @@ describe('mysql', () => {
})
})

it('should throw error when covert args is not right', () => {
const sql = `select convert(json_unquote(json_extract('{"thing": "252"}', "$.thing")));`
it('should throw error when args is not right', () => {
let sql = `select convert(json_unquote(json_extract('{"thing": "252"}', "$.thing")));`
expect(parser.astify.bind(parser, sql)).to.throw('Expected "!=", "#", "%", "&", "&&", "*", "+", ",", "-", "--", "/", "/*", "<", "<<", "<=", "<>", "=", ">", ">=", ">>", "AND", "BETWEEN", "IN", "IS", "LIKE", "NOT", "ON", "OR", "OVER", "REGEXP", "RLIKE", "USING", "XOR", "^", "div", "|", "||", or [ \\t\\n\\r] but ")" found.')
expect(parser.astify.bind(parser, 'select convert("");')).to.throw('Expected "!=", "#", "%", "&", "&&", "*", "+", ",", "-", "--", "/", "/*", "<", "<<", "<=", "<>", "=", ">", ">=", ">>", "AND", "BETWEEN", "COLLATE", "IN", "IS", "LIKE", "NOT", "OR", "REGEXP", "RLIKE", "USING", "XOR", "^", "div", "|", "||", or [ \\t\\n\\r] but ")" found.')
sql = 'SELECT AVG(Quantity,age) FROM table1;'
expect(parser.astify.bind(parser, sql)).to.throw('Expected "#", "%", "&", "(", ")", "*", "+", "-", "--", "->", "->>", ".", "/", "/*", "<<", ">>", "^", "div", "|", "||", [ \\t\\n\\r], [A-Za-z0-9_$\\x80-￿], or [A-Za-z0-9_:] but "," found.')
})

it('should join multiple table with comma', () => {
Expand Down

0 comments on commit 53d0b39

Please sign in to comment.