Skip to content

Commit

Permalink
Add pkg/sql2pgroll package
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-farries committed Dec 3, 2024
1 parent 2fd7105 commit 404fa8a
Show file tree
Hide file tree
Showing 7 changed files with 429 additions and 0 deletions.
47 changes: 47 additions & 0 deletions pkg/sql2pgroll/alter_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll

import (
pgq "github.com/pganalyze/pg_query_go/v6"
"github.com/xataio/pgroll/pkg/migrations"
)

const PlaceHolderSQL = "TODO: Implement SQL data migration"

// convertAlterTableStmt converts an ALTER TABLE statement to pgroll operations.
func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, error) {
if stmt.Objtype != pgq.ObjectType_OBJECT_TABLE {
return nil, nil
}

var ops migrations.Operations
for _, cmd := range stmt.Cmds {
alterTableCmd := cmd.GetAlterTableCmd()
if alterTableCmd == nil {
continue
}

//nolint:gocritic
switch alterTableCmd.Subtype {
case pgq.AlterTableType_AT_SetNotNull:
ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd))
}
}

return ops, nil
}

func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) migrations.Operation {
return &migrations.OpAlterColumn{
Table: stmt.GetRelation().GetRelname(),
Column: cmd.GetName(),
Nullable: ptr(false),
Up: PlaceHolderSQL,
Down: PlaceHolderSQL,
}
}

func ptr[T any](x T) *T {
return &x
}
41 changes: 41 additions & 0 deletions pkg/sql2pgroll/alter_table_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll_test

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/sql2pgroll"
"github.com/xataio/pgroll/pkg/sql2pgroll/expect"
)

func TestConvertAlterTableStatements(t *testing.T) {
t.Parallel()

tests := []struct {
sql string
expectedOp migrations.Operation
}{
{
sql: "ALTER TABLE foo ALTER COLUMN a SET NOT NULL",
expectedOp: expect.AlterTableOp1,
},
}

for _, tc := range tests {
t.Run(tc.sql, func(t *testing.T) {
ops, err := sql2pgroll.Convert(tc.sql)
require.NoError(t, err)

require.Len(t, ops, 1)

alterColumnOps, ok := ops[0].(*migrations.OpAlterColumn)
require.True(t, ok)

assert.Equal(t, tc.expectedOp, alterColumnOps)
})
}
}
54 changes: 54 additions & 0 deletions pkg/sql2pgroll/convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll

import (
"fmt"

pgq "github.com/pganalyze/pg_query_go/v6"
"github.com/xataio/pgroll/pkg/migrations"
)

var ErrStatementCount = fmt.Errorf("expected exactly one statement")

// Convert converts a SQL statement to a slice of pgroll operations.
func Convert(sql string) (migrations.Operations, error) {
ops, err := convert(sql)
if err != nil {
return nil, err
}

if ops == nil {
return makeRawSQLOperation(sql), nil
}

return ops, nil
}

func convert(sql string) (migrations.Operations, error) {
tree, err := pgq.Parse(sql)
if err != nil {
return nil, fmt.Errorf("parse error: %w", err)
}

stmts := tree.GetStmts()
if len(stmts) != 1 {
return nil, fmt.Errorf("%w: got %d statements", ErrStatementCount, len(stmts))
}
node := stmts[0].GetStmt().GetNode()

switch node := (node).(type) {
case *pgq.Node_CreateStmt:
return convertCreateStmt(node.CreateStmt)
case *pgq.Node_AlterTableStmt:
return convertAlterTableStmt(node.AlterTableStmt)
default:
return makeRawSQLOperation(sql), nil
}
}

func makeRawSQLOperation(sql string) migrations.Operations {
return migrations.Operations{
&migrations.OpRawSQL{Up: sql},
}
}
90 changes: 90 additions & 0 deletions pkg/sql2pgroll/create_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll

import (
"fmt"
"strings"

pgq "github.com/pganalyze/pg_query_go/v6"
"github.com/xataio/pgroll/pkg/migrations"
)

