Skip to content

Commit

Permalink
fixes post-rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed Sep 7, 2023
1 parent 50d981c commit d08460b
Show file tree
Hide file tree
Showing 11 changed files with 632 additions and 340 deletions.
4 changes: 2 additions & 2 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ func (m *MergeStmtGenerator) generateDeDupedCTE() string {
}

// generateMergeStmt generates a merge statement.
func (m *MergeStmtGenerator) generateMergeStmt() string {
func (m *MergeStmtGenerator) generateMergeStmt(tempTable string) string {
pkey := m.NormalizedTableSchema.PrimaryKeyColumns[0]

// comma separated list of column names
Expand All @@ -1295,7 +1295,7 @@ func (m *MergeStmtGenerator) generateMergeStmt() string {
}
csep := strings.Join(backtickColNames, ", ")

updateStatementsforToastCols := m.generateUpdateStatement(colNames, m.UnchangedToastColumns)
updateStatementsforToastCols := m.generateUpdateStatement(pureColNames, m.UnchangedToastColumns)
updateStringToastCols := strings.Join(updateStatementsforToastCols, " ")

return fmt.Sprintf(`
Expand Down
5 changes: 2 additions & 3 deletions flow/connectors/eventhub/eventhub.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,8 @@ func (c *EventHubConnector) SetupNormalizedTables(
req *protos.SetupNormalizedTableBatchInput) (
*protos.SetupNormalizedTableBatchOutput, error) {
log.Infof("normalization for event hub is a no-op")
return &protos.SetupNormalizedTableOutput{
TableIdentifier: req.TableIdentifier,
AlreadyExists: false,
return &protos.SetupNormalizedTableBatchOutput{
TableExistsMapping: nil,
}, nil
}

Expand Down
91 changes: 38 additions & 53 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ const (
srcTableName = "src"
mergeStatementSQL = `WITH src_rank AS (
SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns,
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS rank
FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3
)
MERGE INTO %s dst
USING (SELECT %s,_peerdb_record_type,_peerdb_unchanged_toast_columns FROM src_rank WHERE _peerdb_rank=1) src
USING (SELECT %s,_peerdb_record_type,_peerdb_unchanged_toast_columns FROM src_rank WHERE rank=1) src
ON dst.%s=src.%s
WHEN NOT MATCHED AND src._peerdb_record_type!=2 THEN
INSERT (%s) VALUES (%s)
Expand All @@ -59,17 +59,17 @@ const (
DELETE`
fallbackUpsertStatementSQL = `WITH src_rank AS (
SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns,
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS rank
FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3
)
INSERT INTO %s (%s) SELECT %s FROM src_rank WHERE _peerdb_rank=1 AND _peerdb_record_type!=2
INSERT INTO %s (%s) SELECT %s FROM src_rank WHERE rank=1 AND _peerdb_record_type!=2
ON CONFLICT (%s) DO UPDATE SET %s`
fallbackDeleteStatementSQL = `WITH src_rank AS (
SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns,
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS rank
FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3
)
DELETE FROM %s USING src_rank WHERE %s.%s=%s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2`
DELETE FROM %s USING src_rank WHERE %s AND src_rank.rank=1 AND src_rank._peerdb_record_type=2`

dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s"
deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=$1"
Expand Down Expand Up @@ -103,7 +103,6 @@ func (c *PostgresConnector) getReplicaIdentityForTable(schemaTable *SchemaTable)
if err != nil {
return "", fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, err)
}

return string(replicaIdentity), nil
}

Expand Down Expand Up @@ -308,14 +307,13 @@ func generateCreateTableSQLForNormalizedTable(sourceTableIdentifier string,
sourceTableSchema *protos.TableSchema) string {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns))
for columnName, genericColumnType := range sourceTableSchema.Columns {
if sourceTableSchema.PrimaryKeyColumn == strings.ToLower(columnName) {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s PRIMARY KEY,",
columnName, qValueKindToPostgresType(genericColumnType)))
} else {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName,
qValueKindToPostgresType(genericColumnType)))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName,
qValueKindToPostgresType(genericColumnType)))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(\"%s\"),",
strings.TrimSuffix(strings.Join(sourceTableSchema.PrimaryKeyColumns, ","), ",")))
log.Error(fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ",")))
return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ","))
}
Expand Down Expand Up @@ -467,6 +465,9 @@ func (c *PostgresConnector) getTableNametoUnchangedCols(flowJobName string, sync

func (c *PostgresConnector) generateNormalizeStatements(destinationTableIdentifier string,
unchangedToastColumns []string, rawTableIdentifier string, supportsMerge bool) []string {
if supportsMerge {
return []string{c.generateMergeStatement(destinationTableIdentifier, unchangedToastColumns, rawTableIdentifier)}
}
log.Warnf("Postgres version is not high enough to support MERGE, falling back to UPSERT + DELETE")
log.Warnf("TOAST columns will not be updated properly, use REPLICA IDENTITY FULL or upgrade Postgres")
return c.generateFallbackStatements(destinationTableIdentifier, rawTableIdentifier)
Expand All @@ -476,12 +477,13 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie
rawTableIdentifier string) []string {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
columnNames := make([]string, 0, len(normalizedTableSchema.Columns))

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
primaryKeyColumnCasts := make(map[string]string)
for columnName, genericColumnType := range normalizedTableSchema.Columns {
columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName))
columnNames = append(columnNames, columnName)
pgType := qValueKindToPostgresType(genericColumnType)
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"",
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS %s",
columnName, pgType, columnName))
if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
Expand Down Expand Up @@ -519,51 +521,34 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st
rawTableIdentifier string) string {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
columnNames := maps.Keys(normalizedTableSchema.Columns)
for i, columnName := range columnNames {
columnNames[i] = fmt.Sprintf("\"%s\"", columnName)
}

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
var primaryKeyColumnCast string
primaryKeyColumnCasts := make(map[string]string)
for columnName, genericColumnType := range normalizedTableSchema.Columns {
pgType := qValueKindToPostgresType(genericColumnType)
if strings.Contains(genericColumnType, "array") {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS %s",
strings.Trim(columnName, "\""), pgType, columnName))
} else {
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS %s",
strings.Trim(columnName, "\""), pgType, columnName))
}
if normalizedTableSchema.PrimaryKeyColumn == columnName {
primaryKeyColumnCast = fmt.Sprintf("(_peerdb_data->>'%s')::%s", strings.Trim(columnName, "\""), pgType)
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS %s",
columnName, pgType, columnName))
if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
}
}
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")

