diff --git a/flow/e2e/bigquery/qrep_flow_bq_test.go b/flow/e2e/bigquery/qrep_flow_bq_test.go index da1bade8bd..ddfd8b6373 100644 --- a/flow/e2e/bigquery/qrep_flow_bq_test.go +++ b/flow/e2e/bigquery/qrep_flow_bq_test.go @@ -1,10 +1,8 @@ package e2e_bigquery import ( - "context" "fmt" - connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/e2e" "github.com/stretchr/testify/require" ) @@ -17,13 +15,7 @@ func (s PeerFlowE2ETestSuiteBQ) setupSourceTable(tableName string, rowCount int) } func (s PeerFlowE2ETestSuiteBQ) compareTableContentsBQ(tableName string, colsString string) { - // read rows from source table - pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart") - pgQueryExecutor.SetTestEnv(true) - - pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery( - fmt.Sprintf("SELECT %s FROM e2e_test_%s.%s ORDER BY id", colsString, s.bqSuffix, tableName), - ) + pgRows, err := e2e.GetPgRows(s.pool, s.bqSuffix, tableName, colsString) require.NoError(s.t, err) // read rows from destination table @@ -33,7 +25,7 @@ func (s PeerFlowE2ETestSuiteBQ) compareTableContentsBQ(tableName string, colsStr bqRows, err := s.bqHelper.ExecuteAndProcessQuery(bqSelQuery) require.NoError(s.t, err) - e2e.RequireEqualRecordBatchs(s.t, pgRows, bqRows) + e2e.RequireEqualRecordBatches(s.t, pgRows, bqRows) } func (s PeerFlowE2ETestSuiteBQ) Test_Complete_QRep_Flow_Avro() { diff --git a/flow/e2e/snowflake/qrep_flow_sf_test.go b/flow/e2e/snowflake/qrep_flow_sf_test.go index c9303ec2d6..e48d4d62a6 100644 --- a/flow/e2e/snowflake/qrep_flow_sf_test.go +++ b/flow/e2e/snowflake/qrep_flow_sf_test.go @@ -1,10 +1,8 @@ package e2e_snowflake import ( - "context" "fmt" - connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/google/uuid" @@ -19,10 +17,6 @@ func (s PeerFlowE2ETestSuiteSF) setupSourceTable(tableName string, numRows int) require.NoError(s.t, err) } -func (s PeerFlowE2ETestSuiteSF) compareTableContentsSF(tableName, selector string) { - s.compareTableContentsWithDiffSelectorsSF(tableName, selector, selector, false) -} - func (s PeerFlowE2ETestSuiteSF) checkJSONValue(tableName, colName, fieldName, value string) error { res, err := s.sfHelper.ExecuteAndProcessQuery(fmt.Sprintf( "SELECT %s:%s FROM %s;", @@ -38,19 +32,17 @@ func (s PeerFlowE2ETestSuiteSF) checkJSONValue(tableName, colName, fieldName, va return nil } +func (s PeerFlowE2ETestSuiteSF) compareTableContentsSF(tableName, selector string) { + s.compareTableContentsWithDiffSelectorsSF(tableName, selector, selector, false) +} + func (s PeerFlowE2ETestSuiteSF) compareTableContentsWithDiffSelectorsSF(tableName, pgSelector, sfSelector string, tableCaseSensitive bool, ) { - // read rows from source table - pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart") - pgQueryExecutor.SetTestEnv(true) - pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery( - fmt.Sprintf(`SELECT %s FROM e2e_test_%s."%s" ORDER BY id`, pgSelector, s.pgSuffix, tableName), - ) + pgRows, err := e2e.GetPgRows(s.pool, s.pgSuffix, tableName, pgSelector) require.NoError(s.t, err) // read rows from destination table - var qualifiedTableName string if tableCaseSensitive { qualifiedTableName = fmt.Sprintf(`%s.%s."%s"`, s.sfHelper.testDatabaseName, s.sfHelper.testSchemaName, tableName) @@ -59,11 +51,11 @@ func (s PeerFlowE2ETestSuiteSF) compareTableContentsWithDiffSelectorsSF(tableNam } sfSelQuery := fmt.Sprintf(`SELECT %s FROM %s ORDER BY id`, sfSelector, qualifiedTableName) - s.t.Logf("running query on snowflake: %s\n", sfSelQuery) + s.t.Logf("running query on snowflake: %s", sfSelQuery) sfRows, err := s.sfHelper.ExecuteAndProcessQuery(sfSelQuery) require.NoError(s.t, err) - e2e.RequireEqualRecordBatchs(s.t, pgRows, sfRows) + e2e.RequireEqualRecordBatches(s.t, pgRows, sfRows) } func (s PeerFlowE2ETestSuiteSF) Test_Complete_QRep_Flow_Avro_SF() { diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go index 795d191ab6..c2337f21d0 100644 --- a/flow/e2e/test_utils.go +++ b/flow/e2e/test_utils.go @@ -11,6 +11,7 @@ import ( "time" "github.com/PeerDB-io/peer-flow/activities" + connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" utils "github.com/PeerDB-io/peer-flow/connectors/utils/catalog" "github.com/PeerDB-io/peer-flow/e2eshared" @@ -22,6 +23,7 @@ import ( peerflow "github.com/PeerDB-io/peer-flow/workflows" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" "go.temporal.io/sdk/testsuite" ) @@ -57,6 +59,15 @@ func RegisterWorkflowsAndActivities(t *testing.T, env *testsuite.TestWorkflowEnv env.RegisterActivity(&activities.SnapshotActivity{}) } +func GetPgRows(pool *pgxpool.Pool, suffix string, tableName string, cols string) (*model.QRecordBatch, error) { + pgQueryExecutor := connpostgres.NewQRepQueryExecutor(pool, context.Background(), "testflow", "testpart") + pgQueryExecutor.SetTestEnv(true) + + return pgQueryExecutor.ExecuteAndProcessQuery( + fmt.Sprintf("SELECT %s FROM e2e_test_%s.%s ORDER BY id", cols, suffix, tableName), + ) +} + func SetupCDCFlowStatusQuery(env *testsuite.TestWorkflowEnvironment, connectionGen FlowConnectionGenerationConfig, ) { @@ -418,40 +429,7 @@ func (l *TStructuredLogger) Error(msg string, keyvals ...interface{}) { l.logger.With(l.keyvalsToFields(keyvals)).Error(msg) } -// Equals checks if two QRecordBatches are identical. -func RequireEqualRecordBatchs(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) bool { +func RequireEqualRecordBatches(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) { t.Helper() - - if other == nil { - t.Log("other is nil") - return q == nil - } - - // First check simple attributes - if q.NumRecords != other.NumRecords { - // print num records - t.Logf("q.NumRecords: %d", q.NumRecords) - t.Logf("other.NumRecords: %d", other.NumRecords) - return false - } - - // Compare column names - if !q.Schema.EqualNames(other.Schema) { - t.Log("Column names are not equal") - t.Logf("Schema 1: %v", q.Schema.GetColumnNames()) - t.Logf("Schema 2: %v", other.Schema.GetColumnNames()) - return false - } - - // Compare records - for i, record := range q.Records { - if !e2eshared.CheckQRecordEquality(t, record, other.Records[i]) { - t.Logf("Record %d is not equal", i) - t.Logf("Record 1: %v", record) - t.Logf("Record 2: %v", other.Records[i]) - return false - } - } - - return true + require.True(t, e2eshared.CheckEqualRecordBatches(t, q, other)) } diff --git a/flow/e2eshared/e2eshared.go b/flow/e2eshared/e2eshared.go index 0324751a93..1044fe4ad3 100644 --- a/flow/e2eshared/e2eshared.go +++ b/flow/e2eshared/e2eshared.go @@ -65,3 +65,41 @@ func CheckQRecordEquality(t *testing.T, q model.QRecord, other model.QRecord) bo return true } + +// Equals checks if two QRecordBatches are identical. +func CheckEqualRecordBatches(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) bool { + t.Helper() + + if q == nil || other == nil { + t.Logf("q nil? %v, other nil? %v", q == nil, other == nil) + return q == nil && other == nil + } + + // First check simple attributes + if q.NumRecords != other.NumRecords { + // print num records + t.Logf("q.NumRecords: %d", q.NumRecords) + t.Logf("other.NumRecords: %d", other.NumRecords) + return false + } + + // Compare column names + if !q.Schema.EqualNames(other.Schema) { + t.Log("Column names are not equal") + t.Logf("Schema 1: %v", q.Schema.GetColumnNames()) + t.Logf("Schema 2: %v", other.Schema.GetColumnNames()) + return false + } + + // Compare records + for i, record := range q.Records { + if !CheckQRecordEquality(t, record, other.Records[i]) { + t.Logf("Record %d is not equal", i) + t.Logf("Record 1: %v", record) + t.Logf("Record 2: %v", other.Records[i]) + return false + } + } + + return true +}