diff --git a/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go b/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go index 5486583d86..3c07a67691 100644 --- a/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go +++ b/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go @@ -5,50 +5,47 @@ import ( "fmt" "log/slog" "os" + "strings" "testing" "time" "github.com/PeerDB-io/peer-flow/e2e" + "github.com/PeerDB-io/peer-flow/e2eshared" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" + "github.com/PeerDB-io/peer-flow/shared" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "github.com/joho/godotenv" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - "go.temporal.io/sdk/testsuite" + "github.com/ysmood/got" ) -const sqlserverSuffix = "sqlserver" - type PeerFlowE2ETestSuiteSQLServer struct { - suite.Suite - testsuite.WorkflowTestSuite + got.G + t *testing.T pool *pgxpool.Pool sqlsHelper *SQLServerHelper + suffix string } func TestCDCFlowE2ETestSuiteSQLServer(t *testing.T) { - suite.Run(t, new(PeerFlowE2ETestSuiteSQLServer)) -} - -// setup sql server connection -func (s *PeerFlowE2ETestSuiteSQLServer) setupSQLServer() { - env := os.Getenv("ENABLE_SQLSERVER_TESTS") - if env != "true" { - s.sqlsHelper = nil - return - } + e2eshared.GotSuite(t, SetupSuite, func(s PeerFlowE2ETestSuiteSQLServer) { + err := e2e.TearDownPostgres(s.pool, s.suffix) + require.NoError(s.t, err) - sqlsHelper, err := NewSQLServerHelper("test_sqlserver_peer") - require.NoError(s.T(), err) - s.sqlsHelper = sqlsHelper + if s.sqlsHelper != nil { + err = s.sqlsHelper.CleanUp() + require.NoError(s.t, err) + } + }) } -func (s *PeerFlowE2ETestSuiteSQLServer) SetupSuite() { +func SetupSuite(t *testing.T, g got.G) PeerFlowE2ETestSuiteSQLServer { + t.Helper() + err := godotenv.Load() if err != nil { // it's okay if the .env file is not present @@ -56,37 +53,36 @@ func (s *PeerFlowE2ETestSuiteSQLServer) SetupSuite() { slog.Info("Unable to load .env file, using default values from env") } - pool, err := e2e.SetupPostgres(sqlserverSuffix) - if err != nil || pool == nil { - s.Fail("failed to setup postgres", err) + suffix := "sqls_" + strings.ToLower(shared.RandomString(8)) + pool, err := e2e.SetupPostgres(suffix) + if err != nil { + require.NoError(t, err) } - s.pool = pool - - s.setupSQLServer() -} -// Implement TearDownAllSuite interface to tear down the test suite -func (s *PeerFlowE2ETestSuiteSQLServer) TearDownSuite() { - err := e2e.TearDownPostgres(s.pool, sqlserverSuffix) - if err != nil { - s.Fail("failed to drop Postgres schema", err) + var sqlsHelper *SQLServerHelper + env := os.Getenv("ENABLE_SQLSERVER_TESTS") + if env != "true" { + sqlsHelper = nil + } else { + sqlsHelper, err = NewSQLServerHelper("test_sqlserver_peer") + require.NoError(t, err) } - if s.sqlsHelper != nil { - err = s.sqlsHelper.CleanUp() - if err != nil { - s.Fail("failed to clean up sqlserver", err) - } + return PeerFlowE2ETestSuiteSQLServer{ + G: g, + t: t, + pool: pool, + sqlsHelper: sqlsHelper, } } -func (s *PeerFlowE2ETestSuiteSQLServer) setupSQLServerTable(tableName string) { +func (s PeerFlowE2ETestSuiteSQLServer) setupSQLServerTable(tableName string) { schema := getSimpleTableSchema() err := s.sqlsHelper.CreateTable(schema, tableName) - require.NoError(s.T(), err) + require.NoError(s.t, err) } -func (s *PeerFlowE2ETestSuiteSQLServer) insertRowsIntoSQLServerTable(tableName string, numRows int) { +func (s PeerFlowE2ETestSuiteSQLServer) insertRowsIntoSQLServerTable(tableName string, numRows int) { schemaQualified := fmt.Sprintf("%s.%s", s.sqlsHelper.SchemaName, tableName) for i := 0; i < numRows; i++ { params := make(map[string]interface{}) @@ -101,20 +97,20 @@ func (s *PeerFlowE2ETestSuiteSQLServer) insertRowsIntoSQLServerTable(tableName s params, ) - require.NoError(s.T(), err) + require.NoError(s.t, err) } } -func (s *PeerFlowE2ETestSuiteSQLServer) setupPGDestinationTable(tableName string) { +func (s PeerFlowE2ETestSuiteSQLServer) setupPGDestinationTable(tableName string) { ctx := context.Background() - _, err := s.pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS e2e_test_%s.%s", sqlserverSuffix, tableName)) - require.NoError(s.T(), err) + _, err := s.pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS e2e_test_%s.%s", s.suffix, tableName)) + require.NoError(s.t, err) _, err = s.pool.Exec(ctx, fmt.Sprintf("CREATE TABLE e2e_test_%s.%s (id TEXT, card_id TEXT, v_from TIMESTAMP, price NUMERIC, status INT)", - sqlserverSuffix, tableName)) - require.NoError(s.T(), err) + s.suffix, tableName)) + require.NoError(s.t, err) } func getSimpleTableSchema() *model.QRecordSchema { @@ -129,13 +125,13 @@ func getSimpleTableSchema() *model.QRecordSchema { } } -func (s *PeerFlowE2ETestSuiteSQLServer) Test_Complete_QRep_Flow_SqlServer_Append() { +func (s PeerFlowE2ETestSuiteSQLServer) Test_Complete_QRep_Flow_SqlServer_Append() { if s.sqlsHelper == nil { - s.T().Skip("Skipping SQL Server test") + s.t.Skip("Skipping SQL Server test") } - env := s.NewTestWorkflowEnvironment() - e2e.RegisterWorkflowsAndActivities(s.T(), env) + env := e2e.NewTemporalTestWorkflowEnvironment() + e2e.RegisterWorkflowsAndActivities(s.t, env) numRows := 10 tblName := "test_qrep_flow_avro_ss_append" @@ -145,7 +141,7 @@ func (s *PeerFlowE2ETestSuiteSQLServer) Test_Complete_QRep_Flow_SqlServer_Append s.insertRowsIntoSQLServerTable(tblName, numRows) s.setupPGDestinationTable(tblName) - dstTableName := fmt.Sprintf("e2e_test_%s.%s", sqlserverSuffix, tblName) + dstTableName := fmt.Sprintf("e2e_test_%s.%s", s.suffix, tblName) query := fmt.Sprintf("SELECT * FROM %s.%s WHERE v_from BETWEEN {{.start}} AND {{.end}}", s.sqlsHelper.SchemaName, tblName) @@ -172,13 +168,13 @@ func (s *PeerFlowE2ETestSuiteSQLServer) Test_Complete_QRep_Flow_SqlServer_Append s.True(env.IsWorkflowCompleted()) err := env.GetWorkflowError() - s.NoError(err) + require.NoError(s.t, err) // Verify that the destination table has the same number of rows as the source table var numRowsInDest pgtype.Int8 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", dstTableName) err = s.pool.QueryRow(context.Background(), countQuery).Scan(&numRowsInDest) - s.NoError(err) + require.NoError(s.t, err) s.Equal(numRows, int(numRowsInDest.Int64)) }