// return fmt.Sprintf(mergeStatementSQL, primaryKeyColumnCast, internalSchema, rawTableIdentifier,
// destinationTableIdentifier, flattenedCastsSQL, normalizedTableSchema.PrimaryKeyColumn,
// normalizedTableSchema.PrimaryKeyColumn, insertColumnsSQL, insertValuesSQL, updateStatements)
// }

// func (c *PostgresConnector) generateUpdateStatement(allCols []string, unchangedToastColsLists []string) string {
// updateStmts := make([]string, 0)

// for _, cols := range unchangedToastColsLists {
// unchangedColsArray := strings.Split(cols, ",")
// otherCols := utils.ArrayMinus(allCols, unchangedColsArray)
// tmpArray := make([]string, 0)
// for _, colName := range otherCols {
// tmpArray = append(tmpArray, fmt.Sprintf("%s=src.%s", colName, colName))
// }
// ssep := strings.Join(tmpArray, ",")
// updateStmt := fmt.Sprintf(`WHEN MATCHED AND
// src._peerdb_record_type=1 AND _peerdb_unchanged_toast_columns='%s'
// THEN UPDATE SET %s `, cols, ssep)
// updateStmts = append(updateStmts, updateStmt)
// }
// return strings.Join(updateStmts, "\n")
// }
insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",")
insertValuesSQLArray := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName))
}
insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",")
updateStatements := c.generateUpdateStatement(columnNames, unchangedToastColumns)

return fmt.Sprintf(mergeStatementSQL, primaryKeyColumnCasts, internalSchema, rawTableIdentifier,
destinationTableIdentifier, flattenedCastsSQL, normalizedTableSchema.PrimaryKeyColumns,
normalizedTableSchema.PrimaryKeyColumns, insertColumnsSQL, insertValuesSQL, updateStatements)
}

