diff --git a/pkg/migrations/op_create_index.go b/pkg/migrations/op_create_index.go index 3add5a9c..7408471b 100644 --- a/pkg/migrations/op_create_index.go +++ b/pkg/migrations/op_create_index.go @@ -31,7 +31,9 @@ func (o *OpCreateIndex) Start(ctx context.Context, conn db.DB, latestSchema stri stmt += fmt.Sprintf(" USING %s", string(*o.Method)) } - stmt += fmt.Sprintf(" (%s)", strings.Join(quoteColumnNames(o.Columns), ", ")) + stmt += fmt.Sprintf(" (%s)", strings.Join( + quoteColumnNames(table.PhysicalColumnNamesFor(o.Columns...)), ", "), + ) if o.StorageParameters != nil { stmt += fmt.Sprintf(" WITH (%s)", *o.StorageParameters) diff --git a/pkg/migrations/op_create_index_test.go b/pkg/migrations/op_create_index_test.go index 28ef5342..3b9cdc64 100644 --- a/pkg/migrations/op_create_index_test.go +++ b/pkg/migrations/op_create_index_test.go @@ -315,7 +315,7 @@ func TestCreateIndexOnMultipleColumns(t *testing.T) { }}) } -func TestCreateIndexOnTableCreatedInSameMigration(t *testing.T) { +func TestCreateIndexOnObjectsCreatedInSameMigration(t *testing.T) { t.Parallel() ExecuteTests(t, TestCases{ @@ -361,5 +361,61 @@ func TestCreateIndexOnTableCreatedInSameMigration(t *testing.T) { IndexMustExist(t, db, schema, "users", "idx_users_name") }, }, - }, roll.WithSkipValidation(true)) // TODO: Remove once this migration passes validation + { + name: "create index on newly created column", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: ptr(false), + }, + }, + }, + }, + }, + { + Name: "02_add_column_and_index", + Operations: migrations.Operations{ + &migrations.OpAddColumn{ + Table: "users", + Column: migrations.Column{ + Name: "age", + Type: "integer", + Nullable: ptr(true), + }, + Up: "18", + }, + &migrations.OpCreateIndex{ + Name: "idx_users_age", + Table: "users", + Columns: []string{"age"}, + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // The index has been created on the underlying table. + IndexMustExist(t, db, schema, "users", "idx_users_age") + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + // The index has been dropped from the the underlying table. + IndexMustNotExist(t, db, schema, "users", "idx_users_age") + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // The index has been created on the underlying table. + IndexMustExist(t, db, schema, "users", "idx_users_age") + }, + }, + }, roll.WithSkipValidation(true)) // TODO: Remove once these migrations pass validation } diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 49f3ab73..1aa9a48d 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -253,6 +253,16 @@ func (t *Table) RenameColumn(from, to string) { delete(t.Columns, from) } +// PhysicalColumnNames returns the physical column names for the given virtual +// column names +func (t *Table) PhysicalColumnNamesFor(columnNames ...string) []string { + physicalNames := make([]string, 0, len(columnNames)) + for _, cn := range columnNames { + physicalNames = append(physicalNames, t.GetColumn(cn).Name) + } + return physicalNames +} + // Make the Schema struct implement the driver.Valuer interface. This method // simply returns the JSON-encoded representation of the struct. func (s Schema) Value() (driver.Value, error) {