Skip to content

Commit

Permalink
Merge branch 'main' into snowflake-encrypted-pkey
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal authored Sep 18, 2023
2 parents 212543c + 54f8948 commit ee69cff
Show file tree
Hide file tree
Showing 10 changed files with 368 additions and 304 deletions.
70 changes: 40 additions & 30 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,45 +218,55 @@ func (p *PostgresCDCSource) consumeStream(
switch r := rec.(type) {
case *model.UpdateRecord:
// tableName here is destination tableName.
// should be ideally sourceTableName as we are in pullRecrods.
// should be ideally sourceTableName as we are in PullRecords.
// will change in future
pkeyCol := req.TableNameSchemaMapping[tableName].PrimaryKeyColumn
pkeyColVal, err := rec.GetItems().GetValueByColName(pkeyCol)
if err != nil {
return nil, fmt.Errorf("error getting pkey column value: %w", err)
}
isFullReplica := req.TableNameSchemaMapping[tableName].IsReplicaIdentityFull
if isFullReplica {
records.Records = append(records.Records, rec)
} else {
pkeyCol := req.TableNameSchemaMapping[tableName].PrimaryKeyColumn
pkeyColVal, err := rec.GetItems().GetValueByColName(pkeyCol)
if err != nil {
return nil, fmt.Errorf("error getting pkey column value: %w", err)
}

tablePkeyVal := model.TableWithPkey{
TableName: tableName,
PkeyColVal: *pkeyColVal,
tablePkeyVal := model.TableWithPkey{
TableName: tableName,
PkeyColVal: *pkeyColVal,
}
_, ok := records.TablePKeyLastSeen[tablePkeyVal]
if !ok {
records.Records = append(records.Records, rec)
records.TablePKeyLastSeen[tablePkeyVal] = len(records.Records) - 1
} else {
oldRec := records.Records[records.TablePKeyLastSeen[tablePkeyVal]]
// iterate through unchanged toast cols and set them in new record
updatedCols := r.NewItems.UpdateIfNotExists(oldRec.GetItems())
for _, col := range updatedCols {
delete(r.UnchangedToastColumns, col)
}
records.Records = append(records.Records, rec)
records.TablePKeyLastSeen[tablePkeyVal] = len(records.Records) - 1
}
}
_, ok := records.TablePKeyLastSeen[tablePkeyVal]
if !ok {
case *model.InsertRecord:
isFullReplica := req.TableNameSchemaMapping[tableName].IsReplicaIdentityFull
if isFullReplica {
records.Records = append(records.Records, rec)
records.TablePKeyLastSeen[tablePkeyVal] = len(records.Records) - 1
} else {
oldRec := records.Records[records.TablePKeyLastSeen[tablePkeyVal]]
// iterate through unchanged toast cols and set them in new record
updatedCols := r.NewItems.UpdateIfNotExists(oldRec.GetItems())
for _, col := range updatedCols {
delete(r.UnchangedToastColumns, col)
pkeyCol := req.TableNameSchemaMapping[tableName].PrimaryKeyColumn
pkeyColVal, err := rec.GetItems().GetValueByColName(pkeyCol)
if err != nil {
return nil, fmt.Errorf("error getting pkey column value: %w", err)
}
tablePkeyVal := model.TableWithPkey{
TableName: tableName,
PkeyColVal: *pkeyColVal,
}
records.Records = append(records.Records, rec)
// all columns will be set in insert record, so add it to the map
records.TablePKeyLastSeen[tablePkeyVal] = len(records.Records) - 1
}
case *model.InsertRecord:
pkeyCol := req.TableNameSchemaMapping[tableName].PrimaryKeyColumn
pkeyColVal, err := rec.GetItems().GetValueByColName(pkeyCol)
if err != nil {
return nil, fmt.Errorf("error getting pkey column value: %w", err)
}
tablePkeyVal := model.TableWithPkey{
TableName: tableName,
PkeyColVal: *pkeyColVal,
}
records.Records = append(records.Records, rec)
// all columns will be set in insert record, so add it to the map
records.TablePKeyLastSeen[tablePkeyVal] = len(records.Records) - 1
case *model.DeleteRecord:
records.Records = append(records.Records, rec)
case *model.RelationRecord:
Expand Down
8 changes: 4 additions & 4 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,20 @@ func (c *PostgresConnector) getRelIDForTable(schemaTable *SchemaTable) (uint32,
}

// getReplicaIdentity returns the replica identity for a table.
func (c *PostgresConnector) getReplicaIdentityForTable(schemaTable *SchemaTable) (string, error) {
func (c *PostgresConnector) isTableFullReplica(schemaTable *SchemaTable) (bool, error) {
relID, relIDErr := c.getRelIDForTable(schemaTable)
if relIDErr != nil {
return "", fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr)
return false, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr)
}

var replicaIdentity rune
err := c.pool.QueryRow(c.ctx,
`SELECT relreplident FROM pg_class WHERE oid = $1;`,
relID).Scan(&replicaIdentity)
if err != nil {
return "", fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, err)
return false, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, err)
}
return string(replicaIdentity), nil
return string(replicaIdentity) == "f", nil
}

// getPrimaryKeyColumn for table returns the primary key column for a given table
Expand Down
15 changes: 10 additions & 5 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,18 +566,23 @@ func (c *PostgresConnector) getTableSchemaForTable(
}
defer rows.Close()

isFullReplica, replErr := c.isTableFullReplica(schemaTable)
if replErr != nil {
return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr)
}

pkey, err := c.getPrimaryKeyColumn(schemaTable)
if err != nil {
replicaIdentity, err := c.getReplicaIdentityForTable(schemaTable)
if err != nil || replicaIdentity != "f" {
if !isFullReplica {
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
}
}

res := &protos.TableSchema{
TableIdentifier: tableName,
Columns: make(map[string]string),
PrimaryKeyColumn: pkey,
TableIdentifier: tableName,
Columns: make(map[string]string),
PrimaryKeyColumn: pkey,
IsReplicaIdentityFull: isFullReplica,
}

for _, fieldDescription := range rows.FieldDescriptions() {
Expand Down
10 changes: 5 additions & 5 deletions flow/e2e/postgres/peer_flow_pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (s *PeerFlowE2ETestSuitePG) Test_Simple_Flow_PG() {
s.Error(err)
s.Contains(err.Error(), "continue as new")

err = s.comparePGTables(srcTableName, dstTableName)
err = s.comparePGTables(srcTableName, dstTableName, "id,key,value")
s.NoError(err)

env.AssertExpectations(s.T())
Expand Down Expand Up @@ -121,7 +121,7 @@ func (s *PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() {

// verify we got our first row.
e2e.NormalizeFlowCountQuery(env, connectionGen, 2)
err = s.comparePGTables(srcTableName, dstTableName)
err = s.comparePGTables(srcTableName, dstTableName, "id,c1")
s.NoError(err)

// alter source table, add column c2 and insert another row.
Expand All @@ -136,7 +136,7 @@ func (s *PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() {

// verify we got our two rows, if schema did not match up it will error.
e2e.NormalizeFlowCountQuery(env, connectionGen, 4)
err = s.comparePGTables(srcTableName, dstTableName)
err = s.comparePGTables(srcTableName, dstTableName, "id,c1")
s.NoError(err)

// alter source table, add column c3, drop column c2 and insert another row.
Expand All @@ -151,7 +151,7 @@ func (s *PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() {

// verify we got our two rows, if schema did not match up it will error.
e2e.NormalizeFlowCountQuery(env, connectionGen, 6)
err = s.comparePGTables(srcTableName, dstTableName)
err = s.comparePGTables(srcTableName, dstTableName, "id,c1")
s.NoError(err)

// alter source table, drop column c3 and insert another row.
Expand All @@ -166,7 +166,7 @@ func (s *PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() {

// verify we got our two rows, if schema did not match up it will error.
e2e.NormalizeFlowCountQuery(env, connectionGen, 8)
err = s.comparePGTables(srcTableName, dstTableName)
err = s.comparePGTables(srcTableName, dstTableName, "id,c1")
s.NoError(err)
}()

Expand Down
13 changes: 7 additions & 6 deletions flow/e2e/postgres/qrep_flow_pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ func (s *PeerFlowE2ETestSuitePG) setupSourceTable(tableName string, rowCount int
s.NoError(err)
}

func (s *PeerFlowE2ETestSuitePG) comparePGTables(srcSchemaQualified, dstSchemaQualified string) error {
func (s *PeerFlowE2ETestSuitePG) comparePGTables(srcSchemaQualified, dstSchemaQualified, selector string) error {
// Execute the two EXCEPT queries
for {
err := s.compareQuery(srcSchemaQualified, dstSchemaQualified)
err := s.compareQuery(srcSchemaQualified, dstSchemaQualified, selector)
// while testing, the prepared plan might break due to schema changes
// solution is to retry, prepared statement should be evicted upon the first error
if err != nil && !strings.Contains(err.Error(), "cached plan must not change result type") {
Expand All @@ -78,7 +78,7 @@ func (s *PeerFlowE2ETestSuitePG) comparePGTables(srcSchemaQualified, dstSchemaQu
}

for {
err := s.compareQuery(dstSchemaQualified, srcSchemaQualified)
err := s.compareQuery(dstSchemaQualified, srcSchemaQualified, selector)
// while testing, the prepared plan might break due to schema changes
// solution is to retry, prepared statement should be evicted upon the first error
if err != nil && !strings.Contains(err.Error(), "cached plan must not change result type") {
Expand All @@ -93,8 +93,9 @@ func (s *PeerFlowE2ETestSuitePG) comparePGTables(srcSchemaQualified, dstSchemaQu
return nil
}

func (s *PeerFlowE2ETestSuitePG) compareQuery(schema1, schema2 string) error {
query := fmt.Sprintf("SELECT * FROM %s EXCEPT SELECT * FROM %s", schema1, schema2)
func (s *PeerFlowE2ETestSuitePG) compareQuery(srcSchemaQualified, dstSchemaQualified, selector string) error {
query := fmt.Sprintf("SELECT %s FROM %s EXCEPT SELECT %s FROM %s", selector, srcSchemaQualified,
selector, dstSchemaQualified)
rows, _ := s.pool.Query(context.Background(), query)
rowsPresent := false

Expand Down Expand Up @@ -163,7 +164,7 @@ func (s *PeerFlowE2ETestSuitePG) Test_Complete_QRep_Flow_Multi_Insert_PG() {
err = env.GetWorkflowError()
s.NoError(err)

err = s.comparePGTables(srcSchemaQualified, dstSchemaQualified)
err = s.comparePGTables(srcSchemaQualified, dstSchemaQualified, "*")
if err != nil {
s.FailNow(err.Error())
}
Expand Down
Loading

0 comments on commit ee69cff

Please sign in to comment.