Skip to content

Commit

Permalink
added tests for PG and SF schema changes, fixed edge cases and quoting
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal committed Oct 4, 2023
1 parent b181cff commit 03a8f0a
Show file tree
Hide file tree
Showing 9 changed files with 902 additions and 61 deletions.
7 changes: 4 additions & 3 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (c *BigQueryConnector) ReplayTableSchemaDelta(flowJobName string,
}

for _, droppedColumn := range schemaDelta.DroppedColumns {
_, err := c.client.Query(fmt.Sprintf("ALTER TABLE %s.%s DROP COLUMN %s", c.datasetID,
_, err := c.client.Query(fmt.Sprintf("ALTER TABLE %s.%s DROP COLUMN `%s`", c.datasetID,
schemaDelta.DstTableName, droppedColumn)).Read(c.ctx)
if err != nil {
return fmt.Errorf("failed to drop column %s for table %s: %w", droppedColumn,
Expand All @@ -226,8 +226,9 @@ func (c *BigQueryConnector) ReplayTableSchemaDelta(flowJobName string,
}).Infof("[schema delta replay] dropped column %s", droppedColumn)
}
for _, addedColumn := range schemaDelta.AddedColumns {
_, err := c.client.Query(fmt.Sprintf("ALTER TABLE %s.%s ADD COLUMN %s %s", c.datasetID,
schemaDelta.DstTableName, addedColumn.ColumnName, addedColumn.ColumnType)).Read(c.ctx)
_, err := c.client.Query(fmt.Sprintf("ALTER TABLE %s.%s ADD COLUMN `%s` %s", c.datasetID,
schemaDelta.DstTableName, addedColumn.ColumnName,
qValueKindToBigQueryType(addedColumn.ColumnType))).Read(c.ctx)
if err != nil {
return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName,
schemaDelta.SrcTableName, err)
Expand Down
20 changes: 14 additions & 6 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,13 +545,13 @@ func (p *PostgresCDCSource) processRelationMessage(
// retrieve initial RelationMessage for table changed.
prevRel := p.relationMessageMapping[currRel.RelationId]
// creating maps for lookup later
prevRelMap := make(map[string]bool)
currRelMap := make(map[string]bool)
prevRelMap := make(map[string]*uint32)
currRelMap := make(map[string]*uint32)
for _, column := range prevRel.Columns {
prevRelMap[column.Name] = true
prevRelMap[column.Name] = &column.DataType
}
for _, column := range currRel.Columns {
currRelMap[column.Name] = true
currRelMap[column.Name] = &column.DataType
}

schemaDelta := &protos.TableSchemaDelta{
Expand All @@ -564,7 +564,15 @@ func (p *PostgresCDCSource) processRelationMessage(
}
for _, column := range currRel.Columns {
// not present in previous relation message, but in current one, so added.
if !prevRelMap[column.Name] {
if prevRelMap[column.Name] == nil {
schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.DeltaAddedColumn{
ColumnName: column.Name,
ColumnType: string(postgresOIDToQValueKind(column.DataType)),
})
// present in previous and current relation messages, but data types have changed.
// so we add it to AddedColumns and DroppedColumns, knowing that we process DroppedColumns first.
} else if *prevRelMap[column.Name] != *currRelMap[column.Name] {
schemaDelta.DroppedColumns = append(schemaDelta.DroppedColumns, column.Name)
schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.DeltaAddedColumn{
ColumnName: column.Name,
ColumnType: string(postgresOIDToQValueKind(column.DataType)),
Expand All @@ -573,7 +581,7 @@ func (p *PostgresCDCSource) processRelationMessage(
}
for _, column := range prevRel.Columns {
// present in previous relation message, but not in current one, so dropped.
if !currRelMap[column.Name] {
if currRelMap[column.Name] == nil {
schemaDelta.DroppedColumns = append(schemaDelta.DroppedColumns, column.Name)
}
}
Expand Down
4 changes: 2 additions & 2 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,10 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st
pgType := qValueKindToPostgresType(genericColumnType)
if strings.Contains(genericColumnType, "array") {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS %s",
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"",
strings.Trim(columnName, "\""), pgType, columnName))
} else {
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS %s",
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"",
strings.Trim(columnName, "\""), pgType, columnName))
}
if normalizedTableSchema.PrimaryKeyColumn == columnName {
Expand Down
38 changes: 20 additions & 18 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,14 +560,6 @@ func (c *PostgresConnector) getTableSchemaForTable(
return nil, err
}

// Get the column names and types
rows, err := c.pool.Query(c.ctx,
fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, tableName))
if err != nil {
return nil, fmt.Errorf("error getting table schema for table %s: %w", schemaTable, err)
}
defer rows.Close()

isFullReplica, replErr := c.isTableFullReplica(schemaTable)
if replErr != nil {
return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr)
Expand All @@ -587,6 +579,14 @@ func (c *PostgresConnector) getTableSchemaForTable(
IsReplicaIdentityFull: isFullReplica,
}

// Get the column names and types
rows, err := c.pool.Query(c.ctx,
fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, tableName), pgx.QueryExecModeSimpleProtocol)
if err != nil {
return nil, fmt.Errorf("error getting table schema for table %s: %w", schemaTable, err)
}
defer rows.Close()

for _, fieldDescription := range rows.FieldDescriptions() {
genericColType := postgresOIDToQValueKind(fieldDescription.DataTypeOID)
if genericColType == qvalue.QValueKindInvalid {
Expand Down Expand Up @@ -676,7 +676,7 @@ func (c *PostgresConnector) ReplayTableSchemaDelta(flowJobName string, schemaDel
tableSchemaModifyTx, err := c.pool.Begin(c.ctx)
if err != nil {
return fmt.Errorf("error starting transaction for schema modification for table %s: %w",
schemaDelta.SrcTableName, err)
schemaDelta.DstTableName, err)
}
defer func() {
deferErr := tableSchemaModifyTx.Rollback(c.ctx)
Expand All @@ -688,36 +688,38 @@ func (c *PostgresConnector) ReplayTableSchemaDelta(flowJobName string, schemaDel
}()

for _, droppedColumn := range schemaDelta.DroppedColumns {
_, err = tableSchemaModifyTx.Exec(c.ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", schemaDelta.DstTableName,
_, err = tableSchemaModifyTx.Exec(c.ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN \"%s\"", schemaDelta.DstTableName,
droppedColumn))
if err != nil {
return fmt.Errorf("failed to drop column %s for table %s: %w", droppedColumn,
schemaDelta.SrcTableName, err)
schemaDelta.DstTableName, err)
}
log.WithFields(log.Fields{
"flowName": flowJobName,
"tableName": schemaDelta.SrcTableName,
"flowName": flowJobName,
"srcTableName": schemaDelta.SrcTableName,
"dstTableName": schemaDelta.DstTableName,
}).Infof("[schema delta replay] dropped column %s", droppedColumn)
}
for _, addedColumn := range schemaDelta.AddedColumns {
_, err = tableSchemaModifyTx.Exec(c.ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s",
_, err = tableSchemaModifyTx.Exec(c.ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN \"%s\" %s",
schemaDelta.DstTableName, addedColumn.ColumnName,
qValueKindToPostgresType(addedColumn.ColumnType)))
if err != nil {
return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName,
schemaDelta.SrcTableName, err)
schemaDelta.DstTableName, err)
}
log.WithFields(log.Fields{
"flowName": flowJobName,
"tableName": schemaDelta.SrcTableName,
"flowName": flowJobName,
"srcTableName": schemaDelta.SrcTableName,
"dstTableName": schemaDelta.DstTableName,
}).Infof("[schema delta replay] added column %s with data type %s",
addedColumn.ColumnName, addedColumn.ColumnType)
}

err = tableSchemaModifyTx.Commit(c.ctx)
if err != nil {
return fmt.Errorf("failed to commit transaction for table schema modification for table %s: %w",
schemaDelta.SrcTableName, err)
schemaDelta.DstTableName, err)
}

return nil
Expand Down
4 changes: 1 addition & 3 deletions flow/connectors/postgres/postgres_cdc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,6 @@ func (suite *PostgresCDCTestSuite) validateMutatedToastRecords(records []model.R
}

func (suite *PostgresCDCTestSuite) SetupSuite() {
rand.Seed(time.Now().UnixNano())

var err error
suite.connector, err = NewPostgresConnector(context.Background(), &protos.PostgresConfig{
Host: "localhost",
Expand Down Expand Up @@ -828,6 +826,6 @@ func (suite *PostgresCDCTestSuite) TestToastHappyFlow() {
suite.dropTable(toastHappyFlowSrcTableName)
}

func TestPostgresTestSuite(t *testing.T) {
func TestPostgresCDCTestSuite(t *testing.T) {
suite.Run(t, new(PostgresCDCTestSuite))
}
Loading

0 comments on commit 03a8f0a

Please sign in to comment.