Skip to content

Commit

Permalink
Add support for multi table update with non literal value (#15980)
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal authored May 22, 2024
1 parent 64ae1b7 commit 2283f6b
Show file tree
Hide file tree
Showing 18 changed files with 484 additions and 41 deletions.
31 changes: 31 additions & 0 deletions go/test/endtoend/vtgate/queries/dml/dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,37 @@ func TestMultiTargetUpdate(t *testing.T) {
`[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("xyz")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("a")]]`)
}

// TestMultiTargetNonLiteralUpdate executed multi-target update queries with non-literal values.
func TestMultiTargetNonLiteralUpdate(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")

mcmp, closer := start(t)
defer closer()

// initial rows
mcmp.Exec("insert into order_tbl(region_id, oid, cust_no) values (1,1,4), (1,2,2), (2,3,5), (2,4,55)")
mcmp.Exec("insert into oevent_tbl(oid, ename) values (1,'a'), (2,'b'), (3,'a'), (4,'c')")

// multi target update
qr := mcmp.Exec(`update order_tbl o join oevent_tbl ev on o.oid = ev.oid set ev.ename = o.cust_no where ev.oid > 3`)
assert.EqualValues(t, 1, qr.RowsAffected)

// check rows
mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`,
`[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)]]`)
mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`,
`[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("55")]]`)

qr = mcmp.Exec(`update order_tbl o, oevent_tbl ev set ev.ename = 'xyz', o.oid = ev.oid + 40 where o.cust_no = ev.oid and ev.ename = 'b'`)
assert.EqualValues(t, 2, qr.RowsAffected)

// check rows
mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid, region_id`,
`[[INT64(1) INT64(1) INT64(4)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)] [INT64(1) INT64(42) INT64(2)]]`)
mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`,
`[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("xyz")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("55")]]`)
}

