Skip to content

Commit

Permalink
Convert ALTER TABLE ... ADD CONSTRAINT ... CHECK SQL to pgroll op…
Browse files Browse the repository at this point in the history
…eration (#538)

Convert SQL statements like:

```sql
ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0)
```

to the equivalent `pgroll` operation:

```json
[
  {
    "create_constraint": {
      "check": "age > 0",
      "columns": [
        "placeholder"
      ],
      "down": {
        "placeholder": "TODO: Implement SQL data migration"
      },
      "name": "bar",
      "table": "foo",
      "type": "check",
      "up": {
        "placeholder": "TODO: Implement SQL data migration"
      }
    }
  }
]
```

As we don't currently have a reliable way of extracting the columns
covered by the constraint from the the `CHECK` SQL expression, the
converted `pgroll` operation uses placeholders for the `columns`, `up`
and `down` fields.

Some forms of the statement are not currently representable by the
`create_constraint` operation. For these a raw SQL migration is
generated:

```sql
"ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0) NO INHERIT",
"ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0) NOT VALID",
```
  • Loading branch information
andrew-farries authored Dec 17, 2024
1 parent f4c17ff commit be145b1
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 10 deletions.
65 changes: 61 additions & 4 deletions pkg/sql2pgroll/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import (
"github.com/xataio/pgroll/pkg/migrations"
)

const PlaceHolderSQL = "TODO: Implement SQL data migration"
const (
PlaceHolderColumnName = "placeholder"
PlaceHolderSQL = "TODO: Implement SQL data migration"
)

// convertAlterTableStmt converts an ALTER TABLE statement to pgroll operations.
func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, error) {
Expand Down Expand Up @@ -99,11 +102,12 @@ func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTa
}, nil
}

// convertAlterTableAddConstraint converts SQL statements that add UNIQUE or FOREIGN KEY constraints,
// convertAlterTableAddConstraint converts SQL statements that add constraints,
// for example:
//
// `ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)`
// `ALTER TABLE foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES bar (c);`
// `ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0)`
//
// An OpCreateConstraint operation is returned.
func convertAlterTableAddConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
Expand All @@ -119,6 +123,8 @@ func convertAlterTableAddConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTabl
op, err = convertAlterTableAddUniqueConstraint(stmt, node.Constraint)
case pgq.ConstrType_CONSTR_FOREIGN:
op, err = convertAlterTableAddForeignKeyConstraint(stmt, node.Constraint)
case pgq.ConstrType_CONSTR_CHECK:
op, err = convertAlterTableAddCheckConstraint(stmt, node.Constraint)
default:
return nil, nil
}
Expand Down Expand Up @@ -229,6 +235,10 @@ func convertAlterTableAddForeignKeyConstraint(stmt *pgq.AlterTableStmt, constrai
}

