Skip to content

Commit

Permalink
Eventhub fixes with Composite PKeys Support (#379)
Browse files Browse the repository at this point in the history
Co-authored-by: Kevin Biju <[email protected]>
  • Loading branch information
Amogh-Bharadwaj and heavycrystal authored Sep 7, 2023
1 parent befefed commit 1d20fc6
Show file tree
Hide file tree
Showing 14 changed files with 789 additions and 492 deletions.
10 changes: 6 additions & 4 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,12 @@ func (a *FlowableActivity) StartNormalize(
return nil, fmt.Errorf("failed to normalized records: %w", err)
}

err = a.CatalogMirrorMonitor.UpdateEndTimeForCDCBatch(ctx, input.FlowConnectionConfigs.FlowJobName,
res.EndBatchID)
if err != nil {
return nil, err
if res.Done {
err = a.CatalogMirrorMonitor.UpdateEndTimeForCDCBatch(ctx, input.FlowConnectionConfigs.FlowJobName,
res.EndBatchID)
if err != nil {
return nil, err
}
}

// log the number of batches normalized
Expand Down
2 changes: 1 addition & 1 deletion flow/cmd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (h *FlowRequestHandler) CreatePeerFlow(

maxBatchSize := int(cfg.MaxBatchSize)
if maxBatchSize == 0 {
maxBatchSize = 100000
maxBatchSize = 200000
cfg.MaxBatchSize = uint32(maxBatchSize)
}

Expand Down
10 changes: 5 additions & 5 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
if !hasJob || normalizeBatchID == syncBatchID {
log.Printf("waiting for sync to catch up for job %s, so finishing", req.FlowJobName)
return &model.NormalizeResponse{
Done: true,
Done: false,
StartBatchID: normalizeBatchID,
EndBatchID: syncBatchID,
}, nil
Expand Down Expand Up @@ -1278,13 +1278,13 @@ func (m *MergeStmtGenerator) generateDeDupedCTE() string {
) _peerdb_ranked
WHERE _peerdb_rank = 1
) SELECT * FROM _peerdb_de_duplicated_data_res`
pkey := m.NormalizedTableSchema.PrimaryKeyColumn
pkey := m.NormalizedTableSchema.PrimaryKeyColumns[0]
return fmt.Sprintf(cte, pkey)
}

// generateMergeStmt generates a merge statement.
func (m *MergeStmtGenerator) generateMergeStmt(tempTable string) string {
pkey := m.NormalizedTableSchema.PrimaryKeyColumn
pkey := m.NormalizedTableSchema.PrimaryKeyColumns[0]

// comma separated list of column names
backtickColNames := make([]string, 0)
Expand All @@ -1295,8 +1295,8 @@ func (m *MergeStmtGenerator) generateMergeStmt(tempTable string) string {
}
csep := strings.Join(backtickColNames, ", ")

udateStatementsforToastCols := m.generateUpdateStatement(pureColNames, m.UnchangedToastColumns)
updateStringToastCols := strings.Join(udateStatementsforToastCols, " ")
updateStatementsforToastCols := m.generateUpdateStatement(pureColNames, m.UnchangedToastColumns)
updateStringToastCols := strings.Join(updateStatementsforToastCols, " ")

return fmt.Sprintf(`
MERGE %s.%s _peerdb_target USING %s _peerdb_deduped
Expand Down
4 changes: 3 additions & 1 deletion flow/connectors/eventhub/eventhub.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ func (c *EventHubConnector) SetupNormalizedTables(
req *protos.SetupNormalizedTableBatchInput) (
*protos.SetupNormalizedTableBatchOutput, error) {
log.Infof("normalization for event hub is a no-op")
return nil, nil
return &protos.SetupNormalizedTableBatchOutput{
TableExistsMapping: nil,
}, nil
}

func (c *EventHubConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) {
Expand Down
19 changes: 13 additions & 6 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"reflect"
"strings"
"time"

"github.com/PeerDB-io/peer-flow/model"
Expand Down Expand Up @@ -210,11 +211,14 @@ func (p *PostgresCDCSource) consumeStream(
// tableName here is destination tableName.
// should be ideally sourceTableName as we are in pullRecrods.
// will change in future
pkeyCol := req.TableNameSchemaMapping[tableName].PrimaryKeyColumn
pkeyColVal := rec.GetItems()[pkeyCol]
pkeyColsMerged := make([]string, 0)
for _, pkeyCol := range req.TableNameSchemaMapping[tableName].PrimaryKeyColumns {
pkeyColVal := rec.GetItems()[pkeyCol]
pkeyColsMerged = append(pkeyColsMerged, fmt.Sprintf("%v", pkeyColVal))
}
tablePkeyVal := model.TableWithPkey{
TableName: tableName,
PkeyColVal: pkeyColVal,
PkeyColVal: strings.Join(pkeyColsMerged, " "),
}
_, ok := result.TablePKeyLastSeen[tablePkeyVal]
if !ok {
Expand All @@ -233,11 +237,14 @@ func (p *PostgresCDCSource) consumeStream(
result.TablePKeyLastSeen[tablePkeyVal] = len(result.Records) - 1
}
case *model.InsertRecord:
pkeyCol := req.TableNameSchemaMapping[tableName].PrimaryKeyColumn
pkeyColVal := rec.GetItems()[pkeyCol]
pkeyColsMerged := make([]string, 0)
for _, pkeyCol := range req.TableNameSchemaMapping[tableName].PrimaryKeyColumns {
pkeyColVal := rec.GetItems()[pkeyCol]
pkeyColsMerged = append(pkeyColsMerged, fmt.Sprintf("%v", pkeyColVal))
}
tablePkeyVal := model.TableWithPkey{
TableName: tableName,
PkeyColVal: pkeyColVal,
PkeyColVal: strings.Join(pkeyColsMerged, " "),
}
result.Records = append(result.Records, rec)
// all columns will be set in insert record, so add it to the map
Expand Down
140 changes: 82 additions & 58 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/jackc/pgx/v5"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)

//nolint:stylecheck
Expand All @@ -26,6 +27,9 @@ const (
_peerdb_timestamp BIGINT NOT NULL,_peerdb_destination_table_name TEXT NOT NULL,_peerdb_data JSONB NOT NULL,
_peerdb_record_type INTEGER NOT NULL, _peerdb_match_data JSONB,_peerdb_batch_id INTEGER,
_peerdb_unchanged_toast_columns TEXT)`
createRawTableBatchIDIndexSQL = "CREATE INDEX IF NOT EXISTS %s_batchid_idx ON %s.%s(_peerdb_batch_id)"
createRawTableDstTableIndexSQL = `CREATE INDEX IF NOT EXISTS
%s_dst_table_idx ON %s.%s(_peerdb_destination_table_name)`

getLastOffsetSQL = "SELECT lsn_offset FROM %s.%s WHERE mirror_job_name=$1"
getLastSyncBatchID_SQL = "SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name=$1"
Expand All @@ -43,11 +47,11 @@ const (
srcTableName = "src"
mergeStatementSQL = `WITH src_rank AS (
SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns,
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS rank
FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3
)
MERGE INTO %s dst
USING (SELECT %s,_peerdb_record_type,_peerdb_unchanged_toast_columns FROM src_rank WHERE _peerdb_rank=1) src
USING (SELECT %s,_peerdb_record_type,_peerdb_unchanged_toast_columns FROM src_rank WHERE rank=1) src
ON dst.%s=src.%s
WHEN NOT MATCHED AND src._peerdb_record_type!=2 THEN
INSERT (%s) VALUES (%s)
Expand All @@ -56,17 +60,17 @@ const (
DELETE`
fallbackUpsertStatementSQL = `WITH src_rank AS (
SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns,
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS rank
FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3
)
INSERT INTO %s (%s) SELECT %s FROM src_rank WHERE _peerdb_rank=1 AND _peerdb_record_type!=2
INSERT INTO %s (%s) SELECT %s FROM src_rank WHERE rank=1 AND _peerdb_record_type!=2
ON CONFLICT (%s) DO UPDATE SET %s`
fallbackDeleteStatementSQL = `WITH src_rank AS (
SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns,
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS rank
FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3
)
DELETE FROM %s USING src_rank WHERE %s.%s=%s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2`
DELETE FROM %s USING src_rank WHERE %s AND src_rank.rank=1 AND src_rank._peerdb_record_type=2`

dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s"
deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=$1"
Expand All @@ -86,39 +90,59 @@ func (c *PostgresConnector) getRelIDForTable(schemaTable *SchemaTable) (uint32,
return relID, nil
}

// getPrimaryKeyColumn for table returns the primary key column for a given table
// getReplicaIdentity returns the replica identity for a table.
func (c *PostgresConnector) getReplicaIdentityForTable(schemaTable *SchemaTable) (string, error) {
relID, relIDErr := c.getRelIDForTable(schemaTable)
if relIDErr != nil {
return "", fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr)
}

var replicaIdentity rune
err := c.pool.QueryRow(c.ctx,
`SELECT relreplident FROM pg_class WHERE oid = $1;`,
relID).Scan(&replicaIdentity)
if err != nil {
return "", fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, err)
}
return string(replicaIdentity), nil
}

// getPrimaryKeyColumns for table returns the primary key column for a given table
// errors if there is no primary key column or if there is more than one primary key column.
func (c *PostgresConnector) getPrimaryKeyColumn(schemaTable *SchemaTable) (string, error) {
func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *SchemaTable) ([]string, error) {
relID, err := c.getRelIDForTable(schemaTable)
if err != nil {
return "", fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, err)
return nil, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, err)
}

// Get the primary key column name
var pkCol string
pkCols := make([]string, 0)
rows, err := c.pool.Query(c.ctx,
`SELECT a.attname FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = $1 AND i.indisprimary`,
WHERE i.indrelid = $1 AND i.indisprimary ORDER BY a.attname ASC`,
relID)
if err != nil {
return "", fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
}
defer rows.Close()
// 0 rows returned, table has no primary keys
if !rows.Next() {
return "", fmt.Errorf("table %s has no primary keys", schemaTable)
return nil, fmt.Errorf("table %s has no primary keys", schemaTable)
}
err = rows.Scan(&pkCol)
if err != nil {
return "", fmt.Errorf("error scanning primary key column for table %s: %w", schemaTable, err)
}
// more than 1 row returned, table has more than 1 primary key
if rows.Next() {
return "", fmt.Errorf("table %s has more than one primary key", schemaTable)
for {
err = rows.Scan(&pkCol)
if err != nil {
return nil, fmt.Errorf("error scanning primary key column for table %s: %w", schemaTable, err)
}
pkCols = append(pkCols, pkCol)
if !rows.Next() {
break
}
}

return pkCol, nil
return pkCols, nil
}

func (c *PostgresConnector) tableExists(schemaTable *SchemaTable) (bool, error) {
Expand Down Expand Up @@ -284,14 +308,13 @@ func generateCreateTableSQLForNormalizedTable(sourceTableIdentifier string,
sourceTableSchema *protos.TableSchema) string {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns))
for columnName, genericColumnType := range sourceTableSchema.Columns {
if sourceTableSchema.PrimaryKeyColumn == strings.ToLower(columnName) {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s PRIMARY KEY,",
columnName, qValueKindToPostgresType(genericColumnType)))
} else {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName,
qValueKindToPostgresType(genericColumnType)))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName,
qValueKindToPostgresType(genericColumnType)))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(\"%s\"),",
strings.TrimSuffix(strings.Join(sourceTableSchema.PrimaryKeyColumns, ","), ",")))
log.Error(fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ",")))
return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ","))
}
Expand Down Expand Up @@ -445,26 +468,26 @@ func (c *PostgresConnector) generateNormalizeStatements(destinationTableIdentifi
unchangedToastColumns []string, rawTableIdentifier string, supportsMerge bool) []string {
if supportsMerge {
return []string{c.generateMergeStatement(destinationTableIdentifier, unchangedToastColumns, rawTableIdentifier)}
} else {
log.Warnf("Postgres version is not high enough to support MERGE, falling back to UPSERT + DELETE")
log.Warnf("TOAST columns will not be updated properly, use REPLICA IDENTITY FULL or upgrade Postgres")
return c.generateFallbackStatements(destinationTableIdentifier, rawTableIdentifier)
}
log.Warnf("Postgres version is not high enough to support MERGE, falling back to UPSERT + DELETE")
log.Warnf("TOAST columns will not be updated properly, use REPLICA IDENTITY FULL or upgrade Postgres")
return c.generateFallbackStatements(destinationTableIdentifier, rawTableIdentifier)
}

func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifier string,
rawTableIdentifier string) []string {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
columnNames := make([]string, 0, len(normalizedTableSchema.Columns))

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
var primaryKeyColumnCast string
primaryKeyColumnCasts := make(map[string]string)
for columnName, genericColumnType := range normalizedTableSchema.Columns {
columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName))
columnNames = append(columnNames, columnName)
pgType := qValueKindToPostgresType(genericColumnType)
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"",
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS %s",
columnName, pgType, columnName))
if normalizedTableSchema.PrimaryKeyColumn == columnName {
primaryKeyColumnCast = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
}
}
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")
Expand All @@ -475,38 +498,39 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie
updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf("%s=EXCLUDED.%s", columnName, columnName))
}
updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",")
fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, primaryKeyColumnCast, internalSchema,
deleteWhereClauseArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
for columnName, columnCast := range primaryKeyColumnCasts {
deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf("%s.%s=%s AND ",
destinationTableIdentifier, columnName, columnCast))
}
deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ")

fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL,
strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), internalSchema,
rawTableIdentifier, destinationTableIdentifier, insertColumnsSQL, flattenedCastsSQL,
normalizedTableSchema.PrimaryKeyColumn, updateColumnsSQL)
fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, primaryKeyColumnCast, internalSchema,
rawTableIdentifier, destinationTableIdentifier, destinationTableIdentifier,
normalizedTableSchema.PrimaryKeyColumn, primaryKeyColumnCast)
strings.TrimSuffix(strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), ","), updateColumnsSQL)
fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL,
strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), internalSchema,
rawTableIdentifier, destinationTableIdentifier, deleteWhereClauseSQL)

log.Errorln("fallbackUpsertStatement", fallbackUpsertStatement)
log.Errorln("fallbackDeleteStatement", fallbackDeleteStatement)
return []string{fallbackUpsertStatement, fallbackDeleteStatement}
}

func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier string, unchangedToastColumns []string,
rawTableIdentifier string) string {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
columnNames := maps.Keys(normalizedTableSchema.Columns)
for i, columnName := range columnNames {
columnNames[i] = fmt.Sprintf("\"%s\"", columnName)
}

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
var primaryKeyColumnCast string
primaryKeyColumnCasts := make(map[string]string)
for columnName, genericColumnType := range normalizedTableSchema.Columns {
pgType := qValueKindToPostgresType(genericColumnType)
if strings.Contains(genericColumnType, "array") {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS %s",
strings.Trim(columnName, "\""), pgType, columnName))
} else {
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS %s",
strings.Trim(columnName, "\""), pgType, columnName))
}
if normalizedTableSchema.PrimaryKeyColumn == columnName {
primaryKeyColumnCast = fmt.Sprintf("(_peerdb_data->>'%s')::%s", strings.Trim(columnName, "\""), pgType)
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS %s",
columnName, pgType, columnName))
if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
}
}
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")
Expand All @@ -519,9 +543,9 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st
insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",")
updateStatements := c.generateUpdateStatement(columnNames, unchangedToastColumns)

return fmt.Sprintf(mergeStatementSQL, primaryKeyColumnCast, internalSchema, rawTableIdentifier,
destinationTableIdentifier, flattenedCastsSQL, normalizedTableSchema.PrimaryKeyColumn,
normalizedTableSchema.PrimaryKeyColumn, insertColumnsSQL, insertValuesSQL, updateStatements)
return fmt.Sprintf(mergeStatementSQL, primaryKeyColumnCasts, internalSchema, rawTableIdentifier,
destinationTableIdentifier, flattenedCastsSQL, normalizedTableSchema.PrimaryKeyColumns,
normalizedTableSchema.PrimaryKeyColumns, insertColumnsSQL, insertValuesSQL, updateStatements)
}

func (c *PostgresConnector) generateUpdateStatement(allCols []string, unchangedToastColsLists []string) string {
Expand Down
Loading

0 comments on commit 1d20fc6

Please sign in to comment.