From 4733045ecd77db7ef9840f7a43586c7fb56f6e13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Thu, 23 Nov 2023 11:20:43 +0100 Subject: [PATCH] bug fix: stop all kinds of expressions from cnf-exploding (#14585) Signed-off-by: Andres Taylor --- go/vt/sqlparser/predicate_rewriting.go | 391 ++++++++------------ go/vt/sqlparser/predicate_rewriting_test.go | 9 +- 2 files changed, 157 insertions(+), 243 deletions(-) diff --git a/go/vt/sqlparser/predicate_rewriting.go b/go/vt/sqlparser/predicate_rewriting.go index 9dcd239f9eb..7bad1b3b82f 100644 --- a/go/vt/sqlparser/predicate_rewriting.go +++ b/go/vt/sqlparser/predicate_rewriting.go @@ -16,32 +16,14 @@ limitations under the License. package sqlparser -import ( - "vitess.io/vitess/go/vt/log" -) - -// This is the number of OR expressions in a predicate that will disable the CNF -// rewrite because we don't want to send large queries to MySQL -const CNFOrLimit = 5 - // RewritePredicate walks the input AST and rewrites any boolean logic into a simpler form // This simpler form is CNF plus logic for extracting predicates from OR, plus logic for turning ORs into IN -// Note: In order to re-plan, we need to empty the accumulated metadata in the AST, -// so ColName.Metadata will be nil:ed out as part of this rewrite func RewritePredicate(ast SQLNode) SQLNode { - count := 0 - _ = Walk(func(node SQLNode) (bool, error) { - if _, isExpr := node.(*OrExpr); isExpr { - count++ - } - - return true, nil - }, ast) - - allowCNF := count < CNFOrLimit + original := CloneSQLNode(ast) - for { - printExpr(ast) + // Beware: converting to CNF in this loop might cause exponential formula growth. + // We bail out early to prevent going overboard. + for loop := 0; loop < 15; loop++ { exprChanged := false stopOnChange := func(SQLNode, SQLNode) bool { return !exprChanged @@ -52,9 +34,8 @@ func RewritePredicate(ast SQLNode) SQLNode { return true } - rewritten, state := simplifyExpression(e, allowCNF) - if ch, isChange := state.(changed); isChange { - printRule(ch.rule, ch.exprMatched) + rewritten, changed := simplifyExpression(e) + if changed { exprChanged = true cursor.Replace(rewritten) } @@ -65,70 +46,44 @@ func RewritePredicate(ast SQLNode) SQLNode { return ast } } + + return original } -func simplifyExpression(expr Expr, allowCNF bool) (Expr, rewriteState) { +func simplifyExpression(expr Expr) (Expr, bool) { switch expr := expr.(type) { case *NotExpr: return simplifyNot(expr) case *OrExpr: - return simplifyOr(expr, allowCNF) + return simplifyOr(expr) case *XorExpr: return simplifyXor(expr) case *AndExpr: return simplifyAnd(expr) } - return expr, noChange{} + return expr, false } -func simplifyNot(expr *NotExpr) (Expr, rewriteState) { +func simplifyNot(expr *NotExpr) (Expr, bool) { switch child := expr.Expr.(type) { case *NotExpr: - return child.Expr, - newChange("NOT NOT A => A", f(expr)) + return child.Expr, true case *OrExpr: - return &AndExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, - newChange("NOT (A OR B) => NOT A AND NOT B", f(expr)) + // not(or(a,b)) => and(not(a),not(b)) + return &AndExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true case *AndExpr: - return &OrExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, - newChange("NOT (A AND B) => NOT A OR NOT B", f(expr)) + // not(and(a,b)) => or(not(a), not(b)) + return &OrExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true } - return expr, noChange{} + return expr, false } -// ExtractINFromOR will add additional predicated to an OR. -// this rewriter should not be used in a fixed point way, since it returns the original expression with additions, -// and it will therefor OOM before it stops rewriting -func ExtractINFromOR(expr *OrExpr) []Expr { - // we check if we have two comparisons on either side of the OR - // that we can add as an ANDed comparison. - // WHERE (a = 5 and B) or (a = 6 AND C) => - // WHERE (a = 5 AND B) OR (a = 6 AND C) AND a IN (5,6) - // This rewrite makes it possible to find a better route than Scatter if the `a` column has a helpful vindex - lftPredicates := SplitAndExpression(nil, expr.Left) - rgtPredicates := SplitAndExpression(nil, expr.Right) - var ins []Expr - for _, lft := range lftPredicates { - l, ok := lft.(*ComparisonExpr) - if !ok { - continue - } - for _, rgt := range rgtPredicates { - r, ok := rgt.(*ComparisonExpr) - if !ok { - continue - } - in, state := tryTurningOrIntoIn(l, r) - if state.changed() { - ins = append(ins, in) - } - } +func simplifyOr(expr *OrExpr) (Expr, bool) { + res, rewritten := distinctOr(expr) + if rewritten { + return res, true } - return uniquefy(ins) -} - -func simplifyOr(expr *OrExpr, allowCNF bool) (Expr, rewriteState) { or := expr // first we search for ANDs and see how they can be simplified @@ -137,25 +92,21 @@ func simplifyOr(expr *OrExpr, allowCNF bool) (Expr, rewriteState) { if lok && rok { // (<> AND <>) OR (<> AND <>) + // or(and(T1,T2), and(T2, T3)) => and(T1, or(T2, T2)) var a, b, c Expr - var change changed switch { case Equals.Expr(land.Left, rand.Left): - change = newChange("(A and B) or (A and C) => A AND (B OR C)", f(expr)) a, b, c = land.Left, land.Right, rand.Right - return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, change + return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, true case Equals.Expr(land.Left, rand.Right): - change = newChange("(A and B) or (C and A) => A AND (B OR C)", f(expr)) a, b, c = land.Left, land.Right, rand.Left - return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, change + return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, true case Equals.Expr(land.Right, rand.Left): - change = newChange("(B and A) or (A and C) => A AND (B OR C)", f(expr)) a, b, c = land.Right, land.Left, rand.Right - return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, change + return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, true case Equals.Expr(land.Right, rand.Right): - change = newChange("(B and A) or (C and A) => A AND (B OR C)", f(expr)) a, b, c = land.Right, land.Left, rand.Left - return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, change + return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, true } } @@ -163,31 +114,38 @@ func simplifyOr(expr *OrExpr, allowCNF bool) (Expr, rewriteState) { if lok { // Simplification if Equals.Expr(or.Right, land.Left) || Equals.Expr(or.Right, land.Right) { - return or.Right, newChange("(A AND B) OR A => A", f(expr)) + // or(and(a,b), c) => c where c=a or c=b + return or.Right, true } - if allowCNF { - // Distribution Law - return &AndExpr{Left: &OrExpr{Left: land.Left, Right: or.Right}, Right: &OrExpr{Left: land.Right, Right: or.Right}}, - newChange("(A AND B) OR C => (A OR C) AND (B OR C)", f(expr)) - } + // Distribution Law + // or(c, and(a,b)) => and(or(c,a), or(c,b)) + return &AndExpr{ + Left: &OrExpr{ + Left: land.Left, + Right: or.Right, + }, + Right: &OrExpr{ + Left: land.Right, + Right: or.Right, + }, + }, true } // <> OR (<> AND <>) if rok { // Simplification if Equals.Expr(or.Left, rand.Left) || Equals.Expr(or.Left, rand.Right) { - return or.Left, newChange("A OR (A AND B) => A", f(expr)) + // or(a,and(b,c)) => a + return or.Left, true } - if allowCNF { - // Distribution Law - return &AndExpr{ - Left: &OrExpr{Left: or.Left, Right: rand.Left}, - Right: &OrExpr{Left: or.Left, Right: rand.Right}, - }, - newChange("C OR (A AND B) => (C OR A) AND (C OR B)", f(expr)) - } + // Distribution Law + // or(and(a,b), c) => and(or(c,a), or(c,b)) + return &AndExpr{ + Left: &OrExpr{Left: or.Left, Right: rand.Left}, + Right: &OrExpr{Left: or.Left, Right: rand.Right}, + }, true } // next, we want to try to turn multiple ORs into an IN when possible @@ -195,63 +153,131 @@ func simplifyOr(expr *OrExpr, allowCNF bool) (Expr, rewriteState) { rgtCmp, rok := or.Right.(*ComparisonExpr) if lok && rok { newExpr, rewritten := tryTurningOrIntoIn(lftCmp, rgtCmp) - if rewritten.changed() { - return newExpr, rewritten + if rewritten { + // or(a=x,a=y) => in(a,[x,y]) + return newExpr, true } } // Try to make distinct - return distinctOr(expr) + result, changed := distinctOr(expr) + if changed { + return result, true + } + return result, false +} + +func simplifyXor(expr *XorExpr) (Expr, bool) { + // xor(a,b) => and(or(a,b), not(and(a,b)) + return &AndExpr{ + Left: &OrExpr{Left: expr.Left, Right: expr.Right}, + Right: &NotExpr{Expr: &AndExpr{Left: expr.Left, Right: expr.Right}}, + }, true } -func tryTurningOrIntoIn(l, r *ComparisonExpr) (Expr, rewriteState) { +func simplifyAnd(expr *AndExpr) (Expr, bool) { + res, rewritten := distinctAnd(expr) + if rewritten { + return res, true + } + and := expr + if or, ok := and.Left.(*OrExpr); ok { + // Simplification + // and(or(a,b),c) => c when c=a or c=b + if Equals.Expr(or.Left, and.Right) { + return and.Right, true + } + if Equals.Expr(or.Right, and.Right) { + return and.Right, true + } + } + if or, ok := and.Right.(*OrExpr); ok { + // Simplification + if Equals.Expr(or.Left, and.Left) { + return and.Left, true + } + if Equals.Expr(or.Right, and.Left) { + return and.Left, true + } + } + + return expr, false +} + +// ExtractINFromOR will add additional predicated to an OR. +// this rewriter should not be used in a fixed point way, since it returns the original expression with additions, +// and it will therefor OOM before it stops rewriting +func ExtractINFromOR(expr *OrExpr) []Expr { + // we check if we have two comparisons on either side of the OR + // that we can add as an ANDed comparison. + // WHERE (a = 5 and B) or (a = 6 AND C) => + // WHERE (a = 5 AND B) OR (a = 6 AND C) AND a IN (5,6) + // This rewrite makes it possible to find a better route than Scatter if the `a` column has a helpful vindex + lftPredicates := SplitAndExpression(nil, expr.Left) + rgtPredicates := SplitAndExpression(nil, expr.Right) + var ins []Expr + for _, lft := range lftPredicates { + l, ok := lft.(*ComparisonExpr) + if !ok { + continue + } + for _, rgt := range rgtPredicates { + r, ok := rgt.(*ComparisonExpr) + if !ok { + continue + } + in, changed := tryTurningOrIntoIn(l, r) + if changed { + ins = append(ins, in) + } + } + } + + return uniquefy(ins) +} + +func tryTurningOrIntoIn(l, r *ComparisonExpr) (Expr, bool) { // looks for A = X OR A = Y and turns them into A IN (X, Y) col, ok := l.Left.(*ColName) if !ok || !Equals.Expr(col, r.Left) { - return nil, noChange{} + return nil, false } var tuple ValTuple - var ruleStr string + switch l.Operator { case EqualOp: tuple = ValTuple{l.Right} - ruleStr = "A = <>" case InOp: lft, ok := l.Right.(ValTuple) if !ok { - return nil, noChange{} + return nil, false } tuple = lft - ruleStr = "A IN (<>, <>)" default: - return nil, noChange{} + return nil, false } - ruleStr += " OR " - switch r.Operator { case EqualOp: tuple = append(tuple, r.Right) - ruleStr += "A = <>" + case InOp: lft, ok := r.Right.(ValTuple) if !ok { - return nil, noChange{} + return nil, false } tuple = append(tuple, lft...) - ruleStr += "A IN (<>, <>)" + default: - return nil, noChange{} + return nil, false } - ruleStr += " => A IN (<>, <>)" - return &ComparisonExpr{ Operator: InOp, Left: col, Right: uniquefy(tuple), - }, newChange(ruleStr, f(&OrExpr{Left: l, Right: r})) + }, true } func uniquefy(tuple ValTuple) (output ValTuple) { @@ -267,44 +293,7 @@ outer: return } -func simplifyXor(expr *XorExpr) (Expr, rewriteState) { - // DeMorgan Rewriter - return &AndExpr{ - Left: &OrExpr{Left: expr.Left, Right: expr.Right}, - Right: &NotExpr{Expr: &AndExpr{Left: expr.Left, Right: expr.Right}}, - }, newChange("(A XOR B) => (A OR B) AND NOT (A AND B)", f(expr)) -} - -func simplifyAnd(expr *AndExpr) (Expr, rewriteState) { - res, rewritten := distinctAnd(expr) - if rewritten.changed() { - return res, rewritten - } - and := expr - if or, ok := and.Left.(*OrExpr); ok { - // Simplification - if Equals.Expr(or.Left, and.Right) { - return and.Right, newChange("(A OR B) AND A => A", f(expr)) - } - if Equals.Expr(or.Right, and.Right) { - return and.Right, newChange("(A OR B) AND B => B", f(expr)) - } - } - if or, ok := and.Right.(*OrExpr); ok { - // Simplification - if Equals.Expr(or.Left, and.Left) { - return and.Left, newChange("A AND (A OR B) => A", f(expr)) - } - if Equals.Expr(or.Right, and.Left) { - return and.Left, newChange("A AND (B OR A) => A", f(expr)) - } - } - - return expr, noChange{} -} - -func distinctOr(in *OrExpr) (Expr, rewriteState) { - var skipped []*OrExpr +func distinctOr(in *OrExpr) (result Expr, changed bool) { todo := []*OrExpr{in} var leaves []Expr for len(todo) > 0 { @@ -321,27 +310,23 @@ func distinctOr(in *OrExpr) (Expr, rewriteState) { addAnd(curr.Left) addAnd(curr.Right) } - original := len(leaves) + var predicates []Expr outer1: - for len(leaves) > 0 { - curr := leaves[0] - leaves = leaves[1:] + for _, curr := range leaves { for _, alreadyIn := range predicates { if Equals.Expr(alreadyIn, curr) { - if log.V(0) { - skipped = append(skipped, &OrExpr{Left: alreadyIn, Right: curr}) - } + changed = true continue outer1 } } predicates = append(predicates, curr) } - if original == len(predicates) { - return in, noChange{} + if !changed { + return in, false } - var result Expr + for i, curr := range predicates { if i == 0 { result = curr @@ -350,25 +335,10 @@ outer1: result = &OrExpr{Left: result, Right: curr} } - return result, newChange("A OR A => A", func() Expr { - var result Expr - for _, orExpr := range skipped { - if result == nil { - result = orExpr - continue - } - - result = &OrExpr{ - Left: result, - Right: orExpr, - } - } - return result - }) + return } -func distinctAnd(in *AndExpr) (Expr, rewriteState) { - var skipped []*AndExpr +func distinctAnd(in *AndExpr) (result Expr, changed bool) { todo := []*AndExpr{in} var leaves []Expr for len(todo) > 0 { @@ -384,25 +354,23 @@ func distinctAnd(in *AndExpr) (Expr, rewriteState) { addExpr(curr.Left) addExpr(curr.Right) } - original := len(leaves) var predicates []Expr outer1: for _, curr := range leaves { for _, alreadyIn := range predicates { if Equals.Expr(alreadyIn, curr) { - if log.V(0) { - skipped = append(skipped, &AndExpr{Left: alreadyIn, Right: curr}) - } + changed = true continue outer1 } } predicates = append(predicates, curr) } - if original == len(predicates) { - return in, noChange{} + + if !changed { + return in, false } - var result Expr + for i, curr := range predicates { if i == 0 { result = curr @@ -410,62 +378,5 @@ outer1: } result = &AndExpr{Left: result, Right: curr} } - return AndExpressions(leaves...), newChange("A AND A => A", func() Expr { - var result Expr - for _, andExpr := range skipped { - if result == nil { - result = andExpr - continue - } - - result = &AndExpr{ - Left: result, - Right: andExpr, - } - } - return result - }) -} - -type ( - rewriteState interface { - changed() bool - } - noChange struct{} - - // changed makes it possible to make sure we have a rule string for each change we do in the expression tree - changed struct { - rule string - - // ExprMatched is a function here so building of this expression can be paid only when we are debug logging - exprMatched func() Expr - } -) - -func (noChange) changed() bool { return false } -func (changed) changed() bool { return true } - -// f returns a function that returns the expression. It's short by design, so it interferes minimally -// used for logging -func f(e Expr) func() Expr { - return func() Expr { return e } -} - -func printRule(rule string, expr func() Expr) { - if log.V(10) { - log.Infof("Rule: %s ON %s", rule, String(expr())) - } -} - -func printExpr(expr SQLNode) { - if log.V(10) { - log.Infof("Current: %s", String(expr)) - } -} - -func newChange(rule string, exprMatched func() Expr) changed { - return changed{ - rule: rule, - exprMatched: exprMatched, - } + return AndExpressions(leaves...), true } diff --git a/go/vt/sqlparser/predicate_rewriting_test.go b/go/vt/sqlparser/predicate_rewriting_test.go index fba3d2f01dd..e106a56f1aa 100644 --- a/go/vt/sqlparser/predicate_rewriting_test.go +++ b/go/vt/sqlparser/predicate_rewriting_test.go @@ -91,8 +91,8 @@ func TestSimplifyExpression(in *testing.T) { expr, err := ParseExpr(tc.in) require.NoError(t, err) - expr, didRewrite := simplifyExpression(expr, true) - assert.True(t, didRewrite.changed()) + expr, changed := simplifyExpression(expr) + assert.True(t, changed) assert.Equal(t, tc.expected, String(expr)) }) } @@ -137,9 +137,12 @@ func TestRewritePredicate(in *testing.T) { in: "(a = 1 and b = 41) or (a = 2 and b = 42) or (a = 3 and b = 43)", expected: "a in (1, 2, 3) and (a in (1, 2) or b = 43) and ((a = 1 or b = 42 or a = 3) and (a = 1 or b = 42 or b = 43)) and ((b = 41 or a = 2 or a = 3) and (b = 41 or a = 2 or b = 43) and ((b in (41, 42) or a = 3) and b in (41, 42, 43)))", }, { - // this has too many OR expressions in it, so we don't even try the CNF rewriting + // the following two tests show some pathological cases that would grow too much, and so we abort the rewriting in: "a = 1 and b = 41 or a = 2 and b = 42 or a = 3 and b = 43 or a = 4 and b = 44 or a = 5 and b = 45 or a = 6 and b = 46", expected: "a = 1 and b = 41 or a = 2 and b = 42 or a = 3 and b = 43 or a = 4 and b = 44 or a = 5 and b = 45 or a = 6 and b = 46", + }, { + in: "not n0 xor not (n2 and n3) xor (not n2 and (n1 xor n1) xor (n0 xor n0 xor n2))", + expected: "not n0 xor not (n2 and n3) xor (not n2 and (n1 xor n1) xor (n0 xor n0 xor n2))", }} for _, tc := range tests {