Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schemadiff: support valid foreign key cycles #15431

Merged
merged 9 commits into from
Mar 11, 2024
66 changes: 53 additions & 13 deletions go/vt/graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,28 @@ package graph

import (
"fmt"
"maps"
"slices"
"strings"
)

const (
white int = iota
grey
black
)

// Graph is a generic graph implementation.
type Graph[C comparable] struct {
edges map[C][]C
edges map[C][]C
orderedVertices []C
}

// NewGraph creates a new graph for the given comparable type.
func NewGraph[C comparable]() *Graph[C] {
return &Graph[C]{
edges: map[C][]C{},
edges: map[C][]C{},
orderedVertices: []C{},
}
}

Expand All @@ -41,6 +50,7 @@ func (gr *Graph[C]) AddVertex(vertex C) {
return
}
gr.edges[vertex] = []C{}
gr.orderedVertices = append(gr.orderedVertices, vertex)
}

// AddEdge adds an edge to the given Graph.
Expand Down Expand Up @@ -85,35 +95,65 @@ func (gr *Graph[C]) HasCycles() bool {
color := map[C]int{}
for vertex := range gr.edges {
// If any vertex is still white, we initiate a new DFS.
if color[vertex] == 0 {
if gr.hasCyclesDfs(color, vertex) {
if color[vertex] == white {
if hasCycle, _ := gr.hasCyclesDfs(color, vertex); hasCycle {
return true
}
}
}
return false
}

// GetCycles returns all known cycles in the graph.
// It returns a map of vertices to the cycle they are part of.
// We are using a well-known DFS based colouring algorithm to check for cycles.
// Look at https://cp-algorithms.com/graph/finding-cycle.html for more details on the algorithm.
func (gr *Graph[C]) GetCycles() (vertices map[C][]C) {
// If the graph is empty, then we don't need to check anything.
if gr.Empty() {
return nil
}
vertices = make(map[C][]C)
// Initialize the coloring map.
// 0 represents white.
// 1 represents grey.
// 2 represents black.
color := map[C]int{}
for _, vertex := range gr.orderedVertices {
// If any vertex is still white, we initiate a new DFS.
if color[vertex] == white {
// We clone the colors because we wnt full coverage for all vertices.
// Otherwise, the algorithm is optimal and stop more-or-less after the first cycle.
color := maps.Clone(color)
if hasCycle, cycle := gr.hasCyclesDfs(color, vertex); hasCycle {
vertices[vertex] = cycle
}
}
}
return vertices
}

// hasCyclesDfs is a utility function for checking for cycles in a graph.
// It runs a dfs from the given vertex marking each vertex as grey. During the dfs,
// if we encounter a grey vertex, we know we have a cycle. We mark the visited vertices black
// on finishing the dfs.
func (gr *Graph[C]) hasCyclesDfs(color map[C]int, vertex C) bool {
func (gr *Graph[C]) hasCyclesDfs(color map[C]int, vertex C) (bool, []C) {
// Mark the vertex grey.
color[vertex] = 1
color[vertex] = grey
result := []C{vertex}
// Go over all the edges.
for _, end := range gr.edges[vertex] {
// If we encounter a white vertex, we continue the dfs.
if color[end] == 0 {
if gr.hasCyclesDfs(color, end) {
return true
if color[end] == white {
if hasCycle, cycle := gr.hasCyclesDfs(color, end); hasCycle {
return true, append(result, cycle...)
}
} else if color[end] == 1 {
} else if color[end] == grey {
// We encountered a grey vertex, we have a cycle.
return true
return true, append(result, end)
}
}
// Mark the vertex black before finishing
color[vertex] = 2
return false
color[vertex] = black
return false, nil
}
16 changes: 16 additions & 0 deletions go/vt/graph/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func TestStringGraph(t *testing.T) {
wantedGraph string
wantEmpty bool
wantHasCycles bool
wantCycles map[string][]string
}{
{
name: "empty graph",
Expand Down Expand Up @@ -137,6 +138,13 @@ E - F
F - A`,
wantEmpty: false,
wantHasCycles: true,
wantCycles: map[string][]string{
"A": {"A", "B", "E", "F", "A"},
"B": {"B", "E", "F", "A", "B"},
"D": {"D", "E", "F", "A", "B", "E"},
"E": {"E", "F", "A", "B", "E"},
"F": {"F", "A", "B", "E", "F"},
},
},
}
for _, tt := range testcases {
Expand All @@ -148,6 +156,14 @@ F - A`,
require.Equal(t, tt.wantedGraph, graph.PrintGraph())
require.Equal(t, tt.wantEmpty, graph.Empty())
require.Equal(t, tt.wantHasCycles, graph.HasCycles())
if tt.wantCycles == nil {
tt.wantCycles = map[string][]string{}
}
actualCycles := graph.GetCycles()
if actualCycles == nil {
actualCycles = map[string][]string{}
}
require.Equal(t, tt.wantCycles, actualCycles)
})
}
}
51 changes: 46 additions & 5 deletions go/vt/schemadiff/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func TestDiffTables(t *testing.T) {
for _, ts := range tt {
t.Run(ts.name, func(t *testing.T) {
var fromCreateTable *sqlparser.CreateTable
hints := &DiffHints{}
hints := EmptyDiffHints()
if ts.hints != nil {
hints = ts.hints
}
Expand Down Expand Up @@ -448,7 +448,7 @@ func TestDiffViews(t *testing.T) {
name: "none",
},
}
hints := &DiffHints{}
hints := EmptyDiffHints()
env := NewTestEnv()
for _, ts := range tt {
t.Run(ts.name, func(t *testing.T) {
Expand Down Expand Up @@ -545,6 +545,7 @@ func TestDiffSchemas(t *testing.T) {
cdiffs []string
expectError string
tableRename int
fkStrategy int
}{
{
name: "identical tables",
Expand Down Expand Up @@ -799,6 +800,45 @@ func TestDiffSchemas(t *testing.T) {
"CREATE TABLE `t5` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f5` (`i`),\n\tCONSTRAINT `f5` FOREIGN KEY (`i`) REFERENCES `t7` (`id`)\n)",
},
},
{
name: "create tables with foreign keys, with invalid fk reference",
from: "create table t (id int primary key)",
to: `
create table t (id int primary key);
create table t11 (id int primary key, i int, constraint f1101a foreign key (i) references t12 (id) on delete restrict);
create table t12 (id int primary key, i int, constraint f1201a foreign key (i) references t9 (id) on delete set null);
`,
expectError: "table `t12` foreign key references nonexistent table `t9`",
},
{
name: "create tables with foreign keys, with invalid fk reference",
from: "create table t (id int primary key)",
to: `
create table t (id int primary key);
create table t11 (id int primary key, i int, constraint f1101b foreign key (i) references t12 (id) on delete restrict);
create table t12 (id int primary key, i int, constraint f1201b foreign key (i) references t9 (id) on delete set null);
`,
expectError: "table `t12` foreign key references nonexistent table `t9`",
fkStrategy: ForeignKeyCheckStrategyIgnore,
},
{
name: "create tables with foreign keys, with valid cycle",
from: "create table t (id int primary key)",
to: `
create table t (id int primary key);
create table t11 (id int primary key, i int, constraint f1101c foreign key (i) references t12 (id) on delete restrict);
create table t12 (id int primary key, i int, constraint f1201c foreign key (i) references t11 (id) on delete set null);
`,
diffs: []string{
"create table t11 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1101c (i),\n\tconstraint f1101c foreign key (i) references t12 (id) on delete restrict\n)",
"create table t12 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1201c (i),\n\tconstraint f1201c foreign key (i) references t11 (id) on delete set null\n)",
},
cdiffs: []string{
"CREATE TABLE `t11` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1101c` (`i`),\n\tCONSTRAINT `f1101c` FOREIGN KEY (`i`) REFERENCES `t12` (`id`) ON DELETE RESTRICT\n)",
"CREATE TABLE `t12` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1201c` (`i`),\n\tCONSTRAINT `f1201c` FOREIGN KEY (`i`) REFERENCES `t11` (`id`) ON DELETE SET NULL\n)",
},
fkStrategy: ForeignKeyCheckStrategyIgnore,
},
{
name: "drop tables with foreign keys, expect specific order",
from: "create table t7(id int primary key); create table t5 (id int primary key, i int, constraint f5 foreign key (i) references t7(id)); create table t4 (id int primary key, i int, constraint f4 foreign key (i) references t7(id));",
Expand Down Expand Up @@ -932,14 +972,15 @@ func TestDiffSchemas(t *testing.T) {
for _, ts := range tt {
t.Run(ts.name, func(t *testing.T) {
hints := &DiffHints{
TableRenameStrategy: ts.tableRename,
TableRenameStrategy: ts.tableRename,
ForeignKeyCheckStrategy: ts.fkStrategy,
}
diff, err := DiffSchemasSQL(env, ts.from, ts.to, hints)
if ts.expectError != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), ts.expectError)
} else {
assert.NoError(t, err)
require.NoError(t, err)

diffs, err := diff.OrderedDiffs(ctx)
assert.NoError(t, err)
Expand Down Expand Up @@ -1024,7 +1065,7 @@ func TestSchemaApplyError(t *testing.T) {
to: "create table t(id int); create view v1 as select * from t; create view v2 as select * from t",
},
}
hints := &DiffHints{}
hints := EmptyDiffHints()
env := NewTestEnv()
for _, ts := range tt {
t.Run(ts.name, func(t *testing.T) {
Expand Down
25 changes: 20 additions & 5 deletions go/vt/schemadiff/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,31 @@ func (e *ForeignKeyDependencyUnresolvedError) Error() string {

type ForeignKeyLoopError struct {
Table string
Loop []string
Loop []*ForeignKeyTableColumns
}

func (e *ForeignKeyLoopError) Error() string {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shlomi-noach Should we change ForeignKeyLoopError to have a different Loop type? So something like this maybe:

type ForeignKeyColumn struct {
	Table  string
	Column string
}

type ForeignKeyLoopError struct {
	Table string
	Loop  []ForeignKeyColumn
}

This because we now better track this at a column level, so this way we have the details of what the loop looks like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll do so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in d3b8c8e

tableIsInsideLoop := false

escaped := make([]string, len(e.Loop))
for i, t := range e.Loop {
escaped[i] = sqlescape.EscapeID(t)
if t == e.Table {
loop := e.Loop
// The tables in the loop could be e.g.:
// t1->t2->a->b->c->a
// In such case, the loop is a->b->c->a. The last item is always the head & tail of the loop.
// We want to distinguish between the case where the table is inside the loop and the case where it's outside,
// so we remove the prefix of the loop that doesn't participate in the actual cycle.
if len(loop) > 0 {
last := loop[len(loop)-1]
for i := range loop {
if loop[i].Table == last.Table {
loop = loop[i:]
break
}
}
}
escaped := make([]string, len(loop))
for i, fk := range loop {
escaped[i] = fk.Escaped()
if fk.Table == e.Table {
tableIsInsideLoop = true
}
}
Expand Down
Loading
Loading