diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index e1f57e9c28..4ee7425962 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -963,7 +963,7 @@ func (c *BigQueryConnector) SetupNormalizedTables( columns[idx] = &bigquery.FieldSchema{ Name: colName, Type: qValueKindToBigQueryType(genericColType), - Repeated: strings.Contains(genericColType, "array"), + Repeated: qvalue.QValueKind(genericColType).IsArray(), } idx++ } diff --git a/flow/connectors/bigquery/qrecord_value_saver.go b/flow/connectors/bigquery/qrecord_value_saver.go index e724cc4a55..202ac3df4d 100644 --- a/flow/connectors/bigquery/qrecord_value_saver.go +++ b/flow/connectors/bigquery/qrecord_value_saver.go @@ -49,9 +49,10 @@ func (q QRecordValueSaver) Save() (map[string]bigquery.Value, string, error) { for i, v := range q.Record.Entries { k := q.ColumnNames[i] if v.Value == nil { - bqValues[k] = nil - if qvalue.QValueKindIsArray(v.Kind) { + if v.Kind.IsArray() { bqValues[k] = make([]interface{}, 0) + } else { + bqValues[k] = nil } continue } diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index fd4a162e15..1eec3c5bf7 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -11,6 +11,7 @@ import ( "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" @@ -535,7 +536,7 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie for columnName, genericColumnType := range normalizedTableSchema.Columns { columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName)) pgType := qValueKindToPostgresType(genericColumnType) - if strings.Contains(genericColumnType, "array") { + if qvalue.QValueKind(genericColumnType).IsArray() { flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", strings.Trim(columnName, "\""), pgType, columnName)) @@ -589,7 +590,7 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) for columnName, genericColumnType := range normalizedTableSchema.Columns { pgType := qValueKindToPostgresType(genericColumnType) - if strings.Contains(genericColumnType, "array") { + if qvalue.QValueKind(genericColumnType).IsArray() { flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", strings.Trim(columnName, "\""), pgType, columnName)) diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index 4b728f3dad..24ea597ad2 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -1,6 +1,9 @@ package qvalue -import "fmt" +import ( + "fmt" + "strings" +) type QValueKind string @@ -38,17 +41,8 @@ const ( QValueKindArrayString QValueKind = "array_string" ) -func QValueKindIsArray(kind QValueKind) bool { - switch kind { - case QValueKindArrayFloat32, - QValueKindArrayFloat64, - QValueKindArrayInt32, - QValueKindArrayInt64, - QValueKindArrayString: - return true - default: - return false - } +func (kind QValueKind) IsArray() bool { + return strings.HasPrefix(string(kind), "array_") } var QValueKindToSnowflakeTypeMap = map[QValueKind]string{