diff --git a/flow/connectors/sql/query_executor.go b/flow/connectors/sql/query_executor.go index 91972a75c3..05279fdde4 100644 --- a/flow/connectors/sql/query_executor.go +++ b/flow/connectors/sql/query_executor.go @@ -138,6 +138,18 @@ func (g *GenericSQLQueryExecutor) CountNonNullRows( return count.Int64, err } +func (g *GenericSQLQueryExecutor) CountSRIDs( + ctx context.Context, + schemaName string, + tableName string, + columnName string, +) (int64, error) { + var count pgtype.Int8 + err := g.db.QueryRowxContext(ctx, "SELECT COUNT(CASE WHEN ST_SRID("+columnName+ + ") <> 0 THEN 1 END) AS not_zero FROM "+schemaName+"."+tableName).Scan(&count) + return count.Int64, err +} + func (g *GenericSQLQueryExecutor) columnTypeToQField(ct *sql.ColumnType) (model.QField, error) { qvKind, ok := g.dbtypeToQValueKind[ct.DatabaseTypeName()] if !ok { diff --git a/flow/e2e/snowflake/peer_flow_sf_test.go b/flow/e2e/snowflake/peer_flow_sf_test.go index 56084a1a27..635a7f3e0d 100644 --- a/flow/e2e/snowflake/peer_flow_sf_test.go +++ b/flow/e2e/snowflake/peer_flow_sf_test.go @@ -127,7 +127,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() { for range 6 { _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (line,poly) VALUES ($1,$2) - `, srcTableName), "010200000002000000000000000000F03F000000000000004000000000000008400000000000001040", + `, srcTableName), "SRID=5678;010200000002000000000000000000F03F000000000000004000000000000008400000000000001040", "010300000001000000050000000000000000000000000000000000000000000000"+ "00000000000000000000f03f000000000000f03f000000000000f03f0000000000"+ "00f03f000000000000000000000000000000000000000000000000") @@ -143,6 +143,13 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() { return false } + // Make sure SRIDs are set + sridCount, err := s.sfHelper.CountSRIDs("test_invalid_geo_sf_avro_cdc", "line") + if err != nil { + s.t.Log(err) + return false + } + polyCount, err := s.sfHelper.CountNonNullRows("test_invalid_geo_sf_avro_cdc", "poly") if err != nil { return false @@ -151,9 +158,14 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() { if lineCount != 6 || polyCount != 6 { s.t.Logf("wrong counts, expect 6 lines 6 polies, not %d lines %d polies", lineCount, polyCount) return false - } else { - return true } + + if sridCount != 6 { + s.t.Logf("there are some srids that are 0, expected 6 non-zero srids, got %d non-zero srids", sridCount) + return false + } + + return true }) env.Cancel() diff --git a/flow/e2e/snowflake/snowflake_helper.go b/flow/e2e/snowflake/snowflake_helper.go index e0d41e838d..14ca9dc35f 100644 --- a/flow/e2e/snowflake/snowflake_helper.go +++ b/flow/e2e/snowflake/snowflake_helper.go @@ -136,6 +136,15 @@ func (s *SnowflakeTestHelper) CountNonNullRows(tableName string, columnName stri return int(res), nil } +func (s *SnowflakeTestHelper) CountSRIDs(tableName string, columnName string) (int, error) { + res, err := s.testClient.CountSRIDs(context.Background(), s.testSchemaName, tableName, columnName) + if err != nil { + return 0, err + } + + return int(res), nil +} + func (s *SnowflakeTestHelper) CheckNull(tableName string, colNames []string) (bool, error) { return s.testClient.CheckNull(context.Background(), s.testSchemaName, tableName, colNames) } diff --git a/flow/geo/geo.go b/flow/geo/geo.go index 6602a26d53..7e35d9bfe6 100644 --- a/flow/geo/geo.go +++ b/flow/geo/geo.go @@ -31,6 +31,10 @@ func GeoValidate(hexWkb string) (string, error) { } wkt := geometryObject.ToWKT() + + if SRID := geometryObject.SRID(); SRID != 0 { + wkt = fmt.Sprintf("SRID=%d;%s", geometryObject.SRID(), wkt) + } return wkt, nil } diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go index 08c5ce7770..dd55ef7ecc 100644 --- a/flow/model/qrecord_batch.go +++ b/flow/model/qrecord_batch.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log/slog" + "strings" "time" "github.com/google/uuid" @@ -237,9 +238,17 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { return nil, src.err } - wkb, err := geo.GeoToWKB(v) + geoWkt := v + if strings.HasPrefix(v, "SRID=") { + _, wkt, found := strings.Cut(v, ";") + if found { + geoWkt = wkt + } + } + + wkb, err := geo.GeoToWKB(geoWkt) if err != nil { - src.err = errors.New("failed to convert Geospatial value to wkb") + src.err = fmt.Errorf("failed to convert Geospatial value to wkb: %v", err) return nil, src.err } diff --git a/flow/model/qvalue/qvalue.go b/flow/model/qvalue/qvalue.go index ae0a3945ab..4a495a31cc 100644 --- a/flow/model/qvalue/qvalue.go +++ b/flow/model/qvalue/qvalue.go @@ -292,7 +292,15 @@ func compareGeometry(value1, value2 interface{}) bool { case *geom.Geom: return v1.Equals(geo2) case string: - geo1, err := geom.NewGeomFromWKT(v1) + geoWkt := v1 + if strings.HasPrefix(geoWkt, "SRID=") { + _, wkt, found := strings.Cut(geoWkt, ";") + if found { + geoWkt = wkt + } + } + + geo1, err := geom.NewGeomFromWKT(geoWkt) if err != nil { panic(err) }