From 68a4fe4efc222961bd85d13ab3c812779a257895 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 14 Feb 2024 13:23:27 +0000 Subject: [PATCH 1/6] Misc cleanup of qvalue code (#1288) Split out from #1285 --- flow/connectors/postgres/qvalue_convert.go | 5 ++-- .../snowflake/merge_stmt_generator.go | 28 +++++++++---------- flow/model/conversion_avro.go | 2 +- flow/model/qrecord_batch.go | 5 ---- flow/model/qvalue/avro_converter.go | 12 ++++---- 5 files changed, 23 insertions(+), 29 deletions(-) diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index 4823aa2ecc..72f51dff6b 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -501,10 +501,9 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( } default: textVal, ok := value.(string) - if !ok { - return qvalue.QValue{}, fmt.Errorf("failed to parse value %v into QValueKind %v", value, qvalueKind) + if ok { + val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: textVal} } - val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: textVal} } // parsing into pgtype failed. diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go index 4d8ea7dc5a..1ab579069c 100644 --- a/flow/connectors/snowflake/merge_stmt_generator.go +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -40,7 +40,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { } targetColumnName := SnowflakeIdentifierNormalize(column.Name) - switch qvalue.QValueKind(genericColumnType) { + switch qvKind { case qvalue.QValueKindBytes, qvalue.QValueKindBit: flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+ "AS %s", toVariantColumnName, column.Name, targetColumnName)) @@ -61,21 +61,19 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { // flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+ // "Microseconds*1000) "+ // "AS %s", toVariantColumnName, columnName, columnName)) - default: - if qvKind == qvalue.QValueKindNumeric { - precision, scale := numeric.ParseNumericTypmod(column.TypeModifier) - if column.TypeModifier == -1 || precision > 38 || scale > 37 { - precision = numeric.PeerDBNumericPrecision - scale = numeric.PeerDBNumericScale - } - numericType := fmt.Sprintf("NUMERIC(%d,%d)", precision, scale) - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s", - toVariantColumnName, column.Name, numericType, targetColumnName)) - } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s", - toVariantColumnName, column.Name, sfType, targetColumnName)) + case qvalue.QValueKindNumeric: + precision, scale := numeric.ParseNumericTypmod(column.TypeModifier) + if column.TypeModifier == -1 || precision > 38 || scale > 37 { + precision = numeric.PeerDBNumericPrecision + scale = numeric.PeerDBNumericScale } + numericType := fmt.Sprintf("NUMERIC(%d,%d)", precision, scale) + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s", + toVariantColumnName, column.Name, numericType, targetColumnName)) + default: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s", + toVariantColumnName, column.Name, sfType, targetColumnName)) } } flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") diff --git a/flow/model/conversion_avro.go b/flow/model/conversion_avro.go index b379abc09a..b26aeaf9d7 100644 --- a/flow/model/conversion_avro.go +++ b/flow/model/conversion_avro.go @@ -29,7 +29,7 @@ func NewQRecordAvroConverter( } func (qac *QRecordAvroConverter) Convert() (map[string]interface{}, error) { - m := map[string]interface{}{} + m := make(map[string]interface{}, len(qac.QRecord)) for idx, val := range qac.QRecord { key := qac.ColNames[idx] diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go index d7f25c7f5d..572c788630 100644 --- a/flow/model/qrecord_batch.go +++ b/flow/model/qrecord_batch.go @@ -173,11 +173,6 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { values[i] = timestampTZ case qvalue.QValueKindUUID: - if qValue.Value == nil { - values[i] = nil - break - } - v, ok := qValue.Value.([16]byte) // treat it as byte slice if !ok { src.err = fmt.Errorf("invalid UUID value %v", qValue.Value) diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 7c902ffcb5..4f5488cbff 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -169,6 +169,10 @@ func NewQValueAvroConverter(value QValue, targetDWH QDWHType, nullable bool) *QV } func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { + if c.Nullable && c.Value.Value == nil { + return nil, nil + } + switch c.Value.Kind { case QValueKindInvalid: // we will attempt to convert invalid to a string @@ -457,14 +461,12 @@ func (c *QValueAvroConverter) processNullableUnion( avroType string, value interface{}, ) (interface{}, error) { - if value == nil && c.Nullable { - return nil, nil - } - if c.Nullable { + if value == nil { + return nil, nil + } return goavro.Union(avroType, value), nil } - return value, nil } From 0fb50d9069faa182794e3f51a1513bb95e076502 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 14 Feb 2024 13:24:34 +0000 Subject: [PATCH 2/6] Generalize recent []int32 nil handling fix to rest of array types (#1287) Ideally we'd preserve nulls, but this'll do for now --- flow/connectors/postgres/qvalue_convert.go | 191 +++------------------ flow/connectors/utils/array.go | 13 ++ 2 files changed, 37 insertions(+), 167 deletions(-) diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index 72f51dff6b..bd39fd1048 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/lib/pq/oid" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/model/qvalue" ) @@ -190,6 +191,20 @@ func parseJSON(value interface{}) (qvalue.QValue, error) { return qvalue.QValue{Kind: qvalue.QValueKindJSON, Value: string(jsonVal)}, nil } +func convertToArray[T any](kind qvalue.QValueKind, value interface{}) (qvalue.QValue, error) { + switch v := value.(type) { + case pgtype.Array[T]: + if v.Valid { + return qvalue.QValue{Kind: kind, Value: v.Elements}, nil + } + case []T: + return qvalue.QValue{Kind: kind, Value: v}, nil + case []interface{}: + return qvalue.QValue{Kind: kind, Value: utils.ArrayCastElements[T](v)}, nil + } + return qvalue.QValue{}, fmt.Errorf("failed to parse array %s from %T: %v", kind, value, value) +} + func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (qvalue.QValue, error) { val := qvalue.QValue{} @@ -319,179 +334,21 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: rat} } case qvalue.QValueKindArrayFloat32: - switch v := value.(type) { - case pgtype.Array[float32]: - if v.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindArrayFloat32, Value: v.Elements} - } - case []float32: - val = qvalue.QValue{Kind: qvalue.QValueKindArrayFloat32, Value: v} - case []interface{}: - float32Array := make([]float32, len(v)) - for i, val := range v { - float32Array[i] = val.(float32) - } - val = qvalue.QValue{Kind: qvalue.QValueKindArrayFloat32, Value: float32Array} - default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array float32: %v", value) - } + return convertToArray[float32](qvalueKind, value) case qvalue.QValueKindArrayFloat64: - switch v := value.(type) { - case pgtype.Array[float64]: - if v.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindArrayFloat64, Value: v.Elements} - } - case []float64: - val = qvalue.QValue{Kind: qvalue.QValueKindArrayFloat64, Value: v} - case []interface{}: - float64Array := make([]float64, len(v)) - for i, val := range v { - float64Array[i] = val.(float64) - } - val = qvalue.QValue{Kind: qvalue.QValueKindArrayFloat64, Value: float64Array} - default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array float64: %v", value) - } + return convertToArray[float64](qvalueKind, value) case qvalue.QValueKindArrayInt16: - switch v := value.(type) { - case pgtype.Array[int16]: - if v.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt16, Value: v.Elements} - } - case []int16: - val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt16, Value: v} - case []interface{}: - int16Array := make([]int16, len(v)) - for i, val := range v { - int16Array[i] = val.(int16) - } - val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt16, Value: int16Array} - default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array int16: %v", value) - } + return convertToArray[int16](qvalueKind, value) case qvalue.QValueKindArrayInt32: - switch v := value.(type) { - case pgtype.Array[int32]: - if v.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt32, Value: v.Elements} - } - case []int32: - val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt32, Value: v} - case []interface{}: - int32Array := make([]int32, len(v)) - for i, val := range v { - if val == nil { - int32Array[i] = 0 - } else { - int32Array[i] = val.(int32) - } - } - val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt32, Value: int32Array} - default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array int32: %v", value) - } + return convertToArray[int32](qvalueKind, value) case qvalue.QValueKindArrayInt64: - switch v := value.(type) { - case pgtype.Array[int64]: - if v.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt64, Value: v.Elements} - } - case []int64: - val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt64, Value: v} - case []interface{}: - int64Array := make([]int64, len(v)) - for i, val := range v { - int64Array[i] = val.(int64) - } - val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt64, Value: int64Array} - default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array int64: %v", value) - } - case qvalue.QValueKindArrayDate: - switch v := value.(type) { - case pgtype.Array[time.Time]: - if v.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindArrayDate, Value: v.Elements} - } - case []time.Time: - val = qvalue.QValue{Kind: qvalue.QValueKindArrayDate, Value: v} - case []interface{}: - dateArray := make([]time.Time, len(v)) - for i, val := range v { - dateArray[i] = val.(time.Time) - } - val = qvalue.QValue{Kind: qvalue.QValueKindArrayDate, Value: dateArray} - default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array date: %v", value) - } - case qvalue.QValueKindArrayTimestamp: - switch v := value.(type) { - case pgtype.Array[time.Time]: - if v.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestamp, Value: v.Elements} - } - case []time.Time: - val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestamp, Value: v} - case []interface{}: - timestampArray := make([]time.Time, len(v)) - for i, val := range v { - timestampArray[i] = val.(time.Time) - } - val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestamp, Value: timestampArray} - default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array timestamp: %v", value) - } - case qvalue.QValueKindArrayTimestampTZ: - switch v := value.(type) { - case pgtype.Array[time.Time]: - if v.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestampTZ, Value: v.Elements} - } - case []time.Time: - val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestampTZ, Value: v} - case []interface{}: - timestampTZArray := make([]time.Time, len(v)) - for i, val := range v { - timestampTZArray[i] = val.(time.Time) - } - val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestampTZ, Value: timestampTZArray} - default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array timestamptz: %v", value) - } + return convertToArray[int64](qvalueKind, value) + case qvalue.QValueKindArrayDate, qvalue.QValueKindArrayTimestamp, qvalue.QValueKindArrayTimestampTZ: + return convertToArray[time.Time](qvalueKind, value) case qvalue.QValueKindArrayBoolean: - switch v := value.(type) { - case pgtype.Array[bool]: - if v.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindArrayBoolean, Value: v.Elements} - } - case []bool: - val = qvalue.QValue{Kind: qvalue.QValueKindArrayBoolean, Value: v} - case []interface{}: - boolArray := make([]bool, len(v)) - for i, val := range v { - boolArray[i] = val.(bool) - } - val = qvalue.QValue{Kind: qvalue.QValueKindArrayBoolean, Value: boolArray} - default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array boolean: %v", value) - } + return convertToArray[bool](qvalueKind, value) case qvalue.QValueKindArrayString: - switch v := value.(type) { - case pgtype.Array[string]: - if v.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindArrayString, Value: v.Elements} - } - case []string: - val = qvalue.QValue{Kind: qvalue.QValueKindArrayString, Value: v} - case []interface{}: - stringArray := make([]string, len(v)) - for i, val := range v { - stringArray[i] = val.(string) - } - val = qvalue.QValue{Kind: qvalue.QValueKindArrayString, Value: stringArray} - default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array string: %v", value) - } + return convertToArray[string](qvalueKind, value) case qvalue.QValueKindPoint: xCoord := value.(pgtype.Point).P.X yCoord := value.(pgtype.Point).P.Y diff --git a/flow/connectors/utils/array.go b/flow/connectors/utils/array.go index d203beacc2..2633153ae6 100644 --- a/flow/connectors/utils/array.go +++ b/flow/connectors/utils/array.go @@ -52,3 +52,16 @@ func ArraysHaveOverlap[T comparable](first, second []T) bool { return false } + +func ArrayCastElements[T any](arr []any) []T { + res := make([]T, 0, len(arr)) + for _, val := range arr { + if v, ok := val.(T); ok { + res = append(res, v) + } else { + var none T + res = append(res, none) + } + } + return res +} From b2ed20ad5a5bef697b08a35016b4737cab503f2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 14 Feb 2024 14:22:12 +0000 Subject: [PATCH 3/6] Add support for "char" (#1285) Postgres offers a type "char" distinct from CHAR, represented by one byte Map this type in QValue, sqlserver also has char, & on clickhouse we can represent it with FixedString(1) --- flow/connectors/clickhouse/qvalue_convert.go | 1 + flow/connectors/postgres/qvalue_convert.go | 6 ++++++ .../postgres/schema_delta_test_constants.go | 12 +++++++++--- flow/connectors/snowflake/avro_file_writer_test.go | 3 +++ flow/connectors/sql/query_executor.go | 7 +++++++ flow/connectors/sqlserver/qvalue_convert.go | 3 ++- flow/model/model.go | 6 ++++++ flow/model/qrecord_batch.go | 8 ++++++++ flow/model/qvalue/avro_converter.go | 5 +++-- flow/model/qvalue/kind.go | 3 +++ flow/model/qvalue/qvalue.go | 6 ++++++ 11 files changed, 54 insertions(+), 6 deletions(-) diff --git a/flow/connectors/clickhouse/qvalue_convert.go b/flow/connectors/clickhouse/qvalue_convert.go index f80ea7857e..8ec8a52f0f 100644 --- a/flow/connectors/clickhouse/qvalue_convert.go +++ b/flow/connectors/clickhouse/qvalue_convert.go @@ -16,6 +16,7 @@ var clickhouseTypeToQValueKindMap = map[string]qvalue.QValueKind{ "CHAR": qvalue.QValueKindString, "TEXT": qvalue.QValueKindString, "String": qvalue.QValueKindString, + "FixedString(1)": qvalue.QValueKindQChar, "Bool": qvalue.QValueKindBoolean, "DateTime": qvalue.QValueKindTimestamp, "TIMESTAMP": qvalue.QValueKindTimestamp, diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index bd39fd1048..81f19e1ec9 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -32,6 +32,8 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { return qvalue.QValueKindFloat32 case pgtype.Float8OID: return qvalue.QValueKindFloat64 + case pgtype.QCharOID: + return qvalue.QValueKindQChar case pgtype.TextOID, pgtype.VarcharOID, pgtype.BPCharOID: return qvalue.QValueKindString case pgtype.ByteaOID: @@ -122,6 +124,8 @@ func qValueKindToPostgresType(colTypeStr string) string { return "REAL" case qvalue.QValueKindFloat64: return "DOUBLE PRECISION" + case qvalue.QValueKindQChar: + return "\"char\"" case qvalue.QValueKindString: return "TEXT" case qvalue.QValueKindBytes: @@ -277,6 +281,8 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( case qvalue.QValueKindFloat64: floatVal := value.(float64) val = qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: floatVal} + case qvalue.QValueKindQChar: + val = qvalue.QValue{Kind: qvalue.QValueKindQChar, Value: uint8(value.(rune))} case qvalue.QValueKindString: // handling all unsupported types with strings as well for now. val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: fmt.Sprint(value)} diff --git a/flow/connectors/postgres/schema_delta_test_constants.go b/flow/connectors/postgres/schema_delta_test_constants.go index d86f8e98a0..6ded70625a 100644 --- a/flow/connectors/postgres/schema_delta_test_constants.go +++ b/flow/connectors/postgres/schema_delta_test_constants.go @@ -19,6 +19,7 @@ var AddAllColumnTypes = []string{ string(qvalue.QValueKindJSON), string(qvalue.QValueKindNumeric), string(qvalue.QValueKindString), + string(qvalue.QValueKindQChar), string(qvalue.QValueKindTime), string(qvalue.QValueKindTimestamp), string(qvalue.QValueKindTimestampTZ), @@ -93,21 +94,26 @@ var AddAllColumnTypesFields = []*protos.FieldDescription{ }, { Name: "c13", - Type: string(qvalue.QValueKindTime), + Type: string(qvalue.QValueKindQChar), TypeModifier: -1, }, { Name: "c14", - Type: string(qvalue.QValueKindTimestamp), + Type: string(qvalue.QValueKindTime), TypeModifier: -1, }, { Name: "c15", - Type: string(qvalue.QValueKindTimestampTZ), + Type: string(qvalue.QValueKindTimestamp), TypeModifier: -1, }, { Name: "c16", + Type: string(qvalue.QValueKindTimestampTZ), + TypeModifier: -1, + }, + { + Name: "c17", Type: string(qvalue.QValueKindUUID), TypeModifier: -1, }, diff --git a/flow/connectors/snowflake/avro_file_writer_test.go b/flow/connectors/snowflake/avro_file_writer_test.go index bd9b155dab..0252b53fc8 100644 --- a/flow/connectors/snowflake/avro_file_writer_test.go +++ b/flow/connectors/snowflake/avro_file_writer_test.go @@ -40,6 +40,8 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue. value = big.NewRat(int64(placeHolder), 1) case qvalue.QValueKindUUID: value = uuid.New() // assuming you have the github.com/google/uuid package + case qvalue.QValueKindQChar: + value = uint8(48) // case qvalue.QValueKindArray: // value = []int{1, 2, 3} // placeholder array, replace with actual logic // case qvalue.QValueKindStruct: @@ -85,6 +87,7 @@ func generateRecords( qvalue.QValueKindNumeric, qvalue.QValueKindBytes, qvalue.QValueKindUUID, + qvalue.QValueKindQChar, // qvalue.QValueKindJSON, qvalue.QValueKindBit, } diff --git a/flow/connectors/sql/query_executor.go b/flow/connectors/sql/query_executor.go index 14bc6629b7..64ff4ecc54 100644 --- a/flow/connectors/sql/query_executor.go +++ b/flow/connectors/sql/query_executor.go @@ -308,6 +308,9 @@ func (g *GenericSQLQueryExecutor) CheckNull(ctx context.Context, schema string, } func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { + if val == nil { + return qvalue.QValue{Kind: kind, Value: nil}, nil + } switch kind { case qvalue.QValueKindInt32: if v, ok := val.(*sql.NullInt32); ok { @@ -341,6 +344,10 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { return qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: nil}, nil } } + case qvalue.QValueKindQChar: + if v, ok := val.(uint8); ok { + return qvalue.QValue{Kind: qvalue.QValueKindQChar, Value: v}, nil + } case qvalue.QValueKindString: if v, ok := val.(*sql.NullString); ok { if v.Valid { diff --git a/flow/connectors/sqlserver/qvalue_convert.go b/flow/connectors/sqlserver/qvalue_convert.go index cff634139a..b4f73420e1 100644 --- a/flow/connectors/sqlserver/qvalue_convert.go +++ b/flow/connectors/sqlserver/qvalue_convert.go @@ -10,6 +10,7 @@ var qValueKindToSQLServerTypeMap = map[qvalue.QValueKind]string{ qvalue.QValueKindFloat32: "REAL", qvalue.QValueKindFloat64: "FLOAT", qvalue.QValueKindNumeric: "DECIMAL(38, 9)", + qvalue.QValueKindQChar: "CHAR", qvalue.QValueKindString: "NTEXT", qvalue.QValueKindJSON: "NTEXT", // SQL Server doesn't have a native JSON type qvalue.QValueKindTimestamp: "DATETIME2", @@ -51,7 +52,7 @@ var sqlServerTypeToQValueKindMap = map[string]qvalue.QValueKind{ "UNIQUEIDENTIFIER": qvalue.QValueKindUUID, "SMALLINT": qvalue.QValueKindInt32, "TINYINT": qvalue.QValueKindInt32, - "CHAR": qvalue.QValueKindString, + "CHAR": qvalue.QValueKindQChar, "VARCHAR": qvalue.QValueKindString, "NCHAR": qvalue.QValueKindString, "NVARCHAR": qvalue.QValueKindString, diff --git a/flow/model/model.go b/flow/model/model.go index 14de42a44e..a5bb754f2a 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -158,7 +158,13 @@ func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) { } jsonStruct[col] = binStr + case qvalue.QValueKindQChar: + ch, ok := v.Value.(uint8) + if !ok { + return nil, fmt.Errorf("expected \"char\" value for column %s for %T", col, v.Value) + } + jsonStruct[col] = string(ch) case qvalue.QValueKindString, qvalue.QValueKindJSON: strVal, ok := v.Value.(string) if !ok { diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go index 572c788630..4455101ad8 100644 --- a/flow/model/qrecord_batch.go +++ b/flow/model/qrecord_batch.go @@ -137,6 +137,14 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { } values[i] = v + case qvalue.QValueKindQChar: + v, ok := qValue.Value.(uint8) + if !ok { + src.err = fmt.Errorf("invalid \"char\" value") + return nil, src.err + } + values[i] = rune(v) + case qvalue.QValueKindString: v, ok := qValue.Value.(string) if !ok { diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 4f5488cbff..e5e6046fb5 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -64,7 +64,7 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH QDWHType, precision } switch kind { - case QValueKindString: + case QValueKindString, QValueKindQChar: return "string", nil case QValueKindUUID: return AvroSchemaLogical{ @@ -291,7 +291,8 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return goavro.Union("int.date", t), nil } return t, nil - + case QValueKindQChar: + return c.processNullableUnion("string", string(c.Value.Value.(uint8))) case QValueKindString, QValueKindCIDR, QValueKindINET, QValueKindMacaddr: if c.TargetDWH == QDWHTypeSnowflake && c.Value.Value != nil && (len(c.Value.Value.(string)) > 15*1024*1024) { diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index 6328897d3f..78c9ece45b 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -17,6 +17,7 @@ const ( QValueKindInt64 QValueKind = "int64" QValueKindBoolean QValueKind = "bool" QValueKindStruct QValueKind = "struct" + QValueKindQChar QValueKind = "qchar" QValueKindString QValueKind = "string" QValueKindTimestamp QValueKind = "timestamp" QValueKindTimestampTZ QValueKind = "timestamptz" @@ -63,6 +64,7 @@ var QValueKindToSnowflakeTypeMap = map[QValueKind]string{ QValueKindFloat32: "FLOAT", QValueKindFloat64: "FLOAT", QValueKindNumeric: "NUMBER(38, 9)", + QValueKindQChar: "CHAR", QValueKindString: "STRING", QValueKindJSON: "VARIANT", QValueKindTimestamp: "TIMESTAMP_NTZ", @@ -101,6 +103,7 @@ var QValueKindToClickhouseTypeMap = map[QValueKind]string{ QValueKindFloat32: "Float32", QValueKindFloat64: "Float64", QValueKindNumeric: "Decimal128(9)", + QValueKindQChar: "FixedString(1)", QValueKindString: "String", QValueKindJSON: "String", QValueKindTimestamp: "DateTime64(6)", diff --git a/flow/model/qvalue/qvalue.go b/flow/model/qvalue/qvalue.go index b9e040e170..b0b556b3d7 100644 --- a/flow/model/qvalue/qvalue.go +++ b/flow/model/qvalue/qvalue.go @@ -42,6 +42,12 @@ func (q QValue) Equals(other QValue) bool { return compareBoolean(q.Value, other.Value) case QValueKindStruct: return compareStruct(q.Value, other.Value) + case QValueKindQChar: + if (q.Value == nil) == (other.Value == nil) { + return q.Value == nil || q.Value.(uint8) == other.Value.(uint8) + } else { + return false + } case QValueKindString: return compareString(q.Value, other.Value) // all internally represented as a Golang time.Time From db7cdfe51325e2163569dd3484237d4236ef0345 Mon Sep 17 00:00:00 2001 From: Kevin Biju <52661649+heavycrystal@users.noreply.github.com> Date: Wed, 14 Feb 2024 20:36:55 +0530 Subject: [PATCH 4/6] fixes to handling of custom publications for AddTablesToPublication (#1289) 1. need to fetch tables schema-qualified for comparison 2. order of ArrayMinus was wrong 3. also add validation for custom publication existing --- flow/connectors/postgres/client.go | 17 +++++++++++++---- flow/connectors/postgres/postgres.go | 4 ++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 72a5691fe5..563e0c6824 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -629,7 +629,7 @@ func (c *PostgresConnector) getDefaultPublicationName(jobName string) string { func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []string, pubName string) error { if c.conn == nil { - return fmt.Errorf("check tables: conn is nil") + return errors.New("check tables: conn is nil") } // Check that we can select from all tables @@ -649,11 +649,20 @@ func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames [] } } - // Check if tables belong to publication tableStr := strings.Join(tableArr, ",") if pubName != "" { + // Check if publication exists + err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil) + if err != nil { + if err == pgx.ErrNoRows { + return fmt.Errorf("publication does not exist: %s", pubName) + } + return fmt.Errorf("error while checking for publication existence: %w", err) + } + + // Check if tables belong to publication var pubTableCount int - err := c.conn.QueryRow(ctx, fmt.Sprintf(` + err = c.conn.QueryRow(ctx, fmt.Sprintf(` with source_table_components (sname, tname) as (values %s) select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables INNER JOIN source_table_components stc @@ -663,7 +672,7 @@ func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames [] } if pubTableCount != len(tableNames) { - return fmt.Errorf("not all tables belong to publication") + return errors.New("not all tables belong to publication") } } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index f41e015672..d81401511b 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -973,7 +973,7 @@ func (c *PostgresConnector) AddTablesToPublication(ctx context.Context, req *pro // just check if we have all the tables already in the publication for custom publications if req.PublicationName != "" { rows, err := c.conn.Query(ctx, - "SELECT tablename FROM pg_publication_tables WHERE pubname=$1", req.PublicationName) + "SELECT schemaname || '.' || tablename FROM pg_publication_tables WHERE pubname=$1", req.PublicationName) if err != nil { return fmt.Errorf("failed to check tables in publication: %w", err) } @@ -982,7 +982,7 @@ func (c *PostgresConnector) AddTablesToPublication(ctx context.Context, req *pro if err != nil { return fmt.Errorf("failed to check tables in publication: %w", err) } - notPresentTables := utils.ArrayMinus(tableNames, additionalSrcTables) + notPresentTables := utils.ArrayMinus(additionalSrcTables, tableNames) if len(notPresentTables) > 0 { return fmt.Errorf("some additional tables not present in custom publication: %s", strings.Join(notPresentTables, ", ")) From cc61dfb45d896548eb5a8c57d2d7187ca1149a9d Mon Sep 17 00:00:00 2001 From: Amogh Bharadwaj Date: Wed, 14 Feb 2024 21:14:44 +0530 Subject: [PATCH 5/6] Filter out invalid BigQuery timestamps (#1256) Fixes #1254 Adds a test for it in QRep Also we are nulling out and logging values of time/timestamps which cannot be scanned by pgx (time.Time). An example is a [year with more than 4 digits](https://pkg.go.dev/time#Parse) --- flow/connectors/postgres/cdc.go | 9 ++++ flow/connectors/utils/avro/avro_writer.go | 1 + flow/e2e/bigquery/bigquery_helper.go | 2 +- flow/e2e/bigquery/qrep_flow_bq_test.go | 66 +++++++++++++++++++++++ flow/model/conversion_avro.go | 7 +++ flow/model/qvalue/avro_converter.go | 54 +++++++++++++------ flow/model/qvalue/timestamp.go | 20 +++++++ 7 files changed, 141 insertions(+), 18 deletions(-) create mode 100644 flow/model/qvalue/timestamp.go diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 44e81c8227..8992a6562f 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -720,6 +720,15 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma parsedData, err = dt.Codec.DecodeValue(p.typeMap, dataType, formatCode, data) } if err != nil { + if dt.Name == "time" || dt.Name == "timetz" || + dt.Name == "timestamp" || dt.Name == "timestamptz" { + // indicates year is more than 4 digits or something similar, + // which you can insert into postgres, + // but not representable by time.Time + p.logger.Warn(fmt.Sprintf("Invalidated and hence nulled %s data: %s", + dt.Name, string(data))) + return qvalue.QValue{}, nil + } return qvalue.QValue{}, err } retVal, err := parseFieldFromPostgresOID(dataType, parsedData) diff --git a/flow/connectors/utils/avro/avro_writer.go b/flow/connectors/utils/avro/avro_writer.go index 00855bf769..806d8a0b7d 100644 --- a/flow/connectors/utils/avro/avro_writer.go +++ b/flow/connectors/utils/avro/avro_writer.go @@ -150,6 +150,7 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter p.targetDWH, p.avroSchema.NullableFields, colNames, + logger, ) avroMap, err := avroConverter.Convert() diff --git a/flow/e2e/bigquery/bigquery_helper.go b/flow/e2e/bigquery/bigquery_helper.go index bab2be062f..6f70046137 100644 --- a/flow/e2e/bigquery/bigquery_helper.go +++ b/flow/e2e/bigquery/bigquery_helper.go @@ -377,7 +377,7 @@ func (b *BigQueryTestHelper) ExecuteAndProcessQuery(query string) (*model.QRecor }, nil } -// returns whether the function errors or there are nulls +// returns whether the function errors or there are no nulls func (b *BigQueryTestHelper) CheckNull(tableName string, colName []string) (bool, error) { if len(colName) == 0 { return true, nil diff --git a/flow/e2e/bigquery/qrep_flow_bq_test.go b/flow/e2e/bigquery/qrep_flow_bq_test.go index d0048b948d..395a2c5ea2 100644 --- a/flow/e2e/bigquery/qrep_flow_bq_test.go +++ b/flow/e2e/bigquery/qrep_flow_bq_test.go @@ -1,7 +1,9 @@ package e2e_bigquery import ( + "context" "fmt" + "strings" "github.com/stretchr/testify/require" @@ -15,6 +17,34 @@ func (s PeerFlowE2ETestSuiteBQ) setupSourceTable(tableName string, rowCount int) require.NoError(s.t, err) } +func (s PeerFlowE2ETestSuiteBQ) setupTimeTable(tableName string) { + tblFields := []string{ + "watermark_ts timestamp", + "mytimestamp timestamp", + "mytztimestamp timestamptz", + } + tblFieldStr := strings.Join(tblFields, ",") + _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE e2e_test_%s.%s ( + %s + );`, s.bqSuffix, tableName, tblFieldStr)) + + require.NoError(s.t, err) + + var rows []string + row := `(CURRENT_TIMESTAMP,'10001-03-14 23:05:52','50001-03-14 23:05:52.216809+00')` + rows = append(rows, row) + + _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + INSERT INTO e2e_test_%s.%s ( + watermark_ts, + mytimestamp, + mytztimestamp + ) VALUES %s; + `, s.bqSuffix, tableName, strings.Join(rows, ","))) + require.NoError(s.t, err) +} + func (s PeerFlowE2ETestSuiteBQ) Test_Complete_QRep_Flow_Avro() { env := e2e.NewTemporalTestWorkflowEnvironment(s.t) @@ -46,6 +76,42 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Complete_QRep_Flow_Avro() { e2e.RequireEqualTables(s, tblName, "*") } +func (s PeerFlowE2ETestSuiteBQ) Test_Invalid_Timestamps_QRep() { + env := e2e.NewTemporalTestWorkflowEnvironment(s.t) + + tblName := "test_qrep_flow_avro_bq" + s.setupTimeTable(tblName) + + query := fmt.Sprintf("SELECT * FROM e2e_test_%s.%s WHERE watermark_ts BETWEEN {{.start}} AND {{.end}}", + s.bqSuffix, tblName) + + qrepConfig, err := e2e.CreateQRepWorkflowConfig("test_qrep_flow_avro", + fmt.Sprintf("e2e_test_%s.%s", s.bqSuffix, tblName), + tblName, + query, + s.bqHelper.Peer, + "", + true, + "") + qrepConfig.WatermarkColumn = "watermark_ts" + require.NoError(s.t, err) + e2e.RunQrepFlowWorkflow(env, qrepConfig) + + // Verify workflow completes without error + require.True(s.t, env.IsWorkflowCompleted()) + + err = env.GetWorkflowError() + require.NoError(s.t, err) + + ok, err := s.bqHelper.CheckNull(tblName, []string{"mytimestamp"}) + require.NoError(s.t, err) + require.False(s.t, ok) + + ok, err = s.bqHelper.CheckNull(tblName, []string{"mytztimestamp"}) + require.NoError(s.t, err) + require.False(s.t, ok) +} + func (s PeerFlowE2ETestSuiteBQ) Test_PeerDB_Columns_QRep_BQ() { env := e2e.NewTemporalTestWorkflowEnvironment(s.t) diff --git a/flow/model/conversion_avro.go b/flow/model/conversion_avro.go index b26aeaf9d7..39e3579f8f 100644 --- a/flow/model/conversion_avro.go +++ b/flow/model/conversion_avro.go @@ -4,6 +4,8 @@ import ( "encoding/json" "fmt" + "go.temporal.io/sdk/log" + "github.com/PeerDB-io/peer-flow/model/qvalue" ) @@ -12,6 +14,7 @@ type QRecordAvroConverter struct { TargetDWH qvalue.QDWHType NullableFields map[string]struct{} ColNames []string + logger log.Logger } func NewQRecordAvroConverter( @@ -19,12 +22,14 @@ func NewQRecordAvroConverter( targetDWH qvalue.QDWHType, nullableFields map[string]struct{}, colNames []string, + logger log.Logger, ) *QRecordAvroConverter { return &QRecordAvroConverter{ QRecord: q, TargetDWH: targetDWH, NullableFields: nullableFields, ColNames: colNames, + logger: logger, } } @@ -39,7 +44,9 @@ func (qac *QRecordAvroConverter) Convert() (map[string]interface{}, error) { val, qac.TargetDWH, nullable, + qac.logger, ) + avroVal, err := avroConverter.ToAvroValue() if err != nil { return nil, fmt.Errorf("failed to convert QValue to Avro-compatible value: %w", err) diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index e5e6046fb5..24aeb73e20 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -1,6 +1,7 @@ package qvalue import ( + "errors" "fmt" "log/slog" "math/big" @@ -8,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/linkedin/goavro/v2" + "go.temporal.io/sdk/log" hstore_util "github.com/PeerDB-io/peer-flow/hstore" "github.com/PeerDB-io/peer-flow/model/numeric" @@ -158,13 +160,15 @@ type QValueAvroConverter struct { Value QValue TargetDWH QDWHType Nullable bool + logger log.Logger } -func NewQValueAvroConverter(value QValue, targetDWH QDWHType, nullable bool) *QValueAvroConverter { +func NewQValueAvroConverter(value QValue, targetDWH QDWHType, nullable bool, logger log.Logger) *QValueAvroConverter { return &QValueAvroConverter{ Value: value, TargetDWH: targetDWH, Nullable: nullable, + logger: logger, } } @@ -245,6 +249,7 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return t.(int64), nil } } + if c.Nullable { return goavro.Union("long.timestamp-micros", t.(int64)), nil } @@ -269,6 +274,7 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return t.(int64), nil } } + if c.Nullable { return goavro.Union("long.timestamp-micros", t.(int64)), nil } @@ -318,7 +324,7 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { case QValueKindBoolean: return c.processNullableUnion("boolean", c.Value.Value) case QValueKindStruct: - return nil, fmt.Errorf("QValueKindStruct not supported") + return nil, errors.New("QValueKindStruct not supported") case QValueKindNumeric: return c.processNumeric() case QValueKindBytes, QValueKindBit: @@ -371,7 +377,7 @@ func (c *QValueAvroConverter) processGoTimeTZ() (interface{}, error) { t, ok := c.Value.Value.(time.Time) if !ok { - return nil, fmt.Errorf("invalid TimeTZ value") + return nil, errors.New("invalid TimeTZ value") } // Snowflake has issues with avro timestamp types, returning as string form @@ -389,7 +395,7 @@ func (c *QValueAvroConverter) processGoTime() (interface{}, error) { t, ok := c.Value.Value.(time.Time) if !ok { - return nil, fmt.Errorf("invalid Time value") + return nil, errors.New("invalid Time value") } // Snowflake has issues with avro timestamp types, returning as string form @@ -411,7 +417,7 @@ func (c *QValueAvroConverter) processGoTimestampTZ() (interface{}, error) { t, ok := c.Value.Value.(time.Time) if !ok { - return nil, fmt.Errorf("invalid TimestampTZ value") + return nil, errors.New("invalid TimestampTZ value") } // Snowflake has issues with avro timestamp types, returning as string form @@ -419,6 +425,13 @@ func (c *QValueAvroConverter) processGoTimestampTZ() (interface{}, error) { if c.TargetDWH == QDWHTypeSnowflake { return t.Format("2006-01-02 15:04:05.999999-0700"), nil } + + // Bigquery will not allow timestamp if it is less than 1AD and more than 9999AD + // So make such timestamps null + if DisallowedTimestamp(c.TargetDWH, t, c.logger) { + return nil, nil + } + return t.UnixMicro(), nil } @@ -429,7 +442,7 @@ func (c *QValueAvroConverter) processGoTimestamp() (interface{}, error) { t, ok := c.Value.Value.(time.Time) if !ok { - return nil, fmt.Errorf("invalid Timestamp value") + return nil, errors.New("invalid Timestamp value") } // Snowflake has issues with avro timestamp types, returning as string form @@ -437,6 +450,13 @@ func (c *QValueAvroConverter) processGoTimestamp() (interface{}, error) { if c.TargetDWH == QDWHTypeSnowflake { return t.Format("2006-01-02 15:04:05.999999"), nil } + + // Bigquery will not allow timestamp if it is less than 1AD and more than 9999AD + // So make such timestamps null + if DisallowedTimestamp(c.TargetDWH, t, c.logger) { + return nil, nil + } + return t.UnixMicro(), nil } @@ -447,7 +467,7 @@ func (c *QValueAvroConverter) processGoDate() (interface{}, error) { t, ok := c.Value.Value.(time.Time) if !ok { - return nil, fmt.Errorf("invalid Time value for Date") + return nil, errors.New("invalid Time value for Date") } // Snowflake has issues with avro timestamp types, returning as string form @@ -513,7 +533,7 @@ func (c *QValueAvroConverter) processBytes() (interface{}, error) { byteData, ok := c.Value.Value.([]byte) if !ok { - return nil, fmt.Errorf("invalid Bytes value") + return nil, errors.New("invalid Bytes value") } if c.Nullable { @@ -557,7 +577,7 @@ func (c *QValueAvroConverter) processArrayBoolean() (interface{}, error) { arrayData, ok := c.Value.Value.([]bool) if !ok { - return nil, fmt.Errorf("invalid Boolean array value") + return nil, errors.New("invalid Boolean array value") } if c.Nullable { @@ -574,7 +594,7 @@ func (c *QValueAvroConverter) processArrayTime() (interface{}, error) { arrayTime, ok := c.Value.Value.([]time.Time) if !ok { - return nil, fmt.Errorf("invalid Timestamp array value") + return nil, errors.New("invalid Timestamp array value") } transformedTimeArr := make([]interface{}, 0, len(arrayTime)) @@ -602,7 +622,7 @@ func (c *QValueAvroConverter) processArrayDate() (interface{}, error) { arrayDate, ok := c.Value.Value.([]time.Time) if !ok { - return nil, fmt.Errorf("invalid Date array value") + return nil, errors.New("invalid Date array value") } transformedTimeArr := make([]interface{}, 0, len(arrayDate)) @@ -704,7 +724,7 @@ func (c *QValueAvroConverter) processArrayInt16() (interface{}, error) { arrayData, ok := c.Value.Value.([]int16) if !ok { - return nil, fmt.Errorf("invalid Int16 array value") + return nil, errors.New("invalid Int16 array value") } // cast to int32 @@ -727,7 +747,7 @@ func (c *QValueAvroConverter) processArrayInt32() (interface{}, error) { arrayData, ok := c.Value.Value.([]int32) if !ok { - return nil, fmt.Errorf("invalid Int32 array value") + return nil, errors.New("invalid Int32 array value") } if c.Nullable { @@ -744,7 +764,7 @@ func (c *QValueAvroConverter) processArrayInt64() (interface{}, error) { arrayData, ok := c.Value.Value.([]int64) if !ok { - return nil, fmt.Errorf("invalid Int64 array value") + return nil, errors.New("invalid Int64 array value") } if c.Nullable { @@ -761,7 +781,7 @@ func (c *QValueAvroConverter) processArrayFloat32() (interface{}, error) { arrayData, ok := c.Value.Value.([]float32) if !ok { - return nil, fmt.Errorf("invalid Float32 array value") + return nil, errors.New("invalid Float32 array value") } if c.Nullable { @@ -778,7 +798,7 @@ func (c *QValueAvroConverter) processArrayFloat64() (interface{}, error) { arrayData, ok := c.Value.Value.([]float64) if !ok { - return nil, fmt.Errorf("invalid Float64 array value") + return nil, errors.New("invalid Float64 array value") } if c.Nullable { @@ -795,7 +815,7 @@ func (c *QValueAvroConverter) processArrayString() (interface{}, error) { arrayData, ok := c.Value.Value.([]string) if !ok { - return nil, fmt.Errorf("invalid String array value") + return nil, errors.New("invalid String array value") } if c.Nullable { diff --git a/flow/model/qvalue/timestamp.go b/flow/model/qvalue/timestamp.go new file mode 100644 index 0000000000..20d564b927 --- /dev/null +++ b/flow/model/qvalue/timestamp.go @@ -0,0 +1,20 @@ +package qvalue + +import ( + "time" + + "go.temporal.io/sdk/log" +) + +// Bigquery will not allow timestamp if it is less than 1AD and more than 9999AD +func DisallowedTimestamp(dwh QDWHType, t time.Time, logger log.Logger) bool { + if dwh == QDWHTypeBigQuery { + tMicro := t.UnixMicro() + if tMicro < 0 || tMicro > 253402300799999999 { // 9999-12-31 23:59:59.999999 + logger.Warn("Nulling Timestamp value for BigQuery as it exceeds allowed range", + "timestamp", t.String()) + return true + } + } + return false +} From 758727a93a438cc903dc4ce2cac71084a70b730a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 14 Feb 2024 19:47:31 +0000 Subject: [PATCH 6/6] Don't spam logs about unknown types being sent as text (#1291) Logs of this nature should only be generated once per connector PostgresCDCSource & QRepQueryExecutor now extend PostgresConnector, stops QRepQueryExecutor querying for customTypesMapping for snapshots QValue equality shifted because we're now using connector with type mappings when retrieving records --- flow/connectors/bigquery/bigquery.go | 5 +- flow/connectors/postgres/cdc.go | 20 ++- flow/connectors/postgres/postgres.go | 16 ++- flow/connectors/postgres/qrep.go | 33 +---- flow/connectors/postgres/qrep_bench_test.go | 18 ++- .../postgres/qrep_query_executor.go | 75 ++++------- .../postgres/qrep_query_executor_test.go | 47 +++---- flow/connectors/postgres/qvalue_convert.go | 13 +- flow/e2e/bigquery/peer_flow_bq_test.go | 123 +++++++++--------- flow/e2e/bigquery/qrep_flow_bq_test.go | 8 +- flow/e2e/congen.go | 18 +-- flow/e2e/postgres/peer_flow_pg_test.go | 102 +++++++-------- flow/e2e/postgres/qrep_flow_pg_test.go | 52 +++----- flow/e2e/s3/cdc_s3_test.go | 4 +- flow/e2e/s3/qrep_flow_s3_test.go | 10 +- flow/e2e/snowflake/peer_flow_sf_test.go | 123 +++++++++--------- flow/e2e/snowflake/qrep_flow_sf_test.go | 4 +- .../e2e/sqlserver/qrep_flow_sqlserver_test.go | 13 +- flow/e2e/test_utils.go | 33 +++-- flow/e2eshared/e2eshared.go | 15 +-- flow/geo/geo.go | 15 --- flow/model/qvalue/qvalue.go | 74 ++++++++--- 22 files changed, 405 insertions(+), 416 deletions(-) diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 4d63a97657..5aed634e60 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -47,7 +47,6 @@ type BigQueryServiceAccount struct { ClientX509CertURL string `json:"client_x509_cert_url"` } -// BigQueryConnector is a Connector implementation for BigQuery. type BigQueryConnector struct { bqConfig *protos.BigqueryConfig client *bigquery.Client @@ -59,7 +58,6 @@ type BigQueryConnector struct { logger log.Logger } -// Create BigQueryServiceAccount from BigqueryConfig func NewBigQueryServiceAccount(bqConfig *protos.BigqueryConfig) (*BigQueryServiceAccount, error) { var serviceAccount BigQueryServiceAccount serviceAccount.Type = bqConfig.AuthType @@ -178,7 +176,6 @@ func TableCheck(ctx context.Context, client *bigquery.Client, dataset string, pr return nil } -// NewBigQueryConnector creates a new BigQueryConnector from a PeerConnectionConfig. func NewBigQueryConnector(ctx context.Context, config *protos.BigqueryConfig) (*BigQueryConnector, error) { logger := logger.LoggerFromCtx(ctx) @@ -246,7 +243,7 @@ func (c *BigQueryConnector) Close(_ context.Context) error { return c.client.Close() } -// ConnectionActive returns true if the connection is active. +// ConnectionActive returns nil if the connection is active. func (c *BigQueryConnector) ConnectionActive(ctx context.Context) error { _, err := c.client.DatasetInProject(c.projectID, c.datasetID).Metadata(ctx) if err != nil { diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 8992a6562f..76de074bbd 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -23,10 +23,10 @@ import ( "github.com/PeerDB-io/peer-flow/logger" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" - "github.com/PeerDB-io/peer-flow/shared" ) type PostgresCDCSource struct { + *PostgresConnector replConn *pgx.Conn SrcTableIDNameMapping map[uint32]string TableNameMapping map[string]model.NameAndExclude @@ -35,11 +35,9 @@ type PostgresCDCSource struct { relationMessageMapping model.RelationMessageMapping typeMap *pgtype.Map commitLock bool - customTypeMapping map[uint32]string // for partitioned tables, maps child relid to parent relid childToParentRelIDMapping map[uint32]uint32 - logger slog.Logger // for storing chema delta audit logs to catalog catalogPool *pgxpool.Pool @@ -64,14 +62,14 @@ type startReplicationOpts struct { } // Create a new PostgresCDCSource -func NewPostgresCDCSource(ctx context.Context, cdcConfig *PostgresCDCConfig, customTypeMap map[uint32]string) (*PostgresCDCSource, error) { +func (c *PostgresConnector) NewPostgresCDCSource(ctx context.Context, cdcConfig *PostgresCDCConfig) (*PostgresCDCSource, error) { childToParentRelIDMap, err := getChildToParentRelIDMap(ctx, cdcConfig.Connection) if err != nil { return nil, fmt.Errorf("error getting child to parent relid map: %w", err) } - flowName, _ := ctx.Value(shared.FlowNameKey).(string) return &PostgresCDCSource{ + PostgresConnector: c, replConn: cdcConfig.Connection, SrcTableIDNameMapping: cdcConfig.SrcTableIDNameMapping, TableNameMapping: cdcConfig.TableNameMapping, @@ -81,8 +79,6 @@ func NewPostgresCDCSource(ctx context.Context, cdcConfig *PostgresCDCConfig, cus typeMap: pgtype.NewMap(), childToParentRelIDMapping: childToParentRelIDMap, commitLock: false, - customTypeMapping: customTypeMap, - logger: *slog.With(slog.String(string(shared.FlowNameKey), flowName)), catalogPool: cdcConfig.CatalogPool, flowJobName: cdcConfig.FlowJobName, }, nil @@ -731,20 +727,20 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma } return qvalue.QValue{}, err } - retVal, err := parseFieldFromPostgresOID(dataType, parsedData) + retVal, err := p.parseFieldFromPostgresOID(dataType, parsedData) if err != nil { return qvalue.QValue{}, err } return retVal, nil } else if dataType == uint32(oid.T_timetz) { // ugly TIMETZ workaround for CDC decoding. - retVal, err := parseFieldFromPostgresOID(dataType, string(data)) + retVal, err := p.parseFieldFromPostgresOID(dataType, string(data)) if err != nil { return qvalue.QValue{}, err } return retVal, nil } - typeName, ok := p.customTypeMapping[dataType] + typeName, ok := p.customTypesMapping[dataType] if ok { customQKind := customTypeToQKind(typeName) if customQKind == qvalue.QValueKindGeography || customQKind == qvalue.QValueKindGeometry { @@ -835,9 +831,9 @@ func (p *PostgresCDCSource) processRelationMessage( for _, column := range currRel.Columns { // not present in previous relation message, but in current one, so added. if prevRelMap[column.Name] == nil { - qKind := postgresOIDToQValueKind(column.DataType) + qKind := p.postgresOIDToQValueKind(column.DataType) if qKind == qvalue.QValueKindInvalid { - typeName, ok := p.customTypeMapping[column.DataType] + typeName, ok := p.customTypesMapping[column.DataType] if ok { qKind = customTypeToQKind(typeName) } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index d81401511b..b8d3d27f20 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -24,7 +24,6 @@ import ( "github.com/PeerDB-io/peer-flow/shared/alerting" ) -// PostgresConnector is a Connector implementation for Postgres. type PostgresConnector struct { connStr string config *protos.PostgresConfig @@ -33,10 +32,10 @@ type PostgresConnector struct { replConfig *pgx.ConnConfig customTypesMapping map[uint32]string metadataSchema string + hushWarnOID map[uint32]struct{} logger log.Logger } -// NewPostgresConnector creates a new instance of PostgresConnector. func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) (*PostgresConnector, error) { connectionString := utils.GetPGConnectionString(pgConfig) @@ -84,6 +83,7 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) replConfig: replConfig, customTypesMapping: customTypeMap, metadataSchema: metadataSchema, + hushWarnOID: make(map[uint32]struct{}), logger: logger.LoggerFromCtx(ctx), }, nil } @@ -107,7 +107,11 @@ func (c *PostgresConnector) Close(ctx context.Context) error { return nil } -// ConnectionActive returns true if the connection is active. +func (c *PostgresConnector) Conn() *pgx.Conn { + return c.conn +} + +// ConnectionActive returns nil if the connection is active. func (c *PostgresConnector) ConnectionActive(ctx context.Context) error { if c.conn == nil { return fmt.Errorf("connection is nil") @@ -214,7 +218,7 @@ func (c *PostgresConnector) PullRecords(ctx context.Context, catalogPool *pgxpoo } defer replConn.Close(ctx) - cdc, err := NewPostgresCDCSource(ctx, &PostgresCDCConfig{ + cdc, err := c.NewPostgresCDCSource(ctx, &PostgresCDCConfig{ Connection: replConn, SrcTableIDNameMapping: req.SrcTableIDNameMapping, Slot: slotName, @@ -223,7 +227,7 @@ func (c *PostgresConnector) PullRecords(ctx context.Context, catalogPool *pgxpoo RelationMessageMapping: req.RelationMessageMapping, CatalogPool: catalogPool, FlowJobName: req.FlowJobName, - }, c.customTypesMapping) + }) if err != nil { return fmt.Errorf("failed to create cdc source: %w", err) } @@ -603,7 +607,7 @@ func (c *PostgresConnector) getTableSchemaForTable( columnNames := make([]string, 0, len(fields)) columns := make([]*protos.FieldDescription, 0, len(fields)) for _, fieldDescription := range fields { - genericColType := postgresOIDToQValueKind(fieldDescription.DataTypeOID) + genericColType := c.postgresOIDToQValueKind(fieldDescription.DataTypeOID) if genericColType == qvalue.QValueKindInvalid { typeName, ok := c.customTypesMapping[fieldDescription.DataTypeOID] if ok { diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index 8a01d90cbe..13b92dfcd4 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -324,12 +324,8 @@ func (c *PostgresConnector) PullQRepRecords( partitionIdLog := slog.String(string(shared.PartitionIDKey), partition.PartitionId) if partition.FullTablePartition { c.logger.Info("pulling full table partition", partitionIdLog) - executor, err := NewQRepQueryExecutorSnapshot(ctx, - c.conn, c.config.TransactionSnapshot, + executor := c.NewQRepQueryExecutorSnapshot(c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) - if err != nil { - return nil, err - } query := config.Query return executor.ExecuteAndProcessQuery(ctx, query) } @@ -368,12 +364,8 @@ func (c *PostgresConnector) PullQRepRecords( return nil, err } - executor, err := NewQRepQueryExecutorSnapshot( - ctx, c.conn, c.config.TransactionSnapshot, + executor := c.NewQRepQueryExecutorSnapshot(c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) - if err != nil { - return nil, err - } records, err := executor.ExecuteAndProcessQuery(ctx, query, rangeStart, rangeEnd) if err != nil { @@ -392,15 +384,11 @@ func (c *PostgresConnector) PullQRepRecordStream( partitionIdLog := slog.String(string(shared.PartitionIDKey), partition.PartitionId) if partition.FullTablePartition { c.logger.Info("pulling full table partition", partitionIdLog) - executor, err := NewQRepQueryExecutorSnapshot( - ctx, c.conn, c.config.TransactionSnapshot, + executor := c.NewQRepQueryExecutorSnapshot(c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) - if err != nil { - return 0, err - } query := config.Query - _, err = executor.ExecuteAndProcessQueryStream(ctx, stream, query) + _, err := executor.ExecuteAndProcessQueryStream(ctx, stream, query) return 0, err } c.logger.Info("Obtained ranges for partition for PullQRepStream", partitionIdLog) @@ -438,12 +426,8 @@ func (c *PostgresConnector) PullQRepRecordStream( return 0, err } - executor, err := NewQRepQueryExecutorSnapshot( - ctx, c.conn, c.config.TransactionSnapshot, + executor := c.NewQRepQueryExecutorSnapshot(c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) - if err != nil { - return 0, err - } numRecords, err := executor.ExecuteAndProcessQueryStream(ctx, stream, query, rangeStart, rangeEnd) if err != nil { @@ -539,13 +523,10 @@ func (c *PostgresConnector) PullXminRecordStream( query += " WHERE age(xmin) > 0 AND age(xmin) <= age($1::xid)" } - executor, err := NewQRepQueryExecutorSnapshot( - ctx, c.conn, c.config.TransactionSnapshot, + executor := c.NewQRepQueryExecutorSnapshot(c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) - if err != nil { - return 0, currentSnapshotXmin, err - } + var err error var numRecords int if partition.Range != nil { numRecords, currentSnapshotXmin, err = executor.ExecuteAndProcessQueryStreamGettingCurrentSnapshotXmin( diff --git a/flow/connectors/postgres/qrep_bench_test.go b/flow/connectors/postgres/qrep_bench_test.go index e8f514bc38..1dddc617fc 100644 --- a/flow/connectors/postgres/qrep_bench_test.go +++ b/flow/connectors/postgres/qrep_bench_test.go @@ -4,24 +4,28 @@ import ( "context" "testing" - "github.com/jackc/pgx/v5" + "github.com/PeerDB-io/peer-flow/generated/protos" ) func BenchmarkQRepQueryExecutor(b *testing.B) { - connectionString := "postgres://postgres:postgres@localhost:7132/postgres" query := "SELECT * FROM bench.large_table" ctx := context.Background() - - // Create a separate connection for non-replication queries - conn, err := pgx.Connect(ctx, connectionString) + connector, err := NewPostgresConnector(ctx, + &protos.PostgresConfig{ + Host: "localhost", + Port: 7132, + User: "postgres", + Password: "postgres", + Database: "postgres", + }) if err != nil { b.Fatalf("failed to create connection: %v", err) } - defer conn.Close(context.Background()) + defer connector.Close(ctx) // Create a new QRepQueryExecutor instance - qe := NewQRepQueryExecutor(conn, context.Background(), "test flow", "test part") + qe := connector.NewQRepQueryExecutor("test flow", "test part") // Run the benchmark b.ResetTimer() diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index 48900555f0..2b97759d0e 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -13,7 +13,6 @@ import ( "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/geo" - "github.com/PeerDB-io/peer-flow/logger" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/numeric" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -21,48 +20,32 @@ import ( ) type QRepQueryExecutor struct { - conn *pgx.Conn - snapshot string - testEnv bool - flowJobName string - partitionID string - customTypeMap map[uint32]string - logger log.Logger + *PostgresConnector + snapshot string + testEnv bool + flowJobName string + partitionID string + logger log.Logger } -func NewQRepQueryExecutor(conn *pgx.Conn, ctx context.Context, - flowJobName string, partitionID string, -) *QRepQueryExecutor { +func (c *PostgresConnector) NewQRepQueryExecutor(flowJobName string, partitionID string) *QRepQueryExecutor { return &QRepQueryExecutor{ - conn: conn, - snapshot: "", - flowJobName: flowJobName, - partitionID: partitionID, - logger: log.With( - logger.LoggerFromCtx(ctx), - slog.String(string(shared.PartitionIDKey), partitionID), - ), + PostgresConnector: c, + snapshot: "", + flowJobName: flowJobName, + partitionID: partitionID, + logger: log.With(c.logger, slog.String(string(shared.PartitionIDKey), partitionID)), } } -func NewQRepQueryExecutorSnapshot(ctx context.Context, conn *pgx.Conn, snapshot string, - flowJobName string, partitionID string, -) (*QRepQueryExecutor, error) { - CustomTypeMap, err := utils.GetCustomDataTypes(ctx, conn) - if err != nil { - return nil, fmt.Errorf("failed to get custom data types: %w", err) - } +func (c *PostgresConnector) NewQRepQueryExecutorSnapshot(snapshot string, flowJobName string, partitionID string) *QRepQueryExecutor { return &QRepQueryExecutor{ - conn: conn, - snapshot: snapshot, - flowJobName: flowJobName, - partitionID: partitionID, - customTypeMap: CustomTypeMap, - logger: log.With( - logger.LoggerFromCtx(ctx), - slog.String(string(shared.PartitionIDKey), partitionID), - ), - }, nil + PostgresConnector: c, + snapshot: snapshot, + flowJobName: flowJobName, + partitionID: partitionID, + logger: log.With(c.logger, slog.String(string(shared.PartitionIDKey), partitionID)), + } } func (qe *QRepQueryExecutor) SetTestEnv(testEnv bool) { @@ -104,9 +87,9 @@ func (qe *QRepQueryExecutor) fieldDescriptionsToSchema(fds []pgconn.FieldDescrip qfields := make([]model.QField, len(fds)) for i, fd := range fds { cname := fd.Name - ctype := postgresOIDToQValueKind(fd.DataTypeOID) + ctype := qe.postgresOIDToQValueKind(fd.DataTypeOID) if ctype == qvalue.QValueKindInvalid { - typeName, ok := qe.customTypeMap[fd.DataTypeOID] + typeName, ok := qe.customTypesMapping[fd.DataTypeOID] if ok { ctype = customTypeToQKind(typeName) } else { @@ -145,7 +128,7 @@ func (qe *QRepQueryExecutor) ProcessRows( qe.logger.Info("Processing rows") // Iterate over the rows for rows.Next() { - record, err := mapRowToQRecord(qe.logger, rows, fieldDescriptions, qe.customTypeMap) + record, err := qe.mapRowToQRecord(rows, fieldDescriptions) if err != nil { qe.logger.Error("[pg_query_executor] failed to map row to QRecord", slog.Any("error", err)) return nil, fmt.Errorf("failed to map row to QRecord: %w", err) @@ -186,7 +169,7 @@ func (qe *QRepQueryExecutor) processRowsStream( return numRows, ctx.Err() default: // Process the row as before - record, err := mapRowToQRecord(qe.logger, rows, fieldDescriptions, qe.customTypeMap) + record, err := qe.mapRowToQRecord(rows, fieldDescriptions) if err != nil { qe.logger.Error("[pg_query_executor] failed to map row to QRecord", slog.Any("error", err)) stream.Records <- model.QRecordOrError{ @@ -450,28 +433,26 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamWithTx( return totalRecordsFetched, nil } -func mapRowToQRecord( - logger log.Logger, +func (qe *QRepQueryExecutor) mapRowToQRecord( row pgx.Rows, fds []pgconn.FieldDescription, - customTypeMap map[uint32]string, ) ([]qvalue.QValue, error) { // make vals an empty array of QValue of size len(fds) record := make([]qvalue.QValue, len(fds)) values, err := row.Values() if err != nil { - logger.Error("[pg_query_executor] failed to get values from row", slog.Any("error", err)) + qe.logger.Error("[pg_query_executor] failed to get values from row", slog.Any("error", err)) return nil, fmt.Errorf("failed to scan row: %w", err) } for i, fd := range fds { // Check if it's a custom type first - typeName, ok := customTypeMap[fd.DataTypeOID] + typeName, ok := qe.customTypesMapping[fd.DataTypeOID] if !ok { - tmp, err := parseFieldFromPostgresOID(fd.DataTypeOID, values[i]) + tmp, err := qe.parseFieldFromPostgresOID(fd.DataTypeOID, values[i]) if err != nil { - logger.Error("[pg_query_executor] failed to parse field", slog.Any("error", err)) + qe.logger.Error("[pg_query_executor] failed to parse field", slog.Any("error", err)) return nil, fmt.Errorf("failed to parse field: %w", err) } record[i] = tmp diff --git a/flow/connectors/postgres/qrep_query_executor_test.go b/flow/connectors/postgres/qrep_query_executor_test.go index bffb28214c..32d0a1a154 100644 --- a/flow/connectors/postgres/qrep_query_executor_test.go +++ b/flow/connectors/postgres/qrep_query_executor_test.go @@ -10,31 +10,35 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5" + + "github.com/PeerDB-io/peer-flow/generated/protos" ) -func setupDB(t *testing.T) (*pgx.Conn, string) { +func setupDB(t *testing.T) (*PostgresConnector, string) { t.Helper() - config, err := pgx.ParseConfig("postgres://postgres:postgres@localhost:7132/postgres") - if err != nil { - t.Fatalf("unable to parse config: %v", err) - } - - conn, err := pgx.ConnectConfig(context.Background(), config) + connector, err := NewPostgresConnector(context.Background(), + &protos.PostgresConfig{ + Host: "localhost", + Port: 7132, + User: "postgres", + Password: "postgres", + Database: "postgres", + }) if err != nil { - t.Fatalf("unable to connect to database: %v", err) + t.Fatalf("unable to create connector: %v", err) } // Create unique schema name using current time schemaName := fmt.Sprintf("schema_%d", time.Now().Unix()) // Create the schema - _, err = conn.Exec(context.Background(), fmt.Sprintf("CREATE SCHEMA %s;", schemaName)) + _, err = connector.conn.Exec(context.Background(), fmt.Sprintf("CREATE SCHEMA %s;", schemaName)) if err != nil { t.Fatalf("unable to create schema: %v", err) } - return conn, schemaName + return connector, schemaName } func teardownDB(t *testing.T, conn *pgx.Conn, schemaName string) { @@ -47,12 +51,11 @@ func teardownDB(t *testing.T, conn *pgx.Conn, schemaName string) { } func TestExecuteAndProcessQuery(t *testing.T) { - conn, schemaName := setupDB(t) - defer conn.Close(context.Background()) - - defer teardownDB(t, conn, schemaName) - ctx := context.Background() + connector, schemaName := setupDB(t) + conn := connector.conn + defer connector.Close(ctx) + defer teardownDB(t, conn, schemaName) query := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.test(id SERIAL PRIMARY KEY, data TEXT);", schemaName) _, err := conn.Exec(ctx, query) @@ -66,7 +69,7 @@ func TestExecuteAndProcessQuery(t *testing.T) { t.Fatalf("error while inserting into test table: %v", err) } - qe := NewQRepQueryExecutor(conn, ctx, "test flow", "test part") + qe := connector.NewQRepQueryExecutor("test flow", "test part") qe.SetTestEnv(true) query = fmt.Sprintf("SELECT * FROM %s.test;", schemaName) @@ -85,13 +88,11 @@ func TestExecuteAndProcessQuery(t *testing.T) { } func TestAllDataTypes(t *testing.T) { - conn, schemaName := setupDB(t) - defer conn.Close(context.Background()) - - // Call teardownDB function after test - defer teardownDB(t, conn, schemaName) - ctx := context.Background() + connector, schemaName := setupDB(t) + conn := connector.conn + defer conn.Close(ctx) + defer teardownDB(t, conn, schemaName) // Create a table that contains every data type we want to test query := fmt.Sprintf(` @@ -170,7 +171,7 @@ func TestAllDataTypes(t *testing.T) { t.Fatalf("error while inserting into test table: %v", err) } - qe := NewQRepQueryExecutor(conn, ctx, "test flow", "test part") + qe := connector.NewQRepQueryExecutor("test flow", "test part") // Select the row back out of the table query = fmt.Sprintf("SELECT * FROM %s.test;", schemaName) rows, err := qe.ExecuteQuery(context.Background(), query) diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index 81f19e1ec9..d54e6fa4d3 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "log/slog" "math/big" "strings" "time" @@ -18,7 +17,7 @@ import ( var big10 = big.NewInt(10) -func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { +func (c *PostgresConnector) postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { switch recvOID { case pgtype.BoolOID: return qvalue.QValueKindBoolean @@ -104,7 +103,11 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { return qvalue.QValueKindInvalid } else { - slog.Warn(fmt.Sprintf("unsupported field type: %v - type name - %s; returning as string", recvOID, typeName.Name)) + _, warned := c.hushWarnOID[recvOID] + if !warned { + c.logger.Warn(fmt.Sprintf("unsupported field type: %d - type name - %s; returning as string", recvOID, typeName.Name)) + c.hushWarnOID[recvOID] = struct{}{} + } return qvalue.QValueKindString } } @@ -376,8 +379,8 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( return val, nil } -func parseFieldFromPostgresOID(oid uint32, value interface{}) (qvalue.QValue, error) { - return parseFieldFromQValueKind(postgresOIDToQValueKind(oid), value) +func (c *PostgresConnector) parseFieldFromPostgresOID(oid uint32, value interface{}) (qvalue.QValue, error) { + return parseFieldFromQValueKind(c.postgresOIDToQValueKind(oid), value) } func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) { diff --git a/flow/e2e/bigquery/peer_flow_bq_test.go b/flow/e2e/bigquery/peer_flow_bq_test.go index 301359b3cc..0c3306b668 100644 --- a/flow/e2e/bigquery/peer_flow_bq_test.go +++ b/flow/e2e/bigquery/peer_flow_bq_test.go @@ -14,6 +14,7 @@ import ( "github.com/joho/godotenv" "github.com/stretchr/testify/require" + connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/e2eshared" @@ -28,7 +29,7 @@ type PeerFlowE2ETestSuiteBQ struct { t *testing.T bqSuffix string - conn *pgx.Conn + conn *connpostgres.PostgresConnector bqHelper *BigQueryTestHelper } @@ -37,6 +38,10 @@ func (s PeerFlowE2ETestSuiteBQ) T() *testing.T { } func (s PeerFlowE2ETestSuiteBQ) Conn() *pgx.Conn { + return s.conn.Conn() +} + +func (s PeerFlowE2ETestSuiteBQ) Connector() *connpostgres.PostgresConnector { return s.conn } @@ -207,7 +212,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Complete_Flow_No_Data() { srcTableName := s.attachSchemaSuffix("test_no_data") dstTableName := "test_no_data" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, key TEXT NOT NULL, @@ -243,7 +248,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Char_ColType_Error() { srcTableName := s.attachSchemaSuffix("test_char_coltype") dstTableName := "test_char_coltype" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, key TEXT NOT NULL, @@ -282,7 +287,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Complete_Simple_Flow_BQ() { srcTableName := s.attachSchemaSuffix("test_simple_flow_bq") dstTableName := "test_simple_flow_bq" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, key TEXT NOT NULL, @@ -308,7 +313,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Complete_Simple_Flow_BQ() { for i := 0; i < 10; i++ { testKey := fmt.Sprintf("test_key_%d", i) testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(key, value) VALUES ($1, $2) `, srcTableName), testKey, testValue) e2e.EnvNoError(s.t, env, err) @@ -329,7 +334,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Toast_BQ() { srcTableName := s.attachSchemaSuffix("test_toast_bq_1") dstTableName := "test_toast_bq_1" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, t1 text, @@ -360,7 +365,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Toast_BQ() { 2. changes no toast column 2. changes 1 toast column */ - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` BEGIN; INSERT INTO %s(t1,t2,k) SELECT random_string(9000),random_string(9000), 1 FROM generate_series(1,2); @@ -385,7 +390,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Toast_Advance_1_BQ() { srcTableName := s.attachSchemaSuffix("test_toast_bq_3") dstTableName := "test_toast_bq_3" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, t1 text, @@ -411,7 +416,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Toast_Advance_1_BQ() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) // complex transaction with random DMLs on a table with toast columns - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` BEGIN; INSERT INTO %s(t1,t2,k) SELECT random_string(9000),random_string(9000), 1 FROM generate_series(1,2); @@ -447,7 +452,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Toast_Advance_2_BQ() { srcTableName := s.attachSchemaSuffix("test_toast_bq_4") dstTableName := "test_toast_bq_4" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE %s ( id SERIAL PRIMARY KEY, t1 text, @@ -472,7 +477,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Toast_Advance_2_BQ() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) // complex transaction with random DMLs on a table with toast columns - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` BEGIN; INSERT INTO %s(t1,k) SELECT random_string(9000), 1 FROM generate_series(1,1); @@ -501,7 +506,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Toast_Advance_3_BQ() { srcTableName := s.attachSchemaSuffix("test_toast_bq_5") dstTableName := "test_toast_bq_5" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, t1 text, @@ -530,7 +535,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Toast_Advance_3_BQ() { transaction updating a single row multiple times with changed/unchanged toast columns */ - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` BEGIN; INSERT INTO %s(t1,t2,k) SELECT random_string(9000),random_string(9000), 1 FROM generate_series(1,1); @@ -556,12 +561,12 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Types_BQ() { dstTableName := "test_types_bq" createMoodEnum := "CREATE TYPE mood AS ENUM ('happy', 'sad', 'angry');" var pgErr *pgconn.PgError - _, enumErr := s.conn.Exec(context.Background(), createMoodEnum) + _, enumErr := s.Conn().Exec(context.Background(), createMoodEnum) if errors.As(enumErr, &pgErr) && pgErr.Code != pgerrcode.DuplicateObject && !utils.IsUniqueError(pgErr) { require.NoError(s.t, enumErr) } - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (id serial PRIMARY KEY,c1 BIGINT,c2 BIT,c3 VARBIT,c4 BOOLEAN, c6 BYTEA,c7 CHARACTER,c8 varchar,c9 CIDR,c11 DATE,c12 FLOAT,c13 DOUBLE PRECISION, c14 INET,c15 INTEGER,c16 INTERVAL,c17 JSON,c18 JSONB,c21 MACADDR,c22 MONEY, @@ -588,7 +593,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Types_BQ() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) /* test inserting various types*/ - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s SELECT 2,2,b'1',b'101', true,random_bytea(32),'s','test','1.1.10.2'::cidr, CURRENT_DATE,1.23,1.234,'192.168.1.5'::inet,1, @@ -651,7 +656,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_NaN_Doubles_BQ() { srcTableName := s.attachSchemaSuffix("test_nans_bq") dstTableName := "test_nans_bq" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (id serial PRIMARY KEY,c1 double precision,c2 double precision[]); `, srcTableName)) require.NoError(s.t, err) @@ -672,7 +677,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_NaN_Doubles_BQ() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) /* test inserting various types*/ - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s SELECT 2, 'NaN'::double precision, '{NaN, Infinity, -Infinity}'; `, srcTableName)) e2e.EnvNoError(s.t, env, err) @@ -694,7 +699,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Invalid_Geo_BQ_Avro_CDC() { srcTableName := s.attachSchemaSuffix("test_invalid_geo_bq_avro_cdc") dstTableName := "test_invalid_geo_bq_avro_cdc" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, line GEOMETRY(LINESTRING) NOT NULL, @@ -720,7 +725,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Invalid_Geo_BQ_Avro_CDC() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) // insert 4 invalid shapes and 6 valid shapes into the source table for i := 0; i < 4; i++ { - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (line,"polyPoly") VALUES ($1,$2) `, srcTableName), "010200000001000000000000000000F03F0000000000000040", "0103000020e6100000010000000c0000001a8361d35dc64140afdb8d2b1bc3c9bf1b8ed4685fc641405ba64c"+ @@ -732,7 +737,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Invalid_Geo_BQ_Avro_CDC() { } s.t.Log("Inserted 4 invalid geography rows into the source table") for i := 4; i < 10; i++ { - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (line,"polyPoly") VALUES ($1,$2) `, srcTableName), "010200000002000000000000000000F03F000000000000004000000000000008400000000000001040", "010300000001000000050000000000000000000000000000000000000000000000"+ @@ -778,7 +783,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Multi_Table_BQ() { srcTable2Name := s.attachSchemaSuffix("test2_bq") dstTable2Name := "test2_bq" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE %s (id serial primary key, c1 int, c2 text); CREATE TABLE %s(id serial primary key, c1 int, c2 text); `, srcTable1Name, srcTable2Name)) @@ -800,7 +805,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Multi_Table_BQ() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) /* inserting across multiple tables*/ - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (c1,c2) VALUES (1,'dummy_1'); INSERT INTO %s (c1,c2) VALUES (-1,'dummy_-1'); `, srcTable1Name, srcTable2Name)) @@ -834,7 +839,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Simple_Schema_Changes_BQ() { tableName := "test_simple_schema_changes" srcTableName := s.attachSchemaSuffix(tableName) - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 BIGINT @@ -858,7 +863,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Simple_Schema_Changes_BQ() { go func() { // insert first row. e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1) VALUES (1)`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted initial row in the source table") @@ -866,11 +871,11 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Simple_Schema_Changes_BQ() { e2e.EnvWaitForEqualTables(env, s, "normalize insert", tableName, "id,c1") // alter source table, add column c2 and insert another row. - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` ALTER TABLE %s ADD COLUMN c2 BIGINT`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Altered source table, added column c2") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2) VALUES (2,2)`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted row with added c2 in the source table") @@ -879,11 +884,11 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Simple_Schema_Changes_BQ() { e2e.EnvWaitForEqualTables(env, s, "normalize altered row", tableName, "id,c1,c2") // alter source table, add column c3, drop column c2 and insert another row. - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` ALTER TABLE %s DROP COLUMN c2, ADD COLUMN c3 BIGINT`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Altered source table, dropped column c2 and added column c3") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c3) VALUES (3,3)`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted row with added c3 in the source table") @@ -892,11 +897,11 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Simple_Schema_Changes_BQ() { e2e.EnvWaitForEqualTables(env, s, "normalize altered row", tableName, "id,c1,c3") // alter source table, drop column c3 and insert another row. - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` ALTER TABLE %s DROP COLUMN c3`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Altered source table, dropped column c3") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1) VALUES (4)`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted row after dropping all columns in the source table") @@ -917,7 +922,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Composite_PKey_BQ() { tableName := "test_simple_cpkey" srcTableName := s.attachSchemaSuffix("test_simple_cpkey") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT GENERATED ALWAYS AS IDENTITY, c1 INT GENERATED BY DEFAULT AS IDENTITY, @@ -946,7 +951,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Composite_PKey_BQ() { // insert 10 rows into the source table for i := 0; i < 10; i++ { testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c2,t) VALUES ($1,$2) `, srcTableName), i, testValue) e2e.EnvNoError(s.t, env, err) @@ -956,10 +961,10 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Composite_PKey_BQ() { // verify we got our 10 rows e2e.EnvWaitForEqualTables(env, s, "normalize table", tableName, "id,c1,c2,t") - _, err := s.conn.Exec(context.Background(), + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(`UPDATE %s SET c1=c1+1 WHERE MOD(c2,2)=$1`, srcTableName), 1) e2e.EnvNoError(s.t, env, err) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTables(env, s, "normalize update", tableName, "id,c1,c2,t") @@ -977,7 +982,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Composite_PKey_Toast_1_BQ() { srcTableName := s.attachSchemaSuffix("test_cpkey_toast1") dstTableName := "test_cpkey_toast1" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT GENERATED ALWAYS AS IDENTITY, c1 INT GENERATED BY DEFAULT AS IDENTITY, @@ -1004,7 +1009,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Composite_PKey_Toast_1_BQ() { // and then insert, update and delete rows in the table. go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - rowsTx, err := s.conn.Begin(context.Background()) + rowsTx, err := s.Conn().Begin(context.Background()) e2e.EnvNoError(s.t, env, err) // insert 10 rows into the source table @@ -1039,7 +1044,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Composite_PKey_Toast_2_BQ() { tableName := "test_cpkey_toast2" srcTableName := s.attachSchemaSuffix("test_cpkey_toast2") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT GENERATED ALWAYS AS IDENTITY, c1 INT GENERATED BY DEFAULT AS IDENTITY, @@ -1070,7 +1075,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Composite_PKey_Toast_2_BQ() { // insert 10 rows into the source table for i := 0; i < 10; i++ { testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c2,t,t2) VALUES ($1,$2,random_string(9000)) `, srcTableName), i, testValue) e2e.EnvNoError(s.t, env, err) @@ -1078,10 +1083,10 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Composite_PKey_Toast_2_BQ() { s.t.Log("Inserted 10 rows into the source table") e2e.EnvWaitForEqualTables(env, s, "normalize table", tableName, "id,c2,t,t2") - _, err = s.conn.Exec(context.Background(), + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`UPDATE %s SET c1=c1+1 WHERE MOD(c2,2)=$1`, srcTableName), 1) e2e.EnvNoError(s.t, env, err) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTables(env, s, "normalize update", tableName, "id,c2,t,t2") @@ -1097,7 +1102,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Columns_BQ() { srcTableName := s.attachSchemaSuffix("test_peerdb_cols") dstTableName := "test_peerdb_cols_dst" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, key TEXT NOT NULL, @@ -1122,13 +1127,13 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Columns_BQ() { // insert 1 row into the source table testKey := fmt.Sprintf("test_key_%d", 1) testValue := fmt.Sprintf("test_value_%d", 1) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(key, value) VALUES ($1, $2) `, srcTableName), testKey, testValue) e2e.EnvNoError(s.t, env, err) // delete that row - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` DELETE FROM %s WHERE id=1 `, srcTableName)) e2e.EnvNoError(s.t, env, err) @@ -1152,7 +1157,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Multi_Table_Multi_Dataset_BQ() { srcTable2Name := s.attachSchemaSuffix("test2_bq") dstTable2Name := "test2_bq" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE %s(id serial primary key, c1 int, c2 text); CREATE TABLE %s(id serial primary key, c1 int, c2 text); `, srcTable1Name, srcTable2Name)) @@ -1177,7 +1182,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Multi_Table_Multi_Dataset_BQ() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) /* inserting across multiple tables*/ - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (c1,c2) VALUES (1,'dummy_1'); INSERT INTO %s (c1,c2) VALUES (-1,'dummy_-1'); `, srcTable1Name, srcTable2Name)) @@ -1212,7 +1217,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_Basic() { srcName := "test_softdel_src" srcTableName := s.attachSchemaSuffix(srcName) - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 INT, @@ -1248,15 +1253,15 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_Basic() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2,t) VALUES (1,2,random_string(9000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTablesWithNames(env, s, "normalize insert", srcName, tableName, "id,c1,c2,t") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` UPDATE %s SET c1=c1+4 WHERE id=1`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTablesWithNames(env, s, "normalize update", srcName, tableName, "id,c1,c2,t") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` DELETE FROM %s WHERE id=1`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize delete", func() bool { @@ -1292,7 +1297,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_IUD_Same_Batch() { srcTableName := fmt.Sprintf("%s_src", cmpTableName) dstTableName := "test_softdel_iud" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 INT, @@ -1328,7 +1333,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_IUD_Same_Batch() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - insertTx, err := s.conn.Begin(context.Background()) + insertTx, err := s.Conn().Begin(context.Background()) e2e.EnvNoError(s.t, env, err) _, err = insertTx.Exec(context.Background(), fmt.Sprintf(` @@ -1370,7 +1375,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_UD_Same_Batch() { srcTableName := s.attachSchemaSuffix(srcName) dstName := "test_softdel_ud" - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 INT, @@ -1406,12 +1411,12 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_UD_Same_Batch() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2,t) VALUES (1,2,random_string(9000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTablesWithNames(env, s, "normalize insert", srcName, dstName, "id,c1,c2,t") - insertTx, err := s.conn.Begin(context.Background()) + insertTx, err := s.Conn().Begin(context.Background()) e2e.EnvNoError(s.t, env, err) _, err = insertTx.Exec(context.Background(), fmt.Sprintf(` UPDATE %s SET t=random_string(10000) WHERE id=1`, srcTableName)) @@ -1454,7 +1459,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_Insert_After_Delete() { tableName := "test_softdel_iad" srcTableName := s.attachSchemaSuffix(tableName) - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, c1 INT, @@ -1490,11 +1495,11 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_Insert_After_Delete() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2,t) VALUES (1,2,random_string(9000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTables(env, s, "normalize insert", tableName, "id,c1,c2,t") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` DELETE FROM %s WHERE id=1`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize delete", func() bool { @@ -1508,7 +1513,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_Insert_After_Delete() { } return e2eshared.CheckEqualRecordBatches(s.t, pgRows, rows) }) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(id,c1,c2,t) VALUES (1,3,4,random_string(10000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTables(env, s, "normalize reinsert", tableName, "id,c1,c2,t") diff --git a/flow/e2e/bigquery/qrep_flow_bq_test.go b/flow/e2e/bigquery/qrep_flow_bq_test.go index 395a2c5ea2..e881662c7b 100644 --- a/flow/e2e/bigquery/qrep_flow_bq_test.go +++ b/flow/e2e/bigquery/qrep_flow_bq_test.go @@ -11,9 +11,9 @@ import ( ) func (s PeerFlowE2ETestSuiteBQ) setupSourceTable(tableName string, rowCount int) { - err := e2e.CreateTableForQRep(s.conn, s.bqSuffix, tableName) + err := e2e.CreateTableForQRep(s.Conn(), s.bqSuffix, tableName) require.NoError(s.t, err) - err = e2e.PopulateSourceTable(s.conn, s.bqSuffix, tableName, rowCount) + err = e2e.PopulateSourceTable(s.Conn(), s.bqSuffix, tableName, rowCount) require.NoError(s.t, err) } @@ -24,7 +24,7 @@ func (s PeerFlowE2ETestSuiteBQ) setupTimeTable(tableName string) { "mytztimestamp timestamptz", } tblFieldStr := strings.Join(tblFields, ",") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE e2e_test_%s.%s ( %s );`, s.bqSuffix, tableName, tblFieldStr)) @@ -35,7 +35,7 @@ func (s PeerFlowE2ETestSuiteBQ) setupTimeTable(tableName string) { row := `(CURRENT_TIMESTAMP,'10001-03-14 23:05:52','50001-03-14 23:05:52.216809+00')` rows = append(rows, row) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO e2e_test_%s.%s ( watermark_ts, mytimestamp, diff --git a/flow/e2e/congen.go b/flow/e2e/congen.go index 73450d130d..8a7ff6235e 100644 --- a/flow/e2e/congen.go +++ b/flow/e2e/congen.go @@ -9,8 +9,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" - "github.com/PeerDB-io/peer-flow/connectors/utils" - "github.com/PeerDB-io/peer-flow/e2eshared" + connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/generated/protos" ) @@ -120,34 +119,35 @@ func setupPostgresSchema(t *testing.T, conn *pgx.Conn, suffix string) error { } // SetupPostgres sets up the postgres connection. -func SetupPostgres(t *testing.T, suffix string) (*pgx.Conn, error) { +func SetupPostgres(t *testing.T, suffix string) (*connpostgres.PostgresConnector, error) { t.Helper() - conn, err := pgx.Connect(context.Background(), utils.GetPGConnectionString(GetTestPostgresConf())) + connector, err := connpostgres.NewPostgresConnector(context.Background(), GetTestPostgresConf()) if err != nil { return nil, fmt.Errorf("failed to create postgres connection: %w", err) } + conn := connector.Conn() err = cleanPostgres(conn, suffix) if err != nil { - conn.Close(context.Background()) + connector.Close(context.Background()) return nil, err } err = setupPostgresSchema(t, conn, suffix) if err != nil { - conn.Close(context.Background()) + connector.Close(context.Background()) return nil, err } - return conn, nil + return connector, nil } -func TearDownPostgres[T e2eshared.Suite](s T) { +func TearDownPostgres[T Suite](s T) { t := s.T() t.Helper() - conn := s.Conn() + conn := s.Connector().Conn() if conn != nil { suffix := s.Suffix() t.Log("begin tearing down postgres schema", suffix) diff --git a/flow/e2e/postgres/peer_flow_pg_test.go b/flow/e2e/postgres/peer_flow_pg_test.go index 642bc13a94..6d8f185437 100644 --- a/flow/e2e/postgres/peer_flow_pg_test.go +++ b/flow/e2e/postgres/peer_flow_pg_test.go @@ -36,7 +36,7 @@ func (s PeerFlowE2ETestSuitePG) checkPeerdbColumns(dstSchemaQualified string, ro dstSchemaQualified, rowID) var isDeleted pgtype.Bool var syncedAt pgtype.Timestamp - err := s.conn.QueryRow(context.Background(), query).Scan(&isDeleted, &syncedAt) + err := s.Conn().QueryRow(context.Background(), query).Scan(&isDeleted, &syncedAt) if err != nil { return fmt.Errorf("failed to query row: %w", err) } @@ -63,7 +63,7 @@ func (s PeerFlowE2ETestSuitePG) WaitForSchema( s.t.Helper() e2e.EnvWaitFor(s.t, env, 3*time.Minute, reason, func() bool { s.t.Helper() - output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ + output, err := s.conn.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{dstTableName}, }) if err != nil { @@ -84,7 +84,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Simple_Flow_PG() { srcTableName := s.attachSchemaSuffix("test_simple_flow") dstTableName := s.attachSchemaSuffix("test_simple_flow_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, key TEXT NOT NULL, @@ -112,7 +112,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Simple_Flow_PG() { for i := 0; i < 10; i++ { testKey := fmt.Sprintf("test_key_%d", i) testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(key, value, myh) VALUES ($1, $2, '"a"=>"b"') `, srcTableName), testKey, testValue) e2e.EnvNoError(s.t, env, err) @@ -135,7 +135,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Geospatial_PG() { srcTableName := s.attachSchemaSuffix("test_geospatial_pg") dstTableName := s.attachSchemaSuffix("test_geospatial_pg_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, gg geography NOT NULL, @@ -157,7 +157,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Geospatial_PG() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) // insert 1 row into the source table - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(gg, gm) VALUES ('POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))','LINESTRING(0 0, 1 1, 2 2)') `, srcTableName)) e2e.EnvNoError(s.t, env, err) @@ -179,7 +179,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Types_PG() { srcTableName := s.attachSchemaSuffix("test_types_pg") dstTableName := s.attachSchemaSuffix("test_types_pg_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (id serial PRIMARY KEY,c1 BIGINT,c2 BIT,c4 BOOLEAN, c7 CHARACTER,c8 varchar,c9 CIDR,c11 DATE,c12 FLOAT,c13 DOUBLE PRECISION, c14 INET,c15 INTEGER,c21 MACADDR, @@ -202,7 +202,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Types_PG() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s SELECT 2,2,b'1', true,'s','test','1.1.10.2'::cidr, CURRENT_DATE,1.23,1.234,'192.168.1.5'::inet,1, @@ -244,11 +244,11 @@ func (s PeerFlowE2ETestSuitePG) Test_Enums_PG() { dstTableName := s.attachSchemaSuffix("test_enum_flow_dst") createMoodEnum := "CREATE TYPE mood AS ENUM ('happy', 'sad', 'angry');" var pgErr *pgconn.PgError - _, enumErr := s.conn.Exec(context.Background(), createMoodEnum) + _, enumErr := s.Conn().Exec(context.Background(), createMoodEnum) if errors.As(enumErr, &pgErr) && pgErr.Code != pgerrcode.DuplicateObject && !utils.IsUniqueError(enumErr) { require.NoError(s.t, enumErr) } - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, my_mood mood, @@ -269,7 +269,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Enums_PG() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(my_mood, my_null_mood) VALUES ('happy',null) `, srcTableName)) e2e.EnvNoError(s.t, env, err) @@ -291,7 +291,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() { srcTableName := s.attachSchemaSuffix("test_simple_schema_changes") dstTableName := s.attachSchemaSuffix("test_simple_schema_changes_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 BIGINT @@ -314,7 +314,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() { go func() { // insert first row. e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1) VALUES ($1)`, srcTableName), 1) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted initial row in the source table") @@ -342,11 +342,11 @@ func (s PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() { }) // alter source table, add column c2 and insert another row. - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` ALTER TABLE %s ADD COLUMN c2 BIGINT`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Altered source table, added column c2") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2) VALUES ($1,$2)`, srcTableName), 2, 2) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted row with added c2 in the source table") @@ -379,11 +379,11 @@ func (s PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() { }) // alter source table, add column c3, drop column c2 and insert another row. - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` ALTER TABLE %s DROP COLUMN c2, ADD COLUMN c3 BIGINT`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Altered source table, dropped column c2 and added column c3") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c3) VALUES ($1,$2)`, srcTableName), 3, 3) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted row with added c3 in the source table") @@ -421,11 +421,11 @@ func (s PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() { }) // alter source table, drop column c3 and insert another row. - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` ALTER TABLE %s DROP COLUMN c3`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Altered source table, dropped column c3") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1) VALUES ($1)`, srcTableName), 4) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted row after dropping all columns in the source table") @@ -475,7 +475,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Composite_PKey_PG() { srcTableName := s.attachSchemaSuffix("test_simple_cpkey") dstTableName := s.attachSchemaSuffix("test_simple_cpkey_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT GENERATED ALWAYS AS IDENTITY, c1 INT GENERATED BY DEFAULT AS IDENTITY, @@ -503,7 +503,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Composite_PKey_PG() { // insert 10 rows into the source table for i := 0; i < 10; i++ { testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c2,t) VALUES ($1,$2) `, srcTableName), i, testValue) e2e.EnvNoError(s.t, env, err) @@ -514,10 +514,10 @@ func (s PeerFlowE2ETestSuitePG) Test_Composite_PKey_PG() { return s.comparePGTables(srcTableName, dstTableName, "id,c1,c2,t") == nil }) - _, err := s.conn.Exec(context.Background(), + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(`UPDATE %s SET c1=c1+1 WHERE MOD(c2,2)=$1`, srcTableName), 1) e2e.EnvNoError(s.t, env, err) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize modifications", func() bool { return s.comparePGTables(srcTableName, dstTableName, "id,c1,c2,t") == nil @@ -536,7 +536,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Composite_PKey_Toast_1_PG() { randomString := s.attachSchemaSuffix("random_string") dstTableName := s.attachSchemaSuffix("test_cpkey_toast1_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT GENERATED ALWAYS AS IDENTITY, c1 INT GENERATED BY DEFAULT AS IDENTITY, @@ -565,7 +565,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Composite_PKey_Toast_1_PG() { // and then insert, update and delete rows in the table. go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - rowsTx, err := s.conn.Begin(context.Background()) + rowsTx, err := s.Conn().Begin(context.Background()) e2e.EnvNoError(s.t, env, err) // insert 10 rows into the source table @@ -604,7 +604,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Composite_PKey_Toast_2_PG() { randomString := s.attachSchemaSuffix("random_string") dstTableName := s.attachSchemaSuffix("test_cpkey_toast2_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT GENERATED ALWAYS AS IDENTITY, c1 INT GENERATED BY DEFAULT AS IDENTITY, @@ -637,7 +637,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Composite_PKey_Toast_2_PG() { // insert 10 rows into the source table for i := 0; i < 10; i++ { testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c2,t,t2) VALUES ($1,$2,%s(9000)) `, srcTableName, randomString), i, testValue) e2e.EnvNoError(s.t, env, err) @@ -647,10 +647,10 @@ func (s PeerFlowE2ETestSuitePG) Test_Composite_PKey_Toast_2_PG() { e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize 10 rows", func() bool { return s.comparePGTables(srcTableName, dstTableName, "id,c1,c2,t,t2") == nil }) - _, err = s.conn.Exec(context.Background(), + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`UPDATE %s SET c1=c1+1 WHERE MOD(c2,2)=$1`, srcTableName), 1) e2e.EnvNoError(s.t, env, err) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize update", func() bool { @@ -670,7 +670,7 @@ func (s PeerFlowE2ETestSuitePG) Test_PeerDB_Columns() { srcTableName := s.attachSchemaSuffix("test_peerdb_cols") dstTableName := s.attachSchemaSuffix("test_peerdb_cols_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, key TEXT NOT NULL, @@ -695,13 +695,13 @@ func (s PeerFlowE2ETestSuitePG) Test_PeerDB_Columns() { // insert 1 row into the source table testKey := fmt.Sprintf("test_key_%d", 1) testValue := fmt.Sprintf("test_value_%d", 1) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(key, value) VALUES ($1, $2) `, srcTableName), testKey, testValue) e2e.EnvNoError(s.t, env, err) // delete that row - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` DELETE FROM %s WHERE id=1 `, srcTableName)) e2e.EnvNoError(s.t, env, err) @@ -724,7 +724,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Soft_Delete_Basic() { srcTableName := fmt.Sprintf("%s_src", cmpTableName) dstTableName := s.attachSchemaSuffix("test_softdel_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 INT, @@ -760,23 +760,23 @@ func (s PeerFlowE2ETestSuitePG) Test_Soft_Delete_Basic() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2,t) VALUES (1,2,random_string(9000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize row", func() bool { return s.comparePGTables(srcTableName, dstTableName, "id,c1,c2,t") == nil }) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` UPDATE %s SET c1=c1+4 WHERE id=1`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize update", func() bool { return s.comparePGTables(srcTableName, dstTableName, "id,c1,c2,t") == nil }) // since we delete stuff, create another table to compare with - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE %s AS SELECT * FROM %s`, cmpTableName, srcTableName)) e2e.EnvNoError(s.t, env, err) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` DELETE FROM %s WHERE id=1`, srcTableName)) e2e.EnvNoError(s.t, env, err) @@ -810,7 +810,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Soft_Delete_IUD_Same_Batch() { srcTableName := fmt.Sprintf("%s_src", cmpTableName) dstTableName := s.attachSchemaSuffix("test_softdel_iud_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 INT, @@ -846,7 +846,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Soft_Delete_IUD_Same_Batch() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - insertTx, err := s.conn.Begin(context.Background()) + insertTx, err := s.Conn().Begin(context.Background()) e2e.EnvNoError(s.t, env, err) _, err = insertTx.Exec(context.Background(), fmt.Sprintf(` @@ -889,7 +889,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Soft_Delete_UD_Same_Batch() { srcTableName := fmt.Sprintf("%s_src", cmpTableName) dstTableName := s.attachSchemaSuffix("test_softdel_ud_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 INT, @@ -925,14 +925,14 @@ func (s PeerFlowE2ETestSuitePG) Test_Soft_Delete_UD_Same_Batch() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2,t) VALUES (1,2,random_string(9000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize row", func() bool { return s.comparePGTables(srcTableName, dstTableName, "id,c1,c2,t") == nil }) - insertTx, err := s.conn.Begin(context.Background()) + insertTx, err := s.Conn().Begin(context.Background()) e2e.EnvNoError(s.t, env, err) _, err = insertTx.Exec(context.Background(), fmt.Sprintf(` UPDATE %s SET t=random_string(10000) WHERE id=1`, srcTableName)) @@ -973,7 +973,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Soft_Delete_Insert_After_Delete() { srcTableName := s.attachSchemaSuffix("test_softdel_iad") dstTableName := s.attachSchemaSuffix("test_softdel_iad_dst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, c1 INT, @@ -1009,19 +1009,19 @@ func (s PeerFlowE2ETestSuitePG) Test_Soft_Delete_Insert_After_Delete() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2,t) VALUES (1,2,random_string(9000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize row", func() bool { return s.comparePGTables(srcTableName, dstTableName, "id,c1,c2,t") == nil }) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` DELETE FROM %s WHERE id=1`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize delete", func() bool { return s.comparePGTables(srcTableName, dstTableName+` WHERE NOT "_PEERDB_IS_DELETED"`, "id,c1,c2,t") == nil }) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(id,c1,c2,t) VALUES (1,3,4,random_string(10000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize reinsert", func() bool { @@ -1050,7 +1050,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Supported_Mixed_Case_Table() { stmtDstTableName := fmt.Sprintf(`e2e_test_%s."%s"`, s.suffix, "testMixedCaseDst") dstTableName := s.attachSchemaSuffix("testMixedCaseDst") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( "pulseArmor" SERIAL PRIMARY KEY, "highGold" TEXT NOT NULL, @@ -1086,7 +1086,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Supported_Mixed_Case_Table() { for i := 0; i < 20; i++ { testKey := fmt.Sprintf("test_key_%d", i) testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s ("highGold","eVe") VALUES ($1, $2) `, stmtSrcTableName), testKey, testValue) e2e.EnvNoError(s.t, env, err) @@ -1122,7 +1122,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Dynamic_Mirror_Config_Via_Signals() { isPaused := false sentUpdate := false - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, t TEXT DEFAULT md5(random()::text)); @@ -1157,10 +1157,10 @@ func (s PeerFlowE2ETestSuitePG) Test_Dynamic_Mirror_Config_Via_Signals() { addRows := func(numRows int) { for i := 0; i < numRows; i++ { - _, err = s.conn.Exec(context.Background(), + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`INSERT INTO %s DEFAULT VALUES`, srcTable1Name)) e2e.EnvNoError(s.t, env, err) - _, err = s.conn.Exec(context.Background(), + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`INSERT INTO %s DEFAULT VALUES`, srcTable2Name)) e2e.EnvNoError(s.t, env, err) } diff --git a/flow/e2e/postgres/qrep_flow_pg_test.go b/flow/e2e/postgres/qrep_flow_pg_test.go index 1723afca5d..65a043df78 100644 --- a/flow/e2e/postgres/qrep_flow_pg_test.go +++ b/flow/e2e/postgres/qrep_flow_pg_test.go @@ -22,10 +22,9 @@ import ( type PeerFlowE2ETestSuitePG struct { t *testing.T - conn *pgx.Conn - peer *protos.Peer - connector *connpostgres.PostgresConnector - suffix string + conn *connpostgres.PostgresConnector + peer *protos.Peer + suffix string } func (s PeerFlowE2ETestSuitePG) T() *testing.T { @@ -33,6 +32,10 @@ func (s PeerFlowE2ETestSuitePG) T() *testing.T { } func (s PeerFlowE2ETestSuitePG) Conn() *pgx.Conn { + return s.conn.Conn() +} + +func (s PeerFlowE2ETestSuitePG) Connector() *connpostgres.PostgresConnector { return s.conn } @@ -58,33 +61,20 @@ func SetupSuite(t *testing.T) PeerFlowE2ETestSuitePG { suffix := "pg_" + strings.ToLower(shared.RandomString(8)) conn, err := e2e.SetupPostgres(t, suffix) - if err != nil { - require.Fail(t, "failed to setup postgres", err) - } - - connector, err := connpostgres.NewPostgresConnector(context.Background(), - &protos.PostgresConfig{ - Host: "localhost", - Port: 7132, - User: "postgres", - Password: "postgres", - Database: "postgres", - }) - require.NoError(t, err) + require.NoError(t, err, "failed to setup postgres") return PeerFlowE2ETestSuitePG{ - t: t, - conn: conn, - peer: generatePGPeer(e2e.GetTestPostgresConf()), - connector: connector, - suffix: suffix, + t: t, + conn: conn, + peer: generatePGPeer(e2e.GetTestPostgresConf()), + suffix: suffix, } } func (s PeerFlowE2ETestSuitePG) setupSourceTable(tableName string, rowCount int) { - err := e2e.CreateTableForQRep(s.conn, s.suffix, tableName) + err := e2e.CreateTableForQRep(s.Conn(), s.suffix, tableName) require.NoError(s.t, err) - err = e2e.PopulateSourceTable(s.conn, s.suffix, tableName, rowCount) + err = e2e.PopulateSourceTable(s.Conn(), s.suffix, tableName, rowCount) require.NoError(s.t, err) } @@ -108,7 +98,7 @@ func (s PeerFlowE2ETestSuitePG) checkEnums(srcSchemaQualified, dstSchemaQualifie "SELECT 1 FROM %s dst "+ "WHERE src.my_mood::text = dst.my_mood::text)) LIMIT 1;", srcSchemaQualified, dstSchemaQualified) - err := s.conn.QueryRow(context.Background(), query).Scan(&exists) + err := s.Conn().QueryRow(context.Background(), query).Scan(&exists) if err != nil { return err } @@ -122,7 +112,7 @@ func (s PeerFlowE2ETestSuitePG) checkEnums(srcSchemaQualified, dstSchemaQualifie func (s PeerFlowE2ETestSuitePG) compareQuery(srcSchemaQualified, dstSchemaQualified, selector string) error { query := fmt.Sprintf("SELECT %s FROM %s EXCEPT SELECT %s FROM %s", selector, srcSchemaQualified, selector, dstSchemaQualified) - rows, err := s.conn.Query(context.Background(), query, pgx.QueryExecModeExec) + rows, err := s.Conn().Query(context.Background(), query, pgx.QueryExecModeExec) if err != nil { return err } @@ -156,7 +146,7 @@ func (s PeerFlowE2ETestSuitePG) compareQuery(srcSchemaQualified, dstSchemaQualif func (s PeerFlowE2ETestSuitePG) checkSyncedAt(dstSchemaQualified string) error { query := fmt.Sprintf(`SELECT "_PEERDB_SYNCED_AT" FROM %s`, dstSchemaQualified) - rows, _ := s.conn.Query(context.Background(), query) + rows, _ := s.Conn().Query(context.Background(), query) defer rows.Close() for rows.Next() { @@ -176,12 +166,12 @@ func (s PeerFlowE2ETestSuitePG) checkSyncedAt(dstSchemaQualified string) error { func (s PeerFlowE2ETestSuitePG) RunInt64Query(query string) (int64, error) { var count pgtype.Int8 - err := s.conn.QueryRow(context.Background(), query).Scan(&count) + err := s.Conn().QueryRow(context.Background(), query).Scan(&count) return count.Int64, err } func (s PeerFlowE2ETestSuitePG) TestSimpleSlotCreation() { - setupTx, err := s.conn.Begin(context.Background()) + setupTx, err := s.Conn().Begin(context.Background()) require.NoError(s.t, err) // setup 3 tables in pgpeer_repl_test schema // test_1, test_2, test_3, all have 5 columns all text, c1, c2, c3, c4, c5 @@ -207,7 +197,7 @@ func (s PeerFlowE2ETestSuitePG) TestSimpleSlotCreation() { setupError := make(chan error) go func() { - setupError <- s.connector.SetupReplication(context.Background(), signal, setupReplicationInput) + setupError <- s.conn.SetupReplication(context.Background(), signal, setupReplicationInput) }() s.t.Log("waiting for slot creation to complete: ", flowJobName) @@ -230,7 +220,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Complete_QRep_Flow_Multi_Insert_PG() { dstTable := "test_qrep_flow_avro_pg_2" - err := e2e.CreateTableForQRep(s.conn, s.suffix, dstTable) + err := e2e.CreateTableForQRep(s.Conn(), s.suffix, dstTable) require.NoError(s.t, err) srcSchemaQualified := fmt.Sprintf("%s_%s.%s", "e2e_test", s.suffix, srcTable) diff --git a/flow/e2e/s3/cdc_s3_test.go b/flow/e2e/s3/cdc_s3_test.go index 7295273597..3fd89ac74f 100644 --- a/flow/e2e/s3/cdc_s3_test.go +++ b/flow/e2e/s3/cdc_s3_test.go @@ -25,7 +25,7 @@ func (s PeerFlowE2ETestSuiteS3) Test_Complete_Simple_Flow_S3() { srcTableName := s.attachSchemaSuffix("test_simple_flow_s3") dstTableName := fmt.Sprintf("%s.%s", "peerdb_test_s3", "test_simple_flow_s3") flowJobName := s.attachSuffix("test_simple_flow_s3") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE %s ( id SERIAL PRIMARY KEY, key TEXT NOT NULL, @@ -49,7 +49,7 @@ func (s PeerFlowE2ETestSuiteS3) Test_Complete_Simple_Flow_S3() { for i := 1; i <= 20; i++ { testKey := fmt.Sprintf("test_key_%d", i) testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (key, value) VALUES ($1, $2) `, srcTableName), testKey, testValue) e2e.EnvNoError(s.t, env, err) diff --git a/flow/e2e/s3/qrep_flow_s3_test.go b/flow/e2e/s3/qrep_flow_s3_test.go index a27d52171f..54a66f0ed0 100644 --- a/flow/e2e/s3/qrep_flow_s3_test.go +++ b/flow/e2e/s3/qrep_flow_s3_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" - "github.com/jackc/pgx/v5" "github.com/joho/godotenv" "github.com/stretchr/testify/require" + connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/e2eshared" "github.com/PeerDB-io/peer-flow/shared" @@ -19,7 +19,7 @@ import ( type PeerFlowE2ETestSuiteS3 struct { t *testing.T - conn *pgx.Conn + conn *connpostgres.PostgresConnector s3Helper *S3TestHelper suffix string } @@ -28,7 +28,7 @@ func (s PeerFlowE2ETestSuiteS3) T() *testing.T { return s.t } -func (s PeerFlowE2ETestSuiteS3) Conn() *pgx.Conn { +func (s PeerFlowE2ETestSuiteS3) Connector() *connpostgres.PostgresConnector { return s.conn } @@ -54,9 +54,9 @@ func TestPeerFlowE2ETestSuiteGCS(t *testing.T) { } func (s PeerFlowE2ETestSuiteS3) setupSourceTable(tableName string, rowCount int) { - err := e2e.CreateTableForQRep(s.conn, s.suffix, tableName) + err := e2e.CreateTableForQRep(s.conn.Conn(), s.suffix, tableName) require.NoError(s.t, err) - err = e2e.PopulateSourceTable(s.conn, s.suffix, tableName, rowCount) + err = e2e.PopulateSourceTable(s.conn.Conn(), s.suffix, tableName, rowCount) require.NoError(s.t, err) } diff --git a/flow/e2e/snowflake/peer_flow_sf_test.go b/flow/e2e/snowflake/peer_flow_sf_test.go index f669492d45..716eee667c 100644 --- a/flow/e2e/snowflake/peer_flow_sf_test.go +++ b/flow/e2e/snowflake/peer_flow_sf_test.go @@ -14,6 +14,7 @@ import ( "github.com/joho/godotenv" "github.com/stretchr/testify/require" + connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/e2e" @@ -29,7 +30,7 @@ type PeerFlowE2ETestSuiteSF struct { t *testing.T pgSuffix string - conn *pgx.Conn + conn *connpostgres.PostgresConnector sfHelper *SnowflakeTestHelper connector *connsnowflake.SnowflakeConnector } @@ -38,10 +39,14 @@ func (s PeerFlowE2ETestSuiteSF) T() *testing.T { return s.t } -func (s PeerFlowE2ETestSuiteSF) Conn() *pgx.Conn { +func (s PeerFlowE2ETestSuiteSF) Connector() *connpostgres.PostgresConnector { return s.conn } +func (s PeerFlowE2ETestSuiteSF) Conn() *pgx.Conn { + return s.Connector().Conn() +} + func (s PeerFlowE2ETestSuiteSF) Suffix() string { return s.pgSuffix } @@ -128,7 +133,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Complete_Simple_Flow_SF() { srcTableName := s.attachSchemaSuffix(tableName) dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, tableName) - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, key TEXT NOT NULL, @@ -155,7 +160,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Complete_Simple_Flow_SF() { for i := 0; i < 20; i++ { testKey := fmt.Sprintf("test_key_%d", i) testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (key, value) VALUES ($1, $2) `, srcTableName), testKey, testValue) e2e.EnvNoError(s.t, env, err) @@ -187,7 +192,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Flow_ReplicaIdentity_Index_No_Pkey() { dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_replica_identity_no_pkey") // Create a table without a primary key and create a named unique index - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL, key TEXT NOT NULL, @@ -216,7 +221,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Flow_ReplicaIdentity_Index_No_Pkey() { for i := 0; i < 20; i++ { testKey := fmt.Sprintf("test_key_%d", i) testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (id, key, value) VALUES ($1, $2, $3) `, srcTableName), i, testKey, testValue) e2e.EnvNoError(s.t, env, err) @@ -241,7 +246,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() { srcTableName := s.attachSchemaSuffix(tableName) dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_invalid_geo_sf_avro_cdc") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, line GEOMETRY(LINESTRING) NOT NULL, @@ -266,7 +271,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) // insert 4 invalid shapes and 6 valid shapes into the source table for i := 0; i < 4; i++ { - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (line,poly) VALUES ($1,$2) `, srcTableName), "010200000001000000000000000000F03F0000000000000040", "0103000020e6100000010000000c0000001a8361d35dc64140afdb8d2b1bc3c9bf1b8ed4685fc641405ba64c"+ @@ -278,7 +283,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() { } s.t.Log("Inserted 4 invalid geography rows into the source table") for i := 4; i < 10; i++ { - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (line,poly) VALUES ($1,$2) `, srcTableName), "010200000002000000000000000000F03F000000000000004000000000000008400000000000001040", "010300000001000000050000000000000000000000000000000000000000000000"+ @@ -321,7 +326,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Toast_SF() { srcTableName := s.attachSchemaSuffix("test_toast_sf_1") dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_toast_sf_1") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, t1 text, @@ -351,7 +356,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Toast_SF() { 2. changes no toast column 2. changes 1 toast column */ - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` BEGIN; INSERT INTO %s (t1,t2,k) SELECT random_string(9000),random_string(9000), 1 FROM generate_series(1,2); @@ -375,7 +380,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Toast_Advance_1_SF() { srcTableName := s.attachSchemaSuffix("test_toast_sf_3") dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_toast_sf_3") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, t1 text, @@ -400,7 +405,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Toast_Advance_1_SF() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) // complex transaction with random DMLs on a table with toast columns - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` BEGIN; INSERT INTO %s (t1,t2,k) SELECT random_string(9000),random_string(9000), 1 FROM generate_series(1,2); @@ -435,7 +440,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Toast_Advance_2_SF() { srcTableName := s.attachSchemaSuffix("test_toast_sf_4") dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_toast_sf_4") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, t1 text, @@ -459,7 +464,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Toast_Advance_2_SF() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) // complex transaction with random DMLs on a table with toast columns - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` BEGIN; INSERT INTO %s (t1,k) SELECT random_string(9000), 1 FROM generate_series(1,1); @@ -488,7 +493,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Toast_Advance_3_SF() { srcTableName := s.attachSchemaSuffix("test_toast_sf_5") dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_toast_sf_5") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, t1 text, @@ -516,7 +521,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Toast_Advance_3_SF() { transaction updating a single row multiple times with changed/unchanged toast columns */ - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` BEGIN; INSERT INTO %s (t1,t2,k) SELECT random_string(9000),random_string(9000), 1 FROM generate_series(1,1); @@ -543,11 +548,11 @@ func (s PeerFlowE2ETestSuiteSF) Test_Types_SF() { dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_types_sf") createMoodEnum := "CREATE TYPE mood AS ENUM ('happy', 'sad', 'angry');" var pgErr *pgconn.PgError - _, enumErr := s.conn.Exec(context.Background(), createMoodEnum) + _, enumErr := s.Conn().Exec(context.Background(), createMoodEnum) if errors.As(enumErr, &pgErr) && pgErr.Code != pgerrcode.DuplicateObject && !utils.IsUniqueError(pgErr) { require.NoError(s.t, enumErr) } - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (id serial PRIMARY KEY,c1 BIGINT,c2 BIT,c3 VARBIT,c4 BOOLEAN, c6 BYTEA,c7 CHARACTER,c8 varchar,c9 CIDR,c11 DATE,c12 FLOAT,c13 DOUBLE PRECISION, c14 INET,c15 INTEGER,c16 INTERVAL,c17 JSON,c18 JSONB,c21 MACADDR,c22 MONEY, @@ -574,7 +579,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Types_SF() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) /* test inserting various types*/ - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s SELECT 2,2,b'1',b'101', true,random_bytea(32),'s','test','1.1.10.2'::cidr, CURRENT_DATE,1.23,1.234,'192.168.1.5'::inet,1, @@ -640,7 +645,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Multi_Table_SF() { dstTable1Name := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test1_sf") dstTable2Name := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test2_sf") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (id serial primary key, c1 int, c2 text); CREATE TABLE IF NOT EXISTS %s (id serial primary key, c1 int, c2 text); `, srcTable1Name, srcTable2Name)) @@ -661,7 +666,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Multi_Table_SF() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) /* inserting across multiple tables*/ - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (c1,c2) VALUES (1,'dummy_1'); INSERT INTO %s (c1,c2) VALUES (-1,'dummy_-1'); `, srcTable1Name, srcTable2Name)) @@ -693,7 +698,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Simple_Schema_Changes_SF() { srcTableName := s.attachSchemaSuffix("test_simple_schema_changes") dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_simple_schema_changes") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 BIGINT @@ -715,7 +720,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Simple_Schema_Changes_SF() { // and then insert and mutate schema repeatedly. go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1) VALUES ($1)`, srcTableName), 1) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted initial row in the source table") @@ -754,11 +759,11 @@ func (s PeerFlowE2ETestSuiteSF) Test_Simple_Schema_Changes_SF() { e2e.EnvTrue(s.t, env, e2e.CompareTableSchemas(expectedTableSchema, output.TableNameSchemaMapping[dstTableName])) // alter source table, add column c2 and insert another row. - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` ALTER TABLE %s ADD COLUMN c2 BIGINT`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Altered source table, added column c2") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2) VALUES ($1,$2)`, srcTableName), 2, 2) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted row with added c2 in the source table") @@ -798,11 +803,11 @@ func (s PeerFlowE2ETestSuiteSF) Test_Simple_Schema_Changes_SF() { e2e.EnvEqualTables(env, s, "test_simple_schema_changes", "id,c1,c2") // alter source table, add column c3, drop column c2 and insert another row. - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` ALTER TABLE %s DROP COLUMN c2, ADD COLUMN c3 BIGINT`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Altered source table, dropped column c2 and added column c3") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c3) VALUES ($1,$2)`, srcTableName), 3, 3) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted row with added c3 in the source table") @@ -847,11 +852,11 @@ func (s PeerFlowE2ETestSuiteSF) Test_Simple_Schema_Changes_SF() { e2e.EnvEqualTables(env, s, "test_simple_schema_changes", "id,c1,c3") // alter source table, drop column c3 and insert another row. - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` ALTER TABLE %s DROP COLUMN c3`, srcTableName)) e2e.EnvNoError(s.t, env, err) s.t.Log("Altered source table, dropped column c3") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1) VALUES ($1)`, srcTableName), 4) e2e.EnvNoError(s.t, env, err) s.t.Log("Inserted row after dropping all columns in the source table") @@ -908,7 +913,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Composite_PKey_SF() { srcTableName := s.attachSchemaSuffix("test_simple_cpkey") dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_simple_cpkey") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT GENERATED ALWAYS AS IDENTITY, c1 INT GENERATED BY DEFAULT AS IDENTITY, @@ -936,7 +941,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Composite_PKey_SF() { // insert 10 rows into the source table for i := 0; i < 10; i++ { testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c2,t) VALUES ($1,$2) `, srcTableName), i, testValue) e2e.EnvNoError(s.t, env, err) @@ -945,10 +950,10 @@ func (s PeerFlowE2ETestSuiteSF) Test_Composite_PKey_SF() { e2e.EnvWaitForEqualTables(env, s, "normalize table", "test_simple_cpkey", "id,c1,c2,t") - _, err := s.conn.Exec(context.Background(), + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(`UPDATE %s SET c1=c1+1 WHERE MOD(c2,2)=$1`, srcTableName), 1) e2e.EnvNoError(s.t, env, err) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTables(env, s, "normalize update/delete", "test_simple_cpkey", "id,c1,c2,t") @@ -965,7 +970,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Composite_PKey_Toast_1_SF() { srcTableName := s.attachSchemaSuffix("test_cpkey_toast1") dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_cpkey_toast1") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT GENERATED ALWAYS AS IDENTITY, c1 INT GENERATED BY DEFAULT AS IDENTITY, @@ -991,7 +996,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Composite_PKey_Toast_1_SF() { // and then insert, update and delete rows in the table. go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - rowsTx, err := s.conn.Begin(context.Background()) + rowsTx, err := s.Conn().Begin(context.Background()) e2e.EnvNoError(s.t, env, err) // insert 10 rows into the source table @@ -1028,7 +1033,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Composite_PKey_Toast_2_SF() { srcTableName := s.attachSchemaSuffix(tableName) dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, tableName) - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT GENERATED ALWAYS AS IDENTITY, c1 INT GENERATED BY DEFAULT AS IDENTITY, @@ -1058,7 +1063,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Composite_PKey_Toast_2_SF() { // insert 10 rows into the source table for i := 0; i < 10; i++ { testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c2,t,t2) VALUES ($1,$2,random_string(9000)) `, srcTableName), i, testValue) e2e.EnvNoError(s.t, env, err) @@ -1066,10 +1071,10 @@ func (s PeerFlowE2ETestSuiteSF) Test_Composite_PKey_Toast_2_SF() { s.t.Log("Inserted 10 rows into the source table") e2e.EnvWaitForEqualTables(env, s, "normalize table", tableName, "id,c2,t,t2") - _, err = s.conn.Exec(context.Background(), + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`UPDATE %s SET c1=c1+1 WHERE MOD(c2,2)=$1`, srcTableName), 1) e2e.EnvNoError(s.t, env, err) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=$1`, srcTableName), 0) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTables(env, s, "normalize update/delete", tableName, "id,c2,t,t2") @@ -1087,7 +1092,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Column_Exclusion() { srcTableName := s.attachSchemaSuffix(tableName) dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, tableName) - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT GENERATED ALWAYS AS IDENTITY, c1 INT GENERATED BY DEFAULT AS IDENTITY, @@ -1127,7 +1132,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Column_Exclusion() { // insert 10 rows into the source table for i := 0; i < 10; i++ { testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c2,t,t2) VALUES ($1,$2,random_string(100)) `, srcTableName), i, testValue) e2e.EnvNoError(s.t, env, err) @@ -1135,10 +1140,10 @@ func (s PeerFlowE2ETestSuiteSF) Test_Column_Exclusion() { s.t.Log("Inserted 10 rows into the source table") e2e.EnvWaitForEqualTables(env, s, "normalize table", tableName, "id,c1,t,t2") - _, err = s.conn.Exec(context.Background(), + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`UPDATE %s SET c1=c1+1 WHERE MOD(c2,2)=1`, srcTableName)) e2e.EnvNoError(s.t, env, err) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=0`, srcTableName)) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=0`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTables(env, s, "normalize update/delete", tableName, "id,c1,t,t2") @@ -1165,7 +1170,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Soft_Delete_Basic() { srcTableName := s.attachSchemaSuffix(tableName) dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, dstName) - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 INT, @@ -1201,15 +1206,15 @@ func (s PeerFlowE2ETestSuiteSF) Test_Soft_Delete_Basic() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2,t) VALUES (1,2,random_string(9000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTablesWithNames(env, s, "normalize row", tableName, dstName, "id,c1,c2,t") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` UPDATE %s SET c1=c1+4 WHERE id=1`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTablesWithNames(env, s, "normalize update", tableName, dstName, "id,c1,c2,t") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` DELETE FROM %s WHERE id=1`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTablesWithNames( @@ -1241,7 +1246,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Soft_Delete_IUD_Same_Batch() { srcTableName := fmt.Sprintf("%s_src", cmpTableName) dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_softdel_iud") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 INT, @@ -1277,7 +1282,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Soft_Delete_IUD_Same_Batch() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - insertTx, err := s.conn.Begin(context.Background()) + insertTx, err := s.Conn().Begin(context.Background()) e2e.EnvNoError(s.t, env, err) _, err = insertTx.Exec(context.Background(), fmt.Sprintf(` @@ -1320,7 +1325,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Soft_Delete_UD_Same_Batch() { srcTableName := s.attachSchemaSuffix(tableName) dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, dstName) - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, c1 INT, @@ -1356,12 +1361,12 @@ func (s PeerFlowE2ETestSuiteSF) Test_Soft_Delete_UD_Same_Batch() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2,t) VALUES (1,2,random_string(9000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTablesWithNames(env, s, "normalize insert", tableName, dstName, "id,c1,c2,t") - insertTx, err := s.conn.Begin(context.Background()) + insertTx, err := s.Conn().Begin(context.Background()) e2e.EnvNoError(s.t, env, err) _, err = insertTx.Exec(context.Background(), fmt.Sprintf(` UPDATE %s SET t=random_string(10000) WHERE id=1`, srcTableName)) @@ -1404,7 +1409,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Soft_Delete_Insert_After_Delete() { srcTableName := s.attachSchemaSuffix(tableName) dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, tableName) - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, c1 INT, @@ -1440,11 +1445,11 @@ func (s PeerFlowE2ETestSuiteSF) Test_Soft_Delete_Insert_After_Delete() { go func() { e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(c1,c2,t) VALUES (1,2,random_string(9000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTables(env, s, "normalize row", tableName, "id,c1,c2,t") - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` DELETE FROM %s WHERE id=1`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTablesWithNames( @@ -1456,7 +1461,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Soft_Delete_Insert_After_Delete() { "id,c1,c2,t", ) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s(id,c1,c2,t) VALUES (1,3,4,random_string(10000))`, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitForEqualTables(env, s, "normalize reinsert", tableName, "id,c1,c2,t") @@ -1480,7 +1485,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Supported_Mixed_Case_Table_SF() { srcTableName := s.attachSchemaSuffix("testMixedCase") dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "testMixedCase") - _, err := s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS e2e_test_%s."%s" ( "pulseArmor" SERIAL PRIMARY KEY, "highGold" TEXT NOT NULL, @@ -1508,7 +1513,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Supported_Mixed_Case_Table_SF() { for i := 0; i < 20; i++ { testKey := fmt.Sprintf("test_key_%d", i) testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.conn.Exec(context.Background(), fmt.Sprintf(` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO e2e_test_%s."%s"("highGold","eVe") VALUES ($1, $2) `, s.pgSuffix, "testMixedCase"), testKey, testValue) e2e.EnvNoError(s.t, env, err) diff --git a/flow/e2e/snowflake/qrep_flow_sf_test.go b/flow/e2e/snowflake/qrep_flow_sf_test.go index 9541d61c86..a662513e6a 100644 --- a/flow/e2e/snowflake/qrep_flow_sf_test.go +++ b/flow/e2e/snowflake/qrep_flow_sf_test.go @@ -12,9 +12,9 @@ import ( //nolint:unparam func (s PeerFlowE2ETestSuiteSF) setupSourceTable(tableName string, numRows int) { - err := e2e.CreateTableForQRep(s.conn, s.pgSuffix, tableName) + err := e2e.CreateTableForQRep(s.Conn(), s.pgSuffix, tableName) require.NoError(s.t, err) - err = e2e.PopulateSourceTable(s.conn, s.pgSuffix, tableName, numRows) + err = e2e.PopulateSourceTable(s.Conn(), s.pgSuffix, tableName, numRows) require.NoError(s.t, err) } diff --git a/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go b/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go index e94641ab92..87d4285085 100644 --- a/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go +++ b/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go @@ -14,6 +14,7 @@ import ( "github.com/joho/godotenv" "github.com/stretchr/testify/require" + "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/e2eshared" "github.com/PeerDB-io/peer-flow/generated/protos" @@ -25,7 +26,7 @@ import ( type PeerFlowE2ETestSuiteSQLServer struct { t *testing.T - conn *pgx.Conn + conn *connpostgres.PostgresConnector sqlsHelper *SQLServerHelper suffix string } @@ -35,6 +36,10 @@ func (s PeerFlowE2ETestSuiteSQLServer) T() *testing.T { } func (s PeerFlowE2ETestSuiteSQLServer) Conn() *pgx.Conn { + return s.conn.Conn() +} + +func (s PeerFlowE2ETestSuiteSQLServer) Connector() *connpostgres.PostgresConnector { return s.conn } @@ -115,10 +120,10 @@ func (s PeerFlowE2ETestSuiteSQLServer) insertRowsIntoSQLServerTable(tableName st func (s PeerFlowE2ETestSuiteSQLServer) setupPGDestinationTable(tableName string) { ctx := context.Background() - _, err := s.conn.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS e2e_test_%s.%s", s.suffix, tableName)) + _, err := s.Conn().Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS e2e_test_%s.%s", s.suffix, tableName)) require.NoError(s.t, err) - _, err = s.conn.Exec(ctx, + _, err = s.Conn().Exec(ctx, fmt.Sprintf("CREATE TABLE e2e_test_%s.%s (id TEXT, card_id TEXT, v_from TIMESTAMP, price NUMERIC, status INT)", s.suffix, tableName)) require.NoError(s.t, err) @@ -183,7 +188,7 @@ func (s PeerFlowE2ETestSuiteSQLServer) Test_Complete_QRep_Flow_SqlServer_Append( // Verify that the destination table has the same number of rows as the source table var numRowsInDest pgtype.Int8 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", dstTableName) - err = s.conn.QueryRow(context.Background(), countQuery).Scan(&numRowsInDest) + err = s.Conn().QueryRow(context.Background(), countQuery).Scan(&numRowsInDest) require.NoError(s.t, err) require.Equal(s.t, numRows, int(numRowsInDest.Int64)) diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go index 74a7ddd0f8..efcee75636 100644 --- a/flow/e2e/test_utils.go +++ b/flow/e2e/test_utils.go @@ -35,6 +35,17 @@ import ( peerflow "github.com/PeerDB-io/peer-flow/workflows" ) +type Suite interface { + T() *testing.T + Connector() *connpostgres.PostgresConnector + Suffix() string +} + +type RowSource interface { + Suite + GetRows(table, cols string) (*model.QRecordBatch, error) +} + func RegisterWorkflowsAndActivities(t *testing.T, env *testsuite.TestWorkflowEnvironment) { t.Helper() @@ -93,21 +104,21 @@ func EnvTrue(t *testing.T, env *testsuite.TestWorkflowEnvironment, val bool) { } } -func GetPgRows(conn *pgx.Conn, suffix string, table string, cols string) (*model.QRecordBatch, error) { - pgQueryExecutor := connpostgres.NewQRepQueryExecutor(conn, context.Background(), "testflow", "testpart") +func GetPgRows(conn *connpostgres.PostgresConnector, suffix string, table string, cols string) (*model.QRecordBatch, error) { + pgQueryExecutor := conn.NewQRepQueryExecutor("testflow", "testpart") pgQueryExecutor.SetTestEnv(true) return pgQueryExecutor.ExecuteAndProcessQuery( context.Background(), - fmt.Sprintf(`SELECT %s FROM e2e_test_%s."%s" ORDER BY id`, cols, suffix, table), + fmt.Sprintf(`SELECT %s FROM e2e_test_%s.%s ORDER BY id`, cols, suffix, connpostgres.QuoteIdentifier(table)), ) } -func RequireEqualTables(suite e2eshared.RowSource, table string, cols string) { +func RequireEqualTables(suite RowSource, table string, cols string) { t := suite.T() t.Helper() - pgRows, err := GetPgRows(suite.Conn(), suite.Suffix(), table, cols) + pgRows, err := GetPgRows(suite.Connector(), suite.Suffix(), table, cols) require.NoError(t, err) rows, err := suite.GetRows(table, cols) @@ -116,11 +127,11 @@ func RequireEqualTables(suite e2eshared.RowSource, table string, cols string) { require.True(t, e2eshared.CheckEqualRecordBatches(t, pgRows, rows)) } -func EnvEqualTables(env *testsuite.TestWorkflowEnvironment, suite e2eshared.RowSource, table string, cols string) { +func EnvEqualTables(env *testsuite.TestWorkflowEnvironment, suite RowSource, table string, cols string) { t := suite.T() t.Helper() - pgRows, err := GetPgRows(suite.Conn(), suite.Suffix(), table, cols) + pgRows, err := GetPgRows(suite.Connector(), suite.Suffix(), table, cols) EnvNoError(t, env, err) rows, err := suite.GetRows(table, cols) @@ -131,7 +142,7 @@ func EnvEqualTables(env *testsuite.TestWorkflowEnvironment, suite e2eshared.RowS func EnvWaitForEqualTables( env *testsuite.TestWorkflowEnvironment, - suite e2eshared.RowSource, + suite RowSource, reason string, table string, cols string, @@ -142,7 +153,7 @@ func EnvWaitForEqualTables( func EnvWaitForEqualTablesWithNames( env *testsuite.TestWorkflowEnvironment, - suite e2eshared.RowSource, + suite RowSource, reason string, srcTable string, dstTable string, @@ -154,7 +165,7 @@ func EnvWaitForEqualTablesWithNames( EnvWaitFor(t, env, 3*time.Minute, reason, func() bool { t.Helper() - pgRows, err := GetPgRows(suite.Conn(), suite.Suffix(), srcTable, cols) + pgRows, err := GetPgRows(suite.Connector(), suite.Suffix(), srcTable, cols) if err != nil { return false } @@ -506,7 +517,7 @@ func GetOwnersSelectorStringsSF() [2]string { sfFields := make([]string, 0, len(schema.Fields)) for _, field := range schema.Fields { pgFields = append(pgFields, fmt.Sprintf(`"%s"`, field.Name)) - if strings.Contains(field.Name, "geo") { + if strings.HasPrefix(field.Name, "geo") { colName := connsnowflake.SnowflakeIdentifierNormalize(field.Name) // Have to apply a WKT transformation here, diff --git a/flow/e2eshared/e2eshared.go b/flow/e2eshared/e2eshared.go index ce5c25d4fe..721302551e 100644 --- a/flow/e2eshared/e2eshared.go +++ b/flow/e2eshared/e2eshared.go @@ -8,23 +8,10 @@ import ( "strings" "testing" - "github.com/jackc/pgx/v5" - "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" ) -type Suite interface { - T() *testing.T - Conn() *pgx.Conn - Suffix() string -} - -type RowSource interface { - Suite - GetRows(table, cols string) (*model.QRecordBatch, error) -} - func RunSuite[T any](t *testing.T, setup func(t *testing.T) T, teardown func(T)) { t.Helper() t.Parallel() @@ -80,7 +67,7 @@ func CheckQRecordEquality(t *testing.T, q []qvalue.QValue, other []qvalue.QValue for i, entry := range q { otherEntry := other[i] if !entry.Equals(otherEntry) { - t.Logf("entry %d: %v != %v", i, entry, otherEntry) + t.Logf("entry %d: %T %v != %T %v", i, entry.Value, entry, otherEntry.Value, otherEntry) return false } } diff --git a/flow/geo/geo.go b/flow/geo/geo.go index 882ed97dba..9640173973 100644 --- a/flow/geo/geo.go +++ b/flow/geo/geo.go @@ -43,18 +43,3 @@ func GeoToWKB(wkt string) ([]byte, error) { return geometryObject.ToWKB(), nil } - -// compares WKTs -func GeoCompare(wkt1, wkt2 string) bool { - geom1, geoErr := geom.NewGeomFromWKT(wkt1) - if geoErr != nil { - return false - } - - geom2, geoErr := geom.NewGeomFromWKT(wkt2) - if geoErr != nil { - return false - } - - return geom1.Equals(geom2) -} diff --git a/flow/model/qvalue/qvalue.go b/flow/model/qvalue/qvalue.go index b0b556b3d7..1ad07150b6 100644 --- a/flow/model/qvalue/qvalue.go +++ b/flow/model/qvalue/qvalue.go @@ -2,6 +2,7 @@ package qvalue import ( "bytes" + "encoding/json" "fmt" "math" "math/big" @@ -11,8 +12,9 @@ import ( "cloud.google.com/go/civil" "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" + geom "github.com/twpayne/go-geos" - "github.com/PeerDB-io/peer-flow/geo" hstore_util "github.com/PeerDB-io/peer-flow/hstore" ) @@ -23,6 +25,12 @@ type QValue struct { } func (q QValue) Equals(other QValue) bool { + if q.Kind == QValueKindJSON { + return true // TODO fix + } else if q.Value == nil && other.Value == nil { + return true + } + switch q.Kind { case QValueKindEmpty: return other.Kind == QValueKindEmpty @@ -66,6 +74,10 @@ func (q QValue) Equals(other QValue) bool { return compareJSON(q.Value, other.Value) case QValueKindBit: return compareBit(q.Value, other.Value) + case QValueKindGeometry, QValueKindGeography: + return compareGeometry(q.Value, other.Value) + case QValueKindHStore: + return compareHstore(q.Value, other.Value) case QValueKindArrayFloat32: return compareNumericArrays(q.Value, other.Value) case QValueKindArrayFloat64: @@ -82,9 +94,9 @@ func (q QValue) Equals(other QValue) bool { return compareBoolArrays(q.Value, other.Value) case QValueKindArrayString: return compareArrayString(q.Value, other.Value) + default: + return false } - - return false } func (q QValue) GoTimeConvert() (string, error) { @@ -222,10 +234,6 @@ func compareBytes(value1, value2 interface{}) bool { } func compareNumeric(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - rat1, ok1 := getRat(value1) rat2, ok2 := getRat(value2) @@ -250,20 +258,47 @@ func compareString(value1, value2 interface{}) bool { if !ok1 || !ok2 { return false } - if str1 == str2 { - return true - } + return str1 == str2 +} - // Catch matching HStore - parsedHstore1, err := hstore_util.ParseHstore(str1) - if err == nil && parsedHstore1 == str2 { - return true +func compareHstore(value1, value2 interface{}) bool { + str2 := value2.(string) + switch v1 := value1.(type) { + case pgtype.Hstore: + bytes, err := json.Marshal(v1) + if err != nil { + panic(err) + } + return string(bytes) == str2 + case string: + parsedHStore1, err := hstore_util.ParseHstore(v1) + if err != nil { + panic(err) + } + return parsedHStore1 == str2 + default: + panic(fmt.Sprintf("invalid hstore value type %T: %v", value1, value1)) } +} - // Catch matching WKB(in Postgres)-WKT(in destination) geo values - geoConvertedWKT, err := geo.GeoValidate(str1) +func compareGeometry(value1, value2 interface{}) bool { + geo2, err := geom.NewGeomFromWKT(value2.(string)) + if err != nil { + panic(err) + } - return err == nil && geo.GeoCompare(geoConvertedWKT, str2) + switch v1 := value1.(type) { + case *geom.Geom: + return v1.Equals(geo2) + case string: + geo1, err := geom.NewGeomFromWKT(v1) + if err != nil { + panic(err) + } + return geo1.Equals(geo2) + default: + panic(fmt.Sprintf("invalid geometry value type %T: %v", value1, value1)) + } } func compareStruct(value1, value2 interface{}) bool { @@ -299,7 +334,7 @@ func compareBit(value1, value2 interface{}) bool { return false } - return bit1^bit2 == 0 + return bit1 == bit2 } func compareNumericArrays(value1, value2 interface{}) bool { @@ -546,8 +581,7 @@ func getBytes(v interface{}) ([]byte, bool) { case string: return []byte(value), true case nil: - // return empty byte array - return []byte{}, true + return nil, true default: return nil, false }