Skip to content

Commit

Permalink
v1 of postgis support
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed Oct 15, 2023
1 parent 58aa14b commit b00e32f
Show file tree
Hide file tree
Showing 14 changed files with 199 additions and 59 deletions.
29 changes: 6 additions & 23 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type PostgresCDCSource struct {
typeMap *pgtype.Map
startLSN pglogrepl.LSN
commitLock bool
ConnStr string
}

type PostgresCDCConfig struct {
Expand All @@ -40,6 +41,7 @@ type PostgresCDCConfig struct {
SrcTableIDNameMapping map[uint32]string
TableNameMapping map[string]string
RelationMessageMapping model.RelationMessageMapping
ConnStr string
}

// Create a new PostgresCDCSource
Expand All @@ -54,6 +56,7 @@ func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig) (*PostgresCDCSource, err
relationMessageMapping: cdcConfig.RelationMessageMapping,
typeMap: pgtype.NewMap(),
commitLock: false,
ConnStr: cdcConfig.ConnStr,
}, nil
}

Expand Down Expand Up @@ -515,37 +518,17 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma
if err != nil {
return nil, err
}
retVal, err := parseFieldFromPostgresOID(dataType, parsedData)
retVal, err := parseFieldFromPostgresOID(dataType, parsedData, p.ConnStr)
if err != nil {
return nil, err
}
return retVal, nil
} else if dataType == uint32(oid.T_timetz) { // ugly TIMETZ workaround for CDC decoding.
retVal, err := parseFieldFromPostgresOID(dataType, string(data))
retVal, err := parseFieldFromPostgresOID(dataType, string(data), p.ConnStr)
if err != nil {
return nil, err
}
return retVal, nil
} else { // For custom types, let's identify with schema information
var typeName string
res, err := p.replPool.Query(p.ctx, "SELECT typname FROM pg_type WHERE oid = $1", dataType)
if err != nil {
return nil, fmt.Errorf("error querying type name for column: %w", err)
}
scanErr := res.Scan(&typeName)
if err != nil {
return nil, fmt.Errorf("error scanning type name: %w", scanErr)
}
fmt.Println("datatype: ", dataType)
fmt.Println("typeName: ", typeName)
// POSTGIS and HSTORE support
switch typeName {
case "geometry":

case "geography":
case "hstore":
case "point":
}
}
return &qvalue.QValue{Kind: qvalue.QValueKindString, Value: string(data)}, nil
}
Expand Down Expand Up @@ -599,7 +582,7 @@ func (p *PostgresCDCSource) processRelationMessage(
if prevRelMap[column.Name] == nil {
schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.DeltaAddedColumn{
ColumnName: column.Name,
ColumnType: string(postgresOIDToQValueKind(column.DataType)),
ColumnType: string(postgresOIDToQValueKind(column.DataType, p.ConnStr)),
})
// present in previous and current relation messages, but data types have changed.
// so we add it to AddedColumns and DroppedColumns, knowing that we process DroppedColumns first.
Expand Down
3 changes: 2 additions & 1 deletion flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) (*model.R
Publication: publicationName,
TableNameMapping: req.TableNameMapping,
RelationMessageMapping: req.RelationMessageMapping,
ConnStr: c.connStr,
})
if err != nil {
return nil, fmt.Errorf("failed to create cdc source: %w", err)
Expand Down Expand Up @@ -588,7 +589,7 @@ func (c *PostgresConnector) getTableSchemaForTable(
defer rows.Close()

for _, fieldDescription := range rows.FieldDescriptions() {
genericColType := postgresOIDToQValueKind(fieldDescription.DataTypeOID)
genericColType := postgresOIDToQValueKind(fieldDescription.DataTypeOID, c.connStr)
if genericColType == qvalue.QValueKindInvalid {
// we use string for invalid types
genericColType = qvalue.QValueKindString
Expand Down
8 changes: 4 additions & 4 deletions flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func (c *PostgresConnector) PullQRepRecords(
"partitionId": partition.PartitionId,
}).Infof("pulling full table partition for flow job %s", config.FlowJobName)
executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
config.FlowJobName, partition.PartitionId, c.connStr)
query := config.Query
return executor.ExecuteAndProcessQuery(query)
}
Expand Down Expand Up @@ -337,7 +337,7 @@ func (c *PostgresConnector) PullQRepRecords(
}

executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
config.FlowJobName, partition.PartitionId, c.connStr)
records, err := executor.ExecuteAndProcessQuery(query,
rangeStart, rangeEnd)
if err != nil {
Expand All @@ -363,7 +363,7 @@ func (c *PostgresConnector) PullQRepRecordStream(
"partitionId": partition.PartitionId,
}).Infof("pulling full table partition for flow job %s", config.FlowJobName)
executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
config.FlowJobName, partition.PartitionId, c.connStr)
query := config.Query
_, err := executor.ExecuteAndProcessQueryStream(stream, query)
return 0, err
Expand Down Expand Up @@ -410,7 +410,7 @@ func (c *PostgresConnector) PullQRepRecordStream(
}

executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
config.FlowJobName, partition.PartitionId, c.connStr)
numRecords, err := executor.ExecuteAndProcessQueryStream(stream, query, rangeStart, rangeEnd)
if err != nil {
return 0, err
Expand Down
37 changes: 28 additions & 9 deletions flow/connectors/postgres/qrep_query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/shared"
util "github.com/PeerDB-io/peer-flow/utils"
"github.com/jackc/pgx/v5"
Expand All @@ -23,6 +24,7 @@ type QRepQueryExecutor struct {
testEnv bool
flowJobName string
partitionID string
connStr string
}

func NewQRepQueryExecutor(pool *pgxpool.Pool, ctx context.Context,
Expand All @@ -37,7 +39,7 @@ func NewQRepQueryExecutor(pool *pgxpool.Pool, ctx context.Context,
}

func NewQRepQueryExecutorSnapshot(pool *pgxpool.Pool, ctx context.Context, snapshot string,
flowJobName string, partitionID string) *QRepQueryExecutor {
flowJobName string, partitionID string, connStr string) *QRepQueryExecutor {
log.WithFields(log.Fields{
"flowName": flowJobName,
"partitionID": partitionID,
Expand All @@ -48,6 +50,7 @@ func NewQRepQueryExecutorSnapshot(pool *pgxpool.Pool, ctx context.Context, snaps
snapshot: snapshot,
flowJobName: flowJobName,
partitionID: partitionID,
connStr: connStr,
}
}

Expand Down Expand Up @@ -89,11 +92,27 @@ func (qe *QRepQueryExecutor) executeQueryInTx(tx pgx.Tx, cursorName string, fetc
}

// FieldDescriptionsToSchema converts a slice of pgconn.FieldDescription to a QRecordSchema.
func fieldDescriptionsToSchema(fds []pgconn.FieldDescription) *model.QRecordSchema {
func (qe *QRepQueryExecutor) fieldDescriptionsToSchema(fds []pgconn.FieldDescription) *model.QRecordSchema {
qfields := make([]*model.QField, len(fds))
for i, fd := range fds {
cname := fd.Name
ctype := postgresOIDToQValueKind(fd.DataTypeOID)
ctype := postgresOIDToQValueKind(fd.DataTypeOID, qe.connStr)
if ctype == qvalue.QValueKindInvalid {
var typeName string
err := qe.pool.QueryRow(qe.ctx, "SELECT typname FROM pg_type WHERE oid = $1", fd.DataTypeOID).Scan(&typeName)
if err != nil {
ctype = qvalue.QValueKindInvalid
} else {
switch typeName {
case "geometry":
ctype = qvalue.QValueKindGeometry
case "geography":
ctype = qvalue.QValueKindGeography
default:
ctype = qvalue.QValueKindInvalid
}
}
}
// there isn't a way to know if a column is nullable or not
// TODO fix this.
cnullable := true
Expand All @@ -118,7 +137,7 @@ func (qe *QRepQueryExecutor) ProcessRows(
}).Info("Processing rows")
// Iterate over the rows
for rows.Next() {
record, err := mapRowToQRecord(rows, fieldDescriptions)
record, err := mapRowToQRecord(rows, fieldDescriptions, qe.connStr)
if err != nil {
return nil, fmt.Errorf("failed to map row to QRecord: %w", err)
}
Expand All @@ -133,7 +152,7 @@ func (qe *QRepQueryExecutor) ProcessRows(
batch := &model.QRecordBatch{
NumRecords: uint32(len(records)),
Records: records,
Schema: fieldDescriptionsToSchema(fieldDescriptions),
Schema: qe.fieldDescriptionsToSchema(fieldDescriptions),
}

log.WithFields(log.Fields{
Expand All @@ -155,7 +174,7 @@ func (qe *QRepQueryExecutor) processRowsStream(

// Iterate over the rows
for rows.Next() {
record, err := mapRowToQRecord(rows, fieldDescriptions)
record, err := mapRowToQRecord(rows, fieldDescriptions, qe.connStr)
if err != nil {
stream.Records <- &model.QRecordOrError{
Err: fmt.Errorf("failed to map row to QRecord: %w", err),
Expand Down Expand Up @@ -214,7 +233,7 @@ func (qe *QRepQueryExecutor) processFetchedRows(

fieldDescriptions := rows.FieldDescriptions()
if !stream.IsSchemaSet() {
schema := fieldDescriptionsToSchema(fieldDescriptions)
schema := qe.fieldDescriptionsToSchema(fieldDescriptions)
_ = stream.SetSchema(schema)
}

Expand Down Expand Up @@ -395,7 +414,7 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStream(
return totalRecordsFetched, nil
}

func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription) (*model.QRecord, error) {
func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription, connStr string) (*model.QRecord, error) {
// make vals an empty array of QValue of size len(fds)
record := model.NewQRecord(len(fds))

Expand All @@ -405,7 +424,7 @@ func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription) (*model.QRecor
}

for i, fd := range fds {
tmp, err := parseFieldFromPostgresOID(fd.DataTypeOID, values[i])
tmp, err := parseFieldFromPostgresOID(fd.DataTypeOID, values[i], connStr)
if err != nil {
return nil, fmt.Errorf("failed to parse field: %w", err)
}
Expand Down
27 changes: 22 additions & 5 deletions flow/connectors/postgres/qvalue_convert.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package connpostgres

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand All @@ -9,13 +10,14 @@ import (
"strings"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/jackc/pgx/v5/pgtype"
"github.com/lib/pq/oid"
log "github.com/sirupsen/logrus"
)

func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind {
func postgresOIDToQValueKind(recvOID uint32, connStr string) qvalue.QValueKind {
switch recvOID {
case pgtype.BoolOID:
return qvalue.QValueKindBoolean
Expand Down Expand Up @@ -55,6 +57,8 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind {
return qvalue.QValueKindArrayInt32
case pgtype.Int8ArrayOID:
return qvalue.QValueKindArrayInt64
case pgtype.PointOID:
return qvalue.QValueKindPoint
case pgtype.Float4ArrayOID:
return qvalue.QValueKindArrayFloat32
case pgtype.Float8ArrayOID:
Expand All @@ -77,9 +81,18 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind {
return qvalue.QValueKindString
} else if recvOID == uint32(oid.T_tsquery) { // TSQUERY
return qvalue.QValueKindString
} else if recvOID == uint32(oid.T_point) {
return qvalue.QValueKindPoint
}
// log.Warnf("failed to get type name for oid: %v", recvOID)
return qvalue.QValueKindInvalid
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
qKind, err := utils.GetCustomDataType(ctx, connStr, recvOID)
if err != nil {
log.Warnf("failed to get type name for oid: %v", recvOID)
return qvalue.QValueKindInvalid
}
return qKind
} else {
log.Warnf("unsupported field type: %v - type name - %s; returning as string", recvOID, typeName.Name)
return qvalue.QValueKindString
Expand Down Expand Up @@ -337,6 +350,10 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
return nil, fmt.Errorf("failed to parse hstore: %w", err)
}
val = &qvalue.QValue{Kind: qvalue.QValueKindHStore, Value: hstoreVal}
case qvalue.QValueKindPoint:
x_coord := value.(pgtype.Point).P.X
y_coord := value.(pgtype.Point).P.Y
val = &qvalue.QValue{Kind: qvalue.QValueKindPoint, Value: fmt.Sprintf("POINT(%f %f)", x_coord, y_coord)}
default:
// log.Warnf("unhandled QValueKind => %v, parsing as string", qvalueKind)
textVal, ok := value.(string)
Expand All @@ -353,8 +370,8 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
return val, nil
}

func parseFieldFromPostgresOID(oid uint32, value interface{}) (*qvalue.QValue, error) {
return parseFieldFromQValueKind(postgresOIDToQValueKind(oid), value)
func parseFieldFromPostgresOID(oid uint32, value interface{}, connStr string) (*qvalue.QValue, error) {
return parseFieldFromQValueKind(postgresOIDToQValueKind(oid, connStr), value)
}

func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) {
Expand Down
25 changes: 17 additions & 8 deletions flow/connectors/snowflake/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,15 @@ func (c *SnowflakeConnector) ConsolidateQRepPartitions(config *protos.QRepConfig
case protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT:
return fmt.Errorf("multi-insert sync mode not supported for snowflake")
case protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO:
allCols, err := c.getColsFromTable(destTable)
colInfo, err := c.getColsFromTable(destTable)
if err != nil {
log.WithFields(log.Fields{
"flowName": config.FlowJobName,
}).Errorf("failed to get columns from table %s: %v", destTable, err)
return fmt.Errorf("failed to get columns from table %s: %w", destTable, err)
}

allCols := colInfo.Columns
err = CopyStageToDestination(c, config, destTable, stageName, allCols)
if err != nil {
log.WithFields(log.Fields{
Expand All @@ -283,7 +284,7 @@ func (c *SnowflakeConnector) CleanupQRepFlow(config *protos.QRepConfig) error {
return c.dropStage(config.StagingPath, config.FlowJobName)
}

func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, error) {
func (c *SnowflakeConnector) getColsFromTable(tableName string) (*model.ColumnInformation, error) {
// parse the table name to get the schema and table name
components, err := parseTableName(tableName)
if err != nil {
Expand All @@ -296,7 +297,7 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, error

//nolint:gosec
queryString := fmt.Sprintf(`
SELECT column_name
SELECT column_name, data_type
FROM information_schema.columns
WHERE UPPER(table_name) = '%s' AND UPPER(table_schema) = '%s'
`, components.tableIdentifier, components.schemaIdentifier)
Expand All @@ -307,16 +308,24 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, error
}
defer rows.Close()

var cols []string
columnMap := map[string]string{}
for rows.Next() {
var col string
if err := rows.Scan(&col); err != nil {
var colName string
var colType string
if err := rows.Scan(&colName, &colType); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
cols = append(cols, col)
columnMap[colName] = colType
}
var cols []string
for k := range columnMap {
cols = append(cols, k)
}

return cols, nil
return &model.ColumnInformation{
ColumnMap: columnMap,
Columns: cols,
}, nil
}

// dropStage drops the stage for the given job.
Expand Down
Loading

0 comments on commit b00e32f

Please sign in to comment.