Skip to content

Commit

Permalink
Merge pull request #198 from jhump/jh/dynamic-messages-handle-inf-and…
Browse files Browse the repository at this point in the history
…-nan

dynamic messages properly handle inf and nan float values when serializing to/from JSON or text
  • Loading branch information
jhump authored May 20, 2019
2 parents e0d034f + e0bb155 commit 7dce9ca
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 71 deletions.
24 changes: 15 additions & 9 deletions dynamic/binary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@ import (
)

func TestBinaryUnaryFields(t *testing.T) {
binaryTranslationParty(t, unaryFieldsPosMsg)
binaryTranslationParty(t, unaryFieldsNegMsg)
binaryTranslationParty(t, unaryFieldsPosMsg, false)
binaryTranslationParty(t, unaryFieldsNegMsg, false)
binaryTranslationParty(t, unaryFieldsPosInfMsg, false)
binaryTranslationParty(t, unaryFieldsNegInfMsg, false)
binaryTranslationParty(t, unaryFieldsNanMsg, true)
}

func TestBinaryRepeatedFields(t *testing.T) {
binaryTranslationParty(t, repeatedFieldsMsg)
binaryTranslationParty(t, repeatedFieldsMsg, false)
binaryTranslationParty(t, repeatedFieldsInfNanMsg, true)
}

func TestBinaryPackedRepeatedFields(t *testing.T) {
binaryTranslationParty(t, repeatedPackedFieldsMsg)
binaryTranslationParty(t, repeatedPackedFieldsMsg, false)
binaryTranslationParty(t, repeatedPackedFieldsInfNanMsg, true)
}

func TestBinaryMapKeyFields(t *testing.T) {
Expand All @@ -31,7 +36,7 @@ func TestBinaryMapKeyFields(t *testing.T) {
defaultDeterminism = false
}()

binaryTranslationParty(t, mapKeyFieldsMsg)
binaryTranslationParty(t, mapKeyFieldsMsg, false)
}

func TestMarshalMapValueFields(t *testing.T) {
Expand All @@ -41,7 +46,8 @@ func TestMarshalMapValueFields(t *testing.T) {
defaultDeterminism = false
}()

binaryTranslationParty(t, mapValueFieldsMsg)
binaryTranslationParty(t, mapValueFieldsMsg, false)
binaryTranslationParty(t, mapValueFieldsInfNanMsg, true)
}

