From 1f6f49b7f3abbba391618f1f66a7f47c5a7cdc4c Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Tue, 3 Dec 2024 11:29:40 +0000 Subject: [PATCH] Add a `sql2pgroll` package to convert SQL to `pgroll` migrations (#502) Add a `sql2pgroll` package to convert SQL to `pgroll` migrations. Add a (hidden for now) `pgroll sql` command that uses the package to convert SQL strings on the command line to `pgroll` migrations. The `sql2pgroll` package is incomplete, with almost all SQL falling back to conversion using raw SQL migrations. Only some `CREATE TABLE` statements and the `ALTER TABLE ... ALTER COLUMN ... SET NOT NULL` statement are currently handled. ```bash $ pgroll sql "create table foo(a serial primary key, b text unique)" ``` ```json [ { "create_table": { "columns": [ { "name": "a", "pk": true, "type": "serial" }, { "name": "b", "nullable": true, "type": "text", "unique": true } ], "name": "foo" } } ] ``` Part of #504 --- cmd/root.go | 1 + cmd/sql.go | 39 ++++++++++ go.mod | 1 + go.sum | 6 ++ pkg/sql2pgroll/alter_table.go | 47 ++++++++++++ pkg/sql2pgroll/alter_table_test.go | 41 ++++++++++ pkg/sql2pgroll/convert.go | 54 +++++++++++++ pkg/sql2pgroll/create_table.go | 90 ++++++++++++++++++++++ pkg/sql2pgroll/create_table_test.go | 73 ++++++++++++++++++ pkg/sql2pgroll/expect/alter_table.go | 20 +++++ pkg/sql2pgroll/expect/create_table.go | 104 ++++++++++++++++++++++++++ 11 files changed, 476 insertions(+) create mode 100644 cmd/sql.go create mode 100644 pkg/sql2pgroll/alter_table.go create mode 100644 pkg/sql2pgroll/alter_table_test.go create mode 100644 pkg/sql2pgroll/convert.go create mode 100644 pkg/sql2pgroll/create_table.go create mode 100644 pkg/sql2pgroll/create_table_test.go create mode 100644 pkg/sql2pgroll/expect/alter_table.go create mode 100644 pkg/sql2pgroll/expect/create_table.go diff --git a/cmd/root.go b/cmd/root.go index 7dd51cef..e65ebef0 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -79,6 +79,7 @@ func Execute() error { rootCmd.AddCommand(migrateCmd()) rootCmd.AddCommand(pullCmd()) rootCmd.AddCommand(latestCmd()) + rootCmd.AddCommand(sqlCmd()) return rootCmd.Execute() } diff --git a/cmd/sql.go b/cmd/sql.go new file mode 100644 index 00000000..da095328 --- /dev/null +++ b/cmd/sql.go @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/spf13/cobra" + "github.com/xataio/pgroll/pkg/sql2pgroll" +) + +func sqlCmd() *cobra.Command { + sqlCmd := &cobra.Command{ + Use: "sql ", + Short: "Convert SQL statements to pgroll operations", + Args: cobra.ExactArgs(1), + Hidden: true, + RunE: func(cmd *cobra.Command, args []string) error { + sql := args[0] + + ops, err := sql2pgroll.Convert(sql) + if err != nil { + return fmt.Errorf("failed to convert SQL statement: %w", err) + } + + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + if err := enc.Encode(ops); err != nil { + return fmt.Errorf("failed to encode operations: %w", err) + } + + return nil + }, + } + + return sqlCmd +} diff --git a/go.mod b/go.mod index 87db04f5..53a4c756 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pganalyze/pg_query_go/v6 v6.0.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect diff --git a/go.sum b/go.sum index 4d5a42ec..facf6921 100644 --- a/go.sum +++ b/go.sum @@ -69,6 +69,8 @@ github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -140,6 +142,8 @@ github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQ github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pganalyze/pg_query_go/v6 v6.0.0 h1:in6RkR/apfqlAtvqgDxd4Y4o87a5Pr8fkKDB4DrDo2c= +github.com/pganalyze/pg_query_go/v6 v6.0.0/go.mod h1:nvTHIuoud6e1SfrUaFwHqT0i4b5Nr+1rPWVds3B5+50= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -325,6 +329,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf h1: google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go new file mode 100644 index 00000000..3351cd5d --- /dev/null +++ b/pkg/sql2pgroll/alter_table.go @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll + +import ( + pgq "github.com/pganalyze/pg_query_go/v6" + "github.com/xataio/pgroll/pkg/migrations" +) + +const PlaceHolderSQL = "TODO: Implement SQL data migration" + +// convertAlterTableStmt converts an ALTER TABLE statement to pgroll operations. +func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, error) { + if stmt.Objtype != pgq.ObjectType_OBJECT_TABLE { + return nil, nil + } + + var ops migrations.Operations + for _, cmd := range stmt.Cmds { + alterTableCmd := cmd.GetAlterTableCmd() + if alterTableCmd == nil { + continue + } + + //nolint:gocritic + switch alterTableCmd.Subtype { + case pgq.AlterTableType_AT_SetNotNull: + ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd)) + } + } + + return ops, nil +} + +func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) migrations.Operation { + return &migrations.OpAlterColumn{ + Table: stmt.GetRelation().GetRelname(), + Column: cmd.GetName(), + Nullable: ptr(false), + Up: PlaceHolderSQL, + Down: PlaceHolderSQL, + } +} + +func ptr[T any](x T) *T { + return &x +} diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go new file mode 100644 index 00000000..b074d9f5 --- /dev/null +++ b/pkg/sql2pgroll/alter_table_test.go @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/xataio/pgroll/pkg/migrations" + "github.com/xataio/pgroll/pkg/sql2pgroll" + "github.com/xataio/pgroll/pkg/sql2pgroll/expect" +) + +func TestConvertAlterTableStatements(t *testing.T) { + t.Parallel() + + tests := []struct { + sql string + expectedOp migrations.Operation + }{ + { + sql: "ALTER TABLE foo ALTER COLUMN a SET NOT NULL", + expectedOp: expect.AlterTableOp1, + }, + } + + for _, tc := range tests { + t.Run(tc.sql, func(t *testing.T) { + ops, err := sql2pgroll.Convert(tc.sql) + require.NoError(t, err) + + require.Len(t, ops, 1) + + alterColumnOps, ok := ops[0].(*migrations.OpAlterColumn) + require.True(t, ok) + + assert.Equal(t, tc.expectedOp, alterColumnOps) + }) + } +} diff --git a/pkg/sql2pgroll/convert.go b/pkg/sql2pgroll/convert.go new file mode 100644 index 00000000..372f25cf --- /dev/null +++ b/pkg/sql2pgroll/convert.go @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll + +import ( + "fmt" + + pgq "github.com/pganalyze/pg_query_go/v6" + "github.com/xataio/pgroll/pkg/migrations" +) + +var ErrStatementCount = fmt.Errorf("expected exactly one statement") + +// Convert converts a SQL statement to a slice of pgroll operations. +func Convert(sql string) (migrations.Operations, error) { + ops, err := convert(sql) + if err != nil { + return nil, err + } + + if ops == nil { + return makeRawSQLOperation(sql), nil + } + + return ops, nil +} + +func convert(sql string) (migrations.Operations, error) { + tree, err := pgq.Parse(sql) + if err != nil { + return nil, fmt.Errorf("parse error: %w", err) + } + + stmts := tree.GetStmts() + if len(stmts) != 1 { + return nil, fmt.Errorf("%w: got %d statements", ErrStatementCount, len(stmts)) + } + node := stmts[0].GetStmt().GetNode() + + switch node := (node).(type) { + case *pgq.Node_CreateStmt: + return convertCreateStmt(node.CreateStmt) + case *pgq.Node_AlterTableStmt: + return convertAlterTableStmt(node.AlterTableStmt) + default: + return makeRawSQLOperation(sql), nil + } +} + +func makeRawSQLOperation(sql string) migrations.Operations { + return migrations.Operations{ + &migrations.OpRawSQL{Up: sql}, + } +} diff --git a/pkg/sql2pgroll/create_table.go b/pkg/sql2pgroll/create_table.go new file mode 100644 index 00000000..63604895 --- /dev/null +++ b/pkg/sql2pgroll/create_table.go @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll + +import ( + "fmt" + "strings" + + pgq "github.com/pganalyze/pg_query_go/v6" + "github.com/xataio/pgroll/pkg/migrations" +) + +// convertCreateStmt converts a CREATE TABLE statement to a pgroll operation. +func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) { + columns := make([]migrations.Column, 0, len(stmt.TableElts)) + for _, elt := range stmt.TableElts { + columns = append(columns, convertColumnDef(elt.GetColumnDef())) + } + + return migrations.Operations{ + &migrations.OpCreateTable{ + Name: stmt.Relation.Relname, + Columns: columns, + }, + }, nil +} + +func convertColumnDef(col *pgq.ColumnDef) migrations.Column { + ignoredTypeParts := map[string]bool{ + "pg_catalog": true, + } + + // Build the type name, including any schema qualifiers + typeParts := make([]string, 0, len(col.GetTypeName().Names)) + for _, node := range col.GetTypeName().Names { + typePart := node.GetString_().GetSval() + if _, ok := ignoredTypeParts[typePart]; ok { + continue + } + typeParts = append(typeParts, typePart) + } + + // Build the type modifiers, such as precision and scale for numeric types + var typeMods []string + for _, node := range col.GetTypeName().Typmods { + if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok { + typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval())) + } + } + var typeModifier string + if len(typeMods) > 0 { + typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ",")) + } + + // Build the array bounds for array types + var arrayBounds string + for _, node := range col.GetTypeName().ArrayBounds { + bound := node.GetInteger().GetIval() + if bound == -1 { + arrayBounds = "[]" + } else { + arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound) + } + } + + // Determine column nullability, uniqueness, and primary key status + var notNull, unique, pk bool + var defaultValue *string + for _, constraint := range col.Constraints { + if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_NOTNULL { + notNull = true + } + if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_UNIQUE { + unique = true + } + if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_PRIMARY { + pk = true + notNull = true + } + } + + return migrations.Column{ + Name: col.Colname, + Type: strings.Join(typeParts, ".") + typeModifier + arrayBounds, + Nullable: !notNull, + Unique: unique, + Default: defaultValue, + Pk: pk, + } +} diff --git a/pkg/sql2pgroll/create_table_test.go b/pkg/sql2pgroll/create_table_test.go new file mode 100644 index 00000000..5636f489 --- /dev/null +++ b/pkg/sql2pgroll/create_table_test.go @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/xataio/pgroll/pkg/migrations" + "github.com/xataio/pgroll/pkg/sql2pgroll" + "github.com/xataio/pgroll/pkg/sql2pgroll/expect" +) + +func TestConvertCreateTableStatements(t *testing.T) { + t.Parallel() + + tests := []struct { + sql string + expectedOp migrations.Operation + }{ + { + sql: "CREATE TABLE foo(a int)", + expectedOp: expect.CreateTableOp1, + }, + { + sql: "CREATE TABLE foo(a int NOT NULL)", + expectedOp: expect.CreateTableOp2, + }, + { + sql: "CREATE TABLE foo(a varchar(255))", + expectedOp: expect.CreateTableOp3, + }, + { + sql: "CREATE TABLE foo(a numeric(10, 2))", + expectedOp: expect.CreateTableOp4, + }, + { + sql: "CREATE TABLE foo(a int UNIQUE)", + expectedOp: expect.CreateTableOp5, + }, + { + sql: "CREATE TABLE foo(a int PRIMARY KEY)", + expectedOp: expect.CreateTableOp6, + }, + { + sql: "CREATE TABLE foo(a text[])", + expectedOp: expect.CreateTableOp7, + }, + { + sql: "CREATE TABLE foo(a text[5])", + expectedOp: expect.CreateTableOp8, + }, + { + sql: "CREATE TABLE foo(a text[5][3])", + expectedOp: expect.CreateTableOp9, + }, + } + + for _, tc := range tests { + t.Run(tc.sql, func(t *testing.T) { + ops, err := sql2pgroll.Convert(tc.sql) + require.NoError(t, err) + + require.Len(t, ops, 1) + + createTableOp, ok := ops[0].(*migrations.OpCreateTable) + require.True(t, ok) + + assert.Equal(t, tc.expectedOp, createTableOp) + }) + } +} diff --git a/pkg/sql2pgroll/expect/alter_table.go b/pkg/sql2pgroll/expect/alter_table.go new file mode 100644 index 00000000..f1da4a85 --- /dev/null +++ b/pkg/sql2pgroll/expect/alter_table.go @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expect + +import ( + "github.com/xataio/pgroll/pkg/migrations" + "github.com/xataio/pgroll/pkg/sql2pgroll" +) + +var AlterTableOp1 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "a", + Nullable: ptr(false), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +func ptr[T any](v T) *T { + return &v +} diff --git a/pkg/sql2pgroll/expect/create_table.go b/pkg/sql2pgroll/expect/create_table.go new file mode 100644 index 00000000..d8ded16c --- /dev/null +++ b/pkg/sql2pgroll/expect/create_table.go @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expect + +import "github.com/xataio/pgroll/pkg/migrations" + +var CreateTableOp1 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "int4", + Nullable: true, + }, + }, +} + +var CreateTableOp2 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "int4", + }, + }, +} + +var CreateTableOp3 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "varchar(255)", + Nullable: true, + }, + }, +} + +var CreateTableOp4 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "numeric(10,2)", + Nullable: true, + }, + }, +} + +var CreateTableOp5 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "int4", + Nullable: true, + Unique: true, + }, + }, +} + +var CreateTableOp6 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "int4", + Pk: true, + }, + }, +} + +var CreateTableOp7 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "text[]", + Nullable: true, + }, + }, +} + +var CreateTableOp8 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "text[5]", + Nullable: true, + }, + }, +} + +var CreateTableOp9 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "text[5][3]", + Nullable: true, + }, + }, +}