// convertCreateStmt converts a CREATE TABLE statement to a pgroll operation.
func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) {
columns := make([]migrations.Column, 0, len(stmt.TableElts))
for _, elt := range stmt.TableElts {
columns = append(columns, convertColumnDef(elt.GetColumnDef()))
}

return migrations.Operations{
&migrations.OpCreateTable{
Name: stmt.Relation.Relname,
Columns: columns,
},
}, nil
}

func convertColumnDef(col *pgq.ColumnDef) migrations.Column {
ignoredTypeParts := map[string]bool{
"pg_catalog": true,
}

// Build the type name, including any schema qualifiers
typeParts := make([]string, 0, len(col.GetTypeName().Names))
for _, node := range col.GetTypeName().Names {
typePart := node.GetString_().GetSval()
if _, ok := ignoredTypeParts[typePart]; ok {
continue
}
typeParts = append(typeParts, typePart)
}

// Build the type modifiers, such as precision and scale for numeric types
var typeMods []string
for _, node := range col.GetTypeName().Typmods {
if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok {
typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval()))
}
}
var typeModifier string
if len(typeMods) > 0 {
typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ","))
}

// Build the array bounds for array types
var arrayBounds string
for _, node := range col.GetTypeName().ArrayBounds {
bound := node.GetInteger().GetIval()
if bound == -1 {
arrayBounds = "[]"
} else {
arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound)
}
}

// Determine column nullability, uniqueness, and primary key status
var notNull, unique, pk bool
var defaultValue *string
for _, constraint := range col.Constraints {
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_NOTNULL {
notNull = true
}
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_UNIQUE {
unique = true
}
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_PRIMARY {
pk = true
notNull = true
}
}

return migrations.Column{
Name: col.Colname,
Type: strings.Join(typeParts, ".") + typeModifier + arrayBounds,
Nullable: !notNull,
Unique: unique,
Default: defaultValue,
Pk: pk,
}
}
73 changes: 73 additions & 0 deletions pkg/sql2pgroll/create_table_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll_test

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/sql2pgroll"
"github.com/xataio/pgroll/pkg/sql2pgroll/expect"
)

func TestConvertCreateTableStatements(t *testing.T) {
t.Parallel()

tests := []struct {
sql string
expectedOp migrations.Operation
}{
{
sql: "CREATE TABLE foo(a int)",
expectedOp: expect.CreateTableOp1,
},
{
sql: "CREATE TABLE foo(a int NOT NULL)",
expectedOp: expect.CreateTableOp2,
},
{
sql: "CREATE TABLE foo(a varchar(255))",
expectedOp: expect.CreateTableOp3,
},
{
sql: "CREATE TABLE foo(a numeric(10, 2))",
expectedOp: expect.CreateTableOp4,
},
{
sql: "CREATE TABLE foo(a int UNIQUE)",
expectedOp: expect.CreateTableOp5,
},
{
sql: "CREATE TABLE foo(a int PRIMARY KEY)",
expectedOp: expect.CreateTableOp6,
},
{
sql: "CREATE TABLE foo(a text[])",
expectedOp: expect.CreateTableOp7,
},
{
sql: "CREATE TABLE foo(a text[5])",
expectedOp: expect.CreateTableOp8,
},
{
sql: "CREATE TABLE foo(a text[5][3])",
expectedOp: expect.CreateTableOp9,
},
}

for _, tc := range tests {
t.Run(tc.sql, func(t *testing.T) {
ops, err := sql2pgroll.Convert(tc.sql)
require.NoError(t, err)

require.Len(t, ops, 1)

createTableOp, ok := ops[0].(*migrations.OpCreateTable)
require.True(t, ok)

assert.Equal(t, tc.expectedOp, createTableOp)
})
}
}
20 changes: 20 additions & 0 deletions pkg/sql2pgroll/expect/alter_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-License-Identifier: Apache-2.0

package expect

import (
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/sql2pgroll"
)

var AlterTableOp1 = &migrations.OpAlterColumn{
Table: "foo",
Column: "a",
Nullable: ptr(false),
Up: sql2pgroll.PlaceHolderSQL,
Down: sql2pgroll.PlaceHolderSQL,
}

func ptr[T any](v T) *T {
return &v
}
Loading

0 comments on commit 404fa8a

Please sign in to comment.