Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rm connectors/utils/columns.go #1149

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,20 @@ import (
func (h *FlowRequestHandler) ValidateCDCMirror(
ctx context.Context, req *protos.CreateCDCFlowRequest,
) (*protos.ValidateCDCMirrorResponse, error) {
pgPeer, err := connpostgres.NewPostgresConnector(ctx, req.ConnectionConfigs.Source.GetPostgresConfig())
sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig()
if sourcePeerConfig == nil {
slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source))
return nil, fmt.Errorf("source peer config is nil")
}

pgPeer, err := connpostgres.NewPostgresConnector(ctx, sourcePeerConfig)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("failed to create postgres connector: %v", err)
}

defer pgPeer.Close()

sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig()
if sourcePeerConfig == nil {
slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source))
return nil, fmt.Errorf("source peer config is nil")
}

// Check permissions of postgres peer
err = pgPeer.CheckReplicationPermissions(sourcePeerConfig.User)
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,13 +810,14 @@ func (c *BigQueryConnector) SetupNormalizedTables(

// convert the column names and types to bigquery types
columns := make([]*bigquery.FieldSchema, 0, len(tableSchema.ColumnNames)+2)
utils.IterColumns(tableSchema, func(colName, genericColType string) {
for i, colName := range tableSchema.ColumnNames {
genericColType := tableSchema.ColumnTypes[i]
columns = append(columns, &bigquery.FieldSchema{
Name: colName,
Type: qValueKindToBigQueryType(genericColType),
Repeated: qvalue.QValueKind(genericColType).IsArray(),
})
})
}

if req.SoftDeleteColName != "" {
columns = append(columns, &bigquery.FieldSchema{
Expand Down Expand Up @@ -908,7 +909,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos
dstDatasetTable.string()))

if req.SoftDeleteColName != nil {
allCols := strings.Join(utils.TableSchemaColumnNames(renameRequest.TableSchema), ",")
allCols := strings.Join(renameRequest.TableSchema.ColumnNames, ",")
pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",")

c.logger.InfoContext(c.ctx, fmt.Sprintf("handling soft-deletes for table '%s'...", dstDatasetTable.string()))
Expand Down
4 changes: 2 additions & 2 deletions flow/connectors/bigquery/merge_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type mergeStmtGenerator struct {
func (m *mergeStmtGenerator) generateFlattenedCTE() string {
// for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR
// statement.
flattenedProjs := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)+3)
flattenedProjs := make([]string, 0, len(m.normalizedTableSchema.ColumnNames)+3)

for i, colName := range m.normalizedTableSchema.ColumnNames {
colType := m.normalizedTableSchema.ColumnTypes[i]
Expand Down Expand Up @@ -124,7 +124,7 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string {
// generateMergeStmt generates a merge statement.
func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) string {
// comma separated list of column names
columnCount := utils.TableSchemaColumns(m.normalizedTableSchema)
columnCount := len(m.normalizedTableSchema.ColumnNames)
backtickColNames := make([]string, 0, columnCount)
shortBacktickColNames := make([]string, 0, columnCount)
pureColNames := make([]string, 0, columnCount)
Expand Down
7 changes: 4 additions & 3 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,12 @@ func generateCreateTableSQLForNormalizedTable(
softDeleteColName string,
syncedAtColName string,
) string {
createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2)
utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.ColumnNames)+2)
for i, columnName := range sourceTableSchema.ColumnNames {
genericColumnType := sourceTableSchema.ColumnTypes[i]
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf("%s %s", QuoteIdentifier(columnName), qValueKindToPostgresType(genericColumnType)))
})
}

if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
Expand Down
30 changes: 16 additions & 14 deletions flow/connectors/postgres/normalize_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ func (n *normalizeStmtGenerator) generateNormalizeStatements() []string {
}

func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
columnCount := utils.TableSchemaColumns(n.normalizedTableSchema)
columnCount := len(n.normalizedTableSchema.ColumnNames)
columnNames := make([]string, 0, columnCount)
flattenedCastsSQLArray := make([]string, 0, columnCount)
primaryKeyColumnCasts := make(map[string]string, len(n.normalizedTableSchema.PrimaryKeyColumns))
utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) {
for i, columnName := range n.normalizedTableSchema.ColumnNames {
genericColumnType := n.normalizedTableSchema.ColumnTypes[i]
quotedCol := QuoteIdentifier(columnName)
stringCol := QuoteLiteral(columnName)
columnNames = append(columnNames, quotedCol)
Expand All @@ -64,16 +65,16 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType)
}
})
}
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")
parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName)

insertColumnsSQL := strings.Join(columnNames, ",")
updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema))
utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) {
updateColumnsSQLArray := make([]string, 0, columnCount)
for _, columnName := range n.normalizedTableSchema.ColumnNames {
quotedCol := QuoteIdentifier(columnName)
updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`%s=EXCLUDED.%s`, quotedCol, quotedCol))
})
}
updateColumnsSQL := strings.Join(updateColumnsSQLArray, ",")
deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns))
for columnName, columnCast := range primaryKeyColumnCasts {
Expand Down Expand Up @@ -104,19 +105,20 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
}

