Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[postgres] Copy to destination not staging #498

Merged
merged 2 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) (
if p.publication != "" {
pubOpt := fmt.Sprintf("publication_names '%s'", p.publication)
pluginArguments = append(pluginArguments, pubOpt)
} else {
return nil, fmt.Errorf("publication name is not set")
}

replicationOpts := pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}
Expand Down
102 changes: 22 additions & 80 deletions flow/connectors/postgres/qrep_sync_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ package connpostgres
import (
"context"
"fmt"
"strings"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils/metrics"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
util "github.com/PeerDB-io/peer-flow/utils"
"github.com/jackc/pgx/v5"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
Expand All @@ -35,38 +33,9 @@ func (s *QRepStagingTableSync) SyncQRepRecords(
stream *model.QRecordStream,
) (int, error) {
partitionID := partition.PartitionId
runID, err := util.RandomUInt64()
if err != nil {
return -1, fmt.Errorf("failed to generate random runID: %v", err)
}

startTime := time.Now()
pool := s.connector.pool

// create a staging temporary table with the same schema as the destination table
stagingTable := fmt.Sprintf("_%d_staging", runID)

// create the staging temporary table if not exists
tmpTableStmt := fmt.Sprintf(
`CREATE TEMP TABLE %s AS SELECT * FROM %s LIMIT 0;`,
stagingTable,
dstTableName.String(),
)
_, err = pool.Exec(context.Background(), tmpTableStmt)
if err != nil {
log.WithFields(log.Fields{
"flowName": flowJobName,
"partitionID": partitionID,
"destinationTable": dstTableName,
}).Errorf(
"failed to create staging temporary table %s, statement: '%s'. Error: %v",
stagingTable,
tmpTableStmt,
err,
)
return 0, fmt.Errorf("failed to create staging temporary table %s: %w", stagingTable, err)
}

pool := s.connector.pool
schema, err := stream.Schema()
if err != nil {
log.WithFields(log.Fields{
Expand All @@ -77,30 +46,13 @@ func (s *QRepStagingTableSync) SyncQRepRecords(
return 0, fmt.Errorf("failed to get schema from stream: %w", err)
}

// Step 2: Insert records into the staging table.
copySource := model.NewQRecordBatchCopyFromSource(stream)

// Perform the COPY FROM operation
syncRecordsStartTime := time.Now()
syncedRows, err := pool.CopyFrom(
context.Background(),
pgx.Identifier{stagingTable},
schema.GetColumnNames(),
copySource,
)

if err != nil {
return -1, fmt.Errorf("failed to copy records into staging temporary table: %v", err)
}
metrics.LogQRepSyncMetrics(s.connector.ctx, flowJobName, syncedRows, time.Since(syncRecordsStartTime))

// Second transaction - to handle rest of the processing
tx2, err := pool.Begin(context.Background())
tx, err := pool.Begin(context.Background())
if err != nil {
return 0, fmt.Errorf("failed to begin transaction: %v", err)
}
defer func() {
if err := tx2.Rollback(context.Background()); err != nil {
if err := tx.Rollback(context.Background()); err != nil {
if err != pgx.ErrTxClosed {
log.WithFields(log.Fields{
"flowName": flowJobName,
Expand All @@ -111,33 +63,22 @@ func (s *QRepStagingTableSync) SyncQRepRecords(
}
}()

colNames := schema.GetColumnNames()
// wrap the column names in double quotes to handle reserved keywords
for i, colName := range colNames {
colNames[i] = fmt.Sprintf("\"%s\"", colName)
}
colNamesStr := strings.Join(colNames, ", ")
log.WithFields(log.Fields{
"flowName": flowJobName,
"partitionID": partitionID,
}).Infof("Obtained column names and quoted them in QRep sync")
insertFromStagingStmt := fmt.Sprintf(
"INSERT INTO %s (%s) SELECT %s FROM %s",
dstTableName.String(),
colNamesStr,
colNamesStr,
stagingTable,
// Step 2: Insert records into the destination table.
copySource := model.NewQRecordBatchCopyFromSource(stream)

// Perform the COPY FROM operation
syncRecordsStartTime := time.Now()
syncedRows, err := tx.CopyFrom(
context.Background(),
pgx.Identifier{dstTableName.Schema, dstTableName.Table},
schema.GetColumnNames(),
copySource,
)

_, err = tx2.Exec(context.Background(), insertFromStagingStmt)
if err != nil {
log.WithFields(log.Fields{
"flowName": flowJobName,
"partitionID": partitionID,
"destinationTable": dstTableName,
}).Errorf("failed to execute statement '%s': %v", insertFromStagingStmt, err)
return -1, fmt.Errorf("failed to execute statements in a transaction: %v", err)
return -1, fmt.Errorf("failed to copy records into destination table: %v", err)
}
metrics.LogQRepSyncMetrics(s.connector.ctx, flowJobName, syncedRows, time.Since(syncRecordsStartTime))

// marshal the partition to json using protojson
pbytes, err := protojson.Marshal(partition)
Expand All @@ -155,7 +96,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords(
"partitionID": partitionID,
"destinationTable": dstTableName,
}).Infof("Executing transaction inside Qrep sync")
rows, err := tx2.Exec(
rows, err := tx.Exec(
context.Background(),
insertMetadataStmt,
flowJobName,
Expand All @@ -167,18 +108,19 @@ func (s *QRepStagingTableSync) SyncQRepRecords(
if err != nil {
return -1, fmt.Errorf("failed to execute statements in a transaction: %v", err)
}

err = tx.Commit(context.Background())
if err != nil {
return -1, fmt.Errorf("failed to commit transaction: %v", err)
}

totalRecordsAtTarget, err := s.connector.getApproxTableCounts([]string{dstTableName.String()})
if err != nil {
return -1, fmt.Errorf("failed to get total records at target: %v", err)
}
metrics.LogQRepNormalizeMetrics(s.connector.ctx, flowJobName, rows.RowsAffected(),
time.Since(normalizeRecordsStartTime), totalRecordsAtTarget)

err = tx2.Commit(context.Background())
if err != nil {
return -1, fmt.Errorf("failed to commit transaction: %v", err)
}

numRowsInserted := copySource.NumRecords()
log.WithFields(log.Fields{
"flowName": flowJobName,
Expand Down
Loading