diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index d5de8947c5..fcb3e64174 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -12,6 +12,7 @@ import ( "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/connectors/utils/cdc_records" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/geo" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/shared" @@ -362,8 +363,8 @@ func (p *PostgresCDCSource) consumeStream( if retryAttemptForWALSegmentRemoved > maxRetriesForWalSegmentRemoved { return fmt.Errorf("max retries for WAL segment removed exceeded: %+v", errMsg) } else { - p.logger.Warn(fmt.Sprintf( - "WAL segment removed, restarting replication retrying in 30 seconds..."), + p.logger.Warn( + "WAL segment removed, restarting replication retrying in 30 seconds...", slog.Any("error", errMsg), slog.Int("retryAttempt", retryAttemptForWALSegmentRemoved)) time.Sleep(30 * time.Second) continue @@ -761,7 +762,7 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma if ok { customQKind := customTypeToQKind(typeName) if customQKind == qvalue.QValueKindGeography || customQKind == qvalue.QValueKindGeometry { - wkt, err := GeoValidate(string(data)) + wkt, err := geo.GeoValidate(string(data)) if err != nil { return qvalue.QValue{ Kind: customQKind, diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index 75e258e70d..bb07fbb98f 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -7,6 +7,7 @@ import ( "time" "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/geo" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/shared" @@ -446,7 +447,7 @@ func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription, customQKind := customTypeToQKind(typeName) if customQKind == qvalue.QValueKindGeography || customQKind == qvalue.QValueKindGeometry { wkbString, ok := values[i].(string) - wkt, err := GeoValidate(wkbString) + wkt, err := geo.GeoValidate(wkbString) if err != nil || !ok { values[i] = nil } else { diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index aafa00073e..0037a43fa1 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -1,7 +1,6 @@ package connpostgres import ( - "encoding/hex" "encoding/json" "errors" "fmt" @@ -14,8 +13,6 @@ import ( "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/jackc/pgx/v5/pgtype" "github.com/lib/pq/oid" - - geom "github.com/twpayne/go-geos" ) func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { @@ -407,28 +404,3 @@ func customTypeToQKind(typeName string) qvalue.QValueKind { } return qValueKind } - -// returns the WKT representation of the geometry object if it is valid -func GeoValidate(hexWkb string) (string, error) { - // Decode the WKB hex string into binary - wkb, hexErr := hex.DecodeString(hexWkb) - if hexErr != nil { - slog.Warn(fmt.Sprintf("Ignoring invalid WKB: %s", hexWkb)) - return "", hexErr - } - - // UnmarshalWKB performs geometry validation along with WKB parsing - geometryObject, geoErr := geom.NewGeomFromWKB(wkb) - if geoErr != nil { - return "", geoErr - } - - invalidReason := geometryObject.IsValidReason() - if invalidReason != "Valid Geometry" { - slog.Warn(fmt.Sprintf("Ignoring invalid geometry shape %s: %s", hexWkb, invalidReason)) - return "", errors.New(invalidReason) - } - - wkt := geometryObject.ToWKT() - return wkt, nil -} diff --git a/flow/e2e/bigquery/qrep_flow_bq_test.go b/flow/e2e/bigquery/qrep_flow_bq_test.go index da1bade8bd..ddfd8b6373 100644 --- a/flow/e2e/bigquery/qrep_flow_bq_test.go +++ b/flow/e2e/bigquery/qrep_flow_bq_test.go @@ -1,10 +1,8 @@ package e2e_bigquery import ( - "context" "fmt" - connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/e2e" "github.com/stretchr/testify/require" ) @@ -17,13 +15,7 @@ func (s PeerFlowE2ETestSuiteBQ) setupSourceTable(tableName string, rowCount int) } func (s PeerFlowE2ETestSuiteBQ) compareTableContentsBQ(tableName string, colsString string) { - // read rows from source table - pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart") - pgQueryExecutor.SetTestEnv(true) - - pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery( - fmt.Sprintf("SELECT %s FROM e2e_test_%s.%s ORDER BY id", colsString, s.bqSuffix, tableName), - ) + pgRows, err := e2e.GetPgRows(s.pool, s.bqSuffix, tableName, colsString) require.NoError(s.t, err) // read rows from destination table @@ -33,7 +25,7 @@ func (s PeerFlowE2ETestSuiteBQ) compareTableContentsBQ(tableName string, colsStr bqRows, err := s.bqHelper.ExecuteAndProcessQuery(bqSelQuery) require.NoError(s.t, err) - e2e.RequireEqualRecordBatchs(s.t, pgRows, bqRows) + e2e.RequireEqualRecordBatches(s.t, pgRows, bqRows) } func (s PeerFlowE2ETestSuiteBQ) Test_Complete_QRep_Flow_Avro() { diff --git a/flow/e2e/snowflake/qrep_flow_sf_test.go b/flow/e2e/snowflake/qrep_flow_sf_test.go index c9303ec2d6..e48d4d62a6 100644 --- a/flow/e2e/snowflake/qrep_flow_sf_test.go +++ b/flow/e2e/snowflake/qrep_flow_sf_test.go @@ -1,10 +1,8 @@ package e2e_snowflake import ( - "context" "fmt" - connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/google/uuid" @@ -19,10 +17,6 @@ func (s PeerFlowE2ETestSuiteSF) setupSourceTable(tableName string, numRows int) require.NoError(s.t, err) } -func (s PeerFlowE2ETestSuiteSF) compareTableContentsSF(tableName, selector string) { - s.compareTableContentsWithDiffSelectorsSF(tableName, selector, selector, false) -} - func (s PeerFlowE2ETestSuiteSF) checkJSONValue(tableName, colName, fieldName, value string) error { res, err := s.sfHelper.ExecuteAndProcessQuery(fmt.Sprintf( "SELECT %s:%s FROM %s;", @@ -38,19 +32,17 @@ func (s PeerFlowE2ETestSuiteSF) checkJSONValue(tableName, colName, fieldName, va return nil } +func (s PeerFlowE2ETestSuiteSF) compareTableContentsSF(tableName, selector string) { + s.compareTableContentsWithDiffSelectorsSF(tableName, selector, selector, false) +} + func (s PeerFlowE2ETestSuiteSF) compareTableContentsWithDiffSelectorsSF(tableName, pgSelector, sfSelector string, tableCaseSensitive bool, ) { - // read rows from source table - pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart") - pgQueryExecutor.SetTestEnv(true) - pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery( - fmt.Sprintf(`SELECT %s FROM e2e_test_%s."%s" ORDER BY id`, pgSelector, s.pgSuffix, tableName), - ) + pgRows, err := e2e.GetPgRows(s.pool, s.pgSuffix, tableName, pgSelector) require.NoError(s.t, err) // read rows from destination table - var qualifiedTableName string if tableCaseSensitive { qualifiedTableName = fmt.Sprintf(`%s.%s."%s"`, s.sfHelper.testDatabaseName, s.sfHelper.testSchemaName, tableName) @@ -59,11 +51,11 @@ func (s PeerFlowE2ETestSuiteSF) compareTableContentsWithDiffSelectorsSF(tableNam } sfSelQuery := fmt.Sprintf(`SELECT %s FROM %s ORDER BY id`, sfSelector, qualifiedTableName) - s.t.Logf("running query on snowflake: %s\n", sfSelQuery) + s.t.Logf("running query on snowflake: %s", sfSelQuery) sfRows, err := s.sfHelper.ExecuteAndProcessQuery(sfSelQuery) require.NoError(s.t, err) - e2e.RequireEqualRecordBatchs(s.t, pgRows, sfRows) + e2e.RequireEqualRecordBatches(s.t, pgRows, sfRows) } func (s PeerFlowE2ETestSuiteSF) Test_Complete_QRep_Flow_Avro_SF() { diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go index dc7e83d1c3..4f011a8e1b 100644 --- a/flow/e2e/test_utils.go +++ b/flow/e2e/test_utils.go @@ -11,6 +11,7 @@ import ( "time" "github.com/PeerDB-io/peer-flow/activities" + connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" utils "github.com/PeerDB-io/peer-flow/connectors/utils/catalog" "github.com/PeerDB-io/peer-flow/e2eshared" @@ -23,6 +24,7 @@ import ( peerflow "github.com/PeerDB-io/peer-flow/workflows" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" "go.temporal.io/sdk/testsuite" ) @@ -58,6 +60,15 @@ func RegisterWorkflowsAndActivities(t *testing.T, env *testsuite.TestWorkflowEnv env.RegisterActivity(&activities.SnapshotActivity{}) } +func GetPgRows(pool *pgxpool.Pool, suffix string, tableName string, cols string) (*model.QRecordBatch, error) { + pgQueryExecutor := connpostgres.NewQRepQueryExecutor(pool, context.Background(), "testflow", "testpart") + pgQueryExecutor.SetTestEnv(true) + + return pgQueryExecutor.ExecuteAndProcessQuery( + fmt.Sprintf(`SELECT %s FROM e2e_test_%s."%s" ORDER BY id`, cols, suffix, tableName), + ) +} + func SetupCDCFlowStatusQuery(env *testsuite.TestWorkflowEnvironment, connectionGen FlowConnectionGenerationConfig, ) { @@ -357,6 +368,12 @@ func GetOwnersSchema() *model.QRecordSchema { {Name: "f6", Type: qvalue.QValueKindJSON, Nullable: true}, {Name: "f7", Type: qvalue.QValueKindJSON, Nullable: true}, {Name: "f8", Type: qvalue.QValueKindInt16, Nullable: true}, + {Name: "geometryPoint", Type: qvalue.QValueKindGeometry, Nullable: true}, + {Name: "geometry_linestring", Type: qvalue.QValueKindGeometry, Nullable: true}, + {Name: "geometry_polygon", Type: qvalue.QValueKindGeometry, Nullable: true}, + {Name: "geography_point", Type: qvalue.QValueKindGeography, Nullable: true}, + {Name: "geography_linestring", Type: qvalue.QValueKindGeography, Nullable: true}, + {Name: "geography_polygon", Type: qvalue.QValueKindGeography, Nullable: true}, }, } } @@ -367,7 +384,16 @@ func GetOwnersSelectorStringsSF() [2]string { sfFields := make([]string, 0, len(schema.Fields)) for _, field := range schema.Fields { pgFields = append(pgFields, fmt.Sprintf(`"%s"`, field.Name)) - sfFields = append(sfFields, connsnowflake.SnowflakeIdentifierNormalize(field.Name)) + if strings.Contains(field.Name, "geo") { + colName := connsnowflake.SnowflakeIdentifierNormalize(field.Name) + + // Have to apply a WKT transformation here, + // else the sql driver we use receives the values as snowflake's OBJECT + // which is troublesome to deal with. Now it receives it as string. + sfFields = append(sfFields, fmt.Sprintf(`ST_ASWKT(%s) as %s`, colName, colName)) + } else { + sfFields = append(sfFields, connsnowflake.SnowflakeIdentifierNormalize(field.Name)) + } } return [2]string{strings.Join(pgFields, ","), strings.Join(sfFields, ",")} } @@ -419,40 +445,7 @@ func (l *TStructuredLogger) Error(msg string, keyvals ...interface{}) { l.logger.With(l.keyvalsToFields(keyvals)).Error(msg) } -// Equals checks if two QRecordBatches are identical. -func RequireEqualRecordBatchs(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) bool { +func RequireEqualRecordBatches(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) { t.Helper() - - if other == nil { - t.Log("other is nil") - return q == nil - } - - // First check simple attributes - if q.NumRecords != other.NumRecords { - // print num records - t.Logf("q.NumRecords: %d", q.NumRecords) - t.Logf("other.NumRecords: %d", other.NumRecords) - return false - } - - // Compare column names - if !q.Schema.EqualNames(other.Schema) { - t.Log("Column names are not equal") - t.Logf("Schema 1: %v", q.Schema.GetColumnNames()) - t.Logf("Schema 2: %v", other.Schema.GetColumnNames()) - return false - } - - // Compare records - for i, record := range q.Records { - if !e2eshared.CheckQRecordEquality(t, record, other.Records[i]) { - t.Logf("Record %d is not equal", i) - t.Logf("Record 1: %v", record) - t.Logf("Record 2: %v", other.Records[i]) - return false - } - } - - return true + require.True(t, e2eshared.CheckEqualRecordBatches(t, q, other)) } diff --git a/flow/e2eshared/e2eshared.go b/flow/e2eshared/e2eshared.go index 0324751a93..283499aa78 100644 --- a/flow/e2eshared/e2eshared.go +++ b/flow/e2eshared/e2eshared.go @@ -51,14 +51,52 @@ func CheckQRecordEquality(t *testing.T, q model.QRecord, other model.QRecord) bo t.Helper() if q.NumEntries != other.NumEntries { - t.Logf("unequal entry count: %d != %d\n", q.NumEntries, other.NumEntries) + t.Logf("unequal entry count: %d != %d", q.NumEntries, other.NumEntries) return false } for i, entry := range q.Entries { otherEntry := other.Entries[i] if !entry.Equals(otherEntry) { - t.Logf("entry %d: %v != %v\n", i, entry, otherEntry) + t.Logf("entry %d: %v != %v", i, entry, otherEntry) + return false + } + } + + return true +} + +// Equals checks if two QRecordBatches are identical. +func CheckEqualRecordBatches(t *testing.T, q *model.QRecordBatch, other *model.QRecordBatch) bool { + t.Helper() + + if q == nil || other == nil { + t.Logf("q nil? %v, other nil? %v", q == nil, other == nil) + return q == nil && other == nil + } + + // First check simple attributes + if q.NumRecords != other.NumRecords { + // print num records + t.Logf("q.NumRecords: %d", q.NumRecords) + t.Logf("other.NumRecords: %d", other.NumRecords) + return false + } + + // Compare column names + if !q.Schema.EqualNames(other.Schema) { + t.Log("Column names are not equal") + t.Logf("Schema 1: %v", q.Schema.GetColumnNames()) + t.Logf("Schema 2: %v", other.Schema.GetColumnNames()) + return false + } + + // Compare records + for i, record := range q.Records { + if !CheckQRecordEquality(t, record, other.Records[i]) { + t.Logf("Record %d is not equal", i) + t.Logf("Record 1: %v", record) + t.Logf("Record 2: %v", other.Records[i]) return false } } diff --git a/flow/geo/geo.go b/flow/geo/geo.go new file mode 100644 index 0000000000..a7f87e0174 --- /dev/null +++ b/flow/geo/geo.go @@ -0,0 +1,51 @@ +//nolint:all +package geo + +import ( + "encoding/hex" + "errors" + "fmt" + "log/slog" + + geom "github.com/twpayne/go-geos" +) + +// returns the WKT representation of the geometry object if it is valid +func GeoValidate(hexWkb string) (string, error) { + // Decode the WKB hex string into binary + wkb, hexErr := hex.DecodeString(hexWkb) + if hexErr != nil { + slog.Warn(fmt.Sprintf("Ignoring invalid WKB: %s", hexWkb)) + return "", hexErr + } + + // UnmarshalWKB performs geometry validation along with WKB parsing + geometryObject, geoErr := geom.NewGeomFromWKB(wkb) + if geoErr != nil { + return "", geoErr + } + + invalidReason := geometryObject.IsValidReason() + if invalidReason != "Valid Geometry" { + slog.Warn(fmt.Sprintf("Ignoring invalid geometry shape %s: %s", hexWkb, invalidReason)) + return "", errors.New(invalidReason) + } + + wkt := geometryObject.ToWKT() + return wkt, nil +} + +// compares WKTs +func GeoCompare(wkt1, wkt2 string) bool { + geom1, geoErr := geom.NewGeomFromWKT(wkt1) + if geoErr != nil { + return false + } + + geom2, geoErr := geom.NewGeomFromWKT(wkt2) + if geoErr != nil { + return false + } + + return geom1.Equals(geom2) +} diff --git a/flow/model/qvalue/qvalue.go b/flow/model/qvalue/qvalue.go index 4d80b0fb79..24c92e9275 100644 --- a/flow/model/qvalue/qvalue.go +++ b/flow/model/qvalue/qvalue.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + "github.com/PeerDB-io/peer-flow/geo" "github.com/google/uuid" ) @@ -203,8 +204,16 @@ func compareString(value1, value2 interface{}) bool { str1, ok1 := value1.(string) str2, ok2 := value2.(string) + if !ok1 || !ok2 { + return false + } + if str1 == str2 { + return true + } - return ok1 && ok2 && str1 == str2 + // Catch matching WKB(in Postgres)-WKT(in destination) geo values + geoConvertedWKT, err := geo.GeoValidate(str1) + return err == nil && geo.GeoCompare(geoConvertedWKT, str2) } func compareStruct(value1, value2 interface{}) bool { @@ -267,6 +276,10 @@ func compareNumericArrays(value1, value2 interface{}) bool { return true } + if value1 == nil && value2 == "" { + return true + } + // Helper function to convert a value to float64 convertToFloat64 := func(val interface{}) []float64 { switch v := val.(type) { @@ -321,6 +334,11 @@ func compareArrayString(value1, value2 interface{}) bool { return true } + // nulls end up as empty 'variants' in snowflake + if value1 == nil && value2 == "" { + return true + } + array1, ok1 := value1.([]string) array2, ok2 := value2.([]string)