diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 9c117824ad..93bedcc337 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -494,7 +494,7 @@ func (c *BigQueryConnector) NormalizeRecords(ctx context.Context, req *model.Nor SoftDelete: req.SoftDelete, }) if mergeErr != nil { - return nil, err + return nil, mergeErr } err = c.pgMetadata.UpdateNormalizeBatchID(ctx, req.FlowJobName, batchId) diff --git a/flow/connectors/postgres/qrep_query_executor_test.go b/flow/connectors/postgres/qrep_query_executor_test.go index c8ceaee9a2..c2bc037372 100644 --- a/flow/connectors/postgres/qrep_query_executor_test.go +++ b/flow/connectors/postgres/qrep_query_executor_test.go @@ -4,12 +4,12 @@ import ( "bytes" "context" "fmt" - "math/big" "testing" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" + "github.com/shopspring/decimal" "github.com/PeerDB-io/peer-flow/connectors/utils/catalog" ) @@ -234,7 +234,7 @@ func TestAllDataTypes(t *testing.T) { } expectedNumeric := "123.456" - actualNumeric := record[10].Value.(*big.Rat).FloatString(3) + actualNumeric := record[10].Value.(decimal.Decimal).String() if actualNumeric != expectedNumeric { t.Fatalf("expected %v, got %v", expectedNumeric, actualNumeric) } diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index 830d9b8450..981c465c44 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -4,20 +4,18 @@ import ( "encoding/json" "errors" "fmt" - "math/big" "net/netip" "strings" "time" "github.com/jackc/pgx/v5/pgtype" "github.com/lib/pq/oid" + "github.com/shopspring/decimal" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/shared" ) -var big10 = big.NewInt(10) - func (c *PostgresConnector) postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind { switch recvOID { case pgtype.BoolOID: @@ -217,8 +215,7 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( val := qvalue.QValue{} if value == nil { - val = qvalue.QValue{Kind: qvalueKind, Value: nil} - return val, nil + return qvalue.QValue{Kind: qvalueKind, Value: nil}, nil } switch qvalueKind { @@ -341,11 +338,11 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( case qvalue.QValueKindNumeric: numVal := value.(pgtype.Numeric) if numVal.Valid { - rat, err := numericToRat(&numVal) + num, err := numericToDecimal(numVal) if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to convert numeric [%v] to rat: %w", value, err) + return qvalue.QValue{}, fmt.Errorf("failed to convert numeric [%v] to decimal: %w", value, err) } - val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: rat} + val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: num} } case qvalue.QValueKindArrayFloat32: return convertToArray[float32](qvalueKind, value) @@ -388,31 +385,16 @@ func (c *PostgresConnector) parseFieldFromPostgresOID(oid uint32, value interfac return parseFieldFromQValueKind(c.postgresOIDToQValueKind(oid), value) } -func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) { - if numVal.Valid { - if numVal.NaN { - // set to nil if NaN - return nil, nil - } - - switch numVal.InfinityModifier { - case pgtype.NegativeInfinity, pgtype.Infinity: - return nil, nil - } - - rat := new(big.Rat).SetInt(numVal.Int) - if numVal.Exp > 0 { - mul := new(big.Int).Exp(big10, big.NewInt(int64(numVal.Exp)), nil) - rat.Mul(rat, new(big.Rat).SetInt(mul)) - } else if numVal.Exp < 0 { - mul := new(big.Int).Exp(big10, big.NewInt(int64(-numVal.Exp)), nil) - rat.Quo(rat, new(big.Rat).SetInt(mul)) - } - return rat, nil +func numericToDecimal(numVal pgtype.Numeric) (interface{}, error) { + switch { + case !numVal.Valid: + return nil, errors.New("invalid numeric") + case numVal.NaN, numVal.InfinityModifier == pgtype.Infinity, + numVal.InfinityModifier == pgtype.NegativeInfinity: + return nil, nil + default: + return decimal.NewFromBigInt(numVal.Int, numVal.Exp), nil } - - // handle invalid numeric - return nil, errors.New("invalid numeric") } func customTypeToQKind(typeName string) qvalue.QValueKind { diff --git a/flow/connectors/snowflake/avro_file_writer_test.go b/flow/connectors/snowflake/avro_file_writer_test.go index 0252b53fc8..13f4a9e3e8 100644 --- a/flow/connectors/snowflake/avro_file_writer_test.go +++ b/flow/connectors/snowflake/avro_file_writer_test.go @@ -3,12 +3,12 @@ package connsnowflake import ( "context" "fmt" - "math/big" "os" "testing" "time" "github.com/google/uuid" + "github.com/shopspring/decimal" "github.com/stretchr/testify/require" avro "github.com/PeerDB-io/peer-flow/connectors/utils/avro" @@ -36,8 +36,7 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue. qvalue.QValueKindTimeTZ, qvalue.QValueKindDate: value = time.Now() case qvalue.QValueKindNumeric: - // create a new big.Rat for numeric data - value = big.NewRat(int64(placeHolder), 1) + value = decimal.New(int64(placeHolder), 1) case qvalue.QValueKindUUID: value = uuid.New() // assuming you have the github.com/google/uuid package case qvalue.QValueKindQChar: diff --git a/flow/connectors/sql/query_executor.go b/flow/connectors/sql/query_executor.go index 2a06adabdb..91972a75c3 100644 --- a/flow/connectors/sql/query_executor.go +++ b/flow/connectors/sql/query_executor.go @@ -6,12 +6,12 @@ import ( "encoding/json" "fmt" "log/slog" - "math/big" "strings" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" "github.com/jmoiron/sqlx" + "github.com/shopspring/decimal" "go.temporal.io/sdk/activity" "go.temporal.io/sdk/log" @@ -212,7 +212,7 @@ func (g *GenericSQLQueryExecutor) processRows(ctx context.Context, rows *sqlx.Ro case qvalue.QValueKindBytes, qvalue.QValueKindBit: values[i] = new([]byte) case qvalue.QValueKindNumeric: - var s sql.NullString + var s sql.Null[decimal.Decimal] values[i] = &s case qvalue.QValueKindUUID: values[i] = new([]byte) @@ -380,13 +380,9 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { } } case qvalue.QValueKindNumeric: - if v, ok := val.(*sql.NullString); ok { + if v, ok := val.(*sql.Null[decimal.Decimal]); ok { if v.Valid { - numeric := new(big.Rat) - if _, ok := numeric.SetString(v.String); !ok { - return qvalue.QValue{}, fmt.Errorf("failed to parse numeric: %v", v.String) - } - return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: numeric}, nil + return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: v.V}, nil } else { return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: nil}, nil } diff --git a/flow/connectors/utils/cdc_records/cdc_records_storage.go b/flow/connectors/utils/cdc_records/cdc_records_storage.go index d25181e75a..e4a8561aea 100644 --- a/flow/connectors/utils/cdc_records/cdc_records_storage.go +++ b/flow/connectors/utils/cdc_records/cdc_records_storage.go @@ -6,13 +6,13 @@ import ( "errors" "fmt" "log/slog" - "math/big" "os" "runtime" "sync/atomic" "time" "github.com/cockroachdb/pebble" + "github.com/shopspring/decimal" "go.temporal.io/sdk/log" "github.com/PeerDB-io/peer-flow/model" @@ -72,7 +72,7 @@ func (c *cdcRecordsStore) initPebbleDB() error { gob.Register(&model.UpdateRecord{}) gob.Register(&model.DeleteRecord{}) gob.Register(time.Time{}) - gob.Register(&big.Rat{}) + gob.Register(decimal.Decimal{}) var err error // we don't want a WAL since cache, we don't want to overwrite another DB either diff --git a/flow/connectors/utils/cdc_records/cdc_records_storage_test.go b/flow/connectors/utils/cdc_records/cdc_records_storage_test.go index c972092536..50796109f3 100644 --- a/flow/connectors/utils/cdc_records/cdc_records_storage_test.go +++ b/flow/connectors/utils/cdc_records/cdc_records_storage_test.go @@ -3,10 +3,10 @@ package cdc_records import ( "crypto/rand" "log/slog" - "math/big" "testing" "time" + "github.com/shopspring/decimal" "github.com/stretchr/testify/require" "github.com/PeerDB-io/peer-flow/model" @@ -27,9 +27,9 @@ func getTimeForTesting(t *testing.T) time.Time { return tv } -func getRatForTesting(t *testing.T) *big.Rat { +func getDecimalForTesting(t *testing.T) decimal.Decimal { t.Helper() - return big.NewRat(123456789, 987654321) + return decimal.New(9876543210, 123) } func genKeyAndRec(t *testing.T) (model.TableWithPkey, model.Record) { @@ -40,7 +40,7 @@ func genKeyAndRec(t *testing.T) (model.TableWithPkey, model.Record) { require.NoError(t, err) tv := getTimeForTesting(t) - rv := getRatForTesting(t) + rv := getDecimalForTesting(t) key := model.TableWithPkey{ TableName: "test_src_tbl", @@ -126,7 +126,7 @@ func TestRecordsTillSpill(t *testing.T) { require.NoError(t, cdcRecordsStore.Close()) } -func TestTimeAndRatEncoding(t *testing.T) { +func TestTimeAndDecimalEncoding(t *testing.T) { t.Parallel() cdcRecordsStore := NewCDCRecordsStore("test_time_encoding") diff --git a/flow/e2e/bigquery/bigquery_helper.go b/flow/e2e/bigquery/bigquery_helper.go index 76c93bd5e4..7c8ab5257c 100644 --- a/flow/e2e/bigquery/bigquery_helper.go +++ b/flow/e2e/bigquery/bigquery_helper.go @@ -12,6 +12,7 @@ import ( "cloud.google.com/go/bigquery" "cloud.google.com/go/civil" + "github.com/shopspring/decimal" "google.golang.org/api/iterator" peer_bq "github.com/PeerDB-io/peer-flow/connectors/bigquery" @@ -227,7 +228,11 @@ func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) { case time.Time: return qvalue.QValue{Kind: qvalue.QValueKindTimestamp, Value: v}, nil case *big.Rat: - return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: v}, nil + val, err := decimal.NewFromString(v.FloatString(32)) + if err != nil { + return qvalue.QValue{}, fmt.Errorf("bqHelper failed to parse as decimal %v", v) + } + return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: val}, nil case []uint8: return qvalue.QValue{Kind: qvalue.QValueKindBytes, Value: v}, nil case []bigquery.Value: @@ -417,37 +422,26 @@ func (b *BigQueryTestHelper) CheckNull(tableName string, colName []string) (bool } // check if NaN, Inf double values are null -func (b *BigQueryTestHelper) CheckDoubleValues(tableName string, c1 string, c2 string) (bool, error) { - command := fmt.Sprintf("SELECT %s, %s FROM `%s.%s`", - c1, c2, b.Config.DatasetId, tableName) +func (b *BigQueryTestHelper) SelectRow(tableName string, cols ...string) ([]bigquery.Value, error) { + command := fmt.Sprintf("SELECT %s FROM `%s.%s`", + strings.Join(cols, ","), b.Config.DatasetId, tableName) q := b.client.Query(command) q.DisableQueryCache = true it, err := q.Read(context.Background()) if err != nil { - return false, fmt.Errorf("failed to run command: %w", err) + return nil, fmt.Errorf("failed to run command: %w", err) } var row []bigquery.Value for { err := it.Next(&row) if err == iterator.Done { - break + return row, nil } if err != nil { - return false, fmt.Errorf("failed to iterate over query results: %w", err) + return nil, fmt.Errorf("failed to iterate over query results: %w", err) } } - - if len(row) == 0 { - return false, nil - } - - floatArr, _ := row[1].([]float64) - if row[0] != nil || len(floatArr) > 0 { - return false, nil - } - - return true, nil } func qValueKindToBqColTypeString(val qvalue.QValueKind) (string, error) { diff --git a/flow/e2e/bigquery/peer_flow_bq_test.go b/flow/e2e/bigquery/peer_flow_bq_test.go index 3dff6e310e..678b3fc15f 100644 --- a/flow/e2e/bigquery/peer_flow_bq_test.go +++ b/flow/e2e/bigquery/peer_flow_bq_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "cloud.google.com/go/bigquery" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/require" @@ -478,7 +479,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_NaN_Doubles_BQ() { srcTableName := s.attachSchemaSuffix("test_nans_bq") dstTableName := "test_nans_bq" _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s (id serial PRIMARY KEY,c1 double precision,c2 double precision[]); + CREATE TABLE IF NOT EXISTS %s (id serial PRIMARY KEY,c1 double precision,c2 double precision[],c3 numeric); `, srcTableName)) require.NoError(s.t, err) @@ -499,13 +500,21 @@ func (s PeerFlowE2ETestSuiteBQ) Test_NaN_Doubles_BQ() { // test inserting various types _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s SELECT 2, 'NaN'::double precision, '{NaN, Infinity, -Infinity}'; - `, srcTableName)) + INSERT INTO %s SELECT 2, 'NaN'::double precision, '{NaN, Infinity, -Infinity}', 'NaN'::numeric; + `, srcTableName)) e2e.EnvNoError(s.t, env, err) e2e.EnvWaitFor(s.t, env, 2*time.Minute, "normalize weird floats", func() bool { - good, err := s.bqHelper.CheckDoubleValues(dstTableName, "c1", "c2") - return err == nil && good + row, err := s.bqHelper.SelectRow(dstTableName, "c1", "c2", "c3") + if err != nil { + return false + } + if len(row) == 0 { + return false + } + + floatArr, ok := row[1].([]bigquery.Value) + return ok && row[0] == nil && len(floatArr) == 0 && row[2] == nil }) env.Cancel() diff --git a/flow/e2e/postgres/qrep_flow_pg_test.go b/flow/e2e/postgres/qrep_flow_pg_test.go index abb7867d24..ec121b8930 100644 --- a/flow/e2e/postgres/qrep_flow_pg_test.go +++ b/flow/e2e/postgres/qrep_flow_pg_test.go @@ -36,15 +36,10 @@ func (s PeerFlowE2ETestSuitePG) setupSourceTable(tableName string, rowCount int) func (s PeerFlowE2ETestSuitePG) comparePGTables(srcSchemaQualified, dstSchemaQualified, selector string) error { // Execute the two EXCEPT queries - if err := s.compareQuery(srcSchemaQualified, dstSchemaQualified, selector); err != nil { - return err - } - if err := s.compareQuery(dstSchemaQualified, srcSchemaQualified, selector); err != nil { - return err - } - - // If no error is returned, then the contents of the two tables are the same - return nil + return errors.Join( + s.compareQuery(srcSchemaQualified, dstSchemaQualified, selector), + s.compareQuery(dstSchemaQualified, srcSchemaQualified, selector), + ) } func (s PeerFlowE2ETestSuitePG) checkEnums(srcSchemaQualified, dstSchemaQualified string) error { diff --git a/flow/e2e/snowflake/snowflake_helper.go b/flow/e2e/snowflake/snowflake_helper.go index c26209c379..e0d41e838d 100644 --- a/flow/e2e/snowflake/snowflake_helper.go +++ b/flow/e2e/snowflake/snowflake_helper.go @@ -5,10 +5,11 @@ import ( "encoding/json" "errors" "fmt" - "math/big" "os" "time" + "github.com/shopspring/decimal" + connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" "github.com/PeerDB-io/peer-flow/e2eshared" "github.com/PeerDB-io/peer-flow/generated/protos" @@ -173,9 +174,8 @@ func (s *SnowflakeTestHelper) RunIntQuery(query string) (int, error) { case qvalue.QValueKindInt64: return int(rec[0].Value.(int64)), nil case qvalue.QValueKindNumeric: - // get big.Rat and convert to int - rat := rec[0].Value.(*big.Rat) - return int(rat.Num().Int64() / rat.Denom().Int64()), nil + val := rec[0].Value.(decimal.Decimal) + return int(val.IntPart()), nil default: return 0, fmt.Errorf("failed to execute query: %s, returned value of type %s", query, rec[0].Kind) } diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go index dcaef74291..d2bcef6acb 100644 --- a/flow/e2e/test_utils.go +++ b/flow/e2e/test_utils.go @@ -271,7 +271,6 @@ func CreateTableForQRep(conn *pgx.Conn, suffix string, tableName string) error { "geography_linestring geography(linestring)", "geometry_polygon geometry(polygon)", "geography_polygon geography(polygon)", - "nannu NUMERIC", "myreal REAL", "myreal2 REAL", "myreal3 REAL", @@ -337,12 +336,8 @@ func PopulateSourceTable(conn *pgx.Conn, suffix string, tableName string, rowCou 'LINESTRING(0 0, 1 1, 2 2)', 'LINESTRING(-74.0060 40.7128, -73.9352 40.7306, -73.9123 40.7831)', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))','POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))', - 'NaN', - 3.14159, - 1, - 1.0, - '10.0.0.0/32', - '1.1.10.2'::cidr + 3.14159, 1, 1.0, + '10.0.0.0/32', '1.1.10.2'::cidr )`, id, uuid.New().String(), uuid.New().String(), uuid.New().String(), uuid.New().String(), uuid.New().String(), uuid.New().String()) @@ -360,12 +355,8 @@ func PopulateSourceTable(conn *pgx.Conn, suffix string, tableName string, rowCou settle_at, settlement_delay_reason, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, my_date, my_time, my_mood, myh, "geometryPoint", geography_point,geometry_linestring, geography_linestring,geometry_polygon, geography_polygon, - nannu, - myreal, - myreal2, - myreal3, - myinet, - mycidr + myreal, myreal2, myreal3, + myinet, mycidr ) VALUES %s; `, suffix, tableName, strings.Join(rows, ","))) if err != nil { diff --git a/flow/go.mod b/flow/go.mod index 5126e3146d..154a068d20 100644 --- a/flow/go.mod +++ b/flow/go.mod @@ -31,6 +31,7 @@ require ( github.com/linkedin/goavro/v2 v2.12.0 github.com/microsoft/go-mssqldb v1.7.0 github.com/orcaman/concurrent-map/v2 v2.0.1 + github.com/shopspring/decimal v1.3.1 github.com/slack-go/slack v0.12.5 github.com/snowflakedb/gosnowflake v1.8.0 github.com/stretchr/testify v1.9.0 @@ -89,7 +90,6 @@ require ( github.com/prometheus/procfs v0.13.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/segmentio/asm v1.2.0 // indirect - github.com/shopspring/decimal v1.3.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go index 4cf4a11017..08c5ce7770 100644 --- a/flow/model/qrecord_batch.go +++ b/flow/model/qrecord_batch.go @@ -4,11 +4,11 @@ import ( "errors" "fmt" "log/slog" - "math/big" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" + "github.com/shopspring/decimal" "github.com/PeerDB-io/peer-flow/geo" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -198,24 +198,12 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { values[i] = uuid.UUID(v) case qvalue.QValueKindNumeric: - v, ok := qValue.Value.(*big.Rat) + v, ok := qValue.Value.(decimal.Decimal) if !ok { src.err = fmt.Errorf("invalid Numeric value %v", qValue.Value) return nil, src.err } - if v == nil { - values[i] = pgtype.Numeric{ - Int: nil, - Exp: 0, - NaN: true, - InfinityModifier: pgtype.Finite, - Valid: true, - } - break - } - - // TODO: account for precision and scale issues. - values[i] = v.FloatString(38) + values[i] = v case qvalue.QValueKindBytes, qvalue.QValueKindBit: v, ok := qValue.Value.([]byte) diff --git a/flow/model/qrecord_test.go b/flow/model/qrecord_test.go index 6c685a8f41..12b59c67bb 100644 --- a/flow/model/qrecord_test.go +++ b/flow/model/qrecord_test.go @@ -1,10 +1,10 @@ package model_test import ( - "math/big" "testing" "github.com/google/uuid" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/PeerDB-io/peer-flow/e2eshared" @@ -35,14 +35,14 @@ func TestEquals(t *testing.T) { }, { name: "Equal - Same numeric", - q1: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: big.NewRat(10, 2)}}, + q1: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: decimal.NewFromInt(5)}}, q2: []qvalue.QValue{{Kind: qvalue.QValueKindString, Value: "5"}}, want: true, }, { name: "Not Equal - Different numeric", - q1: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: big.NewRat(10, 2)}}, - q2: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: "4.99"}}, + q1: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: decimal.NewFromInt(5)}}, + q2: []qvalue.QValue{{Kind: qvalue.QValueKindString, Value: "4.99"}}, want: false, }, } diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index ea90af00ed..0a299cf82f 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -4,11 +4,11 @@ import ( "errors" "fmt" "log/slog" - "math/big" "time" "github.com/google/uuid" "github.com/linkedin/goavro/v2" + "github.com/shopspring/decimal" "go.temporal.io/sdk/log" hstore_util "github.com/PeerDB-io/peer-flow/hstore" @@ -158,7 +158,7 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH QDWHType, precision } type QValueAvroConverter struct { - Value QValue + QValue TargetDWH QDWHType Nullable bool logger log.Logger @@ -166,7 +166,7 @@ type QValueAvroConverter struct { func NewQValueAvroConverter(value QValue, targetDWH QDWHType, nullable bool, logger log.Logger) *QValueAvroConverter { return &QValueAvroConverter{ - Value: value, + QValue: value, TargetDWH: targetDWH, Nullable: nullable, logger: logger, @@ -174,14 +174,14 @@ func NewQValueAvroConverter(value QValue, targetDWH QDWHType, nullable bool, log } func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { - if c.Nullable && c.Value.Value == nil { + if c.Nullable && c.Value == nil { return nil, nil } - switch c.Value.Kind { + switch c.Kind { case QValueKindInvalid: // we will attempt to convert invalid to a string - return c.processNullableUnion("string", c.Value.Value) + return c.processNullableUnion("string", c.Value) case QValueKindTime: t, err := c.processGoTime() if err != nil || t == nil { @@ -284,31 +284,31 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { } return t, nil case QValueKindQChar: - return c.processNullableUnion("string", string(c.Value.Value.(uint8))) + return c.processNullableUnion("string", string(c.Value.(uint8))) case QValueKindString, QValueKindCIDR, QValueKindINET, QValueKindMacaddr: - if c.TargetDWH == QDWHTypeSnowflake && c.Value.Value != nil && - (len(c.Value.Value.(string)) > 15*1024*1024) { + if c.TargetDWH == QDWHTypeSnowflake && c.Value != nil && + (len(c.Value.(string)) > 15*1024*1024) { slog.Warn("Truncating TEXT value > 15MB for Snowflake!") slog.Warn("Check this issue for details: https://github.com/PeerDB-io/peerdb/issues/309") return c.processNullableUnion("string", "") } - return c.processNullableUnion("string", c.Value.Value) + return c.processNullableUnion("string", c.Value) case QValueKindFloat32: if c.TargetDWH == QDWHTypeBigQuery { - return c.processNullableUnion("double", c.Value.Value) + return c.processNullableUnion("double", c.Value) } - return c.processNullableUnion("float", c.Value.Value) + return c.processNullableUnion("float", c.Value) case QValueKindFloat64: if c.TargetDWH == QDWHTypeSnowflake || c.TargetDWH == QDWHTypeBigQuery { - if f32Val, ok := c.Value.Value.(float32); ok { + if f32Val, ok := c.Value.(float32); ok { return c.processNullableUnion("double", float64(f32Val)) } } - return c.processNullableUnion("double", c.Value.Value) + return c.processNullableUnion("double", c.Value) case QValueKindInt16, QValueKindInt32, QValueKindInt64: - return c.processNullableUnion("long", c.Value.Value) + return c.processNullableUnion("long", c.Value) case QValueKindBoolean: - return c.processNullableUnion("boolean", c.Value.Value) + return c.processNullableUnion("boolean", c.Value) case QValueKindStruct: return nil, errors.New("QValueKindStruct not supported") case QValueKindNumeric: @@ -352,16 +352,16 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { case QValueKindGeography, QValueKindGeometry, QValueKindPoint: return c.processGeospatial() default: - return nil, fmt.Errorf("[toavro] unsupported QValueKind: %s", c.Value.Kind) + return nil, fmt.Errorf("[toavro] unsupported QValueKind: %s", c.Kind) } } func (c *QValueAvroConverter) processGoTimeTZ() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - t, ok := c.Value.Value.(time.Time) + t, ok := c.Value.(time.Time) if !ok { return nil, errors.New("invalid TimeTZ value") } @@ -375,11 +375,11 @@ func (c *QValueAvroConverter) processGoTimeTZ() (interface{}, error) { } func (c *QValueAvroConverter) processGoTime() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - t, ok := c.Value.Value.(time.Time) + t, ok := c.Value.(time.Time) if !ok { return nil, errors.New("invalid Time value") } @@ -397,11 +397,11 @@ func (c *QValueAvroConverter) processGoTime() (interface{}, error) { } func (c *QValueAvroConverter) processGoTimestampTZ() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - t, ok := c.Value.Value.(time.Time) + t, ok := c.Value.(time.Time) if !ok { return nil, errors.New("invalid TimestampTZ value") } @@ -422,11 +422,11 @@ func (c *QValueAvroConverter) processGoTimestampTZ() (interface{}, error) { } func (c *QValueAvroConverter) processGoTimestamp() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - t, ok := c.Value.Value.(time.Time) + t, ok := c.Value.(time.Time) if !ok { return nil, errors.New("invalid Timestamp value") } @@ -447,11 +447,11 @@ func (c *QValueAvroConverter) processGoTimestamp() (interface{}, error) { } func (c *QValueAvroConverter) processGoDate() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - t, ok := c.Value.Value.(time.Time) + t, ok := c.Value.(time.Time) if !ok { return nil, errors.New("invalid Time value for Date") } @@ -484,37 +484,32 @@ func (c *QValueAvroConverter) processNullableUnion( } func (c *QValueAvroConverter) processNumeric() (interface{}, error) { - if c.Value.Value == nil { + if c.Value == nil { return nil, nil } - num, ok := c.Value.Value.(*big.Rat) + num, ok := c.Value.(decimal.Decimal) if !ok { - return nil, fmt.Errorf("invalid Numeric value: expected *big.Rat, got %T", c.Value.Value) + return nil, fmt.Errorf("invalid Numeric value: expected decimal.Decimal, got %T", c.Value) } + rat := num.Rat() - if num == nil { - return nil, nil - } - - decimalValue := num.FloatString(100) - num.SetString(decimalValue) if c.Nullable { - return goavro.Union("bytes.decimal", num), nil + return goavro.Union("bytes.decimal", rat), nil } - return num, nil + return rat, nil } func (c *QValueAvroConverter) processBytes() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } if c.TargetDWH == QDWHTypeClickhouse { - bigNum, ok := c.Value.Value.(*big.Rat) + bigNum, ok := c.Value.(decimal.Decimal) if !ok { - return nil, fmt.Errorf("invalid Numeric value: expected float64, got %T", c.Value.Value) + return nil, fmt.Errorf("invalid Numeric value: expected float64, got %T", c.Value) } num, ok := bigNum.Float64() if !ok { @@ -523,7 +518,7 @@ func (c *QValueAvroConverter) processBytes() (interface{}, error) { return goavro.Union("double", num), nil } - byteData, ok := c.Value.Value.([]byte) + byteData, ok := c.Value.([]byte) if !ok { return nil, errors.New("invalid Bytes value") } @@ -536,13 +531,13 @@ func (c *QValueAvroConverter) processBytes() (interface{}, error) { } func (c *QValueAvroConverter) processJSON() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - jsonString, ok := c.Value.Value.(string) + jsonString, ok := c.Value.(string) if !ok { - return nil, fmt.Errorf("invalid JSON value %v", c.Value.Value) + return nil, fmt.Errorf("invalid JSON value %v", c.Value) } if c.Nullable { @@ -563,11 +558,11 @@ func (c *QValueAvroConverter) processJSON() (interface{}, error) { } func (c *QValueAvroConverter) processArrayBoolean() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - arrayData, ok := c.Value.Value.([]bool) + arrayData, ok := c.Value.([]bool) if !ok { return nil, errors.New("invalid Boolean array value") } @@ -580,11 +575,11 @@ func (c *QValueAvroConverter) processArrayBoolean() (interface{}, error) { } func (c *QValueAvroConverter) processArrayTime() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - arrayTime, ok := c.Value.Value.([]time.Time) + arrayTime, ok := c.Value.([]time.Time) if !ok { return nil, errors.New("invalid Timestamp array value") } @@ -608,11 +603,11 @@ func (c *QValueAvroConverter) processArrayTime() (interface{}, error) { } func (c *QValueAvroConverter) processArrayDate() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - arrayDate, ok := c.Value.Value.([]time.Time) + arrayDate, ok := c.Value.([]time.Time) if !ok { return nil, errors.New("invalid Date array value") } @@ -634,13 +629,13 @@ func (c *QValueAvroConverter) processArrayDate() (interface{}, error) { } func (c *QValueAvroConverter) processHStore() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - hstoreString, ok := c.Value.Value.(string) + hstoreString, ok := c.Value.(string) if !ok { - return nil, fmt.Errorf("invalid HSTORE value %v", c.Value.Value) + return nil, fmt.Errorf("invalid HSTORE value %v", c.Value) } jsonString, err := hstore_util.ParseHstore(hstoreString) @@ -666,16 +661,16 @@ func (c *QValueAvroConverter) processHStore() (interface{}, error) { } func (c *QValueAvroConverter) processUUID() (interface{}, error) { - if c.Value.Value == nil { + if c.Value == nil { return nil, nil } - byteData, ok := c.Value.Value.([16]byte) + byteData, ok := c.Value.([16]byte) if !ok { // attempt to convert google.uuid to [16]byte - byteData, ok = c.Value.Value.(uuid.UUID) + byteData, ok = c.Value.(uuid.UUID) if !ok { - return nil, fmt.Errorf("[conversion] invalid UUID value %v", c.Value.Value) + return nil, fmt.Errorf("[conversion] invalid UUID value %v", c.Value) } } @@ -694,13 +689,13 @@ func (c *QValueAvroConverter) processUUID() (interface{}, error) { } func (c *QValueAvroConverter) processGeospatial() (interface{}, error) { - if c.Value.Value == nil { + if c.Value == nil { return nil, nil } - geoString, ok := c.Value.Value.(string) + geoString, ok := c.Value.(string) if !ok { - return nil, fmt.Errorf("[conversion] invalid geospatial value %v", c.Value.Value) + return nil, fmt.Errorf("[conversion] invalid geospatial value %v", c.Value) } if c.Nullable { @@ -710,11 +705,11 @@ func (c *QValueAvroConverter) processGeospatial() (interface{}, error) { } func (c *QValueAvroConverter) processArrayInt16() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - arrayData, ok := c.Value.Value.([]int16) + arrayData, ok := c.Value.([]int16) if !ok { return nil, errors.New("invalid Int16 array value") } @@ -733,11 +728,11 @@ func (c *QValueAvroConverter) processArrayInt16() (interface{}, error) { } func (c *QValueAvroConverter) processArrayInt32() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - arrayData, ok := c.Value.Value.([]int32) + arrayData, ok := c.Value.([]int32) if !ok { return nil, errors.New("invalid Int32 array value") } @@ -750,11 +745,11 @@ func (c *QValueAvroConverter) processArrayInt32() (interface{}, error) { } func (c *QValueAvroConverter) processArrayInt64() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - arrayData, ok := c.Value.Value.([]int64) + arrayData, ok := c.Value.([]int64) if !ok { return nil, errors.New("invalid Int64 array value") } @@ -767,11 +762,11 @@ func (c *QValueAvroConverter) processArrayInt64() (interface{}, error) { } func (c *QValueAvroConverter) processArrayFloat32() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - arrayData, ok := c.Value.Value.([]float32) + arrayData, ok := c.Value.([]float32) if !ok { return nil, errors.New("invalid Float32 array value") } @@ -784,11 +779,11 @@ func (c *QValueAvroConverter) processArrayFloat32() (interface{}, error) { } func (c *QValueAvroConverter) processArrayFloat64() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - arrayData, ok := c.Value.Value.([]float64) + arrayData, ok := c.Value.([]float64) if !ok { return nil, errors.New("invalid Float64 array value") } @@ -801,11 +796,11 @@ func (c *QValueAvroConverter) processArrayFloat64() (interface{}, error) { } func (c *QValueAvroConverter) processArrayString() (interface{}, error) { - if c.Value.Value == nil && c.Nullable { + if c.Value == nil && c.Nullable { return nil, nil } - arrayData, ok := c.Value.Value.([]string) + arrayData, ok := c.Value.([]string) if !ok { return nil, errors.New("invalid String array value") } diff --git a/flow/model/qvalue/qvalue.go b/flow/model/qvalue/qvalue.go index 013e0ca9e1..ae0a3945ab 100644 --- a/flow/model/qvalue/qvalue.go +++ b/flow/model/qvalue/qvalue.go @@ -14,6 +14,7 @@ import ( "cloud.google.com/go/civil" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" + "github.com/shopspring/decimal" geom "github.com/twpayne/go-geos" hstore_util "github.com/PeerDB-io/peer-flow/hstore" @@ -235,18 +236,14 @@ func compareBytes(value1, value2 interface{}) bool { } func compareNumeric(value1, value2 interface{}) bool { - rat1, ok1 := getRat(value1) - rat2, ok2 := getRat(value2) + num1, ok1 := getDecimal(value1) + num2, ok2 := getDecimal(value2) if !ok1 || !ok2 { return false } - if rat1 == nil && rat2 == nil { - return true - } - - return rat1.Cmp(rat2) == 0 + return num1.Equal(num2) } func compareString(value1, value2 interface{}) bool { @@ -503,8 +500,8 @@ func getInt16(v interface{}) (int16, bool) { return int16(value), true case int64: return int16(value), true - case *big.Rat: - return int16(value.Num().Int64()), true + case decimal.Decimal: + return int16(value.IntPart()), true case string: parsed, err := strconv.ParseInt(value, 10, 16) if err == nil { @@ -520,8 +517,8 @@ func getInt32(v interface{}) (int32, bool) { return value, true case int64: return int32(value), true - case *big.Rat: - return int32(value.Num().Int64()), true + case decimal.Decimal: + return int32(value.IntPart()), true case string: parsed, err := strconv.ParseInt(value, 10, 32) if err == nil { @@ -537,8 +534,8 @@ func getInt64(v interface{}) (int64, bool) { return value, true case int32: return int64(value), true - case *big.Rat: - return value.Num().Int64(), true + case decimal.Decimal: + return value.IntPart(), true case string: parsed, err := strconv.ParseInt(value, 10, 64) if err == nil { @@ -610,55 +607,41 @@ func getUUID(v interface{}) (uuid.UUID, bool) { return uuid.UUID{}, false } -// getRat attempts to parse a big.Rat from an interface -func getRat(v interface{}) (*big.Rat, bool) { +// getDecimal attempts to parse a decimal from an interface +func getDecimal(v interface{}) (decimal.Decimal, bool) { switch value := v.(type) { - case *big.Rat: + case decimal.Decimal: return value, true case string: - //nolint:gosec - parsed, ok := new(big.Rat).SetString(value) - if ok { - return parsed, true + parsed, err := decimal.NewFromString(value) + if err != nil { + panic(err) } + return parsed, true case float64: - rat := new(big.Rat) - return rat.SetFloat64(value), true + return decimal.NewFromFloat(value), true case int64: - rat := new(big.Rat) - return rat.SetInt64(value), true + return decimal.NewFromInt(value), true case uint64: - rat := new(big.Rat) - return rat.SetUint64(value), true + return decimal.NewFromBigInt(new(big.Int).SetUint64(value), 0), true case float32: - rat := new(big.Rat) - return rat.SetFloat64(float64(value)), true + return decimal.NewFromFloat32(value), true case int32: - rat := new(big.Rat) - return rat.SetInt64(int64(value)), true + return decimal.NewFromInt(int64(value)), true case uint32: - rat := new(big.Rat) - return rat.SetUint64(uint64(value)), true + return decimal.NewFromInt(int64(value)), true case int: - rat := new(big.Rat) - return rat.SetInt64(int64(value)), true + return decimal.NewFromInt(int64(value)), true case uint: - rat := new(big.Rat) - return rat.SetUint64(uint64(value)), true + return decimal.NewFromInt(int64(value)), true case int8: - rat := new(big.Rat) - return rat.SetInt64(int64(value)), true + return decimal.NewFromInt(int64(value)), true case uint8: - rat := new(big.Rat) - return rat.SetUint64(uint64(value)), true + return decimal.NewFromInt(int64(value)), true case int16: - rat := new(big.Rat) - return rat.SetInt64(int64(value)), true + return decimal.NewFromInt(int64(value)), true case uint16: - rat := new(big.Rat) - return rat.SetUint64(uint64(value)), true - case nil: - return nil, true + return decimal.NewFromInt(int64(value)), true } - return nil, false + return decimal.Decimal{}, false } diff --git a/flow/model/record_items.go b/flow/model/record_items.go index f75d122f6d..fe9da58af9 100644 --- a/flow/model/record_items.go +++ b/flow/model/record_items.go @@ -5,9 +5,10 @@ import ( "errors" "fmt" "math" - "math/big" "time" + "github.com/shopspring/decimal" + hstore_util "github.com/PeerDB-io/peer-flow/hstore" "github.com/PeerDB-io/peer-flow/model/qvalue" ) @@ -164,16 +165,12 @@ func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) { } jsonStruct[col] = formattedDateArr case qvalue.QValueKindNumeric: - bigRat, ok := v.Value.(*big.Rat) + val, ok := v.Value.(decimal.Decimal) if !ok { - return nil, errors.New("expected *big.Rat value") + return nil, errors.New("expected decimal.Decimal value") } - if bigRat == nil { - jsonStruct[col] = nil - continue - } - jsonStruct[col] = bigRat.FloatString(100) + jsonStruct[col] = val.String() case qvalue.QValueKindFloat64: floatVal, ok := v.Value.(float64) if !ok {