diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 82ba655988..f707a04347 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "regexp" - "strconv" "strings" "github.com/jackc/pglogrepl" @@ -17,7 +16,6 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/numeric" - "github.com/PeerDB-io/peer-flow/shared" ) type PGVersion int @@ -627,148 +625,3 @@ func (c *PostgresConnector) getCurrentLSN(ctx context.Context) (pglogrepl.LSN, e func (c *PostgresConnector) getDefaultPublicationName(jobName string) string { return "peerflow_pub_" + jobName } - -func (c *PostgresConnector) CheckSourceTables(ctx context.Context, - tableNames []*utils.SchemaTable, pubName string, -) error { - if c.conn == nil { - return errors.New("check tables: conn is nil") - } - - // Check that we can select from all tables - tableArr := make([]string, 0, len(tableNames)) - for _, parsedTable := range tableNames { - var row pgx.Row - tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, - QuoteLiteral(parsedTable.Schema), QuoteLiteral(parsedTable.Table))) - err := c.conn.QueryRow(ctx, - fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", - QuoteIdentifier(parsedTable.Schema), QuoteIdentifier(parsedTable.Table))).Scan(&row) - if err != nil && err != pgx.ErrNoRows { - return err - } - } - - tableStr := strings.Join(tableArr, ",") - // Check if publication exists - err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil) - if err != nil { - if err == pgx.ErrNoRows { - return fmt.Errorf("publication does not exist: %s", pubName) - } - return fmt.Errorf("error while checking for publication existence: %w", err) - } - - // Check if tables belong to publication - var pubTableCount int - err = c.conn.QueryRow(ctx, fmt.Sprintf(` - with source_table_components (sname, tname) as (values %s) - select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables - INNER JOIN source_table_components stc - ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount) - if err != nil { - return err - } - - if pubTableCount != len(tableNames) { - return errors.New("not all tables belong to publication") - } - - return nil -} - -func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, username string) error { - if c.conn == nil { - return errors.New("check replication permissions: conn is nil") - } - - var replicationRes bool - err := c.conn.QueryRow(ctx, "SELECT rolreplication FROM pg_roles WHERE rolname = $1", username).Scan(&replicationRes) - if err != nil { - return err - } - - if !replicationRes { - // RDS case: check pg_settings for rds.logical_replication - var setting string - err := c.conn.QueryRow(ctx, "SELECT setting FROM pg_settings WHERE name = 'rds.logical_replication'").Scan(&setting) - if err != nil || setting != "on" { - return errors.New("postgres user does not have replication role") - } - } - - // check wal_level - var walLevel string - err = c.conn.QueryRow(ctx, "SHOW wal_level").Scan(&walLevel) - if err != nil { - return err - } - - if walLevel != "logical" { - return errors.New("wal_level is not logical") - } - - // max_wal_senders must be at least 2 - var maxWalSendersRes string - err = c.conn.QueryRow(ctx, "SHOW max_wal_senders").Scan(&maxWalSendersRes) - if err != nil { - return err - } - - maxWalSenders, err := strconv.Atoi(maxWalSendersRes) - if err != nil { - return err - } - - if maxWalSenders < 2 { - return errors.New("max_wal_senders must be at least 2") - } - - return nil -} - -func (c *PostgresConnector) CheckPublicationPermission(ctx context.Context, tableNameString string) error { - publication := "_PEERDB_DUMMY_PUBLICATION_" + shared.RandomString(4) - // check and enable publish_via_partition_root - supportsPubViaRoot, _, err := c.MajorVersionCheck(ctx, POSTGRES_13) - if err != nil { - return fmt.Errorf("error checking Postgres version: %w", err) - } - var pubViaRootString string - if supportsPubViaRoot { - pubViaRootString = "WITH(publish_via_partition_root=true)" - } - tx, err := c.conn.Begin(ctx) - if err != nil { - return fmt.Errorf("error starting transaction: %w", err) - } - defer func() { - err := tx.Rollback(ctx) - if err != nil && err != pgx.ErrTxClosed { - c.logger.Error("[validate publication create] failed to rollback transaction", "error", err) - } - }() - - // Create the publication - createStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s %s", - publication, tableNameString, pubViaRootString) - _, err = tx.Exec(ctx, createStmt) - if err != nil { - return fmt.Errorf("it will not be possible to create a publication for selected tables: %w", err) - } - - // Drop the publication - dropStmt := "DROP PUBLICATION IF EXISTS " + publication - _, err = tx.Exec(ctx, dropStmt) - if err != nil { - return fmt.Errorf("it will not be possible to drop the publication created for this mirror: %w", - err) - } - - // commit transaction - err = tx.Commit(ctx) - if err != nil { - return fmt.Errorf("unable to validate publication create permission: %w", err) - } - return nil -} diff --git a/flow/connectors/postgres/validate.go b/flow/connectors/postgres/validate.go index 035623cab8..f31a3103f8 100644 --- a/flow/connectors/postgres/validate.go +++ b/flow/connectors/postgres/validate.go @@ -8,55 +8,55 @@ import ( "strings" "github.com/jackc/pgx/v5" + + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/shared" ) -func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []string, pubName string) error { +func (c *PostgresConnector) CheckSourceTables(ctx context.Context, + tableNames []*utils.SchemaTable, pubName string, +) error { if c.conn == nil { return errors.New("check tables: conn is nil") } // Check that we can select from all tables tableArr := make([]string, 0, len(tableNames)) - for _, table := range tableNames { + for _, parsedTable := range tableNames { var row pgx.Row - schemaName, tableName, found := strings.Cut(table, ".") - if !found { - return fmt.Errorf("invalid source table identifier: %s", table) - } - - tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, QuoteLiteral(schemaName), QuoteLiteral(tableName))) + tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, + QuoteLiteral(parsedTable.Schema), QuoteLiteral(parsedTable.Table))) err := c.conn.QueryRow(ctx, - fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName))).Scan(&row) + fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", + QuoteIdentifier(parsedTable.Schema), QuoteIdentifier(parsedTable.Table))).Scan(&row) if err != nil && err != pgx.ErrNoRows { return err } } tableStr := strings.Join(tableArr, ",") - if pubName != "" { - // Check if publication exists - err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil) - if err != nil { - if err == pgx.ErrNoRows { - return fmt.Errorf("publication does not exist: %s", pubName) - } - return fmt.Errorf("error while checking for publication existence: %w", err) + // Check if publication exists + err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil) + if err != nil { + if err == pgx.ErrNoRows { + return fmt.Errorf("publication does not exist: %s", pubName) } + return fmt.Errorf("error while checking for publication existence: %w", err) + } - // Check if tables belong to publication - var pubTableCount int - err = c.conn.QueryRow(ctx, fmt.Sprintf(` + // Check if tables belong to publication + var pubTableCount int + err = c.conn.QueryRow(ctx, fmt.Sprintf(` with source_table_components (sname, tname) as (values %s) select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables INNER JOIN source_table_components stc ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount) - if err != nil { - return err - } + if err != nil { + return err + } - if pubTableCount != len(tableNames) { - return errors.New("not all tables belong to publication") - } + if pubTableCount != len(tableNames) { + return errors.New("not all tables belong to publication") } return nil @@ -112,6 +112,52 @@ func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, use return nil } +func (c *PostgresConnector) CheckPublicationPermission(ctx context.Context, tableNameString string) error { + publication := "_PEERDB_DUMMY_PUBLICATION_" + shared.RandomString(4) + // check and enable publish_via_partition_root + supportsPubViaRoot, _, err := c.MajorVersionCheck(ctx, POSTGRES_13) + if err != nil { + return fmt.Errorf("error checking Postgres version: %w", err) + } + var pubViaRootString string + if supportsPubViaRoot { + pubViaRootString = "WITH(publish_via_partition_root=true)" + } + tx, err := c.conn.Begin(ctx) + if err != nil { + return fmt.Errorf("error starting transaction: %w", err) + } + defer func() { + err := tx.Rollback(ctx) + if err != nil && err != pgx.ErrTxClosed { + c.logger.Error("[validate publication create] failed to rollback transaction", "error", err) + } + }() + + // Create the publication + createStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s %s", + publication, tableNameString, pubViaRootString) + _, err = tx.Exec(ctx, createStmt) + if err != nil { + return fmt.Errorf("it will not be possible to create a publication for selected tables: %w", err) + } + + // Drop the publication + dropStmt := "DROP PUBLICATION IF EXISTS " + publication + _, err = tx.Exec(ctx, dropStmt) + if err != nil { + return fmt.Errorf("it will not be possible to drop the publication created for this mirror: %w", + err) + } + + // commit transaction + err = tx.Commit(ctx) + if err != nil { + return fmt.Errorf("unable to validate publication create permission: %w", err) + } + return nil +} + func (c *PostgresConnector) CheckReplicationConnectivity(ctx context.Context) error { // Check if we can create a replication connection conn, err := c.CreateReplConn(ctx)