From 84aaef251cdc1752e2af9cd4b786feddd3335e84 Mon Sep 17 00:00:00 2001 From: Amogh Bharadwaj Date: Thu, 18 Jan 2024 01:28:56 +0530 Subject: [PATCH] Support more data types (#1089) This PR adds support for the following types across PG, SF and BQ: - Array of Boolean - Array of Date - Array of Timestamp(/TZ) - Array of Int16 (smallint) - CIDR - MacAddr - INET The latter 3 is more relevant for Postgres. This PR achieves the correct AVRO type mapping for Date in BigQuery - eliminating the hack which was being done in merge transforms. This PR also writes date and time values as human readable text for BQ QRep Tests added as well Fixes #666 Fixes #20 --- .../bigquery/avro_transform_test.go | 5 - .../bigquery/merge_stmt_generator.go | 6 +- flow/connectors/bigquery/qrep_avro_sync.go | 96 +++++---- flow/connectors/bigquery/qvalue_convert.go | 8 +- flow/connectors/postgres/cdc.go | 4 +- flow/connectors/postgres/qvalue_convert.go | 148 +++++++++++++- flow/connectors/sql/query_executor.go | 19 +- flow/e2e/bigquery/bigquery_helper.go | 26 ++- flow/e2e/bigquery/peer_flow_bq_test.go | 13 +- flow/e2e/postgres/peer_flow_pg_test.go | 75 +++++++ flow/e2e/snowflake/peer_flow_sf_test.go | 12 +- flow/e2e/test_utils.go | 15 +- flow/model/model.go | 26 +++ flow/model/qrecord_batch.go | 92 +++++---- flow/model/qvalue/avro_converter.go | 190 +++++++++++++++++- flow/model/qvalue/kind.go | 35 +++- flow/model/qvalue/qvalue.go | 81 +++++++- 17 files changed, 730 insertions(+), 121 deletions(-) diff --git a/flow/connectors/bigquery/avro_transform_test.go b/flow/connectors/bigquery/avro_transform_test.go index 0a9332fc87..75dc7b65f1 100644 --- a/flow/connectors/bigquery/avro_transform_test.go +++ b/flow/connectors/bigquery/avro_transform_test.go @@ -17,10 +17,6 @@ func TestAvroTransform(t *testing.T) { Name: "col2", Type: bigquery.JSONFieldType, }, - &bigquery.FieldSchema{ - Name: "col3", - Type: bigquery.DateFieldType, - }, &bigquery.FieldSchema{ Name: "camelCol4", Type: bigquery.StringFieldType, @@ -34,7 +30,6 @@ func TestAvroTransform(t *testing.T) { expectedTransformCols := []string{ "ST_GEOGFROMTEXT(`col1`) AS `col1`", "PARSE_JSON(`col2`,wide_number_mode=>'round') AS `col2`", - "CAST(`col3` AS DATE) AS `col3`", "`camelCol4`", } transformedCols := getTransformedColumns(dstSchema, "sync_col", "del_col") diff --git a/flow/connectors/bigquery/merge_stmt_generator.go b/flow/connectors/bigquery/merge_stmt_generator.go index eb861b4570..e93a139a73 100644 --- a/flow/connectors/bigquery/merge_stmt_generator.go +++ b/flow/connectors/bigquery/merge_stmt_generator.go @@ -53,8 +53,10 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string { case qvalue.QValueKindBytes, qvalue.QValueKindBit: castStmt = fmt.Sprintf("FROM_BASE64(JSON_VALUE(_peerdb_data,'$.%s')) AS `%s`", colName, shortCol) - case qvalue.QValueKindArrayFloat32, qvalue.QValueKindArrayFloat64, - qvalue.QValueKindArrayInt32, qvalue.QValueKindArrayInt64, qvalue.QValueKindArrayString: + case qvalue.QValueKindArrayFloat32, qvalue.QValueKindArrayFloat64, qvalue.QValueKindArrayInt16, + qvalue.QValueKindArrayInt32, qvalue.QValueKindArrayInt64, qvalue.QValueKindArrayString, + qvalue.QValueKindArrayBoolean, qvalue.QValueKindArrayTimestamp, qvalue.QValueKindArrayTimestampTZ, + qvalue.QValueKindArrayDate: castStmt = fmt.Sprintf("ARRAY(SELECT CAST(element AS %s) FROM "+ "UNNEST(CAST(JSON_VALUE_ARRAY(_peerdb_data, '$.%s') AS ARRAY)) AS element WHERE element IS NOT null) AS `%s`", bqType, colName, shortCol) diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index bf784dd494..ac609002d8 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -136,9 +136,6 @@ func getTransformedColumns(dstSchema *bigquery.Schema, syncedAtCol string, softD case bigquery.JSONFieldType: transformedColumns = append(transformedColumns, fmt.Sprintf("PARSE_JSON(`%s`,wide_number_mode=>'round') AS `%s`", col.Name, col.Name)) - case bigquery.DateFieldType: - transformedColumns = append(transformedColumns, - fmt.Sprintf("CAST(`%s` AS DATE) AS `%s`", col.Name, col.Name)) default: transformedColumns = append(transformedColumns, fmt.Sprintf("`%s`", col.Name)) } @@ -290,9 +287,9 @@ func DefineAvroSchema(dstTableName string, func GetAvroType(bqField *bigquery.FieldSchema) (interface{}, error) { considerRepeated := func(typ string, repeated bool) interface{} { if repeated { - return map[string]interface{}{ - "type": "array", - "items": typ, + return qvalue.AvroSchemaArray{ + Type: "array", + Items: typ, } } else { return typ @@ -309,64 +306,79 @@ func GetAvroType(bqField *bigquery.FieldSchema) (interface{}, error) { case bigquery.FloatFieldType: return considerRepeated("double", bqField.Repeated), nil case bigquery.BooleanFieldType: - return "boolean", nil + return considerRepeated("boolean", bqField.Repeated), nil case bigquery.TimestampFieldType: - return map[string]string{ - "type": "long", - "logicalType": "timestamp-micros", - }, nil + timestampSchema := qvalue.AvroSchemaField{ + Type: "long", + LogicalType: "timestamp-micros", + } + if bqField.Repeated { + return qvalue.AvroSchemaComplexArray{ + Type: "array", + Items: timestampSchema, + }, nil + } + return timestampSchema, nil case bigquery.DateFieldType: - return map[string]string{ - "type": "long", - "logicalType": "timestamp-micros", - }, nil + dateSchema := qvalue.AvroSchemaField{ + Type: "int", + LogicalType: "date", + } + if bqField.Repeated { + return qvalue.AvroSchemaComplexArray{ + Type: "array", + Items: dateSchema, + }, nil + } + return dateSchema, nil + case bigquery.TimeFieldType: - return map[string]string{ - "type": "long", - "logicalType": "timestamp-micros", + return qvalue.AvroSchemaField{ + Type: "long", + LogicalType: "timestamp-micros", }, nil case bigquery.DateTimeFieldType: - return map[string]interface{}{ - "type": "record", - "name": "datetime", - "fields": []map[string]string{ + return qvalue.AvroSchemaRecord{ + Type: "record", + Name: "datetime", + Fields: []qvalue.AvroSchemaField{ { - "name": "date", - "type": "int", - "logicalType": "date", + Name: "date", + Type: "int", + LogicalType: "date", }, { - "name": "time", - "type": "long", - "logicalType": "time-micros", + Name: "time", + Type: "long", + LogicalType: "time-micros", }, }, }, nil case bigquery.NumericFieldType: - return map[string]interface{}{ - "type": "bytes", - "logicalType": "decimal", - "precision": 38, - "scale": 9, + return qvalue.AvroSchemaNumeric{ + Type: "bytes", + LogicalType: "decimal", + Precision: 38, + Scale: 9, }, nil case bigquery.RecordFieldType: - avroFields := []map[string]interface{}{} + avroFields := []qvalue.AvroSchemaField{} for _, bqSubField := range bqField.Schema { avroType, err := GetAvroType(bqSubField) if err != nil { return nil, err } - avroFields = append(avroFields, map[string]interface{}{ - "name": bqSubField.Name, - "type": avroType, + avroFields = append(avroFields, qvalue.AvroSchemaField{ + Name: bqSubField.Name, + Type: avroType, }) } - return map[string]interface{}{ - "type": "record", - "name": bqField.Name, - "fields": avroFields, + return qvalue.AvroSchemaRecord{ + Type: "record", + Name: bqField.Name, + Fields: avroFields, }, nil - // TODO(kaushik/sai): Add other field types as needed + default: return nil, fmt.Errorf("unsupported BigQuery field type: %s", bqField.Type) } diff --git a/flow/connectors/bigquery/qvalue_convert.go b/flow/connectors/bigquery/qvalue_convert.go index 7e98eabd15..d4e5032182 100644 --- a/flow/connectors/bigquery/qvalue_convert.go +++ b/flow/connectors/bigquery/qvalue_convert.go @@ -42,10 +42,16 @@ func qValueKindToBigQueryType(colType string) bigquery.FieldType { // For Arrays we return the types of the individual elements, // and wherever this function is called, the 'Repeated' attribute of // FieldSchema must be set to true. - case qvalue.QValueKindArrayInt32, qvalue.QValueKindArrayInt64: + case qvalue.QValueKindArrayInt16, qvalue.QValueKindArrayInt32, qvalue.QValueKindArrayInt64: return bigquery.IntegerFieldType case qvalue.QValueKindArrayFloat32, qvalue.QValueKindArrayFloat64: return bigquery.FloatFieldType + case qvalue.QValueKindArrayBoolean: + return bigquery.BooleanFieldType + case qvalue.QValueKindArrayTimestamp, qvalue.QValueKindArrayTimestampTZ: + return bigquery.TimestampFieldType + case qvalue.QValueKindArrayDate: + return bigquery.DateFieldType case qvalue.QValueKindGeography, qvalue.QValueKindGeometry, qvalue.QValueKindPoint: return bigquery.GeographyFieldType // rest will be strings diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 4fd1b3dd79..42baaf3199 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -741,8 +741,8 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma var parsedData any var err error if dt, ok := p.typeMap.TypeForOID(dataType); ok { - if dt.Name == "uuid" { - // below is required to decode uuid to string + if dt.Name == "uuid" || dt.Name == "cidr" || dt.Name == "inet" || dt.Name == "macaddr" { + // below is required to decode above types to string parsedData, err = dt.Codec.DecodeDatabaseSQLValue(p.typeMap, dataType, pgtype.TextFormatCode, data) } else { parsedData, err = dt.Codec.DecodeValue(p.typeMap, dataType, formatCode, data) diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index 83feea3bef..d357495810 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -41,6 +41,12 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { return qvalue.QValueKindTime case pgtype.DateOID: return qvalue.QValueKindDate + case pgtype.CIDROID: + return qvalue.QValueKindCIDR + case pgtype.MacaddrOID: + return qvalue.QValueKindMacaddr + case pgtype.InetOID: + return qvalue.QValueKindINET case pgtype.TimestampOID: return qvalue.QValueKindTimestamp case pgtype.TimestamptzOID: @@ -50,7 +56,7 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { case pgtype.BitOID, pgtype.VarbitOID: return qvalue.QValueKindBit case pgtype.Int2ArrayOID: - return qvalue.QValueKindArrayInt32 + return qvalue.QValueKindArrayInt16 case pgtype.Int4ArrayOID: return qvalue.QValueKindArrayInt32 case pgtype.Int8ArrayOID: @@ -61,6 +67,14 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { return qvalue.QValueKindArrayFloat32 case pgtype.Float8ArrayOID: return qvalue.QValueKindArrayFloat64 + case pgtype.BoolArrayOID: + return qvalue.QValueKindArrayBoolean + case pgtype.DateArrayOID: + return qvalue.QValueKindArrayDate + case pgtype.TimestampArrayOID: + return qvalue.QValueKindArrayTimestamp + case pgtype.TimestamptzArrayOID: + return qvalue.QValueKindArrayTimestampTZ case pgtype.TextArrayOID, pgtype.VarcharArrayOID, pgtype.BPCharArrayOID: return qvalue.QValueKindArrayString default: @@ -110,13 +124,15 @@ func qValueKindToPostgresType(qvalueKind string) string { case qvalue.QValueKindBytes: return "BYTEA" case qvalue.QValueKindJSON: - return "JSONB" + return "JSON" case qvalue.QValueKindHStore: return "HSTORE" case qvalue.QValueKindUUID: return "UUID" case qvalue.QValueKindTime: return "TIME" + case qvalue.QValueKindTimeTZ: + return "TIMETZ" case qvalue.QValueKindDate: return "DATE" case qvalue.QValueKindTimestamp: @@ -127,6 +143,14 @@ func qValueKindToPostgresType(qvalueKind string) string { return "NUMERIC" case qvalue.QValueKindBit: return "BIT" + case qvalue.QValueKindINET: + return "INET" + case qvalue.QValueKindCIDR: + return "CIDR" + case qvalue.QValueKindMacaddr: + return "MACADDR" + case qvalue.QValueKindArrayInt16: + return "SMALLINT[]" case qvalue.QValueKindArrayInt32: return "INTEGER[]" case qvalue.QValueKindArrayInt64: @@ -135,6 +159,14 @@ func qValueKindToPostgresType(qvalueKind string) string { return "REAL[]" case qvalue.QValueKindArrayFloat64: return "DOUBLE PRECISION[]" + case qvalue.QValueKindArrayDate: + return "DATE[]" + case qvalue.QValueKindArrayTimestamp: + return "TIMESTAMP[]" + case qvalue.QValueKindArrayTimestampTZ: + return "TIMESTAMPTZ[]" + case qvalue.QValueKindArrayBoolean: + return "BOOLEAN[]" case qvalue.QValueKindArrayString: return "TEXT[]" case qvalue.QValueKindGeography: @@ -241,6 +273,33 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( default: return qvalue.QValue{}, fmt.Errorf("failed to parse UUID: %v", value) } + case qvalue.QValueKindINET: + switch value.(type) { + case string: + val = qvalue.QValue{Kind: qvalue.QValueKindINET, Value: value} + case [16]byte: + val = qvalue.QValue{Kind: qvalue.QValueKindINET, Value: value} + default: + return qvalue.QValue{}, fmt.Errorf("failed to parse INET: %v", value) + } + case qvalue.QValueKindCIDR: + switch value.(type) { + case string: + val = qvalue.QValue{Kind: qvalue.QValueKindCIDR, Value: value} + case [16]byte: + val = qvalue.QValue{Kind: qvalue.QValueKindCIDR, Value: value} + default: + return qvalue.QValue{}, fmt.Errorf("failed to parse CIDR: %v", value) + } + case qvalue.QValueKindMacaddr: + switch value.(type) { + case string: + val = qvalue.QValue{Kind: qvalue.QValueKindMacaddr, Value: value} + case [16]byte: + val = qvalue.QValue{Kind: qvalue.QValueKindMacaddr, Value: value} + default: + return qvalue.QValue{}, fmt.Errorf("failed to parse MACADDR: %v", value) + } case qvalue.QValueKindBytes: rawBytes := value.([]byte) val = qvalue.QValue{Kind: qvalue.QValueKindBytes, Value: rawBytes} @@ -292,6 +351,23 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( default: return qvalue.QValue{}, fmt.Errorf("failed to parse array float64: %v", value) } + case qvalue.QValueKindArrayInt16: + switch v := value.(type) { + case pgtype.Array[int16]: + if v.Valid { + val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt16, Value: v.Elements} + } + case []int16: + val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt16, Value: v} + case []interface{}: + int16Array := make([]int16, len(v)) + for i, val := range v { + int16Array[i] = val.(int16) + } + val = qvalue.QValue{Kind: qvalue.QValueKindArrayInt16, Value: int16Array} + default: + return qvalue.QValue{}, fmt.Errorf("failed to parse array int16: %v", value) + } case qvalue.QValueKindArrayInt32: switch v := value.(type) { case pgtype.Array[int32]: @@ -326,6 +402,74 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( default: return qvalue.QValue{}, fmt.Errorf("failed to parse array int64: %v", value) } + case qvalue.QValueKindArrayDate: + switch v := value.(type) { + case pgtype.Array[time.Time]: + if v.Valid { + val = qvalue.QValue{Kind: qvalue.QValueKindArrayDate, Value: v.Elements} + } + case []time.Time: + val = qvalue.QValue{Kind: qvalue.QValueKindArrayDate, Value: v} + case []interface{}: + dateArray := make([]time.Time, len(v)) + for i, val := range v { + dateArray[i] = val.(time.Time) + } + val = qvalue.QValue{Kind: qvalue.QValueKindArrayDate, Value: dateArray} + default: + return qvalue.QValue{}, fmt.Errorf("failed to parse array date: %v", value) + } + case qvalue.QValueKindArrayTimestamp: + switch v := value.(type) { + case pgtype.Array[time.Time]: + if v.Valid { + val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestamp, Value: v.Elements} + } + case []time.Time: + val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestamp, Value: v} + case []interface{}: + timestampArray := make([]time.Time, len(v)) + for i, val := range v { + timestampArray[i] = val.(time.Time) + } + val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestamp, Value: timestampArray} + default: + return qvalue.QValue{}, fmt.Errorf("failed to parse array timestamp: %v", value) + } + case qvalue.QValueKindArrayTimestampTZ: + switch v := value.(type) { + case pgtype.Array[time.Time]: + if v.Valid { + val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestampTZ, Value: v.Elements} + } + case []time.Time: + val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestampTZ, Value: v} + case []interface{}: + timestampTZArray := make([]time.Time, len(v)) + for i, val := range v { + timestampTZArray[i] = val.(time.Time) + } + val = qvalue.QValue{Kind: qvalue.QValueKindArrayTimestampTZ, Value: timestampTZArray} + default: + return qvalue.QValue{}, fmt.Errorf("failed to parse array timestamptz: %v", value) + } + case qvalue.QValueKindArrayBoolean: + switch v := value.(type) { + case pgtype.Array[bool]: + if v.Valid { + val = qvalue.QValue{Kind: qvalue.QValueKindArrayBoolean, Value: v.Elements} + } + case []bool: + val = qvalue.QValue{Kind: qvalue.QValueKindArrayBoolean, Value: v} + case []interface{}: + boolArray := make([]bool, len(v)) + for i, val := range v { + boolArray[i] = val.(bool) + } + val = qvalue.QValue{Kind: qvalue.QValueKindArrayBoolean, Value: boolArray} + default: + return qvalue.QValue{}, fmt.Errorf("failed to parse array boolean: %v", value) + } case qvalue.QValueKindArrayString: switch v := value.(type) { case pgtype.Array[string]: diff --git a/flow/connectors/sql/query_executor.go b/flow/connectors/sql/query_executor.go index 0e5c0a08c1..751230a643 100644 --- a/flow/connectors/sql/query_executor.go +++ b/flow/connectors/sql/query_executor.go @@ -432,6 +432,8 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { kind = qvalue.QValueKindArrayFloat32 case float64: kind = qvalue.QValueKindArrayFloat64 + case int16: + kind = qvalue.QValueKindArrayInt16 case int32: kind = qvalue.QValueKindArrayInt32 case int64: @@ -447,13 +449,12 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { return qvalue.QValue{Kind: qvalue.QValueKindJSON, Value: vstring}, nil case qvalue.QValueKindHStore: - // TODO fix this. return qvalue.QValue{Kind: qvalue.QValueKindHStore, Value: val}, nil case qvalue.QValueKindArrayFloat32, qvalue.QValueKindArrayFloat64, + qvalue.QValueKindArrayInt16, qvalue.QValueKindArrayInt32, qvalue.QValueKindArrayInt64, qvalue.QValueKindArrayString: - // TODO fix this. return toQValueArray(kind, val) } @@ -492,6 +493,20 @@ func toQValueArray(kind qvalue.QValueKind, value interface{}) (qvalue.QValue, er return qvalue.QValue{}, fmt.Errorf("failed to parse array float64: %v", value) } + case qvalue.QValueKindArrayInt16: + switch v := value.(type) { + case []int16: + result = v + case []interface{}: + int16Array := make([]int16, len(v)) + for i, val := range v { + int16Array[i] = val.(int16) + } + result = int16Array + default: + return qvalue.QValue{}, fmt.Errorf("failed to parse array int16: %v", value) + } + case qvalue.QValueKindArrayInt32: switch v := value.(type) { case []int32: diff --git a/flow/e2e/bigquery/bigquery_helper.go b/flow/e2e/bigquery/bigquery_helper.go index 267ff3d69d..445488266e 100644 --- a/flow/e2e/bigquery/bigquery_helper.go +++ b/flow/e2e/bigquery/bigquery_helper.go @@ -233,7 +233,7 @@ func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) { } firstElement := v[0] - switch firstElement.(type) { + switch et := firstElement.(type) { case int, int32: var arr []int32 for _, val := range v { @@ -264,10 +264,30 @@ func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) { arr = append(arr, val.(string)) } return qvalue.QValue{Kind: qvalue.QValueKindArrayString, Value: arr}, nil + case time.Time: + var arr []time.Time + for _, val := range v { + arr = append(arr, val.(time.Time)) + } + return qvalue.QValue{Kind: qvalue.QValueKindArrayTimestamp, Value: arr}, nil + case civil.Date: + var arr []civil.Date + for _, val := range v { + arr = append(arr, val.(civil.Date)) + } + return qvalue.QValue{Kind: qvalue.QValueKindArrayDate, Value: arr}, nil + case bool: + var arr []bool + + for _, val := range v { + arr = append(arr, val.(bool)) + } + return qvalue.QValue{Kind: qvalue.QValueKindArrayBoolean, Value: arr}, nil + default: + // If type is unsupported, return error + return qvalue.QValue{}, fmt.Errorf("bqHelper unsupported type %T", et) } - // If type is unsupported, return error - return qvalue.QValue{}, fmt.Errorf("bqHelper unsupported type %T", v) case nil: return qvalue.QValue{Kind: qvalue.QValueKindInvalid, Value: nil}, nil default: diff --git a/flow/e2e/bigquery/peer_flow_bq_test.go b/flow/e2e/bigquery/peer_flow_bq_test.go index 398b673488..efa47366db 100644 --- a/flow/e2e/bigquery/peer_flow_bq_test.go +++ b/flow/e2e/bigquery/peer_flow_bq_test.go @@ -708,7 +708,8 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Types_BQ() { c14 INET,c15 INTEGER,c16 INTERVAL,c17 JSON,c18 JSONB,c21 MACADDR,c22 MONEY, c23 NUMERIC,c24 OID,c28 REAL,c29 SMALLINT,c30 SMALLSERIAL,c31 SERIAL,c32 TEXT, c33 TIMESTAMP,c34 TIMESTAMPTZ,c35 TIME, c36 TIMETZ,c37 TSQUERY,c38 TSVECTOR, - c39 TXID_SNAPSHOT,c40 UUID,c41 XML, c42 INT[], c43 FLOAT[], c44 TEXT[], c45 mood, c46 HSTORE); + c39 TXID_SNAPSHOT,c40 UUID,c41 XML, c42 INT[], c43 FLOAT[], c44 TEXT[], c45 mood, c46 HSTORE, + c47 DATE[], c48 TIMESTAMPTZ[], c49 TIMESTAMP[], c50 BOOLEAN[], c51 SMALLINT[]); `, srcTableName)) require.NoError(s.t, err) @@ -745,7 +746,12 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Types_BQ() { ARRAY[10299301,2579827], ARRAY[0.0003, 8902.0092], ARRAY['hello','bye'],'happy', - 'key1=>value1, key2=>NULL'::hstore + 'key1=>value1, key2=>NULL'::hstore, + '{2020-01-01, 2020-01-02}'::date[], + '{"2020-01-01 01:01:01+00", "2020-01-02 01:01:01+00"}'::timestamptz[], + '{"2020-01-01 01:01:01", "2020-01-02 01:01:01"}'::timestamp[], + '{true, false}'::boolean[], + '{1, 2}'::smallint[]; `, srcTableName)) e2e.EnvNoError(s.t, env, err) }() @@ -763,7 +769,8 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Types_BQ() { "c41", "c1", "c2", "c3", "c4", "c6", "c39", "c40", "id", "c9", "c11", "c12", "c13", "c14", "c15", "c16", "c17", "c18", "c21", "c22", "c23", "c24", "c28", "c29", "c30", "c31", "c33", "c34", "c35", "c36", - "c37", "c38", "c7", "c8", "c32", "c42", "c43", "c44", "c45", "c46", + "c37", "c38", "c7", "c8", "c32", "c42", "c43", "c44", "c45", "c46", "c47", "c48", + "c49", "c50", "c51", }) if err != nil { s.t.Log(err) diff --git a/flow/e2e/postgres/peer_flow_pg_test.go b/flow/e2e/postgres/peer_flow_pg_test.go index 568bb36a5a..b02d03df05 100644 --- a/flow/e2e/postgres/peer_flow_pg_test.go +++ b/flow/e2e/postgres/peer_flow_pg_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -187,6 +188,80 @@ func (s PeerFlowE2ETestSuitePG) Test_Geospatial_PG() { require.NoError(s.t, err) } +func (s PeerFlowE2ETestSuitePG) Test_Types_PG() { + env := e2e.NewTemporalTestWorkflowEnvironment() + e2e.RegisterWorkflowsAndActivities(s.t, env) + + srcTableName := s.attachSchemaSuffix("test_types_pg") + dstTableName := s.attachSchemaSuffix("test_types_pg_dst") + + _, err := s.pool.Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s (id serial PRIMARY KEY,c1 BIGINT,c2 BIT,c4 BOOLEAN, + c7 CHARACTER,c8 varchar,c9 CIDR,c11 DATE,c12 FLOAT,c13 DOUBLE PRECISION, + c14 INET,c15 INTEGER,c21 MACADDR, + c29 SMALLINT,c32 TEXT, + c33 TIMESTAMP,c34 TIMESTAMPTZ,c35 TIME, c36 TIMETZ, + c40 UUID, c42 INT[], c43 FLOAT[], c44 TEXT[], + c46 DATE[], c47 TIMESTAMPTZ[], c48 TIMESTAMP[], c49 BOOLEAN[], c50 SMALLINT[]); + `, srcTableName)) + require.NoError(s.t, err) + + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: s.attachSuffix("test_types_pg"), + TableNameMapping: map[string]string{srcTableName: dstTableName}, + PostgresPort: e2e.PostgresPort, + Destination: s.peer, + } + + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() + + limits := peerflow.CDCFlowLimits{ + ExitAfterRecords: 1, + MaxBatchSize: 100, + } + + go func() { + e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) + _, err = s.pool.Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s SELECT 2,2,b'1', + true,'s','test','1.1.10.2'::cidr, + CURRENT_DATE,1.23,1.234,'192.168.1.5'::inet,1, + '08:00:2b:01:02:03'::macaddr, + 1,'test',now(),now(),now()::time,now()::timetz, + '66073c38-b8df-4bdb-bbca-1c97596b8940'::uuid, + ARRAY[10299301,2579827], + ARRAY[0.0003, 8902.0092], + ARRAY['hello','bye'], + '{2020-01-01, 2020-01-02}'::date[], + '{"2020-01-01 01:01:01+00", "2020-01-02 01:01:01+00"}'::timestamptz[], + '{"2020-01-01 01:01:01", "2020-01-02 01:01:01"}'::timestamp[], + '{true, false}'::boolean[], + '{1,2}'::smallint[]; + `, srcTableName)) + e2e.EnvNoError(s.t, env, err) + + s.t.Log("Inserted 1 row into the source table") + }() + + env.ExecuteWorkflow(peerflow.CDCFlowWorkflowWithConfig, flowConnConfig, &limits, nil) + + // Verify workflow completes without error + require.True(s.t, env.IsWorkflowCompleted()) + err = env.GetWorkflowError() + + // allow only continue as new error + require.Contains(s.t, err.Error(), "continue as new") + + allCols := []string{ + "c1", "c2", "c4", + "c40", "id", "c9", "c11", "c12", "c13", "c14", "c15", + "c21", "c29", "c33", "c34", "c35", "c36", + "c7", "c8", "c32", "c42", "c43", "c44", "c46", "c47", "c48", "c49", "c50", + } + err = s.comparePGTables(srcTableName, dstTableName, strings.Join(allCols, ",")) + require.NoError(s.t, err) +} + func (s PeerFlowE2ETestSuitePG) Test_Enums_PG() { env := e2e.NewTemporalTestWorkflowEnvironment() e2e.RegisterWorkflowsAndActivities(s.t, env) diff --git a/flow/e2e/snowflake/peer_flow_sf_test.go b/flow/e2e/snowflake/peer_flow_sf_test.go index 056de3a1fb..bb4d64636a 100644 --- a/flow/e2e/snowflake/peer_flow_sf_test.go +++ b/flow/e2e/snowflake/peer_flow_sf_test.go @@ -676,8 +676,8 @@ func (s PeerFlowE2ETestSuiteSF) Test_Types_SF() { c23 NUMERIC,c24 OID,c28 REAL,c29 SMALLINT,c30 SMALLSERIAL,c31 SERIAL,c32 TEXT, c33 TIMESTAMP,c34 TIMESTAMPTZ,c35 TIME, c36 TIMETZ,c37 TSQUERY,c38 TSVECTOR, c39 TXID_SNAPSHOT,c40 UUID,c41 XML, c42 GEOMETRY(POINT), c43 GEOGRAPHY(POINT), - c44 GEOGRAPHY(POLYGON), c45 GEOGRAPHY(LINESTRING), c46 GEOMETRY(LINESTRING), c47 GEOMETRY(POLYGON), - c48 mood, c49 HSTORE); + c44 GEOGRAPHY(POLYGON), c45 GEOGRAPHY(LINESTRING), c46 GEOMETRY(LINESTRING), c47 GEOMETRY(POLYGON), + c48 mood, c49 HSTORE, c50 DATE[], c51 TIMESTAMPTZ[], c52 TIMESTAMP[], c53 BOOLEAN[],c54 SMALLINT[]); `, srcTableName)) require.NoError(s.t, err) @@ -712,7 +712,12 @@ func (s PeerFlowE2ETestSuiteSF) Test_Types_SF() { '66073c38-b8df-4bdb-bbca-1c97596b8940'::uuid,xmlcomment('hello'), 'POINT(1 2)','POINT(40.7128 -74.0060)','POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))', 'LINESTRING(-74.0060 40.7128, -73.9352 40.7306, -73.9123 40.7831)','LINESTRING(0 0, 1 1, 2 2)', - 'POLYGON((-74.0060 40.7128, -73.9352 40.7306, -73.9123 40.7831, -74.0060 40.7128))', 'happy','"a"=>"a\"quote\"", "b"=>NULL'; + 'POLYGON((-74.0060 40.7128, -73.9352 40.7306, -73.9123 40.7831, -74.0060 40.7128))', 'happy','"a"=>"a\"quote\"", "b"=>NULL', + '{2020-01-01, 2020-01-02}'::date[], + '{"2020-01-01 01:01:01+00", "2020-01-02 01:01:01+00"}'::timestamptz[], + '{"2020-01-01 01:01:01", "2020-01-02 01:01:01"}'::timestamp[], + '{true, false}'::boolean[], + '{1,2}'::smallint[]; `, srcTableName)) e2e.EnvNoError(s.t, env, err) }() @@ -731,6 +736,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Types_SF() { "c6", "c39", "c40", "id", "c9", "c11", "c12", "c13", "c14", "c15", "c16", "c17", "c18", "c21", "c22", "c23", "c24", "c28", "c29", "c30", "c31", "c33", "c34", "c35", "c36", "c37", "c38", "c7", "c8", "c32", "c42", "c43", "c44", "c45", "c46", "c47", "c48", "c49", + "c50", "c51", "c52", "c53", "c54", }) if err != nil { s.t.Log(err) diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go index 6d104a8750..69107cdc8f 100644 --- a/flow/e2e/test_utils.go +++ b/flow/e2e/test_utils.go @@ -241,6 +241,11 @@ func CreateTableForQRep(pool *pgxpool.Pool, suffix string, tableName string) err "f6 jsonb", "f7 jsonb", "f8 smallint", + "f9 date[]", + "f10 timestamp with time zone[]", + "f11 timestamp without time zone[]", + "f12 boolean[]", + "f13 smallint[]", "my_date DATE", "my_mood mood", "myh HSTORE", @@ -299,7 +304,12 @@ func PopulateSourceTable(pool *pgxpool.Pool, suffix string, tableName string, ro CURRENT_TIMESTAMP, 1, ARRAY['text1', 'text2'], ARRAY[123, 456], ARRAY[789, 012], ARRAY['varchar1', 'varchar2'], '{"key": -8.02139037433155}', '[{"key1": "value1", "key2": "value2", "key3": "value3"}]', - '{"key": "value"}', 15, CURRENT_DATE, 'happy', '"a"=>"b"','POINT(1 2)','POINT(40.7128 -74.0060)', + '{"key": "value"}', 15,'{2023-09-09,2029-08-10}', + '{"2024-01-15 17:00:00+00","2024-01-16 14:30:00+00"}', + '{"2026-01-17 10:00:00","2026-01-18 13:45:00"}', + '{true, false}', + '{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}', + CURRENT_DATE, 'happy', '"a"=>"b"','POINT(1 2)','POINT(40.7128 -74.0060)', 'LINESTRING(0 0, 1 1, 2 2)', 'LINESTRING(-74.0060 40.7128, -73.9352 40.7306, -73.9123 40.7831)', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))','POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))' @@ -317,7 +327,7 @@ func PopulateSourceTable(pool *pgxpool.Pool, suffix string, tableName string, ro deal_id, ethereum_transaction_id, ignore_price, card_eth_value, paid_eth_price, card_bought_notified, address, account_id, asset_id, status, transaction_id, settled_at, reference_id, - settle_at, settlement_delay_reason, f1, f2, f3, f4, f5, f6, f7, f8, my_date, my_mood, myh, + settle_at, settlement_delay_reason, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, my_date, my_mood, myh, "geometryPoint", geography_point,geometry_linestring, geography_linestring,geometry_polygon, geography_polygon ) VALUES %s; `, suffix, tableName, strings.Join(rows, ","))) @@ -438,6 +448,7 @@ func GetOwnersSchema() *model.QRecordSchema { {Name: "f6", Type: qvalue.QValueKindJSON, Nullable: true}, {Name: "f7", Type: qvalue.QValueKindJSON, Nullable: true}, {Name: "f8", Type: qvalue.QValueKindInt16, Nullable: true}, + {Name: "f13", Type: qvalue.QValueKindArrayInt16, Nullable: true}, {Name: "my_date", Type: qvalue.QValueKindDate, Nullable: true}, {Name: "my_mood", Type: qvalue.QValueKindString, Nullable: true}, {Name: "geometryPoint", Type: qvalue.QValueKindGeometry, Nullable: true}, diff --git a/flow/model/model.go b/flow/model/model.go index 33510e4e12..b3bd44b0ed 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -146,6 +146,22 @@ func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) { var err error switch v.Kind { + case qvalue.QValueKindBit, qvalue.QValueKindBytes: + bitVal, ok := v.Value.([]byte) + if !ok { + return nil, errors.New("expected []byte value") + } + + // convert to binary string because + // json.Marshal stores byte arrays as + // base64 + binStr := "" + for _, b := range bitVal { + binStr += fmt.Sprintf("%08b", b) + } + + jsonStruct[col] = binStr + case qvalue.QValueKindString, qvalue.QValueKindJSON: strVal, ok := v.Value.(string) if !ok { @@ -184,6 +200,16 @@ func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) { if err != nil { return nil, err } + case qvalue.QValueKindArrayDate: + dateArr, ok := v.Value.([]time.Time) + if !ok { + return nil, errors.New("expected []time.Time value") + } + formattedDateArr := make([]string, 0, len(dateArr)) + for _, val := range dateArr { + formattedDateArr = append(formattedDateArr, val.Format("2006-01-02")) + } + jsonStruct[col] = formattedDateArr case qvalue.QValueKindNumeric: bigRat, ok := v.Value.(*big.Rat) if !ok { diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go index bbbf7006a4..4729e04baa 100644 --- a/flow/model/qrecord_batch.go +++ b/flow/model/qrecord_batch.go @@ -41,6 +41,18 @@ func (q *QRecordBatch) ToQRecordStream(buffer int) (*QRecordStream, error) { return stream, nil } +func constructArray[T any](qValue qvalue.QValue, typeName string) (*pgtype.Array[T], error) { + v, ok := qValue.Value.([]T) + if !ok { + return nil, fmt.Errorf("invalid %s value", typeName) + } + return &pgtype.Array[T]{ + Elements: v, + Dims: []pgtype.ArrayDimension{{Length: int32(len(v)), LowerBound: 1}}, + Valid: true, + }, nil +} + type QRecordBatchCopyFromSource struct { numRecords int stream *QRecordStream @@ -215,65 +227,67 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { values[i] = wkb case qvalue.QValueKindArrayString: - v, ok := qValue.Value.([]string) - if !ok { - src.err = fmt.Errorf("invalid ArrayString value") + v, err := constructArray[string](qValue, "ArrayString") + if err != nil { + src.err = err return nil, src.err } - values[i] = pgtype.Array[string]{ - Elements: v, - Dims: []pgtype.ArrayDimension{{Length: int32(len(v)), LowerBound: 1}}, - Valid: true, + values[i] = v + + case qvalue.QValueKindArrayDate, qvalue.QValueKindArrayTimestamp, qvalue.QValueKindArrayTimestampTZ: + v, err := constructArray[time.Time](qValue, "ArrayTime") + if err != nil { + src.err = err + return nil, src.err } + values[i] = v - case qvalue.QValueKindArrayInt32: - v, ok := qValue.Value.([]int32) - if !ok { - src.err = fmt.Errorf("invalid ArrayInt32 value") + case qvalue.QValueKindArrayInt16: + v, err := constructArray[int16](qValue, "ArrayInt16") + if err != nil { + src.err = err return nil, src.err } - values[i] = pgtype.Array[int32]{ - Elements: v, - Dims: []pgtype.ArrayDimension{{Length: int32(len(v)), LowerBound: 1}}, - Valid: true, + values[i] = v + + case qvalue.QValueKindArrayInt32: + v, err := constructArray[int32](qValue, "ArrayInt32") + if err != nil { + src.err = err + return nil, src.err } + values[i] = v case qvalue.QValueKindArrayInt64: - v, ok := qValue.Value.([]int64) - if !ok { - src.err = fmt.Errorf("invalid ArrayInt64 value") + v, err := constructArray[int64](qValue, "ArrayInt64") + if err != nil { + src.err = err return nil, src.err } - values[i] = pgtype.Array[int64]{ - Elements: v, - Dims: []pgtype.ArrayDimension{{Length: int32(len(v)), LowerBound: 1}}, - Valid: true, - } + values[i] = v case qvalue.QValueKindArrayFloat32: - v, ok := qValue.Value.([]float32) - if !ok { - src.err = fmt.Errorf("invalid ArrayFloat32 value") + v, err := constructArray[float32](qValue, "ArrayFloat32") + if err != nil { + src.err = err return nil, src.err } - values[i] = pgtype.Array[float32]{ - Elements: v, - Dims: []pgtype.ArrayDimension{{Length: int32(len(v)), LowerBound: 1}}, - Valid: true, - } + values[i] = v case qvalue.QValueKindArrayFloat64: - v, ok := qValue.Value.([]float64) - if !ok { - src.err = fmt.Errorf("invalid ArrayFloat64 value") + v, err := constructArray[float64](qValue, "ArrayFloat64") + if err != nil { + src.err = err return nil, src.err } - values[i] = pgtype.Array[float64]{ - Elements: v, - Dims: []pgtype.ArrayDimension{{Length: int32(len(v)), LowerBound: 1}}, - Valid: true, + values[i] = v + case qvalue.QValueKindArrayBoolean: + v, err := constructArray[bool](qValue, "ArrayBool") + if err != nil { + src.err = err + return nil, src.err } - + values[i] = v case qvalue.QValueKindJSON: v, ok := qValue.Value.(string) if !ok { diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 4e8f0a60d3..b682d277f5 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -17,6 +17,11 @@ type AvroSchemaArray struct { Items string `json:"items"` } +type AvroSchemaComplexArray struct { + Type string `json:"type"` + Items AvroSchemaField `json:"items"` +} + type AvroSchemaNumeric struct { Type string `json:"type"` LogicalType string `json:"logicalType"` @@ -24,6 +29,18 @@ type AvroSchemaNumeric struct { Scale int `json:"scale"` } +type AvroSchemaRecord struct { + Type string `json:"type"` + Name string `json:"name"` + Fields []AvroSchemaField `json:"fields"` +} + +type AvroSchemaField struct { + Name string `json:"name"` + Type interface{} `json:"type"` + LogicalType string `json:"logicalType,omitempty"` +} + // GetAvroSchemaFromQValueKind returns the Avro schema for a given QValueKind. // The function takes in two parameters, a QValueKind and a boolean indicating if the // Avro schema should respect null values. It returns a QValueKindAvroSchema object @@ -74,7 +91,7 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH QDWHType) (interface Type: "array", Items: "double", }, nil - case QValueKindArrayInt32: + case QValueKindArrayInt32, QValueKindArrayInt16: return AvroSchemaArray{ Type: "array", Items: "int", @@ -84,6 +101,21 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH QDWHType) (interface Type: "array", Items: "long", }, nil + case QValueKindArrayBoolean: + return AvroSchemaArray{ + Type: "array", + Items: "boolean", + }, nil + case QValueKindArrayDate: + return AvroSchemaArray{ + Type: "array", + Items: "string", + }, nil + case QValueKindArrayTimestamp, QValueKindArrayTimestampTZ: + return AvroSchemaArray{ + Type: "array", + Items: "string", + }, nil case QValueKindArrayString: return AvroSchemaArray{ Type: "array", @@ -116,7 +148,7 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { case QValueKindInvalid: // we will attempt to convert invalid to a string return c.processNullableUnion("string", c.Value.Value) - case QValueKindTime, QValueKindTimeTZ, QValueKindDate, QValueKindTimestamp, QValueKindTimestampTZ: + case QValueKindTime, QValueKindTimeTZ, QValueKindTimestamp, QValueKindTimestampTZ: t, err := c.processGoTime() if err != nil || t == nil { return t, err @@ -138,10 +170,29 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { } if c.Nullable { return goavro.Union("long.timestamp-micros", t.(int64)), nil + } + return t.(int64), nil + + case QValueKindDate: + t, err := c.processGoDate() + if err != nil || t == nil { + return t, err + } + if c.TargetDWH == QDWHTypeSnowflake { + if c.Nullable { + return c.processNullableUnion("string", t.(string)) + } else { + return t.(string), nil + } + } + + if c.Nullable { + return goavro.Union("int.date", t), nil } else { - return t.(int64), nil + return t, nil } - case QValueKindString: + + case QValueKindString, QValueKindCIDR, QValueKindINET, QValueKindMacaddr: if c.TargetDWH == QDWHTypeSnowflake && c.Value.Value != nil && (len(c.Value.Value.(string)) > 15*1024*1024) { slog.Warn("Truncating TEXT value > 15MB for Snowflake!") @@ -176,12 +227,30 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return c.processArrayFloat32() case QValueKindArrayFloat64: return c.processArrayFloat64() + case QValueKindArrayInt16: + return c.processArrayInt16() case QValueKindArrayInt32: return c.processArrayInt32() case QValueKindArrayInt64: return c.processArrayInt64() case QValueKindArrayString: return c.processArrayString() + case QValueKindArrayBoolean: + return c.processArrayBoolean() + case QValueKindArrayTimestamp, QValueKindArrayTimestampTZ: + arrayTime, err := c.processArrayTime() + if err != nil || arrayTime == nil { + return arrayTime, err + } + + return arrayTime, nil + case QValueKindArrayDate: + arrayDate, err := c.processArrayDate() + if err != nil || arrayDate == nil { + return arrayDate, err + } + + return arrayDate, nil case QValueKindUUID: return c.processUUID() case QValueKindGeography, QValueKindGeometry, QValueKindPoint: @@ -210,6 +279,25 @@ func (c *QValueAvroConverter) processGoTime() (interface{}, error) { return ret, nil } +func (c *QValueAvroConverter) processGoDate() (interface{}, error) { + if c.Value.Value == nil && c.Nullable { + return nil, nil + } + + t, ok := c.Value.Value.(time.Time) + if !ok { + return nil, fmt.Errorf("invalid Time value for Date") + } + + // Snowflake has issues with avro timestamp types, returning as string form of the int64 + // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file + if c.TargetDWH == QDWHTypeSnowflake { + ret := t.UnixMicro() + return fmt.Sprint(ret), nil + } + return t, nil +} + func (c *QValueAvroConverter) processNullableUnion( avroType string, value interface{}, @@ -298,6 +386,77 @@ func (c *QValueAvroConverter) processJSON() (interface{}, error) { return jsonString, nil } +func (c *QValueAvroConverter) processArrayBoolean() (interface{}, error) { + if c.Value.Value == nil && c.Nullable { + return nil, nil + } + + arrayData, ok := c.Value.Value.([]bool) + if !ok { + return nil, fmt.Errorf("invalid Boolean array value") + } + + if c.Nullable { + return goavro.Union("array", arrayData), nil + } + + return arrayData, nil +} + +func (c *QValueAvroConverter) processArrayTime() (interface{}, error) { + if c.Value.Value == nil && c.Nullable { + return nil, nil + } + + arrayTime, ok := c.Value.Value.([]time.Time) + if !ok { + return nil, fmt.Errorf("invalid Timestamp array value") + } + + transformedTimeArr := make([]interface{}, 0, len(arrayTime)) + for _, t := range arrayTime { + // Snowflake has issues with avro timestamp types, returning as string form of the int64 + // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file + if c.TargetDWH == QDWHTypeSnowflake { + transformedTimeArr = append(transformedTimeArr, t.String()) + } else { + transformedTimeArr = append(transformedTimeArr, t) + } + } + + if c.Nullable { + return goavro.Union("array", transformedTimeArr), nil + } + + return transformedTimeArr, nil +} + +func (c *QValueAvroConverter) processArrayDate() (interface{}, error) { + if c.Value.Value == nil && c.Nullable { + return nil, nil + } + + arrayDate, ok := c.Value.Value.([]time.Time) + if !ok { + return nil, fmt.Errorf("invalid Date array value") + } + + transformedTimeArr := make([]interface{}, 0, len(arrayDate)) + for _, t := range arrayDate { + if c.TargetDWH == QDWHTypeSnowflake { + transformedTimeArr = append(transformedTimeArr, t.Format("2006-01-02")) + } else { + transformedTimeArr = append(transformedTimeArr, t) + } + } + + if c.Nullable { + return goavro.Union("array", transformedTimeArr), nil + } + + return transformedTimeArr, nil +} + func (c *QValueAvroConverter) processHStore() (interface{}, error) { if c.Value.Value == nil && c.Nullable { return nil, nil @@ -374,6 +533,29 @@ func (c *QValueAvroConverter) processGeospatial() (interface{}, error) { return geoString, nil } +func (c *QValueAvroConverter) processArrayInt16() (interface{}, error) { + if c.Value.Value == nil && c.Nullable { + return nil, nil + } + + arrayData, ok := c.Value.Value.([]int16) + if !ok { + return nil, fmt.Errorf("invalid Int16 array value") + } + + // cast to int32 + int32Data := make([]int32, 0, len(arrayData)) + for _, v := range arrayData { + int32Data = append(int32Data, int32(v)) + } + + if c.Nullable { + return goavro.Union("array", int32Data), nil + } + + return int32Data, nil +} + func (c *QValueAvroConverter) processArrayInt32() (interface{}, error) { if c.Value.Value == nil && c.Nullable { return nil, nil diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index 175430f92e..bba156bb88 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -33,12 +33,22 @@ const ( QValueKindGeometry QValueKind = "geometry" QValueKindPoint QValueKind = "point" + // network types + QValueKindCIDR QValueKind = "cidr" + QValueKindINET QValueKind = "inet" + QValueKindMacaddr QValueKind = "macaddr" + // array types - QValueKindArrayFloat32 QValueKind = "array_float32" - QValueKindArrayFloat64 QValueKind = "array_float64" - QValueKindArrayInt32 QValueKind = "array_int32" - QValueKindArrayInt64 QValueKind = "array_int64" - QValueKindArrayString QValueKind = "array_string" + QValueKindArrayFloat32 QValueKind = "array_float32" + QValueKindArrayFloat64 QValueKind = "array_float64" + QValueKindArrayInt16 QValueKind = "array_int16" + QValueKindArrayInt32 QValueKind = "array_int32" + QValueKindArrayInt64 QValueKind = "array_int64" + QValueKindArrayString QValueKind = "array_string" + QValueKindArrayDate QValueKind = "array_date" + QValueKindArrayTimestamp QValueKind = "array_timestamp" + QValueKindArrayTimestampTZ QValueKind = "array_timestamptz" + QValueKindArrayBoolean QValueKind = "array_bool" ) func (kind QValueKind) IsArray() bool { @@ -71,11 +81,16 @@ var QValueKindToSnowflakeTypeMap = map[QValueKind]string{ QValueKindPoint: "GEOMETRY", // array types will be mapped to VARIANT - QValueKindArrayFloat32: "VARIANT", - QValueKindArrayFloat64: "VARIANT", - QValueKindArrayInt32: "VARIANT", - QValueKindArrayInt64: "VARIANT", - QValueKindArrayString: "VARIANT", + QValueKindArrayFloat32: "VARIANT", + QValueKindArrayFloat64: "VARIANT", + QValueKindArrayInt32: "VARIANT", + QValueKindArrayInt64: "VARIANT", + QValueKindArrayInt16: "VARIANT", + QValueKindArrayString: "VARIANT", + QValueKindArrayDate: "VARIANT", + QValueKindArrayTimestamp: "VARIANT", + QValueKindArrayTimestampTZ: "VARIANT", + QValueKindArrayBoolean: "VARIANT", } var QValueKindToClickhouseTypeMap = map[QValueKind]string{ diff --git a/flow/model/qvalue/qvalue.go b/flow/model/qvalue/qvalue.go index 83fb2e9db7..786065227b 100644 --- a/flow/model/qvalue/qvalue.go +++ b/flow/model/qvalue/qvalue.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + "cloud.google.com/go/civil" "github.com/PeerDB-io/peer-flow/geo" hstore_util "github.com/PeerDB-io/peer-flow/hstore" "github.com/google/uuid" @@ -60,10 +61,16 @@ func (q QValue) Equals(other QValue) bool { return compareNumericArrays(q.Value, other.Value) case QValueKindArrayFloat64: return compareNumericArrays(q.Value, other.Value) - case QValueKindArrayInt32: + case QValueKindArrayInt32, QValueKindArrayInt16: return compareNumericArrays(q.Value, other.Value) case QValueKindArrayInt64: return compareNumericArrays(q.Value, other.Value) + case QValueKindArrayDate: + return compareDateArrays(q.Value, other.Value) + case QValueKindArrayTimestamp, QValueKindArrayTimestampTZ: + return compareTimeArrays(q.Value, other.Value) + case QValueKindArrayBoolean: + return compareBoolArrays(q.Value, other.Value) case QValueKindArrayString: return compareArrayString(q.Value, other.Value) } @@ -274,6 +281,12 @@ func compareNumericArrays(value1, value2 interface{}) bool { // Helper function to convert a value to float64 convertToFloat64 := func(val interface{}) []float64 { switch v := val.(type) { + case []int16: + result := make([]float64, len(v)) + for i, value := range v { + result[i] = float64(value) + } + return result case []int32: result := make([]float64, len(v)) for i, value := range v { @@ -315,6 +328,72 @@ func compareNumericArrays(value1, value2 interface{}) bool { return true } +func compareTimeArrays(value1, value2 interface{}) bool { + if value1 == nil && value2 == nil { + return true + } + array1, ok1 := value1.([]time.Time) + array2, ok2 := value2.([]time.Time) + + if !ok1 || !ok2 { + return false + } + + if len(array1) != len(array2) { + return false + } + + for i := range array1 { + if !array1[i].Equal(array2[i]) { + return false + } + } + + return true +} + +func compareDateArrays(value1, value2 interface{}) bool { + if value1 == nil && value2 == nil { + return true + } + array1, ok1 := value1.([]time.Time) + array2, ok2 := value2.([]civil.Date) + + if !ok1 || !ok2 || len(array1) != len(array2) { + return false + } + + for i := range array1 { + if array1[i].Year() != array2[i].Year || + array1[i].Month() != array2[i].Month || + array1[i].Day() != array2[i].Day { + return false + } + } + + return true +} + +func compareBoolArrays(value1, value2 interface{}) bool { + if value1 == nil && value2 == nil { + return true + } + array1, ok1 := value1.([]bool) + array2, ok2 := value2.([]bool) + + if !ok1 || !ok2 || len(array1) != len(array2) { + return false + } + + for i := range array1 { + if array1[i] != array2[i] { + return false + } + } + + return true +} + func compareArrayString(value1, value2 interface{}) bool { if value1 == nil && value2 == nil { return true