func canConvertAlterTableAddForeignKeyConstraint(constraint *pgq.Constraint) bool {
if constraint.SkipValidation {
return false
}

switch constraint.GetFkUpdAction() {
case "r", "c", "n", "d":
// RESTRICT, CASCADE, SET NULL, SET DEFAULT
Expand All @@ -248,6 +258,53 @@ func canConvertAlterTableAddForeignKeyConstraint(constraint *pgq.Constraint) boo
return true
}

// convertAlterTableAddCheckConstraint converts SQL statements like:
//
// `ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0)`
//
// to an OpCreateConstraint operation.
func convertAlterTableAddCheckConstraint(stmt *pgq.AlterTableStmt, constraint *pgq.Constraint) (migrations.Operation, error) {
if !canConvertCheckConstraint(constraint) {
return nil, nil
}

tableName := stmt.GetRelation().GetRelname()
if stmt.GetRelation().GetSchemaname() != "" {
tableName = stmt.GetRelation().GetSchemaname() + "." + tableName
}

expr, err := pgq.DeparseExpr(constraint.GetRawExpr())
if err != nil {
return nil, fmt.Errorf("failed to deparse CHECK expression: %w", err)
}

return &migrations.OpCreateConstraint{
Type: migrations.OpCreateConstraintTypeCheck,
Name: constraint.GetConname(),
Table: tableName,
Check: ptr(expr),
Columns: []string{PlaceHolderColumnName},
Up: migrations.MultiColumnUpSQL{
PlaceHolderColumnName: PlaceHolderSQL,
},
Down: migrations.MultiColumnDownSQL{
PlaceHolderColumnName: PlaceHolderSQL,
},
}, nil
}

// canConvertCheckConstraint checks if the CHECK constraint `constraint` can
// be faithfully converted to an OpCreateConstraint operation without losing
// information.
func canConvertCheckConstraint(constraint *pgq.Constraint) bool {
switch {
case constraint.IsNoInherit, constraint.SkipValidation:
return false
default:
return true
}
}

// convertAlterTableSetColumnDefault converts SQL statements like:
//
// `ALTER TABLE foo COLUMN bar SET DEFAULT 'foo'`
Expand Down Expand Up @@ -317,10 +374,10 @@ func convertAlterTableDropConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTab

return &migrations.OpDropMultiColumnConstraint{
Up: migrations.MultiColumnUpSQL{
"placeholder": PlaceHolderSQL,
PlaceHolderColumnName: PlaceHolderSQL,
},
Down: migrations.MultiColumnDownSQL{
"placeholder": PlaceHolderSQL,
PlaceHolderColumnName: PlaceHolderSQL,
},
Table: tableName,
Name: cmd.GetName(),
Expand Down
18 changes: 14 additions & 4 deletions pkg/sql2pgroll/alter_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,6 @@ func TestConvertAlterTableStatements(t *testing.T) {
sql: "ALTER TABLE foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES bar (c);",
expectedOp: expect.AddForeignKeyOp2,
},
{
sql: "ALTER TABLE foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES bar (c) NOT VALID;",
expectedOp: expect.AddForeignKeyOp2,
},
{
sql: "ALTER TABLE schema_a.foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES schema_a.bar (c);",
expectedOp: expect.AddForeignKeyOp3,
Expand All @@ -136,6 +132,14 @@ func TestConvertAlterTableStatements(t *testing.T) {
sql: "ALTER TABLE foo DROP CONSTRAINT IF EXISTS constraint_foo RESTRICT",
expectedOp: expect.OpDropConstraintWithTable("foo"),
},
{
sql: "ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0)",
expectedOp: expect.CreateConstraintOp3,
},
{
sql: "ALTER TABLE schema.foo ADD CONSTRAINT bar CHECK (age > 0)",
expectedOp: expect.CreateConstraintOp4,
},
}

for _, tc := range tests {
Expand Down Expand Up @@ -176,11 +180,17 @@ func TestUnconvertableAlterTableStatements(t *testing.T) {
"ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON UPDATE SET NULL;",
"ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON UPDATE SET DEFAULT;",
"ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) MATCH FULL;",
"ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) NOT VALID",
// MATCH PARTIAL is not implemented in the actual parser yet
//"ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) MATCH PARTIAL;",

// Drop constraint with CASCADE
"ALTER TABLE foo DROP CONSTRAINT bar CASCADE",

// NO INHERIT and NOT VALID options on CHECK constraints are not
// representable by `OpCreateConstraint`
"ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0) NO INHERIT",
"ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0) NOT VALID",
}

for _, sql := range tests {
Expand Down
28 changes: 28 additions & 0 deletions pkg/sql2pgroll/expect/create_constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,31 @@ var CreateConstraintOp2 = &migrations.OpCreateConstraint{
"b": sql2pgroll.PlaceHolderSQL,
},
}

var CreateConstraintOp3 = &migrations.OpCreateConstraint{
Type: migrations.OpCreateConstraintTypeCheck,
Name: "bar",
Table: "foo",
Check: ptr("age > 0"),
Columns: []string{sql2pgroll.PlaceHolderColumnName},
Up: map[string]string{
sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL,
},
Down: map[string]string{
sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL,
},
}

var CreateConstraintOp4 = &migrations.OpCreateConstraint{
Type: migrations.OpCreateConstraintTypeCheck,
Name: "bar",
Table: "schema.foo",
Check: ptr("age > 0"),
Columns: []string{sql2pgroll.PlaceHolderColumnName},
Up: map[string]string{
sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL,
},
Down: map[string]string{
sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL,
},
}
4 changes: 2 additions & 2 deletions pkg/sql2pgroll/expect/drop_constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import (
func OpDropConstraintWithTable(table string) *migrations.OpDropMultiColumnConstraint {
return &migrations.OpDropMultiColumnConstraint{
Up: migrations.MultiColumnUpSQL{
"placeholder": sql2pgroll.PlaceHolderSQL,
sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL,
},
Down: migrations.MultiColumnDownSQL{
"placeholder": sql2pgroll.PlaceHolderSQL,
sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL,
},
Table: table,
Name: "constraint_foo",
Expand Down

0 comments on commit be145b1

Please sign in to comment.