Skip to content

Commit

Permalink
rm connectors/utils/columns.go
Browse files Browse the repository at this point in the history
Introduced while schemas had to support map[string]string
alongside ColumnNames/ColumnTypes, no longer needed

Also clean up validate_mirror config error timing
  • Loading branch information
serprex committed Jan 24, 2024
1 parent 8eb2567 commit a4d6e28
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 98 deletions.
15 changes: 7 additions & 8 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,20 @@ import (
func (h *FlowRequestHandler) ValidateCDCMirror(
ctx context.Context, req *protos.CreateCDCFlowRequest,
) (*protos.ValidateCDCMirrorResponse, error) {
pgPeer, err := connpostgres.NewPostgresConnector(ctx, req.ConnectionConfigs.Source.GetPostgresConfig())
sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig()
if sourcePeerConfig == nil {
slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source))
return nil, fmt.Errorf("source peer config is nil")
}

pgPeer, err := connpostgres.NewPostgresConnector(ctx, sourcePeerConfig)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("failed to create postgres connector: %v", err)
}

defer pgPeer.Close()

sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig()
if sourcePeerConfig == nil {
slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source))
return nil, fmt.Errorf("source peer config is nil")
}

// Check permissions of postgres peer
err = pgPeer.CheckReplicationPermissions(sourcePeerConfig.User)
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,13 +810,14 @@ func (c *BigQueryConnector) SetupNormalizedTables(

// convert the column names and types to bigquery types
columns := make([]*bigquery.FieldSchema, 0, len(tableSchema.ColumnNames)+2)
utils.IterColumns(tableSchema, func(colName, genericColType string) {
for i, colName := range tableSchema.ColumnNames {
genericColType := tableSchema.ColumnTypes[i]
columns = append(columns, &bigquery.FieldSchema{
Name: colName,
Type: qValueKindToBigQueryType(genericColType),
Repeated: qvalue.QValueKind(genericColType).IsArray(),
})
})
}

if req.SoftDeleteColName != "" {
columns = append(columns, &bigquery.FieldSchema{
Expand Down Expand Up @@ -908,7 +909,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos
dstDatasetTable.string()))

if req.SoftDeleteColName != nil {
allCols := strings.Join(utils.TableSchemaColumnNames(renameRequest.TableSchema), ",")
allCols := strings.Join(renameRequest.TableSchema.ColumnNames, ",")
pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",")

c.logger.InfoContext(c.ctx, fmt.Sprintf("handling soft-deletes for table '%s'...", dstDatasetTable.string()))
Expand Down
4 changes: 2 additions & 2 deletions flow/connectors/bigquery/merge_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type mergeStmtGenerator struct {
func (m *mergeStmtGenerator) generateFlattenedCTE() string {
// for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR
// statement.
flattenedProjs := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)+3)
flattenedProjs := make([]string, 0, len(m.normalizedTableSchema.ColumnNames)+3)

for i, colName := range m.normalizedTableSchema.ColumnNames {
colType := m.normalizedTableSchema.ColumnTypes[i]
Expand Down Expand Up @@ -124,7 +124,7 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string {
// generateMergeStmt generates a merge statement.
func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) string {
// comma separated list of column names
columnCount := utils.TableSchemaColumns(m.normalizedTableSchema)
columnCount := len(m.normalizedTableSchema.ColumnNames)
backtickColNames := make([]string, 0, columnCount)
shortBacktickColNames := make([]string, 0, columnCount)
pureColNames := make([]string, 0, columnCount)
Expand Down
7 changes: 4 additions & 3 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,12 @@ func generateCreateTableSQLForNormalizedTable(
softDeleteColName string,
syncedAtColName string,
) string {
createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2)
utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.ColumnNames)+2)
for i, columnName := range sourceTableSchema.ColumnNames {
genericColumnType := sourceTableSchema.ColumnTypes[i]
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf("%s %s", QuoteIdentifier(columnName), qValueKindToPostgresType(genericColumnType)))
})
}

if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
Expand Down
30 changes: 16 additions & 14 deletions flow/connectors/postgres/normalize_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ func (n *normalizeStmtGenerator) generateNormalizeStatements() []string {
}

func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
columnCount := utils.TableSchemaColumns(n.normalizedTableSchema)
columnCount := len(n.normalizedTableSchema.ColumnNames)
columnNames := make([]string, 0, columnCount)
flattenedCastsSQLArray := make([]string, 0, columnCount)
primaryKeyColumnCasts := make(map[string]string, len(n.normalizedTableSchema.PrimaryKeyColumns))
utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) {
for i, columnName := range n.normalizedTableSchema.ColumnNames {
genericColumnType := n.normalizedTableSchema.ColumnTypes[i]
quotedCol := QuoteIdentifier(columnName)
stringCol := QuoteLiteral(columnName)
columnNames = append(columnNames, quotedCol)
Expand All @@ -64,16 +65,16 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType)
}
})
}
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")
parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName)