func (c *PostgresConnector) generateUpdateStatement(allCols []string, unchangedToastColsLists []string) string {
updateStmts := make([]string, 0)

for _, cols := range unchangedToastColsLists {
unchangedColsArray := strings.Split(cols, ",")
Expand Down
17 changes: 10 additions & 7 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ func (c *PostgresConnector) GetTableSchema(
req *protos.GetTableSchemaBatchInput) (*protos.GetTableSchemaBatchOutput, error) {
res := make(map[string]*protos.TableSchema)
for _, tableName := range req.TableIdentifiers {
tableSchema, err := c.getTableSchemaForTable(tableName)
tableSchema, err := c.getTableSchemaForTable(tableName, req)
if err != nil {
return nil, err
}
Expand All @@ -561,11 +561,13 @@ func (c *PostgresConnector) GetTableSchema(

func (c *PostgresConnector) getTableSchemaForTable(
tableName string,
req *protos.GetTableSchemaBatchInput,
) (*protos.TableSchema, error) {
schemaTable, err := parseSchemaTable(tableName)
if err != nil {
return nil, err
}
log.Infof("getting schema for table %s", tableName)

// Get the column names and types
rows, err := c.pool.Query(c.ctx,
Expand All @@ -577,16 +579,17 @@ func (c *PostgresConnector) getTableSchemaForTable(

pKeyCols, err := c.getPrimaryKeyColumns(schemaTable)
if err != nil {
replicaIdentity, err := c.getReplicaIdentityForTable(schemaTable)
if req.DestinationPeerType != protos.DBType_EVENTHUB || err != nil || replicaIdentity != "f" {
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
replicaIdentity, replicaIdentityErr := c.getReplicaIdentityForTable(schemaTable)
log.Infof("replica identity: %s", replicaIdentity)
if req.DestinationPeerType != protos.DBType_EVENTHUB || replicaIdentityErr != nil || replicaIdentity != "f" {
return nil, fmt.Errorf("error getting primary key column or replica identity for table %s: %w", schemaTable, err)
}
}

res := &protos.TableSchema{
TableIdentifier: tableName,
Columns: make(map[string]string),
PrimaryKeyColumn: pkey,
TableIdentifier: tableName,
Columns: make(map[string]string),
PrimaryKeyColumns: pKeyCols,
}

for _, fieldDescription := range rows.FieldDescriptions() {
Expand Down
6 changes: 3 additions & 3 deletions flow/connectors/postgres/postgres_cdc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ func (suite *PostgresCDCTestSuite) TestSimpleHappyFlow() {
"id": string(qvalue.QValueKindInt32),
"name": string(qvalue.QValueKindString),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
},
}}, tableNameSchema)
tableNameSchemaMapping[simpleHappyFlowDstTableName] =
Expand Down Expand Up @@ -600,7 +600,7 @@ func (suite *PostgresCDCTestSuite) TestAllTypesHappyFlow() {
"c40": string(qvalue.QValueKindUUID),
"c41": string(qvalue.QValueKindString),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
},
},
}, tableNameSchema)
Expand Down Expand Up @@ -688,7 +688,7 @@ func (suite *PostgresCDCTestSuite) TestToastHappyFlow() {
"n_b": string(qvalue.QValueKindBytes),
"lz4_b": string(qvalue.QValueKindBytes),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
},
}}, tableNameSchema)
tableNameSchemaMapping[toastHappyFlowDstTableName] =
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,7 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
mergeStatement := fmt.Sprintf(mergeStatementSQL, destinationTableIdentifier, toVariantColumnName,
rawTableIdentifier, normalizeBatchID, syncBatchID, flattenedCastsSQL,
normalizedTableSchema.PrimaryKeyColumns[0], pkeyColStr, insertColumnsSQL, insertValuesSQL,
updateStringToastCols)
updateStringToastCols, deletePart)

result, err := normalizeRecordsTx.ExecContext(c.ctx, mergeStatement, destinationTableIdentifier)
if err != nil {
Expand Down
Loading

0 comments on commit d08460b

Please sign in to comment.