From a9b20489093bfc938efe68ccd70a3d5c49ac478d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?No=C3=A9mi=20V=C3=A1nyi?= Date: Fri, 22 Nov 2024 09:50:43 +0100 Subject: [PATCH] Add support for creating foreign key constraints using `create_constraint` (#471) This PR introduces a new constraint `type` to `create_constraint` operation called `foreign_key`. Now it is possible to create FK constraints on multiple columns. ### Examples #### Foreign key ```json { "name": "44_add_foreign_key_table_reference_constraint", "operations": [ { "create_constraint": { "type": "foreign_key", "table": "tickets", "name": "fk_sellers", "columns": [ "sellers_name", "sellers_zip" ], "references": { "table": "sellers", "columns": [ "name", "zip" ], "on_delete": "CASCADE" }, "up": { "sellers_name": "sellers_name", "sellers_zip": "sellers_zip" }, "down": { "sellers_name": "sellers_name", "sellers_zip": "sellers_zip" } } } ] } ``` Closes https://github.com/xataio/pgroll/issues/81 --------- Co-authored-by: Andrew Farries --- docs/README.md | 13 +- examples/.ledger | 1 + .../47_add_table_foreign_key_constraint.json | 31 +++ pkg/migrations/duplicate.go | 64 +++--- pkg/migrations/duplicate_test.go | 53 +++++ pkg/migrations/op_add_column.go | 2 +- pkg/migrations/op_add_column_test.go | 2 +- pkg/migrations/op_create_constraint.go | 48 +++++ pkg/migrations/op_create_constraint_test.go | 191 ++++++++++++++++++ pkg/migrations/op_set_fk_test.go | 6 +- pkg/migrations/rename.go | 2 + pkg/migrations/types.go | 16 ++ schema.json | 34 +++- 13 files changed, 420 insertions(+), 43 deletions(-) create mode 100644 examples/47_add_table_foreign_key_constraint.json diff --git a/docs/README.md b/docs/README.md index 10cfabd6..2eef3acd 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1102,7 +1102,7 @@ Example **create table** migrations: A create constraint operation adds a new constraint to an existing table. -Only `UNIQUE` and `CHECK` constraints are supported. +`UNIQUE`, `CHECK` and `FOREIGN KEY` constraints are supported. Required fields: `name`, `table`, `type`, `up`, `down`. @@ -1114,7 +1114,14 @@ Required fields: `name`, `table`, `type`, `up`, `down`. "table": "name of table", "name": "my_unique_constraint", "columns": ["col1", "col2"], - "type": "unique" + "type": "unique"| "check" | "foreign_key", + "check": "SQL expression for CHECK constraint", + "references": { + "name": "name of foreign key reference", + "table": "name of referenced table", + "columns": "[names of referenced columns]", + "on_delete": "ON DELETE behaviour, can be CASCADE, SET NULL, RESTRICT, or NO ACTION. Default is NO ACTION", + }, "up": { "col1": "col1 || random()", "col2": "col2 || random()" @@ -1131,7 +1138,7 @@ Example **create constraint** migrations: * [44_add_table_unique_constraint.json](../examples/44_add_table_unique_constraint.json) * [45_add_table_check_constraint.json](../examples/45_add_table_check_constraint.json) - +* [46_add_table_foreign_key_constraint.json](../examples/46_add_table_foreign_key_constraint.json) ### Drop column diff --git a/examples/.ledger b/examples/.ledger index 4999789e..689b88b3 100644 --- a/examples/.ledger +++ b/examples/.ledger @@ -44,3 +44,4 @@ 44_add_table_unique_constraint.json 45_add_table_check_constraint.json 46_alter_column_drop_default.json +47_add_table_foreign_key_constraint.json diff --git a/examples/47_add_table_foreign_key_constraint.json b/examples/47_add_table_foreign_key_constraint.json new file mode 100644 index 00000000..ebd46b11 --- /dev/null +++ b/examples/47_add_table_foreign_key_constraint.json @@ -0,0 +1,31 @@ +{ + "name": "47_add_table_foreign_key_constraint", + "operations": [ + { + "create_constraint": { + "type": "foreign_key", + "table": "tickets", + "name": "fk_sellers", + "columns": [ + "sellers_name", + "sellers_zip" + ], + "references": { + "table": "sellers", + "columns": [ + "name", + "zip" + ] + }, + "up": { + "sellers_name": "sellers_name", + "sellers_zip": "sellers_zip" + }, + "down": { + "sellers_name": "sellers_name", + "sellers_zip": "sellers_zip" + } + } + } + ] +} diff --git a/pkg/migrations/duplicate.go b/pkg/migrations/duplicate.go index a9f1e308..f6b3da0b 100644 --- a/pkg/migrations/duplicate.go +++ b/pkg/migrations/duplicate.go @@ -43,6 +43,7 @@ const ( cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)` cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s` cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID` + cAlterTableAddForeignKeySQL = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s` ) // NewColumnDuplicator creates a new Duplicator for a column. @@ -91,7 +92,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error { colNames = append(colNames, name) // Duplicate the column with the new type - // and check and fk constraints if sql := d.stmtBuilder.duplicateColumn(c.column, c.asName, c.withoutNotNull, c.withType, d.withoutConstraint); sql != "" { _, err := d.conn.ExecContext(ctx, sql) if err != nil { @@ -108,6 +108,7 @@ func (d *Duplicator) Duplicate(ctx context.Context) error { } } + // Duplicate the column's comment if sql := d.stmtBuilder.duplicateComment(c.column, c.asName); sql != "" { _, err := d.conn.ExecContext(ctx, sql) if err != nil { @@ -120,7 +121,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error { // if the check constraint is not valid for the new column type, in which case // the error is ignored. for _, sql := range d.stmtBuilder.duplicateCheckConstraints(d.withoutConstraint, colNames...) { - // Update the check constraint expression to use the new column names if any of the columns are duplicated _, err := d.conn.ExecContext(ctx, sql) err = errorIgnoringErrorCode(err, undefinedFunctionErrorCode) if err != nil { @@ -132,12 +132,21 @@ func (d *Duplicator) Duplicate(ctx context.Context) error { // The constraint is duplicated by adding a unique index on the column concurrently. // The index is converted into a unique constraint on migration completion. for _, sql := range d.stmtBuilder.duplicateUniqueConstraints(d.withoutConstraint, colNames...) { - // Update the unique constraint columns to use the new column names if any of the columns are duplicated if _, err := d.conn.ExecContext(ctx, sql); err != nil { return err } } + // Generate SQL to duplicate any foreign key constraints on the columns. + // If the foreign key constraint is not valid for a new column type, the error is ignored. + for _, sql := range d.stmtBuilder.duplicateForeignKeyConstraints(d.withoutConstraint, colNames...) { + _, err := d.conn.ExecContext(ctx, sql) + err = errorIgnoringErrorCode(err, dataTypeMismatchErrorCode) + if err != nil { + return err + } + } + return nil } @@ -175,6 +184,26 @@ func (d *duplicatorStmtBuilder) duplicateUniqueConstraints(withoutConstraint []s return stmts } +func (d *duplicatorStmtBuilder) duplicateForeignKeyConstraints(withoutConstraint []string, colNames ...string) []string { + stmts := make([]string, 0, len(d.table.ForeignKeys)) + for _, fk := range d.table.ForeignKeys { + if slices.Contains(withoutConstraint, fk.Name) { + continue + } + if duplicatedMember, constraintColumns := d.allConstraintColumns(fk.Columns, colNames...); duplicatedMember { + stmts = append(stmts, fmt.Sprintf(cAlterTableAddForeignKeySQL, + pq.QuoteIdentifier(d.table.Name), + pq.QuoteIdentifier(DuplicationName(fk.Name)), + strings.Join(quoteColumnNames(constraintColumns), ", "), + pq.QuoteIdentifier(fk.ReferencedTable), + strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "), + fk.OnDelete, + )) + } + } + return stmts +} + // duplicatedConstraintColumns returns a new slice of constraint columns with // the columns that are duplicated replaced with temporary names. func (d *duplicatorStmtBuilder) duplicatedConstraintColumns(constraintColumns []string, duplicatedColumns ...string) []string { @@ -213,7 +242,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn( ) string { const ( cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s` - cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s` cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID` ) @@ -232,23 +260,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn( ) } - // Generate SQL to duplicate any foreign key constraints on the column - for _, fk := range d.table.ForeignKeys { - if slices.Contains(withoutConstraint, fk.Name) { - continue - } - - if slices.Contains(fk.Columns, column.Name) { - sql += fmt.Sprintf(", "+cAddForeignKeySQL, - pq.QuoteIdentifier(DuplicationName(fk.Name)), - strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, column.Name, asName)), ", "), - pq.QuoteIdentifier(fk.ReferencedTable), - strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "), - fk.OnDelete, - ) - } - } - return sql } @@ -295,17 +306,6 @@ func StripDuplicationPrefix(name string) string { return strings.TrimPrefix(name, "_pgroll_dup_") } -func copyAndReplace(xs []string, oldValue, newValue string) []string { - ys := slices.Clone(xs) - - for i, c := range ys { - if c == oldValue { - ys[i] = newValue - } - } - return ys -} - func errorIgnoringErrorCode(err error, code pq.ErrorCode) error { pqErr := &pq.Error{} if ok := errors.As(err, &pqErr); ok { diff --git a/pkg/migrations/duplicate_test.go b/pkg/migrations/duplicate_test.go index 13d7dc16..68da00e7 100644 --- a/pkg/migrations/duplicate_test.go +++ b/pkg/migrations/duplicate_test.go @@ -32,6 +32,10 @@ var table = &schema.Table{ "new_york_adults": {Name: "new_york_adults", Columns: []string{"city", "age"}, Definition: `"city" = 'New York' AND "age" > 21`}, "different_nick": {Name: "different_nick", Columns: []string{"name", "nick"}, Definition: `"name" != "nick"`}, }, + ForeignKeys: map[string]schema.ForeignKey{ + "fk_city": {Name: "fk_city", Columns: []string{"city"}, ReferencedTable: "cities", ReferencedColumns: []string{"id"}, OnDelete: "NO ACTION"}, + "fk_name_nick": {Name: "fk_name_nick", Columns: []string{"name", "nick"}, ReferencedTable: "users", ReferencedColumns: []string{"name", "nick"}, OnDelete: "CASCADE"}, + }, } func TestDuplicateStmtBuilderCheckConstraints(t *testing.T) { @@ -121,3 +125,52 @@ func TestDuplicateStmtBuilderUniqueConstraints(t *testing.T) { }) } } + +func TestDuplicateStmtBuilderForeignKeyConstraints(t *testing.T) { + d := &duplicatorStmtBuilder{table} + for name, testCases := range map[string]struct { + columns []string + expectedStmts []string + }{ + "duplicate single column with no FK constraint": { + columns: []string{"description"}, + expectedStmts: []string{}, + }, + "single-column FK with single column duplicated": { + columns: []string{"city"}, + expectedStmts: []string{ + `ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`, + }, + }, + "single-column FK with multiple columns duplicated": { + columns: []string{"city", "description"}, + expectedStmts: []string{ + `ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`, + }, + }, + "multi-column FK with single column duplicated": { + columns: []string{"name"}, + expectedStmts: []string{ + `ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`, + }, + }, + "multi-column FK with multiple unrelated column duplicated": { + columns: []string{"name", "description"}, + expectedStmts: []string{ + `ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`, + }, + }, + "multi-column FK with multiple columns": { + columns: []string{"name", "nick"}, + expectedStmts: []string{`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "_pgroll_new_nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`}, + }, + } { + t.Run(name, func(t *testing.T) { + stmts := d.duplicateForeignKeyConstraints(nil, testCases.columns...) + assert.Equal(t, len(testCases.expectedStmts), len(stmts)) + for _, stmt := range stmts { + assert.Contains(t, testCases.expectedStmts, stmt) + } + }) + } +} diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index 199b98f6..0f506ed6 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -285,7 +285,7 @@ func (w ColumnSQLWriter) Write(col Column) (string, error) { sql += fmt.Sprintf(" DEFAULT %s", d) } if col.References != nil { - onDelete := "NO ACTION" + onDelete := string(ForeignKeyReferenceOnDeleteNOACTION) if col.References.OnDelete != "" { onDelete = strings.ToUpper(string(col.References.OnDelete)) } diff --git a/pkg/migrations/op_add_column_test.go b/pkg/migrations/op_add_column_test.go index a53e673a..6ee898b8 100644 --- a/pkg/migrations/op_add_column_test.go +++ b/pkg/migrations/op_add_column_test.go @@ -626,7 +626,7 @@ func TestAddForeignKeyColumn(t *testing.T) { Name: "fk_users_id", Table: "users", Column: "id", - OnDelete: "CASCADE", + OnDelete: migrations.ForeignKeyReferenceOnDeleteCASCADE, }, }, }, diff --git a/pkg/migrations/op_create_constraint.go b/pkg/migrations/op_create_constraint.go index 846215e0..4d82341e 100644 --- a/pkg/migrations/op_create_constraint.go +++ b/pkg/migrations/op_create_constraint.go @@ -70,6 +70,8 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema return table, o.addUniqueIndex(ctx, conn) case OpCreateConstraintTypeCheck: return table, o.addCheckConstraint(ctx, conn) + case OpCreateConstraintTypeForeignKey: + return table, o.addForeignKeyConstraint(ctx, conn) } return table, nil @@ -97,6 +99,17 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTra if err != nil { return err } + case OpCreateConstraintTypeForeignKey: + fkOp := &OpSetForeignKey{ + Table: o.Table, + References: ForeignKeyReference{ + Name: o.Name, + }, + } + err := fkOp.Complete(ctx, conn, tr, s) + if err != nil { + return err + } } // remove old columns @@ -198,6 +211,22 @@ func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) err if o.Check == nil || *o.Check == "" { return FieldRequiredError{Name: "check"} } + case OpCreateConstraintTypeForeignKey: + if o.References == nil { + return FieldRequiredError{Name: "references"} + } + table := s.GetTable(o.References.Table) + if table == nil { + return TableDoesNotExistError{Name: o.References.Table} + } + for _, col := range o.References.Columns { + if table.GetColumn(col) == nil { + return ColumnDoesNotExistError{ + Table: o.References.Table, + Name: col, + } + } + } } return nil @@ -223,6 +252,25 @@ func (o *OpCreateConstraint) addCheckConstraint(ctx context.Context, conn db.DB) return err } +func (o *OpCreateConstraint) addForeignKeyConstraint(ctx context.Context, conn db.DB) error { + onDelete := "NO ACTION" + if o.References.OnDelete != "" { + onDelete = strings.ToUpper(string(o.References.OnDelete)) + } + + _, err := conn.ExecContext(ctx, + fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s NOT VALID", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(o.Name), + strings.Join(quotedTemporaryNames(o.Columns), ","), + pq.QuoteIdentifier(o.References.Table), + strings.Join(quoteColumnNames(o.References.Columns), ","), + onDelete, + )) + + return err +} + func quotedTemporaryNames(columns []string) []string { names := make([]string, len(columns)) for i, col := range columns { diff --git a/pkg/migrations/op_create_constraint_test.go b/pkg/migrations/op_create_constraint_test.go index 9a372a5d..b265e83a 100644 --- a/pkg/migrations/op_create_constraint_test.go +++ b/pkg/migrations/op_create_constraint_test.go @@ -355,6 +355,147 @@ func TestCreateConstraint(t *testing.T) { }, rows) }, }, + { + name: "create foreign key constraint on multiple columns", + migrations: []migrations.Migration{ + { + Name: "01_add_tables", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "zip", + Type: "integer", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: ptr(false), + }, + { + Name: "email", + Type: "varchar(255)", + Nullable: ptr(false), + }, + }, + }, + &migrations.OpCreateTable{ + Name: "reports", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "users_id", + Type: "integer", + Nullable: ptr(true), + }, + { + Name: "users_zip", + Type: "integer", + Nullable: ptr(true), + }, + { + Name: "description", + Type: "varchar(255)", + Nullable: ptr(false), + }, + }, + }, + }, + }, + { + Name: "02_create_constraint", + Operations: migrations.Operations{ + &migrations.OpCreateConstraint{ + Name: "fk_users", + Table: "reports", + Type: "foreign_key", + Columns: []string{"users_id", "users_zip"}, + References: &migrations.OpCreateConstraintReferences{ + Table: "users", + Columns: []string{"id", "zip"}, + }, + Up: migrations.OpCreateConstraintUp(map[string]string{ + "users_id": "1", + "users_zip": "12345", + }), + Down: migrations.OpCreateConstraintDown(map[string]string{ + "users_id": "users_id", + "users_zip": "users_zip", + }), + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // The new (temporary) column should exist on the underlying table. + ColumnMustExist(t, db, schema, "reports", migrations.TemporaryName("users_id")) + // The new (temporary) column should exist on the underlying table. + ColumnMustExist(t, db, schema, "reports", migrations.TemporaryName("users_zip")) + // A temporary FK constraint has been created on the temporary column + NotValidatedForeignKeyMustExist(t, db, schema, "reports", "fk_users") + + // Insert values to refer to. + MustInsert(t, db, schema, "01_add_tables", "users", map[string]string{ + "name": "alice", + "email": "alice@example.com", + "zip": "12345", + }) + + // Inserting values into the old schema that the violate the fk constraint must succeed. + MustInsert(t, db, schema, "01_add_tables", "reports", map[string]string{ + "description": "random report", + }) + + // Inserting values into the new schema that meet the FK constraint should succeed. + MustInsert(t, db, schema, "02_create_constraint", "reports", map[string]string{ + "description": "alice report", + "users_id": "1", + "users_zip": "12345", + }) + // Inserting data into the new `reports` view with an invalid user reference fails. + MustNotInsert(t, db, schema, "02_create_constraint", "reports", map[string]string{ + "description": "no one report", + "users_id": "100", + "users_zip": "100", + }, testutils.FKViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + // The check constraint must not exists on the table. + CheckConstraintMustNotExist(t, db, schema, "reports", "fk_users") + // Functions, triggers and temporary columns are dropped. + tableCleanedUp(t, db, schema, "reports", "users_id") + tableCleanedUp(t, db, schema, "reports", "users_zip") + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // Functions, triggers and temporary columns are dropped. + tableCleanedUp(t, db, schema, "reports", "users_id") + tableCleanedUp(t, db, schema, "reports", "users_zip") + + // Inserting values into the new schema that the violate the check constraint must fail. + MustNotInsert(t, db, schema, "02_create_constraint", "reports", map[string]string{ + "description": "no one report", + "users_id": "100", + "users_zip": "100", + }, testutils.FKViolationErrorCode) + + rows := MustSelect(t, db, schema, "02_create_constraint", "reports") + assert.Equal(t, []map[string]any{ + {"id": 1, "description": "random report", "users_id": 1, "users_zip": 12345}, + {"id": 2, "description": "alice report", "users_id": 1, "users_zip": 12345}, + }, rows) + }, + }, { name: "invalid constraint name", migrations: []migrations.Migration{ @@ -490,6 +631,56 @@ func TestCreateConstraint(t *testing.T) { afterRollback: func(t *testing.T, db *sql.DB, schema string) {}, afterComplete: func(t *testing.T, db *sql.DB, schema string) {}, }, + { + name: "missing referenced table for foreign key constraint creation", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: ptr(false), + }, + }, + }, + }, + }, + { + Name: "02_create_constraint_with_missing_referenced_table", + Operations: migrations.Operations{ + &migrations.OpCreateConstraint{ + Name: "fk_missing_table", + Table: "users", + Columns: []string{"name"}, + Type: "foreign_key", + References: &migrations.OpCreateConstraintReferences{ + Table: "missing_table", + Columns: []string{"id"}, + }, + Up: migrations.OpCreateConstraintUp(map[string]string{ + "name": "name", + }), + Down: migrations.OpCreateConstraintDown(map[string]string{ + "name": "name", + }), + }, + }, + }, + }, + wantStartErr: migrations.TableDoesNotExistError{Name: "missing_table"}, + afterStart: func(t *testing.T, db *sql.DB, schema string) {}, + afterRollback: func(t *testing.T, db *sql.DB, schema string) {}, + afterComplete: func(t *testing.T, db *sql.DB, schema string) {}, + }, }) } diff --git a/pkg/migrations/op_set_fk_test.go b/pkg/migrations/op_set_fk_test.go index e110ab5d..6c2a129c 100644 --- a/pkg/migrations/op_set_fk_test.go +++ b/pkg/migrations/op_set_fk_test.go @@ -1341,7 +1341,7 @@ func TestSetForeignKeyValidation(t *testing.T) { Name: "fk_users_doesntexist", Table: "users", Column: "id", - OnDelete: "invalid", + OnDelete: migrations.ForeignKeyReferenceOnDelete("invalid"), }, Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", Down: "user_id", @@ -1368,7 +1368,7 @@ func TestSetForeignKeyValidation(t *testing.T) { Name: "fk_users_doesntexist", Table: "users", Column: "id", - OnDelete: "no action", + OnDelete: migrations.ForeignKeyReferenceOnDeleteNOACTION, }, Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", Down: "user_id", @@ -1392,7 +1392,7 @@ func TestSetForeignKeyValidation(t *testing.T) { Name: "fk_users_doesntexist", Table: "users", Column: "id", - OnDelete: "SET NULL", + OnDelete: migrations.ForeignKeyReferenceOnDeleteSETNULL, }, Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", Down: "user_id", diff --git a/pkg/migrations/rename.go b/pkg/migrations/rename.go index 2b27b77e..67d2c776 100644 --- a/pkg/migrations/rename.go +++ b/pkg/migrations/rename.go @@ -57,6 +57,7 @@ func RenameDuplicatedColumn(ctx context.Context, conn db.DB, table *schema.Table if err != nil { return fmt.Errorf("failed to rename foreign key constraint %q: %w", fk.Name, err) } + delete(table.ForeignKeys, fk.Name) } } @@ -88,6 +89,7 @@ func RenameDuplicatedColumn(ctx context.Context, conn db.DB, table *schema.Table if err != nil { return fmt.Errorf("failed to rename check constraint %q: %w", cc.Name, err) } + delete(table.CheckConstraints, cc.Name) // If the constraint is a `NOT NULL` constraint, convert the duplicated // unchecked `NOT NULL` constraint into a `NOT NULL` attribute on the diff --git a/pkg/migrations/types.go b/pkg/migrations/types.go index a8274fd7..59acb96c 100644 --- a/pkg/migrations/types.go +++ b/pkg/migrations/types.go @@ -134,6 +134,9 @@ type OpCreateConstraint struct { // Name of the constraint Name string `json:"name"` + // Reference to the foreign key + References *OpCreateConstraintReferences `json:"references,omitempty"` + // Name of the table Table string `json:"table"` @@ -147,9 +150,22 @@ type OpCreateConstraint struct { // SQL expression of down migration by column type OpCreateConstraintDown map[string]string +// Reference to the foreign key +type OpCreateConstraintReferences struct { + // Columns to reference + Columns []string `json:"columns"` + + // On delete behavior of the foreign key constraint + OnDelete ForeignKeyReferenceOnDelete `json:"on_delete,omitempty"` + + // Name of the table + Table string `json:"table"` +} + type OpCreateConstraintType string const OpCreateConstraintTypeCheck OpCreateConstraintType = "check" +const OpCreateConstraintTypeForeignKey OpCreateConstraintType = "foreign_key" const OpCreateConstraintTypeUnique OpCreateConstraintType = "unique" // SQL expression of up migration by column diff --git a/schema.json b/schema.json index 59640f47..719d80a9 100644 --- a/schema.json +++ b/schema.json @@ -83,14 +83,18 @@ }, "on_delete": { "description": "On delete behavior of the foreign key constraint", - "type": "string", - "enum": ["NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT"], + "$ref": "#/$defs/ForeignKeyReferenceOnDelete", "default": "NO ACTION" } }, "required": ["column", "name", "table"], "type": "object" }, + "ForeignKeyReferenceOnDelete": { + "description": "On delete behavior of the foreign key constraint", + "type": "string", + "enum": ["NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT"] + }, "OpAddColumn": { "additionalProperties": false, "description": "Add column operation", @@ -459,12 +463,36 @@ "type": { "description": "Type of the constraint", "type": "string", - "enum": ["unique", "check"] + "enum": ["unique", "check", "foreign_key"] }, "check": { "description": "Check constraint expression", "type": "string" }, + "references": { + "description": "Reference to the foreign key", + "type": "object", + "additionalProperties": false, + "required": ["table", "columns"], + "properties": { + "table": { + "description": "Name of the table", + "type": "string" + }, + "columns": { + "description": "Columns to reference", + "type": "array", + "items": { + "type": "string" + } + }, + "on_delete": { + "description": "On delete behavior of the foreign key constraint", + "$ref": "#/$defs/ForeignKeyReferenceOnDelete", + "default": "NO ACTION" + } + } + }, "up": { "type": "object", "additionalProperties": { "type": "string" },