Skip to content

Commit

Permalink
Add support for creating foreign key constraints using `create_constr…
Browse files Browse the repository at this point in the history
…aint` (#471)

This PR introduces a new constraint `type` to `create_constraint`
operation called `foreign_key`. Now it is possible to create FK
constraints on multiple columns.

### Examples

#### Foreign key

```json
{
  "name": "44_add_foreign_key_table_reference_constraint",
  "operations": [
    {
      "create_constraint": {
        "type": "foreign_key",
        "table": "tickets",
        "name": "fk_sellers",
        "columns": [
          "sellers_name",
          "sellers_zip"
        ],
        "references": {
          "table": "sellers",
          "columns": [
            "name",
            "zip"
          ],
          "on_delete": "CASCADE"
        },
        "up": {
          "sellers_name": "sellers_name",
          "sellers_zip": "sellers_zip"
        },
        "down": {
          "sellers_name": "sellers_name",
          "sellers_zip": "sellers_zip"
        }
      }
    }
  ]
}
```

Closes #81

---------

Co-authored-by: Andrew Farries <[email protected]>
  • Loading branch information
kvch and andrew-farries authored Nov 22, 2024
1 parent 81dd0a7 commit a9b2048
Show file tree
Hide file tree
Showing 13 changed files with 420 additions and 43 deletions.
13 changes: 10 additions & 3 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,7 @@ Example **create table** migrations:

A create constraint operation adds a new constraint to an existing table.

Only `UNIQUE` and `CHECK` constraints are supported.
`UNIQUE`, `CHECK` and `FOREIGN KEY` constraints are supported.

Required fields: `name`, `table`, `type`, `up`, `down`.

Expand All @@ -1114,7 +1114,14 @@ Required fields: `name`, `table`, `type`, `up`, `down`.
"table": "name of table",
"name": "my_unique_constraint",
"columns": ["col1", "col2"],
"type": "unique"
"type": "unique"| "check" | "foreign_key",
"check": "SQL expression for CHECK constraint",
"references": {
"name": "name of foreign key reference",
"table": "name of referenced table",
"columns": "[names of referenced columns]",
"on_delete": "ON DELETE behaviour, can be CASCADE, SET NULL, RESTRICT, or NO ACTION. Default is NO ACTION",
},
"up": {
"col1": "col1 || random()",
"col2": "col2 || random()"
Expand All @@ -1131,7 +1138,7 @@ Example **create constraint** migrations:

* [44_add_table_unique_constraint.json](../examples/44_add_table_unique_constraint.json)
* [45_add_table_check_constraint.json](../examples/45_add_table_check_constraint.json)

* [46_add_table_foreign_key_constraint.json](../examples/46_add_table_foreign_key_constraint.json)

### Drop column

Expand Down
1 change: 1 addition & 0 deletions examples/.ledger
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@
44_add_table_unique_constraint.json
45_add_table_check_constraint.json
46_alter_column_drop_default.json
47_add_table_foreign_key_constraint.json
31 changes: 31 additions & 0 deletions examples/47_add_table_foreign_key_constraint.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"name": "47_add_table_foreign_key_constraint",
"operations": [
{
"create_constraint": {
"type": "foreign_key",
"table": "tickets",
"name": "fk_sellers",
"columns": [
"sellers_name",
"sellers_zip"
],
"references": {
"table": "sellers",
"columns": [
"name",
"zip"
]
},
"up": {
"sellers_name": "sellers_name",
"sellers_zip": "sellers_zip"
},
"down": {
"sellers_name": "sellers_name",
"sellers_zip": "sellers_zip"
}
}
}
]
}
64 changes: 32 additions & 32 deletions pkg/migrations/duplicate.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const (
cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)`
cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID`
cAlterTableAddForeignKeySQL = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
)

// NewColumnDuplicator creates a new Duplicator for a column.
Expand Down Expand Up @@ -91,7 +92,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
colNames = append(colNames, name)

// Duplicate the column with the new type
// and check and fk constraints
if sql := d.stmtBuilder.duplicateColumn(c.column, c.asName, c.withoutNotNull, c.withType, d.withoutConstraint); sql != "" {
_, err := d.conn.ExecContext(ctx, sql)
if err != nil {
Expand All @@ -108,6 +108,7 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
}
}

