diff --git a/go.mod b/go.mod index 3381a8673..1982d987a 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/stretchr/testify v1.10.0 github.com/testcontainers/testcontainers-go v0.34.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.34.0 - github.com/xataio/pg_query_go/v6 v6.0.0-20241216080535-894186571799 + github.com/xataio/pg_query_go/v6 v6.0.0-20241217092625-e7ba1fbaf89e golang.org/x/tools v0.28.0 ) diff --git a/go.sum b/go.sum index 3cbb39814..8f91503d2 100644 --- a/go.sum +++ b/go.sum @@ -217,8 +217,8 @@ github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZ github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY= github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY= github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= -github.com/xataio/pg_query_go/v6 v6.0.0-20241216080535-894186571799 h1:FUY7PHaOXMicK0kSh32BdLGwd6svrbNZrnup69jG0DM= -github.com/xataio/pg_query_go/v6 v6.0.0-20241216080535-894186571799/go.mod h1:GK6bpfAhPtZb7wG/IccqvnH+cz3cmvvRTkC+MosESGo= +github.com/xataio/pg_query_go/v6 v6.0.0-20241217092625-e7ba1fbaf89e h1:9DShoOhR7/IsNPwTAMkTMbsEZRVcuJCb20RIVGQTIdU= +github.com/xataio/pg_query_go/v6 v6.0.0-20241217092625-e7ba1fbaf89e/go.mod h1:GK6bpfAhPtZb7wG/IccqvnH+cz3cmvvRTkC+MosESGo= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index 1cc0b35dc..c124d8291 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -89,6 +89,11 @@ func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTa return nil, fmt.Errorf("expected column definition, got %T", cmd.GetDef().Node) } + typeName, err := pgq.DeparseTypeName(node.ColumnDef.GetTypeName()) + if err != nil { + return nil, fmt.Errorf("failed to deparse type name: %w", err) + } + if !canConvertColumnForSetDataType(node.ColumnDef) { return nil, nil } @@ -96,7 +101,7 @@ func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTa return &migrations.OpAlterColumn{ Table: stmt.GetRelation().GetRelname(), Column: cmd.GetName(), - Type: ptr(convertTypeName(node.ColumnDef.GetTypeName())), + Type: ptr(typeName), Up: PlaceHolderSQL, Down: PlaceHolderSQL, }, nil diff --git a/pkg/sql2pgroll/create_table.go b/pkg/sql2pgroll/create_table.go index aed171828..5add694a9 100644 --- a/pkg/sql2pgroll/create_table.go +++ b/pkg/sql2pgroll/create_table.go @@ -3,6 +3,8 @@ package sql2pgroll import ( + "fmt" + pgq "github.com/xataio/pg_query_go/v6" "github.com/xataio/pgroll/pkg/migrations" @@ -12,7 +14,11 @@ import ( func convertCreateStmt(stmt *pgq.CreateStmt) (migrations.Operations, error) { columns := make([]migrations.Column, 0, len(stmt.TableElts)) for _, elt := range stmt.TableElts { - columns = append(columns, convertColumnDef(elt.GetColumnDef())) + column, err := convertColumnDef(elt.GetColumnDef()) + if err != nil { + return nil, fmt.Errorf("error converting column definition: %w", err) + } + columns = append(columns, *column) } return migrations.Operations{ @@ -23,9 +29,12 @@ func convertCreateStmt(stmt *pgq.CreateStmt) (migrations.Operations, error) { }, nil } -func convertColumnDef(col *pgq.ColumnDef) migrations.Column { +func convertColumnDef(col *pgq.ColumnDef) (*migrations.Column, error) { // Convert the column type - typeString := convertTypeName(col.TypeName) + typeString, err := pgq.DeparseTypeName(col.TypeName) + if err != nil { + return nil, fmt.Errorf("error deparsing column type: %w", err) + } // Determine column nullability, uniqueness, and primary key status var notNull, unique, pk bool @@ -43,12 +52,12 @@ func convertColumnDef(col *pgq.ColumnDef) migrations.Column { } } - return migrations.Column{ + return &migrations.Column{ Name: col.Colname, Type: typeString, Nullable: !notNull, Unique: unique, Default: defaultValue, Pk: pk, - } + }, nil } diff --git a/pkg/sql2pgroll/expect/create_table.go b/pkg/sql2pgroll/expect/create_table.go index d8ded16c6..6c7271da3 100644 --- a/pkg/sql2pgroll/expect/create_table.go +++ b/pkg/sql2pgroll/expect/create_table.go @@ -9,7 +9,7 @@ var CreateTableOp1 = &migrations.OpCreateTable{ Columns: []migrations.Column{ { Name: "a", - Type: "int4", + Type: "int", Nullable: true, }, }, @@ -20,7 +20,7 @@ var CreateTableOp2 = &migrations.OpCreateTable{ Columns: []migrations.Column{ { Name: "a", - Type: "int4", + Type: "int", }, }, } @@ -41,7 +41,7 @@ var CreateTableOp4 = &migrations.OpCreateTable{ Columns: []migrations.Column{ { Name: "a", - Type: "numeric(10,2)", + Type: "numeric(10, 2)", Nullable: true, }, }, @@ -52,7 +52,7 @@ var CreateTableOp5 = &migrations.OpCreateTable{ Columns: []migrations.Column{ { Name: "a", - Type: "int4", + Type: "int", Nullable: true, Unique: true, }, @@ -64,7 +64,7 @@ var CreateTableOp6 = &migrations.OpCreateTable{ Columns: []migrations.Column{ { Name: "a", - Type: "int4", + Type: "int", Pk: true, }, }, diff --git a/pkg/sql2pgroll/typename.go b/pkg/sql2pgroll/typename.go deleted file mode 100644 index 372eee0da..000000000 --- a/pkg/sql2pgroll/typename.go +++ /dev/null @@ -1,52 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package sql2pgroll - -import ( - "fmt" - "strings" - - pgq "github.com/xataio/pg_query_go/v6" -) - -// convertTypeName converts a TypeName node to a string. -func convertTypeName(typeName *pgq.TypeName) string { - ignoredTypeParts := map[string]bool{ - "pg_catalog": true, - } - - // Build the type name, including any schema qualifiers - typeParts := make([]string, 0, len(typeName.Names)) - for _, node := range typeName.Names { - typePart := node.GetString_().GetSval() - if _, ok := ignoredTypeParts[typePart]; ok { - continue - } - typeParts = append(typeParts, typePart) - } - - // Build the type modifiers, such as precision and scale for numeric types - var typeMods []string - for _, node := range typeName.Typmods { - if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok { - typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval())) - } - } - var typeModifier string - if len(typeMods) > 0 { - typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ",")) - } - - // Build the array bounds for array types - var arrayBounds string - for _, node := range typeName.ArrayBounds { - bound := node.GetInteger().GetIval() - if bound == -1 { - arrayBounds = "[]" - } else { - arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound) - } - } - - return strings.Join(typeParts, ".") + typeModifier + arrayBounds -}