Skip to content

Commit

Permalink
Add support for "char"
Browse files Browse the repository at this point in the history
Postgres offers a type "char" distinct from CHAR,
represented by one byte

Map this type in QValue, sqlserver also has char,
& on clickhouse we can represent it with FixedString(1)
  • Loading branch information
serprex committed Feb 14, 2024
1 parent c47dbb2 commit 86c8180
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 47 deletions.
1 change: 1 addition & 0 deletions flow/connectors/clickhouse/qvalue_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var clickhouseTypeToQValueKindMap = map[string]qvalue.QValueKind{
"CHAR": qvalue.QValueKindString,
"TEXT": qvalue.QValueKindString,
"String": qvalue.QValueKindString,
"FixedString(1)": qvalue.QValueKindQChar,
"Bool": qvalue.QValueKindBoolean,
"DateTime": qvalue.QValueKindTimestamp,
"TIMESTAMP": qvalue.QValueKindTimestamp,
Expand Down
11 changes: 8 additions & 3 deletions flow/connectors/postgres/qvalue_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ func postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind {
return qvalue.QValueKindFloat32
case pgtype.Float8OID:
return qvalue.QValueKindFloat64
case pgtype.QCharOID:
return qvalue.QValueKindQChar
case pgtype.TextOID, pgtype.VarcharOID, pgtype.BPCharOID:
return qvalue.QValueKindString
case pgtype.ByteaOID:
Expand Down Expand Up @@ -121,6 +123,8 @@ func qValueKindToPostgresType(colTypeStr string) string {
return "REAL"
case qvalue.QValueKindFloat64:
return "DOUBLE PRECISION"
case qvalue.QValueKindQChar:
return "\"char\""
case qvalue.QValueKindString:
return "TEXT"
case qvalue.QValueKindBytes:
Expand Down Expand Up @@ -262,6 +266,8 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
case qvalue.QValueKindFloat64:
floatVal := value.(float64)
val = qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: floatVal}
case qvalue.QValueKindQChar:
val = qvalue.QValue{Kind: qvalue.QValueKindQChar, Value: uint8(value.(rune))}
case qvalue.QValueKindString:
// handling all unsupported types with strings as well for now.
val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: fmt.Sprint(value)}
Expand Down Expand Up @@ -501,10 +507,9 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
}
default:
textVal, ok := value.(string)
if !ok {
return qvalue.QValue{}, fmt.Errorf("failed to parse value %v into QValueKind %v", value, qvalueKind)
if ok {
val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: textVal}
}
val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: textVal}
}

// parsing into pgtype failed.
Expand Down
12 changes: 9 additions & 3 deletions flow/connectors/postgres/schema_delta_test_constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ var AddAllColumnTypes = []string{
string(qvalue.QValueKindJSON),
string(qvalue.QValueKindNumeric),
string(qvalue.QValueKindString),
string(qvalue.QValueKindQChar),
string(qvalue.QValueKindTime),
string(qvalue.QValueKindTimestamp),
string(qvalue.QValueKindTimestampTZ),
Expand Down Expand Up @@ -93,21 +94,26 @@ var AddAllColumnTypesFields = []*protos.FieldDescription{
},
{
Name: "c13",
Type: string(qvalue.QValueKindTime),
Type: string(qvalue.QValueKindQChar),
TypeModifier: -1,
},
{
Name: "c14",
Type: string(qvalue.QValueKindTimestamp),
Type: string(qvalue.QValueKindTime),
TypeModifier: -1,
},
{
Name: "c15",
Type: string(qvalue.QValueKindTimestampTZ),
Type: string(qvalue.QValueKindTimestamp),
TypeModifier: -1,
},
{
Name: "c16",
Type: string(qvalue.QValueKindTimestampTZ),
TypeModifier: -1,
},
{
Name: "c17",
Type: string(qvalue.QValueKindUUID),
TypeModifier: -1,
},
Expand Down
3 changes: 3 additions & 0 deletions flow/connectors/snowflake/avro_file_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue.
value = big.NewRat(int64(placeHolder), 1)
case qvalue.QValueKindUUID:
value = uuid.New() // assuming you have the github.com/google/uuid package
case qvalue.QValueKindQChar:
value = uint8(48)
// case qvalue.QValueKindArray:
// value = []int{1, 2, 3} // placeholder array, replace with actual logic
// case qvalue.QValueKindStruct:
Expand Down Expand Up @@ -85,6 +87,7 @@ func generateRecords(
qvalue.QValueKindNumeric,
qvalue.QValueKindBytes,
qvalue.QValueKindUUID,
qvalue.QValueKindQChar,
// qvalue.QValueKindJSON,
qvalue.QValueKindBit,
}
Expand Down
28 changes: 13 additions & 15 deletions flow/connectors/snowflake/merge_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
}

