Skip to content

Commit

Permalink
Split ColumnNameType into ColumnNames/ColumnTypes
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Dec 29, 2023
1 parent fc27471 commit 268264c
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 203 deletions.
24 changes: 13 additions & 11 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,15 +593,9 @@ func (c *PostgresConnector) getTableSchemaForTable(
defer rows.Close()

fields := rows.FieldDescriptions()
res := &protos.TableSchema{
TableIdentifier: tableName,
Columns: nil,
PrimaryKeyColumns: pKeyCols,
IsReplicaIdentityFull: replicaIdentityType == ReplicaIdentityFull,
ColumnNameType: make([]string, 0, len(fields)*2),
}

for _, fieldDescription := range rows.FieldDescriptions() {
columnNames := make([]string, len(fields))
columnTypes := make([]string, len(fields))
for _, fieldDescription := range fields {
genericColType := postgresOIDToQValueKind(fieldDescription.DataTypeOID)
if genericColType == qvalue.QValueKindInvalid {
typeName, ok := c.customTypesMapping[fieldDescription.DataTypeOID]
Expand All @@ -612,14 +606,22 @@ func (c *PostgresConnector) getTableSchemaForTable(
}
}

res.ColumnNameType = append(res.ColumnNameType, fieldDescription.Name, string(genericColType))
columnNames = append(columnNames, fieldDescription.Name)
columnTypes = append(columnTypes, string(genericColType))
}

if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over table schema: %w", err)
}

return res, nil
return &protos.TableSchema{
TableIdentifier: tableName,
Columns: nil,
PrimaryKeyColumns: pKeyCols,
IsReplicaIdentityFull: replicaIdentityType == ReplicaIdentityFull,
ColumnNames: columnNames,
ColumnTypes: columnTypes,
}, nil
}

