Skip to content

Commit

Permalink
Add a sql2pgroll package to convert SQL to pgroll migrations (#502)
Browse files Browse the repository at this point in the history
Add a `sql2pgroll` package to convert SQL to `pgroll` migrations.

Add a (hidden for now) `pgroll sql` command that uses the package to
convert SQL strings on the command line to `pgroll` migrations.

The `sql2pgroll` package is incomplete, with almost all SQL falling back
to conversion using raw SQL migrations. Only some `CREATE TABLE`
statements and the `ALTER TABLE ... ALTER COLUMN ... SET NOT NULL`
statement are currently handled.

```bash
$ pgroll sql "create table foo(a serial primary key, b text unique)" 
```

```json
[
  {
    "create_table": {
      "columns": [
        {
          "name": "a",
          "pk": true,
          "type": "serial"
        },
        {
          "name": "b",
          "nullable": true,
          "type": "text",
          "unique": true
        }
      ],
      "name": "foo"
    }
  }
]
```

Part of #504
  • Loading branch information
andrew-farries authored Dec 3, 2024
1 parent 43a6d97 commit 1f6f49b
Show file tree
Hide file tree
Showing 11 changed files with 476 additions and 0 deletions.
1 change: 1 addition & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func Execute() error {
rootCmd.AddCommand(migrateCmd())
rootCmd.AddCommand(pullCmd())
rootCmd.AddCommand(latestCmd())
rootCmd.AddCommand(sqlCmd())

return rootCmd.Execute()
}
39 changes: 39 additions & 0 deletions cmd/sql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-License-Identifier: Apache-2.0

package cmd

import (
"encoding/json"
"fmt"
"os"

"github.com/spf13/cobra"
"github.com/xataio/pgroll/pkg/sql2pgroll"
)

func sqlCmd() *cobra.Command {
sqlCmd := &cobra.Command{
Use: "sql <sql statement>",
Short: "Convert SQL statements to pgroll operations",
Args: cobra.ExactArgs(1),
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
sql := args[0]

ops, err := sql2pgroll.Convert(sql)
if err != nil {
return fmt.Errorf("failed to convert SQL statement: %w", err)
}

enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
if err := enc.Encode(ops); err != nil {
return fmt.Errorf("failed to encode operations: %w", err)
}

return nil
},
}

return sqlCmd
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pganalyze/pg_query_go/v6 v6.0.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
Expand Down Expand Up @@ -140,6 +142,8 @@ github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQ
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pganalyze/pg_query_go/v6 v6.0.0 h1:in6RkR/apfqlAtvqgDxd4Y4o87a5Pr8fkKDB4DrDo2c=
github.com/pganalyze/pg_query_go/v6 v6.0.0/go.mod h1:nvTHIuoud6e1SfrUaFwHqT0i4b5Nr+1rPWVds3B5+50=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
Expand Down Expand Up @@ -325,6 +329,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf h1:
google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
47 changes: 47 additions & 0 deletions pkg/sql2pgroll/alter_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll

import (
pgq "github.com/pganalyze/pg_query_go/v6"
"github.com/xataio/pgroll/pkg/migrations"
)

const PlaceHolderSQL = "TODO: Implement SQL data migration"

// convertAlterTableStmt converts an ALTER TABLE statement to pgroll operations.
func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, error) {
if stmt.Objtype != pgq.ObjectType_OBJECT_TABLE {
return nil, nil
}

var ops migrations.Operations
for _, cmd := range stmt.Cmds {
alterTableCmd := cmd.GetAlterTableCmd()
if alterTableCmd == nil {
continue
}

//nolint:gocritic
switch alterTableCmd.Subtype {
case pgq.AlterTableType_AT_SetNotNull:
ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd))
}
}

return ops, nil
}

func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) migrations.Operation {
return &migrations.OpAlterColumn{
Table: stmt.GetRelation().GetRelname(),
Column: cmd.GetName(),
Nullable: ptr(false),
Up: PlaceHolderSQL,
Down: PlaceHolderSQL,
}
}

func ptr[T any](x T) *T {
return &x
}
41 changes: 41 additions & 0 deletions pkg/sql2pgroll/alter_table_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll_test

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/sql2pgroll"
"github.com/xataio/pgroll/pkg/sql2pgroll/expect"
)

func TestConvertAlterTableStatements(t *testing.T) {
t.Parallel()

tests := []struct {
sql string
expectedOp migrations.Operation
}{
{
sql: "ALTER TABLE foo ALTER COLUMN a SET NOT NULL",
expectedOp: expect.AlterTableOp1,
},
}

for _, tc := range tests {
t.Run(tc.sql, func(t *testing.T) {
ops, err := sql2pgroll.Convert(tc.sql)
require.NoError(t, err)

require.Len(t, ops, 1)

alterColumnOps, ok := ops[0].(*migrations.OpAlterColumn)
require.True(t, ok)

assert.Equal(t, tc.expectedOp, alterColumnOps)
})
}
}
54 changes: 54 additions & 0 deletions pkg/sql2pgroll/convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll

import (
"fmt"

pgq "github.com/pganalyze/pg_query_go/v6"
"github.com/xataio/pgroll/pkg/migrations"
)

