Skip to content

Commit

Permalink
Maintain column ordering
Browse files Browse the repository at this point in the history
Go maps are not ordered maps,
ie iterating a map has non determinstic order

This is causes us to lose ordering info when using using TableSchema.Columns

Implement a new field which is an array of strings name/type pairs,
which is used when TableSchema.Columns is nil

This way we're backwards compatible while new mirrors will have correct ordering

This is POC:
1. names are bad (ColumnSchema should be TableSchema but I felt like typing Column a lot)
2. maybe names/types should be two separate arrays, makes ColumnSchemaColumnNames free,
    tho need to update callers to copy slice when caller wants to mutate afterwards
  • Loading branch information
serprex committed Dec 29, 2023
1 parent 7e30827 commit 7da08c4
Show file tree
Hide file tree
Showing 18 changed files with 590 additions and 1,684 deletions.
12 changes: 5 additions & 7 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,16 +794,14 @@ func (c *BigQueryConnector) SetupNormalizedTables(
}

// convert the column names and types to bigquery types
columns := make([]*bigquery.FieldSchema, len(tableSchema.Columns), len(tableSchema.Columns)+2)
idx := 0
for colName, genericColType := range tableSchema.Columns {
columns[idx] = &bigquery.FieldSchema{
columns := make([]*bigquery.FieldSchema, 0, len(tableSchema.Columns)+2)
utils.IterColumns(tableSchema, func(colName, genericColType string) {
columns = append(columns, &bigquery.FieldSchema{
Name: colName,
Type: qValueKindToBigQueryType(genericColType),
Repeated: qvalue.QValueKind(genericColType).IsArray(),
}
idx++
}
})
})

if req.SoftDeleteColName != "" {
columns = append(columns, &bigquery.FieldSchema{
Expand Down
15 changes: 8 additions & 7 deletions flow/connectors/bigquery/merge_statement_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ type mergeStmtGenerator struct {
func (m *mergeStmtGenerator) generateFlattenedCTE() string {
// for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR
// statement.
flattenedProjs := make([]string, 0)
for colName, colType := range m.normalizedTableSchema.Columns {
flattenedProjs := make([]string, 0, utils.ColumnSchemaColumns(m.normalizedTableSchema)+3)
utils.IterColumns(m.normalizedTableSchema, func(colName, colType string) {
bqType := qValueKindToBigQueryType(colType)
// CAST doesn't work for FLOAT, so rewrite it to FLOAT64.
if bqType == bigquery.FloatFieldType {
Expand Down Expand Up @@ -76,7 +76,7 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string {
colName, bqType, colName)
}
flattenedProjs = append(flattenedProjs, castStmt)
}
})
flattenedProjs = append(
flattenedProjs,
"_peerdb_timestamp",
Expand Down Expand Up @@ -111,12 +111,13 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string {
// generateMergeStmt generates a merge statement.
func (m *mergeStmtGenerator) generateMergeStmt() string {
// comma separated list of column names
backtickColNames := make([]string, 0, len(m.normalizedTableSchema.Columns))
pureColNames := make([]string, 0, len(m.normalizedTableSchema.Columns))
for colName := range m.normalizedTableSchema.Columns {
columnCount := utils.ColumnSchemaColumns(m.normalizedTableSchema)
backtickColNames := make([]string, 0, columnCount)
pureColNames := make([]string, 0, columnCount)
utils.IterColumns(m.normalizedTableSchema, func(colName, _ string) {
backtickColNames = append(backtickColNames, fmt.Sprintf("`%s`", colName))
pureColNames = append(pureColNames, colName)
}
})
csep := strings.Join(backtickColNames, ", ")
insertColumnsSQL := csep + fmt.Sprintf(", `%s`", m.peerdbCols.SyncedAtColName)
insertValuesSQL := csep + ",CURRENT_TIMESTAMP"
Expand Down
29 changes: 15 additions & 14 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,11 +391,11 @@ func generateCreateTableSQLForNormalizedTable(
softDeleteColName string,
syncedAtColName string,
) string {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns)+2)
for columnName, genericColumnType := range sourceTableSchema.Columns {
createTableSQLArray := make([]string, 0, utils.ColumnSchemaColumns(sourceTableSchema)+2)
utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName,
qValueKindToPostgresType(genericColumnType)))
}
})

if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
Expand Down Expand Up @@ -591,10 +591,11 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie
rawTableIdentifier string, peerdbCols *protos.PeerDBColumns,
) []string {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
columnNames := make([]string, 0, len(normalizedTableSchema.Columns))
flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
columnCount := utils.ColumnSchemaColumns(normalizedTableSchema)
columnNames := make([]string, 0, columnCount)
flattenedCastsSQLArray := make([]string, 0, columnCount)
primaryKeyColumnCasts := make(map[string]string)
for columnName, genericColumnType := range normalizedTableSchema.Columns {
utils.IterColumns(normalizedTableSchema, func(columnName, genericColumnType string) {
columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName))
pgType := qValueKindToPostgresType(genericColumnType)
if qvalue.QValueKind(genericColumnType).IsArray() {
Expand All @@ -608,15 +609,15 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie
if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
}
}
})
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")
parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier)

insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",")
updateColumnsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
for columnName := range normalizedTableSchema.Columns {
updateColumnsSQLArray := make([]string, 0, utils.ColumnSchemaColumns(normalizedTableSchema))
utils.IterColumns(normalizedTableSchema, func(columnName, _ string) {
updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName))
}
})
updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",")
deleteWhereClauseArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
for columnName, columnCast := range primaryKeyColumnCasts {
Expand Down Expand Up @@ -655,17 +656,17 @@ func (c *PostgresConnector) generateMergeStatement(
peerdbCols *protos.PeerDBColumns,
) string {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
columnNames := maps.Keys(normalizedTableSchema.Columns)
columnNames := utils.ColumnSchemaColumnNames(normalizedTableSchema)
for i, columnName := range columnNames {
columnNames[i] = fmt.Sprintf("\"%s\"", columnName)
}

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
flattenedCastsSQLArray := make([]string, 0, utils.ColumnSchemaColumns(normalizedTableSchema))
parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier)

primaryKeyColumnCasts := make(map[string]string)
primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
for columnName, genericColumnType := range normalizedTableSchema.Columns {
utils.IterColumns(normalizedTableSchema, func(columnName, genericColumnType string) {
pgType := qValueKindToPostgresType(genericColumnType)
if qvalue.QValueKind(genericColumnType).IsArray() {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
Expand All @@ -680,7 +681,7 @@ func (c *PostgresConnector) generateMergeStatement(
primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s",
columnName, columnName))
}
}
})
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")
insertValuesSQLArray := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
Expand Down
6 changes: 4 additions & 2 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,13 @@ func (c *PostgresConnector) getTableSchemaForTable(
}
defer rows.Close()

fields := rows.FieldDescriptions()
res := &protos.TableSchema{
TableIdentifier: tableName,
Columns: make(map[string]string),
Columns: nil,
PrimaryKeyColumns: pKeyCols,
IsReplicaIdentityFull: replicaIdentityType == ReplicaIdentityFull,
ColumnNameType: make([]string, 0, len(fields)*2),
}

for _, fieldDescription := range rows.FieldDescriptions() {
Expand All @@ -610,7 +612,7 @@ func (c *PostgresConnector) getTableSchemaForTable(
}
}

res.Columns[fieldDescription.Name] = string(genericColType)
res.ColumnNameType = append(res.ColumnNameType, fieldDescription.Name, string(genericColType))
}

if err = rows.Err(); err != nil {
Expand Down
13 changes: 7 additions & 6 deletions flow/connectors/postgres/postgres_schema_delta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"testing"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/jackc/pgx/v5"
Expand Down Expand Up @@ -134,14 +135,14 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() {
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
for columnName, columnType := range expectedTableSchema.Columns {
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
if columnName != "id" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
})
}
}
})