// SetupNormalizedTable sets up a normalized table, implementing the Connector interface.
Expand Down
85 changes: 46 additions & 39 deletions flow/connectors/postgres/postgres_schema_delta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,9 @@ func (suite *PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() {
})
suite.failTestError(err)
suite.Equal(&protos.TableSchema{
TableIdentifier: tableName,
ColumnNameType: []string{
"id", string(qvalue.QValueKindInt32),
"hi", string(qvalue.QValueKindInt64),
},
TableIdentifier: tableName,
ColumnNames: []string{"id", "hi"},
ColumnTypes: []string{string(qvalue.QValueKindInt32), string(qvalue.QValueKindInt64)},
PrimaryKeyColumns: []string{"id"},
}, output.TableNameSchemaMapping[tableName])
}
Expand All @@ -113,24 +111,28 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() {
expectedTableSchema := &protos.TableSchema{
TableIdentifier: tableName,
// goal is to test all types we're currently mapping to, not all QValue types
ColumnNameType: []string{
"id", string(qvalue.QValueKindInt32),
"c1", string(qvalue.QValueKindBit),
"c2", string(qvalue.QValueKindBoolean),
"c3", string(qvalue.QValueKindBytes),
"c4", string(qvalue.QValueKindDate),
"c5", string(qvalue.QValueKindFloat32),
"c6", string(qvalue.QValueKindFloat64),
"c7", string(qvalue.QValueKindInt16),
"c8", string(qvalue.QValueKindInt32),
"c9", string(qvalue.QValueKindInt64),
"c10", string(qvalue.QValueKindJSON),
"c11", string(qvalue.QValueKindNumeric),
"c12", string(qvalue.QValueKindString),
"c13", string(qvalue.QValueKindTime),
"c14", string(qvalue.QValueKindTimestamp),
"c15", string(qvalue.QValueKindTimestampTZ),
"c16", string(qvalue.QValueKindUUID),
ColumnNames: []string{
"id", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9",
"c10", "c11", "c12", "c13", "c14", "c15", "c16",
},
ColumnTypes: []string{
string(qvalue.QValueKindInt32),
string(qvalue.QValueKindBit),
string(qvalue.QValueKindBoolean),
string(qvalue.QValueKindBytes),
string(qvalue.QValueKindDate),
string(qvalue.QValueKindFloat32),
string(qvalue.QValueKindFloat64),
string(qvalue.QValueKindInt16),
string(qvalue.QValueKindInt32),
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindJSON),
string(qvalue.QValueKindNumeric),
string(qvalue.QValueKindString),
string(qvalue.QValueKindTime),
string(qvalue.QValueKindTimestamp),
string(qvalue.QValueKindTimestampTZ),
string(qvalue.QValueKindUUID),
},
PrimaryKeyColumns: []string{"id"},
}
Expand Down Expand Up @@ -166,17 +168,21 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() {

expectedTableSchema := &protos.TableSchema{
TableIdentifier: tableName,
ColumnNameType: []string{
"id", string(qvalue.QValueKindInt32),
"c1", string(qvalue.QValueKindString),
"C1", string(qvalue.QValueKindString),
"C 1", string(qvalue.QValueKindString),
"right", string(qvalue.QValueKindString),
"select", string(qvalue.QValueKindString),
"XMIN", string(qvalue.QValueKindString),
"Cariño", string(qvalue.QValueKindString),
"±ªþ³§", string(qvalue.QValueKindString),
"カラム", string(qvalue.QValueKindString),
ColumnNames: []string{
"id", "c1", "C1", "C 1", "right",
"select", "XMIN", "Cariño", "±ªþ³§", "カラム",
},
ColumnTypes: []string{
string(qvalue.QValueKindInt32),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
string(qvalue.QValueKindString),
},
PrimaryKeyColumns: []string{"id"},
}
Expand Down Expand Up @@ -212,11 +218,12 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() {

expectedTableSchema := &protos.TableSchema{
TableIdentifier: tableName,
ColumnNameType: []string{
" ", string(qvalue.QValueKindInt32),
" ", string(qvalue.QValueKindString),
" ", string(qvalue.QValueKindInt64),
" ", string(qvalue.QValueKindDate),
ColumnNames: []string{" ", " ", " ", "\t"},
ColumnTypes: []string{
string(qvalue.QValueKindInt32),
string(qvalue.QValueKindString),
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindDate),
},
PrimaryKeyColumns: []string{" "},
}
Expand Down
17 changes: 9 additions & 8 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,9 @@ func (c *SnowflakeConnector) getTableSchemaForTable(tableName string) (*protos.T
}
}()

res := &protos.TableSchema{
TableIdentifier: tableName,
Columns: nil,
ColumnNameType: make([]string, 0, 16),
}

var columnName, columnType pgtype.Text
columnNames := make([]string, 0, 8)
columnTypes := make([]string, 0, 8)
for rows.Next() {
err = rows.Scan(&columnName, &columnType)
if err != nil {
Expand All @@ -275,10 +271,15 @@ func (c *SnowflakeConnector) getTableSchemaForTable(tableName string) (*protos.T
genericColType = qvalue.QValueKindString
}

res.ColumnNameType = append(res.ColumnNameType, columnName.String, string(genericColType))
columnNames = append(columnNames, columnName.String)
columnTypes = append(columnTypes, string(genericColType))
}

return res, nil
return &protos.TableSchema{
TableIdentifier: tableName,
ColumnNames: columnNames,
ColumnTypes: columnTypes,
}, nil
}

func (c *SnowflakeConnector) GetLastOffset(jobName string) (int64, error) {
Expand Down
56 changes: 8 additions & 48 deletions flow/connectors/utils/columns.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package utils

import (
"slices"

"github.com/PeerDB-io/peer-flow/generated/protos"
"golang.org/x/exp/maps"
)
Expand All @@ -9,19 +11,15 @@ func TableSchemaColumns(schema *protos.TableSchema) int {
if schema.Columns != nil {
return len(schema.Columns)
} else {
return len(schema.ColumnNameType) / 2
return len(schema.ColumnNames)
}
}

func TableSchemaColumnNames(schema *protos.TableSchema) []string {
if schema.Columns != nil {
return maps.Keys(schema.Columns)
} else {
ret := make([]string, 0, len(schema.ColumnNameType)/2)
for i := 0; i < len(schema.ColumnNameType); i += 2 {
ret = append(ret, schema.ColumnNameType[i])
}
return ret
return slices.Clone(schema.ColumnNames)
}
}

Expand All @@ -31,24 +29,8 @@ func IterColumns(schema *protos.TableSchema, iter func(k, v string)) {
iter(k, v)
}
} else {
for i := 0; i < len(schema.ColumnNameType); i += 2 {
iter(schema.ColumnNameType[i], schema.ColumnNameType[i+1])
}
}
}

func IterColumns0(schema *protos.TableSchema, iter func(k, v string) bool) {
if schema.Columns != nil {
for k, v := range schema.Columns {
if iter(k, v) {
return
}
}
} else {
for i := 0; i < len(schema.ColumnNameType); i += 2 {
if iter(schema.ColumnNameType[i], schema.ColumnNameType[i+1]) {
return
}
for i, name := range schema.ColumnNames {
iter(name, schema.ColumnTypes[i])
}
}
}
Expand All @@ -64,34 +46,12 @@ func IterColumns1[T any](schema *protos.TableSchema, iter func(k, v string) (boo
}
return false, zero
} else {
for i := 0; i < len(schema.ColumnNameType); i += 2 {
done, ret := iter(schema.ColumnNameType[i], schema.ColumnNameType[i+1])
for i, name := range schema.ColumnNames {
done, ret := iter(name, schema.ColumnTypes[i])
if done {
return true, ret
}
}
return false, zero
}
}

func IterColumns2[T1 any, T2 any](schema *protos.TableSchema, iter func(k, v string) (bool, T1, T2)) (bool, T1, T2) {
var zero1 T1
var zero2 T2
if schema.Columns != nil {
for k, v := range schema.Columns {
done, r1, r2 := iter(k, v)
if done {
return true, r1, r2
}
}
return false, zero1, zero2
} else {
for i := 0; i < len(schema.ColumnNameType); i += 2 {
done, r1, r2 := iter(schema.ColumnNameType[i], schema.ColumnNameType[i+1])
if done {
return true, r1, r2
}
}
return false, zero1, zero2
}
}
39 changes: 20 additions & 19 deletions flow/e2e/postgres/peer_flow_pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,9 @@ func (s *PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() {
// verify we got our first row.
e2e.NormalizeFlowCountQuery(env, connectionGen, 2)
expectedTableSchema := &protos.TableSchema{
TableIdentifier: dstTableName,
ColumnNameType: []string{
"id", string(qvalue.QValueKindInt64),
"c1", string(qvalue.QValueKindInt64),
},
TableIdentifier: dstTableName,
ColumnNames: []string{"id", "c1"},
ColumnTypes: []string{string(qvalue.QValueKindInt64), string(qvalue.QValueKindInt64)},
PrimaryKeyColumns: []string{"id"},
}
output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{
Expand All @@ -175,10 +173,11 @@ func (s *PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() {
e2e.NormalizeFlowCountQuery(env, connectionGen, 4)
expectedTableSchema = &protos.TableSchema{
TableIdentifier: dstTableName,
ColumnNameType: []string{
"id", string(qvalue.QValueKindInt64),
"c1", string(qvalue.QValueKindInt64),
"c2", string(qvalue.QValueKindInt64),
ColumnNames: []string{"id", "c1", "c2"},
ColumnTypes: []string{
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindInt64),
},
PrimaryKeyColumns: []string{"id"},
}
Expand All @@ -204,11 +203,12 @@ func (s *PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() {
e2e.NormalizeFlowCountQuery(env, connectionGen, 6)
expectedTableSchema = &protos.TableSchema{
TableIdentifier: dstTableName,
ColumnNameType: []string{
"id", string(qvalue.QValueKindInt64),
"c1", string(qvalue.QValueKindInt64),
"c2", string(qvalue.QValueKindInt64),
"c3", string(qvalue.QValueKindInt64),
ColumnNames: []string{"id", "c1", "c2", "c3"},
ColumnTypes: []string{
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindInt64),
},
PrimaryKeyColumns: []string{"id"},
}
Expand All @@ -234,11 +234,12 @@ func (s *PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() {
e2e.NormalizeFlowCountQuery(env, connectionGen, 8)
expectedTableSchema = &protos.TableSchema{
TableIdentifier: dstTableName,
ColumnNameType: []string{
"id", string(qvalue.QValueKindInt64),
"c1", string(qvalue.QValueKindInt64),
"c2", string(qvalue.QValueKindInt64),
"c3", string(qvalue.QValueKindInt64),
ColumnNames: []string{"id", "c1", "c2", "c3"},
ColumnTypes: []string{
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindInt64),
string(qvalue.QValueKindInt64),
},
PrimaryKeyColumns: []string{"id"},
}
Expand Down
Loading

0 comments on commit 268264c

Please sign in to comment.