Skip to content

Commit

Permalink
Validate mirror: account for mixed case (#1123)
Browse files Browse the repository at this point in the history
Forgot to quote the source table identifiers we use in validate mirror's
source table check

---------

Co-authored-by: Philip Dubé <[email protected]>
  • Loading branch information
Amogh-Bharadwaj and serprex authored Jan 22, 2024
1 parent 6ef3707 commit 785facf
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -587,26 +587,31 @@ func (c *PostgresConnector) CheckSourceTables(tableNames []string, pubName strin
}

// Check that we can select from all tables
for _, tableName := range tableNames {
tableArr := make([]string, 0, len(tableNames))
for _, table := range tableNames {
var row pgx.Row
err := c.pool.QueryRow(c.ctx, fmt.Sprintf("SELECT * FROM %s LIMIT 0;", tableName)).Scan(&row)
schemaName, tableName, found := strings.Cut(table, ".")
if !found {
return fmt.Errorf("invalid source table identifier: %s", table)
}

tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, QuoteLiteral(schemaName), QuoteLiteral(tableName)))
err := c.pool.QueryRow(c.ctx,
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName))).Scan(&row)
if err != nil && err != pgx.ErrNoRows {
return err
}
}

// Check if tables belong to publication
tableArr := make([]string, 0, len(tableNames))
for _, tableName := range tableNames {
tableArr = append(tableArr, fmt.Sprintf("'%s'", tableName))
}

tableStr := strings.Join(tableArr, ",")

if pubName != "" {
var pubTableCount int
err := c.pool.QueryRow(c.ctx, fmt.Sprintf("select COUNT(DISTINCT(schemaname||'.'||tablename)) from pg_publication_tables "+
"where schemaname||'.'||tablename in (%s) and pubname=$1;", tableStr), pubName).Scan(&pubTableCount)
err := c.pool.QueryRow(c.ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}
Expand Down

0 comments on commit 785facf

Please sign in to comment.