From c7e9def121fac51a5c37c5db772bb2536dd5040d Mon Sep 17 00:00:00 2001 From: Mostafa Date: Wed, 19 Jun 2024 02:21:31 +0800 Subject: [PATCH] display all maps as their own key/value --- template.go | 39 ++++++++++++++++++++++++++++++++++++++- template_test.go | 6 +++--- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/template.go b/template.go index 268a242..b8b5224 100644 --- a/template.go +++ b/template.go @@ -6,9 +6,9 @@ import ( "sort" "strings" + "github.com/golang/protobuf/protoc-gen-go/descriptor" "github.com/pactus-project/protoc-gen-doc/extensions" "github.com/pseudomuto/protokit" - "github.com/golang/protobuf/protoc-gen-go/descriptor" ) // Template is a type for encapsulating all the parsed files, messages, fields, enums, services, extensions, etc. into @@ -63,6 +63,33 @@ func NewTemplate(descs []*protokit.FileDescriptor) *Template { addFromMessage(m) } + /// Post processing the messages + for _, msg := range file.Messages { + for _, f := range msg.Fields { + if f.IsMap { + index, msg := getMessageByName(&file.Messages, f.Type) + if msg == nil || len(msg.Fields) != 2 { + panic(fmt.Sprintf("unable to find key/va;ue for %s", f.Name)) + } + + keyField := msg.Fields[0] + valueField := msg.Fields[1] + if keyField.Name != "key" { + panic(fmt.Sprintf("expected map entry's first field to be 'key', not '%s'", keyField.Name)) + } + if valueField.Name != "value" { + panic(fmt.Sprintf("expected map entry's first field to be 'value', not '%s'", valueField.Name)) + } + typeName := fmt.Sprintf("map<%s, %s>", keyField.Type, valueField.Type) + f.Type = typeName + f.FullType = typeName + f.LongType = typeName + + file.Messages = append(file.Messages[:index], file.Messages[index+1:]...) + } + } + } + for _, s := range f.Services { file.Services = append(file.Services, parseService(s)) } @@ -78,6 +105,16 @@ func NewTemplate(descs []*protokit.FileDescriptor) *Template { return &Template{Files: files, Scalars: makeScalars()} } +func getMessageByName(orderedMessages *orderedMessages, name string) (int, *Message) { + for index, msg := range *orderedMessages { + if msg.Name == name { + return index, msg + } + } + + return -1, nil +} + func makeScalars() []*ScalarValue { var scalars []*ScalarValue json.Unmarshal(scalarsJSON, &scalars) diff --git a/template_test.go b/template_test.go index 4f6825e..90b5cc1 100644 --- a/template_test.go +++ b/template_test.go @@ -278,9 +278,9 @@ func TestFieldProperties(t *testing.T) { field = findField("properties", findMessage("Vehicle", vehicleFile)) require.Equal(t, "properties", field.Name) require.Equal(t, "repeated", field.Label) - require.Equal(t, "PropertiesEntry", field.Type) - require.Equal(t, "Vehicle.PropertiesEntry", field.LongType) - require.Equal(t, "com.example.Vehicle.PropertiesEntry", field.FullType) + require.Equal(t, "map", field.Type) + require.Equal(t, "map", field.LongType) + require.Equal(t, "map", field.FullType) require.Empty(t, field.DefaultValue) require.True(t, field.IsMap) require.False(t, field.IsOneof)