diff --git a/flow/cmd/peer_data.go b/flow/cmd/peer_data.go index 110b9b5a7f..0bc5f7d245 100644 --- a/flow/cmd/peer_data.go +++ b/flow/cmd/peer_data.go @@ -8,11 +8,9 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" "google.golang.org/protobuf/proto" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" - "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" ) @@ -33,17 +31,19 @@ func (h *FlowRequestHandler) getPGPeerConfig(ctx context.Context, peerName strin return &pgPeerConfig, nil } -func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*pgxpool.Pool, error) { +func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*connpostgres.SSHWrappedPostgresPool, error) { pgPeerConfig, err := h.getPGPeerConfig(ctx, peerName) if err != nil { return nil, err } - connStr := utils.GetPGConnectionString(pgPeerConfig) - peerPool, err := pgxpool.New(ctx, connStr) + + pool, err := connpostgres.NewSSHWrappedPostgresPoolFromConfig(ctx, pgPeerConfig) if err != nil { + slog.Error("Failed to create postgres pool", slog.Any("error", err)) return nil, err } - return peerPool, nil + + return pool, nil } func (h *FlowRequestHandler) GetSchemas( diff --git a/flow/cmd/validate_mirror.go b/flow/cmd/validate_mirror.go index 7f10020e58..8240f7df01 100644 --- a/flow/cmd/validate_mirror.go +++ b/flow/cmd/validate_mirror.go @@ -12,21 +12,20 @@ import ( func (h *FlowRequestHandler) ValidateCDCMirror( ctx context.Context, req *protos.CreateCDCFlowRequest, ) (*protos.ValidateCDCMirrorResponse, error) { - pgPeer, err := connpostgres.NewPostgresConnector(ctx, req.ConnectionConfigs.Source.GetPostgresConfig()) + sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig() + if sourcePeerConfig == nil { + slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source)) + return nil, fmt.Errorf("source peer config is nil") + } + + pgPeer, err := connpostgres.NewPostgresConnector(ctx, sourcePeerConfig) if err != nil { return &protos.ValidateCDCMirrorResponse{ Ok: false, }, fmt.Errorf("failed to create postgres connector: %v", err) } - defer pgPeer.Close() - sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig() - if sourcePeerConfig == nil { - slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source)) - return nil, fmt.Errorf("source peer config is nil") - } - // Check permissions of postgres peer err = pgPeer.CheckReplicationPermissions(sourcePeerConfig.User) if err != nil { diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 8e6fd3e38c..3443993512 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -260,11 +260,11 @@ func (c *BigQueryConnector) ReplayTableSchemaDeltas(flowJobName string, for _, addedColumn := range schemaDelta.AddedColumns { dstDatasetTable, _ := c.convertToDatasetTable(schemaDelta.DstTableName) query := c.client.Query(fmt.Sprintf( - "ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS `%s` %s", dstDatasetTable.dataset, + "ALTER TABLE %s ADD COLUMN IF NOT EXISTS `%s` %s", dstDatasetTable.table, addedColumn.ColumnName, qValueKindToBigQueryType(addedColumn.ColumnType))) query.DefaultProjectID = c.projectID - query.DefaultDatasetID = c.datasetID + query.DefaultDatasetID = dstDatasetTable.dataset _, err := query.Read(c.ctx) if err != nil { return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName, @@ -312,7 +312,7 @@ func (c *BigQueryConnector) SetupMetadataTables() error { } func (c *BigQueryConnector) GetLastOffset(jobName string) (int64, error) { - query := fmt.Sprintf("SELECT offset FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName) + query := fmt.Sprintf("SELECT offset FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName) q := c.client.Query(query) q.DefaultProjectID = c.projectID q.DefaultDatasetID = c.datasetID @@ -339,8 +339,7 @@ func (c *BigQueryConnector) GetLastOffset(jobName string) (int64, error) { func (c *BigQueryConnector) SetLastOffset(jobName string, lastOffset int64) error { query := fmt.Sprintf( - "UPDATE %s.%s SET offset = GREATEST(offset, %d) WHERE mirror_job_name = '%s'", - c.datasetID, + "UPDATE %s SET offset = GREATEST(offset, %d) WHERE mirror_job_name = '%s'", MirrorJobsTable, lastOffset, jobName, @@ -357,8 +356,8 @@ func (c *BigQueryConnector) SetLastOffset(jobName string, lastOffset int64) erro } func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { - query := fmt.Sprintf("SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name = '%s'", - c.datasetID, MirrorJobsTable, jobName) + query := fmt.Sprintf("SELECT sync_batch_id FROM %s WHERE mirror_job_name = '%s'", + MirrorJobsTable, jobName) q := c.client.Query(query) q.DisableQueryCache = true q.DefaultProjectID = c.projectID @@ -385,8 +384,8 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { } func (c *BigQueryConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { - query := fmt.Sprintf("SELECT normalize_batch_id FROM %s.%s WHERE mirror_job_name = '%s'", - c.datasetID, MirrorJobsTable, jobName) + query := fmt.Sprintf("SELECT normalize_batch_id FROM %s WHERE mirror_job_name = '%s'", + MirrorJobsTable, jobName) q := c.client.Query(query) q.DisableQueryCache = true q.DefaultProjectID = c.projectID @@ -416,9 +415,9 @@ func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syn rawTableName := c.getRawTableName(flowJobName) // Prepare the query to retrieve distinct tables in that batch - query := fmt.Sprintf(`SELECT DISTINCT _peerdb_destination_table_name FROM %s.%s + query := fmt.Sprintf(`SELECT DISTINCT _peerdb_destination_table_name FROM %s WHERE _peerdb_batch_id > %d and _peerdb_batch_id <= %d`, - c.datasetID, rawTableName, normalizeBatchID, syncBatchID) + rawTableName, normalizeBatchID, syncBatchID) // Run the query q := c.client.Query(query) q.DefaultProjectID = c.projectID @@ -459,10 +458,10 @@ func (c *BigQueryConnector) getTableNametoUnchangedCols(flowJobName string, sync // where a placeholder value for unchanged cols can be set in DeleteRecord if there is no backfill // we don't want these particular DeleteRecords to be used in the update statement query := fmt.Sprintf(`SELECT _peerdb_destination_table_name, - array_agg(DISTINCT _peerdb_unchanged_toast_columns) as unchanged_toast_columns FROM %s.%s + array_agg(DISTINCT _peerdb_unchanged_toast_columns) as unchanged_toast_columns FROM %s WHERE _peerdb_batch_id > %d AND _peerdb_batch_id <= %d AND _peerdb_record_type != 2 GROUP BY _peerdb_destination_table_name`, - c.datasetID, rawTableName, normalizeBatchID, syncBatchID) + rawTableName, normalizeBatchID, syncBatchID) // Run the query q := c.client.Query(query) q.DefaultDatasetID = c.datasetID @@ -587,6 +586,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) dstDatasetTable, _ := c.convertToDatasetTable(tableName) mergeGen := &mergeStmtGenerator{ rawDatasetTable: &datasetTable{ + project: c.projectID, dataset: c.datasetID, table: rawTableName, }, @@ -609,7 +609,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) i+1, len(mergeStmts), tableName)) q := c.client.Query(mergeStmt) q.DefaultProjectID = c.projectID - q.DefaultDatasetID = c.datasetID + q.DefaultDatasetID = dstDatasetTable.dataset _, err = q.Read(c.ctx) if err != nil { return nil, fmt.Errorf("failed to execute merge statement %s: %v", mergeStmt, err) @@ -618,8 +618,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } // update metadata to make the last normalized batch id to the recent last sync batch id. updateMetadataStmt := fmt.Sprintf( - "UPDATE %s.%s SET normalize_batch_id=%d WHERE mirror_job_name='%s';", - c.datasetID, MirrorJobsTable, req.SyncBatchID, req.FlowJobName) + "UPDATE %s SET normalize_batch_id=%d WHERE mirror_job_name='%s';", + MirrorJobsTable, req.SyncBatchID, req.FlowJobName) query := c.client.Query(updateMetadataStmt) query.DefaultProjectID = c.projectID @@ -719,12 +719,12 @@ func (c *BigQueryConnector) getUpdateMetadataStmt(jobName string, lastSyncedChec // create the job in the metadata table jobStatement := fmt.Sprintf( - "INSERT INTO %s.%s (mirror_job_name,offset,sync_batch_id) VALUES ('%s',%d,%d);", - c.datasetID, MirrorJobsTable, jobName, lastSyncedCheckpointID, batchID) + "INSERT INTO %s (mirror_job_name,offset,sync_batch_id) VALUES ('%s',%d,%d);", + MirrorJobsTable, jobName, lastSyncedCheckpointID, batchID) if hasJob { jobStatement = fmt.Sprintf( - "UPDATE %s.%s SET offset=GREATEST(offset,%d),sync_batch_id=%d WHERE mirror_job_name = '%s';", - c.datasetID, MirrorJobsTable, lastSyncedCheckpointID, batchID, jobName) + "UPDATE %s SET offset=GREATEST(offset,%d),sync_batch_id=%d WHERE mirror_job_name = '%s';", + MirrorJobsTable, lastSyncedCheckpointID, batchID, jobName) } return jobStatement, nil @@ -733,8 +733,8 @@ func (c *BigQueryConnector) getUpdateMetadataStmt(jobName string, lastSyncedChec // metadataHasJob checks if the metadata table has the given job. func (c *BigQueryConnector) metadataHasJob(jobName string) (bool, error) { checkStmt := fmt.Sprintf( - "SELECT COUNT(*) FROM %s.%s WHERE mirror_job_name = '%s'", - c.datasetID, MirrorJobsTable, jobName) + "SELECT COUNT(*) FROM %s WHERE mirror_job_name = '%s'", + MirrorJobsTable, jobName) q := c.client.Query(checkStmt) q.DefaultProjectID = c.projectID @@ -804,13 +804,14 @@ func (c *BigQueryConnector) SetupNormalizedTables( // convert the column names and types to bigquery types columns := make([]*bigquery.FieldSchema, 0, len(tableSchema.ColumnNames)+2) - utils.IterColumns(tableSchema, func(colName, genericColType string) { + for i, colName := range tableSchema.ColumnNames { + genericColType := tableSchema.ColumnTypes[i] columns = append(columns, &bigquery.FieldSchema{ Name: colName, Type: qValueKindToBigQueryType(genericColType), Repeated: qvalue.QValueKind(genericColType).IsArray(), }) - }) + } if req.SoftDeleteColName != "" { columns = append(columns, &bigquery.FieldSchema{ @@ -872,7 +873,7 @@ func (c *BigQueryConnector) SyncFlowCleanup(jobName string) error { } // deleting job from metadata table - query := fmt.Sprintf("DELETE FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName) + query := fmt.Sprintf("DELETE FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName) queryHandler := c.client.Query(query) queryHandler.DefaultProjectID = c.projectID queryHandler.DefaultDatasetID = c.datasetID @@ -902,7 +903,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos dstDatasetTable.string())) if req.SoftDeleteColName != nil { - allCols := strings.Join(utils.TableSchemaColumnNames(renameRequest.TableSchema), ",") + allCols := strings.Join(renameRequest.TableSchema.ColumnNames, ",") pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",") c.logger.InfoContext(c.ctx, fmt.Sprintf("handling soft-deletes for table '%s'...", dstDatasetTable.string())) @@ -1011,12 +1012,16 @@ func (c *BigQueryConnector) CreateTablesFromExisting(req *protos.CreateTablesFro } type datasetTable struct { + project string dataset string table string } func (d *datasetTable) string() string { - return fmt.Sprintf("%s.%s", d.dataset, d.table) + if d.project == "" { + return fmt.Sprintf("%s.%s", d.dataset, d.table) + } + return fmt.Sprintf("%s.%s.%s", d.project, d.dataset, d.table) } func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTable, error) { @@ -1031,6 +1036,12 @@ func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTab dataset: parts[0], table: parts[1], }, nil + } else if len(parts) == 3 { + return &datasetTable{ + project: parts[0], + dataset: parts[1], + table: parts[2], + }, nil } else { return nil, fmt.Errorf("invalid BigQuery table name: %s", tableName) } diff --git a/flow/connectors/bigquery/merge_stmt_generator.go b/flow/connectors/bigquery/merge_stmt_generator.go index d87a83a290..8c7d1992cc 100644 --- a/flow/connectors/bigquery/merge_stmt_generator.go +++ b/flow/connectors/bigquery/merge_stmt_generator.go @@ -34,7 +34,7 @@ type mergeStmtGenerator struct { func (m *mergeStmtGenerator) generateFlattenedCTE() string { // for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR // statement. - flattenedProjs := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)+3) + flattenedProjs := make([]string, 0, len(m.normalizedTableSchema.ColumnNames)+3) for i, colName := range m.normalizedTableSchema.ColumnNames { colType := m.normalizedTableSchema.ColumnTypes[i] @@ -93,9 +93,9 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string { ) // normalize anything between last normalized batch id to last sync batchid - return fmt.Sprintf(`WITH _f AS - (SELECT %s FROM %s WHERE _peerdb_batch_id>%d AND _peerdb_batch_id<=%d AND - _peerdb_destination_table_name='%s')`, + return fmt.Sprintf("WITH _f AS "+ + "(SELECT %s FROM `%s` WHERE _peerdb_batch_id>%d AND _peerdb_batch_id<=%d AND "+ + "_peerdb_destination_table_name='%s')", strings.Join(flattenedProjs, ","), m.rawDatasetTable.string(), m.normalizeBatchID, m.syncBatchID, m.dstTableName) } @@ -124,7 +124,7 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string { // generateMergeStmt generates a merge statement. func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) string { // comma separated list of column names - columnCount := utils.TableSchemaColumns(m.normalizedTableSchema) + columnCount := len(m.normalizedTableSchema.ColumnNames) backtickColNames := make([]string, 0, columnCount) shortBacktickColNames := make([]string, 0, columnCount) pureColNames := make([]string, 0, columnCount) @@ -169,15 +169,11 @@ func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) s } } - return fmt.Sprintf(` - MERGE %s _t USING(%s,%s) _d - ON %s - WHEN NOT MATCHED AND _d._rt!=2 THEN - INSERT (%s) VALUES(%s) - %s - WHEN MATCHED AND _d._rt=2 THEN - %s; - `, m.dstDatasetTable.string(), m.generateFlattenedCTE(), m.generateDeDupedCTE(), + return fmt.Sprintf("MERGE `%s` _t USING(%s,%s) _d"+ + " ON %s WHEN NOT MATCHED AND _d._rt!=2 THEN "+ + "INSERT (%s) VALUES(%s) "+ + "%s WHEN MATCHED AND _d._rt=2 THEN %s;", + m.dstDatasetTable.table, m.generateFlattenedCTE(), m.generateDeDupedCTE(), pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart) } diff --git a/flow/connectors/bigquery/qrep.go b/flow/connectors/bigquery/qrep.go index 5878dd2bd0..c0aafe045f 100644 --- a/flow/connectors/bigquery/qrep.go +++ b/flow/connectors/bigquery/qrep.go @@ -112,10 +112,10 @@ func (c *BigQueryConnector) createMetadataInsertStatement( partitionJSON := string(pbytes) insertMetadataStmt := fmt.Sprintf( - "INSERT INTO %s._peerdb_query_replication_metadata"+ + "INSERT INTO _peerdb_query_replication_metadata"+ "(flowJobName, partitionID, syncPartition, syncStartTime, syncFinishTime) "+ "VALUES ('%s', '%s', JSON '%s', TIMESTAMP('%s'), CURRENT_TIMESTAMP());", - c.datasetID, jobName, partition.PartitionId, + jobName, partition.PartitionId, partitionJSON, startTime.Format(time.RFC3339)) return insertMetadataStmt, nil @@ -170,8 +170,8 @@ func (c *BigQueryConnector) SetupQRepMetadataTables(config *protos.QRepConfig) e func (c *BigQueryConnector) isPartitionSynced(partitionID string) (bool, error) { queryString := fmt.Sprintf( - "SELECT COUNT(*) FROM %s._peerdb_query_replication_metadata WHERE partitionID = '%s';", - c.datasetID, partitionID, + "SELECT COUNT(*) FROM _peerdb_query_replication_metadata WHERE partitionID = '%s';", + partitionID, ) query := c.client.Query(queryString) diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index 7e6768963a..c8b182706f 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -58,6 +58,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( stagingTable := fmt.Sprintf("%s_%s_staging", rawTableName, strconv.FormatInt(syncBatchID, 10)) numRecords, err := s.writeToStage(strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema, &datasetTable{ + project: s.connector.projectID, dataset: s.connector.datasetID, table: stagingTable, }, stream, req.FlowJobName) @@ -67,8 +68,8 @@ func (s *QRepAvroSyncMethod) SyncRecords( bqClient := s.connector.client datasetID := s.connector.datasetID - insertStmt := fmt.Sprintf("INSERT INTO `%s.%s` SELECT * FROM `%s.%s`;", - datasetID, rawTableName, datasetID, stagingTable) + insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT * FROM `%s`;", + rawTableName, stagingTable) lastCP, err := req.Records.GetLastCheckpoint() if err != nil { @@ -171,6 +172,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( // create a staging table name with partitionID replace hyphens with underscores dstDatasetTable, _ := s.connector.convertToDatasetTable(dstTableName) stagingDatasetTable := &datasetTable{ + project: s.connector.projectID, dataset: dstDatasetTable.dataset, table: fmt.Sprintf("%s_%s_staging", dstDatasetTable.table, strings.ReplaceAll(partition.PartitionId, "-", "_")), @@ -198,7 +200,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( } // Insert the records from the staging table into the destination table insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT %s FROM `%s`;", - dstDatasetTable.string(), selector, stagingDatasetTable.string()) + dstTableName, selector, stagingDatasetTable.string()) insertMetadataStmt, err := s.connector.createMetadataInsertStatement(partition, flowJobName, startTime) if err != nil { @@ -229,7 +231,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( flowLog) } - slog.Info(fmt.Sprintf("loaded stage into %s", dstDatasetTable.string()), flowLog) + slog.Info(fmt.Sprintf("loaded stage into %s", dstTableName), flowLog) return numRecords, nil } diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 5ef8f70b7a..c6ae8f7d7f 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -77,7 +77,7 @@ const ( RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3 ) - DELETE FROM %s USING %s FROM src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2` + %s src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2` dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s" deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE mirror_job_name=$1" @@ -257,18 +257,27 @@ func (c *PostgresConnector) checkSlotAndPublication(slot string, publication str func (c *PostgresConnector) GetSlotInfo(slotName string) ([]*protos.SlotInfo, error) { var whereClause string if slotName != "" { - whereClause = fmt.Sprintf(" WHERE slot_name = %s", QuoteLiteral(slotName)) + whereClause = fmt.Sprintf("WHERE slot_name=%s", QuoteLiteral(slotName)) } else { - whereClause = fmt.Sprintf(" WHERE database = %s", QuoteLiteral(c.config.Database)) + whereClause = fmt.Sprintf("WHERE database=%s", QuoteLiteral(c.config.Database)) } - rows, err := c.pool.Query(c.ctx, "SELECT slot_name, redo_lsn::Text,restart_lsn::text,wal_status,"+ - "confirmed_flush_lsn::text,active,"+ - "round((CASE WHEN pg_is_in_recovery() THEN pg_last_wal_receive_lsn() ELSE pg_current_wal_lsn() END"+ - " - confirmed_flush_lsn) / 1024 / 1024) AS MB_Behind"+ - " FROM pg_control_checkpoint(), pg_replication_slots"+whereClause) + + hasWALStatus, _, err := c.MajorVersionCheck(POSTGRES_13) if err != nil { return nil, err } + walStatusSelector := "wal_status" + if !hasWALStatus { + walStatusSelector = "'unknown'" + } + rows, err := c.pool.Query(c.ctx, fmt.Sprintf(`SELECT slot_name, redo_lsn::Text,restart_lsn::text,%s, + confirmed_flush_lsn::text,active, + round((CASE WHEN pg_is_in_recovery() THEN pg_last_wal_receive_lsn() ELSE pg_current_wal_lsn() END + - confirmed_flush_lsn) / 1024 / 1024) AS MB_Behind + FROM pg_control_checkpoint(),pg_replication_slots %s`, walStatusSelector, whereClause)) + if err != nil { + return nil, fmt.Errorf("failed to read information for slots: %w", err) + } defer rows.Close() var slotInfoRows []*protos.SlotInfo for rows.Next() { @@ -415,11 +424,12 @@ func generateCreateTableSQLForNormalizedTable( softDeleteColName string, syncedAtColName string, ) string { - createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2) - utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) { + createTableSQLArray := make([]string, 0, len(sourceTableSchema.ColumnNames)+2) + for i, columnName := range sourceTableSchema.ColumnNames { + genericColumnType := sourceTableSchema.ColumnTypes[i] createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("%s %s", QuoteIdentifier(columnName), qValueKindToPostgresType(genericColumnType))) - }) + } if softDeleteColName != "" { createTableSQLArray = append(createTableSQLArray, diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go index 3792c188af..47da11f0c4 100644 --- a/flow/connectors/postgres/normalize_stmt_generator.go +++ b/flow/connectors/postgres/normalize_stmt_generator.go @@ -44,11 +44,12 @@ func (n *normalizeStmtGenerator) generateNormalizeStatements() []string { } func (n *normalizeStmtGenerator) generateFallbackStatements() []string { - columnCount := utils.TableSchemaColumns(n.normalizedTableSchema) + columnCount := len(n.normalizedTableSchema.ColumnNames) columnNames := make([]string, 0, columnCount) flattenedCastsSQLArray := make([]string, 0, columnCount) primaryKeyColumnCasts := make(map[string]string, len(n.normalizedTableSchema.PrimaryKeyColumns)) - utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + for i, columnName := range n.normalizedTableSchema.ColumnNames { + genericColumnType := n.normalizedTableSchema.ColumnTypes[i] quotedCol := QuoteIdentifier(columnName) stringCol := QuoteLiteral(columnName) columnNames = append(columnNames, quotedCol) @@ -64,16 +65,16 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType) } - }) + } flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName) insertColumnsSQL := strings.Join(columnNames, ",") - updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) - utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) { + updateColumnsSQLArray := make([]string, 0, columnCount) + for _, columnName := range n.normalizedTableSchema.ColumnNames { quotedCol := QuoteIdentifier(columnName) updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`%s=EXCLUDED.%s`, quotedCol, quotedCol)) - }) + } updateColumnsSQL := strings.Join(updateColumnsSQLArray, ",") deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) for columnName, columnCast := range primaryKeyColumnCasts { @@ -82,13 +83,15 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { } deleteWhereClauseSQL := strings.Join(deleteWhereClauseArray, " AND ") - deleteUpdate := "" + // make it update instead in case soft-delete is enabled + deleteUpdate := fmt.Sprintf(`DELETE FROM %s USING `, parsedDstTable.String()) if n.peerdbCols.SoftDelete { deleteUpdate = fmt.Sprintf(`UPDATE %s SET %s=TRUE`, parsedDstTable.String(), QuoteIdentifier(n.peerdbCols.SoftDeleteColName)) if n.peerdbCols.SyncedAtColName != "" { deleteUpdate += fmt.Sprintf(`,%s=CURRENT_TIMESTAMP`, QuoteIdentifier(n.peerdbCols.SyncedAtColName)) } + deleteUpdate += " FROM" } fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema, @@ -96,25 +99,26 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { strings.Join(n.normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL) fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema, - n.rawTableName, parsedDstTable.String(), deleteUpdate, deleteWhereClauseSQL) + n.rawTableName, deleteUpdate, deleteWhereClauseSQL) return []string{fallbackUpsertStatement, fallbackDeleteStatement} } func (n *normalizeStmtGenerator) generateMergeStatement() string { - quotedColumnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema) - for i, columnName := range quotedColumnNames { - quotedColumnNames[i] = QuoteIdentifier(columnName) - } + columnCount := len(n.normalizedTableSchema.ColumnNames) + quotedColumnNames := make([]string, columnCount) - flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) + flattenedCastsSQLArray := make([]string, 0, columnCount) parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName) primaryKeyColumnCasts := make(map[string]string) primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) - utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + for i, columnName := range n.normalizedTableSchema.ColumnNames { + genericColumnType := n.normalizedTableSchema.ColumnTypes[i] quotedCol := QuoteIdentifier(columnName) stringCol := QuoteLiteral(columnName) + quotedColumnNames[i] = quotedCol + pgType := qValueKindToPostgresType(genericColumnType) if qvalue.QValueKind(genericColumnType).IsArray() { flattenedCastsSQLArray = append(flattenedCastsSQLArray, @@ -129,9 +133,9 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string { primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", quotedCol, quotedCol)) } - }) + } flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") - insertValuesSQLArray := make([]string, 0, len(quotedColumnNames)+2) + insertValuesSQLArray := make([]string, 0, columnCount+2) for _, quotedCol := range quotedColumnNames { insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", quotedCol)) } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 4cbfde6f19..231d5f78f2 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -55,15 +55,16 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) // set pool size to 3 to avoid connection pool exhaustion connConfig.MaxConns = 3 - // ensure that replication is set to database - replConfig.ConnConfig.RuntimeParams["replication"] = "database" - replConfig.ConnConfig.RuntimeParams["bytea_output"] = "hex" - replConfig.MaxConns = 1 pool, err := NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig) if err != nil { return nil, fmt.Errorf("failed to create connection pool: %w", err) } + // ensure that replication is set to database + replConfig.ConnConfig.RuntimeParams["replication"] = "database" + replConfig.ConnConfig.RuntimeParams["bytea_output"] = "hex" + replConfig.MaxConns = 1 + customTypeMap, err := utils.GetCustomDataTypes(ctx, pool.Pool) if err != nil { return nil, fmt.Errorf("failed to get custom type map: %w", err) diff --git a/flow/connectors/postgres/postgres_schema_delta_test.go b/flow/connectors/postgres/postgres_schema_delta_test.go index 4c3b012243..450303da42 100644 --- a/flow/connectors/postgres/postgres_schema_delta_test.go +++ b/flow/connectors/postgres/postgres_schema_delta_test.go @@ -9,7 +9,6 @@ import ( "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" - "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/e2eshared" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -121,14 +120,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() { PrimaryKeyColumns: []string{"id"}, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != "id" { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, @@ -171,14 +170,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() { PrimaryKeyColumns: []string{"id"}, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != "id" { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, @@ -212,14 +211,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { PrimaryKeyColumns: []string{" "}, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != " " { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, diff --git a/flow/connectors/postgres/ssh_wrapped_pool.go b/flow/connectors/postgres/ssh_wrapped_pool.go index 4f17116ea4..4dcd2cd0ce 100644 --- a/flow/connectors/postgres/ssh_wrapped_pool.go +++ b/flow/connectors/postgres/ssh_wrapped_pool.go @@ -27,6 +27,23 @@ type SSHWrappedPostgresPool struct { cancel context.CancelFunc } +func NewSSHWrappedPostgresPoolFromConfig( + ctx context.Context, + pgConfig *protos.PostgresConfig, +) (*SSHWrappedPostgresPool, error) { + connectionString := utils.GetPGConnectionString(pgConfig) + + connConfig, err := pgxpool.ParseConfig(connectionString) + if err != nil { + return nil, err + } + + // set pool size to 3 to avoid connection pool exhaustion + connConfig.MaxConns = 3 + + return NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig) +} + func NewSSHWrappedPostgresPool( ctx context.Context, poolConfig *pgxpool.Config, diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go index 291b3314d9..d849fe5a58 100644 --- a/flow/connectors/snowflake/merge_stmt_generator.go +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -27,14 +27,15 @@ type mergeStmtGenerator struct { func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { parsedDstTable, _ := utils.ParseSchemaTable(m.dstTableName) - columnNames := utils.TableSchemaColumnNames(m.normalizedTableSchema) + columnNames := m.normalizedTableSchema.ColumnNames - flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)) - err := utils.IterColumnsError(m.normalizedTableSchema, func(columnName, genericColumnType string) error { + flattenedCastsSQLArray := make([]string, 0, len(columnNames)) + for i, columnName := range columnNames { + genericColumnType := m.normalizedTableSchema.ColumnTypes[i] qvKind := qvalue.QValueKind(genericColumnType) sfType, err := qValueKindToSnowflakeType(qvKind) if err != nil { - return fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err) + return "", fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err) } targetColumnName := SnowflakeIdentifierNormalize(columnName) @@ -69,10 +70,6 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { toVariantColumnName, columnName, sfType, targetColumnName)) } } - return nil - }) - if err != nil { - return "", err } flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index e60bf8993e..ac8e8badcf 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -831,17 +831,18 @@ func generateCreateTableSQLForNormalizedTable( softDeleteColName string, syncedAtColName string, ) string { - createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2) - utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) { + createTableSQLArray := make([]string, 0, len(sourceTableSchema.ColumnNames)+2) + for i, columnName := range sourceTableSchema.ColumnNames { + genericColumnType := sourceTableSchema.ColumnTypes[i] normalizedColName := SnowflakeIdentifierNormalize(columnName) sfColType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType)) if err != nil { slog.Warn(fmt.Sprintf("failed to convert column type %s to snowflake type", genericColumnType), slog.Any("error", err)) - return + continue } createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s`, normalizedColName, sfColType)) - }) + } // add a _peerdb_is_deleted column to the normalized table // this is boolean default false, and is used to mark records as deleted @@ -997,7 +998,7 @@ func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*proto for _, renameRequest := range req.RenameTableOptions { src := renameRequest.CurrentName dst := renameRequest.NewName - allCols := strings.Join(utils.TableSchemaColumnNames(renameRequest.TableSchema), ",") + allCols := strings.Join(renameRequest.TableSchema.ColumnNames, ",") pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",") c.logger.Info(fmt.Sprintf("handling soft-deletes for table '%s'...", dst)) diff --git a/flow/connectors/utils/columns.go b/flow/connectors/utils/columns.go deleted file mode 100644 index f1e0340f03..0000000000 --- a/flow/connectors/utils/columns.go +++ /dev/null @@ -1,31 +0,0 @@ -package utils - -import ( - "slices" - - "github.com/PeerDB-io/peer-flow/generated/protos" -) - -func TableSchemaColumns(schema *protos.TableSchema) int { - return len(schema.ColumnNames) -} - -func TableSchemaColumnNames(schema *protos.TableSchema) []string { - return slices.Clone(schema.ColumnNames) -} - -func IterColumns(schema *protos.TableSchema, iter func(k, v string)) { - for i, name := range schema.ColumnNames { - iter(name, schema.ColumnTypes[i]) - } -} - -func IterColumnsError(schema *protos.TableSchema, iter func(k, v string) error) error { - for i, name := range schema.ColumnNames { - err := iter(name, schema.ColumnTypes[i]) - if err != nil { - return err - } - } - return nil -} diff --git a/flow/e2e/snowflake/snowflake_schema_delta_test.go b/flow/e2e/snowflake/snowflake_schema_delta_test.go index 52f02b005e..f83d1ac679 100644 --- a/flow/e2e/snowflake/snowflake_schema_delta_test.go +++ b/flow/e2e/snowflake/snowflake_schema_delta_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" - "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/e2eshared" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -99,14 +98,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddAllColumnTypes() { }, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != "ID" { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, @@ -154,14 +153,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddTrickyColumnNames() { }, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != "ID" { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, @@ -193,14 +192,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddWhitespaceColumnNames() { }, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != " " { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, diff --git a/flow/model/model.go b/flow/model/model.go index dd875a2583..d8e3a7c751 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -512,8 +512,6 @@ type SyncRecordsRequest struct { Records *CDCRecordStream // FlowJobName is the name of the flow job. FlowJobName string - // SyncMode to use for pushing raw records - SyncMode protos.QRepSyncMode // source:destination mappings TableMappings []*protos.TableMapping // Staging path for AVRO files in CDC diff --git a/flow/workflows/cdc_flow.go b/flow/workflows/cdc_flow.go index 34c72c9371..a0d73aa593 100644 --- a/flow/workflows/cdc_flow.go +++ b/flow/workflows/cdc_flow.go @@ -55,11 +55,6 @@ type CDCFlowWorkflowState struct { FlowConfigUpdates []*protos.CDCFlowConfigUpdate } -type SignalProps struct { - BatchSize uint32 - IdleTimeout uint64 -} - // returns a new empty PeerFlowState func NewCDCFlowWorkflowState(numTables int) *CDCFlowWorkflowState { return &CDCFlowWorkflowState{ @@ -170,8 +165,6 @@ func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdates(ctx workflow.Cont additionalTablesWorkflowCfg.DoInitialSnapshot = true additionalTablesWorkflowCfg.InitialSnapshotOnly = true additionalTablesWorkflowCfg.TableMappings = flowConfigUpdate.AdditionalTables - additionalTablesWorkflowCfg.FlowJobName = fmt.Sprintf("%s_additional_tables_%s", cfg.FlowJobName, - strings.ToLower(shared.RandomString(8))) childAdditionalTablesCDCFlowID, err := GetChildWorkflowID(ctx, "cdc-flow", additionalTablesWorkflowCfg.FlowJobName) @@ -377,18 +370,23 @@ func CDCFlowWorkflowWithConfig( cdcPropertiesSignalChannel := workflow.GetSignalChannel(ctx, shared.CDCDynamicPropertiesSignalName) cdcPropertiesSelector := workflow.NewSelector(ctx) cdcPropertiesSelector.AddReceive(cdcPropertiesSignalChannel, func(c workflow.ReceiveChannel, more bool) { - var cdcSignal SignalProps - c.Receive(ctx, &cdcSignal) + var cdcConfigUpdate *protos.CDCFlowConfigUpdate + c.Receive(ctx, &cdcConfigUpdate) // only modify for options since SyncFlow uses it - if cdcSignal.BatchSize > 0 { - syncFlowOptions.BatchSize = cdcSignal.BatchSize + if cdcConfigUpdate.BatchSize > 0 { + syncFlowOptions.BatchSize = cdcConfigUpdate.BatchSize + } + if cdcConfigUpdate.IdleTimeout > 0 { + syncFlowOptions.IdleTimeoutSeconds = cdcConfigUpdate.IdleTimeout } - if cdcSignal.IdleTimeout > 0 { - syncFlowOptions.IdleTimeoutSeconds = cdcSignal.IdleTimeout + if len(cdcConfigUpdate.AdditionalTables) > 0 { + state.FlowConfigUpdates = append(state.FlowConfigUpdates, cdcConfigUpdate) } - slog.Info("CDC Signal received. Parameters on signal reception:", slog.Int("BatchSize", int(cfg.MaxBatchSize)), - slog.Int("IdleTimeout", int(cfg.IdleTimeoutSeconds))) + slog.Info("CDC Signal received. Parameters on signal reception:", + slog.Int("BatchSize", int(syncFlowOptions.BatchSize)), + slog.Int("IdleTimeout", int(syncFlowOptions.IdleTimeoutSeconds)), + slog.Any("AdditionalTables", cdcConfigUpdate.AdditionalTables)) }) cdcPropertiesSelector.AddDefault(func() { diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go index 7bee648f4e..54d5c95a99 100644 --- a/flow/workflows/setup_flow.go +++ b/flow/workflows/setup_flow.go @@ -11,7 +11,6 @@ import ( "golang.org/x/exp/maps" "github.com/PeerDB-io/peer-flow/activities" - "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" ) @@ -207,15 +206,16 @@ func (s *SetupFlowExecution) fetchTableSchemaAndSetupNormalizedTables( for _, mapping := range flowConnectionConfigs.TableMappings { if mapping.SourceTableIdentifier == srcTableName { if len(mapping.Exclude) != 0 { - columnCount := utils.TableSchemaColumns(tableSchema) + columnCount := len(tableSchema.ColumnNames) columnNames := make([]string, 0, columnCount) columnTypes := make([]string, 0, columnCount) - utils.IterColumns(tableSchema, func(columnName, columnType string) { + for i, columnName := range tableSchema.ColumnNames { + columnType := tableSchema.ColumnTypes[i] if !slices.Contains(mapping.Exclude, columnName) { columnNames = append(columnNames, columnName) columnTypes = append(columnTypes, columnType) } - }) + } tableSchema = &protos.TableSchema{ TableIdentifier: tableSchema.TableIdentifier, PrimaryKeyColumns: tableSchema.PrimaryKeyColumns, diff --git a/flow/workflows/snapshot_flow.go b/flow/workflows/snapshot_flow.go index d38801b599..5b1035eb1c 100644 --- a/flow/workflows/snapshot_flow.go +++ b/flow/workflows/snapshot_flow.go @@ -4,6 +4,7 @@ import ( "fmt" "log/slog" "regexp" + "slices" "strings" "time" @@ -13,6 +14,7 @@ import ( "go.temporal.io/sdk/workflow" "github.com/PeerDB-io/peer-flow/concurrency" + connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/shared" @@ -141,11 +143,13 @@ func (s *SnapshotFlowExecution) cloneTable( if len(mapping.Exclude) != 0 { for _, v := range s.tableNameSchemaMapping { if v.TableIdentifier == srcName { - colNames := utils.TableSchemaColumnNames(v) - for i, colName := range colNames { - colNames[i] = fmt.Sprintf(`"%s"`, colName) + quotedColumns := make([]string, 0, len(v.ColumnNames)) + for _, colName := range v.ColumnNames { + if !slices.Contains(mapping.Exclude, colName) { + quotedColumns = append(quotedColumns, connpostgres.QuoteIdentifier(colName)) + } } - from = strings.Join(colNames, ",") + from = strings.Join(quotedColumns, ",") break } } @@ -254,20 +258,15 @@ func SnapshotFlowWorkflow(ctx workflow.Context, config *protos.FlowConnectionCon logger: logger, } - numTablesInParallel := int(config.SnapshotNumTablesInParallel) - if numTablesInParallel <= 0 { - numTablesInParallel = 1 - } - - replCtx := ctx + numTablesInParallel := int(max(config.SnapshotNumTablesInParallel, 1)) if !config.DoInitialSnapshot { - _, err := se.setupReplication(replCtx) + _, err := se.setupReplication(ctx) if err != nil { return fmt.Errorf("failed to setup replication: %w", err) } - if err := se.closeSlotKeepAlive(replCtx); err != nil { + if err := se.closeSlotKeepAlive(ctx); err != nil { return fmt.Errorf("failed to close slot keep alive: %w", err) } diff --git a/protos/flow.proto b/protos/flow.proto index 66ba78c8be..fe0f84ff86 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -246,11 +246,6 @@ message PartitionRange { } // protos for qrep -enum QRepSyncMode { - QREP_SYNC_MODE_MULTI_INSERT = 0; - QREP_SYNC_MODE_STORAGE_AVRO = 1; -} - enum QRepWriteType { QREP_WRITE_MODE_APPEND = 0; QREP_WRITE_MODE_UPSERT = 1; @@ -277,19 +272,13 @@ message QRepConfig { string watermark_column = 7; bool initial_copy_only = 8; - QRepSyncMode sync_mode = 9; - - // DEPRECATED: eliminate when breaking changes are allowed. - uint32 batch_size_int = 10; - // DEPRECATED: eliminate when breaking changes are allowed. - uint32 batch_duration_seconds = 11; - uint32 max_parallel_workers = 12; + uint32 max_parallel_workers = 9; // time to wait between getting partitions to process - uint32 wait_between_batches_seconds = 13; + uint32 wait_between_batches_seconds = 10; - QRepWriteMode write_mode = 14; + QRepWriteMode write_mode = 11; // This is only used when sync_mode is AVRO // this is the location where the avro files will be written @@ -297,22 +286,22 @@ message QRepConfig { // if this starts with s3:// then it will be written to S3, only supported in Snowflake // if nothing is specified then it will be written to local disk // if using GCS or S3 make sure your instance has the correct permissions. - string staging_path = 15; + string staging_path = 12; // This setting overrides batch_size_int and batch_duration_seconds // and instead uses the number of rows per partition to determine // how many rows to process per batch. - uint32 num_rows_per_partition = 16; + uint32 num_rows_per_partition = 13; // Creates the watermark table on the destination as-is, can be used for some queries. - bool setup_watermark_table_on_destination = 17; + bool setup_watermark_table_on_destination = 14; // create new tables with "_peerdb_resync" suffix, perform initial load and then swap the new table with the old ones // to be used after the old mirror is dropped - bool dst_table_full_resync = 18; + bool dst_table_full_resync = 15; - string synced_at_col_name = 19; - string soft_delete_col_name = 20; + string synced_at_col_name = 16; + string soft_delete_col_name = 17; } message QRepPartition { @@ -390,6 +379,8 @@ enum FlowStatus { message CDCFlowConfigUpdate { repeated TableMapping additional_tables = 1; + uint32 batch_size = 2; + uint64 idle_timeout = 3; } message QRepFlowConfigUpdate { @@ -412,3 +403,4 @@ message AddTablesToPublicationInput{ string publication_name = 2; repeated TableMapping additional_tables = 3; } + diff --git a/protos/route.proto b/protos/route.proto index 046a63aca3..48db51e019 100644 --- a/protos/route.proto +++ b/protos/route.proto @@ -285,7 +285,9 @@ service FlowService { rpc ShutdownFlow(ShutdownRequest) returns (ShutdownResponse) { option (google.api.http) = { post: "/v1/mirrors/drop", body: "*" }; } - rpc FlowStateChange(FlowStateChangeRequest) returns (FlowStateChangeResponse) {} + rpc FlowStateChange(FlowStateChangeRequest) returns (FlowStateChangeResponse) { + option (google.api.http) = { post: "/v1/mirrors/state_change", body: "*" }; + } rpc MirrorStatus(MirrorStatusRequest) returns (MirrorStatusResponse) { option (google.api.http) = { get: "/v1/mirrors/{flow_job_name}" }; } diff --git a/ui/app/api/mirrors/drop/route.ts b/ui/app/api/mirrors/drop/route.ts index e3be0f7c41..a52e4b078d 100644 --- a/ui/app/api/mirrors/drop/route.ts +++ b/ui/app/api/mirrors/drop/route.ts @@ -13,7 +13,7 @@ export async function POST(request: Request) { destinationPeer, removeFlowEntry: true, }; - console.log('/drop/mirror: req:', req); + console.log('/mirrors/drop: req:', req); try { const dropStatus: ShutdownResponse = await fetch( `${flowServiceAddr}/v1/mirrors/drop`, diff --git a/ui/app/api/mirrors/state_change/route.ts b/ui/app/api/mirrors/state_change/route.ts new file mode 100644 index 0000000000..bcc4dd2a8a --- /dev/null +++ b/ui/app/api/mirrors/state_change/route.ts @@ -0,0 +1,23 @@ +import { FlowStateChangeResponse } from '@/grpc_generated/route'; +import { GetFlowHttpAddressFromEnv } from '@/rpc/http'; + +export async function POST(request: Request) { + const body = await request.json(); + const flowServiceAddr = GetFlowHttpAddressFromEnv(); + console.log('/mirrors/state_change: req:', body); + try { + const res: FlowStateChangeResponse = await fetch( + `${flowServiceAddr}/v1/mirrors/state_change`, + { + method: 'POST', + body: JSON.stringify(body), + } + ).then((res) => { + return res.json(); + }); + + return new Response(JSON.stringify(res)); + } catch (e) { + console.error(e); + } +} diff --git a/ui/app/mirrors/create/cdc/cdc.tsx b/ui/app/mirrors/create/cdc/cdc.tsx index 9101703c20..db26cefd51 100644 --- a/ui/app/mirrors/create/cdc/cdc.tsx +++ b/ui/app/mirrors/create/cdc/cdc.tsx @@ -1,5 +1,4 @@ 'use client'; -import { QRepSyncMode } from '@/grpc_generated/flow'; import { DBType } from '@/grpc_generated/peers'; import { Button } from '@/lib/Button'; import { Icon } from '@/lib/Icon'; @@ -39,7 +38,7 @@ export default function CDCConfigForm({ }: MirrorConfigProps) { const [show, setShow] = useState(false); const handleChange = (val: string | boolean, setting: MirrorSetting) => { - let stateVal: string | boolean | QRepSyncMode = val; + let stateVal: string | boolean = val; setting.stateHandler(stateVal, setter); }; diff --git a/ui/app/mirrors/create/handlers.ts b/ui/app/mirrors/create/handlers.ts index 00dab7347b..a3fdab3178 100644 --- a/ui/app/mirrors/create/handlers.ts +++ b/ui/app/mirrors/create/handlers.ts @@ -8,7 +8,6 @@ import { import { FlowConnectionConfigs, QRepConfig, - QRepSyncMode, QRepWriteType, } from '@/grpc_generated/flow'; import { DBType, Peer, dBTypeToJSON } from '@/grpc_generated/peers'; @@ -28,28 +27,6 @@ export const handlePeer = ( ) => { if (!peer) return; if (peerEnd === 'dst') { - if (peer.type === DBType.POSTGRES) { - setConfig((curr) => { - return { - ...curr, - cdcSyncMode: QRepSyncMode.QREP_SYNC_MODE_MULTI_INSERT, - snapshotSyncMode: QRepSyncMode.QREP_SYNC_MODE_MULTI_INSERT, - syncMode: QRepSyncMode.QREP_SYNC_MODE_MULTI_INSERT, - }; - }); - } else if ( - peer.type === DBType.SNOWFLAKE || - peer.type === DBType.BIGQUERY - ) { - setConfig((curr) => { - return { - ...curr, - cdcSyncMode: QRepSyncMode.QREP_SYNC_MODE_STORAGE_AVRO, - snapshotSyncMode: QRepSyncMode.QREP_SYNC_MODE_STORAGE_AVRO, - syncMode: QRepSyncMode.QREP_SYNC_MODE_STORAGE_AVRO, - }; - }); - } setConfig((curr) => ({ ...curr, destination: peer, @@ -238,12 +215,6 @@ export const handleCreateQRep = async ( config.flowJobName = flowJobName; config.query = query; - if (config.destinationPeer?.type == DBType.POSTGRES) { - config.syncMode = QRepSyncMode.QREP_SYNC_MODE_MULTI_INSERT; - } else { - config.syncMode = QRepSyncMode.QREP_SYNC_MODE_STORAGE_AVRO; - } - setLoading(true); const statusMessage: UCreateMirrorResponse = await fetch( '/api/mirrors/qrep', diff --git a/ui/app/mirrors/create/helpers/cdc.ts b/ui/app/mirrors/create/helpers/cdc.ts index f4bdf285f9..663e7aee22 100644 --- a/ui/app/mirrors/create/helpers/cdc.ts +++ b/ui/app/mirrors/create/helpers/cdc.ts @@ -29,7 +29,7 @@ export const cdcSettings: MirrorSetting[] = [ stateHandler: (value, setter) => setter((curr: CDCConfig) => ({ ...curr, - idleTimeoutSeconds: (value as number) || 100000, + idleTimeoutSeconds: (value as number) || 60, })), tips: 'Time after which a Sync flow ends, if it happens before pull batch size is reached. Defaults to 60 seconds.', helpfulLink: 'https://docs.peerdb.io/metrics/important_cdc_configs', diff --git a/ui/app/mirrors/create/helpers/common.ts b/ui/app/mirrors/create/helpers/common.ts index 1af51e8d85..5d1b8d9a93 100644 --- a/ui/app/mirrors/create/helpers/common.ts +++ b/ui/app/mirrors/create/helpers/common.ts @@ -1,14 +1,10 @@ -import { - FlowConnectionConfigs, - QRepSyncMode, - QRepWriteType, -} from '@/grpc_generated/flow'; +import { FlowConnectionConfigs, QRepWriteType } from '@/grpc_generated/flow'; import { Peer } from '@/grpc_generated/peers'; export interface MirrorSetting { label: string; stateHandler: ( - value: string | string[] | Peer | boolean | QRepSyncMode | QRepWriteType, + value: string | string[] | Peer | boolean | QRepWriteType, setter: any ) => void; type?: string; diff --git a/ui/app/mirrors/create/qrep/qrep.tsx b/ui/app/mirrors/create/qrep/qrep.tsx index bfabd6a3ce..06b0379b47 100644 --- a/ui/app/mirrors/create/qrep/qrep.tsx +++ b/ui/app/mirrors/create/qrep/qrep.tsx @@ -1,6 +1,6 @@ 'use client'; import { RequiredIndicator } from '@/components/RequiredIndicator'; -import { QRepConfig, QRepSyncMode, QRepWriteType } from '@/grpc_generated/flow'; +import { QRepConfig, QRepWriteType } from '@/grpc_generated/flow'; import { DBType } from '@/grpc_generated/peers'; import { Label } from '@/lib/Label'; import { RowWithSelect, RowWithSwitch, RowWithTextField } from '@/lib/Layout'; @@ -51,8 +51,7 @@ export default function QRepConfigForm({ const [loading, setLoading] = useState(false); const handleChange = (val: string | boolean, setting: MirrorSetting) => { - let stateVal: string | boolean | QRepSyncMode | QRepWriteType | string[] = - val; + let stateVal: string | boolean | QRepWriteType | string[] = val; if (setting.label.includes('Write Type')) { switch (val) { case 'Upsert': diff --git a/ui/app/mirrors/edit/[mirrorId]/cdc.tsx b/ui/app/mirrors/edit/[mirrorId]/cdc.tsx index e7f7ee0ca1..dd55b5a9bf 100644 --- a/ui/app/mirrors/edit/[mirrorId]/cdc.tsx +++ b/ui/app/mirrors/edit/[mirrorId]/cdc.tsx @@ -2,8 +2,8 @@ import { SyncStatusRow } from '@/app/dto/MirrorsDTO'; import TimeLabel from '@/components/TimeComponent'; import { - CDCMirrorStatus, CloneTableSummary, + MirrorStatusResponse, SnapshotStatus, } from '@/grpc_generated/route'; import { Button } from '@/lib/Button'; @@ -230,13 +230,13 @@ export const SnapshotStatusTable = ({ status }: SnapshotStatusProps) => { }; type CDCMirrorStatusProps = { - cdc: CDCMirrorStatus; + status: MirrorStatusResponse; rows: SyncStatusRow[]; createdAt?: Date; syncStatusChild?: React.ReactNode; }; export function CDCMirror({ - cdc, + status, rows, createdAt, syncStatusChild, @@ -249,8 +249,10 @@ export function CDCMirror({ }; let snapshot = <>; - if (cdc.snapshotStatus) { - snapshot = ; + if (status.cdcStatus?.snapshotStatus) { + snapshot = ( + + ); } useEffect(() => { setMounted(true); @@ -283,7 +285,8 @@ export function CDCMirror({ {syncStatusChild} diff --git a/ui/app/mirrors/edit/[mirrorId]/cdcDetails.tsx b/ui/app/mirrors/edit/[mirrorId]/cdcDetails.tsx index 24f4c2ac38..b71dd15615 100644 --- a/ui/app/mirrors/edit/[mirrorId]/cdcDetails.tsx +++ b/ui/app/mirrors/edit/[mirrorId]/cdcDetails.tsx @@ -3,8 +3,11 @@ import { SyncStatusRow } from '@/app/dto/MirrorsDTO'; import MirrorInfo from '@/components/MirrorInfo'; import PeerButton from '@/components/PeerComponent'; import TimeLabel from '@/components/TimeComponent'; -import { FlowConnectionConfigs } from '@/grpc_generated/flow'; +import { FlowConnectionConfigs, FlowStatus } from '@/grpc_generated/flow'; import { dBTypeFromJSON } from '@/grpc_generated/peers'; +import { FlowStateChangeRequest } from '@/grpc_generated/route'; +import { Button } from '@/lib/Button'; +import { Icon } from '@/lib/Icon'; import { Label } from '@/lib/Label'; import moment from 'moment'; import Link from 'next/link'; @@ -13,10 +16,11 @@ import TablePairs from './tablePairs'; type props = { syncs: SyncStatusRow[]; - mirrorConfig: FlowConnectionConfigs | undefined; + mirrorConfig: FlowConnectionConfigs; createdAt?: Date; + mirrorStatus: FlowStatus; }; -function CdcDetails({ syncs, createdAt, mirrorConfig }: props) { +function CdcDetails({ syncs, createdAt, mirrorConfig, mirrorStatus }: props) { let lastSyncedAt = moment( syncs.length > 1 ? syncs[1]?.endTime @@ -31,7 +35,7 @@ function CdcDetails({ syncs, createdAt, mirrorConfig }: props) { return acc; }, 0); - const tablesSynced = mirrorConfig?.tableMappings; + const tablesSynced = mirrorConfig.tableMappings; return ( <>
@@ -49,11 +53,14 @@ function CdcDetails({ syncs, createdAt, mirrorConfig }: props) { borderRadius: '1rem', border: '1px solid rgba(0,0,0,0.1)', cursor: 'pointer', + display: 'flex', + alignItems: 'center', }} > - - + + + {statusChangeHandle(mirrorConfig, mirrorStatus)}
@@ -140,4 +147,80 @@ export function numberWithCommas(x: any): string { return x.toString().replace(/\B(?=(\d{3})+(?!\d))/g, ','); } +function statusChangeHandle( + mirrorConfig: FlowConnectionConfigs, + mirrorStatus: FlowStatus +) { + // hopefully there's a better way to do this cast + if (mirrorStatus.toString() === FlowStatus[FlowStatus.STATUS_RUNNING]) { + return ( + + ); + } else if (mirrorStatus.toString() === FlowStatus[FlowStatus.STATUS_PAUSED]) { + return ( + + ); + } else { + return ( + + ); + } +} + +function formatStatus(mirrorStatus: FlowStatus) { + const mirrorStatusLower = mirrorStatus + .toString() + .split('_') + .at(-1) + ?.toLocaleLowerCase()!; + return ( + mirrorStatusLower.at(0)?.toLocaleUpperCase() + mirrorStatusLower.slice(1) + ); +} + export default CdcDetails; diff --git a/ui/app/mirrors/edit/[mirrorId]/configValues.ts b/ui/app/mirrors/edit/[mirrorId]/configValues.ts index dd5d5bdb97..8d49f3090f 100644 --- a/ui/app/mirrors/edit/[mirrorId]/configValues.ts +++ b/ui/app/mirrors/edit/[mirrorId]/configValues.ts @@ -1,15 +1,5 @@ -import { FlowConnectionConfigs, QRepSyncMode } from '@/grpc_generated/flow'; +import { FlowConnectionConfigs } from '@/grpc_generated/flow'; -const syncModeToLabel = (mode: QRepSyncMode) => { - switch (mode.toString()) { - case 'QREP_SYNC_MODE_STORAGE_AVRO': - return 'AVRO'; - case 'QREP_SYNC_MODE_MULTI_INSERT': - return 'Copy with Binary'; - default: - return 'AVRO'; - } -}; const MirrorValues = (mirrorConfig: FlowConnectionConfigs | undefined) => { return [ { diff --git a/ui/app/mirrors/edit/[mirrorId]/page.tsx b/ui/app/mirrors/edit/[mirrorId]/page.tsx index d472b849e7..c4c8f75cd8 100644 --- a/ui/app/mirrors/edit/[mirrorId]/page.tsx +++ b/ui/app/mirrors/edit/[mirrorId]/page.tsx @@ -33,9 +33,10 @@ export default async function EditMirror({ return
No mirror status found!
; } - let createdAt = await prisma.flows.findFirst({ + let mirrorInfo = await prisma.flows.findFirst({ select: { created_at: true, + workflow_id: true, }, where: { name: mirrorId, @@ -86,9 +87,9 @@ export default async function EditMirror({
{mirrorId}
); diff --git a/ui/components/PeerForms/ClickhouseConfig.tsx b/ui/components/PeerForms/ClickhouseConfig.tsx index 1d76286db0..3e589c5045 100644 --- a/ui/components/PeerForms/ClickhouseConfig.tsx +++ b/ui/components/PeerForms/ClickhouseConfig.tsx @@ -18,7 +18,7 @@ interface ConfigProps { setter: PeerSetter; } -export default function PostgresForm({ settings, setter }: ConfigProps) { +export default function ClickhouseForm({ settings, setter }: ConfigProps) { const [showSSH, setShowSSH] = useState(false); const [sshConfig, setSSHConfig] = useState(blankSSHConfig);