Skip to content

Commit

Permalink
Geospatial data types: set SRID for geometry (#1514)
Browse files Browse the repository at this point in the history
From Snowflake docs:
```
For GeoJSON, WKT, and WKB input, if the srid argument is not specified, the resulting GEOMETRY object has the SRID set to 0.
```

So we need to explicitly set the SRID in our WKT geospatial strings so
that this is set on the target rows and can be seen with `ST_SRID`.


Test added
Functionally tested
  • Loading branch information
Amogh-Bharadwaj authored Mar 21, 2024
1 parent 472d279 commit 156a9b2
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 6 deletions.
12 changes: 12 additions & 0 deletions flow/connectors/sql/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,18 @@ func (g *GenericSQLQueryExecutor) CountNonNullRows(
return count.Int64, err
}

func (g *GenericSQLQueryExecutor) CountSRIDs(
ctx context.Context,
schemaName string,
tableName string,
columnName string,
) (int64, error) {
var count pgtype.Int8
err := g.db.QueryRowxContext(ctx, "SELECT COUNT(CASE WHEN ST_SRID("+columnName+
") <> 0 THEN 1 END) AS not_zero FROM "+schemaName+"."+tableName).Scan(&count)
return count.Int64, err
}

func (g *GenericSQLQueryExecutor) columnTypeToQField(ct *sql.ColumnType) (model.QField, error) {
qvKind, ok := g.dbtypeToQValueKind[ct.DatabaseTypeName()]
if !ok {
Expand Down
18 changes: 15 additions & 3 deletions flow/e2e/snowflake/peer_flow_sf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() {
for range 6 {
_, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`
INSERT INTO %s (line,poly) VALUES ($1,$2)
`, srcTableName), "010200000002000000000000000000F03F000000000000004000000000000008400000000000001040",
`, srcTableName), "SRID=5678;010200000002000000000000000000F03F000000000000004000000000000008400000000000001040",
"010300000001000000050000000000000000000000000000000000000000000000"+
"00000000000000000000f03f000000000000f03f000000000000f03f0000000000"+
"00f03f000000000000000000000000000000000000000000000000")
Expand All @@ -143,6 +143,13 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() {
return false
}

// Make sure SRIDs are set
sridCount, err := s.sfHelper.CountSRIDs("test_invalid_geo_sf_avro_cdc", "line")
if err != nil {
s.t.Log(err)
return false
}

polyCount, err := s.sfHelper.CountNonNullRows("test_invalid_geo_sf_avro_cdc", "poly")
if err != nil {
return false
Expand All @@ -151,9 +158,14 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() {
if lineCount != 6 || polyCount != 6 {
s.t.Logf("wrong counts, expect 6 lines 6 polies, not %d lines %d polies", lineCount, polyCount)
return false
} else {
return true
}

if sridCount != 6 {
s.t.Logf("there are some srids that are 0, expected 6 non-zero srids, got %d non-zero srids", sridCount)
return false
}

return true
})
env.Cancel()

Expand Down
9 changes: 9 additions & 0 deletions flow/e2e/snowflake/snowflake_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ func (s *SnowflakeTestHelper) CountNonNullRows(tableName string, columnName stri
return int(res), nil
}

func (s *SnowflakeTestHelper) CountSRIDs(tableName string, columnName string) (int, error) {
res, err := s.testClient.CountSRIDs(context.Background(), s.testSchemaName, tableName, columnName)
if err != nil {
return 0, err
}

return int(res), nil
}

func (s *SnowflakeTestHelper) CheckNull(tableName string, colNames []string) (bool, error) {
return s.testClient.CheckNull(context.Background(), s.testSchemaName, tableName, colNames)
}
Expand Down
4 changes: 4 additions & 0 deletions flow/geo/geo.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ func GeoValidate(hexWkb string) (string, error) {
}

wkt := geometryObject.ToWKT()

if SRID := geometryObject.SRID(); SRID != 0 {
wkt = fmt.Sprintf("SRID=%d;%s", geometryObject.SRID(), wkt)
}
return wkt, nil
}

Expand Down
13 changes: 11 additions & 2 deletions flow/model/qrecord_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"log/slog"
"strings"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -237,9 +238,17 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) {
return nil, src.err
}

wkb, err := geo.GeoToWKB(v)
geoWkt := v
if strings.HasPrefix(v, "SRID=") {
_, wkt, found := strings.Cut(v, ";")
if found {
geoWkt = wkt
}
}

wkb, err := geo.GeoToWKB(geoWkt)
if err != nil {
src.err = errors.New("failed to convert Geospatial value to wkb")
src.err = fmt.Errorf("failed to convert Geospatial value to wkb: %v", err)
return nil, src.err
}

Expand Down
10 changes: 9 additions & 1 deletion flow/model/qvalue/qvalue.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,15 @@ func compareGeometry(value1, value2 interface{}) bool {
case *geom.Geom:
return v1.Equals(geo2)
case string:
geo1, err := geom.NewGeomFromWKT(v1)
geoWkt := v1
if strings.HasPrefix(geoWkt, "SRID=") {
_, wkt, found := strings.Cut(geoWkt, ";")
if found {
geoWkt = wkt
}
}

geo1, err := geom.NewGeomFromWKT(geoWkt)
if err != nil {
panic(err)
}
Expand Down

0 comments on commit 156a9b2

Please sign in to comment.