Skip to content

Commit

Permalink
add more publication validation
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed Feb 15, 2024
1 parent 59d7ee2 commit cc617dc
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 42 deletions.
38 changes: 31 additions & 7 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"fmt"
"log/slog"
"strings"

connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
)

Expand Down Expand Up @@ -41,16 +43,38 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}

// Check source tables
sourceTables := make([]string, 0, len(req.ConnectionConfigs.TableMappings))
sourceTables := make([]*utils.SchemaTable, 0, len(req.ConnectionConfigs.TableMappings))
for _, tableMapping := range req.ConnectionConfigs.TableMappings {
sourceTables = append(sourceTables, tableMapping.SourceTableIdentifier)
parsedTable, parseErr := utils.ParseSchemaTable(tableMapping.SourceTableIdentifier)
if parseErr != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("invalid source table identifier: %s", tableMapping.SourceTableIdentifier)
}

sourceTables = append(sourceTables, parsedTable)
}

err = pgPeer.CheckSourceTables(ctx, sourceTables, req.ConnectionConfigs.PublicationName)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("provided source tables invalidated: %v", err)
pubName := req.ConnectionConfigs.PublicationName
if pubName == "" {
pubTables := make([]string, 0, len(sourceTables))
for _, table := range sourceTables {
pubTables = append(pubTables, table.String())
}
pubTableStr := strings.Join(pubTables, ", ")
pubErr := pgPeer.CheckPublicationPermission(ctx, pubTableStr)
if pubErr != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("failed to check publication permission: %v", pubErr)
}
} else {
err = pgPeer.CheckSourceTables(ctx, sourceTables, req.ConnectionConfigs.PublicationName)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("provided source tables invalidated: %v", err)
}
}

return &protos.ValidateCDCMirrorResponse{
Expand Down
114 changes: 79 additions & 35 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/numeric"
"github.com/PeerDB-io/peer-flow/shared"
)

type PGVersion int
Expand Down Expand Up @@ -261,9 +262,9 @@ func (c *PostgresConnector) checkSlotAndPublication(ctx context.Context, slot st
func getSlotInfo(ctx context.Context, conn *pgx.Conn, slotName string, database string) ([]*protos.SlotInfo, error) {
var whereClause string
if slotName != "" {
whereClause = fmt.Sprintf("WHERE slot_name=%s", QuoteLiteral(slotName))
whereClause = "WHERE slot_name=" + QuoteLiteral(slotName)
} else {
whereClause = fmt.Sprintf("WHERE database=%s", QuoteLiteral(database))
whereClause = "WHERE database=" + QuoteLiteral(database)
}

hasWALStatus, _, err := majorVersionCheck(ctx, conn, POSTGRES_13)
Expand Down Expand Up @@ -450,12 +451,12 @@ func generateCreateTableSQLForNormalizedTable(

if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`%s BOOL DEFAULT FALSE`, QuoteIdentifier(softDeleteColName)))
QuoteIdentifier(softDeleteColName)+`%s BOOL DEFAULT FALSE`)
}

if syncedAtColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP`, QuoteIdentifier(syncedAtColName)))
QuoteIdentifier(syncedAtColName)+` TIMESTAMP DEFAULT CURRENT_TIMESTAMP`)
}

// add composite primary key to the table
Expand Down Expand Up @@ -624,64 +625,61 @@ func (c *PostgresConnector) getCurrentLSN(ctx context.Context) (pglogrepl.LSN, e
}

func (c *PostgresConnector) getDefaultPublicationName(jobName string) string {
return fmt.Sprintf("peerflow_pub_%s", jobName)
return "peerflow_pub_" + jobName
}

func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []string, pubName string) error {
func (c *PostgresConnector) CheckSourceTables(ctx context.Context,
tableNames []*utils.SchemaTable, pubName string,
) error {
if c.conn == nil {
return errors.New("check tables: conn is nil")
}

// Check that we can select from all tables
tableArr := make([]string, 0, len(tableNames))
for _, table := range tableNames {
for _, parsedTable := range tableNames {
var row pgx.Row
schemaName, tableName, found := strings.Cut(table, ".")
if !found {
return fmt.Errorf("invalid source table identifier: %s", table)
}

tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, QuoteLiteral(schemaName), QuoteLiteral(tableName)))
tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`,
QuoteLiteral(parsedTable.Schema), QuoteLiteral(parsedTable.Table)))
err := c.conn.QueryRow(ctx,
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName))).Scan(&row)
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;",
QuoteIdentifier(parsedTable.Schema), QuoteIdentifier(parsedTable.Table))).Scan(&row)
if err != nil && err != pgx.ErrNoRows {
return err
}
}

tableStr := strings.Join(tableArr, ",")
if pubName != "" {
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}

// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}
if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}

return nil
}

func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, username string) error {
if c.conn == nil {
return fmt.Errorf("check replication permissions: conn is nil")
return errors.New("check replication permissions: conn is nil")
}

var replicationRes bool
Expand All @@ -695,7 +693,7 @@ func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, use
var setting string
err := c.conn.QueryRow(ctx, "SELECT setting FROM pg_settings WHERE name = 'rds.logical_replication'").Scan(&setting)
if err != nil || setting != "on" {
return fmt.Errorf("postgres user does not have replication role")
return errors.New("postgres user does not have replication role")
}
}

Expand All @@ -707,7 +705,7 @@ func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, use
}

if walLevel != "logical" {
return fmt.Errorf("wal_level is not logical")
return errors.New("wal_level is not logical")
}

// max_wal_senders must be at least 2
Expand All @@ -723,8 +721,54 @@ func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, use
}

if maxWalSenders < 2 {
return fmt.Errorf("max_wal_senders must be at least 2")
return errors.New("max_wal_senders must be at least 2")
}

return nil
}

func (c *PostgresConnector) CheckPublicationPermission(ctx context.Context, tableNameString string) error {
publication := "_PEERDB_DUMMY_PUBLICATION_" + shared.RandomString(4)
// check and enable publish_via_partition_root
supportsPubViaRoot, _, err := c.MajorVersionCheck(ctx, POSTGRES_13)
if err != nil {
return fmt.Errorf("error checking Postgres version: %w", err)
}
var pubViaRootString string
if supportsPubViaRoot {
pubViaRootString = "WITH(publish_via_partition_root=true)"
}
tx, err := c.conn.Begin(ctx)
if err != nil {
return fmt.Errorf("error starting transaction: %w", err)
}
defer func() {
err := tx.Rollback(ctx)
if err != nil && err != pgx.ErrTxClosed {
c.logger.Error("[validate publication create] failed to rollback transaction", "error", err)
}
}()

// Create the publication
createStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s %s",
publication, tableNameString, pubViaRootString)
_, err = tx.Exec(ctx, createStmt)
if err != nil {
return fmt.Errorf("it will not be possible to create a publication for selected tables: %w", err)
}

// Drop the publication
dropStmt := "DROP PUBLICATION IF EXISTS " + publication
_, err = tx.Exec(ctx, dropStmt)
if err != nil {
return fmt.Errorf("it will not be possible to drop the publication created for this mirror: %w",
err)
}

// commit transaction
err = tx.Commit(ctx)
if err != nil {
return fmt.Errorf("unable to validate publication create permission: %w", err)
}
return nil
}

0 comments on commit cc617dc

Please sign in to comment.