Skip to content

Commit

Permalink
Maintain column ordering (#937)
Browse files Browse the repository at this point in the history
Go maps are not ordered maps,
ie iterating a map has non determinstic order

This is causes us to lose ordering info when using using TableSchema.Columns

Instead store columns as two arrays of column names & column types,
which is used when TableSchema.Columns is nil

This way we're backwards compatible while new mirrors will have correct ordering
  • Loading branch information
serprex authored Dec 31, 2023
1 parent 96b8573 commit 0e468b0
Show file tree
Hide file tree
Showing 22 changed files with 753 additions and 536 deletions.
3 changes: 2 additions & 1 deletion flow/cmd/peer_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ func (h *FlowRequestHandler) GetColumns(
cols.column_name = pk.column_name
WHERE
cols.table_schema = $3
AND cols.table_name = $4;
AND cols.table_name = $4
ORDER BY cols.ordinal_position;
`, req.SchemaName, req.TableName, req.SchemaName, req.TableName)
if err != nil {
return &protos.TableColumnsResponse{Columns: nil}, err
Expand Down
12 changes: 5 additions & 7 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,16 +794,14 @@ func (c *BigQueryConnector) SetupNormalizedTables(
}

// convert the column names and types to bigquery types
columns := make([]*bigquery.FieldSchema, len(tableSchema.Columns), len(tableSchema.Columns)+2)
idx := 0
for colName, genericColType := range tableSchema.Columns {
columns[idx] = &bigquery.FieldSchema{
columns := make([]*bigquery.FieldSchema, 0, len(tableSchema.Columns)+2)
utils.IterColumns(tableSchema, func(colName, genericColType string) {
columns = append(columns, &bigquery.FieldSchema{
Name: colName,
Type: qValueKindToBigQueryType(genericColType),
Repeated: qvalue.QValueKind(genericColType).IsArray(),
}
idx++
}
})
})

if req.SoftDeleteColName != "" {
columns = append(columns, &bigquery.FieldSchema{
Expand Down
15 changes: 8 additions & 7 deletions flow/connectors/bigquery/merge_statement_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ type mergeStmtGenerator struct {
func (m *mergeStmtGenerator) generateFlattenedCTE() string {
// for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR
// statement.
flattenedProjs := make([]string, 0)
for colName, colType := range m.normalizedTableSchema.Columns {
flattenedProjs := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)+3)
utils.IterColumns(m.normalizedTableSchema, func(colName, colType string) {
bqType := qValueKindToBigQueryType(colType)
// CAST doesn't work for FLOAT, so rewrite it to FLOAT64.
if bqType == bigquery.FloatFieldType {
Expand Down Expand Up @@ -76,7 +76,7 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string {
colName, bqType, colName)
}
flattenedProjs = append(flattenedProjs, castStmt)
}
})
flattenedProjs = append(
flattenedProjs,
"_peerdb_timestamp",
Expand Down Expand Up @@ -111,12 +111,13 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string {
// generateMergeStmt generates a merge statement.
func (m *mergeStmtGenerator) generateMergeStmt() string {
// comma separated list of column names
backtickColNames := make([]string, 0, len(m.normalizedTableSchema.Columns))
pureColNames := make([]string, 0, len(m.normalizedTableSchema.Columns))
for colName := range m.normalizedTableSchema.Columns {
columnCount := utils.TableSchemaColumns(m.normalizedTableSchema)
backtickColNames := make([]string, 0, columnCount)
pureColNames := make([]string, 0, columnCount)
utils.IterColumns(m.normalizedTableSchema, func(colName, _ string) {
backtickColNames = append(backtickColNames, fmt.Sprintf("`%s`", colName))
pureColNames = append(pureColNames, colName)
}
})
csep := strings.Join(backtickColNames, ", ")
insertColumnsSQL := csep + fmt.Sprintf(", `%s`", m.peerdbCols.SyncedAtColName)
insertValuesSQL := csep + ",CURRENT_TIMESTAMP"
Expand Down
29 changes: 15 additions & 14 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,11 +391,11 @@ func generateCreateTableSQLForNormalizedTable(
softDeleteColName string,
syncedAtColName string,
) string {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns)+2)
for columnName, genericColumnType := range sourceTableSchema.Columns {
createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2)
utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName,
qValueKindToPostgresType(genericColumnType)))
}
})

if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
Expand Down Expand Up @@ -591,10 +591,11 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie
rawTableIdentifier string, peerdbCols *protos.PeerDBColumns,
) []string {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
columnNames := make([]string, 0, len(normalizedTableSchema.Columns))
flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
columnCount := utils.TableSchemaColumns(normalizedTableSchema)
columnNames := make([]string, 0, columnCount)
flattenedCastsSQLArray := make([]string, 0, columnCount)
primaryKeyColumnCasts := make(map[string]string)
for columnName, genericColumnType := range normalizedTableSchema.Columns {
utils.IterColumns(normalizedTableSchema, func(columnName, genericColumnType string) {
columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName))
pgType := qValueKindToPostgresType(genericColumnType)
if qvalue.QValueKind(genericColumnType).IsArray() {
Expand All @@ -608,15 +609,15 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie
if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
}
}
})
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")
parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier)

insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",")
updateColumnsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
for columnName := range normalizedTableSchema.Columns {
updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(normalizedTableSchema))
utils.IterColumns(normalizedTableSchema, func(columnName, _ string) {
updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName))
}
})
updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",")
deleteWhereClauseArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
for columnName, columnCast := range primaryKeyColumnCasts {
Expand Down Expand Up @@ -655,17 +656,17 @@ func (c *PostgresConnector) generateMergeStatement(
peerdbCols *protos.PeerDBColumns,
) string {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
columnNames := maps.Keys(normalizedTableSchema.Columns)
columnNames := utils.TableSchemaColumnNames(normalizedTableSchema)
for i, columnName := range columnNames {
columnNames[i] = fmt.Sprintf("\"%s\"", columnName)
}

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(normalizedTableSchema))
parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier)

primaryKeyColumnCasts := make(map[string]string)
primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
for columnName, genericColumnType := range normalizedTableSchema.Columns {
utils.IterColumns(normalizedTableSchema, func(columnName, genericColumnType string) {
pgType := qValueKindToPostgresType(genericColumnType)
if qvalue.QValueKind(genericColumnType).IsArray() {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
Expand All @@ -680,7 +681,7 @@ func (c *PostgresConnector) generateMergeStatement(
primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s",
columnName, columnName))
}
}
})
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")
insertValuesSQLArray := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
Expand Down
24 changes: 14 additions & 10 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,14 +586,10 @@ func (c *PostgresConnector) getTableSchemaForTable(
}
defer rows.Close()

res := &protos.TableSchema{
TableIdentifier: tableName,
Columns: make(map[string]string),
PrimaryKeyColumns: pKeyCols,
IsReplicaIdentityFull: replicaIdentityType == ReplicaIdentityFull,
}

for _, fieldDescription := range rows.FieldDescriptions() {
fields := rows.FieldDescriptions()
columnNames := make([]string, 0, len(fields))
columnTypes := make([]string, 0, len(fields))
for _, fieldDescription := range fields {
genericColType := postgresOIDToQValueKind(fieldDescription.DataTypeOID)
if genericColType == qvalue.QValueKindInvalid {
typeName, ok := c.customTypesMapping[fieldDescription.DataTypeOID]
Expand All @@ -604,14 +600,22 @@ func (c *PostgresConnector) getTableSchemaForTable(
}
}

res.Columns[fieldDescription.Name] = string(genericColType)
columnNames = append(columnNames, fieldDescription.Name)
columnTypes = append(columnTypes, string(genericColType))
}

if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over table schema: %w", err)
}

return res, nil
return &protos.TableSchema{
TableIdentifier: tableName,
Columns: nil,
PrimaryKeyColumns: pKeyCols,
IsReplicaIdentityFull: replicaIdentityType == ReplicaIdentityFull,
ColumnNames: columnNames,
ColumnTypes: columnTypes,
}, nil
}

// SetupNormalizedTable sets up a normalized table, implementing the Connector interface.
Expand Down
98 changes: 53 additions & 45 deletions flow/connectors/postgres/postgres_schema_delta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"testing"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/jackc/pgx/v5"
Expand Down Expand Up @@ -94,11 +95,9 @@ func (suite *PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() {
})
suite.failTestError(err)
suite.Equal(&protos.TableSchema{
TableIdentifier: tableName,
Columns: map[string]string{
"id": string(qvalue.QValueKindInt32),
"hi": string(qvalue.QValueKindInt64),
},
TableIdentifier: tableName,
ColumnNames: []string{"id", "hi"},
ColumnTypes: []string{string(qvalue.QValueKindInt32), string(qvalue.QValueKindInt64)},
PrimaryKeyColumns: []string{"id"},
}, output.TableNameSchemaMapping[tableName])
}
Expand All @@ -112,36 +111,40 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() {
expectedTableSchema := &protos.TableSchema{
TableIdentifier: tableName,
// goal is to test all types we're currently mapping to, not all QValue types
Columns: map[string]string{
"id": string(qvalue.QValueKindInt32),
"c1": string(qvalue.QValueKindBit),
"c2": string(qvalue.QValueKindBoolean),
"c3": string(qvalue.QValueKindBytes),
"c4": string(qvalue.QValueKindDate),
"c5": string(qvalue.QValueKindFloat32),
"c6": string(qvalue.QValueKindFloat64),
"c7": string(qvalue.QValueKindInt16),
"c8": string(qvalue.QValueKindInt32),
"c9": string(qvalue.QValueKindInt64),
"c10": string(qvalue.QValueKindJSON),
"c11": string(qvalue.QValueKindNumeric),
"c12": string(qvalue.QValueKindString),
"c13": string(qvalue.QValueKindTime),
"c14": string(qvalue.QValueKindTimestamp),
"c15": string(qvalue.QValueKindTimestampTZ),
"c16": string(qvalue.QValueKindUUID),
ColumnNames: []string{
"id", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9",
"c10", "c11", "c12", "c13", "c14", "c15", "c16",
},
ColumnTypes: []string{
string(qvalue.QValueKindInt32),
string(qvalue.QValueKindBit),
string(qvalue.QValueKindBoolean),
string(qvalue.QValueKindBytes),
string(qvalue.QValueKindDate),
string(qvalue.QValueKindFloat32),
string(qvalue.QValueKindFloat64),
string(qvalue.QValueKindInt16),
string(qvalue.QValueKindInt32),
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindJSON),
string(qvalue.QValueKindNumeric),
string(qvalue.QValueKindString),
string(qvalue.QValueKindTime),
string(qvalue.QValueKindTimestamp),
string(qvalue.QValueKindTimestampTZ),
string(qvalue.QValueKindUUID),
},
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
for columnName, columnType := range expectedTableSchema.Columns {
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
if columnName != "id" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
})
}
}
})

err = suite.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand All @@ -165,29 +168,33 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() {

expectedTableSchema := &protos.TableSchema{
TableIdentifier: tableName,
Columns: map[string]string{
"id": string(qvalue.QValueKindInt32),
"c1": string(qvalue.QValueKindString),
"C1": string(qvalue.QValueKindString),
"C 1": string(qvalue.QValueKindString),
"right": string(qvalue.QValueKindString),
"select": string(qvalue.QValueKindString),
"XMIN": string(qvalue.QValueKindString),
"Cariño": string(qvalue.QValueKindString),
"±ªþ³§": string(qvalue.QValueKindString),
"カラム": string(qvalue.QValueKindString),
ColumnNames: []string{
"id", "c1", "C1", "C 1", "right",
"select", "XMIN", "Cariño", "±ªþ³§", "カラム",
},
ColumnTypes: []string{
string(qvalue.QValueKindInt32),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
},
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
for columnName, columnType := range expectedTableSchema.Columns {
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
if columnName != "id" {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
})
}
}
})

err = suite.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand All @@ -211,23 +218,24 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() {

expectedTableSchema := &protos.TableSchema{
TableIdentifier: tableName,
Columns: map[string]string{
" ": string(qvalue.QValueKindInt32),
" ": string(qvalue.QValueKindString),
" ": string(qvalue.QValueKindInt64),
" ": string(qvalue.QValueKindDate),
ColumnNames: []string{" ", " ", " ", "\t"},
ColumnTypes: []string{
string(qvalue.QValueKindInt32),
string(qvalue.QValueKindString),
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindDate),
},
PrimaryKeyColumns: []string{" "},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
for columnName, columnType := range expectedTableSchema.Columns {
utils.IterColumns(expectedTableSchema, func(columnName, columnType string) {
if columnName != " " {
addedColumns = append(addedColumns, &protos.DeltaAddedColumn{
ColumnName: columnName,
ColumnType: columnType,
})
}
}
})

err = suite.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{
SrcTableName: tableName,
Expand Down
Loading

0 comments on commit 0e468b0

Please sign in to comment.