// Duplicate the column's comment
if sql := d.stmtBuilder.duplicateComment(c.column, c.asName); sql != "" {
_, err := d.conn.ExecContext(ctx, sql)
if err != nil {
Expand All @@ -120,7 +121,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
// if the check constraint is not valid for the new column type, in which case
// the error is ignored.
for _, sql := range d.stmtBuilder.duplicateCheckConstraints(d.withoutConstraint, colNames...) {
// Update the check constraint expression to use the new column names if any of the columns are duplicated
_, err := d.conn.ExecContext(ctx, sql)
err = errorIgnoringErrorCode(err, undefinedFunctionErrorCode)
if err != nil {
Expand All @@ -132,12 +132,21 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
// The constraint is duplicated by adding a unique index on the column concurrently.
// The index is converted into a unique constraint on migration completion.
for _, sql := range d.stmtBuilder.duplicateUniqueConstraints(d.withoutConstraint, colNames...) {
// Update the unique constraint columns to use the new column names if any of the columns are duplicated
if _, err := d.conn.ExecContext(ctx, sql); err != nil {
return err
}
}

// Generate SQL to duplicate any foreign key constraints on the columns.
// If the foreign key constraint is not valid for a new column type, the error is ignored.
for _, sql := range d.stmtBuilder.duplicateForeignKeyConstraints(d.withoutConstraint, colNames...) {
_, err := d.conn.ExecContext(ctx, sql)
err = errorIgnoringErrorCode(err, dataTypeMismatchErrorCode)
if err != nil {
return err
}
}

return nil
}

Expand Down Expand Up @@ -175,6 +184,26 @@ func (d *duplicatorStmtBuilder) duplicateUniqueConstraints(withoutConstraint []s
return stmts
}

func (d *duplicatorStmtBuilder) duplicateForeignKeyConstraints(withoutConstraint []string, colNames ...string) []string {
stmts := make([]string, 0, len(d.table.ForeignKeys))
for _, fk := range d.table.ForeignKeys {
if slices.Contains(withoutConstraint, fk.Name) {
continue
}
if duplicatedMember, constraintColumns := d.allConstraintColumns(fk.Columns, colNames...); duplicatedMember {
stmts = append(stmts, fmt.Sprintf(cAlterTableAddForeignKeySQL,
pq.QuoteIdentifier(d.table.Name),
pq.QuoteIdentifier(DuplicationName(fk.Name)),
strings.Join(quoteColumnNames(constraintColumns), ", "),
pq.QuoteIdentifier(fk.ReferencedTable),
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "),
fk.OnDelete,
))
}
}
return stmts
}

// duplicatedConstraintColumns returns a new slice of constraint columns with
// the columns that are duplicated replaced with temporary names.
func (d *duplicatorStmtBuilder) duplicatedConstraintColumns(constraintColumns []string, duplicatedColumns ...string) []string {
Expand Down Expand Up @@ -213,7 +242,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
) string {
const (
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
)

Expand All @@ -232,23 +260,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
)
}

// Generate SQL to duplicate any foreign key constraints on the column
for _, fk := range d.table.ForeignKeys {
if slices.Contains(withoutConstraint, fk.Name) {
continue
}

if slices.Contains(fk.Columns, column.Name) {
sql += fmt.Sprintf(", "+cAddForeignKeySQL,
pq.QuoteIdentifier(DuplicationName(fk.Name)),
strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, column.Name, asName)), ", "),
pq.QuoteIdentifier(fk.ReferencedTable),
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "),
fk.OnDelete,
)
}
}

return sql
}

Expand Down Expand Up @@ -295,17 +306,6 @@ func StripDuplicationPrefix(name string) string {
return strings.TrimPrefix(name, "_pgroll_dup_")
}

func copyAndReplace(xs []string, oldValue, newValue string) []string {
ys := slices.Clone(xs)

for i, c := range ys {
if c == oldValue {
ys[i] = newValue
}
}
return ys
}

func errorIgnoringErrorCode(err error, code pq.ErrorCode) error {
pqErr := &pq.Error{}
if ok := errors.As(err, &pqErr); ok {
Expand Down
53 changes: 53 additions & 0 deletions pkg/migrations/duplicate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ var table = &schema.Table{
"new_york_adults": {Name: "new_york_adults", Columns: []string{"city", "age"}, Definition: `"city" = 'New York' AND "age" > 21`},
"different_nick": {Name: "different_nick", Columns: []string{"name", "nick"}, Definition: `"name" != "nick"`},
},
ForeignKeys: map[string]schema.ForeignKey{
"fk_city": {Name: "fk_city", Columns: []string{"city"}, ReferencedTable: "cities", ReferencedColumns: []string{"id"}, OnDelete: "NO ACTION"},
"fk_name_nick": {Name: "fk_name_nick", Columns: []string{"name", "nick"}, ReferencedTable: "users", ReferencedColumns: []string{"name", "nick"}, OnDelete: "CASCADE"},
},
}

func TestDuplicateStmtBuilderCheckConstraints(t *testing.T) {
Expand Down Expand Up @@ -121,3 +125,52 @@ func TestDuplicateStmtBuilderUniqueConstraints(t *testing.T) {
})
}
}

