diff --git a/flow/cmd/validate_mirror.go b/flow/cmd/validate_mirror.go index 848ce4138a..75c46532fb 100644 --- a/flow/cmd/validate_mirror.go +++ b/flow/cmd/validate_mirror.go @@ -4,8 +4,10 @@ import ( "context" "fmt" "log/slog" + "strings" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" ) @@ -41,16 +43,38 @@ func (h *FlowRequestHandler) ValidateCDCMirror( } // Check source tables - sourceTables := make([]string, 0, len(req.ConnectionConfigs.TableMappings)) + sourceTables := make([]*utils.SchemaTable, 0, len(req.ConnectionConfigs.TableMappings)) for _, tableMapping := range req.ConnectionConfigs.TableMappings { - sourceTables = append(sourceTables, tableMapping.SourceTableIdentifier) + parsedTable, parseErr := utils.ParseSchemaTable(tableMapping.SourceTableIdentifier) + if parseErr != nil { + return &protos.ValidateCDCMirrorResponse{ + Ok: false, + }, fmt.Errorf("invalid source table identifier: %s", tableMapping.SourceTableIdentifier) + } + + sourceTables = append(sourceTables, parsedTable) } - err = pgPeer.CheckSourceTables(ctx, sourceTables, req.ConnectionConfigs.PublicationName) - if err != nil { - return &protos.ValidateCDCMirrorResponse{ - Ok: false, - }, fmt.Errorf("provided source tables invalidated: %v", err) + pubName := req.ConnectionConfigs.PublicationName + if pubName == "" { + pubTables := make([]string, 0, len(sourceTables)) + for _, table := range sourceTables { + pubTables = append(pubTables, table.String()) + } + pubTableStr := strings.Join(pubTables, ", ") + pubErr := pgPeer.CheckPublicationPermission(ctx, pubTableStr) + if pubErr != nil { + return &protos.ValidateCDCMirrorResponse{ + Ok: false, + }, fmt.Errorf("failed to check publication permission: %v", pubErr) + } + } else { + err = pgPeer.CheckSourceTables(ctx, sourceTables, req.ConnectionConfigs.PublicationName) + if err != nil { + return &protos.ValidateCDCMirrorResponse{ + Ok: false, + }, fmt.Errorf("provided source tables invalidated: %v", err) + } } return &protos.ValidateCDCMirrorResponse{ diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 563e0c6824..1d2cfe17e0 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -17,6 +17,7 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/numeric" + "github.com/PeerDB-io/peer-flow/shared" ) type PGVersion int @@ -261,9 +262,9 @@ func (c *PostgresConnector) checkSlotAndPublication(ctx context.Context, slot st func getSlotInfo(ctx context.Context, conn *pgx.Conn, slotName string, database string) ([]*protos.SlotInfo, error) { var whereClause string if slotName != "" { - whereClause = fmt.Sprintf("WHERE slot_name=%s", QuoteLiteral(slotName)) + whereClause = "WHERE slot_name=" + QuoteLiteral(slotName) } else { - whereClause = fmt.Sprintf("WHERE database=%s", QuoteLiteral(database)) + whereClause = "WHERE database=" + QuoteLiteral(database) } hasWALStatus, _, err := majorVersionCheck(ctx, conn, POSTGRES_13) @@ -450,12 +451,12 @@ func generateCreateTableSQLForNormalizedTable( if softDeleteColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`%s BOOL DEFAULT FALSE`, QuoteIdentifier(softDeleteColName))) + QuoteIdentifier(softDeleteColName)+`%s BOOL DEFAULT FALSE`) } if syncedAtColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP`, QuoteIdentifier(syncedAtColName))) + QuoteIdentifier(syncedAtColName)+` TIMESTAMP DEFAULT CURRENT_TIMESTAMP`) } // add composite primary key to the table @@ -624,56 +625,53 @@ func (c *PostgresConnector) getCurrentLSN(ctx context.Context) (pglogrepl.LSN, e } func (c *PostgresConnector) getDefaultPublicationName(jobName string) string { - return fmt.Sprintf("peerflow_pub_%s", jobName) + return "peerflow_pub_" + jobName } -func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []string, pubName string) error { +func (c *PostgresConnector) CheckSourceTables(ctx context.Context, + tableNames []*utils.SchemaTable, pubName string, +) error { if c.conn == nil { return errors.New("check tables: conn is nil") } // Check that we can select from all tables tableArr := make([]string, 0, len(tableNames)) - for _, table := range tableNames { + for _, parsedTable := range tableNames { var row pgx.Row - schemaName, tableName, found := strings.Cut(table, ".") - if !found { - return fmt.Errorf("invalid source table identifier: %s", table) - } - - tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, QuoteLiteral(schemaName), QuoteLiteral(tableName))) + tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, + QuoteLiteral(parsedTable.Schema), QuoteLiteral(parsedTable.Table))) err := c.conn.QueryRow(ctx, - fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName))).Scan(&row) + fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", + QuoteIdentifier(parsedTable.Schema), QuoteIdentifier(parsedTable.Table))).Scan(&row) if err != nil && err != pgx.ErrNoRows { return err } } tableStr := strings.Join(tableArr, ",") - if pubName != "" { - // Check if publication exists - err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil) - if err != nil { - if err == pgx.ErrNoRows { - return fmt.Errorf("publication does not exist: %s", pubName) - } - return fmt.Errorf("error while checking for publication existence: %w", err) + // Check if publication exists + err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil) + if err != nil { + if err == pgx.ErrNoRows { + return fmt.Errorf("publication does not exist: %s", pubName) } + return fmt.Errorf("error while checking for publication existence: %w", err) + } - // Check if tables belong to publication - var pubTableCount int - err = c.conn.QueryRow(ctx, fmt.Sprintf(` + // Check if tables belong to publication + var pubTableCount int + err = c.conn.QueryRow(ctx, fmt.Sprintf(` with source_table_components (sname, tname) as (values %s) select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables INNER JOIN source_table_components stc ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount) - if err != nil { - return err - } + if err != nil { + return err + } - if pubTableCount != len(tableNames) { - return errors.New("not all tables belong to publication") - } + if pubTableCount != len(tableNames) { + return errors.New("not all tables belong to publication") } return nil @@ -681,7 +679,7 @@ func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames [] func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, username string) error { if c.conn == nil { - return fmt.Errorf("check replication permissions: conn is nil") + return errors.New("check replication permissions: conn is nil") } var replicationRes bool @@ -695,7 +693,7 @@ func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, use var setting string err := c.conn.QueryRow(ctx, "SELECT setting FROM pg_settings WHERE name = 'rds.logical_replication'").Scan(&setting) if err != nil || setting != "on" { - return fmt.Errorf("postgres user does not have replication role") + return errors.New("postgres user does not have replication role") } } @@ -707,7 +705,7 @@ func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, use } if walLevel != "logical" { - return fmt.Errorf("wal_level is not logical") + return errors.New("wal_level is not logical") } // max_wal_senders must be at least 2 @@ -723,8 +721,54 @@ func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, use } if maxWalSenders < 2 { - return fmt.Errorf("max_wal_senders must be at least 2") + return errors.New("max_wal_senders must be at least 2") } return nil } + +func (c *PostgresConnector) CheckPublicationPermission(ctx context.Context, tableNameString string) error { + publication := "_PEERDB_DUMMY_PUBLICATION_" + shared.RandomString(4) + // check and enable publish_via_partition_root + supportsPubViaRoot, _, err := c.MajorVersionCheck(ctx, POSTGRES_13) + if err != nil { + return fmt.Errorf("error checking Postgres version: %w", err) + } + var pubViaRootString string + if supportsPubViaRoot { + pubViaRootString = "WITH(publish_via_partition_root=true)" + } + tx, err := c.conn.Begin(ctx) + if err != nil { + return fmt.Errorf("error starting transaction: %w", err) + } + defer func() { + err := tx.Rollback(ctx) + if err != nil && err != pgx.ErrTxClosed { + c.logger.Error("[validate publication create] failed to rollback transaction", "error", err) + } + }() + + // Create the publication + createStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s %s", + publication, tableNameString, pubViaRootString) + _, err = tx.Exec(ctx, createStmt) + if err != nil { + return fmt.Errorf("it will not be possible to create a publication for selected tables: %w", err) + } + + // Drop the publication + dropStmt := "DROP PUBLICATION IF EXISTS " + publication + _, err = tx.Exec(ctx, dropStmt) + if err != nil { + return fmt.Errorf("it will not be possible to drop the publication created for this mirror: %w", + err) + } + + // commit transaction + err = tx.Commit(ctx) + if err != nil { + return fmt.Errorf("unable to validate publication create permission: %w", err) + } + return nil +}