Skip to content

Commit

Permalink
misc fixes and improvements for PG, along with softdel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal committed Jan 3, 2024
1 parent 8e4d68c commit 7be2b55
Show file tree
Hide file tree
Showing 5 changed files with 435 additions and 56 deletions.
94 changes: 57 additions & 37 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,22 @@ const (
createRawTableBatchIDIndexSQL = "CREATE INDEX IF NOT EXISTS %s_batchid_idx ON %s.%s(_peerdb_batch_id)"
createRawTableDstTableIndexSQL = "CREATE INDEX IF NOT EXISTS %s_dst_table_idx ON %s.%s(_peerdb_destination_table_name)"

getLastOffsetSQL = "SELECT lsn_offset FROM %s.%s WHERE mirror_job_name=$1"
setLastOffsetSQL = "UPDATE %s.%s SET lsn_offset=GREATEST(lsn_offset, $1) WHERE mirror_job_name=$2"
getLastSyncBatchID_SQL = "SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name=$1"
getLastNormalizeBatchID_SQL = "SELECT normalize_batch_id FROM %s.%s WHERE mirror_job_name=$1"
createNormalizedTableSQL = "CREATE TABLE IF NOT EXISTS %s(%s)"
getLastOffsetSQL = "SELECT lsn_offset FROM %s.%s WHERE mirror_job_name=$1"
setLastOffsetSQL = "UPDATE %s.%s SET lsn_offset=GREATEST(lsn_offset, $1) WHERE mirror_job_name=$2"
getLastSyncBatchID_SQL = "SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name=$1"
getLastSyncAndNormalizeBatchID_SQL = "SELECT sync_batch_id,normalize_batch_id FROM %s.%s WHERE mirror_job_name=$1"
createNormalizedTableSQL = "CREATE TABLE IF NOT EXISTS %s(%s)"

insertJobMetadataSQL = "INSERT INTO %s.%s VALUES ($1,$2,$3,$4)"
checkIfJobMetadataExistsSQL = "SELECT COUNT(1)::TEXT::BOOL FROM %s.%s WHERE mirror_job_name=$1"
updateMetadataForSyncRecordsSQL = "UPDATE %s.%s SET lsn_offset=GREATEST(lsn_offset, $1), sync_batch_id=$2 WHERE mirror_job_name=$3"
updateMetadataForNormalizeRecordsSQL = "UPDATE %s.%s SET normalize_batch_id=$1 WHERE mirror_job_name=$2"

getDistinctDestinationTableNamesSQL = `SELECT DISTINCT _peerdb_destination_table_name FROM %s.%s WHERE
_peerdb_batch_id>$1 AND _peerdb_batch_id<=$2`
getTableNameToUnchangedToastColsSQL = `SELECT _peerdb_destination_table_name,
ARRAY_AGG(DISTINCT _peerdb_unchanged_toast_columns) FROM %s.%s WHERE
_peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 GROUP BY _peerdb_destination_table_name`
_peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_record_type!=2 GROUP BY _peerdb_destination_table_name`
srcTableName = "src"
mergeStatementSQL = `WITH src_rank AS (
SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns,
Expand Down Expand Up @@ -428,46 +430,40 @@ func generateCreateTableSQLForNormalizedTable(
}

func (c *PostgresConnector) GetLastSyncBatchID(jobName string) (int64, error) {
rows, err := c.pool.Query(c.ctx, fmt.Sprintf(
var result pgtype.Int8
err := c.pool.QueryRow(c.ctx, fmt.Sprintf(
getLastSyncBatchID_SQL,
c.metadataSchema,
mirrorJobsTableIdentifier,
), jobName)
if err != nil {
return 0, fmt.Errorf("error querying Postgres peer for last syncBatchId: %w", err)
}
defer rows.Close()

var result pgtype.Int8
if !rows.Next() {
c.logger.Info("No row found, returning 0")
return 0, nil
}
err = rows.Scan(&result)
), jobName).Scan(&result)
if err != nil {
if err == pgx.ErrNoRows {
c.logger.Info("No row found, returning 0")
return 0, nil
}
return 0, fmt.Errorf("error while reading result row: %w", err)
}
return result.Int64, nil
}

func (c *PostgresConnector) getLastNormalizeBatchID(jobName string) (int64, error) {
rows, err := c.pool.Query(c.ctx, fmt.Sprintf(getLastNormalizeBatchID_SQL, c.metadataSchema,
mirrorJobsTableIdentifier), jobName)
if err != nil {
return 0, fmt.Errorf("error querying Postgres peer for last normalizeBatchId: %w", err)
}
defer rows.Close()

var result pgtype.Int8
if !rows.Next() {
c.logger.Info("No row found returning 0")
return 0, nil
}
err = rows.Scan(&result)
func (c *PostgresConnector) GetLastSyncAndNormalizeBatchID(jobName string) (*model.SyncAndNormalizeBatchID, error) {
var syncResult, normalizeResult pgtype.Int8
err := c.pool.QueryRow(c.ctx, fmt.Sprintf(
getLastSyncAndNormalizeBatchID_SQL,
c.metadataSchema,
mirrorJobsTableIdentifier,
), jobName).Scan(&syncResult, &normalizeResult)
if err != nil {
return 0, fmt.Errorf("error while reading result row: %w", err)
if err == pgx.ErrNoRows {
c.logger.Info("No row found, returning 0")
return &model.SyncAndNormalizeBatchID{}, nil
}
return nil, fmt.Errorf("error while reading result row: %w", err)
}
return result.Int64, nil
return &model.SyncAndNormalizeBatchID{
SyncBatchID: syncResult.Int64,
NormalizeBatchID: normalizeResult.Int64,
}, nil
}

func (c *PostgresConnector) jobMetadataExists(jobName string) (bool, error) {
Expand Down Expand Up @@ -549,6 +545,30 @@ func (c *PostgresConnector) updateNormalizeMetadata(flowJobName string, normaliz
return nil
}

func (c *PostgresConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64,
normalizeBatchID int64,
) ([]string, error) {
rawTableIdentifier := getRawTableIdentifier(flowJobName)

rows, err := c.pool.Query(c.ctx, fmt.Sprintf(getDistinctDestinationTableNamesSQL, c.metadataSchema,
rawTableIdentifier), normalizeBatchID, syncBatchID)
if err != nil {
return nil, fmt.Errorf("error while retrieving table names for normalization: %w", err)
}
defer rows.Close()

var result pgtype.Text
destinationTableNames := make([]string, 0)
for rows.Next() {
err = rows.Scan(&result)
if err != nil {
return nil, fmt.Errorf("failed to read row: %w", err)
}
destinationTableNames = append(destinationTableNames, result.String)
}
return destinationTableNames, nil
}

func (c *PostgresConnector) getTableNametoUnchangedCols(flowJobName string, syncBatchID int64,
normalizeBatchID int64,
) (map[string][]string, error) {
Expand Down Expand Up @@ -768,7 +788,7 @@ func (c *PostgresConnector) generateUpdateStatement(allCols []string,

ssep := strings.Join(tmpArray, ",")
updateStmt := fmt.Sprintf(`WHEN MATCHED AND
src._peerdb_record_type=1 AND _peerdb_unchanged_toast_columns='%s'
src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='%s'
THEN UPDATE SET %s `, cols, ssep)
updateStmts = append(updateStmts, updateStmt)

Expand All @@ -780,7 +800,7 @@ func (c *PostgresConnector) generateUpdateStatement(allCols []string,
fmt.Sprintf(`"%s" = TRUE`, peerdbCols.SoftDeleteColName))
ssep := strings.Join(tmpArray, ", ")
updateStmt := fmt.Sprintf(`WHEN MATCHED AND
src._peerdb_record_type = 2 AND _peerdb_unchanged_toast_columns='%s'
src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s'
THEN UPDATE SET %s `, cols, ssep)
updateStmts = append(updateStmts, updateStmt)
}
Expand Down
48 changes: 30 additions & 18 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,30 +404,41 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S

func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) {
rawTableIdentifier := getRawTableIdentifier(req.FlowJobName)
syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName)

jobMetadataExists, err := c.jobMetadataExists(req.FlowJobName)
if err != nil {
return nil, err
}
normalizeBatchID, err := c.getLastNormalizeBatchID(req.FlowJobName)
if err != nil {
return nil, err
// no SyncFlow has run, chill until more records are loaded.
if !jobMetadataExists {
c.logger.Info("no metadata found for mirror")
return &model.NormalizeResponse{
Done: false,
}, nil
}
jobMetadataExists, err := c.jobMetadataExists(req.FlowJobName)

batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err)
}
// normalize has caught up with sync or no SyncFlow has run, chill until more records are loaded.
if normalizeBatchID >= syncBatchID || !jobMetadataExists {
// normalize has caught up with sync, chill until more records are loaded.
if batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID {
c.logger.Info(fmt.Sprintf("no records to normalize: syncBatchID %d, normalizeBatchID %d",
syncBatchID, normalizeBatchID))
batchIDs.SyncBatchID, batchIDs.NormalizeBatchID))
return &model.NormalizeResponse{
Done: false,
StartBatchID: normalizeBatchID,
EndBatchID: syncBatchID,
StartBatchID: batchIDs.NormalizeBatchID,
EndBatchID: batchIDs.SyncBatchID,
}, nil
}

unchangedToastColsMap, err := c.getTableNametoUnchangedCols(req.FlowJobName, syncBatchID, normalizeBatchID)
destinationTableNames, err := c.getDistinctTableNamesInBatch(
req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID)
if err != nil {
return nil, err
}
unchangedToastColsMap, err := c.getTableNametoUnchangedCols(req.FlowJobName,
batchIDs.SyncBatchID, batchIDs.NormalizeBatchID)
if err != nil {
return nil, err
}
Expand All @@ -449,16 +460,17 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
}
mergeStatementsBatch := &pgx.Batch{}
totalRowsAffected := 0
for destinationTableName, unchangedToastCols := range unchangedToastColsMap {
for _, destinationTableName := range destinationTableNames {
peerdbCols := protos.PeerDBColumns{
SoftDeleteColName: req.SoftDeleteColName,
SyncedAtColName: req.SyncedAtColName,
SoftDelete: req.SoftDelete,
}
normalizeStatements := c.generateNormalizeStatements(destinationTableName, unchangedToastCols,
normalizeStatements := c.generateNormalizeStatements(destinationTableName, unchangedToastColsMap[destinationTableName],
rawTableIdentifier, supportsMerge, &peerdbCols)
fmt.Println(normalizeStatements)
for _, normalizeStatement := range normalizeStatements {
mergeStatementsBatch.Queue(normalizeStatement, normalizeBatchID, syncBatchID, destinationTableName).Exec(
mergeStatementsBatch.Queue(normalizeStatement, batchIDs.NormalizeBatchID, batchIDs.SyncBatchID, destinationTableName).Exec(
func(ct pgconn.CommandTag) error {
totalRowsAffected += int(ct.RowsAffected())
return nil
Expand All @@ -475,7 +487,7 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
c.logger.Info(fmt.Sprintf("normalized %d records", totalRowsAffected))

// updating metadata with new normalizeBatchID
err = c.updateNormalizeMetadata(req.FlowJobName, syncBatchID, normalizeRecordsTx)
err = c.updateNormalizeMetadata(req.FlowJobName, batchIDs.SyncBatchID, normalizeRecordsTx)
if err != nil {
return nil, err
}
Expand All @@ -487,8 +499,8 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)

return &model.NormalizeResponse{
Done: true,
StartBatchID: normalizeBatchID + 1,
EndBatchID: syncBatchID,
StartBatchID: batchIDs.NormalizeBatchID + 1,
EndBatchID: batchIDs.SyncBatchID,
}, nil
}

Expand Down
Loading

0 comments on commit 7be2b55

Please sign in to comment.