Skip to content

Commit

Permalink
Validate mirror: publication creation (#1299)
Browse files Browse the repository at this point in the history
Validate mirror now also checks: if no publication provided, we can
create (and drop) a publication for the selected tables.
  • Loading branch information
Amogh-Bharadwaj authored Feb 15, 2024
1 parent 8e29852 commit 99de22d
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 38 deletions.
38 changes: 31 additions & 7 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"fmt"
"log/slog"
"strings"

connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
)

Expand Down Expand Up @@ -49,16 +51,38 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}

// Check source tables
sourceTables := make([]string, 0, len(req.ConnectionConfigs.TableMappings))
sourceTables := make([]*utils.SchemaTable, 0, len(req.ConnectionConfigs.TableMappings))
for _, tableMapping := range req.ConnectionConfigs.TableMappings {
sourceTables = append(sourceTables, tableMapping.SourceTableIdentifier)
parsedTable, parseErr := utils.ParseSchemaTable(tableMapping.SourceTableIdentifier)
if parseErr != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("invalid source table identifier: %s", tableMapping.SourceTableIdentifier)
}

sourceTables = append(sourceTables, parsedTable)
}

err = pgPeer.CheckSourceTables(ctx, sourceTables, req.ConnectionConfigs.PublicationName)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("provided source tables invalidated: %v", err)
pubName := req.ConnectionConfigs.PublicationName
if pubName == "" {
pubTables := make([]string, 0, len(sourceTables))
for _, table := range sourceTables {
pubTables = append(pubTables, table.String())
}
pubTableStr := strings.Join(pubTables, ", ")
pubErr := pgPeer.CheckPublicationPermission(ctx, pubTableStr)
if pubErr != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("failed to check publication permission: %v", pubErr)
}
} else {
err = pgPeer.CheckSourceTables(ctx, sourceTables, req.ConnectionConfigs.PublicationName)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("provided source tables invalidated: %v", err)
}
}

return &protos.ValidateCDCMirrorResponse{
Expand Down
10 changes: 5 additions & 5 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,9 @@ func (c *PostgresConnector) checkSlotAndPublication(ctx context.Context, slot st
func getSlotInfo(ctx context.Context, conn *pgx.Conn, slotName string, database string) ([]*protos.SlotInfo, error) {
var whereClause string
if slotName != "" {
whereClause = fmt.Sprintf("WHERE slot_name=%s", QuoteLiteral(slotName))
whereClause = "WHERE slot_name=" + QuoteLiteral(slotName)
} else {
whereClause = fmt.Sprintf("WHERE database=%s", QuoteLiteral(database))
whereClause = "WHERE database=" + QuoteLiteral(database)
}

hasWALStatus, _, err := majorVersionCheck(ctx, conn, POSTGRES_13)
Expand Down Expand Up @@ -449,12 +449,12 @@ func generateCreateTableSQLForNormalizedTable(

if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`%s BOOL DEFAULT FALSE`, QuoteIdentifier(softDeleteColName)))
QuoteIdentifier(softDeleteColName)+` BOOL DEFAULT FALSE`)
}

if syncedAtColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP`, QuoteIdentifier(syncedAtColName)))
QuoteIdentifier(syncedAtColName)+` TIMESTAMP DEFAULT CURRENT_TIMESTAMP`)
}

// add composite primary key to the table
Expand Down Expand Up @@ -623,5 +623,5 @@ func (c *PostgresConnector) getCurrentLSN(ctx context.Context) (pglogrepl.LSN, e
}

func (c *PostgresConnector) getDefaultPublicationName(jobName string) string {
return fmt.Sprintf("peerflow_pub_%s", jobName)
return "peerflow_pub_" + jobName
}
98 changes: 72 additions & 26 deletions flow/connectors/postgres/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,55 @@ import (
"strings"

"github.com/jackc/pgx/v5"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/shared"
)

func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []string, pubName string) error {
func (c *PostgresConnector) CheckSourceTables(ctx context.Context,
tableNames []*utils.SchemaTable, pubName string,
) error {
if c.conn == nil {
return errors.New("check tables: conn is nil")
}

// Check that we can select from all tables
tableArr := make([]string, 0, len(tableNames))
for _, table := range tableNames {
for _, parsedTable := range tableNames {
var row pgx.Row
schemaName, tableName, found := strings.Cut(table, ".")
if !found {
return fmt.Errorf("invalid source table identifier: %s", table)
}

tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, QuoteLiteral(schemaName), QuoteLiteral(tableName)))
tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`,
QuoteLiteral(parsedTable.Schema), QuoteLiteral(parsedTable.Table)))
err := c.conn.QueryRow(ctx,
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName))).Scan(&row)
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;",
QuoteIdentifier(parsedTable.Schema), QuoteIdentifier(parsedTable.Table))).Scan(&row)
if err != nil && err != pgx.ErrNoRows {
return err
}
}

tableStr := strings.Join(tableArr, ",")
if pubName != "" {
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}

// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}
if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}

return nil
Expand Down Expand Up @@ -112,6 +112,52 @@ func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, use
return nil
}

func (c *PostgresConnector) CheckPublicationPermission(ctx context.Context, tableNameString string) error {
publication := "_PEERDB_DUMMY_PUBLICATION_" + shared.RandomString(4)
// check and enable publish_via_partition_root
supportsPubViaRoot, _, err := c.MajorVersionCheck(ctx, POSTGRES_13)
if err != nil {
return fmt.Errorf("error checking Postgres version: %w", err)
}
var pubViaRootString string
if supportsPubViaRoot {
pubViaRootString = "WITH(publish_via_partition_root=true)"
}
tx, err := c.conn.Begin(ctx)
if err != nil {
return fmt.Errorf("error starting transaction: %w", err)
}
defer func() {
err := tx.Rollback(ctx)
if err != nil && err != pgx.ErrTxClosed {
c.logger.Error("[validate publication create] failed to rollback transaction", "error", err)
}
}()

// Create the publication
createStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s %s",
publication, tableNameString, pubViaRootString)
_, err = tx.Exec(ctx, createStmt)
if err != nil {
return fmt.Errorf("it will not be possible to create a publication for selected tables: %w", err)
}

// Drop the publication
dropStmt := "DROP PUBLICATION IF EXISTS " + publication
_, err = tx.Exec(ctx, dropStmt)
if err != nil {
return fmt.Errorf("it will not be possible to drop the publication created for this mirror: %w",
err)
}

// commit transaction
err = tx.Commit(ctx)
if err != nil {
return fmt.Errorf("unable to validate publication create permission: %w", err)
}
return nil
}

func (c *PostgresConnector) CheckReplicationConnectivity(ctx context.Context) error {
// Check if we can create a replication connection
conn, err := c.CreateReplConn(ctx)
Expand Down

0 comments on commit 99de22d

Please sign in to comment.