diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 8c582b7b05..f9a4c5126e 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -602,7 +602,7 @@ func (c *PostgresConnector) getDefaultPublicationName(jobName string) string { func (c *PostgresConnector) CheckSourceTables(tableNames []string, pubName string) error { if c.conn == nil { - return fmt.Errorf("check tables: conn is nil") + return errors.New("check tables: conn is nil") } // Check that we can select from all tables @@ -625,8 +625,18 @@ func (c *PostgresConnector) CheckSourceTables(tableNames []string, pubName strin // Check if tables belong to publication tableStr := strings.Join(tableArr, ",") if pubName != "" { + // Check if publication exists + err := c.conn.QueryRow(c.ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil) + if err != nil { + if err == pgx.ErrNoRows { + return fmt.Errorf("publication does not exist: %s", pubName) + } + return fmt.Errorf("error while checking for publication existence: %w", err) + } + + // Check if tables belong to publication var pubTableCount int - err := c.conn.QueryRow(c.ctx, fmt.Sprintf(` + err = c.conn.QueryRow(c.ctx, fmt.Sprintf(` with source_table_components (sname, tname) as (values %s) select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables INNER JOIN source_table_components stc @@ -636,7 +646,7 @@ func (c *PostgresConnector) CheckSourceTables(tableNames []string, pubName strin } if pubTableCount != len(tableNames) { - return fmt.Errorf("not all tables belong to publication") + return errors.New("not all tables belong to publication") } } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 5138fea606..a964c66584 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -995,7 +995,7 @@ func (c *PostgresConnector) AddTablesToPublication(req *protos.AddTablesToPublic // just check if we have all the tables already in the publication if req.PublicationName != "" { rows, err := c.conn.Query(c.ctx, - "SELECT tablename FROM pg_publication_tables WHERE pubname=$1", req.PublicationName) + "SELECT schemaname || '.' || tablename FROM pg_publication_tables WHERE pubname=$1", req.PublicationName) if err != nil { return fmt.Errorf("failed to check tables in publication: %w", err) } @@ -1004,18 +1004,29 @@ func (c *PostgresConnector) AddTablesToPublication(req *protos.AddTablesToPublic if err != nil { return fmt.Errorf("failed to check tables in publication: %w", err) } - notPresentTables := utils.ArrayMinus(tableNames, additionalSrcTables) + notPresentTables := utils.ArrayMinus(additionalSrcTables, tableNames) if len(notPresentTables) > 0 { return fmt.Errorf("some additional tables not present in custom publication: %s", strings.Join(notPresentTables, ", ")) } + } else { + for _, additionalSrcTable := range additionalSrcTables { + schemaTable, err := utils.ParseSchemaTable(additionalSrcTable) + if err != nil { + return err + } + _, err = c.conn.Exec(c.ctx, fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s", + utils.QuoteIdentifier(c.getDefaultPublicationName(req.FlowJobName)), + schemaTable.String())) + // don't error out if table is already added to our publication + if err != nil && !strings.Contains(err.Error(), "SQLSTATE 42710") { + return fmt.Errorf("failed to alter publication: %w", err) + } + c.logger.Info("added table to publication", + slog.String("publication", c.getDefaultPublicationName(req.FlowJobName)), + slog.String("table", additionalSrcTable)) + } } - additionalSrcTablesString := strings.Join(additionalSrcTables, ",") - _, err := c.conn.Exec(c.ctx, fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s", - c.getDefaultPublicationName(req.FlowJobName), additionalSrcTablesString)) - if err != nil { - return fmt.Errorf("failed to alter publication: %w", err) - } return nil }