func (n *normalizeStmtGenerator) generateMergeStatement() string {
quotedColumnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema)
for i, columnName := range quotedColumnNames {
quotedColumnNames[i] = QuoteIdentifier(columnName)
}
columnCount := len(n.normalizedTableSchema.ColumnNames)
quotedColumnNames := make([]string, columnCount)

flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema))
flattenedCastsSQLArray := make([]string, 0, columnCount)
parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName)

primaryKeyColumnCasts := make(map[string]string)
primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns))
utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) {
for i, columnName := range n.normalizedTableSchema.ColumnNames {
genericColumnType := n.normalizedTableSchema.ColumnTypes[i]
quotedCol := QuoteIdentifier(columnName)
stringCol := QuoteLiteral(columnName)
quotedColumnNames[i] = quotedCol

pgType := qValueKindToPostgresType(genericColumnType)
if qvalue.QValueKind(genericColumnType).IsArray() {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
Expand All @@ -131,9 +133,9 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string {
primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s",
quotedCol, quotedCol))
}
})
}
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")
insertValuesSQLArray := make([]string, 0, len(quotedColumnNames)+2)
insertValuesSQLArray := make([]string, 0, columnCount+2)
for _, quotedCol := range quotedColumnNames {
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", quotedCol))
}
Expand Down
19 changes: 9 additions & 10 deletions flow/connectors/postgres/postgres_schema_delta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/require"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/e2eshared"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
Expand Down Expand Up @@ -121,14 +120,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() {
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "id" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down Expand Up @@ -171,14 +170,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() {
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "id" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down Expand Up @@ -212,14 +211,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() {
PrimaryKeyColumns: []string{" "},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != " " {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down
13 changes: 5 additions & 8 deletions flow/connectors/snowflake/merge_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ type mergeStmtGenerator struct {

func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
parsedDstTable, _ := utils.ParseSchemaTable(m.dstTableName)
columnNames := utils.TableSchemaColumnNames(m.normalizedTableSchema)
columnNames := m.normalizedTableSchema.ColumnNames

flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema))
err := utils.IterColumnsError(m.normalizedTableSchema, func(columnName, genericColumnType string) error {
flattenedCastsSQLArray := make([]string, 0, len(columnNames))
for i, columnName := range columnNames {
genericColumnType := m.normalizedTableSchema.ColumnTypes[i]
qvKind := qvalue.QValueKind(genericColumnType)
sfType, err := qValueKindToSnowflakeType(qvKind)
if err != nil {
return fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err)
return "", fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err)
}

targetColumnName := SnowflakeIdentifierNormalize(columnName)
Expand Down Expand Up @@ -69,10 +70,6 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
toVariantColumnName, columnName, sfType, targetColumnName))
}
}
return nil
})
if err != nil {
return "", err
}
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")

Expand Down
11 changes: 6 additions & 5 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -834,17 +834,18 @@ func generateCreateTableSQLForNormalizedTable(
softDeleteColName string,
syncedAtColName string,
) string {
createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2)
utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.ColumnNames)+2)
for i, columnName := range sourceTableSchema.ColumnNames {
genericColumnType := sourceTableSchema.ColumnTypes[i]
normalizedColName := SnowflakeIdentifierNormalize(columnName)
sfColType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType))
if err != nil {
slog.Warn(fmt.Sprintf("failed to convert column type %s to snowflake type", genericColumnType),
slog.Any("error", err))
return
continue
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s`, normalizedColName, sfColType))
})
}

// add a _peerdb_is_deleted column to the normalized table
// this is boolean default false, and is used to mark records as deleted
Expand Down Expand Up @@ -1000,7 +1001,7 @@ func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*proto
for _, renameRequest := range req.RenameTableOptions {
src := renameRequest.CurrentName
dst := renameRequest.NewName
allCols := strings.Join(utils.TableSchemaColumnNames(renameRequest.TableSchema), ",")
allCols := strings.Join(renameRequest.TableSchema.ColumnNames, ",")
pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",")

c.logger.Info(fmt.Sprintf("handling soft-deletes for table '%s'...", dst))
Expand Down
31 changes: 0 additions & 31 deletions flow/connectors/utils/columns.go

This file was deleted.

19 changes: 9 additions & 10 deletions flow/e2e/snowflake/snowflake_schema_delta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/require"

connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/e2eshared"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
Expand Down Expand Up @@ -99,14 +98,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddAllColumnTypes() {
},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "ID" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down Expand Up @@ -154,14 +153,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddTrickyColumnNames() {
},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != "ID" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down Expand Up @@ -193,14 +192,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddWhitespaceColumnNames() {
},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
for i, columnName := range expectedTableSchema.ColumnNames {
if columnName != " " {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
ColumnType: expectedTableSchema.ColumnTypes[i],
})
}
})
}

err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down
Loading
Loading