Skip to content

Commit

Permalink
Refactor replica identity type and primary key column retrieval in Po…
Browse files Browse the repository at this point in the history
…stgres (#860)

This is to support `USING INDEX` replica identity types in Postgres.
Treating the index columns as primary key columns for now as it is the
fastest way for us to support replica identity index.
  • Loading branch information
iskakaushik authored Dec 20, 2023
1 parent d5baa48 commit 8e4e7ba
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 28 deletions.
87 changes: 65 additions & 22 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/lib/pq/oid"
"golang.org/x/exp/maps"
)

Expand Down Expand Up @@ -77,6 +78,15 @@ const (
deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=$1"
)

type ReplicaIdentityType rune

const (
ReplicaIdentityDefault ReplicaIdentityType = 'd'
ReplicaIdentityFull = 'f'
ReplicaIdentityIndex = 'i'
ReplicaIdentityNothing = 'n'
)

// getRelIDForTable returns the relation ID for a table.
func (c *PostgresConnector) getRelIDForTable(schemaTable *utils.SchemaTable) (uint32, error) {
var relID pgtype.Uint32
Expand All @@ -92,54 +102,87 @@ func (c *PostgresConnector) getRelIDForTable(schemaTable *utils.SchemaTable) (ui
}

// getReplicaIdentity returns the replica identity for a table.
func (c *PostgresConnector) isTableFullReplica(schemaTable *utils.SchemaTable) (bool, error) {
func (c *PostgresConnector) getReplicaIdentityType(schemaTable *utils.SchemaTable) (ReplicaIdentityType, error) {
relID, relIDErr := c.getRelIDForTable(schemaTable)
if relIDErr != nil {
return false, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr)
return ReplicaIdentityDefault, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr)
}

var replicaIdentity rune
err := c.pool.QueryRow(c.ctx,
`SELECT relreplident FROM pg_class WHERE oid = $1;`,
relID).Scan(&replicaIdentity)
if err != nil {
return false, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, err)
return ReplicaIdentityDefault, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, err)
}
return string(replicaIdentity) == "f", nil

return ReplicaIdentityType(replicaIdentity), nil
}

// getPrimaryKeyColumns for table returns the primary key column for a given table
// errors if there is no primary key column or if there is more than one primary key column.
func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *utils.SchemaTable) ([]string, error) {
// getPrimaryKeyColumns returns the primary key columns for a given table.
// Errors if there is no primary key column or if there is more than one primary key column.
func (c *PostgresConnector) getPrimaryKeyColumns(
replicaIdentity ReplicaIdentityType,
schemaTable *utils.SchemaTable,
) ([]string, error) {
relID, err := c.getRelIDForTable(schemaTable)
if err != nil {
return nil, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, err)
}

// Get the primary key column name
var pkCol pgtype.Text
pkCols := make([]string, 0)
if replicaIdentity == ReplicaIdentityIndex {
return c.getReplicaIdentityIndexColumns(relID, schemaTable)
}

// Find the primary key index OID
var pkIndexOID oid.Oid
err = c.pool.QueryRow(c.ctx,
`SELECT indexrelid FROM pg_index WHERE indrelid = $1 AND indisprimary`,
relID).Scan(&pkIndexOID)
if err != nil {
return nil, fmt.Errorf("error finding primary key index for table %s: %w", schemaTable, err)
}

return c.getColumnNamesForIndex(pkIndexOID)
}

// getReplicaIdentityIndexColumns returns the columns used in the replica identity index.
func (c *PostgresConnector) getReplicaIdentityIndexColumns(relID uint32, schemaTable *utils.SchemaTable) ([]string, error) {
var indexRelID oid.Oid
// Fetch the OID of the index used as the replica identity
err := c.pool.QueryRow(c.ctx,
`SELECT indexrelid FROM pg_index
WHERE indrelid = $1 AND indisreplident = true`,
relID).Scan(&indexRelID)
if err != nil {
return nil, fmt.Errorf("error finding replica identity index for table %s: %w", schemaTable, err)
}

return c.getColumnNamesForIndex(indexRelID)
}