func TestBinaryExtensionFields(t *testing.T) {
Expand Down Expand Up @@ -127,10 +133,10 @@ func TestBinaryUnknownFields(t *testing.T) {
testutil.Eq(t, buf.buf, bb)

// now try a full translation party to ensure unknown bits remain correct throughout
binaryTranslationParty(t, &msg)
binaryTranslationParty(t, &msg, false)
}

func binaryTranslationParty(t *testing.T, msg proto.Message) {
func binaryTranslationParty(t *testing.T, msg proto.Message, includesNaN bool) {
marshalAppendSimple := func(m *Message) ([]byte, error) {
// Declare a function that has the same interface as (*Message.Marshal) but uses
// MarshalAppend internally so we can reuse the translation party tests to verify
Expand Down Expand Up @@ -160,7 +166,7 @@ func binaryTranslationParty(t *testing.T, msg proto.Message) {
}

for _, marshalFn := range marshalMethods {
doTranslationParty(t, msg, proto.Marshal, proto.Unmarshal, marshalFn, (*Message).Unmarshal)
doTranslationParty(t, msg, proto.Marshal, proto.Unmarshal, marshalFn, (*Message).Unmarshal, includesNaN)
}
}

Expand Down
26 changes: 7 additions & 19 deletions dynamic/equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ func fieldsEqual(aval, bval interface{}) bool {
if !ok {
return false
}
if !MessagesEqual(apm, bpm) {
return false
}
return MessagesEqual(apm, bpm)

} else {
switch arv.Kind() {
case reflect.Ptr:
Expand All @@ -83,33 +82,22 @@ func fieldsEqual(aval, bval interface{}) bool {
return false
}
bpm := bval.(proto.Message) // we know it will succeed because we know a and b have same type
if !MessagesEqual(apm, bpm) {
return false
}
return MessagesEqual(apm, bpm)

case reflect.Map:
if !mapsEqual(arv, brv) {
return false
}
return mapsEqual(arv, brv)

case reflect.Slice:
if arv.Type() == typeOfBytes {
if !bytes.Equal(aval.([]byte), bval.([]byte)) {
return false
}
return bytes.Equal(aval.([]byte), bval.([]byte))
} else {
if !slicesEqual(arv, brv) {
return false
}
return slicesEqual(arv, brv)
}

default:
if aval != bval {
return false
}
return aval == bval
}
}
return true
}

func slicesEqual(a, b reflect.Value) bool {
Expand Down
6 changes: 3 additions & 3 deletions dynamic/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,11 @@ func marshalKnownFieldValueJSON(b *indentBuffer, fd *desc.FieldDescriptor, v int
f := rv.Float()
var str string
if math.IsNaN(f) {
str = "NaN"
str = `"NaN"`
} else if math.IsInf(f, 1) {
str = "Infinity"
str = `"Infinity"`
} else if math.IsInf(f, -1) {
str = "-Infinity"
str = `"-Infinity"`
} else {
var bits int
if rv.Kind() == reflect.Float32 {
Expand Down
19 changes: 12 additions & 7 deletions dynamic/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,25 @@ import (
)

func TestJSONUnaryFields(t *testing.T) {
jsonTranslationParty(t, unaryFieldsPosMsg)
jsonTranslationParty(t, unaryFieldsNegMsg)
jsonTranslationParty(t, unaryFieldsPosMsg, false)
jsonTranslationParty(t, unaryFieldsNegMsg, false)
jsonTranslationParty(t, unaryFieldsPosInfMsg, false)
jsonTranslationParty(t, unaryFieldsNegInfMsg, false)
jsonTranslationParty(t, unaryFieldsNanMsg, true)
}

func TestJSONRepeatedFields(t *testing.T) {
jsonTranslationParty(t, repeatedFieldsMsg)
jsonTranslationParty(t, repeatedFieldsMsg, false)
jsonTranslationParty(t, repeatedFieldsInfNanMsg, true)
}

func TestJSONMapKeyFields(t *testing.T) {
jsonTranslationParty(t, mapKeyFieldsMsg)
jsonTranslationParty(t, mapKeyFieldsMsg, false)
}

func TestJSONMapValueFields(t *testing.T) {
jsonTranslationParty(t, mapValueFieldsMsg)
jsonTranslationParty(t, mapValueFieldsMsg, false)
jsonTranslationParty(t, mapValueFieldsInfNanMsg, true)
}

func TestJSONExtensionFields(t *testing.T) {
Expand Down Expand Up @@ -490,7 +495,7 @@ func TestJSONWellKnownTypeFromFileDescriptorSet(t *testing.T) {
testutil.Eq(t, js, dynJs)
}

func jsonTranslationParty(t *testing.T, msg proto.Message) {
func jsonTranslationParty(t *testing.T, msg proto.Message, includesNaN bool) {
doTranslationParty(t, msg,
func(pm proto.Message) ([]byte, error) {
m := jsonpb.Marshaler{}
Expand All @@ -505,5 +510,5 @@ func jsonTranslationParty(t *testing.T, msg proto.Message) {
func(b []byte, pm proto.Message) error {
return jsonpb.Unmarshal(bytes.NewReader(b), pm)
},
(*Message).MarshalJSON, (*Message).UnmarshalJSON)
(*Message).MarshalJSON, (*Message).UnmarshalJSON, includesNaN)
}
81 changes: 60 additions & 21 deletions dynamic/marshal_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dynamic

import (
"math"
"reflect"
"testing"

Expand Down Expand Up @@ -68,11 +69,26 @@ var unaryFieldsNegMsg = &testprotos.UnaryFields{
Z: testprotos.TestEnum_SECOND.Enum(),
}

var unaryFieldsPosInfMsg = &testprotos.UnaryFields{
S: proto.Float32(float32(math.Inf(1))),
T: proto.Float64(math.Inf(1)),
}

var unaryFieldsNegInfMsg = &testprotos.UnaryFields{
S: proto.Float32(float32(math.Inf(-1))),
T: proto.Float64(math.Inf(-1)),
}

var unaryFieldsNanMsg = &testprotos.UnaryFields{
S: proto.Float32(float32(math.NaN())),
T: proto.Float64(math.NaN()),
}

var repeatedFieldsMsg = &testprotos.RepeatedFields{
I: []int32{1, 2, 3},
J: []int64{4, 5, 6},
K: []int32{7, 8, 9},
L: []int64{10, 11, 12},
I: []int32{1, -2, 3},
J: []int64{-4, 5, -6},
K: []int32{7, -8, 9},
L: []int64{-10, 11, -12},
M: []uint32{13, 14, 15},
N: []uint64{16, 17, 18},
O: []uint32{19, 20, 21},
Expand All @@ -95,11 +111,16 @@ var repeatedFieldsMsg = &testprotos.RepeatedFields{
Z: []testprotos.TestEnum{testprotos.TestEnum_SECOND, testprotos.TestEnum_THIRD, testprotos.TestEnum_FIRST},
}

var repeatedFieldsInfNanMsg = &testprotos.RepeatedFields{
S: []float32{float32(math.Inf(1)), float32(math.Inf(-1)), float32(math.NaN())},
T: []float64{math.Inf(1), math.Inf(-1), math.NaN()},
}

var repeatedPackedFieldsMsg = &testprotos.RepeatedPackedFields{
I: []int32{1, 2, 3},
J: []int64{4, 5, 6},
K: []int32{7, 8, 9},
L: []int64{10, 11, 12},
I: []int32{1, -2, 3},
J: []int64{-4, 5, -6},
K: []int32{7, -8, 9},
L: []int64{-10, 11, -12},
M: []uint32{13, 14, 15},
N: []uint64{16, 17, 18},
O: []uint32{19, 20, 21},
Expand All @@ -116,11 +137,16 @@ var repeatedPackedFieldsMsg = &testprotos.RepeatedPackedFields{
V: []testprotos.TestEnum{testprotos.TestEnum_SECOND, testprotos.TestEnum_THIRD, testprotos.TestEnum_FIRST},
}

var repeatedPackedFieldsInfNanMsg = &testprotos.RepeatedPackedFields{
S: []float32{float32(math.Inf(1)), float32(math.Inf(-1)), float32(math.NaN())},
T: []float64{math.Inf(1), math.Inf(-1), math.NaN()},
}

var mapKeyFieldsMsg = &testprotos.MapKeyFields{
I: map[int32]string{1: "foo", 2: "bar", 3: "baz"},
J: map[int64]string{4: "foo", 5: "bar", 6: "baz"},
K: map[int32]string{7: "foo", 8: "bar", 9: "baz"},
L: map[int64]string{10: "foo", 11: "bar", 12: "baz"},
I: map[int32]string{1: "foo", -2: "bar", 3: "baz"},
J: map[int64]string{-4: "foo", 5: "bar", -6: "baz"},
K: map[int32]string{7: "foo", -8: "bar", 9: "baz"},
L: map[int64]string{-10: "foo", 11: "bar", -12: "baz"},
M: map[uint32]string{13: "foo", 14: "bar", 15: "baz"},
N: map[uint64]string{16: "foo", 17: "bar", 18: "baz"},
O: map[uint32]string{19: "foo", 20: "bar", 21: "baz"},
Expand All @@ -132,10 +158,10 @@ var mapKeyFieldsMsg = &testprotos.MapKeyFields{
}

var mapValueFieldsMsg = &testprotos.MapValFields{
I: map[string]int32{"a": 1, "b": 2, "c": 3},
J: map[string]int64{"a": 4, "b": 5, "c": 6},
K: map[string]int32{"a": 7, "b": 8, "c": 9},
L: map[string]int64{"a": 10, "b": 11, "c": 12},
I: map[string]int32{"a": 1, "b": -2, "c": 3},
J: map[string]int64{"a": -4, "b": 5, "c": -6},
K: map[string]int32{"a": 7, "b": -8, "c": 9},
L: map[string]int64{"a": -10, "b": 11, "c": -12},
M: map[string]uint32{"a": 13, "b": 14, "c": 15},
N: map[string]uint64{"a": 16, "b": 17, "c": 18},
O: map[string]uint32{"a": 19, "b": 20, "c": 21},
Expand All @@ -154,9 +180,15 @@ var mapValueFieldsMsg = &testprotos.MapValFields{
Y: map[string]testprotos.TestEnum{"a": testprotos.TestEnum_SECOND, "b": testprotos.TestEnum_THIRD, "c": testprotos.TestEnum_FIRST},
}

var mapValueFieldsInfNanMsg = &testprotos.MapValFields{
S: map[string]float32{"a": float32(math.Inf(1)), "b": float32(math.Inf(-1)), "c": float32(math.NaN())},
T: map[string]float64{"a": math.Inf(1), "b": math.Inf(-1), "c": math.NaN()},
}

func doTranslationParty(t *testing.T, msg proto.Message,
marshalPm func(proto.Message) ([]byte, error), unmarshalPm func([]byte, proto.Message) error,
marshalDm func(*Message) ([]byte, error), unmarshalDm func(*Message, []byte) error) {
marshalDm func(*Message) ([]byte, error), unmarshalDm func(*Message, []byte) error,
includesNaN bool) {

md, err := desc.LoadMessageDescriptorForMessage(msg)
testutil.Ok(t, err)
Expand All @@ -179,7 +211,10 @@ func doTranslationParty(t *testing.T, msg proto.Message,
err = unmarshalPm(b2a, msg2)
testutil.Ok(t, err)

testutil.Ceq(t, msg, msg2, eqpm)
if !includesNaN {
// NaN fields are never equal so this would always be false
testutil.Ceq(t, msg, msg2, eqpm)
}

// and back again
b3, err := marshalPm(msg2)
Expand All @@ -188,7 +223,9 @@ func doTranslationParty(t *testing.T, msg proto.Message,
err = unmarshalDm(dm2, b3)
testutil.Ok(t, err)

testutil.Ceq(t, dm, dm2, eqdm)
if !includesNaN {
testutil.Ceq(t, dm, dm2, eqdm)
}

// dynamic message -> (bytes) -> dynamic message
// both techniques to unmarshal are equivalent
Expand All @@ -199,6 +236,8 @@ func doTranslationParty(t *testing.T, msg proto.Message,
err = unmarshalDm(dm4, b2a)
testutil.Ok(t, err)

testutil.Ceq(t, dm, dm3, eqdm)
testutil.Ceq(t, dm, dm4, eqdm)
if !includesNaN {
testutil.Ceq(t, dm, dm3, eqdm)
testutil.Ceq(t, dm, dm4, eqdm)
}
}
Loading

0 comments on commit 7dce9ca

Please sign in to comment.