diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index ac50f3ecf7..0e92a7105e 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -30,6 +30,7 @@ type PostgresCDCSource struct { typeMap *pgtype.Map startLSN pglogrepl.LSN commitLock bool + ConnStr string } type PostgresCDCConfig struct { @@ -40,6 +41,7 @@ type PostgresCDCConfig struct { SrcTableIDNameMapping map[uint32]string TableNameMapping map[string]string RelationMessageMapping model.RelationMessageMapping + ConnStr string } // Create a new PostgresCDCSource @@ -54,6 +56,7 @@ func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig) (*PostgresCDCSource, err relationMessageMapping: cdcConfig.RelationMessageMapping, typeMap: pgtype.NewMap(), commitLock: false, + ConnStr: cdcConfig.ConnStr, }, nil } @@ -515,37 +518,17 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma if err != nil { return nil, err } - retVal, err := parseFieldFromPostgresOID(dataType, parsedData) + retVal, err := parseFieldFromPostgresOID(dataType, parsedData, p.ConnStr) if err != nil { return nil, err } return retVal, nil } else if dataType == uint32(oid.T_timetz) { // ugly TIMETZ workaround for CDC decoding. - retVal, err := parseFieldFromPostgresOID(dataType, string(data)) + retVal, err := parseFieldFromPostgresOID(dataType, string(data), p.ConnStr) if err != nil { return nil, err } return retVal, nil - } else { // For custom types, let's identify with schema information - var typeName string - res, err := p.replPool.Query(p.ctx, "SELECT typname FROM pg_type WHERE oid = $1", dataType) - if err != nil { - return nil, fmt.Errorf("error querying type name for column: %w", err) - } - scanErr := res.Scan(&typeName) - if err != nil { - return nil, fmt.Errorf("error scanning type name: %w", scanErr) - } - fmt.Println("datatype: ", dataType) - fmt.Println("typeName: ", typeName) - // POSTGIS and HSTORE support - switch typeName { - case "geometry": - - case "geography": - case "hstore": - case "point": - } } return &qvalue.QValue{Kind: qvalue.QValueKindString, Value: string(data)}, nil } @@ -599,7 +582,7 @@ func (p *PostgresCDCSource) processRelationMessage( if prevRelMap[column.Name] == nil { schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.DeltaAddedColumn{ ColumnName: column.Name, - ColumnType: string(postgresOIDToQValueKind(column.DataType)), + ColumnType: string(postgresOIDToQValueKind(column.DataType, p.ConnStr)), }) // present in previous and current relation messages, but data types have changed. // so we add it to AddedColumns and DroppedColumns, knowing that we process DroppedColumns first. diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 9725a6a2af..03f7fa549c 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -217,6 +217,7 @@ func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) (*model.R Publication: publicationName, TableNameMapping: req.TableNameMapping, RelationMessageMapping: req.RelationMessageMapping, + ConnStr: c.connStr, }) if err != nil { return nil, fmt.Errorf("failed to create cdc source: %w", err) @@ -588,7 +589,7 @@ func (c *PostgresConnector) getTableSchemaForTable( defer rows.Close() for _, fieldDescription := range rows.FieldDescriptions() { - genericColType := postgresOIDToQValueKind(fieldDescription.DataTypeOID) + genericColType := postgresOIDToQValueKind(fieldDescription.DataTypeOID, c.connStr) if genericColType == qvalue.QValueKindInvalid { // we use string for invalid types genericColType = qvalue.QValueKindString diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index 3c0d3124b8..7449e680c1 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -291,7 +291,7 @@ func (c *PostgresConnector) PullQRepRecords( "partitionId": partition.PartitionId, }).Infof("pulling full table partition for flow job %s", config.FlowJobName) executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, - config.FlowJobName, partition.PartitionId) + config.FlowJobName, partition.PartitionId, c.connStr) query := config.Query return executor.ExecuteAndProcessQuery(query) } @@ -337,7 +337,7 @@ func (c *PostgresConnector) PullQRepRecords( } executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, - config.FlowJobName, partition.PartitionId) + config.FlowJobName, partition.PartitionId, c.connStr) records, err := executor.ExecuteAndProcessQuery(query, rangeStart, rangeEnd) if err != nil { @@ -363,7 +363,7 @@ func (c *PostgresConnector) PullQRepRecordStream( "partitionId": partition.PartitionId, }).Infof("pulling full table partition for flow job %s", config.FlowJobName) executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, - config.FlowJobName, partition.PartitionId) + config.FlowJobName, partition.PartitionId, c.connStr) query := config.Query _, err := executor.ExecuteAndProcessQueryStream(stream, query) return 0, err @@ -410,7 +410,7 @@ func (c *PostgresConnector) PullQRepRecordStream( } executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, - config.FlowJobName, partition.PartitionId) + config.FlowJobName, partition.PartitionId, c.connStr) numRecords, err := executor.ExecuteAndProcessQueryStream(stream, query, rangeStart, rangeEnd) if err != nil { return 0, err diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index 5e2ef59d59..1dcf906e7e 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -7,6 +7,7 @@ import ( "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/shared" util "github.com/PeerDB-io/peer-flow/utils" "github.com/jackc/pgx/v5" @@ -23,6 +24,7 @@ type QRepQueryExecutor struct { testEnv bool flowJobName string partitionID string + connStr string } func NewQRepQueryExecutor(pool *pgxpool.Pool, ctx context.Context, @@ -37,7 +39,7 @@ func NewQRepQueryExecutor(pool *pgxpool.Pool, ctx context.Context, } func NewQRepQueryExecutorSnapshot(pool *pgxpool.Pool, ctx context.Context, snapshot string, - flowJobName string, partitionID string) *QRepQueryExecutor { + flowJobName string, partitionID string, connStr string) *QRepQueryExecutor { log.WithFields(log.Fields{ "flowName": flowJobName, "partitionID": partitionID, @@ -48,6 +50,7 @@ func NewQRepQueryExecutorSnapshot(pool *pgxpool.Pool, ctx context.Context, snaps snapshot: snapshot, flowJobName: flowJobName, partitionID: partitionID, + connStr: connStr, } } @@ -89,11 +92,27 @@ func (qe *QRepQueryExecutor) executeQueryInTx(tx pgx.Tx, cursorName string, fetc } // FieldDescriptionsToSchema converts a slice of pgconn.FieldDescription to a QRecordSchema. -func fieldDescriptionsToSchema(fds []pgconn.FieldDescription) *model.QRecordSchema { +func (qe *QRepQueryExecutor) fieldDescriptionsToSchema(fds []pgconn.FieldDescription) *model.QRecordSchema { qfields := make([]*model.QField, len(fds)) for i, fd := range fds { cname := fd.Name - ctype := postgresOIDToQValueKind(fd.DataTypeOID) + ctype := postgresOIDToQValueKind(fd.DataTypeOID, qe.connStr) + if ctype == qvalue.QValueKindInvalid { + var typeName string + err := qe.pool.QueryRow(qe.ctx, "SELECT typname FROM pg_type WHERE oid = $1", fd.DataTypeOID).Scan(&typeName) + if err != nil { + ctype = qvalue.QValueKindInvalid + } else { + switch typeName { + case "geometry": + ctype = qvalue.QValueKindGeometry + case "geography": + ctype = qvalue.QValueKindGeography + default: + ctype = qvalue.QValueKindInvalid + } + } + } // there isn't a way to know if a column is nullable or not // TODO fix this. cnullable := true @@ -118,7 +137,7 @@ func (qe *QRepQueryExecutor) ProcessRows( }).Info("Processing rows") // Iterate over the rows for rows.Next() { - record, err := mapRowToQRecord(rows, fieldDescriptions) + record, err := mapRowToQRecord(rows, fieldDescriptions, qe.connStr) if err != nil { return nil, fmt.Errorf("failed to map row to QRecord: %w", err) } @@ -133,7 +152,7 @@ func (qe *QRepQueryExecutor) ProcessRows( batch := &model.QRecordBatch{ NumRecords: uint32(len(records)), Records: records, - Schema: fieldDescriptionsToSchema(fieldDescriptions), + Schema: qe.fieldDescriptionsToSchema(fieldDescriptions), } log.WithFields(log.Fields{ @@ -155,7 +174,7 @@ func (qe *QRepQueryExecutor) processRowsStream( // Iterate over the rows for rows.Next() { - record, err := mapRowToQRecord(rows, fieldDescriptions) + record, err := mapRowToQRecord(rows, fieldDescriptions, qe.connStr) if err != nil { stream.Records <- &model.QRecordOrError{ Err: fmt.Errorf("failed to map row to QRecord: %w", err), @@ -214,7 +233,7 @@ func (qe *QRepQueryExecutor) processFetchedRows( fieldDescriptions := rows.FieldDescriptions() if !stream.IsSchemaSet() { - schema := fieldDescriptionsToSchema(fieldDescriptions) + schema := qe.fieldDescriptionsToSchema(fieldDescriptions) _ = stream.SetSchema(schema) } @@ -395,7 +414,7 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStream( return totalRecordsFetched, nil } -func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription) (*model.QRecord, error) { +func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription, connStr string) (*model.QRecord, error) { // make vals an empty array of QValue of size len(fds) record := model.NewQRecord(len(fds)) @@ -405,7 +424,7 @@ func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription) (*model.QRecor } for i, fd := range fds { - tmp, err := parseFieldFromPostgresOID(fd.DataTypeOID, values[i]) + tmp, err := parseFieldFromPostgresOID(fd.DataTypeOID, values[i], connStr) if err != nil { return nil, fmt.Errorf("failed to parse field: %w", err) } diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index b9c7dcc904..98b498362f 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -1,6 +1,7 @@ package connpostgres import ( + "context" "encoding/json" "errors" "fmt" @@ -9,13 +10,14 @@ import ( "strings" "time" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/jackc/pgx/v5/pgtype" "github.com/lib/pq/oid" log "github.com/sirupsen/logrus" ) -func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { +func postgresOIDToQValueKind(recvOID uint32, connStr string) qvalue.QValueKind { switch recvOID { case pgtype.BoolOID: return qvalue.QValueKindBoolean @@ -55,6 +57,8 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { return qvalue.QValueKindArrayInt32 case pgtype.Int8ArrayOID: return qvalue.QValueKindArrayInt64 + case pgtype.PointOID: + return qvalue.QValueKindPoint case pgtype.Float4ArrayOID: return qvalue.QValueKindArrayFloat32 case pgtype.Float8ArrayOID: @@ -77,9 +81,18 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { return qvalue.QValueKindString } else if recvOID == uint32(oid.T_tsquery) { // TSQUERY return qvalue.QValueKindString + } else if recvOID == uint32(oid.T_point) { + return qvalue.QValueKindPoint } - // log.Warnf("failed to get type name for oid: %v", recvOID) - return qvalue.QValueKindInvalid + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + qKind, err := utils.GetCustomDataType(ctx, connStr, recvOID) + if err != nil { + log.Warnf("failed to get type name for oid: %v", recvOID) + return qvalue.QValueKindInvalid + } + return qKind } else { log.Warnf("unsupported field type: %v - type name - %s; returning as string", recvOID, typeName.Name) return qvalue.QValueKindString @@ -337,6 +350,10 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( return nil, fmt.Errorf("failed to parse hstore: %w", err) } val = &qvalue.QValue{Kind: qvalue.QValueKindHStore, Value: hstoreVal} + case qvalue.QValueKindPoint: + x_coord := value.(pgtype.Point).P.X + y_coord := value.(pgtype.Point).P.Y + val = &qvalue.QValue{Kind: qvalue.QValueKindPoint, Value: fmt.Sprintf("POINT(%f %f)", x_coord, y_coord)} default: // log.Warnf("unhandled QValueKind => %v, parsing as string", qvalueKind) textVal, ok := value.(string) @@ -353,8 +370,8 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( return val, nil } -func parseFieldFromPostgresOID(oid uint32, value interface{}) (*qvalue.QValue, error) { - return parseFieldFromQValueKind(postgresOIDToQValueKind(oid), value) +func parseFieldFromPostgresOID(oid uint32, value interface{}, connStr string) (*qvalue.QValue, error) { + return parseFieldFromQValueKind(postgresOIDToQValueKind(oid, connStr), value) } func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) { diff --git a/flow/connectors/snowflake/qrep.go b/flow/connectors/snowflake/qrep.go index ce0cc48511..094dbcaff9 100644 --- a/flow/connectors/snowflake/qrep.go +++ b/flow/connectors/snowflake/qrep.go @@ -253,7 +253,7 @@ func (c *SnowflakeConnector) ConsolidateQRepPartitions(config *protos.QRepConfig case protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT: return fmt.Errorf("multi-insert sync mode not supported for snowflake") case protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO: - allCols, err := c.getColsFromTable(destTable) + colInfo, err := c.getColsFromTable(destTable) if err != nil { log.WithFields(log.Fields{ "flowName": config.FlowJobName, @@ -261,6 +261,7 @@ func (c *SnowflakeConnector) ConsolidateQRepPartitions(config *protos.QRepConfig return fmt.Errorf("failed to get columns from table %s: %w", destTable, err) } + allCols := colInfo.Columns err = CopyStageToDestination(c, config, destTable, stageName, allCols) if err != nil { log.WithFields(log.Fields{ @@ -283,7 +284,7 @@ func (c *SnowflakeConnector) CleanupQRepFlow(config *protos.QRepConfig) error { return c.dropStage(config.StagingPath, config.FlowJobName) } -func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, error) { +func (c *SnowflakeConnector) getColsFromTable(tableName string) (*model.ColumnInformation, error) { // parse the table name to get the schema and table name components, err := parseTableName(tableName) if err != nil { @@ -296,7 +297,7 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, error //nolint:gosec queryString := fmt.Sprintf(` - SELECT column_name + SELECT column_name, data_type FROM information_schema.columns WHERE UPPER(table_name) = '%s' AND UPPER(table_schema) = '%s' `, components.tableIdentifier, components.schemaIdentifier) @@ -307,16 +308,24 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, error } defer rows.Close() - var cols []string + columnMap := map[string]string{} for rows.Next() { - var col string - if err := rows.Scan(&col); err != nil { + var colName string + var colType string + if err := rows.Scan(&colName, &colType); err != nil { return nil, fmt.Errorf("failed to scan row: %w", err) } - cols = append(cols, col) + columnMap[colName] = colType + } + var cols []string + for k := range columnMap { + cols = append(cols, k) } - return cols, nil + return &model.ColumnInformation{ + ColumnMap: columnMap, + Columns: cols, + }, nil } // dropStage drops the stage for the given job. diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index 71c7d74d9a..914fe3af61 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -18,6 +18,11 @@ import ( "go.temporal.io/sdk/activity" ) +type CopyInfo struct { + transformationSQL string + columnsSQL string +} + type SnowflakeAvroSyncMethod struct { config *protos.QRepConfig connector *SnowflakeConnector @@ -73,11 +78,12 @@ func (s *SnowflakeAvroSyncMethod) SyncRecords( "flowName": flowJobName, }).Infof("Created stage %s", stage) - allCols, err := s.connector.getColsFromTable(s.config.DestinationTableIdentifier) + colInfo, err := s.connector.getColsFromTable(s.config.DestinationTableIdentifier) if err != nil { return 0, err } + allCols := colInfo.Columns err = s.putFileToStage(localFilePath, stage) if err != nil { return 0, err @@ -251,6 +257,36 @@ func (s *SnowflakeAvroSyncMethod) putFileToStage(localFilePath string, stage str return nil } +func (sc *SnowflakeConnector) GetCopyTransformation(dstTableName string) (*CopyInfo, error) { + colInfo, colsErr := sc.getColsFromTable(dstTableName) + if colsErr != nil { + return nil, fmt.Errorf("failed to get columns from destination table: %w", colsErr) + } + + var transformations []string + var columnOrder []string + for col, colType := range colInfo.ColumnMap { + if col == "_PEERDB_IS_DELETED" { + continue + } + colName := strings.ToLower(col) + columnOrder = append(columnOrder, colName) + switch colType { + case "GEOGRAPHY": + transformations = append(transformations, + fmt.Sprintf("TO_GEOGRAPHY($1:\"%s\"::string) AS \"%s\"", colName, colName)) + case "GEOMETRY": + transformations = append(transformations, + fmt.Sprintf("TO_GEOMETRY($1:\"%s\"::string) AS \"%s\"", colName, colName)) + default: + transformations = append(transformations, fmt.Sprintf("$1:%s AS %s", colName, colName)) + } + } + transformationSQL := strings.Join(transformations, ",") + columnsSQL := strings.Join(columnOrder, ",") + return &CopyInfo{transformationSQL, columnsSQL}, nil +} + func CopyStageToDestination( connector *SnowflakeConnector, config *protos.QRepConfig, @@ -263,7 +299,6 @@ func CopyStageToDestination( }).Infof("Copying stage to destination %s", dstTableName) copyOpts := []string{ "FILE_FORMAT = (TYPE = AVRO)", - "MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE'", "PURGE = TRUE", "ON_ERROR = 'CONTINUE'", } @@ -278,9 +313,13 @@ func CopyStageToDestination( } } + copyTransformation, err := connector.GetCopyTransformation(dstTableName) + if err != nil { + return fmt.Errorf("failed to get copy transformation: %w", err) + } switch appendMode { case true: - err := writeHandler.HandleAppendMode(config.FlowJobName) + err := writeHandler.HandleAppendMode(config.FlowJobName, copyTransformation) if err != nil { return fmt.Errorf("failed to handle append mode: %w", err) } @@ -288,7 +327,7 @@ func CopyStageToDestination( case false: upsertKeyCols := config.WriteMode.UpsertKeyColumns err := writeHandler.HandleUpsertMode(allCols, upsertKeyCols, config.WatermarkColumn, - config.FlowJobName) + config.FlowJobName, copyTransformation) if err != nil { return fmt.Errorf("failed to handle upsert mode: %w", err) } @@ -348,9 +387,12 @@ func NewSnowflakeAvroWriteHandler( } } -func (s *SnowflakeAvroWriteHandler) HandleAppendMode(flowJobName string) error { +func (s *SnowflakeAvroWriteHandler) HandleAppendMode( + flowJobName string, + copyInfo *CopyInfo) error { //nolint:gosec - copyCmd := fmt.Sprintf("COPY INTO %s FROM @%s %s", s.dstTableName, s.stage, strings.Join(s.copyOpts, ",")) + copyCmd := fmt.Sprintf("COPY INTO %s(%s) FROM (SELECT %s FROM @%s) %s", + s.dstTableName, copyInfo.columnsSQL, copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ",")) log.Infof("running copy command: %s", copyCmd) _, err := s.connector.database.Exec(copyCmd) if err != nil { @@ -424,6 +466,7 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode( upsertKeyCols []string, watermarkCol string, flowJobName string, + copyInfo *CopyInfo, ) error { runID, err := util.RandomUInt64() if err != nil { @@ -443,8 +486,8 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode( }).Infof("created temp table %s", tempTableName) //nolint:gosec - copyCmd := fmt.Sprintf("COPY INTO %s FROM @%s %s", - tempTableName, s.stage, strings.Join(s.copyOpts, ",")) + copyCmd := fmt.Sprintf("COPY INTO %s(%s) FROM (SELECT %s FROM @%s) %s", + tempTableName, copyInfo.columnsSQL, copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ",")) _, err = s.connector.database.Exec(copyCmd) if err != nil { return fmt.Errorf("failed to run COPY INTO command: %w", err) diff --git a/flow/connectors/snowflake/qvalue_convert.go b/flow/connectors/snowflake/qvalue_convert.go index 32b84dd07e..b88517856d 100644 --- a/flow/connectors/snowflake/qvalue_convert.go +++ b/flow/connectors/snowflake/qvalue_convert.go @@ -27,8 +27,11 @@ var qValueKindToSnowflakeTypeMap = map[qvalue.QValueKind]string{ qvalue.QValueKindTimeTZ: "STRING", qvalue.QValueKindInvalid: "STRING", qvalue.QValueKindHStore: "STRING", + qvalue.QValueKindGeography: "GEOGRAPHY", + qvalue.QValueKindGeometry: "GEOMETRY", + qvalue.QValueKindPoint: "GEOMETRY", - // array types will be mapped to STRING + // array types will be mapped to VARIANT qvalue.QValueKindArrayFloat32: "VARIANT", qvalue.QValueKindArrayFloat64: "VARIANT", qvalue.QValueKindArrayInt32: "VARIANT", @@ -60,6 +63,8 @@ var snowflakeTypeToQValueKindMap = map[string]qvalue.QValueKind{ "DECIMAL": qvalue.QValueKindNumeric, "NUMERIC": qvalue.QValueKindNumeric, "VARIANT": qvalue.QValueKindJSON, + "GEOMETRY": qvalue.QValueKindGeometry, + "GEOGRAPHY": qvalue.QValueKindGeography, } func qValueKindToSnowflakeType(colType qvalue.QValueKind) string { diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 7046746ecc..d691fef970 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -958,6 +958,9 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement( case qvalue.QValueKindBytes, qvalue.QValueKindBit: flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+ "AS %s,", toVariantColumnName, columnName, targetColumnName)) + case qvalue.QValueKindGeography, qvalue.QValueKindGeometry, qvalue.QValueKindPoint: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS STRING) AS %s,", + toVariantColumnName, columnName, targetColumnName)) // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle time types and interval types // case model.ColumnTypeTime: // flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+ diff --git a/flow/connectors/utils/postgres.go b/flow/connectors/utils/postgres.go index 883cf36791..7a4e672df5 100644 --- a/flow/connectors/utils/postgres.go +++ b/flow/connectors/utils/postgres.go @@ -1,10 +1,13 @@ package utils import ( + "context" "fmt" "net/url" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" + "github.com/jackc/pgx/v5/pgxpool" ) func GetPGConnectionString(pgConfig *protos.PostgresConfig) string { @@ -20,3 +23,30 @@ func GetPGConnectionString(pgConfig *protos.PostgresConfig) string { ) return connString } + +func GetCustomDataType(ctx context.Context, connStr string, dataType uint32) (qvalue.QValueKind, error) { + var typeName string + pool, err := pgxpool.New(ctx, connStr) + if err != nil { + return qvalue.QValueKindString, fmt.Errorf("failed to create postgres connection pool"+ + " for custom type handling: %w", err) + } + + defer pool.Close() + + err = pool.QueryRow(ctx, "SELECT typname FROM pg_type WHERE oid = $1", dataType).Scan(&typeName) + if err != nil { + return qvalue.QValueKindString, fmt.Errorf("failed to query pg_type for custom type handling: %w", err) + } + + var qValueKind qvalue.QValueKind + switch typeName { + case "geometry": + qValueKind = qvalue.QValueKindGeometry + case "geography": + qValueKind = qvalue.QValueKindGeography + default: + qValueKind = qvalue.QValueKindString + } + return qValueKind, nil +} diff --git a/flow/model/column.go b/flow/model/column.go new file mode 100644 index 0000000000..367b7d11f4 --- /dev/null +++ b/flow/model/column.go @@ -0,0 +1,6 @@ +package model + +type ColumnInformation struct { + ColumnMap map[string]string + Columns []string +} diff --git a/flow/model/conversion_avro.go b/flow/model/conversion_avro.go index 3c4ba07076..f312a1a109 100644 --- a/flow/model/conversion_avro.go +++ b/flow/model/conversion_avro.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/PeerDB-io/peer-flow/model/qvalue" + log "github.com/sirupsen/logrus" ) type QRecordAvroConverter struct { @@ -75,6 +76,7 @@ func GetAvroSchemaDefinition( nullableFields := map[string]bool{} for _, qField := range qRecordSchema.Fields { + log.Infof("qField name: %s, qField type: %s", qField.Name, qField.Type) avroType, err := qvalue.GetAvroSchemaFromQValueKind(qField.Type, qField.Nullable) if err != nil { return nil, err diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 62fefe698d..5437aed052 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -36,6 +36,10 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, nullable bool) (*QValueKindAvr return &QValueKindAvroSchema{ AvroLogicalSchema: "string", }, nil + case QValueKindGeometry, QValueKindGeography, QValueKindPoint: + return &QValueKindAvroSchema{ + AvroLogicalSchema: "string", + }, nil case QValueKindInt16, QValueKindInt32, QValueKindInt64: return &QValueKindAvroSchema{ AvroLogicalSchema: "long", @@ -202,6 +206,8 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return c.processArrayString() case QValueKindUUID: return c.processUUID() + case QValueKindGeography, QValueKindGeometry, QValueKindPoint: + return c.processGeospatial() default: return nil, fmt.Errorf("[toavro] unsupported QValueKind: %s", c.Value.Kind) } @@ -330,6 +336,19 @@ func (c *QValueAvroConverter) processUUID() (interface{}, error) { return uuidString, nil } +func (c *QValueAvroConverter) processGeospatial() (interface{}, error) { + if c.Value.Value == nil { + return nil, nil + } + + geoString, ok := c.Value.Value.(string) + if !ok { + return nil, fmt.Errorf("[conversion] invalid geospatial value %v", c.Value.Value) + } + + return geoString, nil +} + func (c *QValueAvroConverter) processArrayInt32() (interface{}, error) { if c.Value.Value == nil && c.Nullable { return nil, nil diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index 1def708611..1e8e5f5099 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -23,6 +23,9 @@ const ( QValueKindJSON QValueKind = "json" QValueKindBit QValueKind = "bit" QValueKindHStore QValueKind = "hstore" + QValueKindGeography QValueKind = "geography" + QValueKindGeometry QValueKind = "geometry" + QValueKindPoint QValueKind = "point" // array types QValueKindArrayFloat32 QValueKind = "array_float32"