diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index 4ee4cf9a29..2a14368fcb 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -524,7 +524,8 @@ func (c *PostgresConnector) SyncQRepRecords( switch syncMode { case protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT: stagingTableSync := &QRepStagingTableSync{connector: c} - return stagingTableSync.SyncQRepRecords(config.FlowJobName, dstTable, partition, stream) + return stagingTableSync.SyncQRepRecords( + config.FlowJobName, dstTable, partition, stream, config.WriteMode) case protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO: return 0, fmt.Errorf("[postgres] SyncQRepRecords not implemented for storage avro sync mode") default: diff --git a/flow/connectors/postgres/qrep_sync_method.go b/flow/connectors/postgres/qrep_sync_method.go index 3f6047d9f6..e6dafcaf61 100644 --- a/flow/connectors/postgres/qrep_sync_method.go +++ b/flow/connectors/postgres/qrep_sync_method.go @@ -3,6 +3,7 @@ package connpostgres import ( "context" "fmt" + "strings" "time" "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" @@ -31,6 +32,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords( dstTableName *SchemaTable, partition *protos.QRepPartition, stream *model.QRecordStream, + writeMode *protos.QRepWriteMode, ) (int, error) { partitionID := partition.PartitionId startTime := time.Now() @@ -66,19 +68,90 @@ func (s *QRepStagingTableSync) SyncQRepRecords( // Step 2: Insert records into the destination table. copySource := model.NewQRecordBatchCopyFromSource(stream) - // Perform the COPY FROM operation syncRecordsStartTime := time.Now() - syncedRows, err := tx.CopyFrom( - context.Background(), - pgx.Identifier{dstTableName.Schema, dstTableName.Table}, - schema.GetColumnNames(), - copySource, - ) + var numRowsSynced int64 - if err != nil { - return -1, fmt.Errorf("failed to copy records into destination table: %v", err) + if writeMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_APPEND { + // Perform the COPY FROM operation + numRowsSynced, err = tx.CopyFrom( + context.Background(), + pgx.Identifier{dstTableName.Schema, dstTableName.Table}, + schema.GetColumnNames(), + copySource, + ) + if err != nil { + return -1, fmt.Errorf("failed to copy records into destination table: %v", err) + } + } else { + // Step 2.1: Create a temp staging table + stagingTableName := fmt.Sprintf("%s_%s", dstTableName.Table, partitionID) + createStagingTableStmt := fmt.Sprintf( + "CREATE TABLE %s LIKE %s;", + pgx.Identifier{dstTableName.Schema, stagingTableName}, + pgx.Identifier{dstTableName.Schema, dstTableName.Table}, + ) + + log.Infof("Creating staging table %s", stagingTableName) + _, err = tx.Exec(context.Background(), createStagingTableStmt) + + if err != nil { + return -1, fmt.Errorf("failed to create staging table: %v", err) + } + + // Step 2.2: Insert records into the staging table + numRowsSynced, err = tx.CopyFrom( + context.Background(), + pgx.Identifier{dstTableName.Schema, stagingTableName}, + schema.GetColumnNames(), + copySource, + ) + if err != nil { + return -1, fmt.Errorf("failed to copy records into staging table: %v", err) + } + + // construct the SET clause for the upsert operation + upsertMatchColsList := writeMode.UpsertKeyColumns + upsertMatchCols := make(map[string]bool) + for _, col := range upsertMatchColsList { + upsertMatchCols[col] = true + } + + setClause := "" + for _, col := range schema.GetColumnNames() { + _, ok := upsertMatchCols[col] + if !ok { + setClause += fmt.Sprintf("%s = %s.%s, ", col, stagingTableName, col) + } + } + setClause = setClause[:len(setClause)-2] + + // Step 2.3: Perform the upsert operation, ON CONFLICT UPDATE + upsertStmt := fmt.Sprintf( + "INSERT INTO %s SELECT * FROM %s ON CONFLICT (%s) DO UPDATE SET %s;", + pgx.Identifier{dstTableName.Schema, dstTableName.Table}, + pgx.Identifier{dstTableName.Schema, stagingTableName}, + strings.Join(writeMode.UpsertKeyColumns, ", "), + setClause, + ) + log.Infof("Performing upsert operation: %s", upsertStmt) + _, err = tx.Exec(context.Background(), upsertStmt) + if err != nil { + return -1, fmt.Errorf("failed to perform upsert operation: %v", err) + } + + // Step 2.4: Drop the staging table + dropStagingTableStmt := fmt.Sprintf( + "DROP TABLE %s;", + pgx.Identifier{dstTableName.Schema, stagingTableName}, + ) + log.Infof("Dropping staging table %s", stagingTableName) + _, err = tx.Exec(context.Background(), dropStagingTableStmt) + if err != nil { + return -1, fmt.Errorf("failed to drop staging table: %v", err) + } } - metrics.LogQRepSyncMetrics(s.connector.ctx, flowJobName, syncedRows, time.Since(syncRecordsStartTime)) + + metrics.LogQRepSyncMetrics(s.connector.ctx, flowJobName, numRowsSynced, time.Since(syncRecordsStartTime)) // marshal the partition to json using protojson pbytes, err := protojson.Marshal(partition)