Skip to content

Commit

Permalink
fix some of the faulty rewrites
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Aug 29, 2024
1 parent fc653a1 commit d8c214b
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 35 deletions.
3 changes: 3 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2391,6 +2391,9 @@ func AndExpressions(exprs ...Expr) Expr {
uniqueAdd(expr)
}
}
if len(unique) == 1 {
return unique[0]
}
return &AndExpr{Predicates: unique}
}
}
Expand Down
43 changes: 23 additions & 20 deletions go/vt/sqlparser/predicate_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ func simplifyNot(expr *NotExpr) (Expr, bool) {
return expr, false
}

func createOrs(exprs ...Expr) Expr {
if len(exprs) == 1 {
return exprs[0]
}
return &OrExpr{Left: exprs[0], Right: createOrs(exprs[1:]...)}
}

func simplifyOr(or *OrExpr) (Expr, bool) {
res, rewritten := distinctOr(or)
if rewritten {
Expand All @@ -100,7 +107,7 @@ func simplifyOr(or *OrExpr) (Expr, bool) {
rand, rok := or.Right.(*AndExpr)

if lok && rok {
// (A AND B) OR (A AND C) => A OR (B AND C)
// (A AND B AND D) OR (A AND C AND D) => (A AND D) AND (B OR C)
var commonPredicates []Expr
var leftRemainder, rightRemainder []Expr

Expand All @@ -125,23 +132,14 @@ func simplifyOr(or *OrExpr) (Expr, bool) {

if len(commonPredicates) > 0 {
// Build the final AndExpr with common predicates and the OrExpr of remainders
var notCommon Expr
switch {
case len(leftRemainder) == 0 && len(rightRemainder) == 0:
// all expressions were common
return AndExpressions(commonPredicates...), true
case len(leftRemainder) == 0:
notCommon = AndExpressions(rightRemainder...)
case len(rightRemainder) == 0:
notCommon = AndExpressions(leftRemainder...)
default:
notCommon = &OrExpr{
Left: AndExpressions(leftRemainder...),
Right: AndExpressions(rightRemainder...),
}
nonCommonPredicates := append(leftRemainder, rightRemainder...)
commonPred := AndExpressions(commonPredicates...)
if len(nonCommonPredicates) == 0 {
return commonPred, true
}
return AndExpressions(append(commonPredicates, notCommon)...), true
return AndExpressions(commonPred, createOrs(nonCommonPredicates...)), true
}
return or, false
}
if !lok && !rok {
lftCmp, lok := or.Left.(*ComparisonExpr)
Expand Down Expand Up @@ -201,6 +199,10 @@ func simplifyXor(xor *XorExpr) (Expr, bool) {
}

func simplifyAnd(expr *AndExpr) (Expr, bool) {
if len(expr.Predicates) == 1 {
return expr.Predicates[0], true
}

res, rewritten := distinctAnd(expr)
if rewritten {
return res, true
Expand All @@ -210,6 +212,7 @@ func simplifyAnd(expr *AndExpr) (Expr, bool) {
simplified := false

// Loop over all predicates in the AndExpr
outer:
for i, andPred := range expr.Predicates {
if or, ok := andPred.(*OrExpr); ok {
// Check if we can simplify by matching with another predicate in the AndExpr
Expand All @@ -223,13 +226,13 @@ func simplifyAnd(expr *AndExpr) (Expr, bool) {
// Found a match, keep the simpler expression (otherPred)
simplifiedPredicates = append(simplifiedPredicates, otherPred)
simplified = true
break
continue outer
}
}
} else {
// No simplification possible, keep the original predicate
simplifiedPredicates = append(simplifiedPredicates, andPred)
}

// No simplification possible, keep the original predicate
simplifiedPredicates = append(simplifiedPredicates, andPred)
}

if simplified {
Expand Down
71 changes: 56 additions & 15 deletions go/vt/vtgate/planbuilder/predicate_rewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import (
"fmt"
"math/rand/v2"
"strconv"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"vitess.io/vitess/go/slice"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
Expand Down Expand Up @@ -82,38 +84,67 @@ func (tc testCase) createPredicate(lvl int) sqlparser.Expr {
panic("unexpected nodeType")
}

func TestOneRewriting(t *testing.T) {
venv := vtenv.NewTestEnv()

// Modify these
const numberOfColumns = 2
const expr = "n1 and n0 or n1 xor n1"

predicate, err := sqlparser.NewTestParser().ParseExpr(expr)
require.NoError(t, err)

simplified := sqlparser.RewritePredicate(predicate)

cfg := &evalengine.Config{
Environment: venv,
Collation: collations.MySQL8().DefaultConnectionCharset(),
ResolveColumn: resolveForFuzz,
}
original, err := evalengine.Translate(predicate, cfg)
require.NoError(t, err)
simpler, err := evalengine.Translate(simplified.(sqlparser.Expr), cfg)
require.NoError(t, err)

env := evalengine.EmptyExpressionEnv(venv)
env.Row = make([]sqltypes.Value, numberOfColumns)
for i := range env.Row {
env.Row[i] = sqltypes.NULL
}

testValues(t, env, 0, original, simpler)
}

func TestFuzzRewriting(t *testing.T) {
// This test, that runs for one second only, will produce lots of random boolean expressions,
// mixing AND, NOT, OR, XOR and column expressions.
// It then takes the predicate and simplifies it
// Finally, it runs both the original and simplified predicate with all combinations of column
// values - trying TRUE, FALSE and NULL. If the two expressions do not return the same value,
// this is considered a test failure.

venv := vtenv.NewTestEnv()
start := time.Now()
for time.Since(start) < 1*time.Second {
tc := testCase{
nodes: rand.IntN(4) + 1,
nodes: 2,
depth: rand.IntN(4) + 1,
}

predicate := tc.createPredicate(0)
name := sqlparser.String(predicate)
t.Run(name, func(t *testing.T) {
venv := vtenv.NewTestEnv()
simplified := sqlparser.RewritePredicate(predicate)

original, err := evalengine.Translate(predicate, &evalengine.Config{
Environment: venv,
Collation: collations.MySQL8().DefaultConnectionCharset(),
ResolveColumn: resolveForFuzz,
})
cfg := &evalengine.Config{
Environment: venv,
Collation: collations.MySQL8().DefaultConnectionCharset(),
ResolveColumn: resolveForFuzz,
NoConstantFolding: true,
NoCompilation: true,
}
original, err := evalengine.Translate(predicate, cfg)
require.NoError(t, err)
simpler, err := evalengine.Translate(simplified.(sqlparser.Expr), &evalengine.Config{
Environment: venv,
Collation: collations.MySQL8().DefaultConnectionCharset(),
ResolveColumn: resolveForFuzz,
})
simpler, err := evalengine.Translate(simplified.(sqlparser.Expr), cfg)
require.NoError(t, err)

env := evalengine.EmptyExpressionEnv(venv)
Expand Down Expand Up @@ -142,7 +173,17 @@ func testValues(t *testing.T, env *evalengine.ExpressionEnv, i int, original, si
require.NoError(t, err)
v2, err := env.Evaluate(simpler)
require.NoError(t, err)
assert.Equal(t, v1.Value(collations.MySQL8().DefaultConnectionCharset()), v2.Value(collations.MySQL8().DefaultConnectionCharset()))
v1Value := v1.Value(collations.MySQL8().DefaultConnectionCharset())
v2Value := v2.Value(collations.MySQL8().DefaultConnectionCharset())
row := strings.Join(slice.Map(env.Row, func(i sqltypes.Value) string {
return i.String()
}), " | ")
msg := fmt.Sprintf("original: %v (%s)\nsimplified: %v (%s)\nrow: %v", sqlparser.String(original), v1Value.String(), sqlparser.String(simpler), v2Value.String(), row)
require.True(
t,
v1Value.Equal(v2Value),
msg,
)
if len(env.Row) > i+1 {
testValues(t, env, i+1, original, simpler)
}
Expand Down

0 comments on commit d8c214b

Please sign in to comment.