From 4ff7f982fde906f1f07157ebad3c25d614e74d5c Mon Sep 17 00:00:00 2001 From: Greg Burek Date: Wed, 25 Sep 2024 16:19:53 -0700 Subject: [PATCH] Adds support for proto3opt Attempts to detect optional fields by using `field.Oneof.Desc.IsSynthetic()` and relies on nil checking pointers for the marshall and unmarshalling --- features/json/message-marshal.go | 14 +- features/json/message-unmarshal.go | 9 +- features/json/message.go | 24 +-- testproto/proto3opt/opt.pb.go | 240 ++++++++++++++++++++++++++++- 4 files changed, 264 insertions(+), 23 deletions(-) diff --git a/features/json/message-marshal.go b/features/json/message-marshal.go index 350f6523..a7af3a97 100644 --- a/features/json/message-marshal.go +++ b/features/json/message-marshal.go @@ -143,13 +143,13 @@ nextField: // If this is the first field in a oneof, write the if statement that checks for nil // and start the switch statement for the oneof type. - if field.Oneof != nil && field == field.Oneof.Fields[0] { + if field.Oneof != nil && field == field.Oneof.Fields[0] && !field.Oneof.Desc.IsSynthetic() { // NOTE: we don't support field masks here (yet). g.P("if x.", field.Oneof.GoName, " != nil {") g.P("switch ov := x.", field.Oneof.GoName, ".(type) {") } - if field.Oneof != nil { + if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() { // If we're in a oneof, check if this is the field that's set in the oneof. g.P("case *", field.GoIdent.GoName, ":") messageOrOneofIdent = "ov" @@ -192,6 +192,12 @@ nextField: switch field.Desc.Kind() { default: // Scalar types can be written by the library. + if field.Oneof != nil && field.Oneof.Desc.IsSynthetic() { + g.P("s.Write", g.libNameForField(field), "(*", messageOrOneofIdent, ".", fieldGoName, ")") + } else { + g.P("s.Write", g.libNameForField(field), "(", messageOrOneofIdent, ".", fieldGoName, ")") + } + case protoreflect.BytesKind: g.P("s.Write", g.libNameForField(field), "(", messageOrOneofIdent, ".", fieldGoName, ")") case protoreflect.EnumKind: // If the field is of type enum, and the enum has a marshaler, use that. @@ -207,12 +213,12 @@ nextField: } // If we're not in a oneof, end the "if not zero". - if field.Oneof == nil { + if field.Oneof == nil || field.Oneof.Desc.IsSynthetic() { g.P("}") // end if x.{field.GoName} != zero value { } // If this is the last field in the oneof, close the switch and if statements. - if field.Oneof != nil && field == field.Oneof.Fields[len(field.Oneof.Fields)-1] { + if field.Oneof != nil && field == field.Oneof.Fields[len(field.Oneof.Fields)-1] && !field.Oneof.Desc.IsSynthetic() { g.P("}") // end switch v := x.{field.Oneof.GoName}.(type) { g.P("}") // end if x.{field.Oneof.GoName} != nil { } diff --git a/features/json/message-unmarshal.go b/features/json/message-unmarshal.go index dc394039..b1380af0 100644 --- a/features/json/message-unmarshal.go +++ b/features/json/message-unmarshal.go @@ -162,7 +162,7 @@ nextField: messageOrOneofIdent := "x" // If this field is in a oneof, allocate a new oneof value wrapper. - if field.Oneof != nil { + if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() { g.P("ov := &", field.GoIdent.GoName, "{}") g.P("x.", field.Oneof.GoName, " = ov") messageOrOneofIdent = "ov" @@ -190,6 +190,13 @@ nextField: switch field.Desc.Kind() { default: // Scalar types can be read by the library. + if field.Oneof != nil && field.Oneof.Desc.IsSynthetic() { + g.P("t := s.Read", g.libNameForField(field), "()") + g.P(messageOrOneofIdent, ".", fieldGoName, " = &t") + } else { + g.P(messageOrOneofIdent, ".", fieldGoName, " = s.Read", g.libNameForField(field), "()") + } + case protoreflect.BytesKind: g.P(messageOrOneofIdent, ".", fieldGoName, " = s.Read", g.libNameForField(field), "()") case protoreflect.EnumKind: // If the field is of type enum, and the enum has an unmarshaler, call the unmarshaler. diff --git a/features/json/message.go b/features/json/message.go index 9c62ad2d..48fab472 100644 --- a/features/json/message.go +++ b/features/json/message.go @@ -6,7 +6,6 @@ package json import ( "fmt" - "slices" "github.com/aperturerobotics/protobuf-go-lite/compiler/protogen" "google.golang.org/protobuf/reflect/protoreflect" @@ -28,26 +27,17 @@ func (g *jsonGenerator) genMessage(message *protogen.Message) { return } - // Check if the message has any optional fields and skip generation if so. - anyOptional := slices.ContainsFunc(message.Fields, func(f *protogen.Field) bool { - return f.Desc.HasOptionalKeyword() - }) + g.genMessageMarshaler(message) + g.genStdMessageMarshaler(message) - if !anyOptional { - g.genMessageMarshaler(message) - g.genStdMessageMarshaler(message) - - g.genMessageUnmarshaler(message) - g.genStdMessageUnmarshaler(message) - } else { - // We do not support marshaling this field, skip the entire message. - g.P("// NOTE: protobuf-go-lite json only supports proto3 and not proto3opt (optional fields).") - g.P() - } + g.genMessageUnmarshaler(message) + g.genStdMessageUnmarshaler(message) } func fieldIsNullable(field *protogen.Field) bool { - // In the supported subset of syntax (proto3 and not proto3opt) we only use pointers for messages. + if field.Oneof != nil && field.Oneof.Desc.IsSynthetic() { + return true + } nullable := field.Desc.Kind() == protoreflect.MessageKind return nullable } diff --git a/testproto/proto3opt/opt.pb.go b/testproto/proto3opt/opt.pb.go index 4090f69c..dc7abaf9 100644 --- a/testproto/proto3opt/opt.pb.go +++ b/testproto/proto3opt/opt.pb.go @@ -378,7 +378,245 @@ func (x *SimpleEnum) UnmarshalJSON(b []byte) error { return json.DefaultUnmarshalerConfig.Unmarshal(b, x) } -// NOTE: protobuf-go-lite json only supports proto3 and not proto3opt (optional fields). +// MarshalProtoJSON marshals the OptionalFieldInProto3 message to JSON. +func (x *OptionalFieldInProto3) MarshalProtoJSON(s *json.MarshalState) { + if x == nil { + s.WriteNil() + return + } + s.WriteObjectStart() + var wroteField bool + if x.OptionalInt32 != nil || s.HasField("optionalInt32") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalInt32") + s.WriteInt32(*x.OptionalInt32) + } + if x.OptionalInt64 != nil || s.HasField("optionalInt64") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalInt64") + s.WriteInt64(*x.OptionalInt64) + } + if x.OptionalUint32 != nil || s.HasField("optionalUint32") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalUint32") + s.WriteUint32(*x.OptionalUint32) + } + if x.OptionalUint64 != nil || s.HasField("optionalUint64") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalUint64") + s.WriteUint64(*x.OptionalUint64) + } + if x.OptionalSint32 != nil || s.HasField("optionalSint32") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalSint32") + s.WriteInt32(*x.OptionalSint32) + } + if x.OptionalSint64 != nil || s.HasField("optionalSint64") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalSint64") + s.WriteInt64(*x.OptionalSint64) + } + if x.OptionalFixed32 != nil || s.HasField("optionalFixed32") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalFixed32") + s.WriteUint32(*x.OptionalFixed32) + } + if x.OptionalFixed64 != nil || s.HasField("optionalFixed64") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalFixed64") + s.WriteUint64(*x.OptionalFixed64) + } + if x.OptionalSfixed32 != nil || s.HasField("optionalSfixed32") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalSfixed32") + s.WriteInt32(*x.OptionalSfixed32) + } + if x.OptionalSfixed64 != nil || s.HasField("optionalSfixed64") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalSfixed64") + s.WriteInt64(*x.OptionalSfixed64) + } + if x.OptionalFloat != nil || s.HasField("optionalFloat") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalFloat") + s.WriteFloat32(*x.OptionalFloat) + } + if x.OptionalDouble != nil || s.HasField("optionalDouble") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalDouble") + s.WriteFloat64(*x.OptionalDouble) + } + if x.OptionalBool != nil || s.HasField("optionalBool") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalBool") + s.WriteBool(*x.OptionalBool) + } + if x.OptionalString != nil || s.HasField("optionalString") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalString") + s.WriteString(*x.OptionalString) + } + if x.OptionalBytes != nil || s.HasField("optionalBytes") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalBytes") + s.WriteBytes(x.OptionalBytes) + } + if x.OptionalEnum != nil || s.HasField("optionalEnum") { + s.WriteMoreIf(&wroteField) + s.WriteObjectField("optionalEnum") + x.OptionalEnum.MarshalProtoJSON(s) + } + s.WriteObjectEnd() +} + +// MarshalJSON marshals the OptionalFieldInProto3 to JSON. +func (x *OptionalFieldInProto3) MarshalJSON() ([]byte, error) { + return json.DefaultMarshalerConfig.Marshal(x) +} + +// UnmarshalProtoJSON unmarshals the OptionalFieldInProto3 message from JSON. +func (x *OptionalFieldInProto3) UnmarshalProtoJSON(s *json.UnmarshalState) { + if s.ReadNil() { + return + } + s.ReadObject(func(key string) { + switch key { + default: + s.Skip() // ignore unknown field + case "optional_int32", "optionalInt32": + s.AddField("optional_int32") + if s.ReadNil() { + x.OptionalInt32 = nil + return + } + t := s.ReadInt32() + x.OptionalInt32 = &t + case "optional_int64", "optionalInt64": + s.AddField("optional_int64") + if s.ReadNil() { + x.OptionalInt64 = nil + return + } + t := s.ReadInt64() + x.OptionalInt64 = &t + case "optional_uint32", "optionalUint32": + s.AddField("optional_uint32") + if s.ReadNil() { + x.OptionalUint32 = nil + return + } + t := s.ReadUint32() + x.OptionalUint32 = &t + case "optional_uint64", "optionalUint64": + s.AddField("optional_uint64") + if s.ReadNil() { + x.OptionalUint64 = nil + return + } + t := s.ReadUint64() + x.OptionalUint64 = &t + case "optional_sint32", "optionalSint32": + s.AddField("optional_sint32") + if s.ReadNil() { + x.OptionalSint32 = nil + return + } + t := s.ReadInt32() + x.OptionalSint32 = &t + case "optional_sint64", "optionalSint64": + s.AddField("optional_sint64") + if s.ReadNil() { + x.OptionalSint64 = nil + return + } + t := s.ReadInt64() + x.OptionalSint64 = &t + case "optional_fixed32", "optionalFixed32": + s.AddField("optional_fixed32") + if s.ReadNil() { + x.OptionalFixed32 = nil + return + } + t := s.ReadUint32() + x.OptionalFixed32 = &t + case "optional_fixed64", "optionalFixed64": + s.AddField("optional_fixed64") + if s.ReadNil() { + x.OptionalFixed64 = nil + return + } + t := s.ReadUint64() + x.OptionalFixed64 = &t + case "optional_sfixed32", "optionalSfixed32": + s.AddField("optional_sfixed32") + if s.ReadNil() { + x.OptionalSfixed32 = nil + return + } + t := s.ReadInt32() + x.OptionalSfixed32 = &t + case "optional_sfixed64", "optionalSfixed64": + s.AddField("optional_sfixed64") + if s.ReadNil() { + x.OptionalSfixed64 = nil + return + } + t := s.ReadInt64() + x.OptionalSfixed64 = &t + case "optional_float", "optionalFloat": + s.AddField("optional_float") + if s.ReadNil() { + x.OptionalFloat = nil + return + } + t := s.ReadFloat32() + x.OptionalFloat = &t + case "optional_double", "optionalDouble": + s.AddField("optional_double") + if s.ReadNil() { + x.OptionalDouble = nil + return + } + t := s.ReadFloat64() + x.OptionalDouble = &t + case "optional_bool", "optionalBool": + s.AddField("optional_bool") + if s.ReadNil() { + x.OptionalBool = nil + return + } + t := s.ReadBool() + x.OptionalBool = &t + case "optional_string", "optionalString": + s.AddField("optional_string") + if s.ReadNil() { + x.OptionalString = nil + return + } + t := s.ReadString() + x.OptionalString = &t + case "optional_bytes", "optionalBytes": + s.AddField("optional_bytes") + if s.ReadNil() { + x.OptionalBytes = nil + return + } + x.OptionalBytes = s.ReadBytes() + case "optional_enum", "optionalEnum": + s.AddField("optional_enum") + if s.ReadNil() { + x.OptionalEnum = nil + return + } + x.OptionalEnum.UnmarshalProtoJSON(s) + } + }) +} + +// UnmarshalJSON unmarshals the OptionalFieldInProto3 from JSON. +func (x *OptionalFieldInProto3) UnmarshalJSON(b []byte) error { + return json.DefaultUnmarshalerConfig.Unmarshal(b, x) +} func (m *OptionalFieldInProto3) MarshalVT() (dAtA []byte, err error) { if m == nil {