func TestDuplicateStmtBuilderForeignKeyConstraints(t *testing.T) {
d := &duplicatorStmtBuilder{table}
for name, testCases := range map[string]struct {
columns []string
expectedStmts []string
}{
"duplicate single column with no FK constraint": {
columns: []string{"description"},
expectedStmts: []string{},
},
"single-column FK with single column duplicated": {
columns: []string{"city"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`,
},
},
"single-column FK with multiple columns duplicated": {
columns: []string{"city", "description"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`,
},
},
"multi-column FK with single column duplicated": {
columns: []string{"name"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`,
},
},
"multi-column FK with multiple unrelated column duplicated": {
columns: []string{"name", "description"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`,
},
},
"multi-column FK with multiple columns": {
columns: []string{"name", "nick"},
expectedStmts: []string{`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "_pgroll_new_nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`},
},
} {
t.Run(name, func(t *testing.T) {
stmts := d.duplicateForeignKeyConstraints(nil, testCases.columns...)
assert.Equal(t, len(testCases.expectedStmts), len(stmts))
for _, stmt := range stmts {
assert.Contains(t, testCases.expectedStmts, stmt)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/migrations/op_add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func (w ColumnSQLWriter) Write(col Column) (string, error) {
sql += fmt.Sprintf(" DEFAULT %s", d)
}
if col.References != nil {
onDelete := "NO ACTION"
onDelete := string(ForeignKeyReferenceOnDeleteNOACTION)
if col.References.OnDelete != "" {
onDelete = strings.ToUpper(string(col.References.OnDelete))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/migrations/op_add_column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ func TestAddForeignKeyColumn(t *testing.T) {
Name: "fk_users_id",
Table: "users",
Column: "id",
OnDelete: "CASCADE",
OnDelete: migrations.ForeignKeyReferenceOnDeleteCASCADE,
},
},
},
Expand Down
48 changes: 48 additions & 0 deletions pkg/migrations/op_create_constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema
return table, o.addUniqueIndex(ctx, conn)
case OpCreateConstraintTypeCheck:
return table, o.addCheckConstraint(ctx, conn)
case OpCreateConstraintTypeForeignKey:
return table, o.addForeignKeyConstraint(ctx, conn)
}

return table, nil
Expand Down Expand Up @@ -97,6 +99,17 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTra
if err != nil {
return err
}
case OpCreateConstraintTypeForeignKey:
fkOp := &OpSetForeignKey{
Table: o.Table,
References: ForeignKeyReference{
Name: o.Name,
},
}
err := fkOp.Complete(ctx, conn, tr, s)
if err != nil {
return err
}
}

// remove old columns
Expand Down Expand Up @@ -198,6 +211,22 @@ func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) err
if o.Check == nil || *o.Check == "" {
return FieldRequiredError{Name: "check"}
}
case OpCreateConstraintTypeForeignKey:
if o.References == nil {
return FieldRequiredError{Name: "references"}
}
table := s.GetTable(o.References.Table)
if table == nil {
return TableDoesNotExistError{Name: o.References.Table}
}
for _, col := range o.References.Columns {
if table.GetColumn(col) == nil {
return ColumnDoesNotExistError{
Table: o.References.Table,
Name: col,
}
}
}
}

return nil
Expand All @@ -223,6 +252,25 @@ func (o *OpCreateConstraint) addCheckConstraint(ctx context.Context, conn db.DB)
return err
}

func (o *OpCreateConstraint) addForeignKeyConstraint(ctx context.Context, conn db.DB) error {
onDelete := "NO ACTION"
if o.References.OnDelete != "" {
onDelete = strings.ToUpper(string(o.References.OnDelete))
}

_, err := conn.ExecContext(ctx,
fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Name),
strings.Join(quotedTemporaryNames(o.Columns), ","),
pq.QuoteIdentifier(o.References.Table),
strings.Join(quoteColumnNames(o.References.Columns), ","),
onDelete,
))

return err
}

func quotedTemporaryNames(columns []string) []string {
names := make([]string, len(columns))
for i, col := range columns {
Expand Down
Loading

0 comments on commit a9b2048

Please sign in to comment.