Skip to content

Commit

Permalink
dynamic messages properly handle inf and nan float values when serial…
Browse files Browse the repository at this point in the history
…izing to JSON or text
  • Loading branch information
jhump committed May 20, 2019
1 parent e0d034f commit e0bb155
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 71 deletions.
24 changes: 15 additions & 9 deletions dynamic/binary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@ import (
)

func TestBinaryUnaryFields(t *testing.T) {
binaryTranslationParty(t, unaryFieldsPosMsg)
binaryTranslationParty(t, unaryFieldsNegMsg)
binaryTranslationParty(t, unaryFieldsPosMsg, false)
binaryTranslationParty(t, unaryFieldsNegMsg, false)
binaryTranslationParty(t, unaryFieldsPosInfMsg, false)
binaryTranslationParty(t, unaryFieldsNegInfMsg, false)
binaryTranslationParty(t, unaryFieldsNanMsg, true)
}

func TestBinaryRepeatedFields(t *testing.T) {
binaryTranslationParty(t, repeatedFieldsMsg)
binaryTranslationParty(t, repeatedFieldsMsg, false)
binaryTranslationParty(t, repeatedFieldsInfNanMsg, true)
}

func TestBinaryPackedRepeatedFields(t *testing.T) {
binaryTranslationParty(t, repeatedPackedFieldsMsg)
binaryTranslationParty(t, repeatedPackedFieldsMsg, false)
binaryTranslationParty(t, repeatedPackedFieldsInfNanMsg, true)
}

func TestBinaryMapKeyFields(t *testing.T) {
Expand All @@ -31,7 +36,7 @@ func TestBinaryMapKeyFields(t *testing.T) {
defaultDeterminism = false
}()

binaryTranslationParty(t, mapKeyFieldsMsg)
binaryTranslationParty(t, mapKeyFieldsMsg, false)
}

func TestMarshalMapValueFields(t *testing.T) {
Expand All @@ -41,7 +46,8 @@ func TestMarshalMapValueFields(t *testing.T) {
defaultDeterminism = false
}()

binaryTranslationParty(t, mapValueFieldsMsg)
binaryTranslationParty(t, mapValueFieldsMsg, false)
binaryTranslationParty(t, mapValueFieldsInfNanMsg, true)
}