var ErrStatementCount = fmt.Errorf("expected exactly one statement")

// Convert converts a SQL statement to a slice of pgroll operations.
func Convert(sql string) (migrations.Operations, error) {
ops, err := convert(sql)
if err != nil {
return nil, err
}

if ops == nil {
return makeRawSQLOperation(sql), nil
}

return ops, nil
}

func convert(sql string) (migrations.Operations, error) {
tree, err := pgq.Parse(sql)
if err != nil {
return nil, fmt.Errorf("parse error: %w", err)
}

stmts := tree.GetStmts()
if len(stmts) != 1 {
return nil, fmt.Errorf("%w: got %d statements", ErrStatementCount, len(stmts))
}
node := stmts[0].GetStmt().GetNode()

switch node := (node).(type) {
case *pgq.Node_CreateStmt:
return convertCreateStmt(node.CreateStmt)
case *pgq.Node_AlterTableStmt:
return convertAlterTableStmt(node.AlterTableStmt)
default:
return makeRawSQLOperation(sql), nil
}
}

func makeRawSQLOperation(sql string) migrations.Operations {
return migrations.Operations{
&migrations.OpRawSQL{Up: sql},
}
}
90 changes: 90 additions & 0 deletions pkg/sql2pgroll/create_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll

import (
"fmt"
"strings"

pgq "github.com/pganalyze/pg_query_go/v6"
"github.com/xataio/pgroll/pkg/migrations"
)

// convertCreateStmt converts a CREATE TABLE statement to a pgroll operation.
func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) {
columns := make([]migrations.Column, 0, len(stmt.TableElts))
for _, elt := range stmt.TableElts {
columns = append(columns, convertColumnDef(elt.GetColumnDef()))
}

return migrations.Operations{
&migrations.OpCreateTable{
Name: stmt.Relation.Relname,
Columns: columns,
},
}, nil
}

func convertColumnDef(col *pgq.ColumnDef) migrations.Column {
ignoredTypeParts := map[string]bool{
"pg_catalog": true,
}

// Build the type name, including any schema qualifiers
typeParts := make([]string, 0, len(col.GetTypeName().Names))
for _, node := range col.GetTypeName().Names {
typePart := node.GetString_().GetSval()
if _, ok := ignoredTypeParts[typePart]; ok {
continue
}
typeParts = append(typeParts, typePart)
}

// Build the type modifiers, such as precision and scale for numeric types
var typeMods []string
for _, node := range col.GetTypeName().Typmods {
if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok {
typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval()))
}
}
var typeModifier string
if len(typeMods) > 0 {
typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ","))
}

// Build the array bounds for array types
var arrayBounds string
for _, node := range col.GetTypeName().ArrayBounds {
bound := node.GetInteger().GetIval()
if bound == -1 {
arrayBounds = "[]"
} else {
arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound)
}
}

// Determine column nullability, uniqueness, and primary key status
var notNull, unique, pk bool
var defaultValue *string
for _, constraint := range col.Constraints {
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_NOTNULL {
notNull = true
}
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_UNIQUE {
unique = true
}
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_PRIMARY {
pk = true
notNull = true
}
}

return migrations.Column{
Name: col.Colname,
Type: strings.Join(typeParts, ".") + typeModifier + arrayBounds,
Nullable: !notNull,
Unique: unique,
Default: defaultValue,
Pk: pk,
}
}
73 changes: 73 additions & 0 deletions pkg/sql2pgroll/create_table_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll_test

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/sql2pgroll"
"github.com/xataio/pgroll/pkg/sql2pgroll/expect"
)

func TestConvertCreateTableStatements(t *testing.T) {
t.Parallel()

tests := []struct {
sql string
expectedOp migrations.Operation
}{
{
sql: "CREATE TABLE foo(a int)",
expectedOp: expect.CreateTableOp1,
},
{
sql: "CREATE TABLE foo(a int NOT NULL)",
expectedOp: expect.CreateTableOp2,
},
{
sql: "CREATE TABLE foo(a varchar(255))",
expectedOp: expect.CreateTableOp3,
},
{
sql: "CREATE TABLE foo(a numeric(10, 2))",
expectedOp: expect.CreateTableOp4,
},
{
sql: "CREATE TABLE foo(a int UNIQUE)",
expectedOp: expect.CreateTableOp5,
},
{
sql: "CREATE TABLE foo(a int PRIMARY KEY)",
expectedOp: expect.CreateTableOp6,
},
{
sql: "CREATE TABLE foo(a text[])",
expectedOp: expect.CreateTableOp7,
},
{
sql: "CREATE TABLE foo(a text[5])",
expectedOp: expect.CreateTableOp8,
},
{
sql: "CREATE TABLE foo(a text[5][3])",
expectedOp: expect.CreateTableOp9,
},
}

for _, tc := range tests {
t.Run(tc.sql, func(t *testing.T) {
ops, err := sql2pgroll.Convert(tc.sql)
require.NoError(t, err)

require.Len(t, ops, 1)

createTableOp, ok := ops[0].(*migrations.OpCreateTable)
require.True(t, ok)

assert.Equal(t, tc.expectedOp, createTableOp)
})
}
}
Loading

0 comments on commit 1f6f49b

Please sign in to comment.