Skip to content

Commit

Permalink
support mixed case table names pg->pg
Browse files Browse the repository at this point in the history
  • Loading branch information
iskakaushik committed Nov 27, 2023
1 parent c8afa19 commit c92bffa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 40 deletions.
43 changes: 6 additions & 37 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,15 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie
destinationTableIdentifier, columnName, columnCast))
}
deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ")
parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier)

fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL,
strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), c.metadataSchema,
rawTableIdentifier, destinationTableIdentifier, insertColumnsSQL, flattenedCastsSQL,
rawTableIdentifier, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL,
strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL)
fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL,
strings.Join(maps.Values(primaryKeyColumnCasts), ","), c.metadataSchema,
rawTableIdentifier, destinationTableIdentifier, deleteWhereClauseSQL)
rawTableIdentifier, parsedDstTable.String(), deleteWhereClauseSQL)

return []string{fallbackUpsertStatement, fallbackDeleteStatement}
}
Expand All @@ -529,6 +530,8 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st
}

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier)

primaryKeyColumnCasts := make(map[string]string)
primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
for columnName, genericColumnType := range normalizedTableSchema.Columns {
Expand Down Expand Up @@ -558,7 +561,7 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st
updateStatements := c.generateUpdateStatement(columnNames, unchangedToastColumns)

return fmt.Sprintf(mergeStatementSQL, strings.Join(maps.Values(primaryKeyColumnCasts), ","),
c.metadataSchema, rawTableIdentifier, destinationTableIdentifier, flattenedCastsSQL,
c.metadataSchema, rawTableIdentifier, parsedDstTable.String(), flattenedCastsSQL,
strings.Join(primaryKeySelectSQLArray, " AND "), insertColumnsSQL, insertValuesSQL, updateStatements)
}

Expand All @@ -584,40 +587,6 @@ func (c *PostgresConnector) generateUpdateStatement(allCols []string, unchangedT
return strings.Join(updateStmts, "\n")
}

func (c *PostgresConnector) getApproxTableCounts(tables []string) (int64, error) {
countTablesBatch := &pgx.Batch{}
totalCount := int64(0)
for _, table := range tables {
parsedTable, err := utils.ParseSchemaTable(table)
if err != nil {
log.Errorf("error while parsing table %s: %v", table, err)
return 0, fmt.Errorf("error while parsing table %s: %w", table, err)
}
countTablesBatch.Queue(
fmt.Sprintf("SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = '%s'::regclass;",
parsedTable.String())).
QueryRow(func(row pgx.Row) error {
var count int64
err := row.Scan(&count)
if err != nil {
log.WithFields(log.Fields{
"table": table,
}).Errorf("error while scanning row: %v", err)
return fmt.Errorf("error while scanning row: %w", err)
}
totalCount += count
return nil
})
}
countTablesResults := c.pool.SendBatch(c.ctx, countTablesBatch)
err := countTablesResults.Close()
if err != nil {
log.Errorf("error while closing statement batch: %v", err)
return 0, fmt.Errorf("error while closing statement batch: %w", err)
}
return totalCount, nil
}

func (c *PostgresConnector) getCurrentLSN() (pglogrepl.LSN, error) {
row := c.pool.QueryRow(c.ctx, "SELECT pg_current_wal_lsn();")
var result string
Expand Down
7 changes: 4 additions & 3 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,11 +635,11 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab
}()

for tableIdentifier, tableSchema := range req.TableNameSchemaMapping {
normalizedTableNameComponents, err := utils.ParseSchemaTable(tableIdentifier)
parsedNormalizedTable, err := utils.ParseSchemaTable(tableIdentifier)
if err != nil {
return nil, fmt.Errorf("error while parsing table schema and name: %w", err)
}
tableAlreadyExists, err := c.tableExists(normalizedTableNameComponents)
tableAlreadyExists, err := c.tableExists(parsedNormalizedTable)
if err != nil {
return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err)
}
Expand All @@ -649,7 +649,8 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab
}

// convert the column names and types to Postgres types
normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(tableIdentifier, tableSchema)
normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(
parsedNormalizedTable.String(), tableSchema)
_, err = createNormalizedTablesTx.Exec(c.ctx, normalizedTableCreateSQL)
if err != nil {
return nil, fmt.Errorf("error while creating normalized table: %w", err)
Expand Down

0 comments on commit c92bffa

Please sign in to comment.