Skip to content

Commit

Permalink
move to validate.go
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed Feb 15, 2024
1 parent 467f331 commit e696b8f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 173 deletions.
147 changes: 0 additions & 147 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"

"github.com/jackc/pglogrepl"
Expand All @@ -17,7 +16,6 @@ import (
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/numeric"
"github.com/PeerDB-io/peer-flow/shared"
)

type PGVersion int
Expand Down Expand Up @@ -627,148 +625,3 @@ func (c *PostgresConnector) getCurrentLSN(ctx context.Context) (pglogrepl.LSN, e
func (c *PostgresConnector) getDefaultPublicationName(jobName string) string {
return "peerflow_pub_" + jobName
}

func (c *PostgresConnector) CheckSourceTables(ctx context.Context,
tableNames []*utils.SchemaTable, pubName string,
) error {
if c.conn == nil {
return errors.New("check tables: conn is nil")
}

// Check that we can select from all tables
tableArr := make([]string, 0, len(tableNames))
for _, parsedTable := range tableNames {
var row pgx.Row
tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`,
QuoteLiteral(parsedTable.Schema), QuoteLiteral(parsedTable.Table)))
err := c.conn.QueryRow(ctx,
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;",
QuoteIdentifier(parsedTable.Schema), QuoteIdentifier(parsedTable.Table))).Scan(&row)
if err != nil && err != pgx.ErrNoRows {
return err
}
}

tableStr := strings.Join(tableArr, ",")
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}

// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}

return nil
}

func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, username string) error {
if c.conn == nil {
return errors.New("check replication permissions: conn is nil")
}

var replicationRes bool
err := c.conn.QueryRow(ctx, "SELECT rolreplication FROM pg_roles WHERE rolname = $1", username).Scan(&replicationRes)
if err != nil {
return err
}

if !replicationRes {
// RDS case: check pg_settings for rds.logical_replication
var setting string
err := c.conn.QueryRow(ctx, "SELECT setting FROM pg_settings WHERE name = 'rds.logical_replication'").Scan(&setting)
if err != nil || setting != "on" {
return errors.New("postgres user does not have replication role")
}
}

// check wal_level
var walLevel string
err = c.conn.QueryRow(ctx, "SHOW wal_level").Scan(&walLevel)
if err != nil {
return err
}

if walLevel != "logical" {
return errors.New("wal_level is not logical")
}

// max_wal_senders must be at least 2
var maxWalSendersRes string
err = c.conn.QueryRow(ctx, "SHOW max_wal_senders").Scan(&maxWalSendersRes)
if err != nil {
return err
}

maxWalSenders, err := strconv.Atoi(maxWalSendersRes)
if err != nil {
return err
}

if maxWalSenders < 2 {
return errors.New("max_wal_senders must be at least 2")
}

return nil
}

func (c *PostgresConnector) CheckPublicationPermission(ctx context.Context, tableNameString string) error {
publication := "_PEERDB_DUMMY_PUBLICATION_" + shared.RandomString(4)
// check and enable publish_via_partition_root
supportsPubViaRoot, _, err := c.MajorVersionCheck(ctx, POSTGRES_13)
if err != nil {
return fmt.Errorf("error checking Postgres version: %w", err)
}
var pubViaRootString string
if supportsPubViaRoot {
pubViaRootString = "WITH(publish_via_partition_root=true)"
}
tx, err := c.conn.Begin(ctx)
if err != nil {
return fmt.Errorf("error starting transaction: %w", err)
}
defer func() {
err := tx.Rollback(ctx)
if err != nil && err != pgx.ErrTxClosed {
c.logger.Error("[validate publication create] failed to rollback transaction", "error", err)
}
}()

// Create the publication
createStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s %s",
publication, tableNameString, pubViaRootString)
_, err = tx.Exec(ctx, createStmt)
if err != nil {
return fmt.Errorf("it will not be possible to create a publication for selected tables: %w", err)
}

// Drop the publication
dropStmt := "DROP PUBLICATION IF EXISTS " + publication
_, err = tx.Exec(ctx, dropStmt)
if err != nil {
return fmt.Errorf("it will not be possible to drop the publication created for this mirror: %w",
err)
}

// commit transaction
err = tx.Commit(ctx)
if err != nil {
return fmt.Errorf("unable to validate publication create permission: %w", err)
}
return nil
}
98 changes: 72 additions & 26 deletions flow/connectors/postgres/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,55 @@ import (
"strings"

"github.com/jackc/pgx/v5"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/shared"
)

func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []string, pubName string) error {
func (c *PostgresConnector) CheckSourceTables(ctx context.Context,
tableNames []*utils.SchemaTable, pubName string,
) error {
if c.conn == nil {
return errors.New("check tables: conn is nil")
}

// Check that we can select from all tables
tableArr := make([]string, 0, len(tableNames))
for _, table := range tableNames {
for _, parsedTable := range tableNames {
var row pgx.Row
schemaName, tableName, found := strings.Cut(table, ".")
if !found {
return fmt.Errorf("invalid source table identifier: %s", table)
}

tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, QuoteLiteral(schemaName), QuoteLiteral(tableName)))
tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`,
QuoteLiteral(parsedTable.Schema), QuoteLiteral(parsedTable.Table)))
err := c.conn.QueryRow(ctx,
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName))).Scan(&row)
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;",
QuoteIdentifier(parsedTable.Schema), QuoteIdentifier(parsedTable.Table))).Scan(&row)
if err != nil && err != pgx.ErrNoRows {
return err
}
}

tableStr := strings.Join(tableArr, ",")
if pubName != "" {
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}

// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}
if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}

return nil
Expand Down Expand Up @@ -112,6 +112,52 @@ func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, use
return nil
}

func (c *PostgresConnector) CheckPublicationPermission(ctx context.Context, tableNameString string) error {
publication := "_PEERDB_DUMMY_PUBLICATION_" + shared.RandomString(4)
// check and enable publish_via_partition_root
supportsPubViaRoot, _, err := c.MajorVersionCheck(ctx, POSTGRES_13)
if err != nil {
return fmt.Errorf("error checking Postgres version: %w", err)
}
var pubViaRootString string
if supportsPubViaRoot {
pubViaRootString = "WITH(publish_via_partition_root=true)"
}
tx, err := c.conn.Begin(ctx)
if err != nil {
return fmt.Errorf("error starting transaction: %w", err)
}
defer func() {
err := tx.Rollback(ctx)
if err != nil && err != pgx.ErrTxClosed {
c.logger.Error("[validate publication create] failed to rollback transaction", "error", err)
}
}()

// Create the publication
createStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s %s",
publication, tableNameString, pubViaRootString)
_, err = tx.Exec(ctx, createStmt)
if err != nil {
return fmt.Errorf("it will not be possible to create a publication for selected tables: %w", err)
}

// Drop the publication
dropStmt := "DROP PUBLICATION IF EXISTS " + publication
_, err = tx.Exec(ctx, dropStmt)
if err != nil {
return fmt.Errorf("it will not be possible to drop the publication created for this mirror: %w",
err)
}

// commit transaction
err = tx.Commit(ctx)
if err != nil {
return fmt.Errorf("unable to validate publication create permission: %w", err)
}
return nil
}

func (c *PostgresConnector) CheckReplicationConnectivity(ctx context.Context) error {
// Check if we can create a replication connection
conn, err := c.CreateReplConn(ctx)
Expand Down

0 comments on commit e696b8f

Please sign in to comment.