diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 51f5d025f5..2527a560c8 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -587,26 +587,31 @@ func (c *PostgresConnector) CheckSourceTables(tableNames []string, pubName strin } // Check that we can select from all tables - for _, tableName := range tableNames { + tableArr := make([]string, 0, len(tableNames)) + for _, table := range tableNames { var row pgx.Row - err := c.pool.QueryRow(c.ctx, fmt.Sprintf("SELECT * FROM %s LIMIT 0;", tableName)).Scan(&row) + schemaName, tableName, found := strings.Cut(table, ".") + if !found { + return fmt.Errorf("invalid source table identifier: %s", table) + } + + tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, QuoteLiteral(schemaName), QuoteLiteral(tableName))) + err := c.pool.QueryRow(c.ctx, + fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName))).Scan(&row) if err != nil && err != pgx.ErrNoRows { return err } } // Check if tables belong to publication - tableArr := make([]string, 0, len(tableNames)) - for _, tableName := range tableNames { - tableArr = append(tableArr, fmt.Sprintf("'%s'", tableName)) - } - tableStr := strings.Join(tableArr, ",") - if pubName != "" { var pubTableCount int - err := c.pool.QueryRow(c.ctx, fmt.Sprintf("select COUNT(DISTINCT(schemaname||'.'||tablename)) from pg_publication_tables "+ - "where schemaname||'.'||tablename in (%s) and pubname=$1;", tableStr), pubName).Scan(&pubTableCount) + err := c.pool.QueryRow(c.ctx, fmt.Sprintf(` + with source_table_components (sname, tname) as (values %s) + select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables + INNER JOIN source_table_components stc + ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount) if err != nil { return err }