diff --git a/flow/connectors/postgres/postgres_schema_delta_test.go b/flow/connectors/postgres/postgres_schema_delta_test.go index 3ef85cf210..b817e7be51 100644 --- a/flow/connectors/postgres/postgres_schema_delta_test.go +++ b/flow/connectors/postgres/postgres_schema_delta_test.go @@ -3,84 +3,70 @@ package connpostgres import ( "context" "fmt" + "strings" "testing" "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/e2eshared" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model/qvalue" + "github.com/PeerDB-io/peer-flow/shared" "github.com/jackc/pgx/v5" - "github.com/stretchr/testify/suite" + "github.com/stretchr/testify/require" + "github.com/ysmood/got" ) type PostgresSchemaDeltaTestSuite struct { - suite.Suite + got.G + t *testing.T connector *PostgresConnector + schema string } -const schemaDeltaTestSchemaName = "pgschema_delta_test" +func SetupSuite(t *testing.T, g got.G) PostgresSchemaDeltaTestSuite { + t.Helper() -func (suite *PostgresSchemaDeltaTestSuite) failTestError(err error) { - if err != nil { - suite.FailNow(err.Error()) - } -} - -func (suite *PostgresSchemaDeltaTestSuite) SetupSuite() { - var err error - suite.connector, err = NewPostgresConnector(context.Background(), &protos.PostgresConfig{ + connector, err := NewPostgresConnector(context.Background(), &protos.PostgresConfig{ Host: "localhost", Port: 7132, User: "postgres", Password: "postgres", Database: "postgres", }, false) - suite.failTestError(err) + require.NoError(t, err) - setupTx, err := suite.connector.pool.Begin(context.Background()) - suite.failTestError(err) + setupTx, err := connector.pool.Begin(context.Background()) + require.NoError(t, err) defer func() { err := setupTx.Rollback(context.Background()) if err != pgx.ErrTxClosed { - suite.failTestError(err) + require.NoError(t, err) } }() + schema := "pgdelta_" + strings.ToLower(shared.RandomString(8)) _, err = setupTx.Exec(context.Background(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", - schemaDeltaTestSchemaName)) - suite.failTestError(err) - _, err = setupTx.Exec(context.Background(), fmt.Sprintf("CREATE SCHEMA %s", schemaDeltaTestSchemaName)) - suite.failTestError(err) + schema)) + require.NoError(t, err) + _, err = setupTx.Exec(context.Background(), fmt.Sprintf("CREATE SCHEMA %s", schema)) + require.NoError(t, err) err = setupTx.Commit(context.Background()) - suite.failTestError(err) -} - -func (suite *PostgresSchemaDeltaTestSuite) TearDownSuite() { - teardownTx, err := suite.connector.pool.Begin(context.Background()) - suite.failTestError(err) - defer func() { - err := teardownTx.Rollback(context.Background()) - if err != pgx.ErrTxClosed { - suite.failTestError(err) - } - }() - _, err = teardownTx.Exec(context.Background(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", - schemaDeltaTestSchemaName)) - suite.failTestError(err) - err = teardownTx.Commit(context.Background()) - suite.failTestError(err) + require.NoError(t, err) - suite.True(suite.connector.ConnectionActive() == nil) - err = suite.connector.Close() - suite.failTestError(err) - suite.False(suite.connector.ConnectionActive() == nil) + return PostgresSchemaDeltaTestSuite{ + G: g, + t: t, + connector: connector, + schema: schema, + } } -func (suite *PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() { - tableName := fmt.Sprintf("%s.simple_add_column", schemaDeltaTestSchemaName) - _, err := suite.connector.pool.Exec(context.Background(), +func (s PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() { + tableName := fmt.Sprintf("%s.simple_add_column", s.schema) + _, err := s.connector.pool.Exec(context.Background(), fmt.Sprintf("CREATE TABLE %s(id INT PRIMARY KEY)", tableName)) - suite.failTestError(err) + require.NoError(s.t, err) - err = suite.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: []*protos.DeltaAddedColumn{{ @@ -88,13 +74,13 @@ func (suite *PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() { ColumnType: string(qvalue.QValueKindInt64), }}, }}) - suite.failTestError(err) + require.NoError(s.t, err) - output, err := suite.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) - suite.failTestError(err) - suite.Equal(&protos.TableSchema{ + require.NoError(s.t, err) + s.Equal(&protos.TableSchema{ TableIdentifier: tableName, ColumnNames: []string{"id", "hi"}, ColumnTypes: []string{string(qvalue.QValueKindInt32), string(qvalue.QValueKindInt64)}, @@ -102,11 +88,11 @@ func (suite *PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() { }, output.TableNameSchemaMapping[tableName]) } -func (suite *PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() { - tableName := fmt.Sprintf("%s.add_drop_all_column_types", schemaDeltaTestSchemaName) - _, err := suite.connector.pool.Exec(context.Background(), +func (s PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() { + tableName := fmt.Sprintf("%s.add_drop_all_column_types", s.schema) + _, err := s.connector.pool.Exec(context.Background(), fmt.Sprintf("CREATE TABLE %s(id INT PRIMARY KEY)", tableName)) - suite.failTestError(err) + require.NoError(s.t, err) expectedTableSchema := &protos.TableSchema{ TableIdentifier: tableName, @@ -146,25 +132,25 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() { } }) - err = suite.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, }}) - suite.failTestError(err) + require.NoError(s.t, err) - output, err := suite.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) - suite.failTestError(err) - suite.Equal(expectedTableSchema, output.TableNameSchemaMapping[tableName]) + require.NoError(s.t, err) + s.Equal(expectedTableSchema, output.TableNameSchemaMapping[tableName]) } -func (suite *PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() { - tableName := fmt.Sprintf("%s.add_drop_tricky_column_names", schemaDeltaTestSchemaName) - _, err := suite.connector.pool.Exec(context.Background(), +func (s PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() { + tableName := fmt.Sprintf("%s.add_drop_tricky_column_names", s.schema) + _, err := s.connector.pool.Exec(context.Background(), fmt.Sprintf("CREATE TABLE %s(id INT PRIMARY KEY)", tableName)) - suite.failTestError(err) + require.NoError(s.t, err) expectedTableSchema := &protos.TableSchema{ TableIdentifier: tableName, @@ -196,25 +182,25 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() { } }) - err = suite.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, }}) - suite.failTestError(err) + require.NoError(s.t, err) - output, err := suite.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) - suite.failTestError(err) - suite.Equal(expectedTableSchema, output.TableNameSchemaMapping[tableName]) + require.NoError(s.t, err) + s.Equal(expectedTableSchema, output.TableNameSchemaMapping[tableName]) } -func (suite *PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { - tableName := fmt.Sprintf("%s.add_drop_whitespace_column_names", schemaDeltaTestSchemaName) - _, err := suite.connector.pool.Exec(context.Background(), +func (s PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { + tableName := fmt.Sprintf("%s.add_drop_whitespace_column_names", s.schema) + _, err := s.connector.pool.Exec(context.Background(), fmt.Sprintf("CREATE TABLE %s(\" \" INT PRIMARY KEY)", tableName)) - suite.failTestError(err) + require.NoError(s.t, err) expectedTableSchema := &protos.TableSchema{ TableIdentifier: tableName, @@ -237,20 +223,39 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { } }) - err = suite.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, }}) - suite.failTestError(err) + require.NoError(s.t, err) - output, err := suite.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) - suite.failTestError(err) - suite.Equal(expectedTableSchema, output.TableNameSchemaMapping[tableName]) + require.NoError(s.t, err) + s.Equal(expectedTableSchema, output.TableNameSchemaMapping[tableName]) } func TestPostgresSchemaDeltaTestSuite(t *testing.T) { - suite.Run(t, new(PostgresSchemaDeltaTestSuite)) + e2eshared.GotSuite(t, SetupSuite, func(s PostgresSchemaDeltaTestSuite) { + teardownTx, err := s.connector.pool.Begin(context.Background()) + require.NoError(s.t, err) + defer func() { + err := teardownTx.Rollback(context.Background()) + if err != pgx.ErrTxClosed { + require.NoError(s.t, err) + } + }() + _, err = teardownTx.Exec(context.Background(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", + s.schema)) + require.NoError(s.t, err) + err = teardownTx.Commit(context.Background()) + require.NoError(s.t, err) + + require.True(s.t, s.connector.ConnectionActive() == nil) + err = s.connector.Close() + require.NoError(s.t, err) + require.False(s.t, s.connector.ConnectionActive() == nil) + }) }