diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index 915429ea..2d6b3d6b 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -4,7 +4,9 @@ package sql2pgroll import ( "fmt" + "strconv" + "github.com/oapi-codegen/nullable" pgq "github.com/pganalyze/pg_query_go/v6" "github.com/xataio/pgroll/pkg/migrations" @@ -38,6 +40,8 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err op, err = convertAlterTableAddConstraint(stmt, alterTableCmd) case pgq.AlterTableType_AT_DropColumn: op, err = convertAlterTableDropColumn(stmt, alterTableCmd) + case pgq.AlterTableType_AT_ColumnDefault: + op, err = convertAlterTableSetColumnDefault(stmt, alterTableCmd) } if err != nil { @@ -158,11 +162,66 @@ func convertAlterTableAddUniqueConstraint(stmt *pgq.AlterTableStmt, constraint * }, nil } -// convertAlterTableDropColumn converts SQL statements like: +// convertAlterTableSetColumnDefault converts SQL statements like: // -// `ALTER TABLE foo DROP COLUMN bar +// `ALTER TABLE foo COLUMN bar SET DEFAULT 'foo'` +// `ALTER TABLE foo COLUMN bar SET DEFAULT 123` +// `ALTER TABLE foo COLUMN bar SET DEFAULT 123.456` +// `ALTER TABLE foo COLUMN bar SET DEFAULT true` +// `ALTER TABLE foo COLUMN bar SET DEFAULT B'0101'` +// `ALTER TABLE foo COLUMN bar SET DEFAULT null` +// `ALTER TABLE foo COLUMN bar DROP DEFAULT` // -// to an OpDropColumn operation. +// to an OpAlterColumn operation. +func convertAlterTableSetColumnDefault(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { + operation := &migrations.OpAlterColumn{ + Table: stmt.GetRelation().GetRelname(), + Column: cmd.GetName(), + Down: PlaceHolderSQL, + Up: PlaceHolderSQL, + } + + if c := cmd.GetDef().GetAConst(); c != nil { + if c.GetIsnull() { + // The default can be set to null + operation.Default = nullable.NewNullNullable[string]() + return operation, nil + } + + // We have a constant + switch v := c.GetVal().(type) { + case *pgq.A_Const_Sval: + operation.Default = nullable.NewNullableWithValue(v.Sval.GetSval()) + case *pgq.A_Const_Ival: + operation.Default = nullable.NewNullableWithValue(strconv.FormatInt(int64(v.Ival.Ival), 10)) + case *pgq.A_Const_Fval: + operation.Default = nullable.NewNullableWithValue(v.Fval.Fval) + case *pgq.A_Const_Boolval: + operation.Default = nullable.NewNullableWithValue(strconv.FormatBool(v.Boolval.Boolval)) + case *pgq.A_Const_Bsval: + operation.Default = nullable.NewNullableWithValue(v.Bsval.Bsval) + default: + return nil, fmt.Errorf("unknown constant type: %T", c.GetVal()) + } + + return operation, nil + } + + if cmd.GetDef() != nil { + // We're setting it to something other than a constant + return nil, nil + } + + // We're not setting it to anything, which is the case when we are dropping it + if cmd.GetBehavior() == pgq.DropBehavior_DROP_RESTRICT { + operation.Default = nullable.NewNullNullable[string]() + return operation, nil + } + + // Unknown case, fall back to raw SQL + return nil, nil +} + func convertAlterTableDropColumn(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { if !canConvertDropColumn(cmd) { return nil, nil diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go index 144b220b..65c39ba8 100644 --- a/pkg/sql2pgroll/alter_table_test.go +++ b/pkg/sql2pgroll/alter_table_test.go @@ -36,6 +36,34 @@ func TestConvertAlterTableStatements(t *testing.T) { sql: "ALTER TABLE foo ALTER COLUMN a TYPE text", expectedOp: expect.AlterColumnOp3, }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT 'baz'", + expectedOp: expect.AlterColumnOp5, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT 123", + expectedOp: expect.AlterColumnOp6, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT true", + expectedOp: expect.AlterColumnOp9, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT B'0101'", + expectedOp: expect.AlterColumnOp10, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT 123.456", + expectedOp: expect.AlterColumnOp8, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar DROP DEFAULT", + expectedOp: expect.AlterColumnOp7, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT null", + expectedOp: expect.AlterColumnOp7, + }, { sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)", expectedOp: expect.CreateConstraintOp1, @@ -85,6 +113,9 @@ func TestUnconvertableAlterTableStatements(t *testing.T) { // CASCADE and IF EXISTS clauses are not represented by OpDropColumn "ALTER TABLE foo DROP COLUMN bar CASCADE", "ALTER TABLE foo DROP COLUMN IF EXISTS bar", + + // Non literal default values + "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT now()", } for _, sql := range tests { diff --git a/pkg/sql2pgroll/expect/alter_column.go b/pkg/sql2pgroll/expect/alter_column.go index ec9b0acc..63041908 100644 --- a/pkg/sql2pgroll/expect/alter_column.go +++ b/pkg/sql2pgroll/expect/alter_column.go @@ -3,6 +3,8 @@ package expect import ( + "github.com/oapi-codegen/nullable" + "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/sql2pgroll" ) @@ -37,6 +39,54 @@ var AlterColumnOp4 = &migrations.OpAlterColumn{ Name: ptr("b"), } +var AlterColumnOp5 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("baz"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp6 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("123"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp7 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullNullable[string](), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp8 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("123.456"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp9 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("true"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp10 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("b0101"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + func ptr[T any](v T) *T { return &v }