Skip to content

Commit

Permalink
Fix e2e.RequireEqualRecordBatches
Browse files Browse the repository at this point in the history
Function was returning bool while callers expect testify/require semantics
  • Loading branch information
serprex committed Jan 3, 2024
1 parent 8e4d68c commit 2cfcd10
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 62 deletions.
12 changes: 2 additions & 10 deletions flow/e2e/bigquery/qrep_flow_bq_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package e2e_bigquery

import (
"context"
"fmt"

connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/e2e"
"github.com/stretchr/testify/require"
)
Expand All @@ -17,13 +15,7 @@ func (s PeerFlowE2ETestSuiteBQ) setupSourceTable(tableName string, rowCount int)
}

func (s PeerFlowE2ETestSuiteBQ) compareTableContentsBQ(tableName string, colsString string) {
// read rows from source table
pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart")
pgQueryExecutor.SetTestEnv(true)

pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery(
fmt.Sprintf("SELECT %s FROM e2e_test_%s.%s ORDER BY id", colsString, s.bqSuffix, tableName),
)
pgRows, err := e2e.GetPgRows(s.pool, s.bqSuffix, tableName, colsString)
require.NoError(s.t, err)

// read rows from destination table
Expand All @@ -33,7 +25,7 @@ func (s PeerFlowE2ETestSuiteBQ) compareTableContentsBQ(tableName string, colsStr
bqRows, err := s.bqHelper.ExecuteAndProcessQuery(bqSelQuery)
require.NoError(s.t, err)

e2e.RequireEqualRecordBatchs(s.t, pgRows, bqRows)
e2e.RequireEqualRecordBatches(s.t, pgRows, bqRows)
}

func (s PeerFlowE2ETestSuiteBQ) Test_Complete_QRep_Flow_Avro() {
Expand Down
22 changes: 7 additions & 15 deletions flow/e2e/snowflake/qrep_flow_sf_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package e2e_snowflake

import (
"context"
"fmt"

connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/e2e"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/google/uuid"
Expand All @@ -19,10 +17,6 @@ func (s PeerFlowE2ETestSuiteSF) setupSourceTable(tableName string, numRows int)
require.NoError(s.t, err)
}

func (s PeerFlowE2ETestSuiteSF) compareTableContentsSF(tableName, selector string) {
s.compareTableContentsWithDiffSelectorsSF(tableName, selector, selector, false)
}

func (s PeerFlowE2ETestSuiteSF) checkJSONValue(tableName, colName, fieldName, value string) error {
res, err := s.sfHelper.ExecuteAndProcessQuery(fmt.Sprintf(
"SELECT %s:%s FROM %s;",
Expand All @@ -38,19 +32,17 @@ func (s PeerFlowE2ETestSuiteSF) checkJSONValue(tableName, colName, fieldName, va
return nil
}

func (s PeerFlowE2ETestSuiteSF) compareTableContentsSF(tableName, selector string) {
s.compareTableContentsWithDiffSelectorsSF(tableName, selector, selector, false)
}

func (s PeerFlowE2ETestSuiteSF) compareTableContentsWithDiffSelectorsSF(tableName, pgSelector, sfSelector string,
tableCaseSensitive bool,
) {
// read rows from source table
pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart")
pgQueryExecutor.SetTestEnv(true)
pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery(
fmt.Sprintf(`SELECT %s FROM e2e_test_%s."%s" ORDER BY id`, pgSelector, s.pgSuffix, tableName),
)
pgRows, err := e2e.GetPgRows(s.pool, s.pgSuffix, tableName, pgSelector)
require.NoError(s.t, err)

// read rows from destination table

var qualifiedTableName string
if tableCaseSensitive {
qualifiedTableName = fmt.Sprintf(`%s.%s."%s"`, s.sfHelper.testDatabaseName, s.sfHelper.testSchemaName, tableName)
Expand All @@ -59,11 +51,11 @@ func (s PeerFlowE2ETestSuiteSF) compareTableContentsWithDiffSelectorsSF(tableNam
}

sfSelQuery := fmt.Sprintf(`SELECT %s FROM %s ORDER BY id`, sfSelector, qualifiedTableName)
s.t.Logf("running query on snowflake: %s\n", sfSelQuery)
s.t.Logf("running query on snowflake: %s", sfSelQuery)
sfRows, err := s.sfHelper.ExecuteAndProcessQuery(sfSelQuery)
require.NoError(s.t, err)

e2e.RequireEqualRecordBatchs(s.t, pgRows, sfRows)
e2e.RequireEqualRecordBatches(s.t, pgRows, sfRows)
}

func (s PeerFlowE2ETestSuiteSF) Test_Complete_QRep_Flow_Avro_SF() {
Expand Down
48 changes: 13 additions & 35 deletions flow/e2e/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/PeerDB-io/peer-flow/activities"
connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
utils "github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
"github.com/PeerDB-io/peer-flow/e2eshared"
Expand All @@ -22,6 +23,7 @@ import (
peerflow "github.com/PeerDB-io/peer-flow/workflows"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require"
"go.temporal.io/sdk/testsuite"
)

Expand Down Expand Up @@ -57,6 +59,15 @@ func RegisterWorkflowsAndActivities(t *testing.T, env *testsuite.TestWorkflowEnv
env.RegisterActivity(&activities.SnapshotActivity{})
}

func GetPgRows(pool *pgxpool.Pool, suffix string, tableName string, cols string) (*model.QRecordBatch, error) {
pgQueryExecutor := connpostgres.NewQRepQueryExecutor(pool, context.Background(), "testflow", "testpart")
pgQueryExecutor.SetTestEnv(true)

return pgQueryExecutor.ExecuteAndProcessQuery(
fmt.Sprintf("SELECT %s FROM e2e_test_%s.%s ORDER BY id", cols, suffix, tableName),
)
}

func SetupCDCFlowStatusQuery(env *testsuite.TestWorkflowEnvironment,
connectionGen FlowConnectionGenerationConfig,
) {
Expand Down Expand Up @@ -418,40 +429,7 @@ func (l *TStructuredLogger) Error(msg string, keyvals ...interface{}) {
l.logger.With(l.keyvalsToFields(keyvals)).Error(msg)
}

// Equals checks if two QRecordBatches are identical.
func RequireEqualRecordBatchs(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) bool {
func RequireEqualRecordBatches(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) {
t.Helper()

if other == nil {
t.Log("other is nil")
return q == nil
}

// First check simple attributes
if q.NumRecords != other.NumRecords {
// print num records
t.Logf("q.NumRecords: %d", q.NumRecords)
t.Logf("other.NumRecords: %d", other.NumRecords)
return false
}

// Compare column names
if !q.Schema.EqualNames(other.Schema) {
t.Log("Column names are not equal")
t.Logf("Schema 1: %v", q.Schema.GetColumnNames())
t.Logf("Schema 2: %v", other.Schema.GetColumnNames())
return false
}

// Compare records
for i, record := range q.Records {
if !e2eshared.CheckQRecordEquality(t, record, other.Records[i]) {
t.Logf("Record %d is not equal", i)
t.Logf("Record 1: %v", record)
t.Logf("Record 2: %v", other.Records[i])
return false
}
}

return true
require.True(t, e2eshared.CheckEqualRecordBatches(t, q, other))
}
42 changes: 40 additions & 2 deletions flow/e2eshared/e2eshared.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,52 @@ func CheckQRecordEquality(t *testing.T, q model.QRecord, other model.QRecord) bo
t.Helper()

if q.NumEntries != other.NumEntries {
t.Logf("unequal entry count: %d != %d\n", q.NumEntries, other.NumEntries)
t.Logf("unequal entry count: %d != %d", q.NumEntries, other.NumEntries)
return false
}

for i, entry := range q.Entries {
otherEntry := other.Entries[i]
if !entry.Equals(otherEntry) {
t.Logf("entry %d: %v != %v\n", i, entry, otherEntry)
t.Logf("entry %d: %v != %v", i, entry, otherEntry)
return false
}
}

return true
}

// Equals checks if two QRecordBatches are identical.
func CheckEqualRecordBatches(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) bool {
t.Helper()

if q == nil || other == nil {
t.Logf("q nil? %v, other nil? %v", q == nil, other == nil)
return q == nil && other == nil
}

// First check simple attributes
if q.NumRecords != other.NumRecords {
// print num records
t.Logf("q.NumRecords: %d", q.NumRecords)
t.Logf("other.NumRecords: %d", other.NumRecords)
return false
}

// Compare column names
if !q.Schema.EqualNames(other.Schema) {
t.Log("Column names are not equal")
t.Logf("Schema 1: %v", q.Schema.GetColumnNames())
t.Logf("Schema 2: %v", other.Schema.GetColumnNames())
return false
}

// Compare records
for i, record := range q.Records {
if !CheckQRecordEquality(t, record, other.Records[i]) {
t.Logf("Record %d is not equal", i)
t.Logf("Record 1: %v", record)
t.Logf("Record 2: %v", other.Records[i])
return false
}
}
Expand Down

0 comments on commit 2cfcd10

Please sign in to comment.