// getColumnNamesForIndex returns the column names for a given index.
func (c *PostgresConnector) getColumnNamesForIndex(indexOID oid.Oid) ([]string, error) {
var col pgtype.Text
cols := make([]string, 0)
rows, err := c.pool.Query(c.ctx,
`SELECT a.attname FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = $1 AND i.indisprimary ORDER BY a.attname ASC`,
relID)
WHERE i.indexrelid = $1 ORDER BY a.attname ASC`,
indexOID)
if err != nil {
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
return nil, fmt.Errorf("error getting columns for index %v: %w", indexOID, err)
}
defer rows.Close()
for {
if !rows.Next() {
break
}
err = rows.Scan(&pkCol)

for rows.Next() {
err = rows.Scan(&col)
if err != nil {
return nil, fmt.Errorf("error scanning primary key column for table %s: %w", schemaTable, err)
return nil, fmt.Errorf("error scanning column for index %v: %w", indexOID, err)
}
pkCols = append(pkCols, pkCol.String)
cols = append(cols, col.String)
}

return pkCols, nil
return cols, nil
}

func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) (bool, error) {
Expand Down
12 changes: 6 additions & 6 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,12 @@ func (c *PostgresConnector) getTableSchemaForTable(
return nil, err
}

isFullReplica, replErr := c.isTableFullReplica(schemaTable)
replicaIdentityType, replErr := c.getReplicaIdentityType(schemaTable)
if replErr != nil {
return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr)
}

pKeyCols, err := c.getPrimaryKeyColumns(schemaTable)
pKeyCols, err := c.getPrimaryKeyColumns(replicaIdentityType, schemaTable)
if err != nil {
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
}
Expand All @@ -581,7 +581,7 @@ func (c *PostgresConnector) getTableSchemaForTable(
TableIdentifier: tableName,
Columns: make(map[string]string),
PrimaryKeyColumns: pKeyCols,
IsReplicaIdentityFull: isFullReplica,
IsReplicaIdentityFull: replicaIdentityType == ReplicaIdentityFull,
}

for _, fieldDescription := range rows.FieldDescriptions() {
Expand Down Expand Up @@ -731,18 +731,18 @@ func (c *PostgresConnector) EnsurePullability(req *protos.EnsurePullabilityBatch
return nil, err
}

isFullReplica, replErr := c.isTableFullReplica(schemaTable)
replicaIdentity, replErr := c.getReplicaIdentityType(schemaTable)
if replErr != nil {
return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr)
}

pKeyCols, err := c.getPrimaryKeyColumns(schemaTable)
pKeyCols, err := c.getPrimaryKeyColumns(replicaIdentity, schemaTable)
if err != nil {
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
}

// we only allow no primary key if the table has REPLICA IDENTITY FULL
if len(pKeyCols) == 0 && !isFullReplica {
if len(pKeyCols) == 0 && !(replicaIdentity == ReplicaIdentityFull) {
return nil, fmt.Errorf("table %s has no primary keys and does not have REPLICA IDENTITY FULL", schemaTable)
}

Expand Down
66 changes: 66 additions & 0 deletions flow/e2e/snowflake/peer_flow_sf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,72 @@ func (s PeerFlowE2ETestSuiteSF) Test_Complete_Simple_Flow_SF() {
env.AssertExpectations(s.t)
}

func (s PeerFlowE2ETestSuiteSF) Test_Flow_ReplicaIdentity_Index_No_Pkey() {
env := e2e.NewTemporalTestWorkflowEnvironment()
e2e.RegisterWorkflowsAndActivities(env, s.t)

srcTableName := s.attachSchemaSuffix("test_replica_identity_no_pkey")
dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_replica_identity_no_pkey")

// Create a table without a primary key and create a named unique index
_, err := s.pool.Exec(context.Background(), fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id SERIAL,
key TEXT NOT NULL,
value TEXT NOT NULL
);
CREATE UNIQUE INDEX unique_idx_on_id_key ON %s (id, key);
ALTER TABLE %s REPLICA IDENTITY USING INDEX unique_idx_on_id_key;
`, srcTableName, srcTableName, srcTableName))
require.NoError(s.t, err)

connectionGen := e2e.FlowConnectionGenerationConfig{
FlowJobName: s.attachSuffix("test_simple_flow"),
TableNameMapping: map[string]string{srcTableName: dstTableName},
PostgresPort: e2e.PostgresPort,
Destination: s.sfHelper.Peer,
}

flowConnConfig, err := connectionGen.GenerateFlowConnectionConfigs()
require.NoError(s.t, err)

limits := peerflow.CDCFlowLimits{
ExitAfterRecords: 20,
MaxBatchSize: 100,
}

// in a separate goroutine, wait for PeerFlowStatusQuery to finish setup
// and then insert 20 rows into the source table
go func() {
e2e.SetupCDCFlowStatusQuery(env, connectionGen)
// insert 20 rows into the source table
for i := 0; i < 20; i++ {
testKey := fmt.Sprintf("test_key_%d", i)
testValue := fmt.Sprintf("test_value_%d", i)
_, err = s.pool.Exec(context.Background(), fmt.Sprintf(`
INSERT INTO %s (id, key, value) VALUES ($1, $2, $3)
`, srcTableName), i, testKey, testValue)
require.NoError(s.t, err)
}
fmt.Println("Inserted 20 rows into the source table")
}()

env.ExecuteWorkflow(peerflow.CDCFlowWorkflowWithConfig, flowConnConfig, &limits, nil)

// Verify workflow completes without error
s.True(env.IsWorkflowCompleted())
err = env.GetWorkflowError()

// allow only continue as new error
require.Contains(s.t, err.Error(), "continue as new")

count, err := s.sfHelper.CountRows("test_replica_identity_no_pkey")
require.NoError(s.t, err)
s.Equal(20, count)

env.AssertExpectations(s.t)
}

func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() {
env := e2e.NewTemporalTestWorkflowEnvironment()
e2e.RegisterWorkflowsAndActivities(env, s.t)
Expand Down

0 comments on commit 8e4e7ba

Please sign in to comment.