diff --git a/flow/connectors/bigquery/qrep.go b/flow/connectors/bigquery/qrep.go index 3da50c8e8f..c62489b847 100644 --- a/flow/connectors/bigquery/qrep.go +++ b/flow/connectors/bigquery/qrep.go @@ -8,7 +8,6 @@ import ( "cloud.google.com/go/bigquery" - "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -39,6 +38,7 @@ func (c *BigQueryConnector) SyncQRepRecords( tblMetadata, stream, config.SyncedAtColName, config.SoftDeleteColName) } +// TODO: consider removing this codepath entirely func (c *BigQueryConnector) replayTableSchemaDeltasQRep( ctx context.Context, config *protos.QRepConfig, @@ -74,7 +74,7 @@ func (c *BigQueryConnector) replayTableSchemaDeltasQRep( tableSchemaDelta.AddedColumns = append(tableSchemaDelta.AddedColumns, &protos.FieldDescription{ Name: col.Name, Type: string(col.Type), - TypeModifier: datatypes.MakeNumericTypmod(int32(col.Precision), int32(col.Scale)), + TypeModifier: col.ParsedNumericTypmod.ToTypmod(), }, ) } diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index f1f8d812c6..d5a3d65183 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -13,6 +13,7 @@ import ( "cloud.google.com/go/bigquery" avro "github.com/PeerDB-io/peer-flow/connectors/utils/avro" + "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -268,9 +269,6 @@ func DefineAvroSchema(dstTableName string, } func GetAvroType(bqField *bigquery.FieldSchema) (interface{}, error) { - avroNumericPrecision, avroNumericScale := qvalue.DetermineNumericSettingForDWH( - int16(bqField.Precision), int16(bqField.Scale), protos.DBType_BIGQUERY) - considerRepeated := func(typ string, repeated bool) interface{} { if repeated { return qvalue.AvroSchemaArray{ @@ -341,6 +339,8 @@ func GetAvroType(bqField *bigquery.FieldSchema) (interface{}, error) { }, }, nil case bigquery.BigNumericFieldType: + avroNumericPrecision, avroNumericScale := datatypes.NewConstrainedNumericTypmod(int16(bqField.Precision), + int16(bqField.Scale)).ToDWHNumericConstraints(protos.DBType_BIGQUERY) return qvalue.AvroSchemaNumeric{ Type: "bytes", LogicalType: "decimal", diff --git a/flow/connectors/bigquery/qvalue_convert.go b/flow/connectors/bigquery/qvalue_convert.go index d2d9d9f0c2..d5ef405e3f 100644 --- a/flow/connectors/bigquery/qvalue_convert.go +++ b/flow/connectors/bigquery/qvalue_convert.go @@ -26,7 +26,8 @@ func qValueKindToBigQueryType(columnDescription *protos.FieldDescription, nullab case qvalue.QValueKindFloat32, qvalue.QValueKindFloat64: bqField.Type = bigquery.FloatFieldType case qvalue.QValueKindNumeric: - precision, scale := datatypes.GetNumericTypeForWarehouse(columnDescription.TypeModifier, datatypes.BigQueryNumericCompatibility{}) + precision, scale := datatypes.NewParsedNumericTypmod(columnDescription.TypeModifier). + ToDWHNumericConstraints(protos.DBType_BIGQUERY) bqField.Type = bigquery.BigNumericFieldType bqField.Precision = int64(precision) bqField.Scale = int64(scale) @@ -150,11 +151,15 @@ func qValueKindToBigQueryTypeString(columnDescription *protos.FieldDescription, } func BigQueryFieldToQField(bqField *bigquery.FieldSchema) qvalue.QField { + var parsedNumericTypmod *datatypes.NumericTypmod + if BigQueryTypeToQValueKind(bqField) == qvalue.QValueKindNumeric { + parsedNumericTypmod = datatypes.NewConstrainedNumericTypmod(int16(bqField.Precision), + int16(bqField.Scale)) + } return qvalue.QField{ - Name: bqField.Name, - Type: BigQueryTypeToQValueKind(bqField), - Precision: int16(bqField.Precision), - Scale: int16(bqField.Scale), - Nullable: !bqField.Required, + Name: bqField.Name, + Type: BigQueryTypeToQValueKind(bqField), + Nullable: !bqField.Required, + ParsedNumericTypmod: parsedNumericTypmod, } } diff --git a/flow/connectors/clickhouse/normalize.go b/flow/connectors/clickhouse/normalize.go index 735660b3a4..6e3bf69579 100644 --- a/flow/connectors/clickhouse/normalize.go +++ b/flow/connectors/clickhouse/normalize.go @@ -134,7 +134,8 @@ func generateCreateTableSQLForNormalizedTable( } if colType == qvalue.QValueKindNumeric { - precision, scale := datatypes.GetNumericTypeForWarehouse(column.TypeModifier, datatypes.ClickHouseNumericCompatibility{}) + precision, scale := datatypes.NewParsedNumericTypmod(column.TypeModifier). + ToDWHNumericConstraints(protos.DBType_CLICKHOUSE) if column.Nullable { stmtBuilder.WriteString(fmt.Sprintf("`%s` Nullable(DECIMAL(%d, %d)), ", dstColName, precision, scale)) } else { @@ -301,7 +302,8 @@ func (c *ClickhouseConnector) NormalizeRecords( colSelector.WriteString(fmt.Sprintf("`%s`,", dstColName)) if clickhouseType == "" { if colType == qvalue.QValueKindNumeric { - precision, scale := datatypes.GetNumericTypeForWarehouse(column.TypeModifier, datatypes.ClickHouseNumericCompatibility{}) + precision, scale := datatypes.NewParsedNumericTypmod(column.TypeModifier). + ToDWHNumericConstraints(protos.DBType_CLICKHOUSE) clickhouseType = fmt.Sprintf("Decimal(%d, %d)", precision, scale) } else { var err error diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 4570b46a46..c1fa179d8b 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -15,7 +15,7 @@ import ( "github.com/lib/pq/oid" "github.com/PeerDB-io/peer-flow/connectors/utils" - numeric "github.com/PeerDB-io/peer-flow/datatypes" + "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/shared" @@ -463,7 +463,7 @@ func generateCreateTableSQLForNormalizedTable( pgColumnType = qValueKindToPostgresType(pgColumnType) } if column.Type == "numeric" && column.TypeModifier != -1 { - precision, scale := numeric.ParseNumericTypmod(column.TypeModifier) + precision, scale := datatypes.NewParsedNumericTypmod(column.TypeModifier).PrecisionAndScale() pgColumnType = fmt.Sprintf("numeric(%d,%d)", precision, scale) } var notNull string diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index 2fa6ecd7fe..f9d06104d2 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -78,13 +78,11 @@ func (qe *QRepQueryExecutor) fieldDescriptionsToSchema(fds []pgconn.FieldDescrip // TODO fix this. cnullable := true if ctype == qvalue.QValueKindNumeric { - precision, scale := datatypes.ParseNumericTypmod(fd.TypeModifier) qfields[i] = qvalue.QField{ - Name: cname, - Type: ctype, - Nullable: cnullable, - Precision: precision, - Scale: scale, + Name: cname, + Type: ctype, + Nullable: cnullable, + ParsedNumericTypmod: datatypes.NewParsedNumericTypmod(fd.TypeModifier), } } else { qfields[i] = qvalue.QField{ diff --git a/flow/connectors/snowflake/get_schema_for_tests.go b/flow/connectors/snowflake/get_schema_for_tests.go index c7875f173f..48f1bbeef0 100644 --- a/flow/connectors/snowflake/get_schema_for_tests.go +++ b/flow/connectors/snowflake/get_schema_for_tests.go @@ -22,11 +22,16 @@ func (c *SnowflakeConnector) getTableSchemaForTable(ctx context.Context, tableNa genericColType = qvalue.QValueKindString } - colFields = append(colFields, &protos.FieldDescription{ + colField := &protos.FieldDescription{ Name: columns[i].ColumnName, Type: string(genericColType), - TypeModifier: datatypes.MakeNumericTypmod(sfColumn.NumericPrecision, sfColumn.NumericScale), - }) + TypeModifier: -1, + } + if genericColType == qvalue.QValueKindNumeric { + colField.TypeModifier = datatypes.NewConstrainedNumericTypmod(int16(sfColumn.NumericPrecision), + int16(sfColumn.NumericScale)).ToTypmod() + } + colFields = append(colFields, colField) } return &protos.TableSchema{ diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go index 3f0cfbc63a..37c92ca216 100644 --- a/flow/connectors/snowflake/merge_stmt_generator.go +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -5,7 +5,7 @@ import ( "strings" "github.com/PeerDB-io/peer-flow/connectors/utils" - numeric "github.com/PeerDB-io/peer-flow/datatypes" + "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/shared" @@ -62,7 +62,8 @@ func (m *mergeStmtGenerator) generateMergeStmt(dstTable string) (string, error) // "Microseconds*1000) "+ // "AS %s", toVariantColumnName, columnName, columnName)) case qvalue.QValueKindNumeric: - precision, scale := numeric.GetNumericTypeForWarehouse(column.TypeModifier, numeric.SnowflakeNumericCompatibility{}) + precision, scale := datatypes.NewParsedNumericTypmod(column.TypeModifier). + ToDWHNumericConstraints(protos.DBType_SNOWFLAKE) numericType := fmt.Sprintf("NUMERIC(%d,%d)", precision, scale) flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s", diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index dcc692a574..0aa55e5804 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -19,7 +19,7 @@ import ( metadataStore "github.com/PeerDB-io/peer-flow/connectors/external_metadata" "github.com/PeerDB-io/peer-flow/connectors/utils" - numeric "github.com/PeerDB-io/peer-flow/datatypes" + "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/logger" "github.com/PeerDB-io/peer-flow/model" @@ -383,7 +383,8 @@ func (c *SnowflakeConnector) ReplayTableSchemaDeltas( } if addedColumn.Type == string(qvalue.QValueKindNumeric) { - precision, scale := numeric.GetNumericTypeForWarehouse(addedColumn.TypeModifier, numeric.SnowflakeNumericCompatibility{}) + precision, scale := datatypes.NewParsedNumericTypmod(addedColumn.TypeModifier). + ToDWHNumericConstraints(protos.DBType_SNOWFLAKE) sfColtype = fmt.Sprintf("NUMERIC(%d,%d)", precision, scale) } @@ -686,7 +687,8 @@ func generateCreateTableSQLForNormalizedTable( } if genericColumnType == "numeric" { - precision, scale := numeric.GetNumericTypeForWarehouse(column.TypeModifier, numeric.SnowflakeNumericCompatibility{}) + precision, scale := datatypes.NewParsedNumericTypmod(column.TypeModifier). + ToDWHNumericConstraints(protos.DBType_SNOWFLAKE) sfColType = fmt.Sprintf("NUMERIC(%d,%d)", precision, scale) } diff --git a/flow/datatypes/numeric.go b/flow/datatypes/numeric.go index cd7791ebec..b33725686e 100644 --- a/flow/datatypes/numeric.go +++ b/flow/datatypes/numeric.go @@ -1,111 +1,163 @@ package datatypes -const ( - // defaults - PeerDBBigQueryScale = 20 - PeerDBSnowflakeScale = 20 - PeerDBClickhouseScale = 38 - VARHDRSZ = 4 -) - -type WarehouseNumericCompatibility interface { - MaxPrecision() int16 - MaxScale() int16 - DefaultPrecisionAndScale() (int16, int16) -} +import ( + "cloud.google.com/go/bigquery" -type ClickHouseNumericCompatibility struct{} + "github.com/PeerDB-io/peer-flow/generated/protos" +) -func (ClickHouseNumericCompatibility) MaxPrecision() int16 { - return 76 -} +const ( + VARHDRSZ = 4 + + // default scale + bigQueryDefaultScale = bigquery.BigNumericScaleDigits + snowflakeDefaultScale = 20 + clickHouseDefaultScale = 38 + genericDefaultScale = 20 + + // max scale + bigQueryMaxScale = bigquery.BigNumericScaleDigits + snowflakeMaxScale = 37 + clickHouseMaxScale = 38 + genericMaxScale = 37 + + // default/max precision + bigQueryPrecision = bigquery.BigNumericPrecisionDigits + snowflakePrecision = 38 + clickHousePrecision = 76 + genericPrecision = 38 +) -func (ClickHouseNumericCompatibility) MaxScale() int16 { - return 38 +var defaultScaleMap = map[protos.DBType]int16{ + protos.DBType_BIGQUERY: bigQueryDefaultScale, + protos.DBType_SNOWFLAKE: snowflakeDefaultScale, + protos.DBType_CLICKHOUSE: clickHouseDefaultScale, } -func (c ClickHouseNumericCompatibility) DefaultPrecisionAndScale() (int16, int16) { - return c.MaxPrecision(), PeerDBClickhouseScale +var maxScaleMap = map[protos.DBType]int16{ + protos.DBType_BIGQUERY: bigQueryMaxScale, + protos.DBType_SNOWFLAKE: snowflakeMaxScale, + protos.DBType_CLICKHOUSE: clickHouseMaxScale, } -type SnowflakeNumericCompatibility struct{} - -func (SnowflakeNumericCompatibility) MaxPrecision() int16 { - return 38 +var precisionMap = map[protos.DBType]int16{ + protos.DBType_BIGQUERY: bigQueryPrecision, + protos.DBType_SNOWFLAKE: snowflakePrecision, + protos.DBType_CLICKHOUSE: clickHousePrecision, } -func (SnowflakeNumericCompatibility) MaxScale() int16 { - return 37 +func getMaxPrecisionForDWH(dwh protos.DBType) int16 { + precision, ok := precisionMap[dwh] + if !ok { + return genericPrecision + } + return precision } -func (s SnowflakeNumericCompatibility) DefaultPrecisionAndScale() (int16, int16) { - return s.MaxPrecision(), PeerDBSnowflakeScale +func getMaxScaleForDWH(dwh protos.DBType) int16 { + scale, ok := maxScaleMap[dwh] + if !ok { + return genericMaxScale + } + return scale } -type BigQueryNumericCompatibility struct{} - -func (BigQueryNumericCompatibility) MaxPrecision() int16 { - return 38 +func getDefaultPrecisionAndScaleForDWH(dwh protos.DBType) (int16, int16) { + defaultScale, ok := defaultScaleMap[dwh] + if !ok { + return getMaxPrecisionForDWH(dwh), genericDefaultScale + } + return getMaxPrecisionForDWH(dwh), defaultScale } -func (BigQueryNumericCompatibility) MaxScale() int16 { - return 20 +func isValidPrecision(precision int16, dwh protos.DBType) bool { + return precision > 0 && precision <= getMaxPrecisionForDWH(dwh) } -func (b BigQueryNumericCompatibility) DefaultPrecisionAndScale() (int16, int16) { - return b.MaxPrecision(), PeerDBBigQueryScale +func isValidScale(precision int16, scale int16, dwh protos.DBType) bool { + return scale >= 0 && + isValidPrecision(precision, dwh) && + scale <= getMaxScaleForDWH(dwh) } -type DefaultNumericCompatibility struct{} +/* +As far as my understanding from Postgres source code goes: + 1. typmod itself is a 32-bit integer. + 2. In the case of NUMERICs, it will be -1 or > VARHDRSZ. + 3. If it is -1, it means that the precision and scale are not specified and it is an "unconstrained" NUMERIC. + 4. If it is > VARHDRSZ, it means that the precision is specified and scale MAY be specified. Otherwise, scale defaults to 0. + 5. This is a "constrained" NUMERIC. Precision in this case ranges only from 1 to 1000, far less than the unconstrained limit. + 6. The scale in this case ranges from -1000 to 1000. Yes, it can be negative. Yes, it can be more than the precision. -func (DefaultNumericCompatibility) MaxPrecision() int16 { - return 38 +Currently, no DWH supports the two weird cases of scales in Postgres NUMERICs. Expected is that the 0 <= scale < precision. +In this case, we will default to the default scale and maximum precision for the DWH. +*/ +type NumericTypmod struct { + constrained bool + precision int16 + scale int16 } -func (DefaultNumericCompatibility) MaxScale() int16 { - return 37 -} - -func (DefaultNumericCompatibility) DefaultPrecisionAndScale() (int16, int16) { - return 38, 20 -} +// This is to reverse what make_numeric_typmod of Postgres does: +// logic copied from: https://github.com/postgres/postgres/blob/c4d5cb71d229095a39fda1121a75ee40e6069a2a/src/backend/utils/adt/numeric.c#L929 +// Maps most "invalid" typmods to be unconstrained (same as -1) +func NewParsedNumericTypmod(typmod int32) *NumericTypmod { + if typmod < VARHDRSZ { + return &NumericTypmod{ + constrained: false, + } + } -func IsValidPrecision(precision int16, warehouseNumeric WarehouseNumericCompatibility) bool { - return precision <= warehouseNumeric.MaxPrecision() + typmod -= VARHDRSZ + // if precision or scale are out of bounds, switch to unconstrained and hope for the best + precision := int16((typmod >> 16) & 0xFFFF) + scale := int16(((typmod & 0x7ff) ^ 1024) - 1024) + if precision < 1 || precision > 1000 || scale < -1000 || scale > 1000 { + return &NumericTypmod{ + constrained: false, + } + } + return &NumericTypmod{ + constrained: true, + precision: int16((typmod >> 16) & 0xFFFF), + scale: int16(((typmod & 0x7ff) ^ 1024) - 1024), + } } -func IsValidPrecisionAndScale(precision, scale int16, warehouseNumeric WarehouseNumericCompatibility) bool { - return IsValidPrecision(precision, warehouseNumeric) && scale <= warehouseNumeric.MaxScale() +// responsibility of caller to ensure sensible values are passed in +func NewConstrainedNumericTypmod(precision int16, scale int16) *NumericTypmod { + return &NumericTypmod{ + constrained: true, + precision: precision, + scale: scale, + } } -func MakeNumericTypmod(precision int32, scale int32) int32 { - if precision == 0 && scale == 0 { +func (t *NumericTypmod) ToTypmod() int32 { + if t == nil || !t.constrained { return -1 } - return (precision << 16) | (scale & 0x7ff) + VARHDRSZ + return ((int32(t.precision) << 16) | (int32(t.scale) & 0x7ff)) + VARHDRSZ } -// This is to reverse what make_numeric_typmod of Postgres does: -// https://github.com/postgres/postgres/blob/21912e3c0262e2cfe64856e028799d6927862563/src/backend/utils/adt/numeric.c#L897 -func ParseNumericTypmod(typmod int32) (int16, int16) { - offsetMod := typmod - VARHDRSZ - precision := int16((offsetMod >> 16) & 0x7FFF) - scale := int16(offsetMod & 0x7FFF) - return precision, scale +func (t *NumericTypmod) PrecisionAndScale() (int16, int16) { + if t == nil { + return 0, 0 + } + return t.precision, t.scale } -func GetNumericTypeForWarehouse(typmod int32, warehouseNumeric WarehouseNumericCompatibility) (int16, int16) { - if typmod == -1 { - return warehouseNumeric.DefaultPrecisionAndScale() +func (t *NumericTypmod) ToDWHNumericConstraints(dwh protos.DBType) (int16, int16) { + if t == nil || !t.constrained { + return getDefaultPrecisionAndScaleForDWH(dwh) } - precision, scale := ParseNumericTypmod(typmod) - if !IsValidPrecision(precision, warehouseNumeric) { - precision = warehouseNumeric.MaxPrecision() + precision, scale := t.precision, t.scale + if !isValidPrecision(t.precision, dwh) { + precision = getMaxPrecisionForDWH(dwh) } - - if !IsValidPrecisionAndScale(precision, scale, warehouseNumeric) { - precision, scale = warehouseNumeric.DefaultPrecisionAndScale() + if !isValidScale(t.precision, t.scale, dwh) { + precision, scale = getDefaultPrecisionAndScaleForDWH(dwh) } return precision, scale diff --git a/flow/datatypes/numeric_test.go b/flow/datatypes/numeric_test.go new file mode 100644 index 0000000000..ac9715d47d --- /dev/null +++ b/flow/datatypes/numeric_test.go @@ -0,0 +1,59 @@ +package datatypes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/PeerDB-io/peer-flow/generated/protos" +) + +var parsingTypmodTests = map[int32]NumericTypmod{ + -1: {constrained: false}, + 0: {constrained: false}, + 3: {constrained: false}, + 65540: {constrained: true, precision: 1, scale: 0}, + 65541: {constrained: true, precision: 1, scale: 1}, + 66588: {constrained: true, precision: 1, scale: -1000}, + 65536004: {constrained: true, precision: 1000, scale: 0}, + 65537004: {constrained: true, precision: 1000, scale: 1000}, + 65537052: {constrained: true, precision: 1000, scale: -1000}, + 65538051: {constrained: true, precision: 1000, scale: -1}, + // precision 1001 onwards should trigger unconstrained + 65601540: {constrained: false}, +} + +func TestNewParsedNumericTypmod(t *testing.T) { + for typmod, expected := range parsingTypmodTests { + parsed := NewParsedNumericTypmod(typmod) + assert.Equal(t, expected, *parsed) + } +} + +func TestParsedNumericTypmod_ToTypmod(t *testing.T) { + for expected, parsed := range parsingTypmodTests { + typmod := parsed.ToTypmod() + if !parsed.constrained { + assert.Equal(t, int32(-1), typmod) + continue + } + assert.Equal(t, expected, typmod) + } +} + +func TestParsedNumericTypmod_ToDWHNumericConstraints(t *testing.T) { + for _, parsed := range parsingTypmodTests { + for _, dwh := range []protos.DBType{ + protos.DBType_BIGQUERY, + protos.DBType_SNOWFLAKE, + protos.DBType_CLICKHOUSE, + } { + precision, scale := parsed.ToDWHNumericConstraints(dwh) + assert.LessOrEqual(t, precision, getMaxPrecisionForDWH(dwh)) + assert.LessOrEqual(t, scale, getMaxScaleForDWH(dwh)) + assert.Positive(t, precision) + assert.GreaterOrEqual(t, scale, int16(0)) + assert.LessOrEqual(t, scale, precision) + } + } +} diff --git a/flow/e2e/bigquery/peer_flow_bq_test.go b/flow/e2e/bigquery/peer_flow_bq_test.go index f1c947593e..bbb8b9a407 100644 --- a/flow/e2e/bigquery/peer_flow_bq_test.go +++ b/flow/e2e/bigquery/peer_flow_bq_test.go @@ -404,7 +404,8 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Types_BQ() { CURRENT_DATE,1.23,1.234,'10.0.0.0/32'::inet,1, '5 years 2 months 29 days 1 minute 2 seconds 200 milliseconds 20000 microseconds'::interval, '{"sai":-8.02139037433155}'::json,'{"sai":1}'::jsonb,'08:00:2b:01:02:03'::macaddr, - 1.2,1.23,4::oid,1.23,1,1,1,'test',now(),now(),now()::time,now()::timetz, + 1.2,11123984799238749821734987124298343982.73497239472819454350349549120834318911, + 4::oid,1.23,1,1,1,'test',now(),now(),now()::time,now()::timetz, 'fat & rat'::tsquery,'a fat cat sat on a mat and ate a fat rat'::tsvector, txid_current_snapshot(), '66073c38-b8df-4bdb-bbca-1c97596b8940'::uuid,xmlcomment('hello'), diff --git a/flow/e2e/clickhouse/clickhouse.go b/flow/e2e/clickhouse/clickhouse.go index 404ca2cc6a..e4eb15adc0 100644 --- a/flow/e2e/clickhouse/clickhouse.go +++ b/flow/e2e/clickhouse/clickhouse.go @@ -121,11 +121,10 @@ func (s ClickHouseSuite) GetRows(table string, cols string) (*model.QRecordBatch return nil, fmt.Errorf("failed to resolve QValueKind for %s", ty.DatabaseTypeName()) } batch.Schema.Fields = append(batch.Schema.Fields, qvalue.QField{ - Name: ty.Name(), - Type: qkind, - Precision: 0, - Scale: 0, - Nullable: nullable, + Name: ty.Name(), + Type: qkind, + Nullable: nullable, + ParsedNumericTypmod: nil, }) } diff --git a/flow/model/conversion_avro.go b/flow/model/conversion_avro.go index 8f52c44611..15233aa7d1 100644 --- a/flow/model/conversion_avro.go +++ b/flow/model/conversion_avro.go @@ -75,7 +75,7 @@ func GetAvroSchemaDefinition( avroFields := make([]QRecordAvroField, 0, len(qRecordSchema.Fields)) for _, qField := range qRecordSchema.Fields { - avroType, err := qvalue.GetAvroSchemaFromQValueKind(qField.Type, targetDWH, qField.Precision, qField.Scale) + avroType, err := qvalue.GetAvroSchemaFromQValueKind(qField.Type, qField.ParsedNumericTypmod, targetDWH) if err != nil { return nil, err } diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 648a6aa7ac..9ccc1040d2 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -50,10 +50,12 @@ type AvroSchemaField struct { LogicalType string `json:"logicalType,omitempty"` } -func TruncateOrLogNumeric(num decimal.Decimal, precision int16, scale int16, targetDB protos.DBType) (decimal.Decimal, error) { - if targetDB == protos.DBType_SNOWFLAKE || targetDB == protos.DBType_BIGQUERY { +func truncateOrLogNumeric(num decimal.Decimal, + parsedNumericTypmod *datatypes.NumericTypmod, targetDWH protos.DBType, +) (decimal.Decimal, error) { + if targetDWH == protos.DBType_SNOWFLAKE || targetDWH == protos.DBType_BIGQUERY { bidigi := datatypes.CountDigits(num.BigInt()) - avroPrecision, avroScale := DetermineNumericSettingForDWH(precision, scale, targetDB) + avroPrecision, avroScale := parsedNumericTypmod.ToDWHNumericConstraints(targetDWH) if bidigi+int(avroScale) > int(avroPrecision) { slog.Warn("Clearing NUMERIC value with too many digits", slog.Any("number", num)) return num, errors.New("invalid numeric") @@ -72,7 +74,7 @@ func TruncateOrLogNumeric(num decimal.Decimal, precision int16, scale int16, tar // // For example, QValueKindInt64 would return an AvroLogicalSchema of "long". Unsupported QValueKinds // will return an error. -func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH protos.DBType, precision int16, scale int16) (interface{}, error) { +func GetAvroSchemaFromQValueKind(kind QValueKind, parsedNumericTypmod *datatypes.NumericTypmod, targetDWH protos.DBType) (interface{}, error) { switch kind { case QValueKindString: return "string", nil @@ -98,7 +100,7 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH protos.DBType, preci case QValueKindBytes: return "bytes", nil case QValueKindNumeric: - avroNumericPrecision, avroNumericScale := DetermineNumericSettingForDWH(precision, scale, targetDWH) + avroNumericPrecision, avroNumericScale := parsedNumericTypmod.ToDWHNumericConstraints(targetDWH) return AvroSchemaNumeric{ Type: "bytes", LogicalType: "decimal", @@ -442,7 +444,7 @@ func (c *QValueAvroConverter) processNullableUnion( } func (c *QValueAvroConverter) processNumeric(num decimal.Decimal) interface{} { - num, err := TruncateOrLogNumeric(num, c.Precision, c.Scale, c.TargetDWH) + num, err := truncateOrLogNumeric(num, c.ParsedNumericTypmod, c.TargetDWH) if err != nil { return nil } diff --git a/flow/model/qvalue/dwh.go b/flow/model/qvalue/dwh.go index 49c359b885..091c3aa411 100644 --- a/flow/model/qvalue/dwh.go +++ b/flow/model/qvalue/dwh.go @@ -5,26 +5,9 @@ import ( "go.temporal.io/sdk/log" - numeric "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" ) -func DetermineNumericSettingForDWH(precision int16, scale int16, dwh protos.DBType) (int16, int16) { - var warehouseNumeric numeric.WarehouseNumericCompatibility - switch dwh { - case protos.DBType_CLICKHOUSE: - warehouseNumeric = numeric.ClickHouseNumericCompatibility{} - case protos.DBType_SNOWFLAKE: - warehouseNumeric = numeric.SnowflakeNumericCompatibility{} - case protos.DBType_BIGQUERY: - warehouseNumeric = numeric.BigQueryNumericCompatibility{} - default: - warehouseNumeric = numeric.DefaultNumericCompatibility{} - } - - return numeric.GetNumericTypeForWarehouse(numeric.MakeNumericTypmod(int32(precision), int32(scale)), warehouseNumeric) -} - // Bigquery will not allow timestamp if it is less than 1AD and more than 9999AD func DisallowedTimestamp(dwh protos.DBType, t time.Time, logger log.Logger) bool { if dwh == protos.DBType_BIGQUERY { diff --git a/flow/model/qvalue/qschema.go b/flow/model/qvalue/qschema.go index a956968ac1..535531a922 100644 --- a/flow/model/qvalue/qschema.go +++ b/flow/model/qvalue/qschema.go @@ -2,14 +2,16 @@ package qvalue import ( "strings" + + "github.com/PeerDB-io/peer-flow/datatypes" ) type QField struct { - Name string - Type QValueKind - Precision int16 - Scale int16 - Nullable bool + Name string + // nil if not a numeric column + ParsedNumericTypmod *datatypes.NumericTypmod + Type QValueKind + Nullable bool } type QRecordSchema struct {