From f429ee23ccf8d32027ef47052488ea3bc404a543 Mon Sep 17 00:00:00 2001 From: Ryan Slade Date: Wed, 11 Dec 2024 17:49:13 +0100 Subject: [PATCH] sql2pgroll: Support set and drop column defaults (#526) Support translating the following alter table statements into pgroll operations: ```sql ALTER TABLE foo COLUMN bar SET DEFAULT 'foo' ALTER TABLE foo COLUMN bar SET DEFAULT 123 ALTER TABLE foo COLUMN bar SET DEFAULT 123.456 ALTER TABLE foo COLUMN bar SET DEFAULT true ALTER TABLE foo COLUMN bar SET DEFAULT B'0101' ALTER TABLE foo COLUMN bar SET DEFAULT null ALTER TABLE foo COLUMN bar DROP DEFAULT ``` For cases where the default is set to something other than a simple literal, we fall back to raw SQL. --- pkg/sql2pgroll/alter_table.go | 65 +++++++++++++++++++++++++-- pkg/sql2pgroll/alter_table_test.go | 31 +++++++++++++ pkg/sql2pgroll/expect/alter_column.go | 50 +++++++++++++++++++++ 3 files changed, 143 insertions(+), 3 deletions(-) diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index 915429ea..2d6b3d6b 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -4,7 +4,9 @@ package sql2pgroll import ( "fmt" + "strconv" + "github.com/oapi-codegen/nullable" pgq "github.com/pganalyze/pg_query_go/v6" "github.com/xataio/pgroll/pkg/migrations" @@ -38,6 +40,8 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err op, err = convertAlterTableAddConstraint(stmt, alterTableCmd) case pgq.AlterTableType_AT_DropColumn: op, err = convertAlterTableDropColumn(stmt, alterTableCmd) + case pgq.AlterTableType_AT_ColumnDefault: + op, err = convertAlterTableSetColumnDefault(stmt, alterTableCmd) } if err != nil { @@ -158,11 +162,66 @@ func convertAlterTableAddUniqueConstraint(stmt *pgq.AlterTableStmt, constraint * }, nil } -// convertAlterTableDropColumn converts SQL statements like: +// convertAlterTableSetColumnDefault converts SQL statements like: // -// `ALTER TABLE foo DROP COLUMN bar +// `ALTER TABLE foo COLUMN bar SET DEFAULT 'foo'` +// `ALTER TABLE foo COLUMN bar SET DEFAULT 123` +// `ALTER TABLE foo COLUMN bar SET DEFAULT 123.456` +// `ALTER TABLE foo COLUMN bar SET DEFAULT true` +// `ALTER TABLE foo COLUMN bar SET DEFAULT B'0101'` +// `ALTER TABLE foo COLUMN bar SET DEFAULT null` +// `ALTER TABLE foo COLUMN bar DROP DEFAULT` // -// to an OpDropColumn operation. +// to an OpAlterColumn operation. +func convertAlterTableSetColumnDefault(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { + operation := &migrations.OpAlterColumn{ + Table: stmt.GetRelation().GetRelname(), + Column: cmd.GetName(), + Down: PlaceHolderSQL, + Up: PlaceHolderSQL, + } + + if c := cmd.GetDef().GetAConst(); c != nil { + if c.GetIsnull() { + // The default can be set to null + operation.Default = nullable.NewNullNullable[string]() + return operation, nil + } + + // We have a constant + switch v := c.GetVal().(type) { + case *pgq.A_Const_Sval: + operation.Default = nullable.NewNullableWithValue(v.Sval.GetSval()) + case *pgq.A_Const_Ival: + operation.Default = nullable.NewNullableWithValue(strconv.FormatInt(int64(v.Ival.Ival), 10)) + case *pgq.A_Const_Fval: + operation.Default = nullable.NewNullableWithValue(v.Fval.Fval) + case *pgq.A_Const_Boolval: + operation.Default = nullable.NewNullableWithValue(strconv.FormatBool(v.Boolval.Boolval)) + case *pgq.A_Const_Bsval: + operation.Default = nullable.NewNullableWithValue(v.Bsval.Bsval) + default: + return nil, fmt.Errorf("unknown constant type: %T", c.GetVal()) + } + + return operation, nil + } + + if cmd.GetDef() != nil { + // We're setting it to something other than a constant + return nil, nil + } + + // We're not setting it to anything, which is the case when we are dropping it + if cmd.GetBehavior() == pgq.DropBehavior_DROP_RESTRICT { + operation.Default = nullable.NewNullNullable[string]() + return operation, nil + } + + // Unknown case, fall back to raw SQL + return nil, nil +} + func convertAlterTableDropColumn(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { if !canConvertDropColumn(cmd) { return nil, nil diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go index 144b220b..65c39ba8 100644 --- a/pkg/sql2pgroll/alter_table_test.go +++ b/pkg/sql2pgroll/alter_table_test.go @@ -36,6 +36,34 @@ func TestConvertAlterTableStatements(t *testing.T) { sql: "ALTER TABLE foo ALTER COLUMN a TYPE text", expectedOp: expect.AlterColumnOp3, }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT 'baz'", + expectedOp: expect.AlterColumnOp5, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT 123", + expectedOp: expect.AlterColumnOp6, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT true", + expectedOp: expect.AlterColumnOp9, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT B'0101'", + expectedOp: expect.AlterColumnOp10, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT 123.456", + expectedOp: expect.AlterColumnOp8, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar DROP DEFAULT", + expectedOp: expect.AlterColumnOp7, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT null", + expectedOp: expect.AlterColumnOp7, + }, { sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)", expectedOp: expect.CreateConstraintOp1, @@ -85,6 +113,9 @@ func TestUnconvertableAlterTableStatements(t *testing.T) { // CASCADE and IF EXISTS clauses are not represented by OpDropColumn "ALTER TABLE foo DROP COLUMN bar CASCADE", "ALTER TABLE foo DROP COLUMN IF EXISTS bar", + + // Non literal default values + "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT now()", } for _, sql := range tests { diff --git a/pkg/sql2pgroll/expect/alter_column.go b/pkg/sql2pgroll/expect/alter_column.go index ec9b0acc..63041908 100644 --- a/pkg/sql2pgroll/expect/alter_column.go +++ b/pkg/sql2pgroll/expect/alter_column.go @@ -3,6 +3,8 @@ package expect import ( + "github.com/oapi-codegen/nullable" + "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/sql2pgroll" ) @@ -37,6 +39,54 @@ var AlterColumnOp4 = &migrations.OpAlterColumn{ Name: ptr("b"), } +var AlterColumnOp5 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("baz"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp6 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("123"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp7 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullNullable[string](), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp8 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("123.456"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp9 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("true"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp10 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("b0101"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + func ptr[T any](v T) *T { return &v }