diff --git a/proto/all_test.go b/proto/all_test.go index c294596c9c..f391af74c3 100644 --- a/proto/all_test.go +++ b/proto/all_test.go @@ -2343,6 +2343,28 @@ func TestDeterministicErrorOnCustomMarshaler(t *testing.T) { } } +func TestRequired(t *testing.T) { + // The F_BoolRequired field appears after all of the required fields. + // It should still be handled even after multiple required field violations. + m := &GoTest{F_BoolRequired: Bool(true)} + got, err := Marshal(m) + if _, ok := err.(*RequiredNotSetError); !ok { + t.Errorf("Marshal() = %v, want RequiredNotSetError error", err) + } + if want := []byte{0x50, 0x01}; !bytes.Equal(got, want) { + t.Errorf("Marshal() = %x, want %x", got, want) + } + + m = new(GoTest) + err = Unmarshal(got, m) + if _, ok := err.(*RequiredNotSetError); !ok { + t.Errorf("Marshal() = %v, want RequiredNotSetError error", err) + } + if !m.GetF_BoolRequired() { + t.Error("m.F_BoolRequired = false, want true") + } +} + // Benchmarks func testMsg() *GoTest { diff --git a/proto/table_marshal.go b/proto/table_marshal.go index 92bd375c6e..ba58c49a43 100644 --- a/proto/table_marshal.go +++ b/proto/table_marshal.go @@ -279,11 +279,13 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte b = append(b, s...) } for _, f := range u.fields { - if f.required && errLater == nil { + if f.required { if f.isPointer && ptr.offset(f.field).getPointer().isNil() { // Required field is not set. // We record the error but keep going, to give a complete marshaling. - errLater = &RequiredNotSetError{f.name} + if errLater == nil { + errLater = &RequiredNotSetError{f.name} + } continue } } @@ -2825,7 +2827,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de p := toAddrPointer(&v, ei.isptr) b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic) b = append(b, 1<<3|WireEndGroup) - if nerr.Merge(err) { + if !nerr.Merge(err) { return b, err } } diff --git a/proto/table_unmarshal.go b/proto/table_unmarshal.go index 9fd5eb5457..e6b15c76ca 100644 --- a/proto/table_unmarshal.go +++ b/proto/table_unmarshal.go @@ -177,10 +177,12 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error { reqMask |= f.reqMask continue } - if r, ok := err.(*RequiredNotSetError); ok && errLater == nil { + if r, ok := err.(*RequiredNotSetError); ok { // Remember this error, but keep parsing. We need to produce // a full parse even if a required field is missing. - errLater = r + if errLater == nil { + errLater = r + } reqMask |= f.reqMask continue }