diff --git a/flow/cmd/peer_data.go b/flow/cmd/peer_data.go
index 110b9b5a7f..0bc5f7d245 100644
--- a/flow/cmd/peer_data.go
+++ b/flow/cmd/peer_data.go
@@ -8,11 +8,9 @@ import (
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
- "github.com/jackc/pgx/v5/pgxpool"
"google.golang.org/protobuf/proto"
connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
- "github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
)
@@ -33,17 +31,19 @@ func (h *FlowRequestHandler) getPGPeerConfig(ctx context.Context, peerName strin
return &pgPeerConfig, nil
}
-func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*pgxpool.Pool, error) {
+func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*connpostgres.SSHWrappedPostgresPool, error) {
pgPeerConfig, err := h.getPGPeerConfig(ctx, peerName)
if err != nil {
return nil, err
}
- connStr := utils.GetPGConnectionString(pgPeerConfig)
- peerPool, err := pgxpool.New(ctx, connStr)
+
+ pool, err := connpostgres.NewSSHWrappedPostgresPoolFromConfig(ctx, pgPeerConfig)
if err != nil {
+ slog.Error("Failed to create postgres pool", slog.Any("error", err))
return nil, err
}
- return peerPool, nil
+
+ return pool, nil
}
func (h *FlowRequestHandler) GetSchemas(
diff --git a/flow/cmd/validate_mirror.go b/flow/cmd/validate_mirror.go
index 7f10020e58..8240f7df01 100644
--- a/flow/cmd/validate_mirror.go
+++ b/flow/cmd/validate_mirror.go
@@ -12,21 +12,20 @@ import (
func (h *FlowRequestHandler) ValidateCDCMirror(
ctx context.Context, req *protos.CreateCDCFlowRequest,
) (*protos.ValidateCDCMirrorResponse, error) {
- pgPeer, err := connpostgres.NewPostgresConnector(ctx, req.ConnectionConfigs.Source.GetPostgresConfig())
+ sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig()
+ if sourcePeerConfig == nil {
+ slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source))
+ return nil, fmt.Errorf("source peer config is nil")
+ }
+
+ pgPeer, err := connpostgres.NewPostgresConnector(ctx, sourcePeerConfig)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("failed to create postgres connector: %v", err)
}
-
defer pgPeer.Close()
- sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig()
- if sourcePeerConfig == nil {
- slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source))
- return nil, fmt.Errorf("source peer config is nil")
- }
-
// Check permissions of postgres peer
err = pgPeer.CheckReplicationPermissions(sourcePeerConfig.User)
if err != nil {
diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go
index 8e6fd3e38c..3443993512 100644
--- a/flow/connectors/bigquery/bigquery.go
+++ b/flow/connectors/bigquery/bigquery.go
@@ -260,11 +260,11 @@ func (c *BigQueryConnector) ReplayTableSchemaDeltas(flowJobName string,
for _, addedColumn := range schemaDelta.AddedColumns {
dstDatasetTable, _ := c.convertToDatasetTable(schemaDelta.DstTableName)
query := c.client.Query(fmt.Sprintf(
- "ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS `%s` %s", dstDatasetTable.dataset,
+ "ALTER TABLE %s ADD COLUMN IF NOT EXISTS `%s` %s",
dstDatasetTable.table, addedColumn.ColumnName,
qValueKindToBigQueryType(addedColumn.ColumnType)))
query.DefaultProjectID = c.projectID
- query.DefaultDatasetID = c.datasetID
+ query.DefaultDatasetID = dstDatasetTable.dataset
_, err := query.Read(c.ctx)
if err != nil {
return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName,
@@ -312,7 +312,7 @@ func (c *BigQueryConnector) SetupMetadataTables() error {
}
func (c *BigQueryConnector) GetLastOffset(jobName string) (int64, error) {
- query := fmt.Sprintf("SELECT offset FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName)
+ query := fmt.Sprintf("SELECT offset FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName)
q := c.client.Query(query)
q.DefaultProjectID = c.projectID
q.DefaultDatasetID = c.datasetID
@@ -339,8 +339,7 @@ func (c *BigQueryConnector) GetLastOffset(jobName string) (int64, error) {
func (c *BigQueryConnector) SetLastOffset(jobName string, lastOffset int64) error {
query := fmt.Sprintf(
- "UPDATE %s.%s SET offset = GREATEST(offset, %d) WHERE mirror_job_name = '%s'",
- c.datasetID,
+ "UPDATE %s SET offset = GREATEST(offset, %d) WHERE mirror_job_name = '%s'",
MirrorJobsTable,
lastOffset,
jobName,
@@ -357,8 +356,8 @@ func (c *BigQueryConnector) SetLastOffset(jobName string, lastOffset int64) erro
}
func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) {
- query := fmt.Sprintf("SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name = '%s'",
- c.datasetID, MirrorJobsTable, jobName)
+ query := fmt.Sprintf("SELECT sync_batch_id FROM %s WHERE mirror_job_name = '%s'",
+ MirrorJobsTable, jobName)
q := c.client.Query(query)
q.DisableQueryCache = true
q.DefaultProjectID = c.projectID
@@ -385,8 +384,8 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) {
}
func (c *BigQueryConnector) GetLastNormalizeBatchID(jobName string) (int64, error) {
- query := fmt.Sprintf("SELECT normalize_batch_id FROM %s.%s WHERE mirror_job_name = '%s'",
- c.datasetID, MirrorJobsTable, jobName)
+ query := fmt.Sprintf("SELECT normalize_batch_id FROM %s WHERE mirror_job_name = '%s'",
+ MirrorJobsTable, jobName)
q := c.client.Query(query)
q.DisableQueryCache = true
q.DefaultProjectID = c.projectID
@@ -416,9 +415,9 @@ func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syn
rawTableName := c.getRawTableName(flowJobName)
// Prepare the query to retrieve distinct tables in that batch
- query := fmt.Sprintf(`SELECT DISTINCT _peerdb_destination_table_name FROM %s.%s
+ query := fmt.Sprintf(`SELECT DISTINCT _peerdb_destination_table_name FROM %s
WHERE _peerdb_batch_id > %d and _peerdb_batch_id <= %d`,
- c.datasetID, rawTableName, normalizeBatchID, syncBatchID)
+ rawTableName, normalizeBatchID, syncBatchID)
// Run the query
q := c.client.Query(query)
q.DefaultProjectID = c.projectID
@@ -459,10 +458,10 @@ func (c *BigQueryConnector) getTableNametoUnchangedCols(flowJobName string, sync
// where a placeholder value for unchanged cols can be set in DeleteRecord if there is no backfill
// we don't want these particular DeleteRecords to be used in the update statement
query := fmt.Sprintf(`SELECT _peerdb_destination_table_name,
- array_agg(DISTINCT _peerdb_unchanged_toast_columns) as unchanged_toast_columns FROM %s.%s
+ array_agg(DISTINCT _peerdb_unchanged_toast_columns) as unchanged_toast_columns FROM %s
WHERE _peerdb_batch_id > %d AND _peerdb_batch_id <= %d AND _peerdb_record_type != 2
GROUP BY _peerdb_destination_table_name`,
- c.datasetID, rawTableName, normalizeBatchID, syncBatchID)
+ rawTableName, normalizeBatchID, syncBatchID)
// Run the query
q := c.client.Query(query)
q.DefaultDatasetID = c.datasetID
@@ -587,6 +586,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
dstDatasetTable, _ := c.convertToDatasetTable(tableName)
mergeGen := &mergeStmtGenerator{
rawDatasetTable: &datasetTable{
+ project: c.projectID,
dataset: c.datasetID,
table: rawTableName,
},
@@ -609,7 +609,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
i+1, len(mergeStmts), tableName))
q := c.client.Query(mergeStmt)
q.DefaultProjectID = c.projectID
- q.DefaultDatasetID = c.datasetID
+ q.DefaultDatasetID = dstDatasetTable.dataset
_, err = q.Read(c.ctx)
if err != nil {
return nil, fmt.Errorf("failed to execute merge statement %s: %v", mergeStmt, err)
@@ -618,8 +618,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
}
// update metadata to make the last normalized batch id to the recent last sync batch id.
updateMetadataStmt := fmt.Sprintf(
- "UPDATE %s.%s SET normalize_batch_id=%d WHERE mirror_job_name='%s';",
- c.datasetID, MirrorJobsTable, req.SyncBatchID, req.FlowJobName)
+ "UPDATE %s SET normalize_batch_id=%d WHERE mirror_job_name='%s';",
+ MirrorJobsTable, req.SyncBatchID, req.FlowJobName)
query := c.client.Query(updateMetadataStmt)
query.DefaultProjectID = c.projectID
@@ -719,12 +719,12 @@ func (c *BigQueryConnector) getUpdateMetadataStmt(jobName string, lastSyncedChec
// create the job in the metadata table
jobStatement := fmt.Sprintf(
- "INSERT INTO %s.%s (mirror_job_name,offset,sync_batch_id) VALUES ('%s',%d,%d);",
- c.datasetID, MirrorJobsTable, jobName, lastSyncedCheckpointID, batchID)
+ "INSERT INTO %s (mirror_job_name,offset,sync_batch_id) VALUES ('%s',%d,%d);",
+ MirrorJobsTable, jobName, lastSyncedCheckpointID, batchID)
if hasJob {
jobStatement = fmt.Sprintf(
- "UPDATE %s.%s SET offset=GREATEST(offset,%d),sync_batch_id=%d WHERE mirror_job_name = '%s';",
- c.datasetID, MirrorJobsTable, lastSyncedCheckpointID, batchID, jobName)
+ "UPDATE %s SET offset=GREATEST(offset,%d),sync_batch_id=%d WHERE mirror_job_name = '%s';",
+ MirrorJobsTable, lastSyncedCheckpointID, batchID, jobName)
}
return jobStatement, nil
@@ -733,8 +733,8 @@ func (c *BigQueryConnector) getUpdateMetadataStmt(jobName string, lastSyncedChec
// metadataHasJob checks if the metadata table has the given job.
func (c *BigQueryConnector) metadataHasJob(jobName string) (bool, error) {
checkStmt := fmt.Sprintf(
- "SELECT COUNT(*) FROM %s.%s WHERE mirror_job_name = '%s'",
- c.datasetID, MirrorJobsTable, jobName)
+ "SELECT COUNT(*) FROM %s WHERE mirror_job_name = '%s'",
+ MirrorJobsTable, jobName)
q := c.client.Query(checkStmt)
q.DefaultProjectID = c.projectID
@@ -804,13 +804,14 @@ func (c *BigQueryConnector) SetupNormalizedTables(
// convert the column names and types to bigquery types
columns := make([]*bigquery.FieldSchema, 0, len(tableSchema.ColumnNames)+2)
- utils.IterColumns(tableSchema, func(colName, genericColType string) {
+ for i, colName := range tableSchema.ColumnNames {
+ genericColType := tableSchema.ColumnTypes[i]
columns = append(columns, &bigquery.FieldSchema{
Name: colName,
Type: qValueKindToBigQueryType(genericColType),
Repeated: qvalue.QValueKind(genericColType).IsArray(),
})
- })
+ }
if req.SoftDeleteColName != "" {
columns = append(columns, &bigquery.FieldSchema{
@@ -872,7 +873,7 @@ func (c *BigQueryConnector) SyncFlowCleanup(jobName string) error {
}
// deleting job from metadata table
- query := fmt.Sprintf("DELETE FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName)
+ query := fmt.Sprintf("DELETE FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName)
queryHandler := c.client.Query(query)
queryHandler.DefaultProjectID = c.projectID
queryHandler.DefaultDatasetID = c.datasetID
@@ -902,7 +903,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos
dstDatasetTable.string()))
if req.SoftDeleteColName != nil {
- allCols := strings.Join(utils.TableSchemaColumnNames(renameRequest.TableSchema), ",")
+ allCols := strings.Join(renameRequest.TableSchema.ColumnNames, ",")
pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",")
c.logger.InfoContext(c.ctx, fmt.Sprintf("handling soft-deletes for table '%s'...", dstDatasetTable.string()))
@@ -1011,12 +1012,16 @@ func (c *BigQueryConnector) CreateTablesFromExisting(req *protos.CreateTablesFro
}
type datasetTable struct {
+ project string
dataset string
table string
}
func (d *datasetTable) string() string {
- return fmt.Sprintf("%s.%s", d.dataset, d.table)
+ if d.project == "" {
+ return fmt.Sprintf("%s.%s", d.dataset, d.table)
+ }
+ return fmt.Sprintf("%s.%s.%s", d.project, d.dataset, d.table)
}
func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTable, error) {
@@ -1031,6 +1036,12 @@ func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTab
dataset: parts[0],
table: parts[1],
}, nil
+ } else if len(parts) == 3 {
+ return &datasetTable{
+ project: parts[0],
+ dataset: parts[1],
+ table: parts[2],
+ }, nil
} else {
return nil, fmt.Errorf("invalid BigQuery table name: %s", tableName)
}
diff --git a/flow/connectors/bigquery/merge_stmt_generator.go b/flow/connectors/bigquery/merge_stmt_generator.go
index d87a83a290..8c7d1992cc 100644
--- a/flow/connectors/bigquery/merge_stmt_generator.go
+++ b/flow/connectors/bigquery/merge_stmt_generator.go
@@ -34,7 +34,7 @@ type mergeStmtGenerator struct {
func (m *mergeStmtGenerator) generateFlattenedCTE() string {
// for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR
// statement.
- flattenedProjs := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)+3)
+ flattenedProjs := make([]string, 0, len(m.normalizedTableSchema.ColumnNames)+3)
for i, colName := range m.normalizedTableSchema.ColumnNames {
colType := m.normalizedTableSchema.ColumnTypes[i]
@@ -93,9 +93,9 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string {
)
// normalize anything between last normalized batch id to last sync batchid
- return fmt.Sprintf(`WITH _f AS
- (SELECT %s FROM %s WHERE _peerdb_batch_id>%d AND _peerdb_batch_id<=%d AND
- _peerdb_destination_table_name='%s')`,
+ return fmt.Sprintf("WITH _f AS "+
+ "(SELECT %s FROM `%s` WHERE _peerdb_batch_id>%d AND _peerdb_batch_id<=%d AND "+
+ "_peerdb_destination_table_name='%s')",
strings.Join(flattenedProjs, ","), m.rawDatasetTable.string(), m.normalizeBatchID,
m.syncBatchID, m.dstTableName)
}
@@ -124,7 +124,7 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string {
// generateMergeStmt generates a merge statement.
func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) string {
// comma separated list of column names
- columnCount := utils.TableSchemaColumns(m.normalizedTableSchema)
+ columnCount := len(m.normalizedTableSchema.ColumnNames)
backtickColNames := make([]string, 0, columnCount)
shortBacktickColNames := make([]string, 0, columnCount)
pureColNames := make([]string, 0, columnCount)
@@ -169,15 +169,11 @@ func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) s
}
}
- return fmt.Sprintf(`
- MERGE %s _t USING(%s,%s) _d
- ON %s
- WHEN NOT MATCHED AND _d._rt!=2 THEN
- INSERT (%s) VALUES(%s)
- %s
- WHEN MATCHED AND _d._rt=2 THEN
- %s;
- `, m.dstDatasetTable.string(), m.generateFlattenedCTE(), m.generateDeDupedCTE(),
+ return fmt.Sprintf("MERGE `%s` _t USING(%s,%s) _d"+
+ " ON %s WHEN NOT MATCHED AND _d._rt!=2 THEN "+
+ "INSERT (%s) VALUES(%s) "+
+ "%s WHEN MATCHED AND _d._rt=2 THEN %s;",
+ m.dstDatasetTable.table, m.generateFlattenedCTE(), m.generateDeDupedCTE(),
pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart)
}
diff --git a/flow/connectors/bigquery/qrep.go b/flow/connectors/bigquery/qrep.go
index 5878dd2bd0..c0aafe045f 100644
--- a/flow/connectors/bigquery/qrep.go
+++ b/flow/connectors/bigquery/qrep.go
@@ -112,10 +112,10 @@ func (c *BigQueryConnector) createMetadataInsertStatement(
partitionJSON := string(pbytes)
insertMetadataStmt := fmt.Sprintf(
- "INSERT INTO %s._peerdb_query_replication_metadata"+
+ "INSERT INTO _peerdb_query_replication_metadata"+
"(flowJobName, partitionID, syncPartition, syncStartTime, syncFinishTime) "+
"VALUES ('%s', '%s', JSON '%s', TIMESTAMP('%s'), CURRENT_TIMESTAMP());",
- c.datasetID, jobName, partition.PartitionId,
+ jobName, partition.PartitionId,
partitionJSON, startTime.Format(time.RFC3339))
return insertMetadataStmt, nil
@@ -170,8 +170,8 @@ func (c *BigQueryConnector) SetupQRepMetadataTables(config *protos.QRepConfig) e
func (c *BigQueryConnector) isPartitionSynced(partitionID string) (bool, error) {
queryString := fmt.Sprintf(
- "SELECT COUNT(*) FROM %s._peerdb_query_replication_metadata WHERE partitionID = '%s';",
- c.datasetID, partitionID,
+ "SELECT COUNT(*) FROM _peerdb_query_replication_metadata WHERE partitionID = '%s';",
+ partitionID,
)
query := c.client.Query(queryString)
diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go
index 7e6768963a..c8b182706f 100644
--- a/flow/connectors/bigquery/qrep_avro_sync.go
+++ b/flow/connectors/bigquery/qrep_avro_sync.go
@@ -58,6 +58,7 @@ func (s *QRepAvroSyncMethod) SyncRecords(
stagingTable := fmt.Sprintf("%s_%s_staging", rawTableName, strconv.FormatInt(syncBatchID, 10))
numRecords, err := s.writeToStage(strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema,
&datasetTable{
+ project: s.connector.projectID,
dataset: s.connector.datasetID,
table: stagingTable,
}, stream, req.FlowJobName)
@@ -67,8 +68,8 @@ func (s *QRepAvroSyncMethod) SyncRecords(
bqClient := s.connector.client
datasetID := s.connector.datasetID
- insertStmt := fmt.Sprintf("INSERT INTO `%s.%s` SELECT * FROM `%s.%s`;",
- datasetID, rawTableName, datasetID, stagingTable)
+ insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT * FROM `%s`;",
+ rawTableName, stagingTable)
lastCP, err := req.Records.GetLastCheckpoint()
if err != nil {
@@ -171,6 +172,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
// create a staging table name with partitionID replace hyphens with underscores
dstDatasetTable, _ := s.connector.convertToDatasetTable(dstTableName)
stagingDatasetTable := &datasetTable{
+ project: s.connector.projectID,
dataset: dstDatasetTable.dataset,
table: fmt.Sprintf("%s_%s_staging", dstDatasetTable.table,
strings.ReplaceAll(partition.PartitionId, "-", "_")),
@@ -198,7 +200,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
}
// Insert the records from the staging table into the destination table
insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT %s FROM `%s`;",
- dstDatasetTable.string(), selector, stagingDatasetTable.string())
+ dstTableName, selector, stagingDatasetTable.string())
insertMetadataStmt, err := s.connector.createMetadataInsertStatement(partition, flowJobName, startTime)
if err != nil {
@@ -229,7 +231,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
flowLog)
}
- slog.Info(fmt.Sprintf("loaded stage into %s", dstDatasetTable.string()), flowLog)
+ slog.Info(fmt.Sprintf("loaded stage into %s", dstTableName), flowLog)
return numRecords, nil
}
diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go
index 5ef8f70b7a..c6ae8f7d7f 100644
--- a/flow/connectors/postgres/client.go
+++ b/flow/connectors/postgres/client.go
@@ -77,7 +77,7 @@ const (
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank
FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3
)
- DELETE FROM %s USING %s FROM src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2`
+ %s src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2`
dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s"
deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE mirror_job_name=$1"
@@ -257,18 +257,27 @@ func (c *PostgresConnector) checkSlotAndPublication(slot string, publication str
func (c *PostgresConnector) GetSlotInfo(slotName string) ([]*protos.SlotInfo, error) {
var whereClause string
if slotName != "" {
- whereClause = fmt.Sprintf(" WHERE slot_name = %s", QuoteLiteral(slotName))
+ whereClause = fmt.Sprintf("WHERE slot_name=%s", QuoteLiteral(slotName))
} else {
- whereClause = fmt.Sprintf(" WHERE database = %s", QuoteLiteral(c.config.Database))
+ whereClause = fmt.Sprintf("WHERE database=%s", QuoteLiteral(c.config.Database))
}
- rows, err := c.pool.Query(c.ctx, "SELECT slot_name, redo_lsn::Text,restart_lsn::text,wal_status,"+
- "confirmed_flush_lsn::text,active,"+
- "round((CASE WHEN pg_is_in_recovery() THEN pg_last_wal_receive_lsn() ELSE pg_current_wal_lsn() END"+
- " - confirmed_flush_lsn) / 1024 / 1024) AS MB_Behind"+
- " FROM pg_control_checkpoint(), pg_replication_slots"+whereClause)
+
+ hasWALStatus, _, err := c.MajorVersionCheck(POSTGRES_13)
if err != nil {
return nil, err
}
+ walStatusSelector := "wal_status"
+ if !hasWALStatus {
+ walStatusSelector = "'unknown'"
+ }
+ rows, err := c.pool.Query(c.ctx, fmt.Sprintf(`SELECT slot_name, redo_lsn::Text,restart_lsn::text,%s,
+ confirmed_flush_lsn::text,active,
+ round((CASE WHEN pg_is_in_recovery() THEN pg_last_wal_receive_lsn() ELSE pg_current_wal_lsn() END
+ - confirmed_flush_lsn) / 1024 / 1024) AS MB_Behind
+ FROM pg_control_checkpoint(),pg_replication_slots %s`, walStatusSelector, whereClause))
+ if err != nil {
+ return nil, fmt.Errorf("failed to read information for slots: %w", err)
+ }
defer rows.Close()
var slotInfoRows []*protos.SlotInfo
for rows.Next() {
@@ -415,11 +424,12 @@ func generateCreateTableSQLForNormalizedTable(
softDeleteColName string,
syncedAtColName string,
) string {
- createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2)
- utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) {
+ createTableSQLArray := make([]string, 0, len(sourceTableSchema.ColumnNames)+2)
+ for i, columnName := range sourceTableSchema.ColumnNames {
+ genericColumnType := sourceTableSchema.ColumnTypes[i]
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf("%s %s", QuoteIdentifier(columnName), qValueKindToPostgresType(genericColumnType)))
- })
+ }
if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go
index 3792c188af..47da11f0c4 100644
--- a/flow/connectors/postgres/normalize_stmt_generator.go
+++ b/flow/connectors/postgres/normalize_stmt_generator.go
@@ -44,11 +44,12 @@ func (n *normalizeStmtGenerator) generateNormalizeStatements() []string {
}
func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
- columnCount := utils.TableSchemaColumns(n.normalizedTableSchema)
+ columnCount := len(n.normalizedTableSchema.ColumnNames)
columnNames := make([]string, 0, columnCount)
flattenedCastsSQLArray := make([]string, 0, columnCount)
primaryKeyColumnCasts := make(map[string]string, len(n.normalizedTableSchema.PrimaryKeyColumns))
- utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) {
+ for i, columnName := range n.normalizedTableSchema.ColumnNames {
+ genericColumnType := n.normalizedTableSchema.ColumnTypes[i]
quotedCol := QuoteIdentifier(columnName)
stringCol := QuoteLiteral(columnName)
columnNames = append(columnNames, quotedCol)
@@ -64,16 +65,16 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType)
}
- })
+ }
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")
parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName)
insertColumnsSQL := strings.Join(columnNames, ",")
- updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema))
- utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) {
+ updateColumnsSQLArray := make([]string, 0, columnCount)
+ for _, columnName := range n.normalizedTableSchema.ColumnNames {
quotedCol := QuoteIdentifier(columnName)
updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`%s=EXCLUDED.%s`, quotedCol, quotedCol))
- })
+ }
updateColumnsSQL := strings.Join(updateColumnsSQLArray, ",")
deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns))
for columnName, columnCast := range primaryKeyColumnCasts {
@@ -82,13 +83,15 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
}
deleteWhereClauseSQL := strings.Join(deleteWhereClauseArray, " AND ")
- deleteUpdate := ""
+ // make it update instead in case soft-delete is enabled
+ deleteUpdate := fmt.Sprintf(`DELETE FROM %s USING `, parsedDstTable.String())
if n.peerdbCols.SoftDelete {
deleteUpdate = fmt.Sprintf(`UPDATE %s SET %s=TRUE`,
parsedDstTable.String(), QuoteIdentifier(n.peerdbCols.SoftDeleteColName))
if n.peerdbCols.SyncedAtColName != "" {
deleteUpdate += fmt.Sprintf(`,%s=CURRENT_TIMESTAMP`, QuoteIdentifier(n.peerdbCols.SyncedAtColName))
}
+ deleteUpdate += " FROM"
}
fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL,
strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema,
@@ -96,25 +99,26 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
strings.Join(n.normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL)
fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL,
strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema,
- n.rawTableName, parsedDstTable.String(), deleteUpdate, deleteWhereClauseSQL)
+ n.rawTableName, deleteUpdate, deleteWhereClauseSQL)
return []string{fallbackUpsertStatement, fallbackDeleteStatement}
}
func (n *normalizeStmtGenerator) generateMergeStatement() string {
- quotedColumnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema)
- for i, columnName := range quotedColumnNames {
- quotedColumnNames[i] = QuoteIdentifier(columnName)
- }
+ columnCount := len(n.normalizedTableSchema.ColumnNames)
+ quotedColumnNames := make([]string, columnCount)
- flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema))
+ flattenedCastsSQLArray := make([]string, 0, columnCount)
parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName)
primaryKeyColumnCasts := make(map[string]string)
primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns))
- utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) {
+ for i, columnName := range n.normalizedTableSchema.ColumnNames {
+ genericColumnType := n.normalizedTableSchema.ColumnTypes[i]
quotedCol := QuoteIdentifier(columnName)
stringCol := QuoteLiteral(columnName)
+ quotedColumnNames[i] = quotedCol
+
pgType := qValueKindToPostgresType(genericColumnType)
if qvalue.QValueKind(genericColumnType).IsArray() {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
@@ -129,9 +133,9 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string {
primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s",
quotedCol, quotedCol))
}
- })
+ }
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")
- insertValuesSQLArray := make([]string, 0, len(quotedColumnNames)+2)
+ insertValuesSQLArray := make([]string, 0, columnCount+2)
for _, quotedCol := range quotedColumnNames {
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", quotedCol))
}
diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go
index 4cbfde6f19..231d5f78f2 100644
--- a/flow/connectors/postgres/postgres.go
+++ b/flow/connectors/postgres/postgres.go
@@ -55,15 +55,16 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
// set pool size to 3 to avoid connection pool exhaustion
connConfig.MaxConns = 3
- // ensure that replication is set to database
- replConfig.ConnConfig.RuntimeParams["replication"] = "database"
- replConfig.ConnConfig.RuntimeParams["bytea_output"] = "hex"
- replConfig.MaxConns = 1
pool, err := NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
+ // ensure that replication is set to database
+ replConfig.ConnConfig.RuntimeParams["replication"] = "database"
+ replConfig.ConnConfig.RuntimeParams["bytea_output"] = "hex"
+ replConfig.MaxConns = 1
+
customTypeMap, err := utils.GetCustomDataTypes(ctx, pool.Pool)
if err != nil {
return nil, fmt.Errorf("failed to get custom type map: %w", err)
diff --git a/flow/connectors/postgres/postgres_schema_delta_test.go b/flow/connectors/postgres/postgres_schema_delta_test.go
index 4c3b012243..450303da42 100644
--- a/flow/connectors/postgres/postgres_schema_delta_test.go
+++ b/flow/connectors/postgres/postgres_schema_delta_test.go
@@ -9,7 +9,6 @@ import (
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/require"
- "github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/e2eshared"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
@@ -121,14 +120,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() {
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
- utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
+ for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "id" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
- ColumnType: columnType,
+ ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
- })
+ }
err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
@@ -171,14 +170,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() {
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
- utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
+ for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "id" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
- ColumnType: columnType,
+ ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
- })
+ }
err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
@@ -212,14 +211,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() {
PrimaryKeyColumns: []string{" "},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
- utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
+ for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != " " {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
- ColumnType: columnType,
+ ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
- })
+ }
err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
diff --git a/flow/connectors/postgres/ssh_wrapped_pool.go b/flow/connectors/postgres/ssh_wrapped_pool.go
index 4f17116ea4..4dcd2cd0ce 100644
--- a/flow/connectors/postgres/ssh_wrapped_pool.go
+++ b/flow/connectors/postgres/ssh_wrapped_pool.go
@@ -27,6 +27,23 @@ type SSHWrappedPostgresPool struct {
cancel context.CancelFunc
}
+func NewSSHWrappedPostgresPoolFromConfig(
+ ctx context.Context,
+ pgConfig *protos.PostgresConfig,
+) (*SSHWrappedPostgresPool, error) {
+ connectionString := utils.GetPGConnectionString(pgConfig)
+
+ connConfig, err := pgxpool.ParseConfig(connectionString)
+ if err != nil {
+ return nil, err
+ }
+
+ // set pool size to 3 to avoid connection pool exhaustion
+ connConfig.MaxConns = 3
+
+ return NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig)
+}
+
func NewSSHWrappedPostgresPool(
ctx context.Context,
poolConfig *pgxpool.Config,
diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go
index 291b3314d9..d849fe5a58 100644
--- a/flow/connectors/snowflake/merge_stmt_generator.go
+++ b/flow/connectors/snowflake/merge_stmt_generator.go
@@ -27,14 +27,15 @@ type mergeStmtGenerator struct {
func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
parsedDstTable, _ := utils.ParseSchemaTable(m.dstTableName)
- columnNames := utils.TableSchemaColumnNames(m.normalizedTableSchema)
+ columnNames := m.normalizedTableSchema.ColumnNames
- flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema))
- err := utils.IterColumnsError(m.normalizedTableSchema, func(columnName, genericColumnType string) error {
+ flattenedCastsSQLArray := make([]string, 0, len(columnNames))
+ for i, columnName := range columnNames {
+ genericColumnType := m.normalizedTableSchema.ColumnTypes[i]
qvKind := qvalue.QValueKind(genericColumnType)
sfType, err := qValueKindToSnowflakeType(qvKind)
if err != nil {
- return fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err)
+ return "", fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err)
}
targetColumnName := SnowflakeIdentifierNormalize(columnName)
@@ -69,10 +70,6 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
toVariantColumnName, columnName, sfType, targetColumnName))
}
}
- return nil
- })
- if err != nil {
- return "", err
}
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")
diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go
index e60bf8993e..ac8e8badcf 100644
--- a/flow/connectors/snowflake/snowflake.go
+++ b/flow/connectors/snowflake/snowflake.go
@@ -831,17 +831,18 @@ func generateCreateTableSQLForNormalizedTable(
softDeleteColName string,
syncedAtColName string,
) string {
- createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2)
- utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) {
+ createTableSQLArray := make([]string, 0, len(sourceTableSchema.ColumnNames)+2)
+ for i, columnName := range sourceTableSchema.ColumnNames {
+ genericColumnType := sourceTableSchema.ColumnTypes[i]
normalizedColName := SnowflakeIdentifierNormalize(columnName)
sfColType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType))
if err != nil {
slog.Warn(fmt.Sprintf("failed to convert column type %s to snowflake type", genericColumnType),
slog.Any("error", err))
- return
+ continue
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s`, normalizedColName, sfColType))
- })
+ }
// add a _peerdb_is_deleted column to the normalized table
// this is boolean default false, and is used to mark records as deleted
@@ -997,7 +998,7 @@ func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*proto
for _, renameRequest := range req.RenameTableOptions {
src := renameRequest.CurrentName
dst := renameRequest.NewName
- allCols := strings.Join(utils.TableSchemaColumnNames(renameRequest.TableSchema), ",")
+ allCols := strings.Join(renameRequest.TableSchema.ColumnNames, ",")
pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",")
c.logger.Info(fmt.Sprintf("handling soft-deletes for table '%s'...", dst))
diff --git a/flow/connectors/utils/columns.go b/flow/connectors/utils/columns.go
deleted file mode 100644
index f1e0340f03..0000000000
--- a/flow/connectors/utils/columns.go
+++ /dev/null
@@ -1,31 +0,0 @@
-package utils
-
-import (
- "slices"
-
- "github.com/PeerDB-io/peer-flow/generated/protos"
-)
-
-func TableSchemaColumns(schema *protos.TableSchema) int {
- return len(schema.ColumnNames)
-}
-
-func TableSchemaColumnNames(schema *protos.TableSchema) []string {
- return slices.Clone(schema.ColumnNames)
-}
-
-func IterColumns(schema *protos.TableSchema, iter func(k, v string)) {
- for i, name := range schema.ColumnNames {
- iter(name, schema.ColumnTypes[i])
- }
-}
-
-func IterColumnsError(schema *protos.TableSchema, iter func(k, v string) error) error {
- for i, name := range schema.ColumnNames {
- err := iter(name, schema.ColumnTypes[i])
- if err != nil {
- return err
- }
- }
- return nil
-}
diff --git a/flow/e2e/snowflake/snowflake_schema_delta_test.go b/flow/e2e/snowflake/snowflake_schema_delta_test.go
index 52f02b005e..f83d1ac679 100644
--- a/flow/e2e/snowflake/snowflake_schema_delta_test.go
+++ b/flow/e2e/snowflake/snowflake_schema_delta_test.go
@@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/require"
connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
- "github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/e2eshared"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
@@ -99,14 +98,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddAllColumnTypes() {
},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
- utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
+ for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "ID" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
- ColumnType: columnType,
+ ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
- })
+ }
err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
@@ -154,14 +153,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddTrickyColumnNames() {
},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
- utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
+ for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "ID" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
- ColumnType: columnType,
+ ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
- })
+ }
err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
@@ -193,14 +192,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddWhitespaceColumnNames() {
},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
- utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
+ for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != " " {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
- ColumnType: columnType,
+ ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
- })
+ }
err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
diff --git a/flow/model/model.go b/flow/model/model.go
index dd875a2583..d8e3a7c751 100644
--- a/flow/model/model.go
+++ b/flow/model/model.go
@@ -512,8 +512,6 @@ type SyncRecordsRequest struct {
Records *CDCRecordStream
// FlowJobName is the name of the flow job.
FlowJobName string
- // SyncMode to use for pushing raw records
- SyncMode protos.QRepSyncMode
// source:destination mappings
TableMappings []*protos.TableMapping
// Staging path for AVRO files in CDC
diff --git a/flow/workflows/cdc_flow.go b/flow/workflows/cdc_flow.go
index 34c72c9371..a0d73aa593 100644
--- a/flow/workflows/cdc_flow.go
+++ b/flow/workflows/cdc_flow.go
@@ -55,11 +55,6 @@ type CDCFlowWorkflowState struct {
FlowConfigUpdates []*protos.CDCFlowConfigUpdate
}
-type SignalProps struct {
- BatchSize uint32
- IdleTimeout uint64
-}
-
// returns a new empty PeerFlowState
func NewCDCFlowWorkflowState(numTables int) *CDCFlowWorkflowState {
return &CDCFlowWorkflowState{
@@ -170,8 +165,6 @@ func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdates(ctx workflow.Cont
additionalTablesWorkflowCfg.DoInitialSnapshot = true
additionalTablesWorkflowCfg.InitialSnapshotOnly = true
additionalTablesWorkflowCfg.TableMappings = flowConfigUpdate.AdditionalTables
- additionalTablesWorkflowCfg.FlowJobName = fmt.Sprintf("%s_additional_tables_%s", cfg.FlowJobName,
- strings.ToLower(shared.RandomString(8)))
childAdditionalTablesCDCFlowID,
err := GetChildWorkflowID(ctx, "cdc-flow", additionalTablesWorkflowCfg.FlowJobName)
@@ -377,18 +370,23 @@ func CDCFlowWorkflowWithConfig(
cdcPropertiesSignalChannel := workflow.GetSignalChannel(ctx, shared.CDCDynamicPropertiesSignalName)
cdcPropertiesSelector := workflow.NewSelector(ctx)
cdcPropertiesSelector.AddReceive(cdcPropertiesSignalChannel, func(c workflow.ReceiveChannel, more bool) {
- var cdcSignal SignalProps
- c.Receive(ctx, &cdcSignal)
+ var cdcConfigUpdate *protos.CDCFlowConfigUpdate
+ c.Receive(ctx, &cdcConfigUpdate)
// only modify for options since SyncFlow uses it
- if cdcSignal.BatchSize > 0 {
- syncFlowOptions.BatchSize = cdcSignal.BatchSize
+ if cdcConfigUpdate.BatchSize > 0 {
+ syncFlowOptions.BatchSize = cdcConfigUpdate.BatchSize
+ }
+ if cdcConfigUpdate.IdleTimeout > 0 {
+ syncFlowOptions.IdleTimeoutSeconds = cdcConfigUpdate.IdleTimeout
}
- if cdcSignal.IdleTimeout > 0 {
- syncFlowOptions.IdleTimeoutSeconds = cdcSignal.IdleTimeout
+ if len(cdcConfigUpdate.AdditionalTables) > 0 {
+ state.FlowConfigUpdates = append(state.FlowConfigUpdates, cdcConfigUpdate)
}
- slog.Info("CDC Signal received. Parameters on signal reception:", slog.Int("BatchSize", int(cfg.MaxBatchSize)),
- slog.Int("IdleTimeout", int(cfg.IdleTimeoutSeconds)))
+ slog.Info("CDC Signal received. Parameters on signal reception:",
+ slog.Int("BatchSize", int(syncFlowOptions.BatchSize)),
+ slog.Int("IdleTimeout", int(syncFlowOptions.IdleTimeoutSeconds)),
+ slog.Any("AdditionalTables", cdcConfigUpdate.AdditionalTables))
})
cdcPropertiesSelector.AddDefault(func() {
diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go
index 7bee648f4e..54d5c95a99 100644
--- a/flow/workflows/setup_flow.go
+++ b/flow/workflows/setup_flow.go
@@ -11,7 +11,6 @@ import (
"golang.org/x/exp/maps"
"github.com/PeerDB-io/peer-flow/activities"
- "github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
)
@@ -207,15 +206,16 @@ func (s *SetupFlowExecution) fetchTableSchemaAndSetupNormalizedTables(
for _, mapping := range flowConnectionConfigs.TableMappings {
if mapping.SourceTableIdentifier == srcTableName {
if len(mapping.Exclude) != 0 {
- columnCount := utils.TableSchemaColumns(tableSchema)
+ columnCount := len(tableSchema.ColumnNames)
columnNames := make([]string, 0, columnCount)
columnTypes := make([]string, 0, columnCount)
- utils.IterColumns(tableSchema, func(columnName, columnType string) {
+ for i, columnName := range tableSchema.ColumnNames {
+ columnType := tableSchema.ColumnTypes[i]
if !slices.Contains(mapping.Exclude, columnName) {
columnNames = append(columnNames, columnName)
columnTypes = append(columnTypes, columnType)
}
- })
+ }
tableSchema = &protos.TableSchema{
TableIdentifier: tableSchema.TableIdentifier,
PrimaryKeyColumns: tableSchema.PrimaryKeyColumns,
diff --git a/flow/workflows/snapshot_flow.go b/flow/workflows/snapshot_flow.go
index d38801b599..5b1035eb1c 100644
--- a/flow/workflows/snapshot_flow.go
+++ b/flow/workflows/snapshot_flow.go
@@ -4,6 +4,7 @@ import (
"fmt"
"log/slog"
"regexp"
+ "slices"
"strings"
"time"
@@ -13,6 +14,7 @@ import (
"go.temporal.io/sdk/workflow"
"github.com/PeerDB-io/peer-flow/concurrency"
+ connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/shared"
@@ -141,11 +143,13 @@ func (s *SnapshotFlowExecution) cloneTable(
if len(mapping.Exclude) != 0 {
for _, v := range s.tableNameSchemaMapping {
if v.TableIdentifier == srcName {
- colNames := utils.TableSchemaColumnNames(v)
- for i, colName := range colNames {
- colNames[i] = fmt.Sprintf(`"%s"`, colName)
+ quotedColumns := make([]string, 0, len(v.ColumnNames))
+ for _, colName := range v.ColumnNames {
+ if !slices.Contains(mapping.Exclude, colName) {
+ quotedColumns = append(quotedColumns, connpostgres.QuoteIdentifier(colName))
+ }
}
- from = strings.Join(colNames, ",")
+ from = strings.Join(quotedColumns, ",")
break
}
}
@@ -254,20 +258,15 @@ func SnapshotFlowWorkflow(ctx workflow.Context, config *protos.FlowConnectionCon
logger: logger,
}
- numTablesInParallel := int(config.SnapshotNumTablesInParallel)
- if numTablesInParallel <= 0 {
- numTablesInParallel = 1
- }
-
- replCtx := ctx
+ numTablesInParallel := int(max(config.SnapshotNumTablesInParallel, 1))
if !config.DoInitialSnapshot {
- _, err := se.setupReplication(replCtx)
+ _, err := se.setupReplication(ctx)
if err != nil {
return fmt.Errorf("failed to setup replication: %w", err)
}
- if err := se.closeSlotKeepAlive(replCtx); err != nil {
+ if err := se.closeSlotKeepAlive(ctx); err != nil {
return fmt.Errorf("failed to close slot keep alive: %w", err)
}
diff --git a/protos/flow.proto b/protos/flow.proto
index 66ba78c8be..fe0f84ff86 100644
--- a/protos/flow.proto
+++ b/protos/flow.proto
@@ -246,11 +246,6 @@ message PartitionRange {
}
// protos for qrep
-enum QRepSyncMode {
- QREP_SYNC_MODE_MULTI_INSERT = 0;
- QREP_SYNC_MODE_STORAGE_AVRO = 1;
-}
-
enum QRepWriteType {
QREP_WRITE_MODE_APPEND = 0;
QREP_WRITE_MODE_UPSERT = 1;
@@ -277,19 +272,13 @@ message QRepConfig {
string watermark_column = 7;
bool initial_copy_only = 8;
- QRepSyncMode sync_mode = 9;
-
- // DEPRECATED: eliminate when breaking changes are allowed.
- uint32 batch_size_int = 10;
- // DEPRECATED: eliminate when breaking changes are allowed.
- uint32 batch_duration_seconds = 11;
- uint32 max_parallel_workers = 12;
+ uint32 max_parallel_workers = 9;
// time to wait between getting partitions to process
- uint32 wait_between_batches_seconds = 13;
+ uint32 wait_between_batches_seconds = 10;
- QRepWriteMode write_mode = 14;
+ QRepWriteMode write_mode = 11;
// This is only used when sync_mode is AVRO
// this is the location where the avro files will be written
@@ -297,22 +286,22 @@ message QRepConfig {
// if this starts with s3:// then it will be written to S3, only supported in Snowflake
// if nothing is specified then it will be written to local disk
// if using GCS or S3 make sure your instance has the correct permissions.
- string staging_path = 15;
+ string staging_path = 12;
// This setting overrides batch_size_int and batch_duration_seconds
// and instead uses the number of rows per partition to determine
// how many rows to process per batch.
- uint32 num_rows_per_partition = 16;
+ uint32 num_rows_per_partition = 13;
// Creates the watermark table on the destination as-is, can be used for some queries.
- bool setup_watermark_table_on_destination = 17;
+ bool setup_watermark_table_on_destination = 14;
// create new tables with "_peerdb_resync" suffix, perform initial load and then swap the new table with the old ones
// to be used after the old mirror is dropped
- bool dst_table_full_resync = 18;
+ bool dst_table_full_resync = 15;
- string synced_at_col_name = 19;
- string soft_delete_col_name = 20;
+ string synced_at_col_name = 16;
+ string soft_delete_col_name = 17;
}
message QRepPartition {
@@ -390,6 +379,8 @@ enum FlowStatus {
message CDCFlowConfigUpdate {
repeated TableMapping additional_tables = 1;
+ uint32 batch_size = 2;
+ uint64 idle_timeout = 3;
}
message QRepFlowConfigUpdate {
@@ -412,3 +403,4 @@ message AddTablesToPublicationInput{
string publication_name = 2;
repeated TableMapping additional_tables = 3;
}
+
diff --git a/protos/route.proto b/protos/route.proto
index 046a63aca3..48db51e019 100644
--- a/protos/route.proto
+++ b/protos/route.proto
@@ -285,7 +285,9 @@ service FlowService {
rpc ShutdownFlow(ShutdownRequest) returns (ShutdownResponse) {
option (google.api.http) = { post: "/v1/mirrors/drop", body: "*" };
}
- rpc FlowStateChange(FlowStateChangeRequest) returns (FlowStateChangeResponse) {}
+ rpc FlowStateChange(FlowStateChangeRequest) returns (FlowStateChangeResponse) {
+ option (google.api.http) = { post: "/v1/mirrors/state_change", body: "*" };
+ }
rpc MirrorStatus(MirrorStatusRequest) returns (MirrorStatusResponse) {
option (google.api.http) = { get: "/v1/mirrors/{flow_job_name}" };
}
diff --git a/ui/app/api/mirrors/drop/route.ts b/ui/app/api/mirrors/drop/route.ts
index e3be0f7c41..a52e4b078d 100644
--- a/ui/app/api/mirrors/drop/route.ts
+++ b/ui/app/api/mirrors/drop/route.ts
@@ -13,7 +13,7 @@ export async function POST(request: Request) {
destinationPeer,
removeFlowEntry: true,
};
- console.log('/drop/mirror: req:', req);
+ console.log('/mirrors/drop: req:', req);
try {
const dropStatus: ShutdownResponse = await fetch(
`${flowServiceAddr}/v1/mirrors/drop`,
diff --git a/ui/app/api/mirrors/state_change/route.ts b/ui/app/api/mirrors/state_change/route.ts
new file mode 100644
index 0000000000..bcc4dd2a8a
--- /dev/null
+++ b/ui/app/api/mirrors/state_change/route.ts
@@ -0,0 +1,23 @@
+import { FlowStateChangeResponse } from '@/grpc_generated/route';
+import { GetFlowHttpAddressFromEnv } from '@/rpc/http';
+
+export async function POST(request: Request) {
+ const body = await request.json();
+ const flowServiceAddr = GetFlowHttpAddressFromEnv();
+ console.log('/mirrors/state_change: req:', body);
+ try {
+ const res: FlowStateChangeResponse = await fetch(
+ `${flowServiceAddr}/v1/mirrors/state_change`,
+ {
+ method: 'POST',
+ body: JSON.stringify(body),
+ }
+ ).then((res) => {
+ return res.json();
+ });
+
+ return new Response(JSON.stringify(res));
+ } catch (e) {
+ console.error(e);
+ }
+}
diff --git a/ui/app/mirrors/create/cdc/cdc.tsx b/ui/app/mirrors/create/cdc/cdc.tsx
index 9101703c20..db26cefd51 100644
--- a/ui/app/mirrors/create/cdc/cdc.tsx
+++ b/ui/app/mirrors/create/cdc/cdc.tsx
@@ -1,5 +1,4 @@
'use client';
-import { QRepSyncMode } from '@/grpc_generated/flow';
import { DBType } from '@/grpc_generated/peers';
import { Button } from '@/lib/Button';
import { Icon } from '@/lib/Icon';
@@ -39,7 +38,7 @@ export default function CDCConfigForm({
}: MirrorConfigProps) {
const [show, setShow] = useState(false);
const handleChange = (val: string | boolean, setting: MirrorSetting) => {
- let stateVal: string | boolean | QRepSyncMode = val;
+ let stateVal: string | boolean = val;
setting.stateHandler(stateVal, setter);
};
diff --git a/ui/app/mirrors/create/handlers.ts b/ui/app/mirrors/create/handlers.ts
index 00dab7347b..a3fdab3178 100644
--- a/ui/app/mirrors/create/handlers.ts
+++ b/ui/app/mirrors/create/handlers.ts
@@ -8,7 +8,6 @@ import {
import {
FlowConnectionConfigs,
QRepConfig,
- QRepSyncMode,
QRepWriteType,
} from '@/grpc_generated/flow';
import { DBType, Peer, dBTypeToJSON } from '@/grpc_generated/peers';
@@ -28,28 +27,6 @@ export const handlePeer = (
) => {
if (!peer) return;
if (peerEnd === 'dst') {
- if (peer.type === DBType.POSTGRES) {
- setConfig((curr) => {
- return {
- ...curr,
- cdcSyncMode: QRepSyncMode.QREP_SYNC_MODE_MULTI_INSERT,
- snapshotSyncMode: QRepSyncMode.QREP_SYNC_MODE_MULTI_INSERT,
- syncMode: QRepSyncMode.QREP_SYNC_MODE_MULTI_INSERT,
- };
- });
- } else if (
- peer.type === DBType.SNOWFLAKE ||
- peer.type === DBType.BIGQUERY
- ) {
- setConfig((curr) => {
- return {
- ...curr,
- cdcSyncMode: QRepSyncMode.QREP_SYNC_MODE_STORAGE_AVRO,
- snapshotSyncMode: QRepSyncMode.QREP_SYNC_MODE_STORAGE_AVRO,
- syncMode: QRepSyncMode.QREP_SYNC_MODE_STORAGE_AVRO,
- };
- });
- }
setConfig((curr) => ({
...curr,
destination: peer,
@@ -238,12 +215,6 @@ export const handleCreateQRep = async (
config.flowJobName = flowJobName;
config.query = query;
- if (config.destinationPeer?.type == DBType.POSTGRES) {
- config.syncMode = QRepSyncMode.QREP_SYNC_MODE_MULTI_INSERT;
- } else {
- config.syncMode = QRepSyncMode.QREP_SYNC_MODE_STORAGE_AVRO;
- }
-
setLoading(true);
const statusMessage: UCreateMirrorResponse = await fetch(
'/api/mirrors/qrep',
diff --git a/ui/app/mirrors/create/helpers/cdc.ts b/ui/app/mirrors/create/helpers/cdc.ts
index f4bdf285f9..663e7aee22 100644
--- a/ui/app/mirrors/create/helpers/cdc.ts
+++ b/ui/app/mirrors/create/helpers/cdc.ts
@@ -29,7 +29,7 @@ export const cdcSettings: MirrorSetting[] = [
stateHandler: (value, setter) =>
setter((curr: CDCConfig) => ({
...curr,
- idleTimeoutSeconds: (value as number) || 100000,
+ idleTimeoutSeconds: (value as number) || 60,
})),
tips: 'Time after which a Sync flow ends, if it happens before pull batch size is reached. Defaults to 60 seconds.',
helpfulLink: 'https://docs.peerdb.io/metrics/important_cdc_configs',
diff --git a/ui/app/mirrors/create/helpers/common.ts b/ui/app/mirrors/create/helpers/common.ts
index 1af51e8d85..5d1b8d9a93 100644
--- a/ui/app/mirrors/create/helpers/common.ts
+++ b/ui/app/mirrors/create/helpers/common.ts
@@ -1,14 +1,10 @@
-import {
- FlowConnectionConfigs,
- QRepSyncMode,
- QRepWriteType,
-} from '@/grpc_generated/flow';
+import { FlowConnectionConfigs, QRepWriteType } from '@/grpc_generated/flow';
import { Peer } from '@/grpc_generated/peers';
export interface MirrorSetting {
label: string;
stateHandler: (
- value: string | string[] | Peer | boolean | QRepSyncMode | QRepWriteType,
+ value: string | string[] | Peer | boolean | QRepWriteType,
setter: any
) => void;
type?: string;
diff --git a/ui/app/mirrors/create/qrep/qrep.tsx b/ui/app/mirrors/create/qrep/qrep.tsx
index bfabd6a3ce..06b0379b47 100644
--- a/ui/app/mirrors/create/qrep/qrep.tsx
+++ b/ui/app/mirrors/create/qrep/qrep.tsx
@@ -1,6 +1,6 @@
'use client';
import { RequiredIndicator } from '@/components/RequiredIndicator';
-import { QRepConfig, QRepSyncMode, QRepWriteType } from '@/grpc_generated/flow';
+import { QRepConfig, QRepWriteType } from '@/grpc_generated/flow';
import { DBType } from '@/grpc_generated/peers';
import { Label } from '@/lib/Label';
import { RowWithSelect, RowWithSwitch, RowWithTextField } from '@/lib/Layout';
@@ -51,8 +51,7 @@ export default function QRepConfigForm({
const [loading, setLoading] = useState(false);
const handleChange = (val: string | boolean, setting: MirrorSetting) => {
- let stateVal: string | boolean | QRepSyncMode | QRepWriteType | string[] =
- val;
+ let stateVal: string | boolean | QRepWriteType | string[] = val;
if (setting.label.includes('Write Type')) {
switch (val) {
case 'Upsert':
diff --git a/ui/app/mirrors/edit/[mirrorId]/cdc.tsx b/ui/app/mirrors/edit/[mirrorId]/cdc.tsx
index e7f7ee0ca1..dd55b5a9bf 100644
--- a/ui/app/mirrors/edit/[mirrorId]/cdc.tsx
+++ b/ui/app/mirrors/edit/[mirrorId]/cdc.tsx
@@ -2,8 +2,8 @@
import { SyncStatusRow } from '@/app/dto/MirrorsDTO';
import TimeLabel from '@/components/TimeComponent';
import {
- CDCMirrorStatus,
CloneTableSummary,
+ MirrorStatusResponse,
SnapshotStatus,
} from '@/grpc_generated/route';
import { Button } from '@/lib/Button';
@@ -230,13 +230,13 @@ export const SnapshotStatusTable = ({ status }: SnapshotStatusProps) => {
};
type CDCMirrorStatusProps = {
- cdc: CDCMirrorStatus;
+ status: MirrorStatusResponse;
rows: SyncStatusRow[];
createdAt?: Date;
syncStatusChild?: React.ReactNode;
};
export function CDCMirror({
- cdc,
+ status,
rows,
createdAt,
syncStatusChild,
@@ -249,8 +249,10 @@ export function CDCMirror({
};
let snapshot = <>>;
- if (cdc.snapshotStatus) {
- snapshot =