targetColumnName := SnowflakeIdentifierNormalize(column.Name)
switch qvalue.QValueKind(genericColumnType) {
switch qvKind {
case qvalue.QValueKindBytes, qvalue.QValueKindBit:
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+
"AS %s", toVariantColumnName, column.Name, targetColumnName))
Expand All @@ -61,21 +61,19 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
// flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+
// "Microseconds*1000) "+
// "AS %s", toVariantColumnName, columnName, columnName))
default:
if qvKind == qvalue.QValueKindNumeric {
precision, scale := numeric.ParseNumericTypmod(column.TypeModifier)
if column.TypeModifier == -1 || precision > 38 || scale > 37 {
precision = numeric.PeerDBNumericPrecision
scale = numeric.PeerDBNumericScale
}
numericType := fmt.Sprintf("NUMERIC(%d,%d)", precision, scale)
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s",
toVariantColumnName, column.Name, numericType, targetColumnName))
} else {
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s",
toVariantColumnName, column.Name, sfType, targetColumnName))
case qvalue.QValueKindNumeric:
precision, scale := numeric.ParseNumericTypmod(column.TypeModifier)
if column.TypeModifier == -1 || precision > 38 || scale > 37 {
precision = numeric.PeerDBNumericPrecision
scale = numeric.PeerDBNumericScale
}
numericType := fmt.Sprintf("NUMERIC(%d,%d)", precision, scale)
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s",
toVariantColumnName, column.Name, numericType, targetColumnName))
default:
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s",
toVariantColumnName, column.Name, sfType, targetColumnName))
}
}
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")
Expand Down
7 changes: 7 additions & 0 deletions flow/connectors/sql/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ func (g *GenericSQLQueryExecutor) CheckNull(ctx context.Context, schema string,
}

