Skip to content

Commit

Permalink
Add support for creating CHECK constraints with create_constraint (
Browse files Browse the repository at this point in the history
…#464)

This PR introduces a new constraint `type` to `create_constraint`
operation called `check`. Now it is possible to create check constraints
on multiple columns.

### Example
```json
{
  "name": "45_add_table_check_constraint",
  "operations": [
    {
      "create_constraint": {
        "type": "check",
        "table": "tickets",
        "name": "check_zip_name",
        "columns": [
          "sellers_name",
          "sellers_zip"
        ],
        "check": "sellers_name ~ 'Alice' AND sellers_zip IS NOT NULL",
        "up": {
          "sellers_name": "Alice",
          "sellers_zip": "(SELECT CASE WHEN sellers_zip IS NOT NULL THEN sellers_zip ELSE '00000' END)"
        },
        "down": {
          "sellers_name": "sellers_name",
          "sellers_zip": "sellers_zip"
        }
      }
    }
  ]
}
```

---------

Co-authored-by: Andrew Farries <[email protected]>
  • Loading branch information
kvch and andrew-farries authored Nov 18, 2024
1 parent 5d63a7f commit d231ee0
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 5 deletions.
3 changes: 2 additions & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ Example **create table** migrations:

A create constraint operation adds a new constraint to an existing table.

Only `UNIQUE` constraints are supported.
Only `UNIQUE` and `CHECK` constraints are supported.

Required fields: `name`, `table`, `type`, `up`, `down`.

Expand Down Expand Up @@ -1129,6 +1129,7 @@ Required fields: `name`, `table`, `type`, `up`, `down`.
Example **create constraint** migrations:

* [44_add_table_unique_constraint.json](../examples/44_add_table_unique_constraint.json)
* [45_add_table_check_constraint.json](../examples/45_add_table_check_constraint.json)


### Drop column
Expand Down
1 change: 1 addition & 0 deletions examples/.ledger
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@
42_create_unique_index.json
43_create_tickets_table.json
44_add_table_unique_constraint.json
45_add_table_check_constraint.json
25 changes: 25 additions & 0 deletions examples/45_add_table_check_constraint.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"name": "45_add_table_check_constraint",
"operations": [
{
"create_constraint": {
"type": "check",
"table": "tickets",
"name": "check_zip_name",
"columns": [
"sellers_name",
"sellers_zip"
],
"check": "sellers_name = 'alice' OR sellers_zip > 0",
"up": {
"sellers_name": "sellers_name",
"sellers_zip": "(SELECT CASE WHEN sellers_name != 'alice' AND sellers_zip <= 0 THEN 123 WHEN sellers_name != 'alice' THEN sellers_zip ELSE sellers_zip END)"
},
"down": {
"sellers_name": "sellers_name",
"sellers_zip": "sellers_zip"
}
}
}
]
}
33 changes: 30 additions & 3 deletions pkg/migrations/op_create_constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,18 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema
}
}

switch o.Type { //nolint:gocritic // more cases will be added
switch o.Type {
case OpCreateConstraintTypeUnique:
return table, o.addUniqueIndex(ctx, conn)
case OpCreateConstraintTypeCheck:
return table, o.addCheckConstraint(ctx, conn)
}

return table, nil
}

func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
switch o.Type { //nolint:gocritic // more cases will be added
switch o.Type {
case OpCreateConstraintTypeUnique:
uniqueOp := &OpSetUnique{
Table: o.Table,
Expand All @@ -84,6 +86,17 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTra
if err != nil {
return err
}
case OpCreateConstraintTypeCheck:
checkOp := &OpSetCheckConstraint{
Table: o.Table,
Check: CheckConstraint{
Name: o.Name,
},
}
err := checkOp.Complete(ctx, conn, tr, s)
if err != nil {
return err
}
}

// remove old columns
Expand Down Expand Up @@ -176,11 +189,15 @@ func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) err
}
}

switch o.Type { //nolint:gocritic // more cases will be added
switch o.Type {
case OpCreateConstraintTypeUnique:
if len(o.Columns) == 0 {
return FieldRequiredError{Name: "columns"}
}
case OpCreateConstraintTypeCheck:
if o.Check == nil || *o.Check == "" {
return FieldRequiredError{Name: "check"}
}
}

return nil
Expand All @@ -196,6 +213,16 @@ func (o *OpCreateConstraint) addUniqueIndex(ctx context.Context, conn db.DB) err
return err
}

func (o *OpCreateConstraint) addCheckConstraint(ctx context.Context, conn db.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Name),
rewriteCheckExpression(*o.Check, o.Columns...),
))

return err
}

