Skip to content

Commit

Permalink
Merge branch 'main' into fix-walheartbeat-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj authored Nov 27, 2023
2 parents 60f071f + f302598 commit 8830d37
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 46 deletions.
43 changes: 6 additions & 37 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,15 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie
destinationTableIdentifier, columnName, columnCast))
}
deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ")
parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier)

fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL,
strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), c.metadataSchema,
rawTableIdentifier, destinationTableIdentifier, insertColumnsSQL, flattenedCastsSQL,
rawTableIdentifier, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL,
strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL)
fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL,
strings.Join(maps.Values(primaryKeyColumnCasts), ","), c.metadataSchema,
rawTableIdentifier, destinationTableIdentifier, deleteWhereClauseSQL)
rawTableIdentifier, parsedDstTable.String(), deleteWhereClauseSQL)

return []string{fallbackUpsertStatement, fallbackDeleteStatement}
}
Expand All @@ -529,6 +530,8 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st
}

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier)

primaryKeyColumnCasts := make(map[string]string)
primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
for columnName, genericColumnType := range normalizedTableSchema.Columns {
Expand Down Expand Up @@ -558,7 +561,7 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st
updateStatements := c.generateUpdateStatement(columnNames, unchangedToastColumns)

return fmt.Sprintf(mergeStatementSQL, strings.Join(maps.Values(primaryKeyColumnCasts), ","),
c.metadataSchema, rawTableIdentifier, destinationTableIdentifier, flattenedCastsSQL,
c.metadataSchema, rawTableIdentifier, parsedDstTable.String(), flattenedCastsSQL,
strings.Join(primaryKeySelectSQLArray, " AND "), insertColumnsSQL, insertValuesSQL, updateStatements)
}

Expand All @@ -584,40 +587,6 @@ func (c *PostgresConnector) generateUpdateStatement(allCols []string, unchangedT
return strings.Join(updateStmts, "\n")
}

func (c *PostgresConnector) getApproxTableCounts(tables []string) (int64, error) {
countTablesBatch := &pgx.Batch{}
totalCount := int64(0)
for _, table := range tables {
parsedTable, err := utils.ParseSchemaTable(table)
if err != nil {
log.Errorf("error while parsing table %s: %v", table, err)
return 0, fmt.Errorf("error while parsing table %s: %w", table, err)
}
countTablesBatch.Queue(
fmt.Sprintf("SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = '%s'::regclass;",
parsedTable.String())).
QueryRow(func(row pgx.Row) error {
var count int64
err := row.Scan(&count)
if err != nil {
log.WithFields(log.Fields{
"table": table,
}).Errorf("error while scanning row: %v", err)
return fmt.Errorf("error while scanning row: %w", err)
}
totalCount += count
return nil
})
}
countTablesResults := c.pool.SendBatch(c.ctx, countTablesBatch)
err := countTablesResults.Close()
if err != nil {
log.Errorf("error while closing statement batch: %v", err)
return 0, fmt.Errorf("error while closing statement batch: %w", err)
}
return totalCount, nil
}

func (c *PostgresConnector) getCurrentLSN() (pglogrepl.LSN, error) {
row := c.pool.QueryRow(c.ctx, "SELECT pg_current_wal_lsn();")
var result string
Expand Down
7 changes: 4 additions & 3 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,11 +635,11 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab
}()

for tableIdentifier, tableSchema := range req.TableNameSchemaMapping {
normalizedTableNameComponents, err := utils.ParseSchemaTable(tableIdentifier)
parsedNormalizedTable, err := utils.ParseSchemaTable(tableIdentifier)
if err != nil {
return nil, fmt.Errorf("error while parsing table schema and name: %w", err)
}
tableAlreadyExists, err := c.tableExists(normalizedTableNameComponents)
tableAlreadyExists, err := c.tableExists(parsedNormalizedTable)
if err != nil {
return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err)
}
Expand All @@ -649,7 +649,8 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab
}

// convert the column names and types to Postgres types
normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(tableIdentifier, tableSchema)
normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(
parsedNormalizedTable.String(), tableSchema)
_, err = createNormalizedTablesTx.Exec(c.ctx, normalizedTableCreateSQL)
if err != nil {
return nil, fmt.Errorf("error while creating normalized table: %w", err)
Expand Down
14 changes: 8 additions & 6 deletions flow/connectors/postgres/qrep_sync_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,23 @@ func (s *QRepStagingTableSync) SyncQRepRecords(

// construct the SET clause for the upsert operation
upsertMatchColsList := writeMode.UpsertKeyColumns
upsertMatchCols := make(map[string]bool)
upsertMatchCols := make(map[string]struct{})
for _, col := range upsertMatchColsList {
upsertMatchCols[col] = true
upsertMatchCols[col] = struct{}{}
}

setClause := ""
setClauseArray := make([]string, 0)
selectStrArray := make([]string, 0)
for _, col := range schema.GetColumnNames() {
_, ok := upsertMatchCols[col]
if !ok {
setClause += fmt.Sprintf("%s = EXCLUDED.%s,", col, col)
setClauseArray = append(setClauseArray, fmt.Sprintf(`"%s" = EXCLUDED."%s"`, col, col))
}
selectStrArray = append(selectStrArray, fmt.Sprintf(`"%s"`, col))
}

setClause = strings.TrimSuffix(setClause, ",")
selectStr := strings.Join(schema.GetColumnNames(), ", ")
setClause := strings.Join(setClauseArray, ",")
selectStr := strings.Join(selectStrArray, ",")

// Step 2.3: Perform the upsert operation, ON CONFLICT UPDATE
upsertStmt := fmt.Sprintf(
Expand Down

0 comments on commit 8830d37

Please sign in to comment.