func TestBinaryExtensionFields(t *testing.T) {
Expand Down Expand Up @@ -127,10 +133,10 @@ func TestBinaryUnknownFields(t *testing.T) {
testutil.Eq(t, buf.buf, bb)

// now try a full translation party to ensure unknown bits remain correct throughout
binaryTranslationParty(t, &msg)
binaryTranslationParty(t, &msg, false)
}

func binaryTranslationParty(t *testing.T, msg proto.Message) {
func binaryTranslationParty(t *testing.T, msg proto.Message, includesNaN bool) {
marshalAppendSimple := func(m *Message) ([]byte, error) {
// Declare a function that has the same interface as (*Message.Marshal) but uses
// MarshalAppend internally so we can reuse the translation party tests to verify
Expand Down Expand Up @@ -160,7 +166,7 @@ func binaryTranslationParty(t *testing.T, msg proto.Message) {
}

for _, marshalFn := range marshalMethods {
doTranslationParty(t, msg, proto.Marshal, proto.Unmarshal, marshalFn, (*Message).Unmarshal)
doTranslationParty(t, msg, proto.Marshal, proto.Unmarshal, marshalFn, (*Message).Unmarshal, includesNaN)
}
}

Expand Down
26 changes: 7 additions & 19 deletions dynamic/equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ func fieldsEqual(aval, bval interface{}) bool {
if !ok {
return false
}
if !MessagesEqual(apm, bpm) {
return false
}
return MessagesEqual(apm, bpm)

} else {
switch arv.Kind() {
case reflect.Ptr:
Expand All @@ -83,33 +82,22 @@ func fieldsEqual(aval, bval interface{}) bool {
return false
}
bpm := bval.(proto.Message) // we know it will succeed because we know a and b have same type
if !MessagesEqual(apm, bpm) {
return false
}
return MessagesEqual(apm, bpm)

case reflect.Map:
if !mapsEqual(arv, brv) {
return false
}
return mapsEqual(arv, brv)

case reflect.Slice:
if arv.Type() == typeOfBytes {
if !bytes.Equal(aval.([]byte), bval.([]byte)) {
return false
}
return bytes.Equal(aval.([]byte), bval.([]byte))
} else {
if !slicesEqual(arv, brv) {
return false
}
return slicesEqual(arv, brv)
}

default:
if aval != bval {
return false
}
return aval == bval
}
}
return true
}

func slicesEqual(a, b reflect.Value) bool {
Expand Down
6 changes: 3 additions & 3 deletions dynamic/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,11 @@ func marshalKnownFieldValueJSON(b *indentBuffer, fd *desc.FieldDescriptor, v int
f := rv.Float()
var str string
if math.IsNaN(f) {
str = "NaN"
str = `"NaN"`
} else if math.IsInf(f, 1) {
str = "Infinity"
str = `"Infinity"`
} else if math.IsInf(f, -1) {
str = "-Infinity"
str = `"-Infinity"`
} else {
var bits int
if rv.Kind() == reflect.Float32 {
Expand Down
19 changes: 12 additions & 7 deletions dynamic/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,25 @@ import (
)

func TestJSONUnaryFields(t *testing.T) {
jsonTranslationParty(t, unaryFieldsPosMsg)
jsonTranslationParty(t, unaryFieldsNegMsg)
jsonTranslationParty(t, unaryFieldsPosMsg, false)
jsonTranslationParty(t, unaryFieldsNegMsg, false)
jsonTranslationParty(t, unaryFieldsPosInfMsg, false)
jsonTranslationParty(t, unaryFieldsNegInfMsg, false)
jsonTranslationParty(t, unaryFieldsNanMsg, true)
}

func TestJSONRepeatedFields(t *testing.T) {
jsonTranslationParty(t, repeatedFieldsMsg)
jsonTranslationParty(t, repeatedFieldsMsg, false)
jsonTranslationParty(t, repeatedFieldsInfNanMsg, true)
}

func TestJSONMapKeyFields(t *testing.T) {
jsonTranslationParty(t, mapKeyFieldsMsg)
jsonTranslationParty(t, mapKeyFieldsMsg, false)
}

func TestJSONMapValueFields(t *testing.T) {
jsonTranslationParty(t, mapValueFieldsMsg)
jsonTranslationParty(t, mapValueFieldsMsg, false)
jsonTranslationParty(t, mapValueFieldsInfNanMsg, true)
}

func TestJSONExtensionFields(t *testing.T) {
Expand Down Expand Up @@ -490,7 +495,7 @@ func TestJSONWellKnownTypeFromFileDescriptorSet(t *testing.T) {
testutil.Eq(t, js, dynJs)
}

func jsonTranslationParty(t *testing.T, msg proto.Message) {
func jsonTranslationParty(t *testing.T, msg proto.Message, includesNaN bool) {
doTranslationParty(t, msg,
func(pm proto.Message) ([]byte, error) {
m := jsonpb.Marshaler{}
Expand All @@ -505,5 +510,5 @@ func jsonTranslationParty(t *testing.T, msg proto.Message) {
func(b []byte, pm proto.Message) error {
return jsonpb.Unmarshal(bytes.NewReader(b), pm)
},
(*Message).MarshalJSON, (*Message).UnmarshalJSON)
(*Message).MarshalJSON, (*Message).UnmarshalJSON, includesNaN)
}
81 changes: 60 additions & 21 deletions dynamic/marshal_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dynamic

import (
"math"
"reflect"
"testing"

Expand Down Expand Up @@ -68,11 +69,26 @@ var unaryFieldsNegMsg = &testprotos.UnaryFields{
Z: testprotos.TestEnum_SECOND.Enum(),
}

var unaryFieldsPosInfMsg = &testprotos.UnaryFields{
S: proto.Float32(float32(math.Inf(1))),
T: proto.Float64(math.Inf(1)),
}

var unaryFieldsNegInfMsg = &testprotos.UnaryFields{
S: proto.Float32(float32(math.Inf(-1))),
T: proto.Float64(math.Inf(-1)),
}

var unaryFieldsNanMsg = &testprotos.UnaryFields{
S: proto.Float32(float32(math.NaN())),
T: proto.Float64(math.NaN()),
}

var repeatedFieldsMsg = &testprotos.RepeatedFields{
I: []int32{1, 2, 3},
J: []int64{4, 5, 6},
K: []int32{7, 8, 9},
L: []int64{10, 11, 12},
I: []int32{1, -2, 3},
J: []int64{-4, 5, -6},
K: []int32{7, -8, 9},
L: []int64{-10, 11, -12},
M: []uint32{13, 14, 15},
N: []uint64{16, 17, 18},
O: []uint32{19, 20, 21},
Expand All @@ -95,11 +111,16 @@ var repeatedFieldsMsg = &testprotos.RepeatedFields{
Z: []testprotos.TestEnum{testprotos.TestEnum_SECOND, testprotos.TestEnum_THIRD, testprotos.TestEnum_FIRST},
}

var repeatedFieldsInfNanMsg = &testprotos.RepeatedFields{
S: []float32{float32(math.Inf(1)), float32(math.Inf(-1)), float32(math.NaN())},
T: []float64{math.Inf(1), math.Inf(-1), math.NaN()},
}

var repeatedPackedFieldsMsg = &testprotos.RepeatedPackedFields{
I: []int32{1, 2, 3},
J: []int64{4, 5, 6},
K: []int32{7, 8, 9},
L: []int64{10, 11, 12},
I: []int32{1, -2, 3},
J: []int64{-4, 5, -6},
K: []int32{7, -8, 9},
L: []int64{-10, 11, -12},
M: []uint32{13, 14, 15},
N: []uint64{16, 17, 18},
O: []uint32{19, 20, 21},
Expand All @@ -116,11 +137,16 @@ var repeatedPackedFieldsMsg = &testprotos.RepeatedPackedFields{
V: []testprotos.TestEnum{testprotos.TestEnum_SECOND, testprotos.TestEnum_THIRD, testprotos.TestEnum_FIRST},
}

var repeatedPackedFieldsInfNanMsg = &testprotos.RepeatedPackedFields{
S: []float32{float32(math.Inf(1)), float32(math.Inf(-1)), float32(math.NaN())},
T: []float64{math.Inf(1), math.Inf(-1), math.NaN()},
}

var mapKeyFieldsMsg = &testprotos.MapKeyFields{
I: map[int32]string{1: "foo", 2: "bar", 3: "baz"},
J: map[int64]string{4: "foo", 5: "bar", 6: "baz"},
K: map[int32]string{7: "foo", 8: "bar", 9: "baz"},
L: map[int64]string{10: "foo", 11: "bar", 12: "baz"},
I: map[int32]string{1: "foo", -2: "bar", 3: "baz"},
J: map[int64]string{-4: "foo", 5: "bar", -6: "baz"},
K: map[int32]string{7: "foo", -8: "bar", 9: "baz"},
L: map[int64]string{-10: "foo", 11: "bar", -12: "baz"},
M: map[uint32]string{13: "foo", 14: "bar", 15: "baz"},
N: map[uint64]string{16: "foo", 17: "bar", 18: "baz"},
O: map[uint32]string{19: "foo", 20: "bar", 21: "baz"},
Expand All @@ -132,10 +158,10 @@ var mapKeyFieldsMsg = &testprotos.MapKeyFields{
}

var mapValueFieldsMsg = &testprotos.MapValFields{
I: map[string]int32{"a": 1, "b": 2, "c": 3},
J: map[string]int64{"a": 4, "b": 5, "c": 6},
K: map[string]int32{"a": 7, "b": 8, "c": 9},
L: map[string]int64{"a": 10, "b": 11, "c": 12},
I: map[string]int32{"a": 1, "b": -2, "c": 3},
J: map[string]int64{"a": -4, "b": 5, "c": -6},
K: map[string]int32{"a": 7, "b": -8, "c": 9},
L: map[string]int64{"a": -10, "b": 11, "c": -12},
M: map[string]uint32{"a": 13, "b": 14, "c": 15},
N: map[string]uint64{"a": 16, "b": 17, "c": 18},
O: map[string]uint32{"a": 19, "b": 20, "c": 21},
Expand All @@ -154,9 +180,15 @@ var mapValueFieldsMsg = &testprotos.MapValFields{
Y: map[string]testprotos.TestEnum{"a": testprotos.TestEnum_SECOND, "b": testprotos.TestEnum_THIRD, "c": testprotos.TestEnum_FIRST},
}

var mapValueFieldsInfNanMsg = &testprotos.MapValFields{
S: map[string]float32{"a": float32(math.Inf(1)), "b": float32(math.Inf(-1)), "c": float32(math.NaN())},
T: map[string]float64{"a": math.Inf(1), "b": math.Inf(-1), "c": math.NaN()},
}

func doTranslationParty(t *testing.T, msg proto.Message,
marshalPm func(proto.Message) ([]byte, error), unmarshalPm func([]byte, proto.Message) error,
marshalDm func(*Message) ([]byte, error), unmarshalDm func(*Message, []byte) error) {
marshalDm func(*Message) ([]byte, error), unmarshalDm func(*Message, []byte) error,
includesNaN bool) {

md, err := desc.LoadMessageDescriptorForMessage(msg)
testutil.Ok(t, err)
Expand All @@ -179,7 +211,10 @@ func doTranslationParty(t *testing.T, msg proto.Message,
err = unmarshalPm(b2a, msg2)
testutil.Ok(t, err)

testutil.Ceq(t, msg, msg2, eqpm)
if !includesNaN {
// NaN fields are never equal so this would always be false
testutil.Ceq(t, msg, msg2, eqpm)
}

// and back again
b3, err := marshalPm(msg2)
Expand All @@ -188,7 +223,9 @@ func doTranslationParty(t *testing.T, msg proto.Message,
err = unmarshalDm(dm2, b3)
testutil.Ok(t, err)

testutil.Ceq(t, dm, dm2, eqdm)
if !includesNaN {
testutil.Ceq(t, dm, dm2, eqdm)
}

// dynamic message -> (bytes) -> dynamic message
// both techniques to unmarshal are equivalent
Expand All @@ -199,6 +236,8 @@ func doTranslationParty(t *testing.T, msg proto.Message,
err = unmarshalDm(dm4, b2a)
testutil.Ok(t, err)

testutil.Ceq(t, dm, dm3, eqdm)
testutil.Ceq(t, dm, dm4, eqdm)
if !includesNaN {
testutil.Ceq(t, dm, dm3, eqdm)
testutil.Ceq(t, dm, dm4, eqdm)
}
}
Loading

0 comments on commit e0bb155

Please sign in to comment.