Skip to content

Commit

Permalink
chore: add optional to the go2proto test
Browse files Browse the repository at this point in the history
  • Loading branch information
jvmakine committed Jan 21, 2025
1 parent 4d0e5dc commit 22610cd
Show file tree
Hide file tree
Showing 8 changed files with 1,394 additions and 676 deletions.
169 changes: 161 additions & 8 deletions cmd/go2proto/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,22 @@ type Field struct {
OriginType string // The original type of the field, eg. int, string, float32, etc.
ProtoType string // The type of the field in the generated .proto file.
ProtoGoType string // The type of the field in the generated Go protobuf code. eg. int -> int64.
Optional bool
Repeated bool
Pointer bool
Import string // required import to create this type

Kind Kind
Optional bool
OptionalWrapper bool // optional as alecthomas/types/optional.Option
Repeated bool
Pointer bool

Import string // required import to create this type
Converter *TypeConverter
Kind Kind
}

type TypeConverter struct {
// FromProto returns a result[*Type] containing a pointer to the Go type and a potential error
FromProto func(variable string) string
// ToProto returns a pointer to the proto type.
ToProto func(variable string) string
}

var reservedWords = map[string]string{
Expand Down Expand Up @@ -499,7 +509,21 @@ func extract(config Config, pkg *PkgRefs) (File, []string, error) {
return state.Dest, state.GoImports, nil
}

func isOptional(named *types.Named) bool {
path := named.Origin().Obj().Pkg().Path()
name := named.Origin().Obj().Name()
return path == "github.com/alecthomas/types/optional" && name == "Option"
}

func (s *State) extractDecl(obj types.Object, named *types.Named) error {
if isOptional(named) {
u := named.TypeArgs().At(0)
if nt, ok := u.(*types.Named); ok {
return s.extractDecl(nt.Obj(), nt)
}
return nil
}

if named.TypeParams() != nil {
return genErrorf(obj.Pos(), "generic types are not supported")
}
Expand Down Expand Up @@ -582,27 +606,156 @@ func (s *State) populateFields(decl *Message, n *types.Named) error {
field := &Field{
Name: rf.Name(),
}
if err := s.applyFieldType(rf.Type(), field); err != nil {
t := rf.Type()
if nt, ok := t.(*types.Named); ok && isOptional(nt) {
field.Optional = true
field.OptionalWrapper = true
t = nt.TypeArgs().At(0)
}
if err := s.applyFieldType(t, field); err != nil {
return fmt.Errorf("%s: %w", rf.Name(), err)
}
field.ID = tag.ID
field.Optional = tag.Optional
field.Optional = tag.Optional || field.Optional
if field.Optional && field.Repeated {
return genErrorf(n.Obj().Pos(), "%s: repeated optional fields are not supported", rf.Name())
}
if nt, ok := rf.Type().(*types.Named); ok {
if nt, ok := t.(*types.Named); ok {
if err := s.extractDecl(rf, nt); err != nil {
return fmt.Errorf("%s: %w", rf.Name(), err)
}
}
if field.Kind == KindUnspecified {
field.Kind = s.Dest.KindOf(rf.Type(), field.OriginType)
}

s.populateConverters(field)

decl.Fields = append(decl.Fields, field)
}
return errf()
}

func (f *Field) ToProto() string {
// TODO: Refactor types away from here
if f.Optional {
if f.OptionalWrapper {
if f.ProtoType == "google.protobuf.Timestamp" {
return "setNil(timestamppb.New(orZero(optionalOrNil(x." + f.Name + "))), optionalOrNil(x." + f.Name + "))"
} else if f.ProtoType == "google.protobuf.Duration" {
return "setNil(durationpb.New(orZero(optionalOrNil(x." + f.Name + "))), optionalOrNil(x." + f.Name + "))"
} else if f.Kind == KindEnum || f.Kind == KindMessage {
return "optionalOrNil(x." + f.Name + ").ToProto()"
}
return "setNil(ptr(" + f.Converter.ToProto("orZero(optionalOrNil(x."+f.Name+"))") + "), optionalOrNil(x." + f.Name + "))"
} else if f.Kind == KindEnum || f.Kind == KindMessage || f.Kind == KindSumType {
return f.Converter.ToProto("x." + f.Name)
} else if f.Pointer {
return "setNil(ptr(" + f.Converter.ToProto("orZero(x."+f.Name+")") + "), x." + f.Name + ")"
} else if f.ProtoType == "google.protobuf.Timestamp" || f.ProtoType == "google.protobuf.Duration" {
return f.Converter.ToProto("x." + f.Name)
}
return "ptr(" + f.Converter.ToProto("x."+f.Name) + ")"
} else if f.Repeated {
if f.Pointer {
return "sliceMap(x." + f.Name + ", func(v *" + f.OriginType + ") *destpb." + f.ProtoGoType + " { return " + f.Converter.ToProto("v") + " })"
} else if f.Kind == KindMessage || f.Kind == KindSumType {
return "sliceMap(x." + f.Name + ", func(v " + f.OriginType + ") *destpb." + f.ProtoGoType + " { return " + f.Converter.ToProto("v") + " })"
}
return "sliceMap(x." + f.Name + ", func(v " + f.OriginType + ") " + f.ProtoGoType + " { return " + f.Converter.ToProto("v") + " })"
}

return f.Converter.ToProto("x." + f.Name)
}

func (f *Field) FromProto() string {
// input are result[*T]
input := f.Converter.FromProto("v." + f.EscapedName())
if f.Optional {
if f.OptionalWrapper {
return "optionalR(" + input + ")"
}
if !f.Pointer {
return "orZeroR(" + input + ")"
}
return input
} else if f.Repeated {
if !f.Pointer {
if f.Kind == KindMessage || f.Kind == KindSumType {
return "sliceMapR(v." + f.EscapedName() + ", func(v *destpb." + f.ProtoGoType + ") result[" + f.OriginType + "] { return orZeroR(" + f.Converter.FromProto("v") + ") })"
}
return "sliceMapR(v." + f.EscapedName() + ", func(v " + f.ProtoGoType + ") result[" + f.OriginType + "] { return orZeroR(" + f.Converter.FromProto("v") + ") })"
}
return "sliceMapR(v." + f.EscapedName() + ", func(v *destpb." + f.ProtoGoType + ") result[*" + f.OriginType + "] { return " + f.Converter.FromProto("v") + " })"
} else if !f.Pointer {
return "orZeroR(" + f.Converter.FromProto("v."+f.EscapedName()) + ")"
}
return input
}

func (s *State) populateConverters(field *Field) {
if field.ProtoType == "google.protobuf.Timestamp" {
field.Converter = &TypeConverter{
FromProto: func(v string) string { return fmt.Sprintf("toResult(setNil(ptr(%s.AsTime()), %s), nil)", v, v) },
ToProto: func(v string) string { return fmt.Sprintf("timestamppb.New(%s)", v) },
}
} else if field.ProtoType == "google.protobuf.Duration" {
field.Converter = &TypeConverter{
FromProto: func(v string) string { return fmt.Sprintf("toResult(setNil(ptr(%s.AsDuration()), %s), nil)", v, v) },
ToProto: func(v string) string { return fmt.Sprintf("durationpb.New(%s)", v) },
}
} else if field.Kind == KindMessage {
field.Converter = &TypeConverter{
FromProto: func(v string) string { return fmt.Sprintf("toResult(%sFromProto(%s))", field.OriginType, v) },
ToProto: func(v string) string { return fmt.Sprintf("%s.ToProto()", v) },
}
} else if field.Kind == KindEnum {
field.Converter = &TypeConverter{
FromProto: func(v string) string { return fmt.Sprintf("ptrR(toResult(%sFromProto(%s)))", field.OriginType, v) },
ToProto: func(v string) string { return fmt.Sprintf("%s.ToProto()", v) },
}
} else if field.Kind == KindTextMarshaler {
field.Converter = &TypeConverter{
FromProto: func(v string) string {
if field.Pointer {
return fmt.Sprintf("toResult(unmarshallText([]byte(%s), out.%s))", v, field.Name)
}
return fmt.Sprintf("toResult(unmarshallText([]byte(%s), &out.%s))", v, field.Name)
},
ToProto: func(v string) string { return fmt.Sprintf("string(protoMust(%s.MarshalText()))", v) },
}
} else if field.Kind == KindBinaryMarshaler {
field.Converter = &TypeConverter{
FromProto: func(v string) string {
if field.Pointer {
return fmt.Sprintf("toResult(unmarshallBinary(%s, out.%s))", v, field.Name)
}
return fmt.Sprintf("toResult(unmarshallBinary(%s, &out.%s))", v, field.Name)
},
ToProto: func(v string) string { return fmt.Sprintf("protoMust(%s.MarshalBinary())", v) },
}
} else if field.Kind == KindSumType {
field.Converter = &TypeConverter{
FromProto: func(v string) string { return fmt.Sprintf("ptrR(toResult(%sFromProto(%s)))", field.OriginType, v) },
ToProto: func(v string) string { return fmt.Sprintf("%sToProto(%s)", field.OriginType, v) },
}
} else {
if field.Pointer || field.Optional {
field.Converter = &TypeConverter{
FromProto: func(v string) string {
return fmt.Sprintf("toResult(setNil(ptr(%s(orZero(%s))), %s), nil)", field.OriginType, v, v)
},
ToProto: func(v string) string { return fmt.Sprintf("%s(%s)", field.ProtoGoType, v) },
}
} else {
field.Converter = &TypeConverter{
FromProto: func(v string) string { return fmt.Sprintf("toResult(ptr(%s(%s)), nil)", field.OriginType, v) },
ToProto: func(v string) string { return fmt.Sprintf("%s(%s)", field.ProtoGoType, v) },
}
}
}
}

func (s *State) extractSumType(obj types.Object, i *types.Interface) error {
sumTypeName := obj.Name()
if _, ok := s.Seen[sumTypeName]; ok {
Expand Down
Loading

0 comments on commit 22610cd

Please sign in to comment.