Skip to content

Commit

Permalink
improve api
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed Feb 15, 2024
1 parent 99de22d commit 151b29d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 51 deletions.
1 change: 1 addition & 0 deletions flow/cmd/peer_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ func (h *FlowRequestHandler) GetColumns(
AND
relname = $2
AND pg_attribute.attnum > 0
AND attname NOT LIKE '%........pg.dropped.%'
ORDER BY
attnum;
`, req.SchemaName, req.TableName)
Expand Down
8 changes: 1 addition & 7 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"strings"

connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/connectors/utils"
Expand Down Expand Up @@ -65,12 +64,7 @@ func (h *FlowRequestHandler) ValidateCDCMirror(

pubName := req.ConnectionConfigs.PublicationName
if pubName == "" {
pubTables := make([]string, 0, len(sourceTables))
for _, table := range sourceTables {
pubTables = append(pubTables, table.String())
}
pubTableStr := strings.Join(pubTables, ", ")
pubErr := pgPeer.CheckPublicationPermission(ctx, pubTableStr)
pubErr := pgPeer.CheckPublicationPermission(ctx, sourceTables)
if pubErr != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
Expand Down
70 changes: 26 additions & 44 deletions flow/connectors/postgres/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/jackc/pgx/v5"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/shared"
)

func (c *PostgresConnector) CheckSourceTables(ctx context.Context,
Expand Down Expand Up @@ -112,49 +111,38 @@ func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, use
return nil
}

func (c *PostgresConnector) CheckPublicationPermission(ctx context.Context, tableNameString string) error {
publication := "_PEERDB_DUMMY_PUBLICATION_" + shared.RandomString(4)
// check and enable publish_via_partition_root
supportsPubViaRoot, _, err := c.MajorVersionCheck(ctx, POSTGRES_13)
if err != nil {
return fmt.Errorf("error checking Postgres version: %w", err)
}
var pubViaRootString string
if supportsPubViaRoot {
pubViaRootString = "WITH(publish_via_partition_root=true)"
}
tx, err := c.conn.Begin(ctx)
if err != nil {
return fmt.Errorf("error starting transaction: %w", err)
func (c *PostgresConnector) CheckPublicationPermission(ctx context.Context, tableNames []*utils.SchemaTable) error {
var hasSuper bool
var canCreateDatabase bool
queryErr := c.conn.QueryRow(ctx, fmt.Sprintf(`
SELECT
rolsuper,
has_database_privilege(rolname, current_database(), 'CREATE') AS can_create_database
FROM pg_roles
WHERE rolname = %s;
`, QuoteLiteral(c.config.User))).Scan(&hasSuper, &canCreateDatabase)
if queryErr != nil {
return fmt.Errorf("error while checking user privileges: %w", queryErr)
}
defer func() {
err := tx.Rollback(ctx)
if err != nil && err != pgx.ErrTxClosed {
c.logger.Error("[validate publication create] failed to rollback transaction", "error", err)
}
}()

// Create the publication
createStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s %s",
publication, tableNameString, pubViaRootString)
_, err = tx.Exec(ctx, createStmt)
if err != nil {
return fmt.Errorf("it will not be possible to create a publication for selected tables: %w", err)
if !hasSuper && !canCreateDatabase {
return errors.New("user does not have superuser or create database privileges")
}

// Drop the publication
dropStmt := "DROP PUBLICATION IF EXISTS " + publication
_, err = tx.Exec(ctx, dropStmt)
if err != nil {
return fmt.Errorf("it will not be possible to drop the publication created for this mirror: %w",
err)
}
// for each table, check if the user is an owner
for _, table := range tableNames {
var owner string
err := c.conn.QueryRow(ctx, fmt.Sprintf("SELECT tableowner FROM pg_tables WHERE schemaname=%s AND tablename=%s",
QuoteLiteral(table.Schema), QuoteLiteral(table.Table))).Scan(&owner)
if err != nil {
return fmt.Errorf("error while checking table owner: %w", err)
}

// commit transaction
err = tx.Commit(ctx)
if err != nil {
return fmt.Errorf("unable to validate publication create permission: %w", err)
if owner != c.config.User {
return fmt.Errorf("user %s is not the owner of table %s", c.config.User, table.String())
}
}

return nil
}

Expand All @@ -167,11 +155,5 @@ func (c *PostgresConnector) CheckReplicationConnectivity(ctx context.Context) er

defer conn.Close(ctx)

var one int
queryErr := conn.QueryRow(ctx, "SELECT 1").Scan(&one)
if queryErr != nil {
return fmt.Errorf("failed to query replication connection: %v", queryErr)
}

return nil
}

0 comments on commit 151b29d

Please sign in to comment.