Skip to content

Commit

Permalink
Validate mirror: check replication connectivity, move to validate.go (#…
Browse files Browse the repository at this point in the history
…1300)

Fixes #746 
Moves all validation related functions in `client.go` to a new
`validate.go` for better organisation
  • Loading branch information
Amogh-Bharadwaj authored Feb 15, 2024
1 parent 492bac4 commit cc3afe8
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 103 deletions.
8 changes: 8 additions & 0 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}
defer pgPeer.Close(ctx)

// Check replication connectivity
err = pgPeer.CheckReplicationConnectivity(ctx)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("unable to establish replication connectivity: %v", err)
}

// Check permissions of postgres peer
err = pgPeer.CheckReplicationPermissions(ctx, sourcePeerConfig.User)
if err != nil {
Expand Down
103 changes: 0 additions & 103 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"

"github.com/jackc/pglogrepl"
Expand Down Expand Up @@ -626,105 +625,3 @@ func (c *PostgresConnector) getCurrentLSN(ctx context.Context) (pglogrepl.LSN, e
func (c *PostgresConnector) getDefaultPublicationName(jobName string) string {
return fmt.Sprintf("peerflow_pub_%s", jobName)
}

func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []string, pubName string) error {
if c.conn == nil {
return errors.New("check tables: conn is nil")
}

// Check that we can select from all tables
tableArr := make([]string, 0, len(tableNames))
for _, table := range tableNames {
var row pgx.Row
schemaName, tableName, found := strings.Cut(table, ".")
if !found {
return fmt.Errorf("invalid source table identifier: %s", table)
}

tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, QuoteLiteral(schemaName), QuoteLiteral(tableName)))
err := c.conn.QueryRow(ctx,
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName))).Scan(&row)
if err != nil && err != pgx.ErrNoRows {
return err
}
}

tableStr := strings.Join(tableArr, ",")
if pubName != "" {
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}

// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}
}

return nil
}

func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, username string) error {
if c.conn == nil {
return fmt.Errorf("check replication permissions: conn is nil")
}

var replicationRes bool
err := c.conn.QueryRow(ctx, "SELECT rolreplication FROM pg_roles WHERE rolname = $1", username).Scan(&replicationRes)
if err != nil {
return err
}

if !replicationRes {
// RDS case: check pg_settings for rds.logical_replication
var setting string
err := c.conn.QueryRow(ctx, "SELECT setting FROM pg_settings WHERE name = 'rds.logical_replication'").Scan(&setting)
if err != nil || setting != "on" {
return fmt.Errorf("postgres user does not have replication role")
}
}

// check wal_level
var walLevel string
err = c.conn.QueryRow(ctx, "SHOW wal_level").Scan(&walLevel)
if err != nil {
return err
}

if walLevel != "logical" {
return fmt.Errorf("wal_level is not logical")
}

// max_wal_senders must be at least 2
var maxWalSendersRes string
err = c.conn.QueryRow(ctx, "SHOW max_wal_senders").Scan(&maxWalSendersRes)
if err != nil {
return err
}

maxWalSenders, err := strconv.Atoi(maxWalSendersRes)
if err != nil {
return err
}

if maxWalSenders < 2 {
return fmt.Errorf("max_wal_senders must be at least 2")
}

return nil
}
131 changes: 131 additions & 0 deletions flow/connectors/postgres/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package connpostgres

import (
"context"
"errors"
"fmt"
"strconv"
"strings"

"github.com/jackc/pgx/v5"
)

func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []string, pubName string) error {
if c.conn == nil {
return errors.New("check tables: conn is nil")
}

// Check that we can select from all tables
tableArr := make([]string, 0, len(tableNames))
for _, table := range tableNames {
var row pgx.Row
schemaName, tableName, found := strings.Cut(table, ".")
if !found {
return fmt.Errorf("invalid source table identifier: %s", table)
}

tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, QuoteLiteral(schemaName), QuoteLiteral(tableName)))
err := c.conn.QueryRow(ctx,
fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName))).Scan(&row)
if err != nil && err != pgx.ErrNoRows {
return err
}
}

tableStr := strings.Join(tableArr, ",")
if pubName != "" {
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}

// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}
}

return nil
}

func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, username string) error {
if c.conn == nil {
return errors.New("check replication permissions: conn is nil")
}

var replicationRes bool
err := c.conn.QueryRow(ctx, "SELECT rolreplication FROM pg_roles WHERE rolname = $1", username).Scan(&replicationRes)
if err != nil {
return err
}

if !replicationRes {
// RDS case: check pg_settings for rds.logical_replication
var setting string
err := c.conn.QueryRow(ctx, "SELECT setting FROM pg_settings WHERE name = 'rds.logical_replication'").Scan(&setting)
if err != nil || setting != "on" {
return errors.New("postgres user does not have replication role")
}
}

// check wal_level
var walLevel string
err = c.conn.QueryRow(ctx, "SHOW wal_level").Scan(&walLevel)
if err != nil {
return err
}

if walLevel != "logical" {
return errors.New("wal_level is not logical")
}

// max_wal_senders must be at least 2
var maxWalSendersRes string
err = c.conn.QueryRow(ctx, "SHOW max_wal_senders").Scan(&maxWalSendersRes)
if err != nil {
return err
}

maxWalSenders, err := strconv.Atoi(maxWalSendersRes)
if err != nil {
return err
}

if maxWalSenders < 2 {
return errors.New("max_wal_senders must be at least 2")
}

return nil
}

func (c *PostgresConnector) CheckReplicationConnectivity(ctx context.Context) error {
// Check if we can create a replication connection
conn, err := c.CreateReplConn(ctx)
if err != nil {
return fmt.Errorf("failed to create replication connection: %v", err)
}

defer conn.Close(ctx)

var one int
queryErr := conn.QueryRow(ctx, "SELECT 1").Scan(&one)
if queryErr != nil {
return fmt.Errorf("failed to query replication connection: %v", queryErr)
}

return nil
}

0 comments on commit cc3afe8

Please sign in to comment.