diff --git a/dynamic/binary_test.go b/dynamic/binary_test.go index 69d112e3..7c1be729 100644 --- a/dynamic/binary_test.go +++ b/dynamic/binary_test.go @@ -12,16 +12,21 @@ import ( ) func TestBinaryUnaryFields(t *testing.T) { - binaryTranslationParty(t, unaryFieldsPosMsg) - binaryTranslationParty(t, unaryFieldsNegMsg) + binaryTranslationParty(t, unaryFieldsPosMsg, false) + binaryTranslationParty(t, unaryFieldsNegMsg, false) + binaryTranslationParty(t, unaryFieldsPosInfMsg, false) + binaryTranslationParty(t, unaryFieldsNegInfMsg, false) + binaryTranslationParty(t, unaryFieldsNanMsg, true) } func TestBinaryRepeatedFields(t *testing.T) { - binaryTranslationParty(t, repeatedFieldsMsg) + binaryTranslationParty(t, repeatedFieldsMsg, false) + binaryTranslationParty(t, repeatedFieldsInfNanMsg, true) } func TestBinaryPackedRepeatedFields(t *testing.T) { - binaryTranslationParty(t, repeatedPackedFieldsMsg) + binaryTranslationParty(t, repeatedPackedFieldsMsg, false) + binaryTranslationParty(t, repeatedPackedFieldsInfNanMsg, true) } func TestBinaryMapKeyFields(t *testing.T) { @@ -31,7 +36,7 @@ func TestBinaryMapKeyFields(t *testing.T) { defaultDeterminism = false }() - binaryTranslationParty(t, mapKeyFieldsMsg) + binaryTranslationParty(t, mapKeyFieldsMsg, false) } func TestMarshalMapValueFields(t *testing.T) { @@ -41,7 +46,8 @@ func TestMarshalMapValueFields(t *testing.T) { defaultDeterminism = false }() - binaryTranslationParty(t, mapValueFieldsMsg) + binaryTranslationParty(t, mapValueFieldsMsg, false) + binaryTranslationParty(t, mapValueFieldsInfNanMsg, true) } func TestBinaryExtensionFields(t *testing.T) { @@ -127,10 +133,10 @@ func TestBinaryUnknownFields(t *testing.T) { testutil.Eq(t, buf.buf, bb) // now try a full translation party to ensure unknown bits remain correct throughout - binaryTranslationParty(t, &msg) + binaryTranslationParty(t, &msg, false) } -func binaryTranslationParty(t *testing.T, msg proto.Message) { +func binaryTranslationParty(t *testing.T, msg proto.Message, includesNaN bool) { marshalAppendSimple := func(m *Message) ([]byte, error) { // Declare a function that has the same interface as (*Message.Marshal) but uses // MarshalAppend internally so we can reuse the translation party tests to verify @@ -160,7 +166,7 @@ func binaryTranslationParty(t *testing.T, msg proto.Message) { } for _, marshalFn := range marshalMethods { - doTranslationParty(t, msg, proto.Marshal, proto.Unmarshal, marshalFn, (*Message).Unmarshal) + doTranslationParty(t, msg, proto.Marshal, proto.Unmarshal, marshalFn, (*Message).Unmarshal, includesNaN) } } diff --git a/dynamic/equal.go b/dynamic/equal.go index 502cd2e1..5fbcc245 100644 --- a/dynamic/equal.go +++ b/dynamic/equal.go @@ -70,9 +70,8 @@ func fieldsEqual(aval, bval interface{}) bool { if !ok { return false } - if !MessagesEqual(apm, bpm) { - return false - } + return MessagesEqual(apm, bpm) + } else { switch arv.Kind() { case reflect.Ptr: @@ -83,33 +82,22 @@ func fieldsEqual(aval, bval interface{}) bool { return false } bpm := bval.(proto.Message) // we know it will succeed because we know a and b have same type - if !MessagesEqual(apm, bpm) { - return false - } + return MessagesEqual(apm, bpm) case reflect.Map: - if !mapsEqual(arv, brv) { - return false - } + return mapsEqual(arv, brv) case reflect.Slice: if arv.Type() == typeOfBytes { - if !bytes.Equal(aval.([]byte), bval.([]byte)) { - return false - } + return bytes.Equal(aval.([]byte), bval.([]byte)) } else { - if !slicesEqual(arv, brv) { - return false - } + return slicesEqual(arv, brv) } default: - if aval != bval { - return false - } + return aval == bval } } - return true } func slicesEqual(a, b reflect.Value) bool { diff --git a/dynamic/json.go b/dynamic/json.go index 4179b218..f79b4ac8 100644 --- a/dynamic/json.go +++ b/dynamic/json.go @@ -367,11 +367,11 @@ func marshalKnownFieldValueJSON(b *indentBuffer, fd *desc.FieldDescriptor, v int f := rv.Float() var str string if math.IsNaN(f) { - str = "NaN" + str = `"NaN"` } else if math.IsInf(f, 1) { - str = "Infinity" + str = `"Infinity"` } else if math.IsInf(f, -1) { - str = "-Infinity" + str = `"-Infinity"` } else { var bits int if rv.Kind() == reflect.Float32 { diff --git a/dynamic/json_test.go b/dynamic/json_test.go index f1d6de07..ec5a5e3a 100644 --- a/dynamic/json_test.go +++ b/dynamic/json_test.go @@ -24,20 +24,25 @@ import ( ) func TestJSONUnaryFields(t *testing.T) { - jsonTranslationParty(t, unaryFieldsPosMsg) - jsonTranslationParty(t, unaryFieldsNegMsg) + jsonTranslationParty(t, unaryFieldsPosMsg, false) + jsonTranslationParty(t, unaryFieldsNegMsg, false) + jsonTranslationParty(t, unaryFieldsPosInfMsg, false) + jsonTranslationParty(t, unaryFieldsNegInfMsg, false) + jsonTranslationParty(t, unaryFieldsNanMsg, true) } func TestJSONRepeatedFields(t *testing.T) { - jsonTranslationParty(t, repeatedFieldsMsg) + jsonTranslationParty(t, repeatedFieldsMsg, false) + jsonTranslationParty(t, repeatedFieldsInfNanMsg, true) } func TestJSONMapKeyFields(t *testing.T) { - jsonTranslationParty(t, mapKeyFieldsMsg) + jsonTranslationParty(t, mapKeyFieldsMsg, false) } func TestJSONMapValueFields(t *testing.T) { - jsonTranslationParty(t, mapValueFieldsMsg) + jsonTranslationParty(t, mapValueFieldsMsg, false) + jsonTranslationParty(t, mapValueFieldsInfNanMsg, true) } func TestJSONExtensionFields(t *testing.T) { @@ -490,7 +495,7 @@ func TestJSONWellKnownTypeFromFileDescriptorSet(t *testing.T) { testutil.Eq(t, js, dynJs) } -func jsonTranslationParty(t *testing.T, msg proto.Message) { +func jsonTranslationParty(t *testing.T, msg proto.Message, includesNaN bool) { doTranslationParty(t, msg, func(pm proto.Message) ([]byte, error) { m := jsonpb.Marshaler{} @@ -505,5 +510,5 @@ func jsonTranslationParty(t *testing.T, msg proto.Message) { func(b []byte, pm proto.Message) error { return jsonpb.Unmarshal(bytes.NewReader(b), pm) }, - (*Message).MarshalJSON, (*Message).UnmarshalJSON) + (*Message).MarshalJSON, (*Message).UnmarshalJSON, includesNaN) } diff --git a/dynamic/marshal_test.go b/dynamic/marshal_test.go index 058595d5..b0163268 100644 --- a/dynamic/marshal_test.go +++ b/dynamic/marshal_test.go @@ -1,6 +1,7 @@ package dynamic import ( + "math" "reflect" "testing" @@ -68,11 +69,26 @@ var unaryFieldsNegMsg = &testprotos.UnaryFields{ Z: testprotos.TestEnum_SECOND.Enum(), } +var unaryFieldsPosInfMsg = &testprotos.UnaryFields{ + S: proto.Float32(float32(math.Inf(1))), + T: proto.Float64(math.Inf(1)), +} + +var unaryFieldsNegInfMsg = &testprotos.UnaryFields{ + S: proto.Float32(float32(math.Inf(-1))), + T: proto.Float64(math.Inf(-1)), +} + +var unaryFieldsNanMsg = &testprotos.UnaryFields{ + S: proto.Float32(float32(math.NaN())), + T: proto.Float64(math.NaN()), +} + var repeatedFieldsMsg = &testprotos.RepeatedFields{ - I: []int32{1, 2, 3}, - J: []int64{4, 5, 6}, - K: []int32{7, 8, 9}, - L: []int64{10, 11, 12}, + I: []int32{1, -2, 3}, + J: []int64{-4, 5, -6}, + K: []int32{7, -8, 9}, + L: []int64{-10, 11, -12}, M: []uint32{13, 14, 15}, N: []uint64{16, 17, 18}, O: []uint32{19, 20, 21}, @@ -95,11 +111,16 @@ var repeatedFieldsMsg = &testprotos.RepeatedFields{ Z: []testprotos.TestEnum{testprotos.TestEnum_SECOND, testprotos.TestEnum_THIRD, testprotos.TestEnum_FIRST}, } +var repeatedFieldsInfNanMsg = &testprotos.RepeatedFields{ + S: []float32{float32(math.Inf(1)), float32(math.Inf(-1)), float32(math.NaN())}, + T: []float64{math.Inf(1), math.Inf(-1), math.NaN()}, +} + var repeatedPackedFieldsMsg = &testprotos.RepeatedPackedFields{ - I: []int32{1, 2, 3}, - J: []int64{4, 5, 6}, - K: []int32{7, 8, 9}, - L: []int64{10, 11, 12}, + I: []int32{1, -2, 3}, + J: []int64{-4, 5, -6}, + K: []int32{7, -8, 9}, + L: []int64{-10, 11, -12}, M: []uint32{13, 14, 15}, N: []uint64{16, 17, 18}, O: []uint32{19, 20, 21}, @@ -116,11 +137,16 @@ var repeatedPackedFieldsMsg = &testprotos.RepeatedPackedFields{ V: []testprotos.TestEnum{testprotos.TestEnum_SECOND, testprotos.TestEnum_THIRD, testprotos.TestEnum_FIRST}, } +var repeatedPackedFieldsInfNanMsg = &testprotos.RepeatedPackedFields{ + S: []float32{float32(math.Inf(1)), float32(math.Inf(-1)), float32(math.NaN())}, + T: []float64{math.Inf(1), math.Inf(-1), math.NaN()}, +} + var mapKeyFieldsMsg = &testprotos.MapKeyFields{ - I: map[int32]string{1: "foo", 2: "bar", 3: "baz"}, - J: map[int64]string{4: "foo", 5: "bar", 6: "baz"}, - K: map[int32]string{7: "foo", 8: "bar", 9: "baz"}, - L: map[int64]string{10: "foo", 11: "bar", 12: "baz"}, + I: map[int32]string{1: "foo", -2: "bar", 3: "baz"}, + J: map[int64]string{-4: "foo", 5: "bar", -6: "baz"}, + K: map[int32]string{7: "foo", -8: "bar", 9: "baz"}, + L: map[int64]string{-10: "foo", 11: "bar", -12: "baz"}, M: map[uint32]string{13: "foo", 14: "bar", 15: "baz"}, N: map[uint64]string{16: "foo", 17: "bar", 18: "baz"}, O: map[uint32]string{19: "foo", 20: "bar", 21: "baz"}, @@ -132,10 +158,10 @@ var mapKeyFieldsMsg = &testprotos.MapKeyFields{ } var mapValueFieldsMsg = &testprotos.MapValFields{ - I: map[string]int32{"a": 1, "b": 2, "c": 3}, - J: map[string]int64{"a": 4, "b": 5, "c": 6}, - K: map[string]int32{"a": 7, "b": 8, "c": 9}, - L: map[string]int64{"a": 10, "b": 11, "c": 12}, + I: map[string]int32{"a": 1, "b": -2, "c": 3}, + J: map[string]int64{"a": -4, "b": 5, "c": -6}, + K: map[string]int32{"a": 7, "b": -8, "c": 9}, + L: map[string]int64{"a": -10, "b": 11, "c": -12}, M: map[string]uint32{"a": 13, "b": 14, "c": 15}, N: map[string]uint64{"a": 16, "b": 17, "c": 18}, O: map[string]uint32{"a": 19, "b": 20, "c": 21}, @@ -154,9 +180,15 @@ var mapValueFieldsMsg = &testprotos.MapValFields{ Y: map[string]testprotos.TestEnum{"a": testprotos.TestEnum_SECOND, "b": testprotos.TestEnum_THIRD, "c": testprotos.TestEnum_FIRST}, } +var mapValueFieldsInfNanMsg = &testprotos.MapValFields{ + S: map[string]float32{"a": float32(math.Inf(1)), "b": float32(math.Inf(-1)), "c": float32(math.NaN())}, + T: map[string]float64{"a": math.Inf(1), "b": math.Inf(-1), "c": math.NaN()}, +} + func doTranslationParty(t *testing.T, msg proto.Message, marshalPm func(proto.Message) ([]byte, error), unmarshalPm func([]byte, proto.Message) error, - marshalDm func(*Message) ([]byte, error), unmarshalDm func(*Message, []byte) error) { + marshalDm func(*Message) ([]byte, error), unmarshalDm func(*Message, []byte) error, + includesNaN bool) { md, err := desc.LoadMessageDescriptorForMessage(msg) testutil.Ok(t, err) @@ -179,7 +211,10 @@ func doTranslationParty(t *testing.T, msg proto.Message, err = unmarshalPm(b2a, msg2) testutil.Ok(t, err) - testutil.Ceq(t, msg, msg2, eqpm) + if !includesNaN { + // NaN fields are never equal so this would always be false + testutil.Ceq(t, msg, msg2, eqpm) + } // and back again b3, err := marshalPm(msg2) @@ -188,7 +223,9 @@ func doTranslationParty(t *testing.T, msg proto.Message, err = unmarshalDm(dm2, b3) testutil.Ok(t, err) - testutil.Ceq(t, dm, dm2, eqdm) + if !includesNaN { + testutil.Ceq(t, dm, dm2, eqdm) + } // dynamic message -> (bytes) -> dynamic message // both techniques to unmarshal are equivalent @@ -199,6 +236,8 @@ func doTranslationParty(t *testing.T, msg proto.Message, err = unmarshalDm(dm4, b2a) testutil.Ok(t, err) - testutil.Ceq(t, dm, dm3, eqdm) - testutil.Ceq(t, dm, dm4, eqdm) + if !includesNaN { + testutil.Ceq(t, dm, dm3, eqdm) + testutil.Ceq(t, dm, dm4, eqdm) + } } diff --git a/dynamic/text.go b/dynamic/text.go index e7c6899d..2d0fa043 100644 --- a/dynamic/text.go +++ b/dynamic/text.go @@ -706,25 +706,59 @@ func (m *Message) unmarshalFieldElementText(fd *desc.FieldDescriptor, tr *txtRea } expected = "string value" case descriptor.FieldDescriptorProto_TYPE_FLOAT: - if tok.tokTyp == tokenFloat { + switch tok.tokTyp { + case tokenFloat: return set(m, fd, float32(tok.val.(float64))) - } else if tok.tokTyp == tokenInt { + case tokenInt: if f, err := strconv.ParseFloat(tok.val.(string), 32); err != nil { return err } else { return set(m, fd, float32(f)) } + case tokenIdent: + ident := strings.ToLower(tok.val.(string)) + if ident == "inf" { + return set(m, fd, float32(math.Inf(1))) + } else if ident == "nan" { + return set(m, fd, float32(math.NaN())) + } + case tokenMinus: + peeked := tr.peek() + if peeked.tokTyp == tokenIdent { + ident := strings.ToLower(peeked.val.(string)) + if ident == "inf" { + tr.next() // consume peeked token + return set(m, fd, float32(math.Inf(-1))) + } + } } expected = "float value" case descriptor.FieldDescriptorProto_TYPE_DOUBLE: - if tok.tokTyp == tokenFloat { + switch tok.tokTyp { + case tokenFloat: return set(m, fd, tok.val) - } else if tok.tokTyp == tokenInt { + case tokenInt: if f, err := strconv.ParseFloat(tok.val.(string), 64); err != nil { return err } else { return set(m, fd, f) } + case tokenIdent: + ident := strings.ToLower(tok.val.(string)) + if ident == "inf" { + return set(m, fd, math.Inf(1)) + } else if ident == "nan" { + return set(m, fd, math.NaN()) + } + case tokenMinus: + peeked := tr.peek() + if peeked.tokTyp == tokenIdent { + ident := strings.ToLower(peeked.val.(string)) + if ident == "inf" { + tr.next() // consume peeked token + return set(m, fd, math.Inf(-1)) + } + } } expected = "float value" case descriptor.FieldDescriptorProto_TYPE_INT32, @@ -972,6 +1006,7 @@ const ( tokenOpenParen tokenCloseParen tokenSlash + tokenMinus ) func (t tokenType) IsSep() bool { @@ -1061,7 +1096,10 @@ func (p *txtReader) processToken(t rune, text string, pos scanner.Position) erro } case '-': // unary minus, for negative ints and floats ch := p.scanner.Peek() - if ch >= '0' && ch <= '9' { + if ch < '0' || ch > '9' { + p.peeked.tokTyp = tokenMinus + p.peeked.val = '-' + } else { t := p.scanner.Scan() if t == scanner.EOF { return io.ErrUnexpectedEOF diff --git a/dynamic/text_test.go b/dynamic/text_test.go index 3d44c724..b9644223 100644 --- a/dynamic/text_test.go +++ b/dynamic/text_test.go @@ -11,20 +11,25 @@ import ( ) func TestTextUnaryFields(t *testing.T) { - textTranslationParty(t, unaryFieldsPosMsg) - textTranslationParty(t, unaryFieldsNegMsg) + textTranslationParty(t, unaryFieldsPosMsg, false) + textTranslationParty(t, unaryFieldsNegMsg, false) + textTranslationParty(t, unaryFieldsPosInfMsg, false) + textTranslationParty(t, unaryFieldsNegInfMsg, false) + textTranslationParty(t, unaryFieldsNanMsg, true) } func TestTextRepeatedFields(t *testing.T) { - textTranslationParty(t, repeatedFieldsMsg) + textTranslationParty(t, repeatedFieldsMsg, false) + textTranslationParty(t, repeatedFieldsInfNanMsg, true) } func TestTextMapKeyFields(t *testing.T) { - textTranslationParty(t, mapKeyFieldsMsg) + textTranslationParty(t, mapKeyFieldsMsg, false) } func TestTextMapValueFields(t *testing.T) { - textTranslationParty(t, mapValueFieldsMsg) + textTranslationParty(t, mapValueFieldsMsg, false) + textTranslationParty(t, mapValueFieldsInfNanMsg, true) } func TestTextUnknownFields(t *testing.T) { @@ -154,7 +159,7 @@ func TestTextLenientParsing(t *testing.T) { } } -func textTranslationParty(t *testing.T, msg proto.Message) { +func textTranslationParty(t *testing.T, msg proto.Message, includesNaN bool) { doTranslationParty(t, msg, func(pm proto.Message) ([]byte, error) { return []byte(proto.MarshalTextString(pm)), nil @@ -162,5 +167,5 @@ func textTranslationParty(t *testing.T, msg proto.Message) { func(b []byte, pm proto.Message) error { return proto.UnmarshalText(string(b), pm) }, - (*Message).MarshalText, (*Message).UnmarshalText) + (*Message).MarshalText, (*Message).UnmarshalText, includesNaN) }