Skip to content

Commit

Permalink
Use a helper function to get a fully qualified relation name (#540)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanslade authored Dec 17, 2024
1 parent be145b1 commit 91f65ea
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions pkg/sql2pgroll/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,8 @@ func convertAlterTableAddForeignKeyConstraint(stmt *pgq.AlterTableStmt, constrai
return nil, fmt.Errorf("unknown delete action: %q", constraint.GetFkDelAction())
}

tableName := stmt.GetRelation().GetRelname()
if stmt.GetRelation().GetSchemaname() != "" {
tableName = stmt.GetRelation().GetSchemaname() + "." + tableName
}

foreignTable := constraint.GetPktable().GetRelname()
if constraint.GetPktable().GetSchemaname() != "" {
foreignTable = constraint.GetPktable().GetSchemaname() + "." + foreignTable
}
tableName := getQualifiedRelationName(stmt.Relation)
foreignTable := getQualifiedRelationName(constraint.GetPktable())

return &migrations.OpCreateConstraint{
Columns: columns,
Expand Down Expand Up @@ -268,10 +261,7 @@ func convertAlterTableAddCheckConstraint(stmt *pgq.AlterTableStmt, constraint *p
return nil, nil
}

tableName := stmt.GetRelation().GetRelname()
if stmt.GetRelation().GetSchemaname() != "" {
tableName = stmt.GetRelation().GetSchemaname() + "." + tableName
}
tableName := getQualifiedRelationName(stmt.GetRelation())

expr, err := pgq.DeparseExpr(constraint.GetRawExpr())
if err != nil {
Expand Down Expand Up @@ -367,10 +357,7 @@ func convertAlterTableDropConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTab
return nil, nil
}

tableName := stmt.GetRelation().GetRelname()
if stmt.GetRelation().GetSchemaname() != "" {
tableName = stmt.GetRelation().GetSchemaname() + "." + tableName
}
tableName := getQualifiedRelationName(stmt.GetRelation())

return &migrations.OpDropMultiColumnConstraint{
Up: migrations.MultiColumnUpSQL{
Expand Down Expand Up @@ -443,6 +430,13 @@ func canConvertColumnForSetDataType(column *pgq.ColumnDef) bool {
return true
}

func getQualifiedRelationName(rel *pgq.RangeVar) string {
if rel.GetSchemaname() == "" {
return rel.GetRelname()
}
return fmt.Sprintf("%s.%s", rel.GetSchemaname(), rel.GetRelname())
}

func ptr[T any](x T) *T {
return &x
}

0 comments on commit 91f65ea

Please sign in to comment.