From 3de7fb0282a629d7c83981711a115d169cc012e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 14 Feb 2024 04:25:19 +0000 Subject: [PATCH] Misc cleanup of qvalue code Split out from #1285 --- flow/connectors/postgres/qvalue_convert.go | 5 ++-- .../snowflake/merge_stmt_generator.go | 28 +++++++++---------- flow/model/conversion_avro.go | 2 +- flow/model/qrecord_batch.go | 5 ---- flow/model/qvalue/avro_converter.go | 12 ++++---- 5 files changed, 23 insertions(+), 29 deletions(-) diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index 4823aa2ecc..72f51dff6b 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -501,10 +501,9 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( } default: textVal, ok := value.(string) - if !ok { - return qvalue.QValue{}, fmt.Errorf("failed to parse value %v into QValueKind %v", value, qvalueKind) + if ok { + val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: textVal} } - val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: textVal} } // parsing into pgtype failed. diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go index 4d8ea7dc5a..1ab579069c 100644 --- a/flow/connectors/snowflake/merge_stmt_generator.go +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -40,7 +40,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { } targetColumnName := SnowflakeIdentifierNormalize(column.Name) - switch qvalue.QValueKind(genericColumnType) { + switch qvKind { case qvalue.QValueKindBytes, qvalue.QValueKindBit: flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+ "AS %s", toVariantColumnName, column.Name, targetColumnName)) @@ -61,21 +61,19 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { // flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+ // "Microseconds*1000) "+ // "AS %s", toVariantColumnName, columnName, columnName)) - default: - if qvKind == qvalue.QValueKindNumeric { - precision, scale := numeric.ParseNumericTypmod(column.TypeModifier) - if column.TypeModifier == -1 || precision > 38 || scale > 37 { - precision = numeric.PeerDBNumericPrecision - scale = numeric.PeerDBNumericScale - } - numericType := fmt.Sprintf("NUMERIC(%d,%d)", precision, scale) - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s", - toVariantColumnName, column.Name, numericType, targetColumnName)) - } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s", - toVariantColumnName, column.Name, sfType, targetColumnName)) + case qvalue.QValueKindNumeric: + precision, scale := numeric.ParseNumericTypmod(column.TypeModifier) + if column.TypeModifier == -1 || precision > 38 || scale > 37 { + precision = numeric.PeerDBNumericPrecision + scale = numeric.PeerDBNumericScale } + numericType := fmt.Sprintf("NUMERIC(%d,%d)", precision, scale) + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s", + toVariantColumnName, column.Name, numericType, targetColumnName)) + default: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s", + toVariantColumnName, column.Name, sfType, targetColumnName)) } } flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") diff --git a/flow/model/conversion_avro.go b/flow/model/conversion_avro.go index b379abc09a..b26aeaf9d7 100644 --- a/flow/model/conversion_avro.go +++ b/flow/model/conversion_avro.go @@ -29,7 +29,7 @@ func NewQRecordAvroConverter( } func (qac *QRecordAvroConverter) Convert() (map[string]interface{}, error) { - m := map[string]interface{}{} + m := make(map[string]interface{}, len(qac.QRecord)) for idx, val := range qac.QRecord { key := qac.ColNames[idx] diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go index d7f25c7f5d..572c788630 100644 --- a/flow/model/qrecord_batch.go +++ b/flow/model/qrecord_batch.go @@ -173,11 +173,6 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { values[i] = timestampTZ case qvalue.QValueKindUUID: - if qValue.Value == nil { - values[i] = nil - break - } - v, ok := qValue.Value.([16]byte) // treat it as byte slice if !ok { src.err = fmt.Errorf("invalid UUID value %v", qValue.Value) diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 7c902ffcb5..4f5488cbff 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -169,6 +169,10 @@ func NewQValueAvroConverter(value QValue, targetDWH QDWHType, nullable bool) *QV } func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { + if c.Nullable && c.Value.Value == nil { + return nil, nil + } + switch c.Value.Kind { case QValueKindInvalid: // we will attempt to convert invalid to a string @@ -457,14 +461,12 @@ func (c *QValueAvroConverter) processNullableUnion( avroType string, value interface{}, ) (interface{}, error) { - if value == nil && c.Nullable { - return nil, nil - } - if c.Nullable { + if value == nil { + return nil, nil + } return goavro.Union(avroType, value), nil } - return value, nil }