func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) {
if val == nil {
return qvalue.QValue{Kind: kind, Value: nil}, nil
}
switch kind {
case qvalue.QValueKindInt32:
if v, ok := val.(*sql.NullInt32); ok {
Expand Down Expand Up @@ -341,6 +344,10 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) {
return qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: nil}, nil
}
}
case qvalue.QValueKindQChar:
if v, ok := val.(uint8); ok {
return qvalue.QValue{Kind: qvalue.QValueKindQChar, Value: v}, nil
}
case qvalue.QValueKindString:
if v, ok := val.(*sql.NullString); ok {
if v.Valid {
Expand Down
3 changes: 2 additions & 1 deletion flow/connectors/sqlserver/qvalue_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ var qValueKindToSQLServerTypeMap = map[qvalue.QValueKind]string{
qvalue.QValueKindFloat32: "REAL",
qvalue.QValueKindFloat64: "FLOAT",
qvalue.QValueKindNumeric: "DECIMAL(38, 9)",
qvalue.QValueKindQChar: "CHAR",
qvalue.QValueKindString: "NTEXT",
qvalue.QValueKindJSON: "NTEXT", // SQL Server doesn't have a native JSON type
qvalue.QValueKindTimestamp: "DATETIME2",
Expand Down Expand Up @@ -51,7 +52,7 @@ var sqlServerTypeToQValueKindMap = map[string]qvalue.QValueKind{
"UNIQUEIDENTIFIER": qvalue.QValueKindUUID,
"SMALLINT": qvalue.QValueKindInt32,
"TINYINT": qvalue.QValueKindInt32,
"CHAR": qvalue.QValueKindString,
"CHAR": qvalue.QValueKindQChar,
"VARCHAR": qvalue.QValueKindString,
"NCHAR": qvalue.QValueKindString,
"NVARCHAR": qvalue.QValueKindString,
Expand Down
2 changes: 1 addition & 1 deletion flow/model/conversion_avro.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func NewQRecordAvroConverter(
}

func (qac *QRecordAvroConverter) Convert() (map[string]interface{}, error) {
m := map[string]interface{}{}
m := make(map[string]interface{}, len(qac.QRecord))

for idx, val := range qac.QRecord {
key := qac.ColNames[idx]
Expand Down
6 changes: 6 additions & 0 deletions flow/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,13 @@ func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) {
}

jsonStruct[col] = binStr
case qvalue.QValueKindQChar:
ch, ok := v.Value.(uint8)
if !ok {
return nil, fmt.Errorf("expected \"char\" value for column %s for %T", col, v.Value)
}

jsonStruct[col] = string(ch)
case qvalue.QValueKindString, qvalue.QValueKindJSON:
strVal, ok := v.Value.(string)
if !ok {
Expand Down
13 changes: 8 additions & 5 deletions flow/model/qrecord_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) {
}
values[i] = v

case qvalue.QValueKindQChar:
v, ok := qValue.Value.(uint8)
if !ok {
src.err = fmt.Errorf("invalid \"char\" value")
return nil, src.err
}
values[i] = rune(v)

case qvalue.QValueKindString:
v, ok := qValue.Value.(string)
if !ok {
Expand Down Expand Up @@ -173,11 +181,6 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) {
values[i] = timestampTZ

case qvalue.QValueKindUUID:
if qValue.Value == nil {
values[i] = nil
break
}

v, ok := qValue.Value.([16]byte) // treat it as byte slice
if !ok {
src.err = fmt.Errorf("invalid UUID value %v", qValue.Value)
Expand Down
43 changes: 24 additions & 19 deletions flow/model/qvalue/avro_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH QDWHType, precision
}

switch kind {
case QValueKindString:
case QValueKindString, QValueKindQChar:
return "string", nil
case QValueKindUUID:
return AvroSchemaLogical{
Expand Down Expand Up @@ -169,6 +169,10 @@ func NewQValueAvroConverter(value QValue, targetDWH QDWHType, nullable bool) *QV
}

func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) {
if c.Nullable && c.Value.Value == nil {
return nil, nil
}

switch c.Value.Kind {
case QValueKindInvalid:
// we will attempt to convert invalid to a string
Expand All @@ -180,21 +184,21 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) {
}
if c.TargetDWH == QDWHTypeSnowflake {
if c.Nullable {
return c.processNullableUnion("string", t.(string))
return c.processNullableUnion("string", t)
} else {
return t.(string), nil
}
}

if c.TargetDWH == QDWHTypeClickhouse {
if c.Nullable {
return c.processNullableUnion("string", t.(string))
return c.processNullableUnion("string", t)
} else {
return t.(string), nil
}
}
if c.Nullable {
return goavro.Union("long.time-micros", t.(int64)), nil
return goavro.Union("long.time-micros", t), nil
}
return t.(int64), nil
case QValueKindTimeTZ:
Expand All @@ -204,21 +208,21 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) {
}
if c.TargetDWH == QDWHTypeSnowflake {
if c.Nullable {
return c.processNullableUnion("string", t.(string))
return c.processNullableUnion("string", t)
} else {
return t.(string), nil
}
}

if c.TargetDWH == QDWHTypeClickhouse {
if c.Nullable {
return c.processNullableUnion("long", t.(int64))
return c.processNullableUnion("long", t)
} else {
return t.(int64), nil
}
}
if c.Nullable {
return goavro.Union("long.time-micros", t.(int64)), nil
return goavro.Union("long.time-micros", t), nil
}
return t.(int64), nil
case QValueKindTimestamp:
Expand All @@ -228,21 +232,21 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) {
}
if c.TargetDWH == QDWHTypeSnowflake {
if c.Nullable {
return c.processNullableUnion("string", t.(string))
return c.processNullableUnion("string", t)
} else {
return t.(string), nil
}
}

if c.TargetDWH == QDWHTypeClickhouse {
if c.Nullable {
return c.processNullableUnion("long", t.(int64))
return c.processNullableUnion("long", t)
} else {
return t.(int64), nil
}
}
if c.Nullable {
return goavro.Union("long.timestamp-micros", t.(int64)), nil
return goavro.Union("long.timestamp-micros", t), nil
}
return t.(int64), nil
case QValueKindTimestampTZ:
Expand All @@ -252,21 +256,21 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) {
}
if c.TargetDWH == QDWHTypeSnowflake {
if c.Nullable {
return c.processNullableUnion("string", t.(string))
return c.processNullableUnion("string", t)
} else {
return t.(string), nil
}
}

if c.TargetDWH == QDWHTypeClickhouse {
if c.Nullable {
return c.processNullableUnion("long", t.(int64))
return c.processNullableUnion("long", t)
} else {
return t.(int64), nil
}
}
if c.Nullable {
return goavro.Union("long.timestamp-micros", t.(int64)), nil
return goavro.Union("long.timestamp-micros", t), nil
}
return t.(int64), nil
case QValueKindDate:
Expand All @@ -277,7 +281,7 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) {

if c.TargetDWH == QDWHTypeSnowflake {
if c.Nullable {
return c.processNullableUnion("string", t.(string))
return c.processNullableUnion("string", t)
} else {
return t.(string), nil
}
Expand All @@ -287,7 +291,8 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) {
return goavro.Union("int.date", t), nil
}
return t, nil

case QValueKindQChar:
return c.processNullableUnion("string", string(c.Value.Value.(uint8)))
case QValueKindString, QValueKindCIDR, QValueKindINET, QValueKindMacaddr:
if c.TargetDWH == QDWHTypeSnowflake && c.Value.Value != nil &&
(len(c.Value.Value.(string)) > 15*1024*1024) {
Expand Down Expand Up @@ -457,11 +462,11 @@ func (c *QValueAvroConverter) processNullableUnion(
avroType string,
value interface{},
) (interface{}, error) {
if value == nil && c.Nullable {
return nil, nil
}

if c.Nullable {
if value == nil {
return nil, nil
}

return goavro.Union(avroType, value), nil
}

Expand Down
Loading

0 comments on commit 86c8180

Please sign in to comment.