From 5fb024f5670bb00d3f56fc5d5b7df88bce168db3 Mon Sep 17 00:00:00 2001 From: Amogh Bharadwaj Date: Tue, 17 Oct 2023 01:04:04 +0530 Subject: [PATCH 1/3] Geospatial support for Snowflake (#516) - Supports GEOGRAPHY and GEOMETRY of PostGIS from PostgreSQL to Snowflake - both CDC and QRep - Supports Postgres' POINT data type to Snowflake - Adds unique timestamp in prefix for S3 test --- .github/workflows/flow.yml | 2 +- flow/connectors/postgres/cdc.go | 19 ++++- flow/connectors/postgres/postgres.go | 27 +++++-- flow/connectors/postgres/qrep.go | 25 +++++-- .../postgres/qrep_query_executor.go | 75 +++++++++++++------ flow/connectors/postgres/qvalue_convert.go | 24 +++++- flow/connectors/snowflake/qrep.go | 25 +++++-- flow/connectors/snowflake/qrep_avro_sync.go | 69 +++++++++++++++-- flow/connectors/snowflake/qvalue_convert.go | 7 +- flow/connectors/snowflake/snowflake.go | 3 + flow/connectors/utils/postgres.go | 27 +++++++ flow/e2e/bigquery/qrep_flow_bq_test.go | 28 ------- flow/e2e/s3/s3_helper.go | 13 ++-- flow/e2e/snowflake/peer_flow_sf_test.go | 20 +++-- flow/e2e/snowflake/qrep_flow_sf_test.go | 29 ------- flow/e2e/test_utils.go | 33 ++++++-- flow/model/column.go | 8 ++ flow/model/qvalue/avro_converter.go | 22 ++++++ flow/model/qvalue/kind.go | 3 + flow/workflows/cdc_flow.go | 1 - 20 files changed, 329 insertions(+), 131 deletions(-) create mode 100644 flow/model/column.go diff --git a/.github/workflows/flow.yml b/.github/workflows/flow.yml index ffceae7c53..1c3a385d3f 100644 --- a/.github/workflows/flow.yml +++ b/.github/workflows/flow.yml @@ -16,7 +16,7 @@ jobs: timeout-minutes: 30 services: pg_cdc: - image: postgres:15.4-alpine + image: postgis/postgis:15-3.4-alpine ports: - 7132:5432 env: diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 81e9ef790d..c20f0f20d6 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -30,6 +30,7 @@ type PostgresCDCSource struct { typeMap *pgtype.Map startLSN pglogrepl.LSN commitLock bool + customTypeMapping map[uint32]string } type PostgresCDCConfig struct { @@ -43,7 +44,7 @@ type PostgresCDCConfig struct { } // Create a new PostgresCDCSource -func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig) (*PostgresCDCSource, error) { +func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig, customTypeMap map[uint32]string) (*PostgresCDCSource, error) { return &PostgresCDCSource{ ctx: cdcConfig.AppContext, replPool: cdcConfig.Connection, @@ -54,6 +55,7 @@ func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig) (*PostgresCDCSource, err relationMessageMapping: cdcConfig.RelationMessageMapping, typeMap: pgtype.NewMap(), commitLock: false, + customTypeMapping: customTypeMap, }, nil } @@ -527,6 +529,12 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma } return retVal, nil } + typeName, ok := p.customTypeMapping[dataType] + if ok { + return &qvalue.QValue{Kind: customTypeToQKind(typeName), + Value: string(data)}, nil + } + return &qvalue.QValue{Kind: qvalue.QValueKindString, Value: string(data)}, nil } @@ -577,9 +585,16 @@ func (p *PostgresCDCSource) processRelationMessage( for _, column := range currRel.Columns { // not present in previous relation message, but in current one, so added. if prevRelMap[column.Name] == nil { + qKind := postgresOIDToQValueKind(column.DataType) + if qKind == qvalue.QValueKindInvalid { + typeName, ok := p.customTypeMapping[column.DataType] + if ok { + qKind = customTypeToQKind(typeName) + } + } schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.DeltaAddedColumn{ ColumnName: column.Name, - ColumnType: string(postgresOIDToQValueKind(column.DataType)), + ColumnType: string(qKind), }) // present in previous and current relation messages, but data types have changed. // so we add it to AddedColumns and DroppedColumns, knowing that we process DroppedColumns first. diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 9725a6a2af..e7588a74df 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -31,6 +31,7 @@ type PostgresConnector struct { pool *pgxpool.Pool replPool *pgxpool.Pool tableSchemaMapping map[string]*protos.TableSchema + customTypesMapping map[uint32]string } // SchemaTable is a table in a schema. @@ -56,6 +57,11 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) return nil, fmt.Errorf("failed to create connection pool: %w", err) } + customTypeMap, err := utils.GetCustomDataTypes(ctx, pool) + if err != nil { + return nil, fmt.Errorf("failed to get custom type map: %w", err) + } + // ensure that replication is set to database connConfig, err := pgxpool.ParseConfig(connectionString) if err != nil { @@ -72,11 +78,12 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) } return &PostgresConnector{ - connStr: connectionString, - ctx: ctx, - config: pgConfig, - pool: pool, - replPool: replPool, + connStr: connectionString, + ctx: ctx, + config: pgConfig, + pool: pool, + replPool: replPool, + customTypesMapping: customTypeMap, }, nil } @@ -217,7 +224,7 @@ func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) (*model.R Publication: publicationName, TableNameMapping: req.TableNameMapping, RelationMessageMapping: req.RelationMessageMapping, - }) + }, c.customTypesMapping) if err != nil { return nil, fmt.Errorf("failed to create cdc source: %w", err) } @@ -590,8 +597,12 @@ func (c *PostgresConnector) getTableSchemaForTable( for _, fieldDescription := range rows.FieldDescriptions() { genericColType := postgresOIDToQValueKind(fieldDescription.DataTypeOID) if genericColType == qvalue.QValueKindInvalid { - // we use string for invalid types - genericColType = qvalue.QValueKindString + typeName, ok := c.customTypesMapping[fieldDescription.DataTypeOID] + if ok { + genericColType = customTypeToQKind(typeName) + } else { + genericColType = qvalue.QValueKindString + } } res.Columns[fieldDescription.Name] = string(genericColType) diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index 3c0d3124b8..830ff453bf 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -290,8 +290,11 @@ func (c *PostgresConnector) PullQRepRecords( log.WithFields(log.Fields{ "partitionId": partition.PartitionId, }).Infof("pulling full table partition for flow job %s", config.FlowJobName) - executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, + executor, err := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) + if err != nil { + return nil, err + } query := config.Query return executor.ExecuteAndProcessQuery(query) } @@ -336,8 +339,12 @@ func (c *PostgresConnector) PullQRepRecords( return nil, err } - executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, + executor, err := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) + if err != nil { + return nil, err + } + records, err := executor.ExecuteAndProcessQuery(query, rangeStart, rangeEnd) if err != nil { @@ -362,10 +369,14 @@ func (c *PostgresConnector) PullQRepRecordStream( "flowName": config.FlowJobName, "partitionId": partition.PartitionId, }).Infof("pulling full table partition for flow job %s", config.FlowJobName) - executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, + executor, err := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) + if err != nil { + return 0, err + } + query := config.Query - _, err := executor.ExecuteAndProcessQueryStream(stream, query) + _, err = executor.ExecuteAndProcessQueryStream(stream, query) return 0, err } log.WithFields(log.Fields{ @@ -409,8 +420,12 @@ func (c *PostgresConnector) PullQRepRecordStream( return 0, err } - executor := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, + executor, err := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) + if err != nil { + return 0, err + } + numRecords, err := executor.ExecuteAndProcessQueryStream(stream, query, rangeStart, rangeEnd) if err != nil { return 0, err diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index 5e2ef59d59..94943132bb 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -7,6 +7,7 @@ import ( "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/shared" util "github.com/PeerDB-io/peer-flow/utils" "github.com/jackc/pgx/v5" @@ -17,12 +18,13 @@ import ( ) type QRepQueryExecutor struct { - pool *pgxpool.Pool - ctx context.Context - snapshot string - testEnv bool - flowJobName string - partitionID string + pool *pgxpool.Pool + ctx context.Context + snapshot string + testEnv bool + flowJobName string + partitionID string + customTypeMap map[uint32]string } func NewQRepQueryExecutor(pool *pgxpool.Pool, ctx context.Context, @@ -37,18 +39,23 @@ func NewQRepQueryExecutor(pool *pgxpool.Pool, ctx context.Context, } func NewQRepQueryExecutorSnapshot(pool *pgxpool.Pool, ctx context.Context, snapshot string, - flowJobName string, partitionID string) *QRepQueryExecutor { + flowJobName string, partitionID string) (*QRepQueryExecutor, error) { log.WithFields(log.Fields{ "flowName": flowJobName, "partitionID": partitionID, }).Info("Declared new qrep executor for snapshot") - return &QRepQueryExecutor{ - pool: pool, - ctx: ctx, - snapshot: snapshot, - flowJobName: flowJobName, - partitionID: partitionID, + CustomTypeMap, err := utils.GetCustomDataTypes(ctx, pool) + if err != nil { + return nil, fmt.Errorf("failed to get custom data types: %w", err) } + return &QRepQueryExecutor{ + pool: pool, + ctx: ctx, + snapshot: snapshot, + flowJobName: flowJobName, + partitionID: partitionID, + customTypeMap: CustomTypeMap, + }, nil } func (qe *QRepQueryExecutor) SetTestEnv(testEnv bool) { @@ -89,11 +96,22 @@ func (qe *QRepQueryExecutor) executeQueryInTx(tx pgx.Tx, cursorName string, fetc } // FieldDescriptionsToSchema converts a slice of pgconn.FieldDescription to a QRecordSchema. -func fieldDescriptionsToSchema(fds []pgconn.FieldDescription) *model.QRecordSchema { +func (qe *QRepQueryExecutor) fieldDescriptionsToSchema(fds []pgconn.FieldDescription) *model.QRecordSchema { qfields := make([]*model.QField, len(fds)) for i, fd := range fds { cname := fd.Name ctype := postgresOIDToQValueKind(fd.DataTypeOID) + if ctype == qvalue.QValueKindInvalid { + var err error + ctype = qvalue.QValueKind(qe.customTypeMap[fd.DataTypeOID]) + if err != nil { + ctype = qvalue.QValueKindInvalid + typeName, ok := qe.customTypeMap[fd.DataTypeOID] + if ok { + ctype = customTypeToQKind(typeName) + } + } + } // there isn't a way to know if a column is nullable or not // TODO fix this. cnullable := true @@ -118,7 +136,7 @@ func (qe *QRepQueryExecutor) ProcessRows( }).Info("Processing rows") // Iterate over the rows for rows.Next() { - record, err := mapRowToQRecord(rows, fieldDescriptions) + record, err := mapRowToQRecord(rows, fieldDescriptions, qe.customTypeMap) if err != nil { return nil, fmt.Errorf("failed to map row to QRecord: %w", err) } @@ -133,7 +151,7 @@ func (qe *QRepQueryExecutor) ProcessRows( batch := &model.QRecordBatch{ NumRecords: uint32(len(records)), Records: records, - Schema: fieldDescriptionsToSchema(fieldDescriptions), + Schema: qe.fieldDescriptionsToSchema(fieldDescriptions), } log.WithFields(log.Fields{ @@ -155,7 +173,7 @@ func (qe *QRepQueryExecutor) processRowsStream( // Iterate over the rows for rows.Next() { - record, err := mapRowToQRecord(rows, fieldDescriptions) + record, err := mapRowToQRecord(rows, fieldDescriptions, qe.customTypeMap) if err != nil { stream.Records <- &model.QRecordOrError{ Err: fmt.Errorf("failed to map row to QRecord: %w", err), @@ -214,7 +232,7 @@ func (qe *QRepQueryExecutor) processFetchedRows( fieldDescriptions := rows.FieldDescriptions() if !stream.IsSchemaSet() { - schema := fieldDescriptionsToSchema(fieldDescriptions) + schema := qe.fieldDescriptionsToSchema(fieldDescriptions) _ = stream.SetSchema(schema) } @@ -395,7 +413,8 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStream( return totalRecordsFetched, nil } -func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription) (*model.QRecord, error) { +func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription, + customTypeMap map[uint32]string) (*model.QRecord, error) { // make vals an empty array of QValue of size len(fds) record := model.NewQRecord(len(fds)) @@ -405,11 +424,21 @@ func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription) (*model.QRecor } for i, fd := range fds { - tmp, err := parseFieldFromPostgresOID(fd.DataTypeOID, values[i]) - if err != nil { - return nil, fmt.Errorf("failed to parse field: %w", err) + // Check if it's a custom type first + typeName, ok := customTypeMap[fd.DataTypeOID] + if !ok { + tmp, err := parseFieldFromPostgresOID(fd.DataTypeOID, values[i]) + if err != nil { + return nil, fmt.Errorf("failed to parse field: %w", err) + } + record.Set(i, *tmp) + } else { + customTypeVal := qvalue.QValue{ + Kind: customTypeToQKind(typeName), + Value: values[i], + } + record.Set(i, customTypeVal) } - record.Set(i, *tmp) } return record, nil diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index b9c7dcc904..21f802f6c0 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -55,6 +55,8 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { return qvalue.QValueKindArrayInt32 case pgtype.Int8ArrayOID: return qvalue.QValueKindArrayInt64 + case pgtype.PointOID: + return qvalue.QValueKindPoint case pgtype.Float4ArrayOID: return qvalue.QValueKindArrayFloat32 case pgtype.Float8ArrayOID: @@ -77,8 +79,10 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { return qvalue.QValueKindString } else if recvOID == uint32(oid.T_tsquery) { // TSQUERY return qvalue.QValueKindString + } else if recvOID == uint32(oid.T_point) { // POINT + return qvalue.QValueKindPoint } - // log.Warnf("failed to get type name for oid: %v", recvOID) + return qvalue.QValueKindInvalid } else { log.Warnf("unsupported field type: %v - type name - %s; returning as string", recvOID, typeName.Name) @@ -337,6 +341,11 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( return nil, fmt.Errorf("failed to parse hstore: %w", err) } val = &qvalue.QValue{Kind: qvalue.QValueKindHStore, Value: hstoreVal} + case qvalue.QValueKindPoint: + xCoord := value.(pgtype.Point).P.X + yCoord := value.(pgtype.Point).P.Y + val = &qvalue.QValue{Kind: qvalue.QValueKindPoint, + Value: fmt.Sprintf("POINT(%f %f)", xCoord, yCoord)} default: // log.Warnf("unhandled QValueKind => %v, parsing as string", qvalueKind) textVal, ok := value.(string) @@ -380,3 +389,16 @@ func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) { // handle invalid numeric return nil, errors.New("invalid numeric") } + +func customTypeToQKind(typeName string) qvalue.QValueKind { + var qValueKind qvalue.QValueKind + switch typeName { + case "geometry": + qValueKind = qvalue.QValueKindGeometry + case "geography": + qValueKind = qvalue.QValueKindGeography + default: + qValueKind = qvalue.QValueKindString + } + return qValueKind +} diff --git a/flow/connectors/snowflake/qrep.go b/flow/connectors/snowflake/qrep.go index ce0cc48511..094dbcaff9 100644 --- a/flow/connectors/snowflake/qrep.go +++ b/flow/connectors/snowflake/qrep.go @@ -253,7 +253,7 @@ func (c *SnowflakeConnector) ConsolidateQRepPartitions(config *protos.QRepConfig case protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT: return fmt.Errorf("multi-insert sync mode not supported for snowflake") case protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO: - allCols, err := c.getColsFromTable(destTable) + colInfo, err := c.getColsFromTable(destTable) if err != nil { log.WithFields(log.Fields{ "flowName": config.FlowJobName, @@ -261,6 +261,7 @@ func (c *SnowflakeConnector) ConsolidateQRepPartitions(config *protos.QRepConfig return fmt.Errorf("failed to get columns from table %s: %w", destTable, err) } + allCols := colInfo.Columns err = CopyStageToDestination(c, config, destTable, stageName, allCols) if err != nil { log.WithFields(log.Fields{ @@ -283,7 +284,7 @@ func (c *SnowflakeConnector) CleanupQRepFlow(config *protos.QRepConfig) error { return c.dropStage(config.StagingPath, config.FlowJobName) } -func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, error) { +func (c *SnowflakeConnector) getColsFromTable(tableName string) (*model.ColumnInformation, error) { // parse the table name to get the schema and table name components, err := parseTableName(tableName) if err != nil { @@ -296,7 +297,7 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, error //nolint:gosec queryString := fmt.Sprintf(` - SELECT column_name + SELECT column_name, data_type FROM information_schema.columns WHERE UPPER(table_name) = '%s' AND UPPER(table_schema) = '%s' `, components.tableIdentifier, components.schemaIdentifier) @@ -307,16 +308,24 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, error } defer rows.Close() - var cols []string + columnMap := map[string]string{} for rows.Next() { - var col string - if err := rows.Scan(&col); err != nil { + var colName string + var colType string + if err := rows.Scan(&colName, &colType); err != nil { return nil, fmt.Errorf("failed to scan row: %w", err) } - cols = append(cols, col) + columnMap[colName] = colType + } + var cols []string + for k := range columnMap { + cols = append(cols, k) } - return cols, nil + return &model.ColumnInformation{ + ColumnMap: columnMap, + Columns: cols, + }, nil } // dropStage drops the stage for the given job. diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index 71c7d74d9a..4f5a9b8fc5 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -18,6 +18,11 @@ import ( "go.temporal.io/sdk/activity" ) +type CopyInfo struct { + transformationSQL string + columnsSQL string +} + type SnowflakeAvroSyncMethod struct { config *protos.QRepConfig connector *SnowflakeConnector @@ -73,11 +78,12 @@ func (s *SnowflakeAvroSyncMethod) SyncRecords( "flowName": flowJobName, }).Infof("Created stage %s", stage) - allCols, err := s.connector.getColsFromTable(s.config.DestinationTableIdentifier) + colInfo, err := s.connector.getColsFromTable(s.config.DestinationTableIdentifier) if err != nil { return 0, err } + allCols := colInfo.Columns err = s.putFileToStage(localFilePath, stage) if err != nil { return 0, err @@ -251,6 +257,46 @@ func (s *SnowflakeAvroSyncMethod) putFileToStage(localFilePath string, stage str return nil } +func (sc *SnowflakeConnector) GetCopyTransformation(dstTableName string) (*CopyInfo, error) { + colInfo, colsErr := sc.getColsFromTable(dstTableName) + if colsErr != nil { + return nil, fmt.Errorf("failed to get columns from destination table: %w", colsErr) + } + + var transformations []string + var columnOrder []string + for col, colType := range colInfo.ColumnMap { + if col == "_PEERDB_IS_DELETED" { + continue + } + colName := strings.ToLower(col) + // No need to quote raw table columns + if strings.Contains(dstTableName, "_PEERDB_RAW") { + columnOrder = append(columnOrder, colName) + } else { + columnOrder = append(columnOrder, fmt.Sprintf("\"%s\"", colName)) + } + + switch colType { + case "GEOGRAPHY": + transformations = append(transformations, + fmt.Sprintf("TO_GEOGRAPHY($1:\"%s\"::string) AS \"%s\"", colName, colName)) + case "GEOMETRY": + transformations = append(transformations, + fmt.Sprintf("TO_GEOMETRY($1:\"%s\"::string) AS \"%s\"", colName, colName)) + case "NUMBER": + transformations = append(transformations, + fmt.Sprintf("$1:\"%s\" AS \"%s\"", colName, colName)) + default: + transformations = append(transformations, + fmt.Sprintf("($1:\"%s\")::%s AS \"%s\"", colName, colType, colName)) + } + } + transformationSQL := strings.Join(transformations, ",") + columnsSQL := strings.Join(columnOrder, ",") + return &CopyInfo{transformationSQL, columnsSQL}, nil +} + func CopyStageToDestination( connector *SnowflakeConnector, config *protos.QRepConfig, @@ -263,7 +309,6 @@ func CopyStageToDestination( }).Infof("Copying stage to destination %s", dstTableName) copyOpts := []string{ "FILE_FORMAT = (TYPE = AVRO)", - "MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE'", "PURGE = TRUE", "ON_ERROR = 'CONTINUE'", } @@ -278,9 +323,13 @@ func CopyStageToDestination( } } + copyTransformation, err := connector.GetCopyTransformation(dstTableName) + if err != nil { + return fmt.Errorf("failed to get copy transformation: %w", err) + } switch appendMode { case true: - err := writeHandler.HandleAppendMode(config.FlowJobName) + err := writeHandler.HandleAppendMode(config.FlowJobName, copyTransformation) if err != nil { return fmt.Errorf("failed to handle append mode: %w", err) } @@ -288,7 +337,7 @@ func CopyStageToDestination( case false: upsertKeyCols := config.WriteMode.UpsertKeyColumns err := writeHandler.HandleUpsertMode(allCols, upsertKeyCols, config.WatermarkColumn, - config.FlowJobName) + config.FlowJobName, copyTransformation) if err != nil { return fmt.Errorf("failed to handle upsert mode: %w", err) } @@ -348,9 +397,12 @@ func NewSnowflakeAvroWriteHandler( } } -func (s *SnowflakeAvroWriteHandler) HandleAppendMode(flowJobName string) error { +func (s *SnowflakeAvroWriteHandler) HandleAppendMode( + flowJobName string, + copyInfo *CopyInfo) error { //nolint:gosec - copyCmd := fmt.Sprintf("COPY INTO %s FROM @%s %s", s.dstTableName, s.stage, strings.Join(s.copyOpts, ",")) + copyCmd := fmt.Sprintf("COPY INTO %s(%s) FROM (SELECT %s FROM @%s) %s", + s.dstTableName, copyInfo.columnsSQL, copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ",")) log.Infof("running copy command: %s", copyCmd) _, err := s.connector.database.Exec(copyCmd) if err != nil { @@ -424,6 +476,7 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode( upsertKeyCols []string, watermarkCol string, flowJobName string, + copyInfo *CopyInfo, ) error { runID, err := util.RandomUInt64() if err != nil { @@ -443,8 +496,8 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode( }).Infof("created temp table %s", tempTableName) //nolint:gosec - copyCmd := fmt.Sprintf("COPY INTO %s FROM @%s %s", - tempTableName, s.stage, strings.Join(s.copyOpts, ",")) + copyCmd := fmt.Sprintf("COPY INTO %s(%s) FROM (SELECT %s FROM @%s) %s", + tempTableName, copyInfo.columnsSQL, copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ",")) _, err = s.connector.database.Exec(copyCmd) if err != nil { return fmt.Errorf("failed to run COPY INTO command: %w", err) diff --git a/flow/connectors/snowflake/qvalue_convert.go b/flow/connectors/snowflake/qvalue_convert.go index 32b84dd07e..b88517856d 100644 --- a/flow/connectors/snowflake/qvalue_convert.go +++ b/flow/connectors/snowflake/qvalue_convert.go @@ -27,8 +27,11 @@ var qValueKindToSnowflakeTypeMap = map[qvalue.QValueKind]string{ qvalue.QValueKindTimeTZ: "STRING", qvalue.QValueKindInvalid: "STRING", qvalue.QValueKindHStore: "STRING", + qvalue.QValueKindGeography: "GEOGRAPHY", + qvalue.QValueKindGeometry: "GEOMETRY", + qvalue.QValueKindPoint: "GEOMETRY", - // array types will be mapped to STRING + // array types will be mapped to VARIANT qvalue.QValueKindArrayFloat32: "VARIANT", qvalue.QValueKindArrayFloat64: "VARIANT", qvalue.QValueKindArrayInt32: "VARIANT", @@ -60,6 +63,8 @@ var snowflakeTypeToQValueKindMap = map[string]qvalue.QValueKind{ "DECIMAL": qvalue.QValueKindNumeric, "NUMERIC": qvalue.QValueKindNumeric, "VARIANT": qvalue.QValueKindJSON, + "GEOMETRY": qvalue.QValueKindGeometry, + "GEOGRAPHY": qvalue.QValueKindGeography, } func qValueKindToSnowflakeType(colType qvalue.QValueKind) string { diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 7046746ecc..d691fef970 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -958,6 +958,9 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement( case qvalue.QValueKindBytes, qvalue.QValueKindBit: flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+ "AS %s,", toVariantColumnName, columnName, targetColumnName)) + case qvalue.QValueKindGeography, qvalue.QValueKindGeometry, qvalue.QValueKindPoint: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS STRING) AS %s,", + toVariantColumnName, columnName, targetColumnName)) // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle time types and interval types // case model.ColumnTypeTime: // flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+ diff --git a/flow/connectors/utils/postgres.go b/flow/connectors/utils/postgres.go index 883cf36791..2080d2f3dd 100644 --- a/flow/connectors/utils/postgres.go +++ b/flow/connectors/utils/postgres.go @@ -1,10 +1,12 @@ package utils import ( + "context" "fmt" "net/url" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/jackc/pgx/v5/pgxpool" ) func GetPGConnectionString(pgConfig *protos.PostgresConfig) string { @@ -20,3 +22,28 @@ func GetPGConnectionString(pgConfig *protos.PostgresConfig) string { ) return connString } + +func GetCustomDataTypes(ctx context.Context, pool *pgxpool.Pool) (map[uint32]string, error) { + rows, err := pool.Query(ctx, ` + SELECT t.oid, t.typname as type + FROM pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid)) + AND NOT EXISTS(SELECT 1 FROM pg_catalog.pg_type el WHERE el.oid = t.typelem AND el.typarray = t.oid) + AND n.nspname NOT IN ('pg_catalog', 'information_schema'); + `) + if err != nil { + return nil, fmt.Errorf("failed to get custom types: %w", err) + } + + customTypeMap := map[uint32]string{} + for rows.Next() { + var typeID uint32 + var typeName string + if err := rows.Scan(&typeID, &typeName); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + customTypeMap[typeID] = typeName + } + return customTypeMap, nil +} diff --git a/flow/e2e/bigquery/qrep_flow_bq_test.go b/flow/e2e/bigquery/qrep_flow_bq_test.go index fb0fc3b876..5e6374cc1a 100644 --- a/flow/e2e/bigquery/qrep_flow_bq_test.go +++ b/flow/e2e/bigquery/qrep_flow_bq_test.go @@ -3,8 +3,6 @@ package e2e_bigquery import ( "context" "fmt" - "sort" - "strings" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/e2e" @@ -29,32 +27,6 @@ func (s *PeerFlowE2ETestSuiteBQ) setupBQDestinationTable(dstTable string) { fmt.Printf("created table on bigquery: %s.%s. %v\n", s.bqHelper.Config.DatasetId, dstTable, err) } -func (s *PeerFlowE2ETestSuiteBQ) compareTableSchemasBQ(tableName string) { - // read rows from source table - pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart") - pgQueryExecutor.SetTestEnv(true) - - pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery( - fmt.Sprintf("SELECT * FROM e2e_test_%s.%s ORDER BY id", bigquerySuffix, tableName), - ) - s.NoError(err) - sort.Slice(pgRows.Schema.Fields, func(i int, j int) bool { - return strings.Compare(pgRows.Schema.Fields[i].Name, pgRows.Schema.Fields[j].Name) == -1 - }) - - // read rows from destination table - qualifiedTableName := fmt.Sprintf("`%s.%s`", s.bqHelper.Config.DatasetId, tableName) - bqRows, err := s.bqHelper.ExecuteAndProcessQuery( - fmt.Sprintf("SELECT * FROM %s ORDER BY id", qualifiedTableName), - ) - s.NoError(err) - sort.Slice(bqRows.Schema.Fields, func(i int, j int) bool { - return strings.Compare(bqRows.Schema.Fields[i].Name, bqRows.Schema.Fields[j].Name) == -1 - }) - - s.True(pgRows.Schema.EqualNames(bqRows.Schema), "schemas from source and destination tables are not equal") -} - func (s *PeerFlowE2ETestSuiteBQ) compareTableContentsBQ(tableName string, colsString string) { // read rows from source table pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart") diff --git a/flow/e2e/s3/s3_helper.go b/flow/e2e/s3/s3_helper.go index 5d40d9a0e0..7ea629ad2b 100644 --- a/flow/e2e/s3/s3_helper.go +++ b/flow/e2e/s3/s3_helper.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "os" + "time" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/e2e" @@ -15,14 +16,14 @@ import ( ) const ( - peerName string = "test_s3_peer" - prefixName string = "test-s3" + peerName string = "test_s3_peer" ) type S3TestHelper struct { client *s3.S3 s3Config *protos.S3Config bucketName string + prefix string } func NewS3TestHelper(switchToGCS bool) (*S3TestHelper, error) { @@ -51,10 +52,11 @@ func NewS3TestHelper(switchToGCS bool) (*S3TestHelper, error) { if err != nil { return nil, err } + prefix := fmt.Sprintf("peerdb_test/%d", time.Now().UnixNano()) return &S3TestHelper{ client, &protos.S3Config{ - Url: fmt.Sprintf("s3://%s/%s", bucketName, prefixName), + Url: fmt.Sprintf("s3://%s/%s", bucketName, prefix), AccessKeyId: &config.AccessKeyID, SecretAccessKey: &config.SecretAccessKey, Region: &config.Region, @@ -68,6 +70,7 @@ func NewS3TestHelper(switchToGCS bool) (*S3TestHelper, error) { }, }, bucketName, + prefix, }, nil } @@ -89,7 +92,7 @@ func (h *S3TestHelper) ListAllFiles( ) ([]*s3.Object, error) { Bucket := h.bucketName - Prefix := fmt.Sprintf("%s/%s/", prefixName, jobName) + Prefix := fmt.Sprintf("%s/%s/", h.prefix, jobName) files, err := h.client.ListObjects(&s3.ListObjectsInput{ Bucket: &Bucket, Prefix: &Prefix, @@ -105,7 +108,7 @@ func (h *S3TestHelper) ListAllFiles( // Delete all generated objects during the test func (h *S3TestHelper) CleanUp() error { Bucket := h.bucketName - Prefix := prefixName + Prefix := h.prefix files, err := h.client.ListObjects(&s3.ListObjectsInput{ Bucket: &Bucket, Prefix: &Prefix, diff --git a/flow/e2e/snowflake/peer_flow_sf_test.go b/flow/e2e/snowflake/peer_flow_sf_test.go index de8e8ba9c8..418bcfd63e 100644 --- a/flow/e2e/snowflake/peer_flow_sf_test.go +++ b/flow/e2e/snowflake/peer_flow_sf_test.go @@ -596,7 +596,8 @@ func (s *PeerFlowE2ETestSuiteSF) Test_Types_SF() { c14 INET,c15 INTEGER,c16 INTERVAL,c17 JSON,c18 JSONB,c21 MACADDR,c22 MONEY, c23 NUMERIC,c24 OID,c28 REAL,c29 SMALLINT,c30 SMALLSERIAL,c31 SERIAL,c32 TEXT, c33 TIMESTAMP,c34 TIMESTAMPTZ,c35 TIME, c36 TIMETZ,c37 TSQUERY,c38 TSVECTOR, - c39 TXID_SNAPSHOT,c40 UUID,c41 XML); + c39 TXID_SNAPSHOT,c40 UUID,c41 XML, c42 GEOMETRY(POINT), c43 GEOGRAPHY(POINT), + c44 GEOGRAPHY(POLYGON), c45 GEOGRAPHY(LINESTRING), c46 GEOMETRY(LINESTRING), c47 GEOMETRY(POLYGON)); CREATE OR REPLACE FUNCTION random_bytea(bytea_length integer) RETURNS bytea AS $body$ SELECT decode(string_agg(lpad(to_hex(width_bucket(random(), 0, 1, 256)-1),2,'0') ,''), 'hex') @@ -637,7 +638,10 @@ func (s *PeerFlowE2ETestSuiteSF) Test_Types_SF() { 1.2,1.23,4::oid,1.23,1,1,1,'test',now(),now(),now()::time,now()::timetz, 'fat & rat'::tsquery,'a fat cat sat on a mat and ate a fat rat'::tsvector, txid_current_snapshot(), - '66073c38-b8df-4bdb-bbca-1c97596b8940'::uuid,xmlcomment('hello'); + '66073c38-b8df-4bdb-bbca-1c97596b8940'::uuid,xmlcomment('hello'), + 'POINT(1 2)','POINT(40.7128 -74.0060)','POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))', + 'LINESTRING(-74.0060 40.7128, -73.9352 40.7306, -73.9123 40.7831)','LINESTRING(0 0, 1 1, 2 2)', + 'POLYGON((-74.0060 40.7128, -73.9352 40.7306, -73.9123 40.7831, -74.0060 40.7128))'; `, srcTableName)) s.NoError(err) fmt.Println("Executed an insert with all types") @@ -656,7 +660,7 @@ func (s *PeerFlowE2ETestSuiteSF) Test_Types_SF() { noNulls, err := s.sfHelper.CheckNull("test_types_sf", []string{"c41", "c1", "c2", "c3", "c4", "c6", "c39", "c40", "id", "c9", "c11", "c12", "c13", "c14", "c15", "c16", "c17", "c18", "c21", "c22", "c23", "c24", "c28", "c29", "c30", "c31", "c33", "c34", "c35", "c36", - "c37", "c38", "c7", "c8", "c32"}) + "c37", "c38", "c7", "c8", "c32", "c42", "c43", "c44", "c45", "c46"}) if err != nil { fmt.Println("error %w", err) } @@ -679,7 +683,8 @@ func (s *PeerFlowE2ETestSuiteSF) Test_Types_SF_Avro_CDC() { c14 INET,c15 INTEGER,c16 INTERVAL,c17 JSON,c18 JSONB,c21 MACADDR,c22 MONEY, c23 NUMERIC,c24 OID,c28 REAL,c29 SMALLINT,c30 SMALLSERIAL,c31 SERIAL,c32 TEXT, c33 TIMESTAMP,c34 TIMESTAMPTZ,c35 TIME, c36 TIMETZ,c37 TSQUERY,c38 TSVECTOR, - c39 TXID_SNAPSHOT,c40 UUID,c41 XML); + c39 TXID_SNAPSHOT,c40 UUID,c41 XML, c42 GEOMETRY(POINT), c43 GEOGRAPHY(POINT), + c44 GEOGRAPHY(POLYGON), c45 GEOGRAPHY(LINESTRING), c46 GEOMETRY(LINESTRING), c47 GEOMETRY(POLYGON)); CREATE OR REPLACE FUNCTION random_bytea(bytea_length integer) RETURNS bytea AS $body$ SELECT decode(string_agg(lpad(to_hex(width_bucket(random(), 0, 1, 256)-1),2,'0') ,''), 'hex') @@ -721,7 +726,10 @@ func (s *PeerFlowE2ETestSuiteSF) Test_Types_SF_Avro_CDC() { 1.2,1.23,4::oid,1.23,1,1,1,'test',now(),now(),now()::time,now()::timetz, 'fat & rat'::tsquery,'a fat cat sat on a mat and ate a fat rat'::tsvector, txid_current_snapshot(), - '66073c38-b8df-4bdb-bbca-1c97596b8940'::uuid,xmlcomment('hello'); + '66073c38-b8df-4bdb-bbca-1c97596b8940'::uuid,xmlcomment('hello'), + 'POINT(1 2)','POINT(40.7128 -74.0060)','POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))', + 'LINESTRING(-74.0060 40.7128, -73.9352 40.7306, -73.9123 40.7831)','LINESTRING(0 0, 1 1, 2 2)', + 'POLYGON((-74.0060 40.7128, -73.9352 40.7306, -73.9123 40.7831, -74.0060 40.7128))'; `, srcTableName)) s.NoError(err) fmt.Println("Executed an insert with all types") @@ -740,7 +748,7 @@ func (s *PeerFlowE2ETestSuiteSF) Test_Types_SF_Avro_CDC() { noNulls, err := s.sfHelper.CheckNull("test_types_sf_avro_cdc", []string{"c41", "c1", "c2", "c3", "c4", "c6", "c39", "c40", "id", "c9", "c11", "c12", "c13", "c14", "c15", "c16", "c17", "c18", "c21", "c22", "c23", "c24", "c28", "c29", "c30", "c31", "c33", "c34", "c35", "c36", - "c37", "c38", "c7", "c8", "c32"}) + "c37", "c38", "c7", "c8", "c32", "c42", "c43", "c44", "c45", "c46"}) if err != nil { fmt.Println("error %w", err) } diff --git a/flow/e2e/snowflake/qrep_flow_sf_test.go b/flow/e2e/snowflake/qrep_flow_sf_test.go index 01d2532e51..8da1df3c5f 100644 --- a/flow/e2e/snowflake/qrep_flow_sf_test.go +++ b/flow/e2e/snowflake/qrep_flow_sf_test.go @@ -3,8 +3,6 @@ package e2e_snowflake import ( "context" "fmt" - "sort" - "strings" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/e2e" @@ -32,33 +30,6 @@ func (s *PeerFlowE2ETestSuiteSF) setupSFDestinationTable(dstTable string) { fmt.Printf("created table on snowflake: %s.%s. %v\n", s.sfHelper.testSchemaName, dstTable, err) } -func (s *PeerFlowE2ETestSuiteSF) compareTableSchemasSF(tableName string) { - // read rows from source table - pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart") - pgQueryExecutor.SetTestEnv(true) - pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery( - fmt.Sprintf("SELECT * FROM e2e_test_%s.%s LIMIT 0", snowflakeSuffix, tableName), - ) - require.NoError(s.T(), err) - sort.Slice(pgRows.Schema.Fields, func(i int, j int) bool { - return strings.Compare(pgRows.Schema.Fields[i].Name, pgRows.Schema.Fields[j].Name) == -1 - }) - - // read rows from destination table - qualifiedTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, tableName) - // excluding soft-delete column during schema conversion - sfSelQuery := fmt.Sprintf(`SELECT * EXCLUDE _PEERDB_IS_DELETED FROM %s LIMIT 0`, qualifiedTableName) - fmt.Printf("running query on snowflake: %s\n", sfSelQuery) - - sfRows, err := s.sfHelper.ExecuteAndProcessQuery(sfSelQuery) - require.NoError(s.T(), err) - sort.Slice(sfRows.Schema.Fields, func(i int, j int) bool { - return strings.Compare(sfRows.Schema.Fields[i].Name, sfRows.Schema.Fields[j].Name) == -1 - }) - - s.True(pgRows.Schema.EqualNames(sfRows.Schema), "schemas from source and destination tables are not equal") -} - func (s *PeerFlowE2ETestSuiteSF) compareTableContentsSF(tableName string, selector string, caseSensitive bool) { // read rows from source table pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background(), "testflow", "testpart") diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go index f3b3b00fa3..29b7a66ed0 100644 --- a/flow/e2e/test_utils.go +++ b/flow/e2e/test_utils.go @@ -152,7 +152,14 @@ func CreateSourceTableQRep(pool *pgxpool.Pool, suffix string, tableName string) "f7 jsonb", "f8 smallint", } - + if strings.Contains(tableName, "sf") { + tblFields = append(tblFields, "geometry_point geometry(point)", + "geography_point geography(point)", + "geometry_linestring geometry(linestring)", + "geography_linestring geography(linestring)", + "geometry_polygon geometry(polygon)", + "geography_polygon geography(polygon)") + } tblFieldStr := strings.Join(tblFields, ",") _, err := pool.Exec(context.Background(), fmt.Sprintf(` @@ -187,6 +194,15 @@ func PopulateSourceTable(pool *pgxpool.Pool, suffix string, tableName string, ro for i := 0; i < rowCount-1; i++ { id := uuid.New().String() ids = append(ids, id) + geoValues := "" + if strings.Contains(tableName, "sf") { + // geo types + geoValues = `,'POINT(1 2)','POINT(40.7128 -74.0060)', + 'LINESTRING(0 0, 1 1, 2 2)', + 'LINESTRING(-74.0060 40.7128, -73.9352 40.7306, -73.9123 40.7831)', + 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))', + 'POLYGON((-74.0060 40.7128, -73.9352 40.7306, -73.9123 40.7831, -74.0060 40.7128))'` + } row := fmt.Sprintf(` ( '%s', '%s', CURRENT_TIMESTAMP, 3.86487206688919, CURRENT_TIMESTAMP, @@ -198,13 +214,19 @@ func PopulateSourceTable(pool *pgxpool.Pool, suffix string, tableName string, ro CURRENT_TIMESTAMP, 1, ARRAY['text1', 'text2'], ARRAY[123, 456], ARRAY[789, 012], ARRAY['varchar1', 'varchar2'], '{"key": 8.5}', '[{"key1": "value1", "key2": "value2", "key3": "value3"}]', - '{"key": "value"}', 15 + '{"key": "value"}', 15 %s )`, id, uuid.New().String(), uuid.New().String(), - uuid.New().String(), uuid.New().String(), uuid.New().String(), uuid.New().String()) + uuid.New().String(), uuid.New().String(), uuid.New().String(), uuid.New().String(), geoValues) rows = append(rows, row) } + geoColumns := "" + if strings.Contains(tableName, "sf") { + geoColumns = ",geometry_point, geography_point," + + "geometry_linestring, geography_linestring," + + "geometry_polygon, geography_polygon" + } _, err := pool.Exec(context.Background(), fmt.Sprintf(` INSERT INTO e2e_test_%s.%s ( id, card_id, "from", price, created_at, @@ -213,9 +235,10 @@ func PopulateSourceTable(pool *pgxpool.Pool, suffix string, tableName string, ro deal_id, ethereum_transaction_id, ignore_price, card_eth_value, paid_eth_price, card_bought_notified, address, account_id, asset_id, status, transaction_id, settled_at, reference_id, - settle_at, settlement_delay_reason, f1, f2, f3, f4, f5, f6, f7, f8 + settle_at, settlement_delay_reason, f1, f2, f3, f4, f5, f6, f7, f8 + %s ) VALUES %s; - `, suffix, tableName, strings.Join(rows, ","))) + `, suffix, tableName, geoColumns, strings.Join(rows, ","))) if err != nil { return err } diff --git a/flow/model/column.go b/flow/model/column.go new file mode 100644 index 0000000000..5cbf25dc2e --- /dev/null +++ b/flow/model/column.go @@ -0,0 +1,8 @@ +package model + +type ColumnInformation struct { + // This is a mapping from column name to column type + // Example: "name" -> "VARCHAR" + ColumnMap map[string]string + Columns []string // List of column names +} diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 62fefe698d..9b8db89d78 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -36,6 +36,10 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, nullable bool) (*QValueKindAvr return &QValueKindAvroSchema{ AvroLogicalSchema: "string", }, nil + case QValueKindGeometry, QValueKindGeography, QValueKindPoint: + return &QValueKindAvroSchema{ + AvroLogicalSchema: "string", + }, nil case QValueKindInt16, QValueKindInt32, QValueKindInt64: return &QValueKindAvroSchema{ AvroLogicalSchema: "long", @@ -202,6 +206,8 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return c.processArrayString() case QValueKindUUID: return c.processUUID() + case QValueKindGeography, QValueKindGeometry, QValueKindPoint: + return c.processGeospatial() default: return nil, fmt.Errorf("[toavro] unsupported QValueKind: %s", c.Value.Kind) } @@ -330,6 +336,22 @@ func (c *QValueAvroConverter) processUUID() (interface{}, error) { return uuidString, nil } +func (c *QValueAvroConverter) processGeospatial() (interface{}, error) { + if c.Value.Value == nil { + return nil, nil + } + + geoString, ok := c.Value.Value.(string) + if !ok { + return nil, fmt.Errorf("[conversion] invalid geospatial value %v", c.Value.Value) + } + + if c.Nullable { + return goavro.Union("string", geoString), nil + } + return geoString, nil +} + func (c *QValueAvroConverter) processArrayInt32() (interface{}, error) { if c.Value.Value == nil && c.Nullable { return nil, nil diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index 1def708611..1e8e5f5099 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -23,6 +23,9 @@ const ( QValueKindJSON QValueKind = "json" QValueKindBit QValueKind = "bit" QValueKindHStore QValueKind = "hstore" + QValueKindGeography QValueKind = "geography" + QValueKindGeometry QValueKind = "geometry" + QValueKindPoint QValueKind = "point" // array types QValueKindArrayFloat32 QValueKind = "array_float32" diff --git a/flow/workflows/cdc_flow.go b/flow/workflows/cdc_flow.go index fb47ff8ffd..8029a10408 100644 --- a/flow/workflows/cdc_flow.go +++ b/flow/workflows/cdc_flow.go @@ -336,7 +336,6 @@ func CDCFlowWorkflowWithConfig( cfg.TableNameSchemaMapping[modifiedDstTables[i]] = getModifiedSchemaRes.TableNameSchemaMapping[modifiedSrcTables[i]] } - } } From 90d467f5bb52d846380e30d294a9e3b87d4e3336 Mon Sep 17 00:00:00 2001 From: Kaushik Iska Date: Mon, 16 Oct 2023 15:49:15 -0400 Subject: [PATCH 2/3] Make qrep status more useful (#522) --- flow/activities/flowable.go | 12 +- .../connectors/utils/monitoring/monitoring.go | 43 +++++- .../V10__mirror_drop_bad_constraints.sql | 19 +++ .../V11__qrep_runs_start_time_nullable.sql | 5 + ui/app/mirrors/edit/[mirrorId]/cdc.tsx | 12 +- ui/app/mirrors/edit/[mirrorId]/qrep.tsx | 0 .../mirrors/status/qrep/[mirrorId]/page.tsx | 43 ++++++ .../qrep/[mirrorId]/qrepStatusTable.tsx | 146 ++++++++++++++++++ ui/prisma/schema.prisma | 9 +- 9 files changed, 276 insertions(+), 13 deletions(-) create mode 100644 nexus/catalog/migrations/V10__mirror_drop_bad_constraints.sql create mode 100644 nexus/catalog/migrations/V11__qrep_runs_start_time_nullable.sql delete mode 100644 ui/app/mirrors/edit/[mirrorId]/qrep.tsx create mode 100644 ui/app/mirrors/status/qrep/[mirrorId]/page.tsx create mode 100644 ui/app/mirrors/status/qrep/[mirrorId]/qrepStatusTable.tsx diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 4ecb3afcd8..f66efa25d1 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -418,7 +418,6 @@ func (a *FlowableActivity) GetQRepPartitions(ctx context.Context, shutdown <- true }() - startTime := time.Now() partitions, err := srcConn.GetQRepPartitions(config, last) if err != nil { return nil, fmt.Errorf("failed to get partitions from source: %w", err) @@ -428,7 +427,6 @@ func (a *FlowableActivity) GetQRepPartitions(ctx context.Context, ctx, config, runUUID, - startTime, partitions, ) if err != nil { @@ -447,6 +445,11 @@ func (a *FlowableActivity) ReplicateQRepPartitions(ctx context.Context, partitions *protos.QRepPartitionBatch, runUUID string, ) error { + err := a.CatalogMirrorMonitor.UpdateStartTimeForQRepRun(ctx, runUUID) + if err != nil { + return fmt.Errorf("failed to update start time for qrep run: %w", err) + } + numPartitions := len(partitions.Partitions) log.Infof("replicating partitions for job - %s - batch %d - size: %d\n", config.FlowJobName, partitions.BatchId, numPartitions) @@ -469,6 +472,11 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, partition *protos.QRepPartition, runUUID string, ) error { + err := a.CatalogMirrorMonitor.UpdateStartTimeForPartition(ctx, runUUID, partition) + if err != nil { + return fmt.Errorf("failed to update start time for partition: %w", err) + } + ctx = context.WithValue(ctx, shared.EnableMetricsKey, a.EnableMetrics) srcConn, err := connectors.GetQRepPullConnector(ctx, config.SourcePeer) if err != nil { diff --git a/flow/connectors/utils/monitoring/monitoring.go b/flow/connectors/utils/monitoring/monitoring.go index 367ea81a85..8c5e9a22e7 100644 --- a/flow/connectors/utils/monitoring/monitoring.go +++ b/flow/connectors/utils/monitoring/monitoring.go @@ -157,7 +157,6 @@ func (c *CatalogMirrorMonitor) InitializeQRepRun( ctx context.Context, config *protos.QRepConfig, runUUID string, - startTime time.Time, partitions []*protos.QRepPartition, ) error { if c == nil || c.catalogConn == nil { @@ -166,8 +165,8 @@ func (c *CatalogMirrorMonitor) InitializeQRepRun( flowJobName := config.GetFlowJobName() _, err := c.catalogConn.Exec(ctx, - "INSERT INTO peerdb_stats.qrep_runs(flow_name,run_uuid,start_time) VALUES($1,$2,$3) ON CONFLICT DO NOTHING", - flowJobName, runUUID, startTime) + "INSERT INTO peerdb_stats.qrep_runs(flow_name,run_uuid) VALUES($1,$2) ON CONFLICT DO NOTHING", + flowJobName, runUUID) if err != nil { return fmt.Errorf("error while inserting qrep run in qrep_runs: %w", err) } @@ -193,6 +192,21 @@ func (c *CatalogMirrorMonitor) InitializeQRepRun( return nil } +func (c *CatalogMirrorMonitor) UpdateStartTimeForQRepRun(ctx context.Context, runUUID string) error { + if c == nil || c.catalogConn == nil { + return nil + } + + _, err := c.catalogConn.Exec(ctx, + "UPDATE peerdb_stats.qrep_runs SET start_time=$1 WHERE run_uuid=$2", + time.Now(), runUUID) + if err != nil { + return fmt.Errorf("error while updating num_rows_to_sync for run_uuid %s in qrep_runs: %w", runUUID, err) + } + + return nil +} + func (c *CatalogMirrorMonitor) UpdateEndTimeForQRepRun(ctx context.Context, runUUID string) error { if c == nil || c.catalogConn == nil { return nil @@ -253,10 +267,10 @@ func (c *CatalogMirrorMonitor) addPartitionToQRepRun(ctx context.Context, flowJo _, err := c.catalogConn.Exec(ctx, `INSERT INTO peerdb_stats.qrep_partitions - (flow_name,run_uuid,partition_uuid,partition_start,partition_end,start_time,restart_count) - VALUES($1,$2,$3,$4,$5,$6,$7) ON CONFLICT(run_uuid,partition_uuid) DO UPDATE SET + (flow_name,run_uuid,partition_uuid,partition_start,partition_end,restart_count) + VALUES($1,$2,$3,$4,$5,$6) ON CONFLICT(run_uuid,partition_uuid) DO UPDATE SET restart_count=qrep_partitions.restart_count+1`, - flowJobName, runUUID, partition.PartitionId, rangeStart, rangeEnd, time.Now(), 0) + flowJobName, runUUID, partition.PartitionId, rangeStart, rangeEnd, 0) if err != nil { return fmt.Errorf("error while inserting qrep partition in qrep_partitions: %w", err) } @@ -264,6 +278,23 @@ func (c *CatalogMirrorMonitor) addPartitionToQRepRun(ctx context.Context, flowJo return nil } +func (c *CatalogMirrorMonitor) UpdateStartTimeForPartition( + ctx context.Context, + runUUID string, + partition *protos.QRepPartition, +) error { + if c == nil || c.catalogConn == nil { + return nil + } + + _, err := c.catalogConn.Exec(ctx, `UPDATE peerdb_stats.qrep_partitions SET start_time=$1 + WHERE run_uuid=$2 AND partition_uuid=$3`, time.Now(), runUUID, partition.PartitionId) + if err != nil { + return fmt.Errorf("error while updating qrep partition in qrep_partitions: %w", err) + } + return nil +} + func (c *CatalogMirrorMonitor) UpdatePullEndTimeAndRowsForPartition(ctx context.Context, runUUID string, partition *protos.QRepPartition, rowsInPartition int64) error { if c == nil || c.catalogConn == nil { diff --git a/nexus/catalog/migrations/V10__mirror_drop_bad_constraints.sql b/nexus/catalog/migrations/V10__mirror_drop_bad_constraints.sql new file mode 100644 index 0000000000..7bf7de1643 --- /dev/null +++ b/nexus/catalog/migrations/V10__mirror_drop_bad_constraints.sql @@ -0,0 +1,19 @@ +-- Drop the foreign key constraint from qrep_partitions to qrep_runs +ALTER TABLE peerdb_stats.qrep_partitions +DROP CONSTRAINT fk_qrep_partitions_run_uuid; + +-- Drop the unique constraint for flow_name from qrep_runs +ALTER TABLE peerdb_stats.qrep_runs +DROP CONSTRAINT uq_qrep_runs_flow_name; + +-- Add unique constraint to qrep_runs for (flow_name, run_uuid) +ALTER TABLE peerdb_stats.qrep_runs +ADD CONSTRAINT uq_qrep_runs_flow_run +UNIQUE (flow_name, run_uuid); + +-- Add foreign key from qrep_partitions to qrep_runs +ALTER TABLE peerdb_stats.qrep_partitions +ADD CONSTRAINT fk_qrep_partitions_run +FOREIGN KEY (flow_name, run_uuid) +REFERENCES peerdb_stats.qrep_runs(flow_name, run_uuid) +ON DELETE CASCADE; diff --git a/nexus/catalog/migrations/V11__qrep_runs_start_time_nullable.sql b/nexus/catalog/migrations/V11__qrep_runs_start_time_nullable.sql new file mode 100644 index 0000000000..3e78249110 --- /dev/null +++ b/nexus/catalog/migrations/V11__qrep_runs_start_time_nullable.sql @@ -0,0 +1,5 @@ +ALTER TABLE peerdb_stats.qrep_runs +ALTER COLUMN start_time DROP NOT NULL; + +ALTER TABLE peerdb_stats.qrep_partitions +ALTER COLUMN start_time DROP NOT NULL; diff --git a/ui/app/mirrors/edit/[mirrorId]/cdc.tsx b/ui/app/mirrors/edit/[mirrorId]/cdc.tsx index 00bbd2a5c6..a83e31cc08 100644 --- a/ui/app/mirrors/edit/[mirrorId]/cdc.tsx +++ b/ui/app/mirrors/edit/[mirrorId]/cdc.tsx @@ -14,6 +14,7 @@ import { ProgressBar } from '@/lib/ProgressBar'; import { SearchField } from '@/lib/SearchField'; import { Table, TableCell, TableRow } from '@/lib/Table'; import moment, { Duration, Moment } from 'moment'; +import Link from 'next/link'; const Badges = [ @@ -35,6 +36,7 @@ const Badges = [ ]; class TableCloneSummary { + flowJobName: string; tableName: string; totalNumPartitions: number; totalNumRows: number; @@ -44,6 +46,7 @@ class TableCloneSummary { cloneStartTime: Moment | null; constructor(clone: QRepMirrorStatus) { + this.flowJobName = clone.config?.flowJobName || ''; this.tableName = clone.config?.watermarkTable || ''; this.totalNumPartitions = 0; this.totalNumRows = 0; @@ -151,7 +154,14 @@ const SnapshotStatusTable = ({ status }: SnapshotStatusProps) => ( - + diff --git a/ui/app/mirrors/edit/[mirrorId]/qrep.tsx b/ui/app/mirrors/edit/[mirrorId]/qrep.tsx deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/ui/app/mirrors/status/qrep/[mirrorId]/page.tsx b/ui/app/mirrors/status/qrep/[mirrorId]/page.tsx new file mode 100644 index 0000000000..131d788ee6 --- /dev/null +++ b/ui/app/mirrors/status/qrep/[mirrorId]/page.tsx @@ -0,0 +1,43 @@ +import prisma from '@/app/utils/prisma'; +import { Header } from '@/lib/Header'; +import { LayoutMain } from '@/lib/Layout'; +import QRepStatusTable, { QRepPartitionStatus } from './qrepStatusTable'; + +type QRepMirrorStatusProps = { + params: { mirrorId: string }; +}; + +export default async function QRepMirrorStatus({ + params: { mirrorId }, +}: QRepMirrorStatusProps) { + const runs = await prisma.qrep_partitions.findMany({ + where: { + flow_name: mirrorId, + start_time: { + not: null, + }, + }, + orderBy: { + start_time: 'desc', + }, + }); + + const partitions = runs.map((run) => { + let ret: QRepPartitionStatus = { + partitionId: run.partition_uuid, + runUuid: run.run_uuid, + startTime: run.start_time, + endTime: run.end_time, + numRows: run.rows_in_partition, + status: '', + }; + return ret; + }); + + return ( + +
{mirrorId}
+ +
+ ); +} diff --git a/ui/app/mirrors/status/qrep/[mirrorId]/qrepStatusTable.tsx b/ui/app/mirrors/status/qrep/[mirrorId]/qrepStatusTable.tsx new file mode 100644 index 0000000000..55ee7d5819 --- /dev/null +++ b/ui/app/mirrors/status/qrep/[mirrorId]/qrepStatusTable.tsx @@ -0,0 +1,146 @@ +'use client'; + +import { Button } from '@/lib/Button'; +import { Checkbox } from '@/lib/Checkbox'; +import { Icon } from '@/lib/Icon'; +import { Label } from '@/lib/Label'; +import { ProgressCircle } from '@/lib/ProgressCircle'; +import { SearchField } from '@/lib/SearchField'; +import { Table, TableCell, TableRow } from '@/lib/Table'; +import moment from 'moment'; +import { useState } from 'react'; + +export type QRepPartitionStatus = { + partitionId: string; + runUuid: string; + status: string; + startTime: Date | null; + endTime: Date | null; + numRows: number | null; +}; + +function TimeOrProgressBar({ time }: { time: Date | null }) { + if (time === null) { + return ; + } else { + return ; + } +} + +function RowPerPartition({ + partitionId, + runUuid, + status, + startTime, + endTime, + numRows, +}: QRepPartitionStatus) { + return ( + + + + + + + + + + + + + + + + + + + + + ); +} + +type QRepStatusTableProps = { + flowJobName: string; + partitions: QRepPartitionStatus[]; +}; + +export default function QRepStatusTable({ + flowJobName, + partitions, +}: QRepStatusTableProps) { + const ROWS_PER_PAGE = 10; + const [currentPage, setCurrentPage] = useState(1); + const totalPages = Math.ceil(partitions.length / ROWS_PER_PAGE); + + const visiblePartitions = partitions.slice( + (currentPage - 1) * ROWS_PER_PAGE, + currentPage * ROWS_PER_PAGE + ); + + const handleNext = () => { + if (currentPage < totalPages) setCurrentPage(currentPage + 1); + }; + + const handlePrevious = () => { + if (currentPage > 1) setCurrentPage(currentPage - 1); + }; + + return ( + Progress} + toolbar={{ + left: ( + <> + + + + + +
+ +
+ + ), + right: , + }} + header={ + + + + + Partition UUID + Run UUID + Start Time + End Time + Num Rows Synced + + } + > + {visiblePartitions.map((partition, index) => ( + + ))} +
+ ); +} diff --git a/ui/prisma/schema.prisma b/ui/prisma/schema.prisma index 0e55d6eaa4..81007f1902 100644 --- a/ui/prisma/schema.prisma +++ b/ui/prisma/schema.prisma @@ -113,13 +113,13 @@ model qrep_partitions { partition_start String partition_end String rows_in_partition Int? - start_time DateTime @db.Timestamp(6) + start_time DateTime? @db.Timestamp(6) pull_end_time DateTime? @db.Timestamp(6) end_time DateTime? @db.Timestamp(6) restart_count Int metadata Json? id Int @id @default(autoincrement()) - qrep_runs qrep_runs @relation(fields: [flow_name], references: [flow_name], onDelete: Cascade, onUpdate: NoAction, map: "fk_qrep_partitions_run_uuid") + qrep_runs qrep_runs @relation(fields: [flow_name, run_uuid], references: [flow_name, run_uuid], onDelete: Cascade, onUpdate: NoAction, map: "fk_qrep_partitions_run") @@unique([run_uuid, partition_uuid]) @@index([flow_name, run_uuid], map: "idx_qrep_partitions_flow_name_run_uuid") @@ -129,15 +129,16 @@ model qrep_partitions { } model qrep_runs { - flow_name String @unique(map: "uq_qrep_runs_flow_name") + flow_name String run_uuid String - start_time DateTime @db.Timestamp(6) + start_time DateTime? @db.Timestamp(6) end_time DateTime? @db.Timestamp(6) metadata Json? config_proto Bytes? id Int @id @default(autoincrement()) qrep_partitions qrep_partitions[] + @@unique([flow_name, run_uuid], map: "uq_qrep_runs_flow_run") @@index([flow_name], map: "idx_qrep_runs_flow_name", type: Hash) @@index([run_uuid], map: "idx_qrep_runs_run_uuid", type: Hash) @@index([start_time], map: "idx_qrep_runs_start_time") From ae666e0e0946718d4bd1d074c2dae53b7f8571d4 Mon Sep 17 00:00:00 2001 From: Kaushik Iska Date: Mon, 16 Oct 2023 20:40:05 -0400 Subject: [PATCH 3/3] Add the ability to push to eventhubs in an asynchronous way (#523) `PEERDB_BETA_EVENTHUB_PUSH_ASYNC=true` enables this --- flow/connectors/eventhub/eventhub.go | 140 ++++++++++++++---------- flow/connectors/eventhub/hub_batches.go | 10 +- flow/connectors/eventhub/hubmanager.go | 22 +--- flow/connectors/utils/env.go | 30 +++++ 4 files changed, 119 insertions(+), 83 deletions(-) create mode 100644 flow/connectors/utils/env.go diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 51dba78b7d..953f44f197 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -17,7 +17,6 @@ import ( "github.com/PeerDB-io/peer-flow/model" cmap "github.com/orcaman/concurrent-map/v2" log "github.com/sirupsen/logrus" - "go.temporal.io/sdk/activity" ) type EventHubConnector struct { @@ -61,15 +60,8 @@ func NewEventHubConnector( func (c *EventHubConnector) Close() error { var allErrors error - // close all the eventhub connections. - err := c.hubManager.Close() - if err != nil { - log.Errorf("failed to close eventhub connections: %v", err) - allErrors = errors.Join(allErrors, err) - } - // close the postgres metadata store. - err = c.pgMetadata.Close() + err := c.pgMetadata.Close() if err != nil { log.Errorf("failed to close postgres metadata store: %v", err) allErrors = errors.Join(allErrors, err) @@ -129,46 +121,32 @@ func (c *EventHubConnector) updateLastOffset(jobName string, offset int64) error return nil } -func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { - shutdown := utils.HeartbeatRoutine(c.ctx, 10*time.Second, func() string { - return fmt.Sprintf("syncing records to eventhub with"+ - " push parallelism %d and push batch size %d", - req.PushParallelism, req.PushBatchSize) - }) - defer func() { - shutdown <- true - }() - tableNameRowsMapping := cmap.New[uint32]() - batch := req.Records - eventsPerHeartBeat := 1000 - eventsPerBatch := int(req.PushBatchSize) - if eventsPerBatch <= 0 { - eventsPerBatch = 10000 - } - maxParallelism := req.PushParallelism - if maxParallelism <= 0 { - maxParallelism = 10 - } +func (c *EventHubConnector) processBatch( + flowJobName string, + batch *model.RecordBatch, + eventsPerBatch int, + maxParallelism int64, +) error { + ctx := context.Background() + tableNameRowsMapping := cmap.New[uint32]() batchPerTopic := NewHubBatches(c.hubManager) toJSONOpts := model.NewToJSONOptions(c.config.UnnestColumns) - startTime := time.Now() for i, record := range batch.Records { json, err := record.GetItems().ToJSONWithOpts(toJSONOpts) if err != nil { log.WithFields(log.Fields{ - "flowName": req.FlowJobName, + "flowName": flowJobName, }).Infof("failed to convert record to json: %v", err) - return nil, err + return err } flushBatch := func() error { - err := c.sendEventBatch(batchPerTopic, maxParallelism, - req.FlowJobName, tableNameRowsMapping) + err := c.sendEventBatch(ctx, batchPerTopic, maxParallelism, flowJobName, tableNameRowsMapping) if err != nil { log.WithFields(log.Fields{ - "flowName": req.FlowJobName, + "flowName": flowJobName, }).Infof("failed to send event batch: %v", err) return err } @@ -179,45 +157,84 @@ func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S topicName, err := NewScopedEventhub(record.GetTableName()) if err != nil { log.WithFields(log.Fields{ - "flowName": req.FlowJobName, + "flowName": flowJobName, }).Infof("failed to get topic name: %v", err) - return nil, err + return err } - err = batchPerTopic.AddEvent(topicName, json) + err = batchPerTopic.AddEvent(ctx, topicName, json) if err != nil { log.WithFields(log.Fields{ - "flowName": req.FlowJobName, + "flowName": flowJobName, }).Infof("failed to add event to batch: %v", err) - return nil, err - } - - if i%eventsPerHeartBeat == 0 { - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("sent %d records to hub: %s", i, topicName.ToString())) + return err } if (i+1)%eventsPerBatch == 0 { err := flushBatch() if err != nil { - return nil, err + return err } } } - // send the remaining events. if batchPerTopic.Len() > 0 { - err := c.sendEventBatch(batchPerTopic, maxParallelism, - req.FlowJobName, tableNameRowsMapping) + err := c.sendEventBatch(ctx, batchPerTopic, maxParallelism, flowJobName, tableNameRowsMapping) if err != nil { - return nil, err + return err } } + rowsSynced := len(batch.Records) log.WithFields(log.Fields{ - "flowName": req.FlowJobName, + "flowName": flowJobName, }).Infof("[total] successfully sent %d records to event hub", rowsSynced) + return nil +} + +func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { + shutdown := utils.HeartbeatRoutine(c.ctx, 10*time.Second, func() string { + return fmt.Sprintf("syncing records to eventhub with"+ + " push parallelism %d and push batch size %d", + req.PushParallelism, req.PushBatchSize) + }) + defer func() { + shutdown <- true + }() - err := c.updateLastOffset(req.FlowJobName, batch.LastCheckPointID) + eventsPerBatch := int(req.PushBatchSize) + if eventsPerBatch <= 0 { + eventsPerBatch = 10000 + } + maxParallelism := req.PushParallelism + if maxParallelism <= 0 { + maxParallelism = 10 + } + + var err error + startTime := time.Now() + + batch := req.Records + + // if env var PEERDB_BETA_EVENTHUB_PUSH_ASYNC=true + // we kick off processBatch in a goroutine and return immediately. + // otherwise, we block until processBatch is done. + if utils.GetEnvBool("PEERDB_BETA_EVENTHUB_PUSH_ASYNC", false) { + go func() { + err = c.processBatch(req.FlowJobName, batch, eventsPerBatch, maxParallelism) + if err != nil { + log.Errorf("[async] failed to process batch: %v", err) + } + }() + } else { + err = c.processBatch(req.FlowJobName, batch, eventsPerBatch, maxParallelism) + if err != nil { + log.Errorf("failed to process batch: %v", err) + return nil, err + } + } + + err = c.updateLastOffset(req.FlowJobName, batch.LastCheckPointID) if err != nil { log.Errorf("failed to update last offset: %v", err) return nil, err @@ -228,18 +245,19 @@ func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S return nil, err } - metrics.LogSyncMetrics(c.ctx, req.FlowJobName, int64(rowsSynced), time.Since(startTime)) - metrics.LogNormalizeMetrics(c.ctx, req.FlowJobName, int64(rowsSynced), - time.Since(startTime), int64(rowsSynced)) + rowsSynced := int64(len(batch.Records)) + metrics.LogSyncMetrics(c.ctx, req.FlowJobName, rowsSynced, time.Since(startTime)) + metrics.LogNormalizeMetrics(c.ctx, req.FlowJobName, rowsSynced, time.Since(startTime), rowsSynced) return &model.SyncResponse{ FirstSyncedCheckPointID: batch.FirstCheckPointID, LastSyncedCheckPointID: batch.LastCheckPointID, - NumRecordsSynced: int64(len(batch.Records)), - TableNameRowsMapping: tableNameRowsMapping.Items(), + NumRecordsSynced: rowsSynced, + TableNameRowsMapping: make(map[string]uint32), }, nil } func (c *EventHubConnector) sendEventBatch( + ctx context.Context, events *HubBatches, maxParallelism int64, flowName string, @@ -268,7 +286,7 @@ func (c *EventHubConnector) sendEventBatch( }() numEvents := eventBatch.NumEvents() - err := c.sendBatch(tblName, eventBatch) + err := c.sendBatch(ctx, tblName, eventBatch) if err != nil { once.Do(func() { firstErr = err }) return @@ -298,8 +316,12 @@ func (c *EventHubConnector) sendEventBatch( return nil } -func (c *EventHubConnector) sendBatch(tblName ScopedEventhub, events *azeventhubs.EventDataBatch) error { - subCtx, cancel := context.WithTimeout(c.ctx, 5*time.Minute) +func (c *EventHubConnector) sendBatch( + ctx context.Context, + tblName ScopedEventhub, + events *azeventhubs.EventDataBatch, +) error { + subCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) defer cancel() hub, err := c.hubManager.GetOrCreateHubClient(tblName) diff --git a/flow/connectors/eventhub/hub_batches.go b/flow/connectors/eventhub/hub_batches.go index 97d1e568a4..652e10d45f 100644 --- a/flow/connectors/eventhub/hub_batches.go +++ b/flow/connectors/eventhub/hub_batches.go @@ -1,6 +1,7 @@ package conneventhub import ( + "context" "fmt" "strings" @@ -20,14 +21,14 @@ func NewHubBatches(manager *EventHubManager) *HubBatches { } } -func (h *HubBatches) AddEvent(name ScopedEventhub, event string) error { +func (h *HubBatches) AddEvent(ctx context.Context, name ScopedEventhub, event string) error { batches, ok := h.batches[name] if !ok { batches = []*azeventhubs.EventDataBatch{} } if len(batches) == 0 { - newBatch, err := h.manager.CreateEventDataBatch(name) + newBatch, err := h.manager.CreateEventDataBatch(ctx, name) if err != nil { return err } @@ -36,7 +37,7 @@ func (h *HubBatches) AddEvent(name ScopedEventhub, event string) error { if err := tryAddEventToBatch(event, batches[len(batches)-1]); err != nil { if strings.Contains(err.Error(), "too large for the batch") { - overflowBatch, err := h.handleBatchOverflow(name, event) + overflowBatch, err := h.handleBatchOverflow(ctx, name, event) if err != nil { return fmt.Errorf("failed to handle batch overflow: %v", err) } @@ -51,10 +52,11 @@ func (h *HubBatches) AddEvent(name ScopedEventhub, event string) error { } func (h *HubBatches) handleBatchOverflow( + ctx context.Context, name ScopedEventhub, event string, ) (*azeventhubs.EventDataBatch, error) { - newBatch, err := h.manager.CreateEventDataBatch(name) + newBatch, err := h.manager.CreateEventDataBatch(ctx, name) if err != nil { return nil, err } diff --git a/flow/connectors/eventhub/hubmanager.go b/flow/connectors/eventhub/hubmanager.go index be2681110d..1241c0ec7a 100644 --- a/flow/connectors/eventhub/hubmanager.go +++ b/flow/connectors/eventhub/hubmanager.go @@ -16,7 +16,6 @@ import ( ) type EventHubManager struct { - ctx context.Context creds *azidentity.DefaultAzureCredential // eventhub peer name -> config peerConfig cmap.ConcurrentMap[string, *protos.EventHubConfig] @@ -36,7 +35,6 @@ func NewEventHubManager( } return &EventHubManager{ - ctx: ctx, creds: creds, peerConfig: peerConfig, } @@ -69,30 +67,14 @@ func (m *EventHubManager) GetOrCreateHubClient(name ScopedEventhub) (*azeventhub return hub.(*azeventhubs.ProducerClient), nil } -func (m *EventHubManager) Close() error { - var globalErr error - m.hubs.Range(func(key, value interface{}) bool { - hub := value.(*azeventhubs.ProducerClient) - err := hub.Close(m.ctx) - if err != nil { - log.Errorf("failed to close eventhub client: %v", err) - globalErr = fmt.Errorf("failed to close eventhub client: %v", err) - return false - } - return true - }) - - return globalErr -} - -func (m *EventHubManager) CreateEventDataBatch(name ScopedEventhub) (*azeventhubs.EventDataBatch, error) { +func (m *EventHubManager) CreateEventDataBatch(ctx context.Context, name ScopedEventhub) (*azeventhubs.EventDataBatch, error) { hub, err := m.GetOrCreateHubClient(name) if err != nil { return nil, err } opts := &azeventhubs.EventDataBatchOptions{} - batch, err := hub.NewEventDataBatch(m.ctx, opts) + batch, err := hub.NewEventDataBatch(ctx, opts) if err != nil { return nil, fmt.Errorf("failed to create event data batch: %v", err) } diff --git a/flow/connectors/utils/env.go b/flow/connectors/utils/env.go new file mode 100644 index 0000000000..d3c1acfff5 --- /dev/null +++ b/flow/connectors/utils/env.go @@ -0,0 +1,30 @@ +package utils + +import ( + "os" + "strconv" +) + +// GetEnv returns the value of the environment variable with the given name +// and a boolean indicating whether the environment variable exists. +func GetEnv(name string) (string, bool) { + val, exists := os.LookupEnv(name) + return val, exists +} + +// GetEnvBool returns the value of the environment variable with the given name +// or defaultValue if the environment variable is not set or is not a valid +// boolean value. +func GetEnvBool(name string, defaultValue bool) bool { + val, ok := GetEnv(name) + if !ok { + return defaultValue + } + + b, err := strconv.ParseBool(val) + if err != nil { + return defaultValue + } + + return b +}