insertColumnsSQL := strings.Join(columnNames, ",")
updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema))
utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) {
updateColumnsSQLArray := make([]string, 0, columnCount)
for _, columnName := range n.normalizedTableSchema.ColumnNames {
quotedCol := QuoteIdentifier(columnName)
updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`%s=EXCLUDED.%s`, quotedCol, quotedCol))
})
}
updateColumnsSQL := strings.Join(updateColumnsSQLArray, ",")
deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns))
for columnName, columnCast := range primaryKeyColumnCasts {
Expand Down Expand Up @@ -104,19 +105,20 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
}

func (n *normalizeStmtGenerator) generateMergeStatement() string {
quotedColumnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema)
for i, columnName := range quotedColumnNames {
quotedColumnNames[i] = QuoteIdentifier(columnName)
}
columnCount := len(n.normalizedTableSchema.ColumnNames)
quotedColumnNames := make([]string, columnCount)

flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema))
flattenedCastsSQLArray := make([]string, 0, columnCount)
parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName)

primaryKeyColumnCasts := make(map[string]string)
primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns))
utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) {
for i, columnName := range n.normalizedTableSchema.ColumnNames {
genericColumnType := n.normalizedTableSchema.ColumnTypes[i]
quotedCol := QuoteIdentifier(columnName)
stringCol := QuoteLiteral(columnName)
quotedColumnNames[i] = quotedCol

pgType := qValueKindToPostgresType(genericColumnType)
if qvalue.QValueKind(genericColumnType).IsArray() {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
Expand All @@ -131,9 +133,9 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string {
primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s",
quotedCol, quotedCol))
}
})
}
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")
insertValuesSQLArray := make([]string, 0, len(quotedColumnNames)+2)
insertValuesSQLArray := make([]string, 0, columnCount+2)
for _, quotedCol := range quotedColumnNames {
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", quotedCol))
}
Expand Down
19 changes: 9 additions & 10 deletions flow/connectors/postgres/postgres_schema_delta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/require"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/e2eshared"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
Expand Down Expand Up @@ -121,14 +120,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() {
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "id" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down Expand Up @@ -171,14 +170,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() {
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "id" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down Expand Up @@ -212,14 +211,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() {
PrimaryKeyColumns: []string{" "},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != " " {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down
13 changes: 5 additions & 8 deletions flow/connectors/snowflake/merge_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ type mergeStmtGenerator struct {

func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
parsedDstTable, _ := utils.ParseSchemaTable(m.dstTableName)
columnNames := utils.TableSchemaColumnNames(m.normalizedTableSchema)
columnNames := m.normalizedTableSchema.ColumnNames

flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema))
err := utils.IterColumnsError(m.normalizedTableSchema, func(columnName, genericColumnType string) error {
flattenedCastsSQLArray := make([]string, 0, len(columnNames))
for i, columnName := range columnNames {
genericColumnType := m.normalizedTableSchema.ColumnTypes[i]
qvKind := qvalue.QValueKind(genericColumnType)
sfType, err := qValueKindToSnowflakeType(qvKind)
if err != nil {
return fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err)
return "", fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err)
}

targetColumnName := SnowflakeIdentifierNormalize(columnName)
Expand Down Expand Up @@ -69,10 +70,6 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
toVariantColumnName, columnName, sfType, targetColumnName))
}
}
return nil
})
if err != nil {
return "", err
}
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")

Expand Down
11 changes: 6 additions & 5 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -834,17 +834,18 @@ func generateCreateTableSQLForNormalizedTable(
softDeleteColName string,
syncedAtColName string,
) string {
createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2)
utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.ColumnNames)+2)
for i, columnName := range sourceTableSchema.ColumnNames {
genericColumnType := sourceTableSchema.ColumnTypes[i]
normalizedColName := SnowflakeIdentifierNormalize(columnName)
sfColType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType))
if err != nil {
slog.Warn(fmt.Sprintf("failed to convert column type %s to snowflake type", genericColumnType),
slog.Any("error", err))
return
continue
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s`, normalizedColName, sfColType))
})
}

// add a _peerdb_is_deleted column to the normalized table
// this is boolean default false, and is used to mark records as deleted
Expand Down Expand Up @@ -1000,7 +1001,7 @@ func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*proto
for _, renameRequest := range req.RenameTableOptions {
src := renameRequest.CurrentName
dst := renameRequest.NewName
allCols := strings.Join(utils.TableSchemaColumnNames(renameRequest.TableSchema), ",")
allCols := strings.Join(renameRequest.TableSchema.ColumnNames, ",")
pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",")

c.logger.Info(fmt.Sprintf("handling soft-deletes for table '%s'...", dst))
Expand Down
31 changes: 0 additions & 31 deletions flow/connectors/utils/columns.go

This file was deleted.

19 changes: 9 additions & 10 deletions flow/e2e/snowflake/snowflake_schema_delta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/require"

connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/e2eshared"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
Expand Down Expand Up @@ -99,14 +98,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddAllColumnTypes() {
},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "ID" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down Expand Up @@ -154,14 +153,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddTrickyColumnNames() {
},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "ID" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down Expand Up @@ -193,14 +192,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddWhitespaceColumnNames() {
},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != " " {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down
Loading

0 comments on commit a4d6e28

Please sign in to comment.