func quotedTemporaryNames(columns []string) []string {
names := make([]string, len(columns))
for i, col := range columns {
Expand Down
220 changes: 220 additions & 0 deletions pkg/migrations/op_create_constraint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/assert"

"github.com/xataio/pgroll/internal/testutils"
"github.com/xataio/pgroll/pkg/migrations"
)
Expand Down Expand Up @@ -97,6 +99,80 @@ func TestCreateConstraint(t *testing.T) {
}, testutils.UniqueViolationErrorCode)
},
},
{
name: "create check constraint on single column",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
{
Name: "name",
Type: "varchar(255)",
Nullable: ptr(false),
},
},
},
},
},
{
Name: "02_create_constraint",
Operations: migrations.Operations{
&migrations.OpCreateConstraint{
Name: "name_letters",
Table: "users",
Type: "check",
Check: ptr("name ~ '^[a-zA-Z]+$'"),
Columns: []string{"name"},
Up: migrations.OpCreateConstraintUp(map[string]string{
"name": "regexp_replace(name, '\\d+', '', 'g')",
}),
Down: migrations.OpCreateConstraintDown(map[string]string{
"name": "name",
}),
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// The new (temporary) column should exist on the underlying table.
ColumnMustExist(t, db, schema, "users", migrations.TemporaryName("name"))
// The check constraint exists on the new table.
CheckConstraintMustExist(t, db, schema, "users", "name_letters")
// Inserting values into the old schema that violate the check constraint must succeed.
MustInsert(t, db, schema, "01_add_table", "users", map[string]string{
"name": "alice11",
})

// Inserting values into the new schema that violate the check constraint should fail.
MustInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
"name": "bob",
})
MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
"name": "bob2",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
// Functions, triggers and temporary columns are dropped.
tableCleanedUp(t, db, schema, "users", "name")
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Functions, triggers and temporary columns are dropped.
tableCleanedUp(t, db, schema, "users", "name")

// Inserting values into the new schema that violate the check constraint should fail.
MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
"name": "carol0",
}, testutils.CheckViolationErrorCode)
},
},
{
name: "create unique constraint on multiple columns",
migrations: []migrations.Migration{
Expand Down Expand Up @@ -181,6 +257,104 @@ func TestCreateConstraint(t *testing.T) {
// Complete is a no-op.
},
},
{
name: "create check constraint on multiple columns",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
{
Name: "name",
Type: "varchar(255)",
Nullable: ptr(false),
},
{
Name: "email",
Type: "varchar(255)",
Nullable: ptr(false),
},
},
},
},
},
{
Name: "02_create_constraint",
Operations: migrations.Operations{
&migrations.OpCreateConstraint{
Name: "check_name_email",
Table: "users",
Type: "check",
Check: ptr("name != email"),
Columns: []string{"name", "email"},
Up: migrations.OpCreateConstraintUp(map[string]string{
"name": "name",
"email": "(SELECT CASE WHEN email ~ '@' THEN email ELSE email || '@example.com' END)",
}),
Down: migrations.OpCreateConstraintDown(map[string]string{
"name": "name",
"email": "email",
}),
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// The new (temporary) column should exist on the underlying table.
ColumnMustExist(t, db, schema, "users", migrations.TemporaryName("name"))
// The new (temporary) column should exist on the underlying table.
ColumnMustExist(t, db, schema, "users", migrations.TemporaryName("email"))
// The check constraint exists on the new table.
CheckConstraintMustExist(t, db, schema, "users", "check_name_email")

// Inserting values into the old schema that the violate the check constraint must succeed.
MustInsert(t, db, schema, "01_add_table", "users", map[string]string{
"name": "alice",
"email": "alice",
})

// Inserting values into the new schema that meet the check constraint should succeed.
MustInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
"name": "bob",
"email": "[email protected]",
})
MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
"name": "bob",
"email": "bob",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
// The check constraint must not exists on the table.
CheckConstraintMustNotExist(t, db, schema, "users", "check_name_email")
// Functions, triggers and temporary columns are dropped.
tableCleanedUp(t, db, schema, "users", "name")
tableCleanedUp(t, db, schema, "users", "email")
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Functions, triggers and temporary columns are dropped.
tableCleanedUp(t, db, schema, "users", "name")
tableCleanedUp(t, db, schema, "users", "email")

// Inserting values into the new schema that the violate the check constraint must fail.
MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
"name": "carol",
"email": "carol",
}, testutils.CheckViolationErrorCode)

rows := MustSelect(t, db, schema, "02_create_constraint", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "alice", "email": "[email protected]"},
{"id": 2, "name": "bob", "email": "[email protected]"},
}, rows)
},
},
{
name: "invalid constraint name",
migrations: []migrations.Migration{
Expand Down Expand Up @@ -270,6 +444,52 @@ func TestCreateConstraint(t *testing.T) {
afterRollback: func(t *testing.T, db *sql.DB, schema string) {},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {},
},
{
name: "expression of check constraint is missing",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
{
Name: "name",
Type: "varchar(255)",
Nullable: ptr(false),
},
},
},
},
},
{
Name: "02_create_constraint_with_missing_migration",
Operations: migrations.Operations{
&migrations.OpCreateConstraint{
Name: "check_name",
Table: "users",
Columns: []string{"name"},
Type: "check",
Up: migrations.OpCreateConstraintUp(map[string]string{
"name": "name",
}),
Down: migrations.OpCreateConstraintDown(map[string]string{
"name": "name",
}),
},
},
},
},
wantStartErr: migrations.FieldRequiredError{Name: "check"},
afterStart: func(t *testing.T, db *sql.DB, schema string) {},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {},
},
})
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/migrations/types.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,11 @@
"type": {
"description": "Type of the constraint",
"type": "string",
"enum": ["unique"]
"enum": ["unique", "check"]
},
"check": {
"description": "Check constraint expression",
"type": "string"
},
"up": {
"type": "object",
Expand Down

0 comments on commit d231ee0

Please sign in to comment.