Skip to content

Commit

Permalink
schemadiff: beter nil check validation
Browse files Browse the repository at this point in the history
Signed-off-by: Shlomi Noach <[email protected]>
  • Loading branch information
shlomi-noach committed Mar 20, 2024
1 parent d20f3c5 commit 5bd6902
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 44 deletions.
9 changes: 8 additions & 1 deletion go/vt/schemadiff/annotations.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,16 @@ func unifiedAnnotated(from *TextualAnnotations, to *TextualAnnotations) *Textual
// annotatedDiff returns the annotated representations of the from and to entities, and their unified representation.
func annotatedDiff(diff EntityDiff, entityAnnotations *TextualAnnotations) (from *TextualAnnotations, to *TextualAnnotations, unified *TextualAnnotations) {
fromEntity, toEntity := diff.Entities()
// Handle the infamous golang interface is not-nil but underlying object is:
if fromEntity != nil && fromEntity.Create() == nil {
fromEntity = nil
}
if toEntity != nil && toEntity.Create() == nil {
toEntity = nil
}
switch {
case fromEntity == nil && toEntity == nil:
// Should never get here.
// Will only get here if using mockup entities, as generated by EntityDiffByStatement.
return nil, nil, nil
case fromEntity == nil:
// A new entity was created.
Expand Down
2 changes: 1 addition & 1 deletion go/vt/schemadiff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func DiffSchemas(env *Environment, schema1 *Schema, schema2 *Schema, hints *Diff

// EntityDiffByStatement is a helper function that returns a simplified and incomplete EntityDiff based on the given SQL statement.
// It is useful for testing purposes as a quick mean to wrap a statement with a diff.
func EntityDiffByStatement(statement sqlparser.Statement) EntityDiff {
func EntityDiffByStatement(env *Environment, statement sqlparser.Statement) EntityDiff {
switch stmt := statement.(type) {
case *sqlparser.CreateTable:
return &CreateTableEntityDiff{createTable: stmt}
Expand Down
126 changes: 84 additions & 42 deletions go/vt/schemadiff/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func TestDiffTables(t *testing.T) {
assert.NoError(t, err)
require.NotNil(t, d)
require.False(t, d.IsEmpty())
{
t.Run("statement", func(t *testing.T) {
diff := d.StatementString()
assert.Equal(t, ts.diff, diff)
action, err := DDLActionStr(d)
Expand All @@ -389,8 +389,8 @@ func TestDiffTables(t *testing.T) {
if ts.toName != "" {
assert.Equal(t, ts.toName, eTo.Name())
}
}
{
})
t.Run("canonical", func(t *testing.T) {
canonicalDiff := d.CanonicalStatementString()
assert.Equal(t, ts.cdiff, canonicalDiff)
action, err := DDLActionStr(d)
Expand All @@ -400,13 +400,18 @@ func TestDiffTables(t *testing.T) {
// validate we can parse back the statement
_, err = env.Parser().ParseStrictDDL(canonicalDiff)
assert.NoError(t, err)
}
if ts.annotated != nil {
// Optional test for assorted scenarios.
_, _, unified := d.Annotated()
unifiedExport := unified.Export()
assert.Equal(t, ts.annotated, strings.Split(unifiedExport, "\n"))
}
})
t.Run("annotations", func(t *testing.T) {
from, to, unified := d.Annotated()
require.NotNil(t, from)
require.NotNil(t, to)
require.NotNil(t, unified)
if ts.annotated != nil {
// Optional test for assorted scenarios.
unifiedExport := unified.Export()
assert.Equal(t, ts.annotated, strings.Split(unifiedExport, "\n"))
}
})
// let's also check dq, and also validate that dq's statement is identical to d's
assert.NoError(t, dqerr)
require.NotNil(t, dq)
Expand Down Expand Up @@ -1162,39 +1167,76 @@ func TestSchemaApplyError(t *testing.T) {
func TestEntityDiffByStatement(t *testing.T) {
env := NewTestEnv()

{
queries := []string{
"create table t1(id int primary key)",
"alter table t1 add column i int",
"rename table t1 to t2",
"drop table t1",
"create view v1 as select * from t1",
"alter view v1 as select * from t2",
"drop view v1",
}
for _, query := range queries {
t.Run(query, func(t *testing.T) {
stmt, err := env.Parser().ParseStrictDDL(query)
require.NoError(t, err)
entityDiff := EntityDiffByStatement(stmt)
require.NotNil(t, entityDiff)
require.NotNil(t, entityDiff.Statement())
require.Equal(t, stmt, entityDiff.Statement())
})
}
tcases := []struct {
query string
valid bool
expectAnotated bool
}{
{
query: "create table t1(id int primary key)",
valid: true,
expectAnotated: true,
},
{
query: "alter table t1 add column i int",
valid: true,
},
{
query: "rename table t1 to t2",
valid: true,
},
{
query: "drop table t1",
valid: true,
},
{
query: "create view v1 as select * from t1",
valid: true,
expectAnotated: true,
},
{
query: "alter view v1 as select * from t2",
valid: true,
},
{
query: "drop view v1",
valid: true,
},
{
query: "drop database d1",
valid: false,
},
{
query: "optimize table t1",
valid: false,
},
}
{
queries := []string{
"drop database d1",
"optimize table t1",
}
for _, query := range queries {
t.Run(query, func(t *testing.T) {
stmt, err := env.Parser().ParseStrictDDL(query)
require.NoError(t, err)
entityDiff := EntityDiffByStatement(stmt)

for _, tcase := range tcases {
t.Run(tcase.query, func(t *testing.T) {
stmt, err := env.Parser().ParseStrictDDL(tcase.query)
require.NoError(t, err)
entityDiff := EntityDiffByStatement(env, stmt)
if !tcase.valid {
require.Nil(t, entityDiff)
})
}
return
}
require.NotNil(t, entityDiff)
require.NotNil(t, entityDiff.Statement())
require.Equal(t, stmt, entityDiff.Statement())

annotatedFrom, annotatedTo, annotatedUnified := entityDiff.Annotated()
// EntityDiffByStatement doesn't have real entities behind it, just a wrapper around a statement.
// Therefore, there are no annotations.
if tcase.expectAnotated {
assert.NotNil(t, annotatedFrom)
assert.NotNil(t, annotatedTo)
assert.NotNil(t, annotatedUnified)
} else {
assert.Nil(t, annotatedFrom)
assert.Nil(t, annotatedTo)
assert.Nil(t, annotatedUnified)
}
})
}
}
3 changes: 3 additions & 0 deletions go/vt/schemadiff/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -1798,6 +1798,9 @@ func (c *CreateTableEntity) primaryKeyColumns() []*sqlparser.IndexColumn {

// Create implements Entity interface
func (c *CreateTableEntity) Create() EntityDiff {
if c == nil {
return nil
}
return &CreateTableEntityDiff{to: c, createTable: c.CreateTable}
}

Expand Down
3 changes: 3 additions & 0 deletions go/vt/schemadiff/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ func (c *CreateViewEntity) ViewDiff(other *CreateViewEntity, _ *DiffHints) (*Alt

// Create implements Entity interface
func (c *CreateViewEntity) Create() EntityDiff {
if c == nil {
return nil
}
return &CreateViewEntityDiff{createView: c.CreateView}
}

Expand Down

0 comments on commit 5bd6902

Please sign in to comment.