diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index bad43d61cb..e2f2f12659 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -372,7 +372,7 @@ func (c *BigQueryConnector) syncRecordsViaAvro( ) (*model.SyncResponse, error) { tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, syncBatchID) - streamRes, err := utils.RecordsToRawTableStream(streamReq) + stream, err := utils.RecordsToRawTableStream(streamReq) if err != nil { return nil, fmt.Errorf("failed to convert records to raw table stream: %w", err) } @@ -384,7 +384,7 @@ func (c *BigQueryConnector) syncRecordsViaAvro( } res, err := avroSync.SyncRecords(ctx, req, rawTableName, - rawTableMetadata, syncBatchID, streamRes.Stream, streamReq.TableMapping) + rawTableMetadata, syncBatchID, stream, streamReq.TableMapping) if err != nil { return nil, fmt.Errorf("failed to sync records via avro: %w", err) } diff --git a/flow/connectors/clickhouse/cdc.go b/flow/connectors/clickhouse/cdc.go index aabd51c11f..9b0911955b 100644 --- a/flow/connectors/clickhouse/cdc.go +++ b/flow/connectors/clickhouse/cdc.go @@ -82,7 +82,7 @@ func (c *ClickhouseConnector) syncRecordsViaAvro( ) (*model.SyncResponse, error) { tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, syncBatchID) - streamRes, err := utils.RecordsToRawTableStream(streamReq) + stream, err := utils.RecordsToRawTableStream(streamReq) if err != nil { return nil, fmt.Errorf("failed to convert records to raw table stream: %w", err) } @@ -98,7 +98,7 @@ func (c *ClickhouseConnector) syncRecordsViaAvro( return nil, err } - numRecords, err := avroSyncer.SyncRecords(ctx, destinationTableSchema, streamRes.Stream, req.FlowJobName) + numRecords, err := avroSyncer.SyncRecords(ctx, destinationTableSchema, stream, req.FlowJobName) if err != nil { return nil, err } diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index a1ab93f391..d58530da54 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -158,7 +158,7 @@ func (c *EventHubConnector) processBatch( // partition_column is the column in the table that is used to determine // the partition key for the eventhub. partitionColumn := destination.PartitionKeyColumn - partitionValue := record.GetItems().GetColumnValue(partitionColumn).Value + partitionValue := record.GetItems().GetColumnValue(partitionColumn).Value() var partitionKey string if partitionValue != nil { partitionKey = fmt.Sprint(partitionValue) diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index e5bad73d62..056eafcccc 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -607,7 +607,7 @@ func (p *PostgresCDCSource) convertTupleToMap( } switch col.DataType { case 'n': // null - val := qvalue.QValue{Kind: qvalue.QValueKindInvalid, Value: nil} + val := qvalue.QValueNull(qvalue.QValueKindInvalid) items.AddColumn(colName, val) case 't': // text /* bytea also appears here as a hex */ @@ -649,19 +649,28 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma // but not representable by time.Time p.logger.Warn(fmt.Sprintf("Invalidated and hence nulled %s data: %s", dt.Name, string(data))) - return qvalue.QValue{}, nil + switch dt.Name { + case "time": + return qvalue.QValueNull(qvalue.QValueKindTime), nil + case "timetz": + return qvalue.QValueNull(qvalue.QValueKindTimeTZ), nil + case "timestamp": + return qvalue.QValueNull(qvalue.QValueKindTimestamp), nil + case "timestamptz": + return qvalue.QValueNull(qvalue.QValueKindTimestampTZ), nil + } } - return qvalue.QValue{}, err + return nil, err } retVal, err := p.parseFieldFromPostgresOID(dataType, parsedData) if err != nil { - return qvalue.QValue{}, err + return nil, err } return retVal, nil } else if dataType == uint32(oid.T_timetz) { // ugly TIMETZ workaround for CDC decoding. retVal, err := p.parseFieldFromPostgresOID(dataType, string(data)) if err != nil { - return qvalue.QValue{}, err + return nil, err } return retVal, nil } @@ -669,28 +678,26 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma typeName, ok := p.customTypesMapping[dataType] if ok { customQKind := customTypeToQKind(typeName) - if customQKind == qvalue.QValueKindGeography || customQKind == qvalue.QValueKindGeometry { + switch customQKind { + case qvalue.QValueKindGeography, qvalue.QValueKindGeometry: wkt, err := geo.GeoValidate(string(data)) if err != nil { - return qvalue.QValue{ - Kind: customQKind, - Value: nil, - }, nil + return qvalue.QValueNull(customQKind), nil + } else if customQKind == qvalue.QValueKindGeography { + return qvalue.QValueGeography{Val: wkt}, nil } else { - return qvalue.QValue{ - Kind: customQKind, - Value: wkt, - }, nil + return qvalue.QValueGeometry{Val: wkt}, nil } - } else { - return qvalue.QValue{ - Kind: customQKind, - Value: string(data), - }, nil + case qvalue.QValueKindHStore: + return qvalue.QValueHStore{Val: string(data)}, nil + case qvalue.QValueKindString: + return qvalue.QValueString{Val: string(data)}, nil + default: + return nil, fmt.Errorf("unknown custom qkind: %s", customQKind) } } - return qvalue.QValue{Kind: qvalue.QValueKindString, Value: string(data)}, nil + return qvalue.QValueString{Val: string(data)}, nil } func (p *PostgresCDCSource) auditSchemaDelta(ctx context.Context, flowJobName string, rec *model.RelationRecord) error { @@ -795,7 +802,7 @@ func (p *PostgresCDCSource) recToTablePKey(req *model.PullRecordsRequest, if err != nil { return nil, fmt.Errorf("error getting pkey column value: %w", err) } - pkeyColsMerged = append(pkeyColsMerged, []byte(fmt.Sprint(pkeyColVal.Value))) + pkeyColsMerged = append(pkeyColsMerged, []byte(fmt.Sprint(pkeyColVal.Value()))) } return &model.TableWithPkey{ diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index c89567a4e6..57cd7e9d86 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -415,8 +415,8 @@ func (c *PostgresConnector) SyncRecords(ctx context.Context, req *model.SyncReco var row []any switch typedRecord := record.(type) { case *model.InsertRecord: - itemsJSON, err := typedRecord.Items.ToJSONWithOptions(&model.ToJSONOptions{ - UnnestColumns: map[string]struct{}{}, + itemsJSON, err := typedRecord.Items.ToJSONWithOptions(model.ToJSONOptions{ + UnnestColumns: nil, HStoreAsJSON: false, }) if err != nil { @@ -435,15 +435,15 @@ func (c *PostgresConnector) SyncRecords(ctx context.Context, req *model.SyncReco } case *model.UpdateRecord: - newItemsJSON, err := typedRecord.NewItems.ToJSONWithOptions(&model.ToJSONOptions{ - UnnestColumns: map[string]struct{}{}, + newItemsJSON, err := typedRecord.NewItems.ToJSONWithOptions(model.ToJSONOptions{ + UnnestColumns: nil, HStoreAsJSON: false, }) if err != nil { return nil, fmt.Errorf("failed to serialize update record new items to JSON: %w", err) } - oldItemsJSON, err := typedRecord.OldItems.ToJSONWithOptions(&model.ToJSONOptions{ - UnnestColumns: map[string]struct{}{}, + oldItemsJSON, err := typedRecord.OldItems.ToJSONWithOptions(model.ToJSONOptions{ + UnnestColumns: nil, HStoreAsJSON: false, }) if err != nil { @@ -462,8 +462,8 @@ func (c *PostgresConnector) SyncRecords(ctx context.Context, req *model.SyncReco } case *model.DeleteRecord: - itemsJSON, err := typedRecord.Items.ToJSONWithOptions(&model.ToJSONOptions{ - UnnestColumns: map[string]struct{}{}, + itemsJSON, err := typedRecord.Items.ToJSONWithOptions(model.ToJSONOptions{ + UnnestColumns: nil, HStoreAsJSON: false, }) if err != nil { diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index a4efdfc85b..f7fb9accbd 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -458,19 +458,26 @@ func (qe *QRepQueryExecutor) mapRowToQRecord( record[i] = tmp } else { customQKind := customTypeToQKind(typeName) - if customQKind == qvalue.QValueKindGeography || customQKind == qvalue.QValueKindGeometry { - wkbString, ok := values[i].(string) - wkt, err := geo.GeoValidate(wkbString) - if err != nil || !ok { - values[i] = nil - } else { - values[i] = wkt + if values[i] == nil { + record[i] = qvalue.QValueNull(customQKind) + } else { + switch customQKind { + case qvalue.QValueKindGeography, qvalue.QValueKindGeometry: + wkbString, ok := values[i].(string) + wkt, err := geo.GeoValidate(wkbString) + if err != nil || !ok { + record[i] = qvalue.QValueNull(qvalue.QValueKindGeography) + } else if customQKind == qvalue.QValueKindGeography { + record[i] = qvalue.QValueGeography{Val: wkt} + } else { + record[i] = qvalue.QValueGeometry{Val: wkt} + } + case qvalue.QValueKindHStore: + record[i] = qvalue.QValueHStore{Val: fmt.Sprint(values[i])} + case qvalue.QValueKindString: + record[i] = qvalue.QValueString{Val: fmt.Sprint(values[i])} } } - record[i] = qvalue.QValue{ - Kind: customQKind, - Value: values[i], - } } } diff --git a/flow/connectors/postgres/qrep_query_executor_test.go b/flow/connectors/postgres/qrep_query_executor_test.go index cdb37b0349..9a5ca02330 100644 --- a/flow/connectors/postgres/qrep_query_executor_test.go +++ b/flow/connectors/postgres/qrep_query_executor_test.go @@ -75,8 +75,8 @@ func TestExecuteAndProcessQuery(t *testing.T) { t.Fatalf("expected 1 record, got %v", len(batch.Records)) } - if batch.Records[0][1].Value != "testdata" { - t.Fatalf("expected 'testdata', got %v", batch.Records[0][0].Value) + if batch.Records[0][1].Value() != "testdata" { + t.Fatalf("expected 'testdata', got %v", batch.Records[0][0].Value()) } } @@ -189,52 +189,52 @@ func TestAllDataTypes(t *testing.T) { record := batch.Records[0] expectedBool := true - if record[0].Value.(bool) != expectedBool { - t.Fatalf("expected %v, got %v", expectedBool, record[0].Value) + if record[0].Value().(bool) != expectedBool { + t.Fatalf("expected %v, got %v", expectedBool, record[0].Value()) } expectedInt4 := int32(2) - if record[1].Value.(int32) != expectedInt4 { - t.Fatalf("expected %v, got %v", expectedInt4, record[1].Value) + if record[1].Value().(int32) != expectedInt4 { + t.Fatalf("expected %v, got %v", expectedInt4, record[1].Value()) } expectedInt8 := int64(3) - if record[2].Value.(int64) != expectedInt8 { - t.Fatalf("expected %v, got %v", expectedInt8, record[2].Value) + if record[2].Value().(int64) != expectedInt8 { + t.Fatalf("expected %v, got %v", expectedInt8, record[2].Value()) } expectedFloat4 := float32(1.1) - if record[3].Value.(float32) != expectedFloat4 { - t.Fatalf("expected %v, got %v", expectedFloat4, record[3].Value) + if record[3].Value().(float32) != expectedFloat4 { + t.Fatalf("expected %v, got %v", expectedFloat4, record[3].Value()) } expectedFloat8 := float64(2.2) - if record[4].Value.(float64) != expectedFloat8 { - t.Fatalf("expected %v, got %v", expectedFloat8, record[4].Value) + if record[4].Value().(float64) != expectedFloat8 { + t.Fatalf("expected %v, got %v", expectedFloat8, record[4].Value()) } expectedText := "text" - if record[5].Value.(string) != expectedText { - t.Fatalf("expected %v, got %v", expectedText, record[5].Value) + if record[5].Value().(string) != expectedText { + t.Fatalf("expected %v, got %v", expectedText, record[5].Value()) } expectedBytea := []byte("bytea") - if !bytes.Equal(record[6].Value.([]byte), expectedBytea) { - t.Fatalf("expected %v, got %v", expectedBytea, record[6].Value) + if !bytes.Equal(record[6].Value().([]byte), expectedBytea) { + t.Fatalf("expected %v, got %v", expectedBytea, record[6].Value()) } expectedJSON := `{"key":"value"}` - if record[7].Value.(string) != expectedJSON { - t.Fatalf("expected %v, got %v", expectedJSON, record[7].Value) + if record[7].Value().(string) != expectedJSON { + t.Fatalf("expected %v, got %v", expectedJSON, record[7].Value()) } - actualUUID := record[8].Value.([16]uint8) + actualUUID := record[8].Value().([16]uint8) if !bytes.Equal(actualUUID[:], savedUUID[:]) { t.Fatalf("expected %v, got %v", savedUUID, actualUUID) } expectedNumeric := "123.456" - actualNumeric := record[10].Value.(decimal.Decimal).String() + actualNumeric := record[10].Value().(decimal.Decimal).String() if actualNumeric != expectedNumeric { t.Fatalf("expected %v, got %v", expectedNumeric, actualNumeric) } diff --git a/flow/connectors/postgres/qrep_sql_sync.go b/flow/connectors/postgres/qrep_sql_sync.go index a6ccc5e224..bd0e78f32a 100644 --- a/flow/connectors/postgres/qrep_sql_sync.go +++ b/flow/connectors/postgres/qrep_sql_sync.go @@ -92,7 +92,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords( copySource, ) if err != nil { - return -1, fmt.Errorf("failed to copy records into destination table: %v", err) + return -1, fmt.Errorf("failed to copy records into destination table: %w", err) } if syncedAtCol != "" { @@ -104,7 +104,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords( ) _, err = tx.Exec(context.Background(), updateSyncedAtStmt) if err != nil { - return -1, fmt.Errorf("failed to update synced_at column: %v", err) + return -1, fmt.Errorf("failed to update synced_at column: %w", err) } } } else { @@ -123,7 +123,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords( stagingTableName, createStagingTableStmt), syncLog) _, err = tx.Exec(context.Background(), createStagingTableStmt) if err != nil { - return -1, fmt.Errorf("failed to create staging table: %v", err) + return -1, fmt.Errorf("failed to create staging table: %w", err) } // Step 2.2: Insert records into the staging table @@ -134,7 +134,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords( copySource, ) if err != nil || numRowsSynced != int64(copySource.NumRecords()) { - return -1, fmt.Errorf("failed to copy records into staging table: %v", err) + return -1, fmt.Errorf("failed to copy records into staging table: %w", err) } // construct the SET clause for the upsert operation @@ -173,7 +173,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords( s.connector.logger.Info("Performing upsert operation", slog.String("upsertStmt", upsertStmt), syncLog) res, err := tx.Exec(context.Background(), upsertStmt) if err != nil { - return -1, fmt.Errorf("failed to perform upsert operation: %v", err) + return -1, fmt.Errorf("failed to perform upsert operation: %w", err) } numRowsSynced = res.RowsAffected() @@ -186,7 +186,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords( s.connector.logger.Info("Dropping staging table", slog.String("stagingTable", stagingTableName), syncLog) _, err = tx.Exec(context.Background(), dropStagingTableStmt) if err != nil { - return -1, fmt.Errorf("failed to drop staging table: %v", err) + return -1, fmt.Errorf("failed to drop staging table: %w", err) } } @@ -195,7 +195,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords( // marshal the partition to json using protojson pbytes, err := protojson.Marshal(partition) if err != nil { - return -1, fmt.Errorf("failed to marshal partition to json: %v", err) + return -1, fmt.Errorf("failed to marshal partition to json: %w", err) } metadataTableIdentifier := pgx.Identifier{s.connector.metadataSchema, qRepMetadataTableName} @@ -214,12 +214,12 @@ func (s *QRepStagingTableSync) SyncQRepRecords( time.Now(), ) if err != nil { - return -1, fmt.Errorf("failed to execute statements in a transaction: %v", err) + return -1, fmt.Errorf("failed to execute statements in a transaction: %w", err) } err = tx.Commit(context.Background()) if err != nil { - return -1, fmt.Errorf("failed to commit transaction: %v", err) + return -1, fmt.Errorf("failed to commit transaction: %w", err) } numRowsInserted := copySource.NumRecords() diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index d0e3e8cc0b..01164672e8 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" "github.com/lib/pq/oid" "github.com/shopspring/decimal" @@ -195,39 +196,37 @@ func qValueKindToPostgresType(colTypeStr string) string { func parseJSON(value interface{}) (qvalue.QValue, error) { jsonVal, err := json.Marshal(value) if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to parse JSON: %w", err) + return nil, fmt.Errorf("failed to parse JSON: %w", err) } - return qvalue.QValue{Kind: qvalue.QValueKindJSON, Value: string(jsonVal)}, nil + return qvalue.QValueJSON{Val: string(jsonVal)}, nil } -func convertToArray[T any](kind qvalue.QValueKind, value interface{}) (qvalue.QValue, error) { +func convertToArray[T any](kind qvalue.QValueKind, value interface{}) ([]T, error) { switch v := value.(type) { case pgtype.Array[T]: if v.Valid { - return qvalue.QValue{Kind: kind, Value: v.Elements}, nil + return v.Elements, nil } case []T: - return qvalue.QValue{Kind: kind, Value: v}, nil + return v, nil case []interface{}: - return qvalue.QValue{Kind: kind, Value: shared.ArrayCastElements[T](v)}, nil + return shared.ArrayCastElements[T](v), nil } - return qvalue.QValue{}, fmt.Errorf("failed to parse array %s from %T: %v", kind, value, value) + return nil, fmt.Errorf("failed to parse array %s from %T: %v", kind, value, value) } func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (qvalue.QValue, error) { - val := qvalue.QValue{} - if value == nil { - return qvalue.QValue{Kind: qvalueKind, Value: nil}, nil + return qvalue.QValueNull(qvalueKind), nil } switch qvalueKind { case qvalue.QValueKindTimestamp: timestamp := value.(time.Time) - val = qvalue.QValue{Kind: qvalue.QValueKindTimestamp, Value: timestamp} + return qvalue.QValueTimestamp{Val: timestamp}, nil case qvalue.QValueKindTimestampTZ: timestamp := value.(time.Time) - val = qvalue.QValue{Kind: qvalue.QValueKindTimestampTZ, Value: timestamp} + return qvalue.QValueTimestampTZ{Val: timestamp}, nil case qvalue.QValueKindInterval: intervalObject := value.(pgtype.Interval) var interval peerdb_interval.PeerDBInterval @@ -241,22 +240,22 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( intervalJSON, err := json.Marshal(interval) if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to parse interval: %w", err) + return nil, fmt.Errorf("failed to parse interval: %w", err) } if !interval.Valid { - return qvalue.QValue{}, fmt.Errorf("invalid interval: %v", value) + return nil, fmt.Errorf("invalid interval: %v", value) } - return qvalue.QValue{Kind: qvalue.QValueKindString, Value: string(intervalJSON)}, nil + return qvalue.QValueString{Val: string(intervalJSON)}, nil case qvalue.QValueKindDate: date := value.(time.Time) - val = qvalue.QValue{Kind: qvalue.QValueKindDate, Value: date} + return qvalue.QValueDate{Val: date}, nil case qvalue.QValueKindTime: timeVal := value.(pgtype.Time) if timeVal.Valid { // 86399999999 to prevent 24:00:00 - val = qvalue.QValue{Kind: qvalue.QValueKindTime, Value: time.UnixMicro(min(timeVal.Microseconds, 86399999999))} + return qvalue.QValueTime{Val: time.UnixMicro(min(timeVal.Microseconds, 86399999999))}, nil } case qvalue.QValueKindTimeTZ: timeVal := value.(string) @@ -266,147 +265,186 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( timeVal = strings.Replace(timeVal, "+00", "+0000", 1) t, err := time.Parse("15:04:05.999999-0700", timeVal) if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to parse time: %w", err) + return nil, fmt.Errorf("failed to parse time: %w", err) } t = t.AddDate(1970, 0, 0) - val = qvalue.QValue{Kind: qvalue.QValueKindTimeTZ, Value: t} + return qvalue.QValueTimeTZ{Val: t}, nil case qvalue.QValueKindBoolean: boolVal := value.(bool) - val = qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: boolVal} + return qvalue.QValueBoolean{Val: boolVal}, nil case qvalue.QValueKindJSON: tmp, err := parseJSON(value) if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to parse JSON: %w", err) + return nil, fmt.Errorf("failed to parse JSON: %w", err) } - val = tmp + return tmp, nil case qvalue.QValueKindInt16: intVal := value.(int16) - val = qvalue.QValue{Kind: qvalue.QValueKindInt16, Value: int32(intVal)} + return qvalue.QValueInt16{Val: intVal}, nil case qvalue.QValueKindInt32: intVal := value.(int32) - val = qvalue.QValue{Kind: qvalue.QValueKindInt32, Value: intVal} + return qvalue.QValueInt32{Val: intVal}, nil case qvalue.QValueKindInt64: intVal := value.(int64) - val = qvalue.QValue{Kind: qvalue.QValueKindInt64, Value: intVal} + return qvalue.QValueInt64{Val: intVal}, nil case qvalue.QValueKindFloat32: floatVal := value.(float32) - val = qvalue.QValue{Kind: qvalue.QValueKindFloat32, Value: floatVal} + return qvalue.QValueFloat32{Val: floatVal}, nil case qvalue.QValueKindFloat64: floatVal := value.(float64) - val = qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: floatVal} + return qvalue.QValueFloat64{Val: floatVal}, nil case qvalue.QValueKindQChar: - val = qvalue.QValue{Kind: qvalue.QValueKindQChar, Value: uint8(value.(rune))} + return qvalue.QValueQChar{Val: uint8(value.(rune))}, nil case qvalue.QValueKindString: // handling all unsupported types with strings as well for now. - val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: fmt.Sprint(value)} + return qvalue.QValueString{Val: fmt.Sprint(value)}, nil case qvalue.QValueKindUUID: - switch value.(type) { + switch v := value.(type) { case string: - val = qvalue.QValue{Kind: qvalue.QValueKindUUID, Value: value} + id, err := uuid.Parse(v) + if err != nil { + return nil, fmt.Errorf("failed to parse UUID: %w", err) + } + return qvalue.QValueUUID{Val: [16]byte(id)}, nil case [16]byte: - val = qvalue.QValue{Kind: qvalue.QValueKindUUID, Value: value} + return qvalue.QValueUUID{Val: v}, nil default: - return qvalue.QValue{}, fmt.Errorf("failed to parse UUID: %v", value) + return nil, fmt.Errorf("failed to parse UUID: %v", value) } case qvalue.QValueKindINET: switch v := value.(type) { case string: - val = qvalue.QValue{Kind: qvalue.QValueKindINET, Value: value} + return qvalue.QValueINET{Val: v}, nil case [16]byte: - val = qvalue.QValue{Kind: qvalue.QValueKindINET, Value: value} + return qvalue.QValueINET{Val: string(v[:])}, nil case netip.Prefix: - val = qvalue.QValue{Kind: qvalue.QValueKindINET, Value: v.String()} + return qvalue.QValueINET{Val: v.String()}, nil default: - return qvalue.QValue{}, fmt.Errorf("failed to parse INET: %v", v) + return nil, fmt.Errorf("failed to parse INET: %v", v) } case qvalue.QValueKindCIDR: switch v := value.(type) { case string: - val = qvalue.QValue{Kind: qvalue.QValueKindCIDR, Value: value} + return qvalue.QValueCIDR{Val: v}, nil case [16]byte: - val = qvalue.QValue{Kind: qvalue.QValueKindCIDR, Value: value} + return qvalue.QValueCIDR{Val: string(v[:])}, nil case netip.Prefix: - val = qvalue.QValue{Kind: qvalue.QValueKindCIDR, Value: v.String()} + return qvalue.QValueCIDR{Val: v.String()}, nil default: - return qvalue.QValue{}, fmt.Errorf("failed to parse CIDR: %v", value) + return nil, fmt.Errorf("failed to parse CIDR: %v", value) } case qvalue.QValueKindMacaddr: - switch value.(type) { + switch v := value.(type) { case string: - val = qvalue.QValue{Kind: qvalue.QValueKindMacaddr, Value: value} + return qvalue.QValueMacaddr{Val: v}, nil case [16]byte: - val = qvalue.QValue{Kind: qvalue.QValueKindMacaddr, Value: value} + return qvalue.QValueMacaddr{Val: string(v[:])}, nil default: - return qvalue.QValue{}, fmt.Errorf("failed to parse MACADDR: %v", value) + return nil, fmt.Errorf("failed to parse MACADDR: %v", value) } case qvalue.QValueKindBytes: rawBytes := value.([]byte) - val = qvalue.QValue{Kind: qvalue.QValueKindBytes, Value: rawBytes} + return qvalue.QValueBytes{Val: rawBytes}, nil case qvalue.QValueKindBit: bitsVal := value.(pgtype.Bits) if bitsVal.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindBit, Value: bitsVal.Bytes} + return qvalue.QValueBit{Val: bitsVal.Bytes}, nil } case qvalue.QValueKindNumeric: numVal := value.(pgtype.Numeric) if numVal.Valid { num, err := numericToDecimal(numVal) if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to convert numeric [%v] to decimal: %w", value, err) + return nil, fmt.Errorf("failed to convert numeric [%v] to decimal: %w", value, err) } - val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: num} + return num, nil } case qvalue.QValueKindArrayFloat32: - return convertToArray[float32](qvalueKind, value) + a, err := convertToArray[float32](qvalueKind, value) + if err != nil { + return nil, err + } + return qvalue.QValueArrayFloat32{Val: a}, nil case qvalue.QValueKindArrayFloat64: - return convertToArray[float64](qvalueKind, value) + a, err := convertToArray[float64](qvalueKind, value) + if err != nil { + return nil, err + } + return qvalue.QValueArrayFloat64{Val: a}, nil case qvalue.QValueKindArrayInt16: - return convertToArray[int16](qvalueKind, value) + a, err := convertToArray[int16](qvalueKind, value) + if err != nil { + return nil, err + } + return qvalue.QValueArrayInt16{Val: a}, nil case qvalue.QValueKindArrayInt32: - return convertToArray[int32](qvalueKind, value) + a, err := convertToArray[int32](qvalueKind, value) + if err != nil { + return nil, err + } + return qvalue.QValueArrayInt32{Val: a}, nil case qvalue.QValueKindArrayInt64: - return convertToArray[int64](qvalueKind, value) + a, err := convertToArray[int64](qvalueKind, value) + if err != nil { + return nil, err + } + return qvalue.QValueArrayInt64{Val: a}, nil case qvalue.QValueKindArrayDate, qvalue.QValueKindArrayTimestamp, qvalue.QValueKindArrayTimestampTZ: - return convertToArray[time.Time](qvalueKind, value) + a, err := convertToArray[time.Time](qvalueKind, value) + if err != nil { + return nil, err + } + switch qvalueKind { + case qvalue.QValueKindArrayDate: + return qvalue.QValueArrayDate{Val: a}, nil + case qvalue.QValueKindArrayTimestamp: + return qvalue.QValueArrayTimestamp{Val: a}, nil + case qvalue.QValueKindArrayTimestampTZ: + return qvalue.QValueArrayTimestampTZ{Val: a}, nil + } case qvalue.QValueKindArrayBoolean: - return convertToArray[bool](qvalueKind, value) + a, err := convertToArray[bool](qvalueKind, value) + if err != nil { + return nil, err + } + return qvalue.QValueArrayBoolean{Val: a}, nil case qvalue.QValueKindArrayString: - return convertToArray[string](qvalueKind, value) + a, err := convertToArray[string](qvalueKind, value) + if err != nil { + return nil, err + } + return qvalue.QValueArrayString{Val: a}, nil case qvalue.QValueKindPoint: xCoord := value.(pgtype.Point).P.X yCoord := value.(pgtype.Point).P.Y - val = qvalue.QValue{ - Kind: qvalue.QValueKindPoint, - Value: fmt.Sprintf("POINT(%f %f)", xCoord, yCoord), - } + return qvalue.QValuePoint{ + Val: fmt.Sprintf("POINT(%f %f)", xCoord, yCoord), + }, nil default: textVal, ok := value.(string) if ok { - val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: textVal} + return qvalue.QValueString{Val: textVal}, nil } } // parsing into pgtype failed. - if val == (qvalue.QValue{}) { - return qvalue.QValue{}, fmt.Errorf("failed to parse value %v into QValueKind %v", value, qvalueKind) - } - return val, nil + return nil, fmt.Errorf("failed to parse value %v into QValueKind %v", value, qvalueKind) } func (c *PostgresConnector) parseFieldFromPostgresOID(oid uint32, value interface{}) (qvalue.QValue, error) { return parseFieldFromQValueKind(c.postgresOIDToQValueKind(oid), value) } -func numericToDecimal(numVal pgtype.Numeric) (interface{}, error) { +func numericToDecimal(numVal pgtype.Numeric) (qvalue.QValue, error) { switch { case !numVal.Valid: - return nil, errors.New("invalid numeric") + return qvalue.QValueNull(qvalue.QValueKindNumeric), errors.New("invalid numeric") case numVal.NaN, numVal.InfinityModifier == pgtype.Infinity, numVal.InfinityModifier == pgtype.NegativeInfinity: - return nil, nil + return qvalue.QValueNull(qvalue.QValueKindNumeric), nil default: - return decimal.NewFromBigInt(numVal.Int, numVal.Exp), nil + return qvalue.QValueNumeric{Val: decimal.NewFromBigInt(numVal.Int, numVal.Exp)}, nil } } diff --git a/flow/connectors/s3/s3.go b/flow/connectors/s3/s3.go index 7e95e00b43..fddd69f78f 100644 --- a/flow/connectors/s3/s3.go +++ b/flow/connectors/s3/s3.go @@ -155,11 +155,10 @@ func (c *S3Connector) SetLastOffset(ctx context.Context, jobName string, offset func (c *S3Connector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) { tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, req.SyncBatchID) - streamRes, err := utils.RecordsToRawTableStream(streamReq) + recordStream, err := utils.RecordsToRawTableStream(streamReq) if err != nil { return nil, fmt.Errorf("failed to convert records to raw table stream: %w", err) } - recordStream := streamRes.Stream qrepConfig := &protos.QRepConfig{ FlowJobName: req.FlowJobName, DestinationTableIdentifier: "raw_table_" + req.FlowJobName, diff --git a/flow/connectors/snowflake/avro_file_writer_test.go b/flow/connectors/snowflake/avro_file_writer_test.go index 13f4a9e3e8..0006513cbb 100644 --- a/flow/connectors/snowflake/avro_file_writer_test.go +++ b/flow/connectors/snowflake/avro_file_writer_test.go @@ -17,45 +17,53 @@ import ( ) // createQValue creates a QValue of the appropriate kind for a given placeholder. -func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue.QValue { +func createQValue(t *testing.T, kind qvalue.QValueKind, placeholder int) qvalue.QValue { t.Helper() - var value interface{} switch kind { - case qvalue.QValueKindInt16, qvalue.QValueKindInt32, qvalue.QValueKindInt64: - value = int64(placeHolder) + case qvalue.QValueKindInt16: + return qvalue.QValueInt16{Val: int16(placeholder)} + case qvalue.QValueKindInt32: + return qvalue.QValueInt32{Val: int32(placeholder)} + case qvalue.QValueKindInt64: + return qvalue.QValueInt64{Val: int64(placeholder)} case qvalue.QValueKindFloat32: - value = float32(placeHolder) + return qvalue.QValueFloat32{Val: float32(placeholder) / 4.0} case qvalue.QValueKindFloat64: - value = float64(placeHolder) + return qvalue.QValueFloat64{Val: float64(placeholder) / 4.0} case qvalue.QValueKindBoolean: - value = placeHolder%2 == 0 + return qvalue.QValueBoolean{Val: placeholder%2 == 0} case qvalue.QValueKindString: - value = fmt.Sprintf("string%d", placeHolder) - case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindTime, - qvalue.QValueKindTimeTZ, qvalue.QValueKindDate: - value = time.Now() + return qvalue.QValueString{Val: fmt.Sprintf("string%d", placeholder)} + case qvalue.QValueKindTimestamp: + return qvalue.QValueTimestamp{Val: time.Now()} + case qvalue.QValueKindTimestampTZ: + return qvalue.QValueTimestampTZ{Val: time.Now()} + case qvalue.QValueKindTime: + return qvalue.QValueTime{Val: time.Now()} + case qvalue.QValueKindTimeTZ: + return qvalue.QValueTimeTZ{Val: time.Now()} + case qvalue.QValueKindDate: + return qvalue.QValueDate{Val: time.Now()} case qvalue.QValueKindNumeric: - value = decimal.New(int64(placeHolder), 1) + return qvalue.QValueNumeric{Val: decimal.New(int64(placeholder), 1)} case qvalue.QValueKindUUID: - value = uuid.New() // assuming you have the github.com/google/uuid package + return qvalue.QValueUUID{Val: [16]byte(uuid.New())} // assuming you have the github.com/google/uuid package case qvalue.QValueKindQChar: - value = uint8(48) - // case qvalue.QValueKindArray: - // value = []int{1, 2, 3} // placeholder array, replace with actual logic - // case qvalue.QValueKindStruct: - // value = map[string]interface{}{"key": "value"} // placeholder struct, replace with actual logic - // case qvalue.QValueKindJSON: - // value = `{"key": "value"}` // placeholder JSON, replace with actual logic - case qvalue.QValueKindBytes, qvalue.QValueKindBit: - value = []byte("sample bytes") // placeholder bytes, replace with actual logic + return qvalue.QValueQChar{Val: uint8(48 + placeholder%10)} // assuming you have the github.com/google/uuid package + // case qvalue.QValueKindArray: + // value = []int{1, 2, 3} // placeholder array, replace with actual logic + // case qvalue.QValueKindStruct: + // value = map[string]interface{}{"key": "value"} // placeholder struct, replace with actual logic + // case qvalue.QValueKindJSON: + // value = `{"key": "value"}` // placeholder JSON, replace with actual logic + case qvalue.QValueKindBytes: + return qvalue.QValueBytes{Val: []byte("sample bytes")} // placeholder bytes, replace with actual logic + case qvalue.QValueKindBit: + return qvalue.QValueBit{Val: []byte("sample bits")} // placeholder bytes, replace with actual logic default: require.Failf(t, "unsupported QValueKind", "unsupported QValueKind: %s", kind) - } - - return qvalue.QValue{ - Kind: kind, - Value: value, + return qvalue.QValueNull(kind) } } @@ -115,10 +123,10 @@ func generateRecords( entries := make([]qvalue.QValue, len(allQValueKinds)) for i, kind := range allQValueKinds { - placeHolder := int(row) * i - entries[i] = createQValue(t, kind, placeHolder) if allnulls { - entries[i].Value = nil + entries[i] = qvalue.QValueNull(kind) + } else { + entries[i] = createQValue(t, kind, int(row)*i) } } diff --git a/flow/connectors/snowflake/qrep_avro_consolidate.go b/flow/connectors/snowflake/qrep_avro_consolidate.go index 8d39ecf0b8..6f261d2cf8 100644 --- a/flow/connectors/snowflake/qrep_avro_consolidate.go +++ b/flow/connectors/snowflake/qrep_avro_consolidate.go @@ -142,7 +142,7 @@ func (s *SnowflakeAvroConsolidateHandler) generateUpsertMergeCommand( upsertKeyCols := s.config.WriteMode.UpsertKeyColumns // all cols are acquired from snowflake schema, so let us try to make upsert key cols match the case // and also the watermark col, then the quoting should be fine - caseMatchedCols := map[string]string{} + caseMatchedCols := make(map[string]string, len(s.allColNames)) for _, col := range s.allColNames { caseMatchedCols[strings.ToLower(col)] = col } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 1d9ed6fe58..3844c321fc 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -449,7 +449,7 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( ) (*model.SyncResponse, error) { tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, syncBatchID) - streamRes, err := utils.RecordsToRawTableStream(streamReq) + stream, err := utils.RecordsToRawTableStream(streamReq) if err != nil { return nil, fmt.Errorf("failed to convert records to raw table stream: %w", err) } @@ -466,7 +466,7 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( return nil, err } - numRecords, err := avroSyncer.SyncRecords(ctx, destinationTableSchema, streamRes.Stream, req.FlowJobName) + numRecords, err := avroSyncer.SyncRecords(ctx, destinationTableSchema, stream, req.FlowJobName) if err != nil { return nil, err } diff --git a/flow/connectors/sql/query_executor.go b/flow/connectors/sql/query_executor.go index 05279fdde4..447da7c712 100644 --- a/flow/connectors/sql/query_executor.go +++ b/flow/connectors/sql/query_executor.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "strings" + "time" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" @@ -218,7 +219,7 @@ func (g *GenericSQLQueryExecutor) processRows(ctx context.Context, rows *sqlx.Ro case qvalue.QValueKindBoolean: var b sql.NullBool values[i] = &b - case qvalue.QValueKindString: + case qvalue.QValueKindString, qvalue.QValueKindHStore: var s sql.NullString values[i] = &s case qvalue.QValueKindBytes, qvalue.QValueKindBit: @@ -321,87 +322,124 @@ func (g *GenericSQLQueryExecutor) CheckNull(ctx context.Context, schema string, func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { if val == nil { - return qvalue.QValue{Kind: kind, Value: nil}, nil + return qvalue.QValueNull(kind), nil } switch kind { case qvalue.QValueKindInt32: if v, ok := val.(*sql.NullInt32); ok { if v.Valid { - return qvalue.QValue{Kind: qvalue.QValueKindInt32, Value: v.Int32}, nil + return qvalue.QValueInt32{Val: v.Int32}, nil } else { - return qvalue.QValue{Kind: qvalue.QValueKindInt32, Value: nil}, nil + return qvalue.QValueNull(qvalue.QValueKindInt32), nil } } case qvalue.QValueKindInt64: if v, ok := val.(*sql.NullInt64); ok { if v.Valid { - return qvalue.QValue{Kind: qvalue.QValueKindInt64, Value: v.Int64}, nil + return qvalue.QValueInt64{Val: v.Int64}, nil } else { - return qvalue.QValue{Kind: qvalue.QValueKindInt64, Value: nil}, nil + return qvalue.QValueNull(qvalue.QValueKindInt64), nil } } case qvalue.QValueKindFloat32: if v, ok := val.(*sql.NullFloat64); ok { if v.Valid { - return qvalue.QValue{Kind: qvalue.QValueKindFloat32, Value: float32(v.Float64)}, nil + return qvalue.QValueFloat32{Val: float32(v.Float64)}, nil } else { - return qvalue.QValue{Kind: qvalue.QValueKindFloat32, Value: nil}, nil + return qvalue.QValueNull(qvalue.QValueKindFloat32), nil } } case qvalue.QValueKindFloat64: if v, ok := val.(*sql.NullFloat64); ok { if v.Valid { - return qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: v.Float64}, nil + return qvalue.QValueFloat64{Val: v.Float64}, nil } else { - return qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: nil}, nil + return qvalue.QValueNull(qvalue.QValueKindFloat64), nil } } case qvalue.QValueKindQChar: if v, ok := val.(uint8); ok { - return qvalue.QValue{Kind: qvalue.QValueKindQChar, Value: v}, nil + return qvalue.QValueQChar{Val: v}, nil } case qvalue.QValueKindString: if v, ok := val.(*sql.NullString); ok { if v.Valid { - return qvalue.QValue{Kind: qvalue.QValueKindString, Value: v.String}, nil + return qvalue.QValueString{Val: v.String}, nil } else { - return qvalue.QValue{Kind: qvalue.QValueKindString, Value: nil}, nil + return qvalue.QValueNull(qvalue.QValueKindString), nil } } case qvalue.QValueKindBoolean: if v, ok := val.(*sql.NullBool); ok { if v.Valid { - return qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: v.Bool}, nil + return qvalue.QValueBoolean{Val: v.Bool}, nil } else { - return qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: nil}, nil + return qvalue.QValueNull(qvalue.QValueKindBoolean), nil } } - case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindDate, - qvalue.QValueKindTime, qvalue.QValueKindTimeTZ: + case qvalue.QValueKindTimestamp: if t, ok := val.(*sql.NullTime); ok { if t.Valid { - return qvalue.QValue{ - Kind: kind, - Value: t.Time, + return qvalue.QValueTimestamp{Val: t.Time}, nil + } else { + return qvalue.QValueNull(kind), nil + } + } + case qvalue.QValueKindTimestampTZ: + if t, ok := val.(*sql.NullTime); ok { + if t.Valid { + return qvalue.QValueTimestampTZ{Val: t.Time}, nil + } else { + return qvalue.QValueNull(kind), nil + } + } + case qvalue.QValueKindDate: + if t, ok := val.(*sql.NullTime); ok { + if t.Valid { + return qvalue.QValueDate{Val: t.Time}, nil + } else { + return qvalue.QValueNull(kind), nil + } + } + case qvalue.QValueKindTime: + if t, ok := val.(*sql.NullTime); ok { + if t.Valid { + tt := t.Time + // anchor on unix epoch, some drivers anchor on 0001-01-01 + return qvalue.QValueTimeTZ{ + Val: time.Date(1970, time.January, 1, tt.Hour(), tt.Minute(), tt.Second(), tt.Nanosecond(), time.UTC), }, nil } else { - return qvalue.QValue{ - Kind: kind, - Value: nil, + return qvalue.QValueNull(kind), nil + } + } + case qvalue.QValueKindTimeTZ: + if t, ok := val.(*sql.NullTime); ok { + if t.Valid { + tt := t.Time + // anchor on unix epoch, some drivers anchor on 0001-01-01 + return qvalue.QValueTimeTZ{ + Val: time.Date(1970, time.January, 1, tt.Hour(), tt.Minute(), tt.Second(), tt.Nanosecond(), tt.Location()), }, nil + } else { + return qvalue.QValueNull(kind), nil } } case qvalue.QValueKindNumeric: if v, ok := val.(*sql.Null[decimal.Decimal]); ok { if v.Valid { - return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: v.V}, nil + return qvalue.QValueNumeric{Val: v.V}, nil } else { - return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: nil}, nil + return qvalue.QValueNull(qvalue.QValueKindNumeric), nil } } - case qvalue.QValueKindBytes, qvalue.QValueKindBit: + case qvalue.QValueKindBytes: + if v, ok := val.(*[]byte); ok && v != nil { + return qvalue.QValueBytes{Val: *v}, nil + } + case qvalue.QValueKindBit: if v, ok := val.(*[]byte); ok && v != nil { - return qvalue.QValue{Kind: kind, Value: *v}, nil + return qvalue.QValueBit{Val: *v}, nil } case qvalue.QValueKindUUID: @@ -409,13 +447,9 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { // convert byte array to string uuidVal, err := uuid.FromBytes(*v) if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to parse uuid: %v", *v) + return nil, fmt.Errorf("failed to parse uuid: %v", *v) } - return qvalue.QValue{Kind: qvalue.QValueKindString, Value: uuidVal.String()}, nil - } - - if v, ok := val.(*[16]byte); ok && v != nil { - return qvalue.QValue{Kind: qvalue.QValueKindString, Value: *v}, nil + return qvalue.QValueUUID{Val: [16]byte(uuidVal)}, nil } case qvalue.QValueKindJSON: @@ -430,7 +464,7 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { var v []interface{} err := json.Unmarshal([]byte(vstring), &v) if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to parse json array: %v", vstring) + return nil, fmt.Errorf("failed to parse json array: %v", vstring) } // assume all elements in the array are of the same type @@ -456,10 +490,16 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { } } - return qvalue.QValue{Kind: qvalue.QValueKindJSON, Value: vstring}, nil + return qvalue.QValueJSON{Val: vstring}, nil case qvalue.QValueKindHStore: - return qvalue.QValue{Kind: qvalue.QValueKindHStore, Value: val}, nil + vraw := val.(*interface{}) + vstring, ok := (*vraw).(string) + if !ok { + slog.Warn("A parsed hstore value was not a string. Likely a null field value") + } + + return qvalue.QValueHStore{Val: vstring}, nil case qvalue.QValueKindArrayFloat32, qvalue.QValueKindArrayFloat64, qvalue.QValueKindArrayInt16, @@ -469,96 +509,95 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { } // If type is unsupported or doesn't match the specified kind, return error - return qvalue.QValue{}, fmt.Errorf("unsupported type %T for kind %s", val, kind) + return nil, fmt.Errorf("unsupported type %T for kind %s", val, kind) } func toQValueArray(kind qvalue.QValueKind, value interface{}) (qvalue.QValue, error) { - var result interface{} switch kind { case qvalue.QValueKindArrayFloat32: switch v := value.(type) { case []float32: - result = v + return qvalue.QValueArrayFloat32{Val: v}, nil case []interface{}: float32Array := make([]float32, len(v)) for i, val := range v { float32Array[i] = val.(float32) } - result = float32Array + return qvalue.QValueArrayFloat32{Val: float32Array}, nil default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array float32: %v", value) + return nil, fmt.Errorf("failed to parse array float32: %v", value) } case qvalue.QValueKindArrayFloat64: switch v := value.(type) { case []float64: - result = v + return qvalue.QValueArrayFloat64{Val: v}, nil case []interface{}: float64Array := make([]float64, len(v)) for i, val := range v { float64Array[i] = val.(float64) } - result = float64Array + return qvalue.QValueArrayFloat64{Val: float64Array}, nil default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array float64: %v", value) + return nil, fmt.Errorf("failed to parse array float64: %v", value) } case qvalue.QValueKindArrayInt16: switch v := value.(type) { case []int16: - result = v + return qvalue.QValueArrayInt16{Val: v}, nil case []interface{}: int16Array := make([]int16, len(v)) for i, val := range v { int16Array[i] = val.(int16) } - result = int16Array + return qvalue.QValueArrayInt16{Val: int16Array}, nil default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array int16: %v", value) + return nil, fmt.Errorf("failed to parse array int16: %v", value) } case qvalue.QValueKindArrayInt32: switch v := value.(type) { case []int32: - result = v + return qvalue.QValueArrayInt32{Val: v}, nil case []interface{}: int32Array := make([]int32, len(v)) for i, val := range v { int32Array[i] = val.(int32) } - result = int32Array + return qvalue.QValueArrayInt32{Val: int32Array}, nil default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array int32: %v", value) + return nil, fmt.Errorf("failed to parse array int32: %v", value) } case qvalue.QValueKindArrayInt64: switch v := value.(type) { case []int64: - result = v + return qvalue.QValueArrayInt64{Val: v}, nil case []interface{}: int64Array := make([]int64, len(v)) for i, val := range v { int64Array[i] = val.(int64) } - result = int64Array + return qvalue.QValueArrayInt64{Val: int64Array}, nil default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array int64: %v", value) + return nil, fmt.Errorf("failed to parse array int64: %v", value) } case qvalue.QValueKindArrayString: switch v := value.(type) { case []string: - result = v + return qvalue.QValueArrayString{Val: v}, nil case []interface{}: stringArray := make([]string, len(v)) for i, val := range v { stringArray[i] = val.(string) } - result = stringArray + return qvalue.QValueArrayString{Val: stringArray}, nil default: - return qvalue.QValue{}, fmt.Errorf("failed to parse array string: %v", value) + return nil, fmt.Errorf("failed to parse array string: %v", value) } } - return qvalue.QValue{Kind: kind, Value: result}, nil + return qvalue.QValueNull(kind), nil } diff --git a/flow/connectors/utils/cdc_records/cdc_records_storage.go b/flow/connectors/utils/cdc_records/cdc_records_storage.go index e4a8561aea..1de0a6d2cd 100644 --- a/flow/connectors/utils/cdc_records/cdc_records_storage.go +++ b/flow/connectors/utils/cdc_records/cdc_records_storage.go @@ -16,6 +16,7 @@ import ( "go.temporal.io/sdk/log" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/peerdbenv" "github.com/PeerDB-io/peer-flow/shared" ) @@ -73,6 +74,45 @@ func (c *cdcRecordsStore) initPebbleDB() error { gob.Register(&model.DeleteRecord{}) gob.Register(time.Time{}) gob.Register(decimal.Decimal{}) + gob.Register(qvalue.QValueNull("")) + gob.Register(qvalue.QValueInvalid{}) + gob.Register(qvalue.QValueFloat32{}) + gob.Register(qvalue.QValueFloat64{}) + gob.Register(qvalue.QValueInt16{}) + gob.Register(qvalue.QValueInt32{}) + gob.Register(qvalue.QValueInt64{}) + gob.Register(qvalue.QValueBoolean{}) + gob.Register(qvalue.QValueStruct{}) + gob.Register(qvalue.QValueQChar{}) + gob.Register(qvalue.QValueString{}) + gob.Register(qvalue.QValueTimestamp{}) + gob.Register(qvalue.QValueTimestampTZ{}) + gob.Register(qvalue.QValueDate{}) + gob.Register(qvalue.QValueTime{}) + gob.Register(qvalue.QValueTimeTZ{}) + gob.Register(qvalue.QValueInterval{}) + gob.Register(qvalue.QValueNumeric{}) + gob.Register(qvalue.QValueBytes{}) + gob.Register(qvalue.QValueUUID{}) + gob.Register(qvalue.QValueJSON{}) + gob.Register(qvalue.QValueBit{}) + gob.Register(qvalue.QValueHStore{}) + gob.Register(qvalue.QValueGeography{}) + gob.Register(qvalue.QValueGeometry{}) + gob.Register(qvalue.QValuePoint{}) + gob.Register(qvalue.QValueCIDR{}) + gob.Register(qvalue.QValueINET{}) + gob.Register(qvalue.QValueMacaddr{}) + gob.Register(qvalue.QValueArrayFloat32{}) + gob.Register(qvalue.QValueArrayFloat64{}) + gob.Register(qvalue.QValueArrayInt16{}) + gob.Register(qvalue.QValueArrayInt32{}) + gob.Register(qvalue.QValueArrayInt64{}) + gob.Register(qvalue.QValueArrayString{}) + gob.Register(qvalue.QValueArrayDate{}) + gob.Register(qvalue.QValueArrayTimestamp{}) + gob.Register(qvalue.QValueArrayTimestampTZ{}) + gob.Register(qvalue.QValueArrayBoolean{}) var err error // we don't want a WAL since cache, we don't want to overwrite another DB either diff --git a/flow/connectors/utils/cdc_records/cdc_records_storage_test.go b/flow/connectors/utils/cdc_records/cdc_records_storage_test.go index 5acbc5962a..0d0b6da9e2 100644 --- a/flow/connectors/utils/cdc_records/cdc_records_storage_test.go +++ b/flow/connectors/utils/cdc_records/cdc_records_storage_test.go @@ -61,18 +61,9 @@ func genKeyAndRec(t *testing.T) (model.TableWithPkey, model.Record) { "rv": 2, }, Values: []qvalue.QValue{ - { - Kind: qvalue.QValueKindInt64, - Value: 1, - }, - { - Kind: qvalue.QValueKindTime, - Value: tv, - }, - { - Kind: qvalue.QValueKindNumeric, - Value: rv, - }, + qvalue.QValueInt64{Val: 1}, + qvalue.QValueTime{Val: tv}, + qvalue.QValueNumeric{Val: rv}, }, }, } diff --git a/flow/connectors/utils/stream.go b/flow/connectors/utils/stream.go index c2d146c062..d6729ca242 100644 --- a/flow/connectors/utils/stream.go +++ b/flow/connectors/utils/stream.go @@ -11,7 +11,7 @@ import ( "github.com/PeerDB-io/peer-flow/model/qvalue" ) -func RecordsToRawTableStream(req *model.RecordsToStreamRequest) (*model.RecordsToStreamResponse, error) { +func RecordsToRawTableStream(req *model.RecordsToStreamRequest) (*model.QRecordStream, error) { recordStream := model.NewQRecordStream(1 << 17) err := recordStream.SetSchema(&model.QRecordSchema{ Fields: []model.QField{ @@ -70,9 +70,7 @@ func RecordsToRawTableStream(req *model.RecordsToStreamRequest) (*model.RecordsT close(recordStream.Records) }() - return &model.RecordsToStreamResponse{ - Stream: recordStream, - }, nil + return recordStream, nil } func recordToQRecordOrError(batchID int64, record model.Record) model.QRecordOrError { @@ -87,23 +85,10 @@ func recordToQRecordOrError(batchID int64, record model.Record) model.QRecordOrE } } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: itemsJSON, - } - entries[4] = qvalue.QValue{ - Kind: qvalue.QValueKindInt64, - Value: 0, - } - entries[5] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: "", - } - entries[7] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: "", - } - + entries[3] = qvalue.QValueString{Val: itemsJSON} + entries[4] = qvalue.QValueInt64{Val: 0} + entries[5] = qvalue.QValueString{Val: ""} + entries[7] = qvalue.QValueString{Val: ""} case *model.UpdateRecord: newItemsJSON, err := typedRecord.NewItems.ToJSON() if err != nil { @@ -118,22 +103,10 @@ func recordToQRecordOrError(batchID int64, record model.Record) model.QRecordOrE } } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: newItemsJSON, - } - entries[4] = qvalue.QValue{ - Kind: qvalue.QValueKindInt64, - Value: 1, - } - entries[5] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: oldItemsJSON, - } - entries[7] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: KeysToString(typedRecord.UnchangedToastColumns), - } + entries[3] = qvalue.QValueString{Val: newItemsJSON} + entries[4] = qvalue.QValueInt64{Val: 1} + entries[5] = qvalue.QValueString{Val: oldItemsJSON} + entries[7] = qvalue.QValueString{Val: KeysToString(typedRecord.UnchangedToastColumns)} case *model.DeleteRecord: itemsJSON, err := typedRecord.Items.ToJSON() @@ -143,22 +116,10 @@ func recordToQRecordOrError(batchID int64, record model.Record) model.QRecordOrE } } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: itemsJSON, - } - entries[4] = qvalue.QValue{ - Kind: qvalue.QValueKindInt64, - Value: 2, - } - entries[5] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: itemsJSON, - } - entries[7] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: KeysToString(typedRecord.UnchangedToastColumns), - } + entries[3] = qvalue.QValueString{Val: itemsJSON} + entries[4] = qvalue.QValueInt64{Val: 2} + entries[5] = qvalue.QValueString{Val: itemsJSON} + entries[7] = qvalue.QValueString{Val: KeysToString(typedRecord.UnchangedToastColumns)} default: return model.QRecordOrError{ @@ -166,22 +127,10 @@ func recordToQRecordOrError(batchID int64, record model.Record) model.QRecordOrE } } - entries[0] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: uuid.New().String(), - } - entries[1] = qvalue.QValue{ - Kind: qvalue.QValueKindInt64, - Value: time.Now().UnixNano(), - } - entries[2] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: record.GetDestinationTableName(), - } - entries[6] = qvalue.QValue{ - Kind: qvalue.QValueKindInt64, - Value: batchID, - } + entries[0] = qvalue.QValueString{Val: uuid.New().String()} + entries[1] = qvalue.QValueInt64{Val: time.Now().UnixNano()} + entries[2] = qvalue.QValueString{Val: record.GetDestinationTableName()} + entries[6] = qvalue.QValueInt64{Val: batchID} return model.QRecordOrError{ Record: entries[:], diff --git a/flow/e2e/bigquery/bigquery_helper.go b/flow/e2e/bigquery/bigquery_helper.go index 7c8ab5257c..cb8558866a 100644 --- a/flow/e2e/bigquery/bigquery_helper.go +++ b/flow/e2e/bigquery/bigquery_helper.go @@ -209,37 +209,40 @@ func (b *BigQueryTestHelper) countRowsWithDataset(dataset, tableName string, non func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) { // Based on the real type of the bigquery.Value, we create a qvalue.QValue switch v := bqValue.(type) { - case int, int32: - return qvalue.QValue{Kind: qvalue.QValueKindInt32, Value: v}, nil + case int: + return qvalue.QValueInt32{Val: int32(v)}, nil + case int32: + return qvalue.QValueInt32{Val: v}, nil case int64: - return qvalue.QValue{Kind: qvalue.QValueKindInt64, Value: v}, nil + return qvalue.QValueInt64{Val: v}, nil case float32: - return qvalue.QValue{Kind: qvalue.QValueKindFloat32, Value: v}, nil + return qvalue.QValueFloat32{Val: v}, nil case float64: - return qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: v}, nil + return qvalue.QValueFloat64{Val: v}, nil case string: - return qvalue.QValue{Kind: qvalue.QValueKindString, Value: v}, nil + return qvalue.QValueString{Val: v}, nil case bool: - return qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: v}, nil + return qvalue.QValueBoolean{Val: v}, nil case civil.Date: - return qvalue.QValue{Kind: qvalue.QValueKindDate, Value: v.In(time.UTC)}, nil + return qvalue.QValueDate{Val: v.In(time.UTC)}, nil case civil.Time: - return qvalue.QValue{Kind: qvalue.QValueKindTime, Value: v}, nil + tm := time.Unix(int64(v.Hour)*3600+int64(v.Minute)*60+int64(v.Second), int64(v.Nanosecond)) + return qvalue.QValueTime{Val: tm}, nil case time.Time: - return qvalue.QValue{Kind: qvalue.QValueKindTimestamp, Value: v}, nil + return qvalue.QValueTimestamp{Val: v}, nil case *big.Rat: val, err := decimal.NewFromString(v.FloatString(32)) if err != nil { - return qvalue.QValue{}, fmt.Errorf("bqHelper failed to parse as decimal %v", v) + return nil, fmt.Errorf("bqHelper failed to parse as decimal %v", v) } - return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: val}, nil + return qvalue.QValueNumeric{Val: val}, nil case []uint8: - return qvalue.QValue{Kind: qvalue.QValueKindBytes, Value: v}, nil + return qvalue.QValueBytes{Val: v}, nil case []bigquery.Value: // If the type is an array, we need to convert each element // we can assume all elements are of the same type, let us use first element if len(v) == 0 { - return qvalue.QValue{Kind: qvalue.QValueKindInvalid, Value: nil}, nil + return qvalue.QValueNull(qvalue.QValueKindInvalid), nil } firstElement := v[0] @@ -249,60 +252,59 @@ func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) { for _, val := range v { arr = append(arr, val.(int32)) } - return qvalue.QValue{Kind: qvalue.QValueKindArrayInt32, Value: arr}, nil + return qvalue.QValueArrayInt32{Val: arr}, nil case int64: var arr []int64 for _, val := range v { arr = append(arr, val.(int64)) } - return qvalue.QValue{Kind: qvalue.QValueKindArrayInt64, Value: arr}, nil + return qvalue.QValueArrayInt64{Val: arr}, nil case float32: var arr []float32 for _, val := range v { arr = append(arr, val.(float32)) } - return qvalue.QValue{Kind: qvalue.QValueKindArrayFloat32, Value: arr}, nil + return qvalue.QValueArrayFloat32{Val: arr}, nil case float64: var arr []float64 for _, val := range v { arr = append(arr, val.(float64)) } - return qvalue.QValue{Kind: qvalue.QValueKindArrayFloat64, Value: arr}, nil + return qvalue.QValueArrayFloat64{Val: arr}, nil case string: var arr []string for _, val := range v { arr = append(arr, val.(string)) } - return qvalue.QValue{Kind: qvalue.QValueKindArrayString, Value: arr}, nil + return qvalue.QValueArrayString{Val: arr}, nil case time.Time: var arr []time.Time for _, val := range v { arr = append(arr, val.(time.Time)) } - return qvalue.QValue{Kind: qvalue.QValueKindArrayTimestamp, Value: arr}, nil + return qvalue.QValueArrayTimestamp{Val: arr}, nil case civil.Date: - var arr []civil.Date + var arr []time.Time for _, val := range v { - arr = append(arr, val.(civil.Date)) + arr = append(arr, val.(civil.Date).In(time.UTC)) } - return qvalue.QValue{Kind: qvalue.QValueKindArrayDate, Value: arr}, nil + return qvalue.QValueArrayDate{Val: arr}, nil case bool: var arr []bool - for _, val := range v { arr = append(arr, val.(bool)) } - return qvalue.QValue{Kind: qvalue.QValueKindArrayBoolean, Value: arr}, nil + return qvalue.QValueArrayBoolean{Val: arr}, nil default: // If type is unsupported, return error - return qvalue.QValue{}, fmt.Errorf("bqHelper unsupported type %T", et) + return nil, fmt.Errorf("bqHelper unsupported type %T", et) } case nil: - return qvalue.QValue{Kind: qvalue.QValueKindInvalid, Value: nil}, nil + return qvalue.QValueNull(qvalue.QValueKindInvalid), nil default: // If type is unsupported, return error - return qvalue.QValue{}, fmt.Errorf("bqHelper unsupported type %T", v) + return nil, fmt.Errorf("bqHelper unsupported type %T", v) } } @@ -512,5 +514,13 @@ func (b *BigQueryTestHelper) RunInt64Query(query string) (int64, error) { return 0, fmt.Errorf("expected only 1 record, got %d", len(recordBatch.Records)) } - return recordBatch.Records[0][0].Value.(int64), nil + switch v := recordBatch.Records[0][0].(type) { + case qvalue.QValueInt16: + return int64(v.Val), nil + case qvalue.QValueInt32: + return int64(v.Val), nil + case qvalue.QValueInt64: + return v.Val, nil + } + return 0, fmt.Errorf("non-integer result: %T", recordBatch.Records[0][0]) } diff --git a/flow/e2e/bigquery/peer_flow_bq_test.go b/flow/e2e/bigquery/peer_flow_bq_test.go index 51ba631e60..206f71ea27 100644 --- a/flow/e2e/bigquery/peer_flow_bq_test.go +++ b/flow/e2e/bigquery/peer_flow_bq_test.go @@ -29,14 +29,14 @@ func (s PeerFlowE2ETestSuiteBQ) checkJSONValue(tableName, colName, fieldName, va "SELECT `%s`.%s FROM `%s.%s`;", colName, fieldName, s.bqHelper.Config.DatasetId, tableName)) if err != nil { - return fmt.Errorf("json value check failed: %v", err) + return fmt.Errorf("json value check failed: %w", err) } if len(res.Records) == 0 { return fmt.Errorf("bad json: empty result set from %s", tableName) } - jsonVal := res.Records[0][0].Value + jsonVal := res.Records[0][0].Value() if jsonVal != value { return fmt.Errorf("bad json value in field %s of column %s: %v. expected: %v", fieldName, colName, jsonVal, value) } @@ -69,20 +69,8 @@ func (s *PeerFlowE2ETestSuiteBQ) checkPeerdbColumns(dstQualified string, softDel for _, record := range recordBatch.Records { for _, entry := range record { - if entry.Kind == qvalue.QValueKindBoolean { - isDeleteVal, ok := entry.Value.(bool) - if !(ok && isDeleteVal) { - return errors.New("peerdb column failed: _PEERDB_IS_DELETED is not true") - } - recordCount += 1 - } - - if entry.Kind == qvalue.QValueKindTimestamp { - _, ok := entry.Value.(time.Time) - if !ok { - return errors.New("peerdb column failed: _PEERDB_SYNCED_AT is not valid") - } - + switch entry.(type) { + case qvalue.QValueBoolean, qvalue.QValueTimestamp: recordCount += 1 } } @@ -455,14 +443,17 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Types_BQ() { // check if JSON on bigquery side is a good JSON if err := s.checkJSONValue(dstTableName, "c17", "sai", "-8.021390374331551"); err != nil { + s.t.Log(err) return false } // check if HSTORE on bigquery side is a good JSON if err := s.checkJSONValue(dstTableName, "c46", "key1", "\"value1\""); err != nil { + s.t.Log(err) return false } if err := s.checkJSONValue(dstTableName, "c46", "key2", "null"); err != nil { + s.t.Log(err) return false } diff --git a/flow/e2e/snowflake/qrep_flow_sf_test.go b/flow/e2e/snowflake/qrep_flow_sf_test.go index d7ce60d81b..2d0d8c5d3a 100644 --- a/flow/e2e/snowflake/qrep_flow_sf_test.go +++ b/flow/e2e/snowflake/qrep_flow_sf_test.go @@ -29,7 +29,7 @@ func (s PeerFlowE2ETestSuiteSF) checkJSONValue(tableName, colName, fieldName, va return fmt.Errorf("bad json: empty result set from %s", tableName) } - jsonVal := res.Records[0][0].Value + jsonVal := res.Records[0][0].Value() if jsonVal != value { return fmt.Errorf("bad json value in field %s of column %s: %v. expected: %v", fieldName, colName, jsonVal, value) } diff --git a/flow/e2e/snowflake/snowflake_helper.go b/flow/e2e/snowflake/snowflake_helper.go index 14ca9dc35f..c2a6001b01 100644 --- a/flow/e2e/snowflake/snowflake_helper.go +++ b/flow/e2e/snowflake/snowflake_helper.go @@ -6,9 +6,6 @@ import ( "errors" "fmt" "os" - "time" - - "github.com/shopspring/decimal" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" "github.com/PeerDB-io/peer-flow/e2eshared" @@ -177,16 +174,15 @@ func (s *SnowflakeTestHelper) RunIntQuery(query string) (int, error) { return 0, fmt.Errorf("failed to execute query: %s, returned %d != 1 columns", query, len(rec)) } - switch rec[0].Kind { - case qvalue.QValueKindInt32: - return int(rec[0].Value.(int32)), nil - case qvalue.QValueKindInt64: - return int(rec[0].Value.(int64)), nil - case qvalue.QValueKindNumeric: - val := rec[0].Value.(decimal.Decimal) - return int(val.IntPart()), nil + switch v := rec[0].(type) { + case qvalue.QValueInt32: + return int(v.Val), nil + case qvalue.QValueInt64: + return int(v.Val), nil + case qvalue.QValueNumeric: + return int(v.Val.IntPart()), nil default: - return 0, fmt.Errorf("failed to execute query: %s, returned value of type %s", query, rec[0].Kind) + return 0, fmt.Errorf("failed to execute query: %s, returned value of type %s", query, rec[0].Kind()) } } @@ -199,12 +195,9 @@ func (s *SnowflakeTestHelper) checkSyncedAt(query string) error { for _, record := range recordBatch.Records { for _, entry := range record { - if entry.Kind != qvalue.QValueKindTimestamp { - return errors.New("synced_at column check failed: _PEERDB_SYNCED_AT is not timestamp") - } - _, ok := entry.Value.(time.Time) + _, ok := entry.(qvalue.QValueTimestamp) if !ok { - return errors.New("synced_at column failed: _PEERDB_SYNCED_AT is not valid") + return errors.New("synced_at column failed: _PEERDB_SYNCED_AT is not a timestamp") } } } diff --git a/flow/e2eshared/e2eshared.go b/flow/e2eshared/e2eshared.go index 087ff58014..9b63b41286 100644 --- a/flow/e2eshared/e2eshared.go +++ b/flow/e2eshared/e2eshared.go @@ -69,8 +69,8 @@ func CheckQRecordEquality(t *testing.T, q []qvalue.QValue, other []qvalue.QValue for i, entry := range q { otherEntry := other[i] - if !entry.Equals(otherEntry) { - t.Logf("entry %d: %T %v != %T %v", i, entry.Value, entry, otherEntry.Value, otherEntry) + if !qvalue.Equals(entry, otherEntry) { + t.Logf("entry %d: %T %+v != %T %+v", i, entry, entry, otherEntry, otherEntry) return false } } diff --git a/flow/hstore/hstore.go b/flow/hstore/hstore.go index cbb7d60c8a..aca1412ba0 100644 --- a/flow/hstore/hstore.go +++ b/flow/hstore/hstore.go @@ -14,9 +14,8 @@ import ( "strings" ) -type text struct { - String string - Valid bool +func text(s string) *string { + return &s } type hstore map[string]*string @@ -150,35 +149,35 @@ func (p *hstoreParser) consumeKVSeparator() error { return p.consumeExpected2('=', '>') } -// consumeDoubleQuotedOrNull consumes the Hstore key/value separator "=>" or returns an error. -func (p *hstoreParser) consumeDoubleQuotedOrNull() (text, error) { +// consumeDoubleQuotedOrNull consumes the string or returns an error. +func (p *hstoreParser) consumeDoubleQuotedOrNull() (*string, error) { // peek at the next byte if p.atEnd() { - return text{}, errors.New("found end instead of value") + return nil, errors.New("found end instead of value") } next := p.str[p.pos] if next == 'N' { // must be the exact string NULL: use consumeExpected2 twice err := p.consumeExpected2('N', 'U') if err != nil { - return text{}, err + return nil, err } err = p.consumeExpected2('L', 'L') if err != nil { - return text{}, err + return nil, err } - return text{String: "", Valid: false}, nil + return nil, nil } else if next != '"' { - return text{}, unexpectedByteErr(next, '"') + return nil, unexpectedByteErr(next, '"') } // skip the double quote p.pos += 1 s, err := p.consumeDoubleQuoted() if err != nil { - return text{}, err + return nil, err } - return text{String: s, Valid: true}, nil + return text(s), nil } func ParseHstore(s string) (string, error) { @@ -217,11 +216,7 @@ func ParseHstore(s string) (string, error) { if err != nil { return "", err } - if value.Valid { - result[key] = &value.String - } else { - result[key] = nil - } + result[key] = value } jsonBytes, err := json.Marshal(result) diff --git a/flow/model/conversion_avro.go b/flow/model/conversion_avro.go index 39e3579f8f..94ea7ddfc3 100644 --- a/flow/model/conversion_avro.go +++ b/flow/model/conversion_avro.go @@ -40,14 +40,12 @@ func (qac *QRecordAvroConverter) Convert() (map[string]interface{}, error) { key := qac.ColNames[idx] _, nullable := qac.NullableFields[key] - avroConverter := qvalue.NewQValueAvroConverter( + avroVal, err := qvalue.QValueToAvro( val, qac.TargetDWH, nullable, qac.logger, ) - - avroVal, err := avroConverter.ToAvroValue() if err != nil { return nil, fmt.Errorf("failed to convert QValue to Avro-compatible value: %w", err) } diff --git a/flow/model/model.go b/flow/model/model.go index 10fe95d5a7..0754db0c66 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -62,7 +62,7 @@ type ToJSONOptions struct { HStoreAsJSON bool } -func NewToJSONOptions(unnestCols []string, hstoreAsJSON bool) *ToJSONOptions { +func NewToJSONOptions(unnestCols []string, hstoreAsJSON bool) ToJSONOptions { var unnestColumns map[string]struct{} if len(unnestCols) != 0 { unnestColumns = make(map[string]struct{}, len(unnestCols)) @@ -70,7 +70,7 @@ func NewToJSONOptions(unnestCols []string, hstoreAsJSON bool) *ToJSONOptions { unnestColumns[col] = struct{}{} } } - return &ToJSONOptions{ + return ToJSONOptions{ UnnestColumns: unnestColumns, HStoreAsJSON: hstoreAsJSON, } diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go index dd55ef7ecc..eabd187ade 100644 --- a/flow/model/qrecord_batch.go +++ b/flow/model/qrecord_batch.go @@ -9,7 +9,6 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" - "github.com/shopspring/decimal" "github.com/PeerDB-io/peer-flow/geo" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -42,7 +41,7 @@ func (q *QRecordBatch) ToQRecordStream(buffer int) (*QRecordStream, error) { } func constructArray[T any](qValue qvalue.QValue, typeName string) (*pgtype.Array[T], error) { - v, ok := qValue.Value.([]T) + v, ok := qValue.Value().([]T) if !ok { return nil, fmt.Errorf("invalid %s value", typeName) } @@ -93,154 +92,67 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { values := make([]interface{}, numEntries) for i, qValue := range record { - if qValue.Value == nil { + if qValue.Value() == nil { values[i] = nil continue } - switch qValue.Kind { - case qvalue.QValueKindFloat32: - v, ok := qValue.Value.(float32) - if !ok { - src.err = errors.New("invalid float32 value") - return nil, src.err - } - values[i] = v - - case qvalue.QValueKindFloat64: - v, ok := qValue.Value.(float64) - if !ok { - src.err = errors.New("invalid float64 value") - return nil, src.err - } - values[i] = v - - case qvalue.QValueKindInt16, qvalue.QValueKindInt32: - v, ok := qValue.Value.(int32) - if !ok { - src.err = errors.New("invalid int32 value") - return nil, src.err - } - values[i] = v - - case qvalue.QValueKindInt64: - v, ok := qValue.Value.(int64) - if !ok { - src.err = errors.New("invalid int64 value") - return nil, src.err - } - values[i] = v - - case qvalue.QValueKindBoolean: - v, ok := qValue.Value.(bool) - if !ok { - src.err = errors.New("invalid boolean value") - return nil, src.err - } - values[i] = v - - case qvalue.QValueKindQChar: - v, ok := qValue.Value.(uint8) - if !ok { - src.err = errors.New("invalid \"char\" value") - return nil, src.err - } - values[i] = rune(v) - - case qvalue.QValueKindString: - v, ok := qValue.Value.(string) - if !ok { - src.err = errors.New("invalid string value") - return nil, src.err - } - values[i] = v - - case qvalue.QValueKindCIDR, qvalue.QValueKindINET: - v, ok := qValue.Value.(string) + switch v := qValue.(type) { + case qvalue.QValueFloat32: + values[i] = v.Val + case qvalue.QValueFloat64: + values[i] = v.Val + case qvalue.QValueInt16: + values[i] = v.Val + case qvalue.QValueInt32: + values[i] = v.Val + case qvalue.QValueInt64: + values[i] = v.Val + case qvalue.QValueBoolean: + values[i] = v.Val + case qvalue.QValueQChar: + values[i] = rune(v.Val) + case qvalue.QValueString: + values[i] = v.Val + case qvalue.QValueCIDR, qvalue.QValueINET: + str, ok := v.Value().(string) if !ok { src.err = errors.New("invalid INET/CIDR value") return nil, src.err } - values[i] = v + values[i] = str - case qvalue.QValueKindTime: - t, ok := qValue.Value.(time.Time) - if !ok { - src.err = errors.New("invalid Time value") - return nil, src.err - } - time := pgtype.Time{Microseconds: t.UnixMicro(), Valid: true} - values[i] = time - - case qvalue.QValueKindTimestamp: - t, ok := qValue.Value.(time.Time) - if !ok { - src.err = errors.New("invalid ExtendedTime value") - return nil, src.err - } - timestamp := pgtype.Timestamp{Time: t, Valid: true} - values[i] = timestamp - - case qvalue.QValueKindTimestampTZ: - t, ok := qValue.Value.(time.Time) - if !ok { - src.err = errors.New("invalid ExtendedTime value") - return nil, src.err - } - timestampTZ := pgtype.Timestamptz{Time: t, Valid: true} - values[i] = timestampTZ - - case qvalue.QValueKindUUID: - v, ok := qValue.Value.([16]byte) // treat it as byte slice - if !ok { - src.err = fmt.Errorf("invalid UUID value %v", qValue.Value) - return nil, src.err - } - values[i] = uuid.UUID(v) - - case qvalue.QValueKindNumeric: - v, ok := qValue.Value.(decimal.Decimal) - if !ok { - src.err = fmt.Errorf("invalid Numeric value %v", qValue.Value) - return nil, src.err - } - values[i] = v - - case qvalue.QValueKindBytes, qvalue.QValueKindBit: - v, ok := qValue.Value.([]byte) + case qvalue.QValueTime: + values[i] = pgtype.Time{Microseconds: v.Val.UnixMicro(), Valid: true} + case qvalue.QValueTimestamp: + values[i] = pgtype.Timestamp{Time: v.Val, Valid: true} + case qvalue.QValueTimestampTZ: + values[i] = pgtype.Timestamptz{Time: v.Val, Valid: true} + case qvalue.QValueUUID: + values[i] = uuid.UUID(v.Val) + case qvalue.QValueNumeric: + values[i] = v.Val + case qvalue.QValueBytes, qvalue.QValueBit: + bytes, ok := v.Value().([]byte) if !ok { src.err = errors.New("invalid Bytes value") return nil, src.err } - values[i] = v - - case qvalue.QValueKindDate: - t, ok := qValue.Value.(time.Time) - if !ok { - src.err = errors.New("invalid Date value") - return nil, src.err - } - date := pgtype.Date{Time: t, Valid: true} - values[i] = date + values[i] = bytes - case qvalue.QValueKindHStore: - v, ok := qValue.Value.(string) - if !ok { - src.err = errors.New("invalid HStore value") - return nil, src.err - } - - values[i] = v - case qvalue.QValueKindGeography, qvalue.QValueKindGeometry, qvalue.QValueKindPoint: - v, ok := qValue.Value.(string) + case qvalue.QValueDate: + values[i] = pgtype.Date{Time: v.Val, Valid: true} + case qvalue.QValueHStore: + values[i] = v.Val + case qvalue.QValueGeography, qvalue.QValueGeometry, qvalue.QValuePoint: + geoWkt, ok := v.Value().(string) if !ok { src.err = errors.New("invalid Geospatial value") return nil, src.err } - geoWkt := v - if strings.HasPrefix(v, "SRID=") { - _, wkt, found := strings.Cut(v, ";") + if strings.HasPrefix(geoWkt, "SRID=") { + _, wkt, found := strings.Cut(geoWkt, ";") if found { geoWkt = wkt } @@ -253,79 +165,74 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { } values[i] = wkb - case qvalue.QValueKindArrayString: - v, err := constructArray[string](qValue, "ArrayString") + case qvalue.QValueArrayString: + a, err := constructArray[string](qValue, "ArrayString") if err != nil { src.err = err return nil, src.err } - values[i] = v + values[i] = a - case qvalue.QValueKindArrayDate, qvalue.QValueKindArrayTimestamp, qvalue.QValueKindArrayTimestampTZ: - v, err := constructArray[time.Time](qValue, "ArrayTime") + case qvalue.QValueArrayDate, qvalue.QValueArrayTimestamp, qvalue.QValueArrayTimestampTZ: + a, err := constructArray[time.Time](qValue, "ArrayTime") if err != nil { src.err = err return nil, src.err } - values[i] = v + values[i] = a - case qvalue.QValueKindArrayInt16: - v, err := constructArray[int16](qValue, "ArrayInt16") + case qvalue.QValueArrayInt16: + a, err := constructArray[int16](qValue, "ArrayInt16") if err != nil { src.err = err return nil, src.err } - values[i] = v + values[i] = a - case qvalue.QValueKindArrayInt32: - v, err := constructArray[int32](qValue, "ArrayInt32") + case qvalue.QValueArrayInt32: + a, err := constructArray[int32](qValue, "ArrayInt32") if err != nil { src.err = err return nil, src.err } - values[i] = v + values[i] = a - case qvalue.QValueKindArrayInt64: - v, err := constructArray[int64](qValue, "ArrayInt64") + case qvalue.QValueArrayInt64: + a, err := constructArray[int64](qValue, "ArrayInt64") if err != nil { src.err = err return nil, src.err } - values[i] = v + values[i] = a - case qvalue.QValueKindArrayFloat32: - v, err := constructArray[float32](qValue, "ArrayFloat32") + case qvalue.QValueArrayFloat32: + a, err := constructArray[float32](qValue, "ArrayFloat32") if err != nil { src.err = err return nil, src.err } - values[i] = v + values[i] = a - case qvalue.QValueKindArrayFloat64: - v, err := constructArray[float64](qValue, "ArrayFloat64") + case qvalue.QValueArrayFloat64: + a, err := constructArray[float64](qValue, "ArrayFloat64") if err != nil { src.err = err return nil, src.err } - values[i] = v - case qvalue.QValueKindArrayBoolean: - v, err := constructArray[bool](qValue, "ArrayBool") + values[i] = a + case qvalue.QValueArrayBoolean: + a, err := constructArray[bool](qValue, "ArrayBool") if err != nil { src.err = err return nil, src.err } - values[i] = v - case qvalue.QValueKindJSON: - v, ok := qValue.Value.(string) - if !ok { - src.err = errors.New("invalid JSON value") - return nil, src.err - } - values[i] = v + values[i] = a + case qvalue.QValueJSON: + values[i] = v.Val // And so on for the other types... default: - src.err = fmt.Errorf("unsupported value type %s", qValue.Kind) + src.err = fmt.Errorf("unsupported value type %T", qValue) return nil, src.err } } diff --git a/flow/model/qrecord_stream.go b/flow/model/qrecord_stream.go index 97546ce247..0feafe5502 100644 --- a/flow/model/qrecord_stream.go +++ b/flow/model/qrecord_stream.go @@ -51,10 +51,6 @@ func (r *RecordsToStreamRequest) GetRecords() <-chan Record { return r.records } -type RecordsToStreamResponse struct { - Stream *QRecordStream -} - func NewQRecordStream(buffer int) *QRecordStream { return &QRecordStream{ schema: make(chan QRecordSchemaOrError, 1), diff --git a/flow/model/qrecord_test.go b/flow/model/qrecord_test.go index e6c769fd69..a373452b9e 100644 --- a/flow/model/qrecord_test.go +++ b/flow/model/qrecord_test.go @@ -23,26 +23,26 @@ func TestEquals(t *testing.T) { }{ { name: "Equal - Same UUID", - q1: []qvalue.QValue{{Kind: qvalue.QValueKindUUID, Value: uuidVal1}}, - q2: []qvalue.QValue{{Kind: qvalue.QValueKindString, Value: uuidVal1.String()}}, + q1: []qvalue.QValue{qvalue.QValueUUID{Val: uuidVal1}}, + q2: []qvalue.QValue{qvalue.QValueString{Val: uuidVal1.String()}}, want: true, }, { name: "Not Equal - Different UUID", - q1: []qvalue.QValue{{Kind: qvalue.QValueKindUUID, Value: uuidVal1}}, - q2: []qvalue.QValue{{Kind: qvalue.QValueKindUUID, Value: uuidVal2}}, + q1: []qvalue.QValue{qvalue.QValueUUID{Val: uuidVal1}}, + q2: []qvalue.QValue{qvalue.QValueUUID{Val: uuidVal2}}, want: false, }, { name: "Equal - Same numeric", - q1: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: decimal.NewFromInt(5)}}, - q2: []qvalue.QValue{{Kind: qvalue.QValueKindString, Value: "5"}}, + q1: []qvalue.QValue{qvalue.QValueNumeric{Val: decimal.NewFromInt(5)}}, + q2: []qvalue.QValue{qvalue.QValueString{Val: "5"}}, want: true, }, { name: "Not Equal - Different numeric", - q1: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: decimal.NewFromInt(5)}}, - q2: []qvalue.QValue{{Kind: qvalue.QValueKindString, Value: "4.99"}}, + q1: []qvalue.QValue{qvalue.QValueNumeric{Val: decimal.NewFromInt(5)}}, + q2: []qvalue.QValue{qvalue.QValueString{Val: "4.99"}}, want: false, }, } diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 3df8738209..536e705db6 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -162,34 +162,30 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH QDWHType, precision } type QValueAvroConverter struct { - QValue TargetDWH QDWHType Nullable bool logger log.Logger } -func NewQValueAvroConverter(value QValue, targetDWH QDWHType, nullable bool, logger log.Logger) *QValueAvroConverter { - return &QValueAvroConverter{ - QValue: value, +func QValueToAvro(value QValue, targetDWH QDWHType, nullable bool, logger log.Logger) (interface{}, error) { + if nullable && value.Value() == nil { + return nil, nil + } + + c := &QValueAvroConverter{ TargetDWH: targetDWH, Nullable: nullable, logger: logger, } -} - -func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { - if c.Nullable && c.Value == nil { - return nil, nil - } - switch c.Kind { - case QValueKindInvalid: + switch v := value.(type) { + case QValueInvalid: // we will attempt to convert invalid to a string - return c.processNullableUnion("string", c.Value) - case QValueKindTime: - t, err := c.processGoTime() - if err != nil || t == nil { - return t, err + return c.processNullableUnion("string", v.Val) + case QValueTime: + t := c.processGoTime(v.Val) + if t == nil { + return nil, nil } if c.TargetDWH == QDWHTypeSnowflake { @@ -211,10 +207,10 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return goavro.Union("long.time-micros", t.(int64)), nil } return t.(int64), nil - case QValueKindTimeTZ: - t, err := c.processGoTimeTZ() - if err != nil || t == nil { - return t, err + case QValueTimeTZ: + t := c.processGoTimeTZ(v.Val) + if t == nil { + return nil, nil } if c.TargetDWH == QDWHTypeSnowflake { if c.Nullable { @@ -235,10 +231,10 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return goavro.Union("long.time-micros", t.(int64)), nil } return t.(int64), nil - case QValueKindTimestamp: - t, err := c.processGoTimestamp() - if err != nil || t == nil { - return t, err + case QValueTimestamp: + t := c.processGoTimestamp(v.Val) + if t == nil { + return nil, nil } if c.TargetDWH == QDWHTypeSnowflake { if c.Nullable { @@ -252,10 +248,10 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return goavro.Union("long.timestamp-micros", t.(int64)), nil } return t.(int64), nil - case QValueKindTimestampTZ: - t, err := c.processGoTimestampTZ() - if err != nil || t == nil { - return t, err + case QValueTimestampTZ: + t := c.processGoTimestampTZ(v.Val) + if t == nil { + return nil, nil } if c.TargetDWH == QDWHTypeSnowflake { if c.Nullable { @@ -269,10 +265,10 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return goavro.Union("long.timestamp-micros", t.(int64)), nil } return t.(int64), nil - case QValueKindDate: - t, err := c.processGoDate() - if err != nil || t == nil { - return t, err + case QValueDate: + t := c.processGoDate(v.Val) + if t == nil { + return nil, nil } if c.TargetDWH == QDWHTypeSnowflake { @@ -287,191 +283,135 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return goavro.Union("int.date", t), nil } return t, nil - case QValueKindQChar: - return c.processNullableUnion("string", string(c.Value.(uint8))) - case QValueKindString, QValueKindCIDR, QValueKindINET, QValueKindMacaddr, QValueKindInterval: - if c.TargetDWH == QDWHTypeSnowflake && c.Value != nil && - (len(c.Value.(string)) > 15*1024*1024) { + case QValueQChar: + return c.processNullableUnion("string", string(v.Val)) + case QValueString, QValueCIDR, QValueINET, QValueMacaddr, QValueInterval: + if c.TargetDWH == QDWHTypeSnowflake && v.Value() != nil && + (len(v.Value().(string)) > 15*1024*1024) { slog.Warn("Truncating TEXT value > 15MB for Snowflake!") slog.Warn("Check this issue for details: https://github.com/PeerDB-io/peerdb/issues/309") return c.processNullableUnion("string", "") } - return c.processNullableUnion("string", c.Value) - case QValueKindFloat32: + return c.processNullableUnion("string", v.Value()) + case QValueFloat32: if c.TargetDWH == QDWHTypeBigQuery { - return c.processNullableUnion("double", c.Value) - } - return c.processNullableUnion("float", c.Value) - case QValueKindFloat64: - if c.TargetDWH == QDWHTypeSnowflake || c.TargetDWH == QDWHTypeBigQuery { - if f32Val, ok := c.Value.(float32); ok { - return c.processNullableUnion("double", float64(f32Val)) - } - } - return c.processNullableUnion("double", c.Value) - case QValueKindInt16, QValueKindInt32, QValueKindInt64: - return c.processNullableUnion("long", c.Value) - case QValueKindBoolean: - return c.processNullableUnion("boolean", c.Value) - case QValueKindStruct: - return nil, errors.New("QValueKindStruct not supported") - case QValueKindNumeric: - return c.processNumeric() - case QValueKindBytes, QValueKindBit: - return c.processBytes() - case QValueKindJSON: - return c.processJSON() - case QValueKindHStore: - return c.processHStore() - case QValueKindArrayFloat32: - return c.processArrayFloat32() - case QValueKindArrayFloat64: - return c.processArrayFloat64() - case QValueKindArrayInt16: - return c.processArrayInt16() - case QValueKindArrayInt32: - return c.processArrayInt32() - case QValueKindArrayInt64: - return c.processArrayInt64() - case QValueKindArrayString: - return c.processArrayString() - case QValueKindArrayBoolean: - return c.processArrayBoolean() - case QValueKindArrayTimestamp, QValueKindArrayTimestampTZ: - arrayTime, err := c.processArrayTime() - if err != nil || arrayTime == nil { - return arrayTime, err - } - - return arrayTime, nil - case QValueKindArrayDate: - arrayDate, err := c.processArrayDate() - if err != nil || arrayDate == nil { - return arrayDate, err - } - - return arrayDate, nil - case QValueKindUUID: - return c.processUUID() - case QValueKindGeography, QValueKindGeometry, QValueKindPoint: - return c.processGeospatial() + return c.processNullableUnion("double", float64(v.Val)) + } + return c.processNullableUnion("float", v.Val) + case QValueFloat64: + return c.processNullableUnion("double", v.Val) + case QValueInt16: + return c.processNullableUnion("long", int32(v.Val)) + case QValueInt32, QValueInt64: + return c.processNullableUnion("long", v.Value()) + case QValueBoolean: + return c.processNullableUnion("boolean", v.Val) + case QValueStruct: + return nil, errors.New("QValueStruct not supported") + case QValueNumeric: + return c.processNumeric(v.Val), nil + case QValueBytes: + return c.processBytes(v.Val), nil + case QValueBit: + return c.processBytes(v.Val), nil + case QValueJSON: + return c.processJSON(v.Val) + case QValueHStore: + return c.processHStore(v.Val) + case QValueArrayFloat32: + return c.processArrayFloat32(v.Val), nil + case QValueArrayFloat64: + return c.processArrayFloat64(v.Val), nil + case QValueArrayInt16: + return c.processArrayInt16(v.Val), nil + case QValueArrayInt32: + return c.processArrayInt32(v.Val), nil + case QValueArrayInt64: + return c.processArrayInt64(v.Val), nil + case QValueArrayString: + return c.processArrayString(v.Val), nil + case QValueArrayBoolean: + return c.processArrayBoolean(v.Val), nil + case QValueArrayTimestamp, QValueArrayTimestampTZ: + return c.processArrayTime(v.Value().([]time.Time)), nil + case QValueArrayDate: + return c.processArrayDate(v.Val), nil + case QValueUUID: + return c.processUUID(v.Val), nil + case QValueGeography, QValueGeometry, QValuePoint: + return c.processGeospatial(v.Value().(string)), nil default: - return nil, fmt.Errorf("[toavro] unsupported QValueKind: %s", c.Kind) + return nil, fmt.Errorf("[toavro] unsupported %T", value) } } -func (c *QValueAvroConverter) processGoTimeTZ() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - t, ok := c.Value.(time.Time) - if !ok { - return nil, errors.New("invalid TimeTZ value") - } - +func (c *QValueAvroConverter) processGoTimeTZ(t time.Time) interface{} { // Snowflake has issues with avro timestamp types, returning as string form // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file if c.TargetDWH == QDWHTypeSnowflake { - return t.Format("15:04:05.999999-0700"), nil + return t.Format("15:04:05.999999-0700") } - return t.UnixMicro(), nil + return t.UnixMicro() } -func (c *QValueAvroConverter) processGoTime() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - t, ok := c.Value.(time.Time) - if !ok { - return nil, errors.New("invalid Time value") - } - +func (c *QValueAvroConverter) processGoTime(t time.Time) interface{} { // Snowflake has issues with avro timestamp types, returning as string form // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file if c.TargetDWH == QDWHTypeSnowflake { - return t.Format("15:04:05.999999"), nil + return t.Format("15:04:05.999999") } if c.TargetDWH == QDWHTypeClickhouse { - return t.Format("15:04:05.999999"), nil + return t.Format("15:04:05.999999") } - return t.UnixMicro(), nil + return t.UnixMicro() } -func (c *QValueAvroConverter) processGoTimestampTZ() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - t, ok := c.Value.(time.Time) - if !ok { - return nil, errors.New("invalid TimestampTZ value") - } - +func (c *QValueAvroConverter) processGoTimestampTZ(t time.Time) interface{} { // Snowflake has issues with avro timestamp types, returning as string form // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file if c.TargetDWH == QDWHTypeSnowflake { - return t.Format("2006-01-02 15:04:05.999999-0700"), nil + return t.Format("2006-01-02 15:04:05.999999-0700") } // Bigquery will not allow timestamp if it is less than 1AD and more than 9999AD // So make such timestamps null if DisallowedTimestamp(c.TargetDWH, t, c.logger) { - return nil, nil + return nil } - return t.UnixMicro(), nil + return t.UnixMicro() } -func (c *QValueAvroConverter) processGoTimestamp() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - t, ok := c.Value.(time.Time) - if !ok { - return nil, errors.New("invalid Timestamp value") - } - +func (c *QValueAvroConverter) processGoTimestamp(t time.Time) interface{} { // Snowflake has issues with avro timestamp types, returning as string form // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file if c.TargetDWH == QDWHTypeSnowflake { - return t.Format("2006-01-02 15:04:05.999999"), nil + return t.Format("2006-01-02 15:04:05.999999") } // Bigquery will not allow timestamp if it is less than 1AD and more than 9999AD // So make such timestamps null if DisallowedTimestamp(c.TargetDWH, t, c.logger) { - return nil, nil + return nil } - return t.UnixMicro(), nil + return t.UnixMicro() } -func (c *QValueAvroConverter) processGoDate() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - t, ok := c.Value.(time.Time) - if !ok { - return nil, errors.New("invalid Time value for Date") - } - +func (c *QValueAvroConverter) processGoDate(t time.Time) interface{} { // Bigquery will not allow Date if it is less than 1AD and more than 9999AD // So make such Dates null if DisallowedTimestamp(c.TargetDWH, t, c.logger) { - return nil, nil + return nil } // Snowflake has issues with avro timestamp types, returning as string form // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file if c.TargetDWH == QDWHTypeSnowflake { - return t.Format("2006-01-02"), nil + return t.Format("2006-01-02") } - return t, nil + return t } func (c *QValueAvroConverter) processNullableUnion( @@ -487,63 +427,22 @@ func (c *QValueAvroConverter) processNullableUnion( return value, nil } -func (c *QValueAvroConverter) processNumeric() (interface{}, error) { - if c.Value == nil { - return nil, nil - } - - num, ok := c.Value.(decimal.Decimal) - if !ok { - return nil, fmt.Errorf("invalid Numeric value: expected decimal.Decimal, got %T", c.Value) - } +func (c *QValueAvroConverter) processNumeric(num decimal.Decimal) interface{} { rat := num.Rat() - if c.Nullable { - return goavro.Union("bytes.decimal", rat), nil + return goavro.Union("bytes.decimal", rat) } - - return rat, nil + return rat } -func (c *QValueAvroConverter) processBytes() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - if c.TargetDWH == QDWHTypeClickhouse { - bigNum, ok := c.Value.(decimal.Decimal) - if !ok { - return nil, fmt.Errorf("invalid Numeric value: expected float64, got %T", c.Value) - } - num, ok := bigNum.Float64() - if !ok { - return nil, fmt.Errorf("not able to convert bigNum to float64 %+v", bigNum) - } - return goavro.Union("double", num), nil - } - - byteData, ok := c.Value.([]byte) - if !ok { - return nil, errors.New("invalid Bytes value") - } - +func (c *QValueAvroConverter) processBytes(byteData []byte) interface{} { if c.Nullable { - return goavro.Union("bytes", byteData), nil + return goavro.Union("bytes", byteData) } - - return byteData, nil + return byteData } -func (c *QValueAvroConverter) processJSON() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - jsonString, ok := c.Value.(string) - if !ok { - return nil, fmt.Errorf("invalid JSON value %v", c.Value) - } - +func (c *QValueAvroConverter) processJSON(jsonString string) (interface{}, error) { if c.Nullable { if c.TargetDWH == QDWHTypeSnowflake && len(jsonString) > 15*1024*1024 { slog.Warn("Truncating JSON value > 15MB for Snowflake!") @@ -561,33 +460,15 @@ func (c *QValueAvroConverter) processJSON() (interface{}, error) { return jsonString, nil } -func (c *QValueAvroConverter) processArrayBoolean() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - arrayData, ok := c.Value.([]bool) - if !ok { - return nil, errors.New("invalid Boolean array value") - } - +func (c *QValueAvroConverter) processArrayBoolean(arrayData []bool) interface{} { if c.Nullable { - return goavro.Union("array", arrayData), nil + return goavro.Union("array", arrayData) } - return arrayData, nil + return arrayData } -func (c *QValueAvroConverter) processArrayTime() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - arrayTime, ok := c.Value.([]time.Time) - if !ok { - return nil, errors.New("invalid Timestamp array value") - } - +func (c *QValueAvroConverter) processArrayTime(arrayTime []time.Time) interface{} { transformedTimeArr := make([]interface{}, 0, len(arrayTime)) for _, t := range arrayTime { // Snowflake has issues with avro timestamp types, returning as string form @@ -600,22 +481,13 @@ func (c *QValueAvroConverter) processArrayTime() (interface{}, error) { } if c.Nullable { - return goavro.Union("array", transformedTimeArr), nil + return goavro.Union("array", transformedTimeArr) } - return transformedTimeArr, nil + return transformedTimeArr } -func (c *QValueAvroConverter) processArrayDate() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - arrayDate, ok := c.Value.([]time.Time) - if !ok { - return nil, errors.New("invalid Date array value") - } - +func (c *QValueAvroConverter) processArrayDate(arrayDate []time.Time) interface{} { transformedTimeArr := make([]interface{}, 0, len(arrayDate)) for _, t := range arrayDate { if c.TargetDWH == QDWHTypeSnowflake { @@ -626,25 +498,16 @@ func (c *QValueAvroConverter) processArrayDate() (interface{}, error) { } if c.Nullable { - return goavro.Union("array", transformedTimeArr), nil + return goavro.Union("array", transformedTimeArr) } - return transformedTimeArr, nil + return transformedTimeArr } -func (c *QValueAvroConverter) processHStore() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - hstoreString, ok := c.Value.(string) - if !ok { - return nil, fmt.Errorf("invalid HSTORE value %v", c.Value) - } - - jsonString, err := hstore_util.ParseHstore(hstoreString) +func (c *QValueAvroConverter) processHStore(hstore string) (interface{}, error) { + jsonString, err := hstore_util.ParseHstore(hstore) if err != nil { - return "", err + return "", fmt.Errorf("cannot parse %s: %w", hstore, err) } if c.Nullable { @@ -664,60 +527,22 @@ func (c *QValueAvroConverter) processHStore() (interface{}, error) { return jsonString, nil } -func (c *QValueAvroConverter) processUUID() (interface{}, error) { - if c.Value == nil { - return nil, nil - } - - byteData, ok := c.Value.([16]byte) - if !ok { - // attempt to convert google.uuid to [16]byte - byteData, ok = c.Value.(uuid.UUID) - if !ok { - return nil, fmt.Errorf("[conversion] invalid UUID value %v", c.Value) - } - } - - u, err := uuid.FromBytes(byteData[:]) - if err != nil { - return nil, fmt.Errorf("[conversion] conversion of invalid UUID value: %w", err) - } - - uuidString := u.String() - +func (c *QValueAvroConverter) processUUID(byteData [16]byte) interface{} { + uuidString := uuid.UUID(byteData).String() if c.Nullable { - return goavro.Union("string", uuidString), nil + return goavro.Union("string", uuidString) } - - return uuidString, nil + return uuidString } -func (c *QValueAvroConverter) processGeospatial() (interface{}, error) { - if c.Value == nil { - return nil, nil - } - - geoString, ok := c.Value.(string) - if !ok { - return nil, fmt.Errorf("[conversion] invalid geospatial value %v", c.Value) - } - +func (c *QValueAvroConverter) processGeospatial(geoString string) interface{} { if c.Nullable { - return goavro.Union("string", geoString), nil + return goavro.Union("string", geoString) } - return geoString, nil + return geoString } -func (c *QValueAvroConverter) processArrayInt16() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - arrayData, ok := c.Value.([]int16) - if !ok { - return nil, errors.New("invalid Int16 array value") - } - +func (c *QValueAvroConverter) processArrayInt16(arrayData []int16) interface{} { // cast to int32 int32Data := make([]int32, 0, len(arrayData)) for _, v := range arrayData { @@ -725,93 +550,43 @@ func (c *QValueAvroConverter) processArrayInt16() (interface{}, error) { } if c.Nullable { - return goavro.Union("array", int32Data), nil + return goavro.Union("array", int32Data) } - return int32Data, nil + return int32Data } -func (c *QValueAvroConverter) processArrayInt32() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - arrayData, ok := c.Value.([]int32) - if !ok { - return nil, errors.New("invalid Int32 array value") - } - +func (c *QValueAvroConverter) processArrayInt32(arrayData []int32) interface{} { if c.Nullable { - return goavro.Union("array", arrayData), nil + return goavro.Union("array", arrayData) } - - return arrayData, nil + return arrayData } -func (c *QValueAvroConverter) processArrayInt64() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - arrayData, ok := c.Value.([]int64) - if !ok { - return nil, errors.New("invalid Int64 array value") - } - +func (c *QValueAvroConverter) processArrayInt64(arrayData []int64) interface{} { if c.Nullable { - return goavro.Union("array", arrayData), nil + return goavro.Union("array", arrayData) } - - return arrayData, nil + return arrayData } -func (c *QValueAvroConverter) processArrayFloat32() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - arrayData, ok := c.Value.([]float32) - if !ok { - return nil, errors.New("invalid Float32 array value") - } - +func (c *QValueAvroConverter) processArrayFloat32(arrayData []float32) interface{} { if c.Nullable { - return goavro.Union("array", arrayData), nil + return goavro.Union("array", arrayData) } - - return arrayData, nil + return arrayData } -func (c *QValueAvroConverter) processArrayFloat64() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - arrayData, ok := c.Value.([]float64) - if !ok { - return nil, errors.New("invalid Float64 array value") - } - +func (c *QValueAvroConverter) processArrayFloat64(arrayData []float64) interface{} { if c.Nullable { - return goavro.Union("array", arrayData), nil + return goavro.Union("array", arrayData) } - - return arrayData, nil + return arrayData } -func (c *QValueAvroConverter) processArrayString() (interface{}, error) { - if c.Value == nil && c.Nullable { - return nil, nil - } - - arrayData, ok := c.Value.([]string) - if !ok { - return nil, errors.New("invalid String array value") - } - +func (c *QValueAvroConverter) processArrayString(arrayData []string) interface{} { if c.Nullable { - return goavro.Union("array", arrayData), nil + return goavro.Union("array", arrayData) } - - return arrayData, nil + return arrayData } diff --git a/flow/model/qvalue/equals.go b/flow/model/qvalue/equals.go new file mode 100644 index 0000000000..713f71134a --- /dev/null +++ b/flow/model/qvalue/equals.go @@ -0,0 +1,485 @@ +package qvalue + +import ( + "bytes" + "math" + "math/big" + "reflect" + "slices" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + "github.com/shopspring/decimal" + geom "github.com/twpayne/go-geos" + + hstore_util "github.com/PeerDB-io/peer-flow/hstore" +) + +func valueEmpty(value any) bool { + return value == nil || value == "" || value == "null" || + (reflect.TypeOf(value).Kind() == reflect.Slice && reflect.ValueOf(value).Len() == 0) +} + +func Equals(qv QValue, other QValue) bool { + qvValue := qv.Value() + otherValue := other.Value() + if valueEmpty(qvValue) && valueEmpty(otherValue) { + return true + } + + switch q := qv.(type) { + case QValueInvalid: + return true + case QValueFloat32: + return q.compareFloat32(other) + case QValueFloat64: + return q.compareFloat64(other) + case QValueInt16: + return q.compareInt16(other) + case QValueInt32: + return q.compareInt32(other) + case QValueInt64: + return q.compareInt64(other) + case QValueBoolean: + if otherVal, ok := other.(QValueBoolean); ok { + return q.Val == otherVal.Val + } + return false + case QValueStruct: + if otherVal, ok := other.(QValueStruct); ok { + return q.compareStruct(otherVal) + } + return false + case QValueQChar: + if otherVal, ok := other.(QValueQChar); ok { + return q.Val == otherVal.Val + } + return false + case QValueString: + return compareString(q.Val, otherValue) + case QValueINET: + return compareString(q.Val, otherValue) + case QValueCIDR: + return compareString(q.Val, otherValue) + // all internally represented as a Golang time.Time + case QValueDate, QValueTimestamp, QValueTimestampTZ, QValueTime, QValueTimeTZ: + return compareGoTime(qvValue, otherValue) + case QValueNumeric: + return compareNumeric(q.Val, otherValue) + case QValueBytes: + return compareBytes(qvValue, otherValue) + case QValueUUID: + return compareUUID(qvValue, otherValue) + case QValueJSON: + // TODO (kaushik): fix for tests + return true + case QValueBit: + return compareBytes(qvValue, otherValue) + case QValueGeometry: + return compareGeometry(q.Val, otherValue) + case QValueGeography: + return compareGeometry(q.Val, otherValue) + case QValueHStore: + return compareHStore(q.Val, otherValue) + case QValueArrayInt32, QValueArrayInt16, QValueArrayInt64, QValueArrayFloat32, QValueArrayFloat64: + return compareNumericArrays(qvValue, otherValue) + case QValueArrayDate: + return compareDateArrays(q.Val, otherValue) + case QValueArrayTimestamp, QValueArrayTimestampTZ: + return compareTimeArrays(qvValue, otherValue) + case QValueArrayBoolean: + return compareBoolArrays(q.Val, otherValue) + case QValueArrayString: + return compareArrayString(q.Val, otherValue) + default: + return false + } +} + +func (v QValueInt16) compareInt16(value2 QValue) bool { + int2, ok2 := getInt16(value2.Value()) + return ok2 && v.Val == int2 +} + +func (v QValueInt32) compareInt32(value2 QValue) bool { + int2, ok2 := getInt32(value2.Value()) + return ok2 && v.Val == int2 +} + +func (v QValueInt64) compareInt64(value2 QValue) bool { + int2, ok2 := getInt64(value2.Value()) + return ok2 && v.Val == int2 +} + +func (v QValueFloat32) compareFloat32(value2 QValue) bool { + float2, ok2 := getFloat32(value2.Value()) + return ok2 && v.Val == float2 +} + +func (v QValueFloat64) compareFloat64(value2 QValue) bool { + float2, ok2 := getFloat64(value2.Value()) + return ok2 && v.Val == float2 +} + +func compareString(s1 string, value2 interface{}) bool { + s2, ok := value2.(string) + return ok && s1 == s2 +} + +func compareGoTime(value1, value2 interface{}) bool { + et1, ok1 := value1.(time.Time) + et2, ok2 := value2.(time.Time) + + if !ok1 || !ok2 { + return false + } + + // TODO: this is a hack, we should be comparing the actual time values + // currently this is only used for testing so that is OK. + t1 := et1.UnixMicro() + t2 := et2.UnixMicro() + + return t1 == t2 +} + +func compareUUID(value1, value2 interface{}) bool { + uuid1, ok1 := getUUID(value1) + uuid2, ok2 := getUUID(value2) + + return ok1 && ok2 && uuid1 == uuid2 +} + +func compareBytes(value1, value2 interface{}) bool { + bytes1, ok1 := getBytes(value1) + bytes2, ok2 := getBytes(value2) + + return ok1 && ok2 && bytes.Equal(bytes1, bytes2) +} + +func compareNumeric(value1, value2 interface{}) bool { + num1, ok1 := getDecimal(value1) + num2, ok2 := getDecimal(value2) + + if !ok1 || !ok2 { + return false + } + + return num1.Equal(num2) +} + +func compareHStore(str1 string, value2 interface{}) bool { + str2 := value2.(string) + if str1 == str2 { + return true + } + parsedHStore1, err := hstore_util.ParseHstore(str1) + if err != nil { + panic(err) + } + return parsedHStore1 == strings.ReplaceAll(strings.ReplaceAll(str2, " ", ""), "\n", "") +} + +func compareGeometry(geoWkt string, value2 interface{}) bool { + geo2, err := geom.NewGeomFromWKT(value2.(string)) + if err != nil { + panic(err) + } + + if strings.HasPrefix(geoWkt, "SRID=") { + _, wkt, found := strings.Cut(geoWkt, ";") + if found { + geoWkt = wkt + } + } + + geo1, err := geom.NewGeomFromWKT(geoWkt) + if err != nil { + panic(err) + } + return geo1.Equals(geo2) +} + +func (v QValueStruct) compareStruct(value2 QValueStruct) bool { + struct1 := v.Val + struct2 := value2.Val + if len(struct1) != len(struct2) { + return false + } + for k, v1 := range struct1 { + v2, ok := struct2[k] + if !ok { + return false + } + q1, ok1 := v1.(QValue) + q2, ok2 := v2.(QValue) + if !ok1 || !ok2 || !Equals(q1, q2) { + return false + } + } + return true +} + +func compareNumericArrays(value1, value2 interface{}) bool { + // Helper function to convert a value to float64 + convertToFloat64 := func(val interface{}) []float64 { + switch v := val.(type) { + case []int16: + result := make([]float64, len(v)) + for i, value := range v { + result[i] = float64(value) + } + return result + case []int32: + result := make([]float64, len(v)) + for i, value := range v { + result[i] = float64(value) + } + return result + case []int64: + result := make([]float64, len(v)) + for i, value := range v { + result[i] = float64(value) + } + return result + case []float32: + result := make([]float64, len(v)) + for i, value := range v { + result[i] = float64(value) + } + return result + case []float64: + return v + default: + return nil + } + } + + array1 := convertToFloat64(value1) + array2 := convertToFloat64(value2) + + if array1 == nil || array2 == nil || len(array1) != len(array2) { + return false + } + + for i := range array1 { + if math.Abs(array1[i]-array2[i]) >= 1e9 { + return false + } + } + + return true +} + +func compareTimeArrays(value1, value2 interface{}) bool { + array1, ok1 := value1.([]time.Time) + array2, ok2 := value2.([]time.Time) + + if !ok1 || !ok2 || len(array1) != len(array2) { + return false + } + + for i := range array1 { + if !array1[i].Equal(array2[i]) { + return false + } + } + return true +} + +func compareDateArrays(value1, value2 interface{}) bool { + array1, ok1 := value1.([]time.Time) + array2, ok2 := value2.([]time.Time) + + if !ok1 || !ok2 || len(array1) != len(array2) { + return false + } + + for i := range array1 { + if array1[i].Year() != array2[i].Year() || + array1[i].Month() != array2[i].Month() || + array1[i].Day() != array2[i].Day() { + return false + } + } + return true +} + +func compareBoolArrays(value1, value2 interface{}) bool { + array1, ok1 := value1.([]bool) + array2, ok2 := value2.([]bool) + + if !ok1 || !ok2 || len(array1) != len(array2) { + return false + } + + for i := range array1 { + if array1[i] != array2[i] { + return false + } + } + return true +} + +func compareArrayString(value1, value2 interface{}) bool { + array1, ok1 := value1.([]string) + array2, ok2 := value2.([]string) + + if !ok1 || !ok2 { + return false + } + + return slices.Compare(array1, array2) == 0 +} + +func getInt16(v interface{}) (int16, bool) { + switch value := v.(type) { + case int16: + return value, true + case int32: + return int16(value), true + case int64: + return int16(value), true + case decimal.Decimal: + return int16(value.IntPart()), true + case string: + parsed, err := strconv.ParseInt(value, 10, 16) + if err == nil { + return int16(parsed), true + } + } + return 0, false +} + +func getInt32(v interface{}) (int32, bool) { + switch value := v.(type) { + case int32: + return value, true + case int64: + return int32(value), true + case decimal.Decimal: + return int32(value.IntPart()), true + case string: + parsed, err := strconv.ParseInt(value, 10, 32) + if err == nil { + return int32(parsed), true + } + } + return 0, false +} + +func getInt64(v interface{}) (int64, bool) { + switch value := v.(type) { + case int64: + return value, true + case int32: + return int64(value), true + case decimal.Decimal: + return value.IntPart(), true + case string: + parsed, err := strconv.ParseInt(value, 10, 64) + if err == nil { + return parsed, true + } + } + return 0, false +} + +func getFloat32(v interface{}) (float32, bool) { + switch value := v.(type) { + case float32: + return value, true + case float64: + return float32(value), true + case string: + parsed, err := strconv.ParseFloat(value, 32) + if err == nil { + return float32(parsed), true + } + } + return 0, false +} + +func getFloat64(v interface{}) (float64, bool) { + switch value := v.(type) { + case float64: + return value, true + case float32: + return float64(value), true + case string: + parsed, err := strconv.ParseFloat(value, 64) + if err == nil { + return parsed, true + } + } + return 0, false +} + +func getBytes(v interface{}) ([]byte, bool) { + switch value := v.(type) { + case []byte: + return value, true + case string: + return []byte(value), true + case nil: + return nil, true + default: + return nil, false + } +} + +func getUUID(v interface{}) (uuid.UUID, bool) { + switch value := v.(type) { + case uuid.UUID: + return value, true + case string: + parsed, err := uuid.Parse(value) + if err == nil { + return parsed, true + } + case [16]byte: + return uuid.UUID(value), true + } + + return uuid.UUID{}, false +} + +// getDecimal attempts to parse a decimal from an interface +func getDecimal(v interface{}) (decimal.Decimal, bool) { + switch value := v.(type) { + case decimal.Decimal: + return value, true + case string: + parsed, err := decimal.NewFromString(value) + if err != nil { + panic(err) + } + return parsed, true + case float64: + return decimal.NewFromFloat(value), true + case int64: + return decimal.NewFromInt(value), true + case uint64: + return decimal.NewFromBigInt(new(big.Int).SetUint64(value), 0), true + case float32: + return decimal.NewFromFloat32(value), true + case int32: + return decimal.NewFromInt(int64(value)), true + case uint32: + return decimal.NewFromInt(int64(value)), true + case int: + return decimal.NewFromInt(int64(value)), true + case uint: + return decimal.NewFromInt(int64(value)), true + case int8: + return decimal.NewFromInt(int64(value)), true + case uint8: + return decimal.NewFromInt(int64(value)), true + case int16: + return decimal.NewFromInt(int64(value)), true + case uint16: + return decimal.NewFromInt(int64(value)), true + } + return decimal.Decimal{}, false +} diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index 9ed9ac0beb..f07df58383 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -8,7 +8,6 @@ import ( type QValueKind string const ( - QValueKindEmpty QValueKind = "" QValueKindInvalid QValueKind = "invalid" QValueKindFloat32 QValueKind = "float32" QValueKindFloat64 QValueKind = "float64" diff --git a/flow/model/qvalue/qvalue.go b/flow/model/qvalue/qvalue.go index bedc3decec..40ac4b6fc5 100644 --- a/flow/model/qvalue/qvalue.go +++ b/flow/model/qvalue/qvalue.go @@ -1,652 +1,479 @@ package qvalue import ( - "bytes" - "encoding/json" - "fmt" - "math" - "math/big" - "reflect" - "strconv" - "strings" "time" - "cloud.google.com/go/civil" - "github.com/google/uuid" - "github.com/jackc/pgx/v5/pgtype" "github.com/shopspring/decimal" - geom "github.com/twpayne/go-geos" - - hstore_util "github.com/PeerDB-io/peer-flow/hstore" ) // if new types are added, register them in gob - cdc_records_storage.go -type QValue struct { - Kind QValueKind - Value interface{} -} - -func (q QValue) Equals(other QValue) bool { - if q.Kind == QValueKindJSON { - return true // TODO fix - } else if q.Value == nil && other.Value == nil { - return true - } - - switch q.Kind { - case QValueKindEmpty: - return other.Kind == QValueKindEmpty - case QValueKindInvalid: - return true - case QValueKindFloat32: - return compareFloat32(q.Value, other.Value) - case QValueKindFloat64: - return compareFloat64(q.Value, other.Value) - case QValueKindInt16: - return compareInt16(q.Value, other.Value) - case QValueKindInt32: - return compareInt32(q.Value, other.Value) - case QValueKindInt64: - return compareInt64(q.Value, other.Value) - case QValueKindBoolean: - return compareBoolean(q.Value, other.Value) - case QValueKindStruct: - return compareStruct(q.Value, other.Value) - case QValueKindQChar: - if (q.Value == nil) == (other.Value == nil) { - return q.Value == nil || q.Value.(uint8) == other.Value.(uint8) - } else { - return false - } - case QValueKindString, QValueKindINET, QValueKindCIDR: - return compareString(q.Value, other.Value) - // all internally represented as a Golang time.Time - case QValueKindDate, - QValueKindTimestamp, QValueKindTimestampTZ: - return compareGoTime(q.Value, other.Value) - case QValueKindTime, QValueKindTimeTZ: - return compareGoCivilTime(q.Value, other.Value) - case QValueKindNumeric: - return compareNumeric(q.Value, other.Value) - case QValueKindBytes: - return compareBytes(q.Value, other.Value) - case QValueKindUUID: - return compareUUID(q.Value, other.Value) - case QValueKindJSON: - return compareJSON(q.Value, other.Value) - case QValueKindBit: - return compareBit(q.Value, other.Value) - case QValueKindGeometry, QValueKindGeography: - return compareGeometry(q.Value, other.Value) - case QValueKindHStore: - return compareHstore(q.Value, other.Value) - case QValueKindArrayFloat32: - return compareNumericArrays(q.Value, other.Value) - case QValueKindArrayFloat64: - return compareNumericArrays(q.Value, other.Value) - case QValueKindArrayInt32, QValueKindArrayInt16: - return compareNumericArrays(q.Value, other.Value) - case QValueKindArrayInt64: - return compareNumericArrays(q.Value, other.Value) - case QValueKindArrayDate: - return compareDateArrays(q.Value, other.Value) - case QValueKindArrayTimestamp, QValueKindArrayTimestampTZ: - return compareTimeArrays(q.Value, other.Value) - case QValueKindArrayBoolean: - return compareBoolArrays(q.Value, other.Value) - case QValueKindArrayString: - return compareArrayString(q.Value, other.Value) - default: - return false - } -} - -func (q QValue) GoTimeConvert() (string, error) { - if q.Kind == QValueKindTime || q.Kind == QValueKindTimeTZ { - return q.Value.(time.Time).Format("15:04:05.999999"), nil - // no connector supports time with timezone yet - // } else if q.Kind == QValueKindTimeTZ { - // return q.Value.(time.Time).Format("15:04:05.999999-0700"), nil - } else if q.Kind == QValueKindDate { - return q.Value.(time.Time).Format("2006-01-02"), nil - } else if q.Kind == QValueKindTimestamp { - return q.Value.(time.Time).Format("2006-01-02 15:04:05.999999"), nil - } else if q.Kind == QValueKindTimestampTZ { - return q.Value.(time.Time).Format("2006-01-02 15:04:05.999999-0700"), nil - } else { - return "", fmt.Errorf("unsupported QValueKind: %s", q.Kind) - } -} - -func compareInt16(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - - int1, ok1 := getInt16(value1) - int2, ok2 := getInt16(value2) - return ok1 && ok2 && int1 == int2 -} - -func compareInt32(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - - int1, ok1 := getInt32(value1) - int2, ok2 := getInt32(value2) - return ok1 && ok2 && int1 == int2 -} - -func compareInt64(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - - int1, ok1 := getInt64(value1) - int2, ok2 := getInt64(value2) - return ok1 && ok2 && int1 == int2 -} - -func compareFloat32(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - float1, ok1 := getFloat32(value1) - float2, ok2 := getFloat32(value2) - return ok1 && ok2 && float1 == float2 -} - -func compareFloat64(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - - float1, ok1 := getFloat64(value1) - float2, ok2 := getFloat64(value2) - return ok1 && ok2 && float1 == float2 -} - -func compareGoTime(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - - et1, ok1 := value1.(time.Time) - et2, ok2 := value2.(time.Time) - - if !ok1 || !ok2 { - return false - } - - // TODO: this is a hack, we should be comparing the actual time values - // currently this is only used for testing so that is OK. - t1 := et1.UnixMicro() - t2 := et2.UnixMicro() - - return t1 == t2 -} - -func compareGoCivilTime(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - - t1, ok1 := value1.(time.Time) - t2, ok2 := value2.(time.Time) - - if !ok1 || !ok2 { - if !ok2 { - // For BigQuery, we need to compare civil.Time with time.Time - ct2, ok3 := value2.(civil.Time) - if !ok3 { - return false - } - return t1.Hour() == ct2.Hour && t1.Minute() == ct2.Minute && t1.Second() == ct2.Second - } - return false - } - - return t1.Hour() == t2.Hour() && t1.Minute() == t2.Minute() && t1.Second() == t2.Second() -} +type QValue interface { + Kind() QValueKind + Value() any +} + +type QValueNull QValueKind + +func (v QValueNull) Kind() QValueKind { + return QValueKind(v) +} + +func (QValueNull) Value() any { + return nil +} + +type QValueInvalid struct { + Val string +} + +func (QValueInvalid) Kind() QValueKind { + return QValueKindInvalid +} + +func (v QValueInvalid) Value() any { + return v.Val +} + +type QValueFloat32 struct { + Val float32 +} + +func (QValueFloat32) Kind() QValueKind { + return QValueKindFloat32 +} + +func (v QValueFloat32) Value() any { + return v.Val +} + +type QValueFloat64 struct { + Val float64 +} + +func (QValueFloat64) Kind() QValueKind { + return QValueKindFloat64 +} + +func (v QValueFloat64) Value() any { + return v.Val +} + +type QValueInt16 struct { + Val int16 +} + +func (QValueInt16) Kind() QValueKind { + return QValueKindInt16 +} + +func (v QValueInt16) Value() any { + return v.Val +} + +type QValueInt32 struct { + Val int32 +} + +func (QValueInt32) Kind() QValueKind { + return QValueKindInt32 +} + +func (v QValueInt32) Value() any { + return v.Val +} + +type QValueInt64 struct { + Val int64 +} + +func (QValueInt64) Kind() QValueKind { + return QValueKindInt64 +} + +func (v QValueInt64) Value() any { + return v.Val +} + +type QValueBoolean struct { + Val bool +} + +func (QValueBoolean) Kind() QValueKind { + return QValueKindBoolean +} + +func (v QValueBoolean) Value() any { + return v.Val +} + +type QValueStruct struct { + Val map[string]interface{} +} + +func (QValueStruct) Kind() QValueKind { + return QValueKindStruct +} + +func (v QValueStruct) Value() any { + return v.Val +} + +type QValueQChar struct { + Val uint8 +} + +func (QValueQChar) Kind() QValueKind { + return QValueKindQChar +} + +func (v QValueQChar) Value() any { + return v.Val +} + +type QValueString struct { + Val string +} + +func (QValueString) Kind() QValueKind { + return QValueKindString +} + +func (v QValueString) Value() any { + return v.Val +} + +type QValueTimestamp struct { + Val time.Time +} + +func (QValueTimestamp) Kind() QValueKind { + return QValueKindTimestamp +} + +func (v QValueTimestamp) Value() any { + return v.Val +} + +type QValueTimestampTZ struct { + Val time.Time +} + +func (QValueTimestampTZ) Kind() QValueKind { + return QValueKindTimestampTZ +} + +func (v QValueTimestampTZ) Value() any { + return v.Val +} + +type QValueDate struct { + Val time.Time +} + +func (QValueDate) Kind() QValueKind { + return QValueKindDate +} + +func (v QValueDate) Value() any { + return v.Val +} + +type QValueTime struct { + Val time.Time +} + +func (QValueTime) Kind() QValueKind { + return QValueKindTime +} + +func (v QValueTime) Value() any { + return v.Val +} + +type QValueTimeTZ struct { + Val time.Time +} + +func (QValueTimeTZ) Kind() QValueKind { + return QValueKindTimeTZ +} + +func (v QValueTimeTZ) Value() any { + return v.Val +} + +type QValueInterval struct { + Val string +} + +func (QValueInterval) Kind() QValueKind { + return QValueKindInterval +} + +func (v QValueInterval) Value() any { + return v.Val +} + +type QValueNumeric struct { + Val decimal.Decimal +} + +func (QValueNumeric) Kind() QValueKind { + return QValueKindNumeric +} + +func (v QValueNumeric) Value() any { + return v.Val +} + +type QValueBytes struct { + Val []byte +} + +func (QValueBytes) Kind() QValueKind { + return QValueKindBytes +} + +func (v QValueBytes) Value() any { + return v.Val +} + +type QValueUUID struct { + Val [16]byte +} + +func (QValueUUID) Kind() QValueKind { + return QValueKindUUID +} + +func (v QValueUUID) Value() any { + return v.Val +} + +type QValueJSON struct { + Val string +} + +func (QValueJSON) Kind() QValueKind { + return QValueKindJSON +} + +func (v QValueJSON) Value() any { + return v.Val +} + +type QValueBit struct { + Val []byte +} + +func (QValueBit) Kind() QValueKind { + return QValueKindBit +} + +func (v QValueBit) Value() any { + return v.Val +} + +type QValueHStore struct { + Val string +} + +func (QValueHStore) Kind() QValueKind { + return QValueKindHStore +} + +func (v QValueHStore) Value() any { + return v.Val +} + +type QValueGeography struct { + Val string +} + +func (QValueGeography) Kind() QValueKind { + return QValueKindGeography +} + +func (v QValueGeography) Value() any { + return v.Val +} + +type QValueGeometry struct { + Val string +} + +func (QValueGeometry) Kind() QValueKind { + return QValueKindGeometry +} + +func (v QValueGeometry) Value() any { + return v.Val +} + +type QValuePoint struct { + Val string +} + +func (QValuePoint) Kind() QValueKind { + return QValueKindPoint +} + +func (v QValuePoint) Value() any { + return v.Val +} + +type QValueCIDR struct { + Val string +} + +func (QValueCIDR) Kind() QValueKind { + return QValueKindCIDR +} + +func (v QValueCIDR) Value() any { + return v.Val +} + +type QValueINET struct { + Val string +} + +func (QValueINET) Kind() QValueKind { + return QValueKindINET +} + +func (v QValueINET) Value() any { + return v.Val +} + +type QValueMacaddr struct { + Val string +} + +func (QValueMacaddr) Kind() QValueKind { + return QValueKindMacaddr +} + +func (v QValueMacaddr) Value() any { + return v.Val +} + +type QValueArrayFloat32 struct { + Val []float32 +} + +func (QValueArrayFloat32) Kind() QValueKind { + return QValueKindArrayFloat32 +} + +func (v QValueArrayFloat32) Value() any { + return v.Val +} + +type QValueArrayFloat64 struct { + Val []float64 +} + +func (QValueArrayFloat64) Kind() QValueKind { + return QValueKindArrayFloat64 +} + +func (v QValueArrayFloat64) Value() any { + return v.Val +} + +type QValueArrayInt16 struct { + Val []int16 +} + +func (QValueArrayInt16) Kind() QValueKind { + return QValueKindInt16 +} + +func (v QValueArrayInt16) Value() any { + return v.Val +} + +type QValueArrayInt32 struct { + Val []int32 +} + +func (QValueArrayInt32) Kind() QValueKind { + return QValueKindInt32 +} + +func (v QValueArrayInt32) Value() any { + return v.Val +} + +type QValueArrayInt64 struct { + Val []int64 +} + +func (QValueArrayInt64) Kind() QValueKind { + return QValueKindArrayInt64 +} -func compareUUID(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } +func (v QValueArrayInt64) Value() any { + return v.Val +} + +type QValueArrayString struct { + Val []string +} + +func (QValueArrayString) Kind() QValueKind { + return QValueKindArrayString +} + +func (v QValueArrayString) Value() any { + return v.Val +} + +type QValueArrayDate struct { + Val []time.Time +} + +func (QValueArrayDate) Kind() QValueKind { + return QValueKindArrayDate +} + +func (v QValueArrayDate) Value() any { + return v.Val +} + +type QValueArrayTimestamp struct { + Val []time.Time +} + +func (QValueArrayTimestamp) Kind() QValueKind { + return QValueKindArrayTimestamp +} + +func (v QValueArrayTimestamp) Value() any { + return v.Val +} + +type QValueArrayTimestampTZ struct { + Val []time.Time +} + +func (QValueArrayTimestampTZ) Kind() QValueKind { + return QValueKindArrayTimestampTZ +} + +func (v QValueArrayTimestampTZ) Value() any { + return v.Val +} + +type QValueArrayBoolean struct { + Val []bool +} + +func (QValueArrayBoolean) Kind() QValueKind { + return QValueKindArrayBoolean +} - uuid1, ok1 := getUUID(value1) - uuid2, ok2 := getUUID(value2) - - return ok1 && ok2 && uuid1 == uuid2 -} - -func compareBoolean(value1, value2 interface{}) bool { - bool1, ok1 := value1.(bool) - bool2, ok2 := value2.(bool) - - return ok1 && ok2 && bool1 == bool2 -} - -func compareBytes(value1, value2 interface{}) bool { - bytes1, ok1 := getBytes(value1) - bytes2, ok2 := getBytes(value2) - - return ok1 && ok2 && bytes.Equal(bytes1, bytes2) -} - -func compareNumeric(value1, value2 interface{}) bool { - num1, ok1 := getDecimal(value1) - num2, ok2 := getDecimal(value2) - - if !ok1 || !ok2 { - return false - } - - return num1.Equal(num2) -} - -func compareString(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - - str1, ok1 := value1.(string) - str2, ok2 := value2.(string) - if !ok1 || !ok2 { - return false - } - return str1 == str2 -} - -func compareHstore(value1, value2 interface{}) bool { - str2 := value2.(string) - switch v1 := value1.(type) { - case pgtype.Hstore: - bytes, err := json.Marshal(v1) - if err != nil { - panic(err) - } - return string(bytes) == str2 - case string: - if v1 == str2 { - return true - } - parsedHStore1, err := hstore_util.ParseHstore(v1) - if err != nil { - panic(err) - } - return parsedHStore1 == strings.ReplaceAll(strings.ReplaceAll(str2, " ", ""), "\n", "") - default: - panic(fmt.Sprintf("invalid hstore value type %T: %v", value1, value1)) - } -} - -func compareGeometry(value1, value2 interface{}) bool { - geo2, err := geom.NewGeomFromWKT(value2.(string)) - if err != nil { - panic(err) - } - - switch v1 := value1.(type) { - case *geom.Geom: - return v1.Equals(geo2) - case string: - geoWkt := v1 - if strings.HasPrefix(geoWkt, "SRID=") { - _, wkt, found := strings.Cut(geoWkt, ";") - if found { - geoWkt = wkt - } - } - - geo1, err := geom.NewGeomFromWKT(geoWkt) - if err != nil { - panic(err) - } - return geo1.Equals(geo2) - default: - panic(fmt.Sprintf("invalid geometry value type %T: %v", value1, value1)) - } -} - -func compareStruct(value1, value2 interface{}) bool { - struct1, ok1 := value1.(map[string]interface{}) - struct2, ok2 := value2.(map[string]interface{}) - if !ok1 || !ok2 || len(struct1) != len(struct2) { - return false - } - for k, v1 := range struct1 { - v2, ok := struct2[k] - if !ok { - return false - } - q1, ok1 := v1.(QValue) - q2, ok2 := v2.(QValue) - if !ok1 || !ok2 || !q1.Equals(q2) { - return false - } - } - return true -} - -func compareJSON(value1, value2 interface{}) bool { - // TODO (kaushik): fix for tests - return true -} - -func compareBit(value1, value2 interface{}) bool { - bit1, ok1 := value1.(int) - bit2, ok2 := value2.(int) - - if !ok1 || !ok2 { - return false - } - - return bit1 == bit2 -} - -func compareNumericArrays(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - - if value1 == nil && value2 == "null" { - return true - } - - if value1 == nil && value2 == "" { - return true - } - - // Helper function to convert a value to float64 - convertToFloat64 := func(val interface{}) []float64 { - switch v := val.(type) { - case []int16: - result := make([]float64, len(v)) - for i, value := range v { - result[i] = float64(value) - } - return result - case []int32: - result := make([]float64, len(v)) - for i, value := range v { - result[i] = float64(value) - } - return result - case []int64: - result := make([]float64, len(v)) - for i, value := range v { - result[i] = float64(value) - } - return result - case []float32: - result := make([]float64, len(v)) - for i, value := range v { - result[i] = float64(value) - } - return result - case []float64: - return v - default: - return nil - } - } - - array1 := convertToFloat64(value1) - array2 := convertToFloat64(value2) - - if array1 == nil || array2 == nil || len(array1) != len(array2) { - return false - } - - for i := range array1 { - if math.Abs(array1[i]-array2[i]) >= 1e9 { - return false - } - } - - return true -} - -func compareTimeArrays(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - array1, ok1 := value1.([]time.Time) - array2, ok2 := value2.([]time.Time) - - if !ok1 || !ok2 { - return false - } - - if len(array1) != len(array2) { - return false - } - - for i := range array1 { - if !array1[i].Equal(array2[i]) { - return false - } - } - - return true -} - -func compareDateArrays(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - array1, ok1 := value1.([]time.Time) - array2, ok2 := value2.([]civil.Date) - - if !ok1 || !ok2 || len(array1) != len(array2) { - return false - } - - for i := range array1 { - if array1[i].Year() != array2[i].Year || - array1[i].Month() != array2[i].Month || - array1[i].Day() != array2[i].Day { - return false - } - } - - return true -} - -func compareBoolArrays(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - array1, ok1 := value1.([]bool) - array2, ok2 := value2.([]bool) - - if !ok1 || !ok2 || len(array1) != len(array2) { - return false - } - - for i := range array1 { - if array1[i] != array2[i] { - return false - } - } - - return true -} - -func compareArrayString(value1, value2 interface{}) bool { - if value1 == nil && value2 == nil { - return true - } - - // also return true if value2 is string null - if value1 == nil && value2 == "null" { - return true - } - - // nulls end up as empty 'variants' in snowflake - if value1 == nil && value2 == "" { - return true - } - - array1, ok1 := value1.([]string) - array2, ok2 := value2.([]string) - - if !ok1 || !ok2 { - return false - } - - return reflect.DeepEqual(array1, array2) -} - -func getInt16(v interface{}) (int16, bool) { - switch value := v.(type) { - case int16: - return value, true - case int32: - return int16(value), true - case int64: - return int16(value), true - case decimal.Decimal: - return int16(value.IntPart()), true - case string: - parsed, err := strconv.ParseInt(value, 10, 16) - if err == nil { - return int16(parsed), true - } - } - return 0, false -} - -func getInt32(v interface{}) (int32, bool) { - switch value := v.(type) { - case int32: - return value, true - case int64: - return int32(value), true - case decimal.Decimal: - return int32(value.IntPart()), true - case string: - parsed, err := strconv.ParseInt(value, 10, 32) - if err == nil { - return int32(parsed), true - } - } - return 0, false -} - -func getInt64(v interface{}) (int64, bool) { - switch value := v.(type) { - case int64: - return value, true - case int32: - return int64(value), true - case decimal.Decimal: - return value.IntPart(), true - case string: - parsed, err := strconv.ParseInt(value, 10, 64) - if err == nil { - return parsed, true - } - } - return 0, false -} - -func getFloat32(v interface{}) (float32, bool) { - switch value := v.(type) { - case float32: - return value, true - case float64: - return float32(value), true - case string: - parsed, err := strconv.ParseFloat(value, 32) - if err == nil { - return float32(parsed), true - } - } - return 0, false -} - -func getFloat64(v interface{}) (float64, bool) { - switch value := v.(type) { - case float64: - return value, true - case float32: - return float64(value), true - case string: - parsed, err := strconv.ParseFloat(value, 64) - if err == nil { - return parsed, true - } - } - return 0, false -} - -func getBytes(v interface{}) ([]byte, bool) { - switch value := v.(type) { - case []byte: - return value, true - case string: - return []byte(value), true - case nil: - return nil, true - default: - return nil, false - } -} - -func getUUID(v interface{}) (uuid.UUID, bool) { - switch value := v.(type) { - case uuid.UUID: - return value, true - case string: - parsed, err := uuid.Parse(value) - if err == nil { - return parsed, true - } - case [16]byte: - return uuid.UUID(value), true - } - - return uuid.UUID{}, false -} - -// getDecimal attempts to parse a decimal from an interface -func getDecimal(v interface{}) (decimal.Decimal, bool) { - switch value := v.(type) { - case decimal.Decimal: - return value, true - case string: - parsed, err := decimal.NewFromString(value) - if err != nil { - panic(err) - } - return parsed, true - case float64: - return decimal.NewFromFloat(value), true - case int64: - return decimal.NewFromInt(value), true - case uint64: - return decimal.NewFromBigInt(new(big.Int).SetUint64(value), 0), true - case float32: - return decimal.NewFromFloat32(value), true - case int32: - return decimal.NewFromInt(int64(value)), true - case uint32: - return decimal.NewFromInt(int64(value)), true - case int: - return decimal.NewFromInt(int64(value)), true - case uint: - return decimal.NewFromInt(int64(value)), true - case int8: - return decimal.NewFromInt(int64(value)), true - case uint8: - return decimal.NewFromInt(int64(value)), true - case int16: - return decimal.NewFromInt(int64(value)), true - case uint16: - return decimal.NewFromInt(int64(value)), true - } - return decimal.Decimal{}, false +func (v QValueArrayBoolean) Value() any { + return v.Val } diff --git a/flow/model/record_items.go b/flow/model/record_items.go index a258f78be5..5482cf3973 100644 --- a/flow/model/record_items.go +++ b/flow/model/record_items.go @@ -5,9 +5,8 @@ import ( "errors" "fmt" "math" - "time" - "github.com/shopspring/decimal" + "github.com/google/uuid" hstore_util "github.com/PeerDB-io/peer-flow/hstore" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -48,7 +47,7 @@ func (r *RecordItems) GetColumnValue(col string) qvalue.QValue { if idx, ok := r.ColToValIdx[col]; ok { return r.Values[idx] } - return qvalue.QValue{} + return nil } // UpdateIfNotExists takes in a RecordItems as input and updates the values of the @@ -70,7 +69,7 @@ func (r *RecordItems) UpdateIfNotExists(input *RecordItems) []string { func (r *RecordItems) GetValueByColName(colName string) (qvalue.QValue, error) { idx, ok := r.ColToValIdx[colName] if !ok { - return qvalue.QValue{}, fmt.Errorf("column name %s not found", colName) + return nil, fmt.Errorf("column name %s not found", colName) } return r.Values[idx], nil } @@ -79,26 +78,22 @@ func (r *RecordItems) Len() int { return len(r.Values) } -func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) { +func (r *RecordItems) toMap(hstoreAsJSON bool, opts ToJSONOptions) (map[string]interface{}, error) { if r.ColToValIdx == nil { return nil, errors.New("colToValIdx is nil") } jsonStruct := make(map[string]interface{}, len(r.ColToValIdx)) for col, idx := range r.ColToValIdx { - v := r.Values[idx] - if v.Value == nil { + qv := r.Values[idx] + if qv == nil { jsonStruct[col] = nil continue } - var err error - switch v.Kind { - case qvalue.QValueKindBit, qvalue.QValueKindBytes: - bitVal, ok := v.Value.([]byte) - if !ok { - return nil, errors.New("expected []byte value") - } + switch v := qv.(type) { + case qvalue.QValueBit: + bitVal := v.Val // convert to binary string because // json.Marshal stores byte arrays as @@ -109,36 +104,55 @@ func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) { } jsonStruct[col] = binStr - case qvalue.QValueKindQChar: - ch, ok := v.Value.(uint8) - if !ok { - return nil, fmt.Errorf("expected \"char\" value for column %s for %T", col, v.Value) - } + case qvalue.QValueBytes: + bitVal := v.Val - jsonStruct[col] = string(ch) - case qvalue.QValueKindString, qvalue.QValueKindJSON: - strVal, ok := v.Value.(string) - if !ok { - return nil, fmt.Errorf("expected string value for column %s for %T", col, v.Value) + // convert to binary string because + // json.Marshal stores byte arrays as + // base64 + binStr := "" + for _, b := range bitVal { + binStr += fmt.Sprintf("%08b", b) } + jsonStruct[col] = binStr + case qvalue.QValueUUID: + jsonStruct[col] = uuid.UUID(v.Val) + case qvalue.QValueQChar: + jsonStruct[col] = string(v.Val) + case qvalue.QValueString: + strVal := v.Val + if len(strVal) > 15*1024*1024 { jsonStruct[col] = "" } else { jsonStruct[col] = strVal } - case qvalue.QValueKindHStore: - hstoreVal, ok := v.Value.(string) - if !ok { - return nil, fmt.Errorf("expected string value for hstore column %s for value %T", col, v.Value) + case qvalue.QValueJSON: + if len(v.Val) > 15*1024*1024 { + jsonStruct[col] = "" + } else if _, ok := opts.UnnestColumns[col]; ok { + var unnestStruct map[string]interface{} + err := json.Unmarshal([]byte(v.Val), &unnestStruct) + if err != nil { + return nil, err + } + + for k, v := range unnestStruct { + jsonStruct[k] = v + } + } else { + jsonStruct[col] = v.Val } + case qvalue.QValueHStore: + hstoreVal := v.Val if !hstoreAsJSON { jsonStruct[col] = hstoreVal } else { jsonVal, err := hstore_util.ParseHstore(hstoreVal) if err != nil { - return nil, fmt.Errorf("unable to convert hstore column %s to json for value %T", col, v.Value) + return nil, fmt.Errorf("unable to convert hstore column %s to json for value %T: %w", col, v, err) } if len(jsonVal) > 15*1024*1024 { @@ -148,55 +162,39 @@ func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) { } } - case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindDate, - qvalue.QValueKindTime, qvalue.QValueKindTimeTZ: - jsonStruct[col], err = v.GoTimeConvert() - if err != nil { - return nil, err - } - case qvalue.QValueKindArrayDate: - dateArr, ok := v.Value.([]time.Time) - if !ok { - return nil, errors.New("expected []time.Time value") - } + case qvalue.QValueTimestamp: + jsonStruct[col] = v.Val.Format("2006-01-02 15:04:05.999999") + case qvalue.QValueTimestampTZ: + jsonStruct[col] = v.Val.Format("2006-01-02 15:04:05.999999-0700") + case qvalue.QValueDate: + jsonStruct[col] = v.Val.Format("2006-01-02") + case qvalue.QValueTime: + jsonStruct[col] = v.Val.Format("15:04:05.999999") + case qvalue.QValueTimeTZ: + jsonStruct[col] = v.Val.Format("15:04:05.999999") + case qvalue.QValueArrayDate: + dateArr := v.Val formattedDateArr := make([]string, 0, len(dateArr)) for _, val := range dateArr { formattedDateArr = append(formattedDateArr, val.Format("2006-01-02")) } jsonStruct[col] = formattedDateArr - case qvalue.QValueKindNumeric: - val, ok := v.Value.(decimal.Decimal) - if !ok { - return nil, errors.New("expected decimal.Decimal value") - } - - jsonStruct[col] = val.String() - case qvalue.QValueKindFloat64: - floatVal, ok := v.Value.(float64) - if !ok { - return nil, errors.New("expected float64 value") - } - if math.IsNaN(floatVal) || math.IsInf(floatVal, 0) { + case qvalue.QValueNumeric: + jsonStruct[col] = v.Val.String() + case qvalue.QValueFloat64: + if math.IsNaN(v.Val) || math.IsInf(v.Val, 0) { jsonStruct[col] = nil } else { - jsonStruct[col] = floatVal - } - case qvalue.QValueKindFloat32: - floatVal, ok := v.Value.(float32) - if !ok { - return nil, errors.New("expected float32 value") + jsonStruct[col] = v.Val } - if math.IsNaN(float64(floatVal)) || math.IsInf(float64(floatVal), 0) { + case qvalue.QValueFloat32: + if math.IsNaN(float64(v.Val)) || math.IsInf(float64(v.Val), 0) { jsonStruct[col] = nil } else { - jsonStruct[col] = floatVal + jsonStruct[col] = v.Val } - case qvalue.QValueKindArrayFloat64: - floatArr, ok := v.Value.([]float64) - if !ok { - return nil, errors.New("expected []float64 value") - } - + case qvalue.QValueArrayFloat64: + floatArr := v.Val nullableFloatArr := make([]interface{}, 0, len(floatArr)) for _, val := range floatArr { if math.IsNaN(val) || math.IsInf(val, 0) { @@ -206,11 +204,8 @@ func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) { } } jsonStruct[col] = nullableFloatArr - case qvalue.QValueKindArrayFloat32: - floatArr, ok := v.Value.([]float32) - if !ok { - return nil, errors.New("expected []float32 value") - } + case qvalue.QValueArrayFloat32: + floatArr := v.Val nullableFloatArr := make([]interface{}, 0, len(floatArr)) for _, val := range floatArr { if math.IsNaN(float64(val)) || math.IsInf(float64(val), 0) { @@ -222,7 +217,7 @@ func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) { jsonStruct[col] = nullableFloatArr default: - jsonStruct[col] = v.Value + jsonStruct[col] = v.Value() } } @@ -231,7 +226,7 @@ func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) { // a separate method like gives flexibility // for us to handle some data types differently -func (r *RecordItems) ToJSONWithOptions(options *ToJSONOptions) (string, error) { +func (r *RecordItems) ToJSONWithOptions(options ToJSONOptions) (string, error) { return r.ToJSONWithOpts(options) } @@ -239,30 +234,12 @@ func (r *RecordItems) ToJSON() (string, error) { return r.ToJSONWithOpts(NewToJSONOptions(nil, true)) } -func (r *RecordItems) ToJSONWithOpts(opts *ToJSONOptions) (string, error) { - jsonStruct, err := r.toMap(opts.HStoreAsJSON) +func (r *RecordItems) ToJSONWithOpts(opts ToJSONOptions) (string, error) { + jsonStruct, err := r.toMap(opts.HStoreAsJSON, opts) if err != nil { return "", err } - for col, idx := range r.ColToValIdx { - v := r.Values[idx] - if v.Kind == qvalue.QValueKindJSON { - if _, ok := opts.UnnestColumns[col]; ok { - var unnestStruct map[string]interface{} - err := json.Unmarshal([]byte(v.Value.(string)), &unnestStruct) - if err != nil { - return "", err - } - - for k, v := range unnestStruct { - jsonStruct[k] = v - } - delete(jsonStruct, col) - } - } - } - jsonBytes, err := json.Marshal(jsonStruct) if err != nil { return "", err diff --git a/flow/pua/peerdb.go b/flow/pua/peerdb.go index 778cfd366e..197be4b461 100644 --- a/flow/pua/peerdb.go +++ b/flow/pua/peerdb.go @@ -108,7 +108,7 @@ func GetRowQ(ls *lua.LState, row *model.RecordItems, col string) qvalue.QValue { qv, err := row.GetValueByColName(col) if err != nil { ls.RaiseError(err.Error()) - return qvalue.QValue{} + return nil } return qv } @@ -148,7 +148,7 @@ func LuaRowColumns(ls *lua.LState) int { func LuaRowColumnKind(ls *lua.LState) int { row, key := LuaRow.StartIndex(ls) - ls.Push(lua.LString(GetRowQ(ls, row, key).Kind)) + ls.Push(lua.LString(GetRowQ(ls, row, key).Kind())) return 1 } @@ -222,13 +222,13 @@ func qvToLTable[T any](ls *lua.LState, s []T, f func(x T) lua.LValue) *lua.LTabl } func LuaQValue(ls *lua.LState, qv qvalue.QValue) lua.LValue { - switch v := qv.Value.(type) { + switch v := qv.Value().(type) { case nil: return lua.LNil case bool: return lua.LBool(v) case uint8: - if qv.Kind == qvalue.QValueKindQChar { + if qv.Kind() == qvalue.QValueKindQChar { return lua.LString(rune(v)) } else { return lua.LNumber(v) @@ -244,12 +244,6 @@ func LuaQValue(ls *lua.LState, qv qvalue.QValue) lua.LValue { case float64: return lua.LNumber(v) case string: - if qv.Kind == qvalue.QValueKindUUID { - u, err := uuid.Parse(v) - if err != nil { - return LuaUuid.New(ls, u) - } - } return lua.LString(v) case time.Time: return LuaTime.New(ls, v) @@ -292,7 +286,7 @@ func LuaQValue(ls *lua.LState, qv qvalue.QValue) lua.LValue { return lua.LBool(x) }) default: - return lua.LString(fmt.Sprint(qv.Value)) + return lua.LString(fmt.Sprint(qv.Value())) } }