err = suite.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down Expand Up @@ -180,14 +181,14 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() {
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
for columnName, columnType := range expectedTableSchema.Columns {
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
if columnName != "id" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
})
}
}
})

err = suite.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down Expand Up @@ -220,14 +221,14 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() {
PrimaryKeyColumns: []string{" "},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
for columnName, columnType := range expectedTableSchema.Columns {
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
if columnName != " " {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
})
}
}
})

err = suite.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down
28 changes: 16 additions & 12 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"github.com/snowflakedb/gosnowflake"
"go.temporal.io/sdk/activity"
"golang.org/x/exp/maps"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -260,7 +259,8 @@ func (c *SnowflakeConnector) getTableSchemaForTable(tableName string) (*protos.T

res := &protos.TableSchema{
TableIdentifier: tableName,
Columns: make(map[string]string),
Columns: nil,
ColumnNameType: make([]string, 0, 16),
}

var columnName, columnType pgtype.Text
Expand All @@ -275,7 +275,7 @@ func (c *SnowflakeConnector) getTableSchemaForTable(tableName string) (*protos.T
genericColType = qvalue.QValueKindString
}

res.Columns[columnName.String] = string(genericColType)
res.ColumnNameType = append(res.ColumnNameType, columnName.String, string(genericColType))
}

return res, nil
Expand Down Expand Up @@ -765,17 +765,17 @@ func generateCreateTableSQLForNormalizedTable(
softDeleteColName string,
syncedAtColName string,
) string {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns)+2)
for columnName, genericColumnType := range sourceTableSchema.Columns {
createTableSQLArray := make([]string, 0, utils.ColumnSchemaColumns(sourceTableSchema)+2)
utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) {
normalizedColName := SnowflakeIdentifierNormalize(columnName)
sfColType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType))
if err != nil {
slog.Warn(fmt.Sprintf("failed to convert column type %s to snowflake type", genericColumnType),
slog.Any("error", err))
continue
return
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s,`, normalizedColName, sfColType))
}
})

// add a _peerdb_is_deleted column to the normalized table
// this is boolean default false, and is used to mark records as deleted
Expand Down Expand Up @@ -826,14 +826,14 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
if err != nil {
return 0, fmt.Errorf("unable to parse destination table '%s'", parsedDstTable)
}
columnNames := maps.Keys(normalizedTableSchema.Columns)
columnNames := utils.ColumnSchemaColumnNames(normalizedTableSchema)

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
for columnName, genericColumnType := range normalizedTableSchema.Columns {
flattenedCastsSQLArray := make([]string, 0, utils.ColumnSchemaColumns(normalizedTableSchema))
ret, err := utils.IterColumns1(normalizedTableSchema, func(columnName, genericColumnType string) (bool, error) {
qvKind := qvalue.QValueKind(genericColumnType)
sfType, err := qValueKindToSnowflakeType(qvKind)
if err != nil {
return 0, fmt.Errorf("failed to convert column type %s to snowflake type: %w",
return true, fmt.Errorf("failed to convert column type %s to snowflake type: %w",
genericColumnType, err)
}

Expand Down Expand Up @@ -865,6 +865,10 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
toVariantColumnName, columnName, sfType, targetColumnName))
}
}
return false, nil
})
if ret {
return 0, err
}
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ""), ",")

Expand Down Expand Up @@ -1133,7 +1137,7 @@ func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*proto
for _, renameRequest := range req.RenameTableOptions {
src := renameRequest.CurrentName
dst := renameRequest.NewName
allCols := strings.Join(maps.Keys(renameRequest.TableSchema.Columns), ",")
allCols := strings.Join(utils.ColumnSchemaColumnNames(renameRequest.TableSchema), ",")
pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",")

c.logger.Info(fmt.Sprintf("handling soft-deletes for table '%s'...", dst))
Expand Down
Loading

0 comments on commit 7da08c4

Please sign in to comment.