diff --git a/docs/README.md b/docs/README.md index 29996378..2094d0c3 100644 --- a/docs/README.md +++ b/docs/README.md @@ -27,6 +27,7 @@ * [Add unique constraint](#add-unique-constraint) * [Create index](#create-index) * [Create table](#create-table) + * [Create constraint](#create-constraint) * [Drop column](#drop-column) * [Drop constraint](#drop-constraint) * [Drop index](#drop-index) @@ -687,6 +688,7 @@ See the [examples](../examples) directory for examples of each kind of operation * [Add unique constraint](#add-unique-constraint) * [Create index](#create-index) * [Create table](#create-table) +* [Create constraint](#create-constraint) * [Drop column](#drop-column) * [Drop constraint](#drop-constraint) * [Drop index](#drop-index) @@ -1037,6 +1039,40 @@ Example **create table** migrations: * [25_add_table_with_check_constraint.json](../examples/25_add_table_with_check_constraint.json) * [28_different_defaults.json](../examples/28_different_defaults.json) +### Create constraint + +A create constraint operation adds a new constraint to an existing table. + +Only `UNIQUE` constraints are supported. + +Required fields: `name`, `table`, `type`, `up`, `down`. + +**create constraint** operations have this structure: + +```json +{ + "create_constraint": { + "table": "name of table", + "name": "my_unique_constraint", + "columns": ["col1", "col2"], + "type": "unique" + "up": { + "col1": "col1 || random()", + "col2": "col2 || random()" + }, + "down": { + "col1": "col1", + "col2": "col2" + } + } +} +``` + +Example **create constraint** migrations: + +* [44_add_table_unique_constraint.json](../examples/44_add_table_unique_constraint.json) + + ### Drop column A drop column operation drops a column from an existing table. diff --git a/examples/.ledger b/examples/.ledger index aad357af..a8067c2a 100644 --- a/examples/.ledger +++ b/examples/.ledger @@ -40,3 +40,5 @@ 40_create_enum_type.json 41_add_enum_column.json 42_create_unique_index.json +43_create_tickets_table.json +44_add_table_unique_constraint.json diff --git a/examples/43_create_tickets_table.json b/examples/43_create_tickets_table.json new file mode 100644 index 00000000..be4be181 --- /dev/null +++ b/examples/43_create_tickets_table.json @@ -0,0 +1,25 @@ +{ + "name": "43_create_tickets_table", + "operations": [ + { + "create_table": { + "name": "tickets", + "columns": [ + { + "name": "ticket_id", + "type": "serial", + "pk": true + }, + { + "name": "sellers_name", + "type": "varchar(255)" + }, + { + "name": "sellers_zip", + "type": "integer" + } + ] + } + } + ] +} diff --git a/examples/44_add_table_unique_constraint.json b/examples/44_add_table_unique_constraint.json new file mode 100644 index 00000000..3ffa9433 --- /dev/null +++ b/examples/44_add_table_unique_constraint.json @@ -0,0 +1,24 @@ +{ + "name": "44_add_table_unique_constraint", + "operations": [ + { + "create_constraint": { + "type": "unique", + "table": "tickets", + "name": "unique_zip_name", + "columns": [ + "sellers_name", + "sellers_zip" + ], + "up": { + "sellers_name": "sellers_name", + "sellers_zip": "sellers_zip" + }, + "down": { + "sellers_name": "sellers_name", + "sellers_zip": "sellers_zip" + } + } + } + ] +} diff --git a/pkg/migrations/errors.go b/pkg/migrations/errors.go index a9566fa6..344d9fab 100644 --- a/pkg/migrations/errors.go +++ b/pkg/migrations/errors.go @@ -54,6 +54,15 @@ func (e ColumnDoesNotExistError) Error() string { return fmt.Sprintf("column %q does not exist on table %q", e.Name, e.Table) } +type ColumnMigrationMissingError struct { + Table string + Name string +} + +func (e ColumnMigrationMissingError) Error() string { + return fmt.Sprintf("migration for column %q in %q is missing", e.Name, e.Table) +} + type ColumnIsNotNullableError struct { Table string Name string diff --git a/pkg/migrations/op_common.go b/pkg/migrations/op_common.go index 364e70f0..5c732fee 100644 --- a/pkg/migrations/op_common.go +++ b/pkg/migrations/op_common.go @@ -24,6 +24,7 @@ const ( OpNameDropConstraint OpName = "drop_constraint" OpNameSetReplicaIdentity OpName = "set_replica_identity" OpRawSQLName OpName = "sql" + OpCreateConstraintName OpName = "create_constraint" // Internal operation types used by `alter_column` OpNameRenameColumn OpName = "rename_column" @@ -124,6 +125,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error { case OpRawSQLName: item = &OpRawSQL{} + case OpCreateConstraintName: + item = &OpCreateConstraint{} + default: return fmt.Errorf("unknown migration type: %v", opName) } @@ -210,6 +214,9 @@ func OperationName(op Operation) OpName { case *OpRawSQL: return OpRawSQLName + case *OpCreateConstraint: + return OpCreateConstraintName + } panic(fmt.Errorf("unknown operation for %T", op)) diff --git a/pkg/migrations/op_create_constraint.go b/pkg/migrations/op_create_constraint.go new file mode 100644 index 00000000..ff8613a7 --- /dev/null +++ b/pkg/migrations/op_create_constraint.go @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: Apache-2.0 + +package migrations + +import ( + "context" + "fmt" + "strings" + + "github.com/lib/pq" + + "github.com/xataio/pgroll/pkg/db" + "github.com/xataio/pgroll/pkg/schema" +) + +var _ Operation = (*OpCreateConstraint)(nil) + +func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { + var err error + var table *schema.Table + for _, col := range o.Columns { + if table, err = o.duplicateColumnBeforeStart(ctx, conn, latestSchema, tr, col, s); err != nil { + return nil, err + } + } + + switch o.Type { //nolint:gocritic // more cases will be added + case OpCreateConstraintTypeUnique: + return table, o.addUniqueIndex(ctx, conn) + } + + return table, nil +} + +func (o *OpCreateConstraint) duplicateColumnBeforeStart(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, colName string, s *schema.Schema) (*schema.Table, error) { + table := s.GetTable(o.Table) + column := table.GetColumn(colName) + + d := NewColumnDuplicator(conn, table, column) + if err := d.Duplicate(ctx); err != nil { + return nil, fmt.Errorf("failed to duplicate column for new constraint: %w", err) + } + + upSQL, ok := o.Up[colName] + if !ok { + return nil, fmt.Errorf("up migration is missing for column %s", colName) + } + physicalColumnName := TemporaryName(colName) + err := createTrigger(ctx, conn, tr, triggerConfig{ + Name: TriggerName(o.Table, colName), + Direction: TriggerDirectionUp, + Columns: table.Columns, + SchemaName: s.Name, + LatestSchema: latestSchema, + TableName: o.Table, + PhysicalColumn: physicalColumnName, + SQL: upSQL, + }) + if err != nil { + return nil, fmt.Errorf("failed to create up trigger: %w", err) + } + + table.AddColumn(colName, schema.Column{ + Name: physicalColumnName, + }) + + downSQL, ok := o.Down[colName] + if !ok { + return nil, fmt.Errorf("down migration is missing for column %s", colName) + } + err = createTrigger(ctx, conn, tr, triggerConfig{ + Name: TriggerName(o.Table, physicalColumnName), + Direction: TriggerDirectionDown, + Columns: table.Columns, + LatestSchema: latestSchema, + SchemaName: s.Name, + TableName: o.Table, + PhysicalColumn: colName, + SQL: downSQL, + }) + if err != nil { + return nil, fmt.Errorf("failed to create down trigger: %w", err) + } + return table, nil +} + +func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { + switch o.Type { //nolint:gocritic // more cases will be added + case OpCreateConstraintTypeUnique: + uniqueOp := &OpSetUnique{ + Table: o.Table, + Name: o.Name, + } + err := uniqueOp.Complete(ctx, conn, tr, s) + if err != nil { + return err + } + } + + // remove old columns + _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s %s", + pq.QuoteIdentifier(o.Table), + dropMultipleColumns(quoteColumnNames(o.Columns)), + )) + if err != nil { + return err + } + + // rename new columns to old name + table := s.GetTable(o.Table) + for _, col := range o.Columns { + column := table.GetColumn(col) + if err := RenameDuplicatedColumn(ctx, conn, table, column); err != nil { + return err + } + } + + return o.removeTriggers(ctx, conn) +} + +func (o *OpCreateConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { + _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s %s", + pq.QuoteIdentifier(o.Table), + dropMultipleColumns(quotedTemporaryNames(o.Columns)), + )) + if err != nil { + return err + } + + return o.removeTriggers(ctx, conn) +} + +func (o *OpCreateConstraint) removeTriggers(ctx context.Context, conn db.DB) error { + dropFuncs := make([]string, len(o.Columns)*2) + for i, j := 0, 0; i < len(o.Columns); i, j = i+1, j+2 { + dropFuncs[j] = pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Columns[i])) + dropFuncs[j+1] = pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Columns[i]))) + } + _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", + strings.Join(dropFuncs, ", "), + )) + return err +} + +func dropMultipleColumns(columns []string) string { + for i, col := range columns { + columns[i] = "DROP COLUMN IF EXISTS " + col + } + return strings.Join(columns, ", ") +} + +func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) error { + table := s.GetTable(o.Table) + if table == nil { + return TableDoesNotExistError{Name: o.Table} + } + + if err := ValidateIdentifierLength(o.Name); err != nil { + return err + } + + if table.ConstraintExists(o.Name) { + return ConstraintAlreadyExistsError{ + Table: o.Table, + Constraint: o.Name, + } + } + + for _, col := range o.Columns { + if table.GetColumn(col) == nil { + return ColumnDoesNotExistError{ + Table: o.Table, + Name: col, + } + } + if _, ok := o.Up[col]; !ok { + return ColumnMigrationMissingError{ + Table: o.Table, + Name: col, + } + } + if _, ok := o.Down[col]; !ok { + return ColumnMigrationMissingError{ + Table: o.Table, + Name: col, + } + } + } + + switch o.Type { //nolint:gocritic // more cases will be added + case OpCreateConstraintTypeUnique: + if len(o.Columns) == 0 { + return FieldRequiredError{Name: "columns"} + } + } + + return nil +} + +func (o *OpCreateConstraint) addUniqueIndex(ctx context.Context, conn db.DB) error { + _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)", + pq.QuoteIdentifier(o.Name), + pq.QuoteIdentifier(o.Table), + strings.Join(quotedTemporaryNames(o.Columns), ", "), + )) + + return err +} + +func quotedTemporaryNames(columns []string) []string { + names := make([]string, len(columns)) + for i, col := range columns { + names[i] = pq.QuoteIdentifier(TemporaryName(col)) + } + return names +} diff --git a/pkg/migrations/op_create_constraint_test.go b/pkg/migrations/op_create_constraint_test.go new file mode 100644 index 00000000..8b2ca8e6 --- /dev/null +++ b/pkg/migrations/op_create_constraint_test.go @@ -0,0 +1,289 @@ +// SPDX-License-Identifier: Apache-2.0 + +package migrations_test + +import ( + "database/sql" + "strings" + "testing" + + "github.com/xataio/pgroll/internal/testutils" + "github.com/xataio/pgroll/pkg/migrations" +) + +func TestCreateConstraint(t *testing.T) { + t.Parallel() + + invalidName := strings.Repeat("x", 64) + ExecuteTests(t, TestCases{ + { + name: "create unique constraint on single column", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: ptr(false), + }, + }, + }, + }, + }, + { + Name: "02_create_constraint", + Operations: migrations.Operations{ + &migrations.OpCreateConstraint{ + Name: "unique_name", + Table: "users", + Type: "unique", + Columns: []string{"name"}, + Up: migrations.OpCreateConstraintUp(map[string]string{ + "name": "name || random()", + }), + Down: migrations.OpCreateConstraintDown(map[string]string{ + "name": "name", + }), + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // The index has been created on the underlying table. + IndexMustExist(t, db, schema, "users", "unique_name") + + // Inserting values into the old schema that violate uniqueness should succeed. + MustInsert(t, db, schema, "01_add_table", "users", map[string]string{ + "name": "alice", + }) + MustInsert(t, db, schema, "01_add_table", "users", map[string]string{ + "name": "alice", + }) + + // Inserting values into the new schema that violate uniqueness should fail. + MustInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "bob", + }) + MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "bob", + }, testutils.UniqueViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + // The index has been dropped from the the underlying table. + IndexMustNotExist(t, db, schema, "users", "uniue_name") + + // Functions, triggers and temporary columns are dropped. + tableCleanedUp(t, db, schema, "users", "name") + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // Functions, triggers and temporary columns are dropped. + tableCleanedUp(t, db, schema, "users", "name") + + // Inserting values into the new schema that violate uniqueness should fail. + MustInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "carol", + }) + MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "carol", + }, testutils.UniqueViolationErrorCode) + }, + }, + { + name: "create unique constraint on multiple columns", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: ptr(false), + }, + { + Name: "email", + Type: "varchar(255)", + Nullable: ptr(false), + }, + }, + }, + }, + }, + { + Name: "02_create_constraint", + Operations: migrations.Operations{ + &migrations.OpCreateConstraint{ + Name: "unique_name_email", + Table: "users", + Type: "unique", + Columns: []string{"name", "email"}, + Up: migrations.OpCreateConstraintUp(map[string]string{ + "name": "name || random()", + "email": "email || random()", + }), + Down: migrations.OpCreateConstraintDown(map[string]string{ + "name": "name", + "email": "email", + }), + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // The index has been created on the underlying table. + IndexMustExist(t, db, schema, "users", "unique_name_email") + + // Inserting values into the old schema that violate uniqueness should succeed. + MustInsert(t, db, schema, "01_add_table", "users", map[string]string{ + "name": "alice", + "email": "alice@alice.me", + }) + MustInsert(t, db, schema, "01_add_table", "users", map[string]string{ + "name": "alice", + "email": "alice@alice.me", + }) + + // Inserting values into the new schema that violate uniqueness should fail. + MustInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "bob", + "email": "bob@bob.me", + }) + MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "bob", + "email": "bob@bob.me", + }, testutils.UniqueViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + // The index has been dropped from the the underlying table. + IndexMustNotExist(t, db, schema, "users", "unique_name_email") + + // Functions, triggers and temporary columns are dropped. + tableCleanedUp(t, db, schema, "users", "name") + tableCleanedUp(t, db, schema, "users", "email") + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // Complete is a no-op. + }, + }, + { + name: "invalid constraint name", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: ptr(false), + }, + { + Name: "registered_at_year", + Type: "integer", + Nullable: ptr(false), + }, + }, + }, + }, + }, + { + Name: "02_create_constraint_with_invalid_name", + Operations: migrations.Operations{ + &migrations.OpCreateConstraint{ + Name: invalidName, + Table: "users", + Columns: []string{"registered_at_year"}, + Type: "unique", + }, + }, + }, + }, + wantStartErr: migrations.ValidateIdentifierLength(invalidName), + afterStart: func(t *testing.T, db *sql.DB, schema string) {}, + afterRollback: func(t *testing.T, db *sql.DB, schema string) {}, + afterComplete: func(t *testing.T, db *sql.DB, schema string) {}, + }, + { + name: "missing migration for constraint creation", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: ptr(false), + }, + }, + }, + }, + }, + { + Name: "02_create_constraint_with_missing_migration", + Operations: migrations.Operations{ + &migrations.OpCreateConstraint{ + Name: "unique_name", + Table: "users", + Columns: []string{"name"}, + Type: "unique", + Up: migrations.OpCreateConstraintUp(map[string]string{}), + Down: migrations.OpCreateConstraintDown(map[string]string{ + "name": "name", + }), + }, + }, + }, + }, + wantStartErr: migrations.ColumnMigrationMissingError{Table: "users", Name: "name"}, + afterStart: func(t *testing.T, db *sql.DB, schema string) {}, + afterRollback: func(t *testing.T, db *sql.DB, schema string) {}, + afterComplete: func(t *testing.T, db *sql.DB, schema string) {}, + }, + }) +} + +func tableCleanedUp(t *testing.T, db *sql.DB, schema, table, column string) { + // The new, temporary column should not exist on the underlying table. + ColumnMustNotExist(t, db, schema, table, migrations.TemporaryName(column)) + + // The up function no longer exists. + FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName(table, column)) + // The down function no longer exists. + FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName(table, migrations.TemporaryName(column))) + + // The up trigger no longer exists. + TriggerMustNotExist(t, db, schema, table, migrations.TriggerName(table, column)) + // The down trigger no longer exists. + TriggerMustNotExist(t, db, schema, table, migrations.TriggerName(table, migrations.TemporaryName(column))) +} diff --git a/pkg/migrations/types.go b/pkg/migrations/types.go index 5efb1acb..bd11428b 100644 --- a/pkg/migrations/types.go +++ b/pkg/migrations/types.go @@ -119,6 +119,37 @@ type OpAlterColumn struct { Up string `json:"up,omitempty"` } +// Add constraint to table operation +type OpCreateConstraint struct { + // Columns to add constraint to + Columns []string `json:"columns,omitempty"` + + // SQL expression of down migration by column + Down OpCreateConstraintDown `json:"down"` + + // Name of the constraint + Name string `json:"name"` + + // Name of the table + Table string `json:"table"` + + // Type of the constraint + Type OpCreateConstraintType `json:"type"` + + // SQL expression of up migration by column + Up OpCreateConstraintUp `json:"up"` +} + +// SQL expression of down migration by column +type OpCreateConstraintDown map[string]string + +type OpCreateConstraintType string + +const OpCreateConstraintTypeUnique OpCreateConstraintType = "unique" + +// SQL expression of up migration by column +type OpCreateConstraintUp map[string]string + // Create index operation type OpCreateIndex struct { // Names of columns on which to define the index diff --git a/schema.json b/schema.json index 4d7cdedb..428c11f1 100644 --- a/schema.json +++ b/schema.json @@ -432,6 +432,44 @@ "required": ["identity", "table"], "type": "object" }, + "OpCreateConstraint": { + "additionalProperties": false, + "description": "Add constraint to table operation", + "properties": { + "table": { + "description": "Name of the table", + "type": "string" + }, + "name": { + "description": "Name of the constraint", + "type": "string" + }, + "columns": { + "description": "Columns to add constraint to", + "type": "array", + "items": { + "type": "string" + } + }, + "type": { + "description": "Type of the constraint", + "type": "string", + "enum": ["unique"] + }, + "up": { + "type": "object", + "additionalProperties": { "type": "string" }, + "description": "SQL expression of up migration by column" + }, + "down": { + "type": "object", + "additionalProperties": { "type": "string" }, + "description": "SQL expression of down migration by column" + } + }, + "required": ["name", "table", "type", "up", "down"], + "type": "object" + }, "PgRollOperation": { "anyOf": [ { @@ -565,6 +603,17 @@ } }, "required": ["set_replica_identity"] + }, + { + "type": "object", + "description": "Add constraint operation", + "additionalProperties": false, + "properties": { + "create_constraint": { + "$ref": "#/$defs/OpCreateConstraint" + } + }, + "required": ["create_constraint"] } ] },