Skip to content

Commit

Permalink
Support upsert mode pg -> pg
Browse files Browse the repository at this point in the history
  • Loading branch information
iskakaushik committed Oct 21, 2023
1 parent eb4024d commit 8153697
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 11 deletions.
3 changes: 2 additions & 1 deletion flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ func (c *PostgresConnector) SyncQRepRecords(
switch syncMode {
case protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT:
stagingTableSync := &QRepStagingTableSync{connector: c}
return stagingTableSync.SyncQRepRecords(config.FlowJobName, dstTable, partition, stream)
return stagingTableSync.SyncQRepRecords(
config.FlowJobName, dstTable, partition, stream, config.WriteMode)
case protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO:
return 0, fmt.Errorf("[postgres] SyncQRepRecords not implemented for storage avro sync mode")
default:
Expand Down
95 changes: 85 additions & 10 deletions flow/connectors/postgres/qrep_sync_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package connpostgres
import (
"context"
"fmt"
"strings"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils/metrics"
Expand Down Expand Up @@ -31,6 +32,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords(
dstTableName *SchemaTable,
partition *protos.QRepPartition,
stream *model.QRecordStream,
writeMode *protos.QRepWriteMode,
) (int, error) {
partitionID := partition.PartitionId
startTime := time.Now()
Expand Down Expand Up @@ -66,19 +68,92 @@ func (s *QRepStagingTableSync) SyncQRepRecords(
// Step 2: Insert records into the destination table.
copySource := model.NewQRecordBatchCopyFromSource(stream)

// Perform the COPY FROM operation
syncRecordsStartTime := time.Now()
syncedRows, err := tx.CopyFrom(
context.Background(),
pgx.Identifier{dstTableName.Schema, dstTableName.Table},
schema.GetColumnNames(),
copySource,
)
var numRowsSynced int64

if err != nil {
return -1, fmt.Errorf("failed to copy records into destination table: %v", err)
if writeMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_APPEND {
// Perform the COPY FROM operation
numRowsSynced, err = tx.CopyFrom(
context.Background(),
pgx.Identifier{dstTableName.Schema, dstTableName.Table},
schema.GetColumnNames(),
copySource,
)
if err != nil {
return -1, fmt.Errorf("failed to copy records into destination table: %v", err)
}
} else {
// Step 2.1: Create a temp staging table
stagingTableName := fmt.Sprintf("%s_%s", dstTableName.Table, partitionID)
createStagingTableStmt := fmt.Sprintf(
"CREATE TABLE %s LIKE %s;",
pgx.Identifier{dstTableName.Schema, stagingTableName},
pgx.Identifier{dstTableName.Schema, dstTableName.Table},
)

log.Infof("Creating staging table %s", stagingTableName)
_, err = tx.Exec(context.Background(), createStagingTableStmt)

if err != nil {
return -1, fmt.Errorf("failed to create staging table: %v", err)
}

// Step 2.2: Insert records into the staging table
numRowsSynced, err = tx.CopyFrom(
context.Background(),
pgx.Identifier{dstTableName.Schema, stagingTableName},
schema.GetColumnNames(),
copySource,
)
if err != nil {
return -1, fmt.Errorf("failed to copy records into staging table: %v", err)
}

// construct the SET clause for the upsert operation
upsertMatchColsList := writeMode.UpsertKeyColumns
upsertMatchCols := make(map[string]bool)
for _, col := range upsertMatchColsList {
upsertMatchCols[col] = true
}

setClause := ""
for _, col := range schema.GetColumnNames() {
_, ok := upsertMatchCols[col]
if !ok {
setClause += fmt.Sprintf("%s = %s.%s, ", col, stagingTableName, col)
}
}
setClause = setClause[:len(setClause)-2]

// Step 2.3: Perform the upsert operation, ON CONFLICT UPDATE
upsertStmt := fmt.Sprintf(
"INSERT INTO %s SELECT * FROM %s ON CONFLICT (%s) DO UPDATE SET %s;",
pgx.Identifier{dstTableName.Schema, dstTableName.Table},
pgx.Identifier{dstTableName.Schema, stagingTableName},
strings.Join(writeMode.UpsertKeyColumns, ", "),
setClause,
)
log.Infof("Performing upsert operation: %s", upsertStmt)
res, err := tx.Exec(context.Background(), upsertStmt)
if err != nil {
return -1, fmt.Errorf("failed to perform upsert operation: %v", err)
}

numRowsSynced = res.RowsAffected()

// Step 2.4: Drop the staging table
dropStagingTableStmt := fmt.Sprintf(
"DROP TABLE %s;",
pgx.Identifier{dstTableName.Schema, stagingTableName},
)
log.Infof("Dropping staging table %s", stagingTableName)
_, err = tx.Exec(context.Background(), dropStagingTableStmt)
if err != nil {
return -1, fmt.Errorf("failed to drop staging table: %v", err)
}
}
metrics.LogQRepSyncMetrics(s.connector.ctx, flowJobName, syncedRows, time.Since(syncRecordsStartTime))

metrics.LogQRepSyncMetrics(s.connector.ctx, flowJobName, numRowsSynced, time.Since(syncRecordsStartTime))

// marshal the partition to json using protojson
pbytes, err := protojson.Marshal(partition)
Expand Down

0 comments on commit 8153697

Please sign in to comment.