// TestDMLInUnique for update/delete statement using an IN clause with the Vindexes,
// the query is correctly split according to the corresponding values in the IN list.
func TestDMLInUnique(t *testing.T) {
Expand Down
23 changes: 22 additions & 1 deletion go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

73 changes: 62 additions & 11 deletions go/vt/vtgate/engine/dml_with_input.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type DMLWithInput struct {

DMLs []Primitive
OutputCols [][]int
BVList []map[string]int
}

func (dml *DMLWithInput) RouteType() string {
Expand Down Expand Up @@ -69,18 +70,16 @@ func (dml *DMLWithInput) TryExecute(ctx context.Context, vcursor VCursor, bindVa

var res *sqltypes.Result
for idx, prim := range dml.DMLs {
var bv *querypb.BindVariable
if len(dml.OutputCols[idx]) == 1 {
bv = getBVSingle(inputRes, dml.OutputCols[idx][0])
var qr *sqltypes.Result
if len(dml.BVList) == 0 || len(dml.BVList[idx]) == 0 {
qr, err = executeLiteralUpdate(ctx, vcursor, bindVars, prim, inputRes, dml.OutputCols[idx])
} else {
bv = getBVMulti(inputRes, dml.OutputCols[idx])
qr, err = executeNonLiteralUpdate(ctx, vcursor, bindVars, prim, inputRes, dml.OutputCols[idx], dml.BVList[idx])
}

bindVars[DmlVals] = bv
qr, err := vcursor.ExecutePrimitive(ctx, prim, bindVars, false)
if err != nil {
return nil, err
}

if res == nil {
res = qr
} else {
Expand All @@ -90,18 +89,32 @@ func (dml *DMLWithInput) TryExecute(ctx context.Context, vcursor VCursor, bindVa
return res, nil
}

func getBVSingle(res *sqltypes.Result, offset int) *querypb.BindVariable {
// executeLiteralUpdate executes the primitive that can be executed with a single bind variable from the input result.
// The column updated have same value for all rows in the input result.
func executeLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, prim Primitive, inputRes *sqltypes.Result, outputCols []int) (*sqltypes.Result, error) {
var bv *querypb.BindVariable
if len(outputCols) == 1 {
bv = getBVSingle(inputRes.Rows, outputCols[0])
} else {
bv = getBVMulti(inputRes.Rows, outputCols)
}

bindVars[DmlVals] = bv
return vcursor.ExecutePrimitive(ctx, prim, bindVars, false)
}

func getBVSingle(rows []sqltypes.Row, offset int) *querypb.BindVariable {
bv := &querypb.BindVariable{Type: querypb.Type_TUPLE}
for _, row := range res.Rows {
for _, row := range rows {
bv.Values = append(bv.Values, sqltypes.ValueToProto(row[offset]))
}
return bv
}

func getBVMulti(res *sqltypes.Result, offsets []int) *querypb.BindVariable {
func getBVMulti(rows []sqltypes.Row, offsets []int) *querypb.BindVariable {
bv := &querypb.BindVariable{Type: querypb.Type_TUPLE}
outputVals := make([]sqltypes.Value, 0, len(offsets))
for _, row := range res.Rows {
for _, row := range rows {
for _, offset := range offsets {
outputVals = append(outputVals, row[offset])
}
Expand All @@ -111,6 +124,34 @@ func getBVMulti(res *sqltypes.Result, offsets []int) *querypb.BindVariable {
return bv
}

// executeNonLiteralUpdate executes the primitive that needs to be executed per row from the input result.
// The column updated might have different value for each row in the input result.
func executeNonLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, prim Primitive, inputRes *sqltypes.Result, outputCols []int, vars map[string]int) (qr *sqltypes.Result, err error) {
var res *sqltypes.Result
for _, row := range inputRes.Rows {
var bv *querypb.BindVariable
if len(outputCols) == 1 {
bv = getBVSingle([]sqltypes.Row{row}, outputCols[0])
} else {
bv = getBVMulti([]sqltypes.Row{row}, outputCols)
}
bindVars[DmlVals] = bv
for k, v := range vars {
bindVars[k] = sqltypes.ValueBindVariable(row[v])
}
qr, err = vcursor.ExecutePrimitive(ctx, prim, bindVars, false)
if err != nil {
return nil, err
}
if res == nil {
res = qr
} else {
res.RowsAffected += res.RowsAffected
}
}
return res, nil
}

// TryStreamExecute performs a streaming exec.
func (dml *DMLWithInput) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
res, err := dml.TryExecute(ctx, vcursor, bindVars, wantfields)
Expand All @@ -133,6 +174,16 @@ func (dml *DMLWithInput) description() PrimitiveDescription {
other := map[string]any{
"Offset": offsets,
}
var bvList []string
for idx, vars := range dml.BVList {
if len(vars) == 0 {
continue
}
bvList = append(bvList, fmt.Sprintf("%d:[%s]", idx, orderedStringIntMap(vars)))
}
if len(bvList) > 0 {
other["BindVars"] = bvList
}
return PrimitiveDescription{
OperatorType: "DMLWithInput",
TargetTabletType: topodatapb.TabletType_PRIMARY,
Expand Down
9 changes: 9 additions & 0 deletions go/vt/vtgate/engine/plan_description.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/json"
"fmt"
"sort"
"strings"

"vitess.io/vitess/go/tools/graphviz"
"vitess.io/vitess/go/vt/key"
Expand Down Expand Up @@ -266,3 +267,11 @@ func (m orderedMap) MarshalJSON() ([]byte, error) {
buf.WriteString("}")
return buf.Bytes(), nil
}

func (m orderedMap) String() string {
var output []string
for _, val := range m {
output = append(output, fmt.Sprintf("%s:%v", val.key, val.val))
}
return strings.Join(output, " ")
}
2 changes: 2 additions & 0 deletions go/vt/vtgate/planbuilder/dml_with_input.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type dmlWithInput struct {
dmls []logicalPlan

outputCols [][]int
bvList []map[string]int
}

var _ logicalPlan = (*dmlWithInput)(nil)
Expand All @@ -40,5 +41,6 @@ func (d *dmlWithInput) Primitive() engine.Primitive {
DMLs: dels,
Input: inp,
OutputCols: d.outputCols,
BVList: d.bvList,
}
}
1 change: 1 addition & 0 deletions go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ func transformDMLWithInput(ctx *plancontext.PlanningContext, op *operators.DMLWi
input: input,
dmls: dmls,
outputCols: op.Offsets,
bvList: op.BvList,
}, nil
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ func splitGroupingToLeftAndRight(
rhs.addGrouping(ctx, groupBy)
columns.addRight(groupBy.Inner)
case deps.IsSolvedBy(lhs.tableID.Merge(rhs.tableID)):
jc := breakExpressionInLHSandRHSForApplyJoin(ctx, groupBy.Inner, lhs.tableID)
jc := breakExpressionInLHSandRHS(ctx, groupBy.Inner, lhs.tableID)
for _, lhsExpr := range jc.LHSExprs {
e := lhsExpr.Expr
lhs.addGrouping(ctx, NewGroupBy(e))
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sql
rhs := aj.RHS
predicates := sqlparser.SplitAndExpression(nil, expr)
for _, pred := range predicates {
col := breakExpressionInLHSandRHSForApplyJoin(ctx, pred, TableID(aj.LHS))
col := breakExpressionInLHSandRHS(ctx, pred, TableID(aj.LHS))
aj.JoinPredicates.add(col)
ctx.AddJoinPredicates(pred, col.RHSExpr)
rhs = rhs.AddPredicate(ctx, col.RHSExpr)
Expand Down Expand Up @@ -202,7 +202,7 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq
case deps.IsSolvedBy(rhs):
col.RHSExpr = e
case deps.IsSolvedBy(both):
col = breakExpressionInLHSandRHSForApplyJoin(ctx, e, TableID(aj.LHS))
col = breakExpressionInLHSandRHS(ctx, e, TableID(aj.LHS))
default:
panic(vterrors.VT13002(sqlparser.String(e)))
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (jpc *joinPredicateCollector) inspectPredicate(
// then we can use this predicate to connect the subquery to the outer query
if !deps.IsSolvedBy(jpc.subqID) && deps.IsSolvedBy(jpc.totalID) {
jpc.predicates = append(jpc.predicates, predicate)
jc := breakExpressionInLHSandRHSForApplyJoin(ctx, predicate, jpc.outerID)
jc := breakExpressionInLHSandRHS(ctx, predicate, jpc.outerID)
jpc.joinColumns = append(jpc.joinColumns, jc)
pred = jc.RHSExpr
}
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/planbuilder/operators/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ func createDeleteOpWithTarget(ctx *plancontext.PlanningContext, target semantics
Where: sqlparser.NewWhere(sqlparser.WhereClause, compExpr),
}
return dmlOp{
createOperatorFromDelete(ctx, del),
vTbl,
cols,
op: createOperatorFromDelete(ctx, del),
vTbl: vTbl,
cols: cols,
}
}

Expand Down
9 changes: 5 additions & 4 deletions go/vt/vtgate/planbuilder/operators/dml_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ type TargetTable struct {
Name sqlparser.TableName
}

// dmlOp stores intermediary value for Update/Delete Operator with the vindexes.Table for ordering.
// dmlOp stores intermediary value for Update/Delete Operator with the vindexes. Table for ordering.
type dmlOp struct {
op Operator
vTbl *vindexes.Table
cols []*sqlparser.ColName
op Operator
vTbl *vindexes.Table
cols []*sqlparser.ColName
updList updList
}

// sortDmlOps sort the operator based on sharding vindex type.
Expand Down
20 changes: 20 additions & 0 deletions go/vt/vtgate/planbuilder/operators/dml_with_input.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ type DMLWithInput struct {
cols [][]*sqlparser.ColName
Offsets [][]int

updList []updList
BvList []map[string]int

noColumns
noPredicates
}
Expand Down Expand Up @@ -86,6 +89,7 @@ func (d *DMLWithInput) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy {
}

func (d *DMLWithInput) planOffsets(ctx *plancontext.PlanningContext) Operator {
// go through the primary key columns to get offset from the input
offsets := make([][]int, len(d.cols))
for idx, columns := range d.cols {
for _, col := range columns {
Expand All @@ -94,6 +98,22 @@ func (d *DMLWithInput) planOffsets(ctx *plancontext.PlanningContext) Operator {
}
}
d.Offsets = offsets

// go through the update list and get offset for input columns
bvList := make([]map[string]int, len(d.updList))
for idx, ul := range d.updList {
vars := make(map[string]int)
for _, updCol := range ul {
for _, bvExpr := range updCol.jc.LHSExprs {
offset := d.Source.AddColumn(ctx, true, false, aeWrap(bvExpr.Expr))
vars[bvExpr.Name] = offset
}
}
if len(vars) > 0 {
bvList[idx] = vars
}
}
d.BvList = bvList
return d
}

Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import (
"vitess.io/vitess/go/vt/vtgate/semantics"
)

// breakExpressionInLHSandRHSForApplyJoin takes an expression and
// breakExpressionInLHSandRHS takes an expression and
// extracts the parts that are coming from one of the sides into `ColName`s that are needed
func breakExpressionInLHSandRHSForApplyJoin(
func breakExpressionInLHSandRHS(
ctx *plancontext.PlanningContext,
expr sqlparser.Expr,
lhs semantics.TableSet,
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/subquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (sq *SubQuery) GetJoinColumns(ctx *plancontext.PlanningContext, outer Opera
}
sq.outerID = outerID
mapper := func(in sqlparser.Expr) (applyJoinColumn, error) {
return breakExpressionInLHSandRHSForApplyJoin(ctx, in, outerID), nil
return breakExpressionInLHSandRHS(ctx, in, outerID), nil
}
joinPredicates, err := slice.MapWithError(sq.Predicates, mapper)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func extractLHSExpr(
lhs semantics.TableSet,
) func(expr sqlparser.Expr) sqlparser.Expr {
return func(expr sqlparser.Expr) sqlparser.Expr {
col := breakExpressionInLHSandRHSForApplyJoin(ctx, expr, lhs)
col := breakExpressionInLHSandRHS(ctx, expr, lhs)
if col.IsPureLeft() {
panic(vterrors.VT13001("did not expect to find any predicates that do not need data from the inner here"))
}
Expand Down
Loading

0 comments on commit 2283f6b

Please sign in to comment.