diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 6d6616ad92f..c97f1750cfd 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -823,31 +823,29 @@ func tryPushUnion(ctx *plancontext.PlanningContext, op *Union) (Operator, *Apply func handleLastInsertIDColumns(ctx *plancontext.PlanningContext, output Operator) Operator { offset := -1 topLevel := false - var arg sqlparser.Expr + var foundFunc *sqlparser.FuncExpr for i, expr := range output.GetSelectExprs(ctx) { ae, ok := expr.(*sqlparser.AliasedExpr) if !ok { panic(vterrors.VT09015()) } - replaceFn := func(node sqlparser.Expr) (sqlparser.Expr, bool) { + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { fnc, ok := node.(*sqlparser.FuncExpr) if !ok || !fnc.Name.EqualString("last_insert_id") { - return node, false + return true, nil } if offset != -1 { panic(vterrors.VT12001("last_insert_id() found multiple times in select list")) } - arg = fnc.Exprs[0] + foundFunc = fnc if node == ae.Expr { topLevel = true } offset = i - return arg, true - } + return false, nil - newExpr := sqlparser.CopyAndReplaceExpr(ae.Expr, replaceFn) - ae.Expr = newExpr.(sqlparser.Expr) + }, ae.Expr) } if offset == -1 { panic(vterrors.VT12001("last_insert_id() only supported in the select list")) @@ -861,7 +859,7 @@ func handleLastInsertIDColumns(ctx *plancontext.PlanningContext, output Operator } } - offset = output.AddColumn(ctx, false, false, aeWrap(arg)) + offset = output.AddColumn(ctx, false, false, aeWrap(foundFunc)) return &SaveToSession{ unaryOperator: unaryOperator{ Source: output, @@ -890,25 +888,7 @@ func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, s } func extractSelectExpressions(horizon *Horizon) sqlparser.SelectExprs { - sel := sqlparser.GetFirstSelect(horizon.Query) - // we handle last_insert_id with arguments separately - no need to send this down to mysql - selExprs := sqlparser.CopyAndReplaceExpr(sel.SelectExprs, func(node sqlparser.Expr) (sqlparser.Expr, bool) { - switch node := node.(type) { - case *sqlparser.FuncExpr: - if node.Name.EqualString("last_insert_id") && len(node.Exprs) == 1 { - return node.Exprs[0], true - } - return node, true - case sqlparser.Expr: - // we do this to make sure we get a clone of the expression - // if planning changes the expression, we should not change the original - return node, true - default: - return nil, false - } - }) - - return selExprs.(sqlparser.SelectExprs) + return sqlparser.CloneSelectExprs(sqlparser.GetFirstSelect(horizon.Query).SelectExprs) } func colNamesAlign(expected, actual sqlparser.SelectExprs) bool { diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 6f4b8716ce0..b1134becef6 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -2177,8 +2177,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select 12 from dual where 1 != 1", - "Query": "select 12 from dual", + "FieldQuery": "select last_insert_id(12) from dual where 1 != 1", + "Query": "select last_insert_id(12) from dual", "Table": "dual" } ] @@ -2205,8 +2205,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select bar, 12, foo from `user` where 1 != 1", - "Query": "select bar, 12, foo from `user`", + "FieldQuery": "select bar, 12, last_insert_id(foo) from `user` where 1 != 1", + "Query": "select bar, 12, last_insert_id(foo) from `user`", "Table": "`user`" } ] @@ -5687,8 +5687,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user`", + "FieldQuery": "select last_insert_id(id) from `user` where 1 != 1", + "Query": "select last_insert_id(id) from `user`", "Table": "`user`" } ]