Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate mirror: publication creation #1299

Merged
merged 5 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"fmt"
"log/slog"
"strings"

connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
)

Expand Down Expand Up @@ -49,16 +51,38 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}

// Check source tables
sourceTables := make([]string, 0, len(req.ConnectionConfigs.TableMappings))
sourceTables := make([]*utils.SchemaTable, 0, len(req.ConnectionConfigs.TableMappings))
for _, tableMapping := range req.ConnectionConfigs.TableMappings {
sourceTables = append(sourceTables, tableMapping.SourceTableIdentifier)
parsedTable, parseErr := utils.ParseSchemaTable(tableMapping.SourceTableIdentifier)
if parseErr != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("invalid source table identifier: %s", tableMapping.SourceTableIdentifier)
}

sourceTables = append(sourceTables, parsedTable)
}

err = pgPeer.CheckSourceTables(ctx, sourceTables, req.ConnectionConfigs.PublicationName)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("provided source tables invalidated: %v", err)
pubName := req.ConnectionConfigs.PublicationName
if pubName == "" {
pubTables := make([]string, 0, len(sourceTables))
for _, table := range sourceTables {
pubTables = append(pubTables, table.String())
}
pubTableStr := strings.Join(pubTables, ", ")
pubErr := pgPeer.CheckPublicationPermission(ctx, pubTableStr)
if pubErr != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("failed to check publication permission: %v", pubErr)
}
} else {
err = pgPeer.CheckSourceTables(ctx, sourceTables, req.ConnectionConfigs.PublicationName)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("provided source tables invalidated: %v", err)
}
}

return &protos.ValidateCDCMirrorResponse{
Expand Down
10 changes: 5 additions & 5 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,9 @@ func (c *PostgresConnector) checkSlotAndPublication(ctx context.Context, slot st
func getSlotInfo(ctx context.Context, conn *pgx.Conn, slotName string, database string) ([]*protos.SlotInfo, error) {
var whereClause string
if slotName != "" {
whereClause = fmt.Sprintf("WHERE slot_name=%s", QuoteLiteral(slotName))
whereClause = "WHERE slot_name=" + QuoteLiteral(slotName)
} else {
whereClause = fmt.Sprintf("WHERE database=%s", QuoteLiteral(database))
whereClause = "WHERE database=" + QuoteLiteral(database)
}

hasWALStatus, _, err := majorVersionCheck(ctx, conn, POSTGRES_13)
Expand Down Expand Up @@ -449,12 +449,12 @@ func generateCreateTableSQLForNormalizedTable(

if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`%s BOOL DEFAULT FALSE`, QuoteIdentifier(softDeleteColName)))
QuoteIdentifier(softDeleteColName)+` BOOL DEFAULT FALSE`)
}

if syncedAtColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP`, QuoteIdentifier(syncedAtColName)))
QuoteIdentifier(syncedAtColName)+` TIMESTAMP DEFAULT CURRENT_TIMESTAMP`)
}

// add composite primary key to the table
Expand Down Expand Up @@ -623,5 +623,5 @@ func (c *PostgresConnector) getCurrentLSN(ctx context.Context) (pglogrepl.LSN, e
}

func (c *PostgresConnector) getDefaultPublicationName(jobName string) string {
return fmt.Sprintf("peerflow_pub_%s", jobName)
return "peerflow_pub_" + jobName
}
99 changes: 73 additions & 26 deletions flow/connectors/postgres/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,55 @@ import (
"strings"

"github.com/jackc/pgx/v5"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/shared"
)

func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []string, pubName string) error {
func (c *PostgresConnector) CheckSourceTables(ctx context.Context,
tableNames []*utils.SchemaTable, pubName string,
) error {
if c.conn == nil {
return errors.New("check tables: conn is nil")
}

// Check that we can select from all tables
tableArr := make([]string, 0, len(tableNames))
for _, table := range tableNames {
for _, parsedTable := range tableNames {
var row pgx.Row
schemaName, tableName, found := strings.Cut(table, ".")
if !found {
return fmt.Errorf("invalid source table identifier: %s", table)
}

tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, QuoteLiteral(schemaName), QuoteLiteral(tableName)))
tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`,
QuoteLiteral(parsedTable.Schema), QuoteLiteral(parsedTable.Table)))
err := c.conn.QueryRow(ctx,
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName))).Scan(&row)
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;",
QuoteIdentifier(parsedTable.Schema), QuoteIdentifier(parsedTable.Table))).Scan(&row)
if err != nil && err != pgx.ErrNoRows {
return err
}
}

tableStr := strings.Join(tableArr, ",")
if pubName != "" {
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}

// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}
if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}

return nil
Expand Down Expand Up @@ -112,6 +112,53 @@ func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, use
return nil
}

func (c *PostgresConnector) CheckPublicationPermission(ctx context.Context, tableNameString string) error {
publication := "_PEERDB_DUMMY_PUBLICATION_" + shared.RandomString(4)
// check and enable publish_via_partition_root
supportsPubViaRoot, _, err := c.MajorVersionCheck(ctx, POSTGRES_13)
if err != nil {
return fmt.Errorf("error checking Postgres version: %w", err)
}
var pubViaRootString string
if supportsPubViaRoot {
pubViaRootString = "WITH(publish_via_partition_root=true)"
}
tx, err := c.conn.Begin(ctx)
if err != nil {
return fmt.Errorf("error starting transaction: %w", err)
}
defer tx.Conn().Close(ctx)
defer func() {
err := tx.Rollback(ctx)
if err != nil && err != pgx.ErrTxClosed {
c.logger.Error("[validate publication create] failed to rollback transaction", "error", err)
}
}()

// Create the publication
createStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s %s",
publication, tableNameString, pubViaRootString)
_, err = tx.Exec(ctx, createStmt)
if err != nil {
return fmt.Errorf("it will not be possible to create a publication for selected tables: %w", err)
}

// Drop the publication
dropStmt := "DROP PUBLICATION IF EXISTS " + publication
_, err = tx.Exec(ctx, dropStmt)
if err != nil {
return fmt.Errorf("it will not be possible to drop the publication created for this mirror: %w",
err)
}

// commit transaction
err = tx.Commit(ctx)
if err != nil {
return fmt.Errorf("unable to validate publication create permission: %w", err)
}
return nil
}

func (c *PostgresConnector) CheckReplicationConnectivity(ctx context.Context) error {
// Check if we can create a replication connection
conn, err := c.CreateReplConn(ctx)
Expand Down
Loading