diff --git a/flow/connectors/snowflake/client.go b/flow/connectors/snowflake/client.go index 14bba3d561..285b4882b8 100644 --- a/flow/connectors/snowflake/client.go +++ b/flow/connectors/snowflake/client.go @@ -10,6 +10,7 @@ import ( peersql "github.com/PeerDB-io/peer-flow/connectors/sql" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" util "github.com/PeerDB-io/peer-flow/utils" ) @@ -55,7 +56,7 @@ func NewSnowflakeClient(ctx context.Context, config *protos.SnowflakeConfig) (*S } genericExecutor := *peersql.NewGenericSQLQueryExecutor( - ctx, database, snowflakeTypeToQValueKindMap, qValueKindToSnowflakeTypeMap) + ctx, database, snowflakeTypeToQValueKindMap, qvalue.QValueKindToSnowflakeTypeMap) return &SnowflakeClient{ GenericSQLQueryExecutor: genericExecutor, diff --git a/flow/connectors/snowflake/qvalue_convert.go b/flow/connectors/snowflake/qvalue_convert.go index b88517856d..421281834c 100644 --- a/flow/connectors/snowflake/qvalue_convert.go +++ b/flow/connectors/snowflake/qvalue_convert.go @@ -6,39 +6,6 @@ import ( "github.com/PeerDB-io/peer-flow/model/qvalue" ) -var qValueKindToSnowflakeTypeMap = map[qvalue.QValueKind]string{ - qvalue.QValueKindBoolean: "BOOLEAN", - qvalue.QValueKindInt16: "INTEGER", - qvalue.QValueKindInt32: "INTEGER", - qvalue.QValueKindInt64: "INTEGER", - qvalue.QValueKindFloat32: "FLOAT", - qvalue.QValueKindFloat64: "FLOAT", - qvalue.QValueKindNumeric: "NUMBER(38, 9)", - qvalue.QValueKindString: "STRING", - qvalue.QValueKindJSON: "VARIANT", - qvalue.QValueKindTimestamp: "TIMESTAMP_NTZ", - qvalue.QValueKindTimestampTZ: "TIMESTAMP_TZ", - qvalue.QValueKindTime: "TIME", - qvalue.QValueKindDate: "DATE", - qvalue.QValueKindBit: "BINARY", - qvalue.QValueKindBytes: "BINARY", - qvalue.QValueKindStruct: "STRING", - qvalue.QValueKindUUID: "STRING", - qvalue.QValueKindTimeTZ: "STRING", - qvalue.QValueKindInvalid: "STRING", - qvalue.QValueKindHStore: "STRING", - qvalue.QValueKindGeography: "GEOGRAPHY", - qvalue.QValueKindGeometry: "GEOMETRY", - qvalue.QValueKindPoint: "GEOMETRY", - - // array types will be mapped to VARIANT - qvalue.QValueKindArrayFloat32: "VARIANT", - qvalue.QValueKindArrayFloat64: "VARIANT", - qvalue.QValueKindArrayInt32: "VARIANT", - qvalue.QValueKindArrayInt64: "VARIANT", - qvalue.QValueKindArrayString: "VARIANT", -} - var snowflakeTypeToQValueKindMap = map[string]qvalue.QValueKind{ "INT": qvalue.QValueKindInt32, "BIGINT": qvalue.QValueKindInt64, @@ -67,11 +34,13 @@ var snowflakeTypeToQValueKindMap = map[string]qvalue.QValueKind{ "GEOGRAPHY": qvalue.QValueKindGeography, } -func qValueKindToSnowflakeType(colType qvalue.QValueKind) string { - if val, ok := qValueKindToSnowflakeTypeMap[colType]; ok { - return val +func qValueKindToSnowflakeType(colType qvalue.QValueKind) (string, error) { + val, err := colType.ToDWHColumnType(qvalue.QDWHTypeSnowflake) + if err != nil { + return "", err } - return "STRING" + + return val, err } func snowflakeTypeToQValueKind(name string) (qvalue.QValueKind, error) { diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 3d89d5b5f6..29dc3eaf47 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -451,9 +451,13 @@ func (c *SnowflakeConnector) ReplayTableSchemaDeltas(flowJobName string, } for _, addedColumn := range schemaDelta.AddedColumns { + sfColtype, err := qValueKindToSnowflakeType(qvalue.QValueKind(addedColumn.ColumnType)) + if err != nil { + return fmt.Errorf("failed to convert column type %s to snowflake type: %w", + addedColumn.ColumnType, err) + } _, err = tableSchemaModifyTx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN \"%s\" %s", - schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), - qValueKindToSnowflakeType(qvalue.QValueKind(addedColumn.ColumnType)))) + schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), sfColtype)) if err != nil { return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName, schemaDelta.DstTableName, err) @@ -876,8 +880,12 @@ func generateCreateTableSQLForNormalizedTable( createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns)) for columnName, genericColumnType := range sourceTableSchema.Columns { columnNameUpper := strings.ToUpper(columnName) - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper, - qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType)))) + sfColType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType)) + if err != nil { + log.Warnf("failed to convert column type %s to snowflake type: %v", genericColumnType, err) + continue + } + createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper, sfColType)) } // add a _peerdb_is_deleted column to the normalized table @@ -942,7 +950,12 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement( flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns)) for columnName, genericColumnType := range normalizedTableSchema.Columns { - sfType := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType)) + sfType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType)) + if err != nil { + return 0, fmt.Errorf("failed to convert column type %s to snowflake type: %w", + genericColumnType, err) + } + targetColumnName := fmt.Sprintf(`"%s"`, strings.ToUpper(columnName)) switch qvalue.QValueKind(genericColumnType) { case qvalue.QValueKindBytes, qvalue.QValueKindBit: diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index 51e2ee2d68..b2eb1b1e48 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -50,41 +50,47 @@ func QValueKindIsArray(kind QValueKind) bool { } } +var QValueKindToSnowflakeTypeMap = map[QValueKind]string{ + QValueKindBoolean: "BOOLEAN", + QValueKindInt16: "INTEGER", + QValueKindInt32: "INTEGER", + QValueKindInt64: "INTEGER", + QValueKindFloat32: "FLOAT", + QValueKindFloat64: "FLOAT", + QValueKindNumeric: "NUMBER(38, 9)", + QValueKindString: "STRING", + QValueKindJSON: "VARIANT", + QValueKindTimestamp: "TIMESTAMP_NTZ", + QValueKindTimestampTZ: "TIMESTAMP_TZ", + QValueKindTime: "TIME", + QValueKindDate: "DATE", + QValueKindBit: "BINARY", + QValueKindBytes: "BINARY", + QValueKindStruct: "STRING", + QValueKindUUID: "STRING", + QValueKindTimeTZ: "STRING", + QValueKindInvalid: "STRING", + QValueKindHStore: "STRING", + QValueKindGeography: "GEOGRAPHY", + QValueKindGeometry: "GEOMETRY", + QValueKindPoint: "GEOMETRY", + + // array types will be mapped to VARIANT + QValueKindArrayFloat32: "VARIANT", + QValueKindArrayFloat64: "VARIANT", + QValueKindArrayInt32: "VARIANT", + QValueKindArrayInt64: "VARIANT", + QValueKindArrayString: "VARIANT", +} + func (kind QValueKind) ToDWHColumnType(dwhType QDWHType) (string, error) { if dwhType != QDWHTypeSnowflake { return "", fmt.Errorf("unsupported DWH type: %v", dwhType) } - switch kind { - case QValueKindFloat32, QValueKindFloat64: - return "FLOAT", nil - case QValueKindInt16, QValueKindInt32, QValueKindInt64: - return "INTEGER", nil - case QValueKindBoolean: - return "BOOLEAN", nil - case QValueKindString: - return "VARCHAR", nil - case QValueKindTimestamp, QValueKindTimestampTZ: - return "TIMESTAMP_NTZ", nil // or TIMESTAMP_TZ based on your needs - case QValueKindDate: - return "DATE", nil - case QValueKindTime, QValueKindTimeTZ: - return "TIME", nil - case QValueKindNumeric: - return "NUMBER", nil - case QValueKindBytes: - return "BINARY", nil - case QValueKindUUID: - return "VARCHAR", nil // Snowflake doesn't have a native UUID type - case QValueKindJSON: - return "VARIANT", nil - case QValueKindArrayFloat32, - QValueKindArrayFloat64, - QValueKindArrayInt32, - QValueKindArrayInt64, - QValueKindArrayString: - return "ARRAY", nil - default: - return "", fmt.Errorf("unsupported QValueKind: %v for dwhtype: %v", kind, dwhType) + if val, ok := QValueKindToSnowflakeTypeMap[kind]; ok { + return val, nil + } else { + return "STRING", nil } }