Skip to content

Commit

Permalink
Fix e2e.RequireEqualRecordBatches (#967)
Browse files Browse the repository at this point in the history
I introduced regression in #923

Function was returning bool while callers expect testify/require
semantics

---------

Co-authored-by: Amogh-Bharadwaj <[email protected]>
Co-authored-by: Kevin Biju <[email protected]>
  • Loading branch information
3 people authored Jan 3, 2024
1 parent c18b45e commit 20ed252
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 96 deletions.
7 changes: 4 additions & 3 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/connectors/utils/cdc_records"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/geo"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/shared"
Expand Down Expand Up @@ -362,8 +363,8 @@ func (p *PostgresCDCSource) consumeStream(
if retryAttemptForWALSegmentRemoved > maxRetriesForWalSegmentRemoved {
return fmt.Errorf("max retries for WAL segment removed exceeded: %+v", errMsg)
} else {
p.logger.Warn(fmt.Sprintf(
"WAL segment removed, restarting replication retrying in 30 seconds..."),
p.logger.Warn(
"WAL segment removed, restarting replication retrying in 30 seconds...",
slog.Any("error", errMsg), slog.Int("retryAttempt", retryAttemptForWALSegmentRemoved))
time.Sleep(30 * time.Second)
continue
Expand Down Expand Up @@ -761,7 +762,7 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma
if ok {
customQKind := customTypeToQKind(typeName)
if customQKind == qvalue.QValueKindGeography || customQKind == qvalue.QValueKindGeometry {
wkt, err := GeoValidate(string(data))
wkt, err := geo.GeoValidate(string(data))
if err != nil {
return qvalue.QValue{
Kind: customQKind,
Expand Down
3 changes: 2 additions & 1 deletion flow/connectors/postgres/qrep_query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/geo"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/shared"
Expand Down Expand Up @@ -446,7 +447,7 @@ func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription,
customQKind := customTypeToQKind(typeName)
if customQKind == qvalue.QValueKindGeography || customQKind == qvalue.QValueKindGeometry {
wkbString, ok := values[i].(string)
wkt, err := GeoValidate(wkbString)
wkt, err := geo.GeoValidate(wkbString)
if err != nil || !ok {
values[i] = nil
} else {
Expand Down
28 changes: 0 additions & 28 deletions flow/connectors/postgres/qvalue_convert.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package connpostgres

import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
Expand All @@ -14,8 +13,6 @@ import (
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/jackc/pgx/v5/pgtype"
"github.com/lib/pq/oid"

geom "github.com/twpayne/go-geos"
)

func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind {
Expand Down Expand Up @@ -407,28 +404,3 @@ func customTypeToQKind(typeName string) qvalue.QValueKind {
}
return qValueKind
}

// returns the WKT representation of the geometry object if it is valid
func GeoValidate(hexWkb string) (string, error) {
// Decode the WKB hex string into binary
wkb, hexErr := hex.DecodeString(hexWkb)
if hexErr != nil {
slog.Warn(fmt.Sprintf("Ignoring invalid WKB: %s", hexWkb))
return "", hexErr
}

// UnmarshalWKB performs geometry validation along with WKB parsing
geometryObject, geoErr := geom.NewGeomFromWKB(wkb)
if geoErr != nil {
return "", geoErr
}

invalidReason := geometryObject.IsValidReason()
if invalidReason != "Valid Geometry" {
slog.Warn(fmt.Sprintf("Ignoring invalid geometry shape %s: %s", hexWkb, invalidReason))
return "", errors.New(invalidReason)
}

wkt := geometryObject.ToWKT()
return wkt, nil
}
12 changes: 2 additions & 10 deletions flow/e2e/bigquery/qrep_flow_bq_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package e2e_bigquery

import (
"context"
"fmt"

connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/e2e"
"github.com/stretchr/testify/require"
)
Expand All @@ -17,13 +15,7 @@ func (s PeerFlowE2ETestSuiteBQ) setupSourceTable(tableName string, rowCount int)
}

func (s PeerFlowE2ETestSuiteBQ) compareTableContentsBQ(tableName string, colsString string) {
// read rows from source table
pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart")
pgQueryExecutor.SetTestEnv(true)

pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery(
fmt.Sprintf("SELECT %s FROM e2e_test_%s.%s ORDER BY id", colsString, s.bqSuffix, tableName),
)
pgRows, err := e2e.GetPgRows(s.pool, s.bqSuffix, tableName, colsString)
require.NoError(s.t, err)

// read rows from destination table
Expand All @@ -33,7 +25,7 @@ func (s PeerFlowE2ETestSuiteBQ) compareTableContentsBQ(tableName string, colsStr
bqRows, err := s.bqHelper.ExecuteAndProcessQuery(bqSelQuery)
require.NoError(s.t, err)

e2e.RequireEqualRecordBatchs(s.t, pgRows, bqRows)
e2e.RequireEqualRecordBatches(s.t, pgRows, bqRows)
}

func (s PeerFlowE2ETestSuiteBQ) Test_Complete_QRep_Flow_Avro() {
Expand Down
22 changes: 7 additions & 15 deletions flow/e2e/snowflake/qrep_flow_sf_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package e2e_snowflake

import (
"context"
"fmt"

connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/e2e"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/google/uuid"
Expand All @@ -19,10 +17,6 @@ func (s PeerFlowE2ETestSuiteSF) setupSourceTable(tableName string, numRows int)
require.NoError(s.t, err)
}

func (s PeerFlowE2ETestSuiteSF) compareTableContentsSF(tableName, selector string) {
s.compareTableContentsWithDiffSelectorsSF(tableName, selector, selector, false)
}

func (s PeerFlowE2ETestSuiteSF) checkJSONValue(tableName, colName, fieldName, value string) error {
res, err := s.sfHelper.ExecuteAndProcessQuery(fmt.Sprintf(
"SELECT %s:%s FROM %s;",
Expand All @@ -38,19 +32,17 @@ func (s PeerFlowE2ETestSuiteSF) checkJSONValue(tableName, colName, fieldName, va
return nil
}

func (s PeerFlowE2ETestSuiteSF) compareTableContentsSF(tableName, selector string) {
s.compareTableContentsWithDiffSelectorsSF(tableName, selector, selector, false)
}

func (s PeerFlowE2ETestSuiteSF) compareTableContentsWithDiffSelectorsSF(tableName, pgSelector, sfSelector string,
tableCaseSensitive bool,
) {
// read rows from source table
pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart")
pgQueryExecutor.SetTestEnv(true)
pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery(
fmt.Sprintf(`SELECT %s FROM e2e_test_%s."%s" ORDER BY id`, pgSelector, s.pgSuffix, tableName),
)
pgRows, err := e2e.GetPgRows(s.pool, s.pgSuffix, tableName, pgSelector)
require.NoError(s.t, err)

// read rows from destination table

var qualifiedTableName string
if tableCaseSensitive {
qualifiedTableName = fmt.Sprintf(`%s.%s."%s"`, s.sfHelper.testDatabaseName, s.sfHelper.testSchemaName, tableName)
Expand All @@ -59,11 +51,11 @@ func (s PeerFlowE2ETestSuiteSF) compareTableContentsWithDiffSelectorsSF(tableNam
}

sfSelQuery := fmt.Sprintf(`SELECT %s FROM %s ORDER BY id`, sfSelector, qualifiedTableName)
s.t.Logf("running query on snowflake: %s\n", sfSelQuery)
s.t.Logf("running query on snowflake: %s", sfSelQuery)
sfRows, err := s.sfHelper.ExecuteAndProcessQuery(sfSelQuery)
require.NoError(s.t, err)

e2e.RequireEqualRecordBatchs(s.t, pgRows, sfRows)
e2e.RequireEqualRecordBatches(s.t, pgRows, sfRows)
}

func (s PeerFlowE2ETestSuiteSF) Test_Complete_QRep_Flow_Avro_SF() {
Expand Down
65 changes: 29 additions & 36 deletions flow/e2e/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/PeerDB-io/peer-flow/activities"
connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
utils "github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
"github.com/PeerDB-io/peer-flow/e2eshared"
Expand All @@ -23,6 +24,7 @@ import (
peerflow "github.com/PeerDB-io/peer-flow/workflows"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require"
"go.temporal.io/sdk/testsuite"
)

Expand Down Expand Up @@ -58,6 +60,15 @@ func RegisterWorkflowsAndActivities(t *testing.T, env *testsuite.TestWorkflowEnv
env.RegisterActivity(&activities.SnapshotActivity{})
}

func GetPgRows(pool *pgxpool.Pool, suffix string, tableName string, cols string) (*model.QRecordBatch, error) {
pgQueryExecutor := connpostgres.NewQRepQueryExecutor(pool, context.Background(), "testflow", "testpart")
pgQueryExecutor.SetTestEnv(true)

return pgQueryExecutor.ExecuteAndProcessQuery(
fmt.Sprintf(`SELECT %s FROM e2e_test_%s."%s" ORDER BY id`, cols, suffix, tableName),
)
}

func SetupCDCFlowStatusQuery(env *testsuite.TestWorkflowEnvironment,
connectionGen FlowConnectionGenerationConfig,
) {
Expand Down Expand Up @@ -357,6 +368,12 @@ func GetOwnersSchema() *model.QRecordSchema {
{Name: "f6", Type: qvalue.QValueKindJSON, Nullable: true},
{Name: "f7", Type: qvalue.QValueKindJSON, Nullable: true},
{Name: "f8", Type: qvalue.QValueKindInt16, Nullable: true},
{Name: "geometryPoint", Type: qvalue.QValueKindGeometry, Nullable: true},
{Name: "geometry_linestring", Type: qvalue.QValueKindGeometry, Nullable: true},
{Name: "geometry_polygon", Type: qvalue.QValueKindGeometry, Nullable: true},
{Name: "geography_point", Type: qvalue.QValueKindGeography, Nullable: true},
{Name: "geography_linestring", Type: qvalue.QValueKindGeography, Nullable: true},
{Name: "geography_polygon", Type: qvalue.QValueKindGeography, Nullable: true},
},
}
}
Expand All @@ -367,7 +384,16 @@ func GetOwnersSelectorStringsSF() [2]string {
sfFields := make([]string, 0, len(schema.Fields))
for _, field := range schema.Fields {
pgFields = append(pgFields, fmt.Sprintf(`"%s"`, field.Name))
sfFields = append(sfFields, connsnowflake.SnowflakeIdentifierNormalize(field.Name))
if strings.Contains(field.Name, "geo") {
colName := connsnowflake.SnowflakeIdentifierNormalize(field.Name)

// Have to apply a WKT transformation here,
// else the sql driver we use receives the values as snowflake's OBJECT
// which is troublesome to deal with. Now it receives it as string.
sfFields = append(sfFields, fmt.Sprintf(`ST_ASWKT(%s) as %s`, colName, colName))
} else {
sfFields = append(sfFields, connsnowflake.SnowflakeIdentifierNormalize(field.Name))
}
}
return [2]string{strings.Join(pgFields, ","), strings.Join(sfFields, ",")}
}
Expand Down Expand Up @@ -419,40 +445,7 @@ func (l *TStructuredLogger) Error(msg string, keyvals ...interface{}) {
l.logger.With(l.keyvalsToFields(keyvals)).Error(msg)
}

// Equals checks if two QRecordBatches are identical.
func RequireEqualRecordBatchs(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) bool {
func RequireEqualRecordBatches(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) {
t.Helper()

if other == nil {
t.Log("other is nil")
return q == nil
}

// First check simple attributes
if q.NumRecords != other.NumRecords {
// print num records
t.Logf("q.NumRecords: %d", q.NumRecords)
t.Logf("other.NumRecords: %d", other.NumRecords)
return false
}

// Compare column names
if !q.Schema.EqualNames(other.Schema) {
t.Log("Column names are not equal")
t.Logf("Schema 1: %v", q.Schema.GetColumnNames())
t.Logf("Schema 2: %v", other.Schema.GetColumnNames())
return false
}

// Compare records
for i, record := range q.Records {
if !e2eshared.CheckQRecordEquality(t, record, other.Records[i]) {
t.Logf("Record %d is not equal", i)
t.Logf("Record 1: %v", record)
t.Logf("Record 2: %v", other.Records[i])
return false
}
}

return true
require.True(t, e2eshared.CheckEqualRecordBatches(t, q, other))
}
42 changes: 40 additions & 2 deletions flow/e2eshared/e2eshared.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,52 @@ func CheckQRecordEquality(t *testing.T, q model.QRecord, other model.QRecord) bo
t.Helper()

if q.NumEntries != other.NumEntries {
t.Logf("unequal entry count: %d != %d\n", q.NumEntries, other.NumEntries)
t.Logf("unequal entry count: %d != %d", q.NumEntries, other.NumEntries)
return false
}

for i, entry := range q.Entries {
otherEntry := other.Entries[i]
if !entry.Equals(otherEntry) {
t.Logf("entry %d: %v != %v\n", i, entry, otherEntry)
t.Logf("entry %d: %v != %v", i, entry, otherEntry)
return false
}
}

return true
}

// Equals checks if two QRecordBatches are identical.
func CheckEqualRecordBatches(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) bool {
t.Helper()

if q == nil || other == nil {
t.Logf("q nil? %v, other nil? %v", q == nil, other == nil)
return q == nil && other == nil
}

// First check simple attributes
if q.NumRecords != other.NumRecords {
// print num records
t.Logf("q.NumRecords: %d", q.NumRecords)
t.Logf("other.NumRecords: %d", other.NumRecords)
return false
}

// Compare column names
if !q.Schema.EqualNames(other.Schema) {
t.Log("Column names are not equal")
t.Logf("Schema 1: %v", q.Schema.GetColumnNames())
t.Logf("Schema 2: %v", other.Schema.GetColumnNames())
return false
}

// Compare records
for i, record := range q.Records {
if !CheckQRecordEquality(t, record, other.Records[i]) {
t.Logf("Record %d is not equal", i)
t.Logf("Record 1: %v", record)
t.Logf("Record 2: %v", other.Records[i])
return false
}
}
Expand Down
51 changes: 51 additions & 0 deletions flow/geo/geo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//nolint:all
package geo

import (
"encoding/hex"
"errors"
"fmt"
"log/slog"

geom "github.com/twpayne/go-geos"
)

// returns the WKT representation of the geometry object if it is valid
func GeoValidate(hexWkb string) (string, error) {
// Decode the WKB hex string into binary
wkb, hexErr := hex.DecodeString(hexWkb)
if hexErr != nil {
slog.Warn(fmt.Sprintf("Ignoring invalid WKB: %s", hexWkb))
return "", hexErr
}

// UnmarshalWKB performs geometry validation along with WKB parsing
geometryObject, geoErr := geom.NewGeomFromWKB(wkb)
if geoErr != nil {
return "", geoErr
}

invalidReason := geometryObject.IsValidReason()
if invalidReason != "Valid Geometry" {
slog.Warn(fmt.Sprintf("Ignoring invalid geometry shape %s: %s", hexWkb, invalidReason))
return "", errors.New(invalidReason)
}

wkt := geometryObject.ToWKT()
return wkt, nil
}

// compares WKTs
func GeoCompare(wkt1, wkt2 string) bool {
geom1, geoErr := geom.NewGeomFromWKT(wkt1)
if geoErr != nil {
return false
}

geom2, geoErr := geom.NewGeomFromWKT(wkt2)
if geoErr != nil {
return false
}

return geom1.Equals(geom2)
}
Loading

0 comments on commit 20ed252

Please sign in to comment.