diff --git a/flow/connectors/postgres/qrep_sql_sync.go b/flow/connectors/postgres/qrep_sql_sync.go index 16c5102b8f..bc97fcf93d 100644 --- a/flow/connectors/postgres/qrep_sql_sync.go +++ b/flow/connectors/postgres/qrep_sql_sync.go @@ -84,7 +84,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords( // both overwrite and append if writeMode == nil || - writeMode.WriteType != protos.QRepWriteType_QREP_WRITE_MODE_UPSERT { + writeMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_APPEND { // Perform the COPY FROM operation numRowsSynced, err = tx.CopyFrom( context.Background(), @@ -108,6 +108,54 @@ func (s *QRepStagingTableSync) SyncQRepRecords( return -1, fmt.Errorf("failed to update synced_at column: %v", err) } } + } else if writeMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_OVERWRITE { + dstTableIdentifier := pgx.Identifier{dstTableName.Schema, dstTableName.Table}.Sanitize() + overwriteTempTable := pgx.Identifier{dstTableName.Schema, dstTableName.Table + "_overwrite"} + overwriteTempTableIdentifier := overwriteTempTable.Sanitize() + newColumns := make([]string, 0, len(schema.Fields)) + for _, field := range schema.Fields { + newColumns = append(newColumns, fmt.Sprintf("%s %s", QuoteIdentifier(field.Name), + qValueKindToPostgresType(string(field.Type)))) + } + newColumns = append(newColumns, fmt.Sprintf("%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP", + QuoteIdentifier(syncedAtCol))) + _, err := tx.Exec(context.Background(), fmt.Sprintf("CREATE UNLOGGED TABLE %s (%s);", + overwriteTempTableIdentifier, + strings.Join(newColumns, ", "), + )) + if err != nil { + return -1, fmt.Errorf("failed to create %s: %v", overwriteTempTableIdentifier, err) + } + + _, err = tx.CopyFrom(context.Background(), overwriteTempTable, schema.GetColumnNames(), copySource) + if err != nil { + return -1, fmt.Errorf("failed to copy records into %s: %v", overwriteTempTableIdentifier, err) + } + + _, err = tx.Exec(context.Background(), fmt.Sprintf("DROP TABLE %s;", dstTableIdentifier)) + if err != nil { + return -1, fmt.Errorf("failed to drop %s: %v", dstTableIdentifier, err) + } + + _, err = tx.Exec(context.Background(), fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", + overwriteTempTableIdentifier, QuoteIdentifier(dstTableName.Table))) + if err != nil { + return -1, fmt.Errorf("failed to rename %s to %s: %v", + overwriteTempTableIdentifier, dstTableIdentifier, err) + } + + if syncedAtCol != "" { + updateSyncedAtStmt := fmt.Sprintf( + `UPDATE %s SET %s = CURRENT_TIMESTAMP WHERE %s IS NULL;`, + dstTableIdentifier, + QuoteIdentifier(syncedAtCol), + QuoteIdentifier(syncedAtCol), + ) + _, err = tx.Exec(context.Background(), updateSyncedAtStmt) + if err != nil { + return -1, fmt.Errorf("failed to update synced_at column: %v", err) + } + } } else { // Step 2.1: Create a temp staging table stagingTableName := "_peerdb_staging_" + shared.RandomString(8)