Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support generic struct types #145

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module github.com/jfeliu007/goplantuml

go 1.17
go 1.18

require (
github.com/spf13/afero v1.8.2
golang.org/x/text v0.3.7 // indirect
)
require github.com/spf13/afero v1.8.2

require golang.org/x/text v0.3.7 // indirect
64 changes: 40 additions & 24 deletions parser/class_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ call the Render() function and this will return a string with the class diagram.

See github.com/jfeliu007/goplantuml/cmd/goplantuml/main.go for a command that uses this functions and outputs the text to
the console.

*/
package parser

Expand Down Expand Up @@ -221,17 +220,19 @@ func (p *ClassParser) parsePackage(node ast.Node) {
for fileName := range pack.Files {
sortedFiles = append(sortedFiles, fileName)
}

sort.Strings(sortedFiles)
for _, fileName := range sortedFiles {
if strings.HasSuffix(fileName, "_test.go") {
continue
}

if !strings.HasSuffix(fileName, "_test.go") {
f := pack.Files[fileName]
for _, d := range f.Imports {
p.parseImports(d)
}
for _, d := range f.Decls {
p.parseFileDeclarations(d)
}
f := pack.Files[fileName]
for _, d := range f.Imports {
p.parseImports(d)
}
for _, d := range f.Decls {
p.parseFileDeclarations(d)
}
}
}
Expand Down Expand Up @@ -267,7 +268,6 @@ func (p *ClassParser) parseFileDeclarations(node ast.Decl) {
}

func (p *ClassParser) handleFuncDecl(decl *ast.FuncDecl) {

if decl.Recv != nil {
if decl.Recv.List == nil {
return
Expand Down Expand Up @@ -296,24 +296,30 @@ func (p *ClassParser) handleFuncDecl(decl *ast.FuncDecl) {
}
}

func handleGenDecStructType(p *ClassParser, typeName string, c *ast.StructType) {
func handleGenDecStructType(p *ClassParser, typeName string, c *ast.StructType, typeParams *ast.FieldList) {
for _, f := range c.Fields.List {
p.getOrCreateStruct(typeName).AddField(f, p.allImports)
}

if typeParams == nil {
return
}

for _, tp := range typeParams.List {
p.getOrCreateStruct(typeName).AddTypeParam(tp)
}
}

func handleGenDecInterfaceType(p *ClassParser, typeName string, c *ast.InterfaceType) {
for _, f := range c.Methods.List {
switch t := f.Type.(type) {
case *ast.FuncType:
p.getOrCreateStruct(typeName).AddMethod(f, p.allImports)
break
case *ast.Ident:
f, _ := getFieldType(t, p.allImports)
st := p.getOrCreateStruct(typeName)
f = replacePackageConstant(f, st.PackageName)
st.AddToComposition(f)
break
}
}
}
Expand All @@ -338,7 +344,7 @@ func (p *ClassParser) processSpec(spec ast.Spec) {
switch c := v.Type.(type) {
case *ast.StructType:
declarationType = "class"
handleGenDecStructType(p, typeName, c)
handleGenDecStructType(p, typeName, c, v.TypeParams)
case *ast.InterfaceType:
declarationType = "interface"
handleGenDecInterfaceType(p, typeName, c)
Expand Down Expand Up @@ -379,7 +385,6 @@ func (p *ClassParser) processSpec(spec ast.Spec) {
p.allRenamedStructs[pack[0]][renamedClass] = pack[1]
}
}
return
}

// If this element is an array or a pointer, this function will return the type that is closer to these
Expand Down Expand Up @@ -465,7 +470,7 @@ func (p *ClassParser) renderStructures(pack string, structures map[string]*Struc
str.WriteLineWithDepth(2, aliasComplexNameComment)
str.WriteLineWithDepth(1, "}")
}
str.WriteLineWithDepth(0, fmt.Sprintf(`}`))
str.WriteLineWithDepth(0, `}`)
if p.renderingOptions.Compositions {
str.WriteLineWithDepth(0, composition.String())
}
Expand All @@ -479,7 +484,6 @@ func (p *ClassParser) renderStructures(pack string, structures map[string]*Struc
}

func (p *ClassParser) renderAliases(str *LineStringBuilder) {

aliasString := ""
if p.renderingOptions.ConnectionLabels {
aliasString = aliasOf
Expand All @@ -505,7 +509,6 @@ func (p *ClassParser) renderAliases(str *LineStringBuilder) {
}

func (p *ClassParser) renderStructure(structure *Struct, pack string, name string, str *LineStringBuilder, composition *LineStringBuilder, extends *LineStringBuilder, aggregations *LineStringBuilder) {

privateFields := &LineStringBuilder{}
publicFields := &LineStringBuilder{}
privateMethods := &LineStringBuilder{}
Expand All @@ -518,9 +521,24 @@ func (p *ClassParser) renderStructure(structure *Struct, pack string, name strin
case "alias":
sType = "<< (T, #FF7700) >> "
renderStructureType = "class"
}

types := ""
if structure.Generics.exists() {
types = "<"
for t := range structure.Generics.Types {
types += fmt.Sprintf("%s, ", t)
}
types = strings.TrimSuffix(types, ", ")
types += " constrains "
for _, n := range structure.Generics.Names {
types += fmt.Sprintf("%s, ", n)
}
types = strings.TrimSuffix(types, ", ")
types += ">"
}
str.WriteLineWithDepth(1, fmt.Sprintf(`%s %s %s {`, renderStructureType, name, sType))

str.WriteLineWithDepth(1, fmt.Sprintf(`%s %s%s %s {`, renderStructureType, name, types, sType))
p.renderStructFields(structure, privateFields, publicFields)
p.renderStructMethods(structure, privateMethods, publicMethods)
p.renderCompositions(structure, name, composition)
Expand All @@ -538,7 +556,7 @@ func (p *ClassParser) renderStructure(structure *Struct, pack string, name strin
if publicMethods.Len() > 0 {
str.WriteLineWithDepth(0, publicMethods.String())
}
str.WriteLineWithDepth(1, fmt.Sprintf(`}`))
str.WriteLineWithDepth(1, `}`)
}

func (p *ClassParser) renderCompositions(structure *Struct, name string, composition *LineStringBuilder) {
Expand All @@ -562,7 +580,6 @@ func (p *ClassParser) renderCompositions(structure *Struct, name string, composi
}

func (p *ClassParser) renderAggregations(structure *Struct, name string, aggregations *LineStringBuilder) {

aggregationMap := structure.Aggregations
if p.renderingOptions.AggregatePrivateMembers {
p.updatePrivateAggregations(structure, aggregationMap)
Expand All @@ -571,7 +588,6 @@ func (p *ClassParser) renderAggregations(structure *Struct, name string, aggrega
}

func (p *ClassParser) updatePrivateAggregations(structure *Struct, aggregationsMap map[string]struct{}) {

for agg := range structure.PrivateAggregations {
aggregationsMap[agg] = struct{}{}
}
Expand Down Expand Up @@ -600,13 +616,13 @@ func (p *ClassParser) renderAggregationMap(aggregationMap map[string]struct{}, s
}

func (p *ClassParser) getPackageName(t string, st *Struct) string {

packageName := st.PackageName
if isPrimitiveString(t) {
packageName = builtinPackageName
}
return packageName
}

func (p *ClassParser) renderExtends(structure *Struct, name string, extends *LineStringBuilder) {

orderedExtends := []string{}
Expand All @@ -628,7 +644,6 @@ func (p *ClassParser) renderExtends(structure *Struct, name string, extends *Lin
}

func (p *ClassParser) renderStructMethods(structure *Struct, privateMethods *LineStringBuilder, publicMethods *LineStringBuilder) {

for _, method := range structure.Functions {
accessModifier := "+"
if unicode.IsLower(rune(method.Name[0])) {
Expand Down Expand Up @@ -685,6 +700,7 @@ func (p *ClassParser) getOrCreateStruct(name string) *Struct {
Functions: make([]*Function, 0),
Fields: make([]*Field, 0),
Type: "",
Generics: NewGeneric(),
Composition: make(map[string]struct{}, 0),
Extends: make(map[string]struct{}, 0),
Aggregations: make(map[string]struct{}, 0),
Expand Down
9 changes: 5 additions & 4 deletions parser/class_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package parser

import (
"go/ast"
"io/ioutil"
"os"
"reflect"
"testing"
)
Expand Down Expand Up @@ -94,6 +94,7 @@ func TestGetOrCreateStruct(t *testing.T) {
Functions: make([]*Function, 0),
Fields: make([]*Field, 0),
Type: "",
Generics: NewGeneric(),
Composition: make(map[string]struct{}, 0),
Extends: make(map[string]struct{}, 0),
Aggregations: make(map[string]struct{}, 0),
Expand Down Expand Up @@ -181,7 +182,6 @@ func TestRenderStructFields(t *testing.T) {
}

func TestRenderStructures(t *testing.T) {

structMap := map[string]*Struct{
"MainClass": getTestStruct(),
}
Expand Down Expand Up @@ -296,6 +296,7 @@ func getTestStruct() *Struct {
ReturnValues: []string{"int"},
},
},
Generics: NewGeneric(),
}
}

Expand Down Expand Up @@ -563,7 +564,7 @@ func TestRender(t *testing.T) {
})

resultRender := parser.Render()
result, err := ioutil.ReadFile("../testingsupport/testingsupport.puml")
result, err := os.ReadFile("../testingsupport/testingsupport.puml")
if err != nil {
t.Errorf("TestRender: expected no errors reading testing file, got %s", err.Error())
}
Expand Down Expand Up @@ -592,7 +593,7 @@ func TestMultipleFolders(t *testing.T) {
}

resultRender := parser.Render()
result, err := ioutil.ReadFile("../testingsupport/subfolder1-2.puml")
result, err := os.ReadFile("../testingsupport/subfolder1-2.puml")
if err != nil {
t.Errorf("TestMultipleFolders: expected no errors reading testing file, got %s", err.Error())
}
Expand Down
30 changes: 16 additions & 14 deletions parser/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (

const packageConstant = "{packageName}"

//Field can hold the name and type of any field
// Field can hold the name and type of any field
type Field struct {
Name string
Type string
FullType string
}

//Returns a string representation of the given expression if it was recognized.
//Refer to the implementation to see the different string representations.
// Returns a string representation of the given expression if it was recognized.
// Refer to the implementation to see the different string representations.
func getFieldType(exp ast.Expr, aliases map[string]string) (string, []string) {
switch v := exp.(type) {
case *ast.Ident:
Expand All @@ -40,12 +40,23 @@ func getFieldType(exp ast.Expr, aliases map[string]string) (string, []string) {
return getFuncType(v, aliases)
case *ast.Ellipsis:
return getEllipsis(v, aliases)
case *ast.IndexExpr:
return getIndexExpr(v, aliases)
case *ast.IndexListExpr:
return getIndexListExpr(v, aliases)
}
return "", []string{}
}

func getIdent(v *ast.Ident, aliases map[string]string) (string, []string) {
func getIndexExpr(v *ast.IndexExpr, aliases map[string]string) (string, []string) {
return getFieldType(v.X, aliases)
}

func getIndexListExpr(v *ast.IndexListExpr, aliases map[string]string) (string, []string) {
return getFieldType(v.X, aliases)
}

func getIdent(v *ast.Ident, aliases map[string]string) (string, []string) {
if isPrimitive(v) {
return v.Name, []string{}
}
Expand All @@ -59,7 +70,6 @@ func getArrayType(v *ast.ArrayType, aliases map[string]string) (string, []string
}

func getSelectorExp(v *ast.SelectorExpr, aliases map[string]string) (string, []string) {

packageName := v.X.(*ast.Ident).Name
if realPackageName, ok := aliases[packageName]; ok {
packageName = realPackageName
Expand All @@ -69,26 +79,22 @@ func getSelectorExp(v *ast.SelectorExpr, aliases map[string]string) (string, []s
}

func getMapType(v *ast.MapType, aliases map[string]string) (string, []string) {

t1, f1 := getFieldType(v.Key, aliases)
t2, f2 := getFieldType(v.Value, aliases)
return fmt.Sprintf("<font color=blue>map</font>[%s]%s", t1, t2), append(f1, f2...)
}

func getStarExp(v *ast.StarExpr, aliases map[string]string) (string, []string) {

t, f := getFieldType(v.X, aliases)
return fmt.Sprintf("*%s", t), f
}

func getChanType(v *ast.ChanType, aliases map[string]string) (string, []string) {

t, f := getFieldType(v.Value, aliases)
return fmt.Sprintf("<font color=blue>chan</font> %s", t), f
}

func getStructType(v *ast.StructType, aliases map[string]string) (string, []string) {

fieldList := make([]string, 0)
for _, field := range v.Fields.List {
t, _ := getFieldType(field.Type, aliases)
Expand All @@ -98,7 +104,6 @@ func getStructType(v *ast.StructType, aliases map[string]string) (string, []stri
}

func getInterfaceType(v *ast.InterfaceType, aliases map[string]string) (string, []string) {

methods := make([]string, 0)
for _, field := range v.Methods.List {
methodName := ""
Expand All @@ -112,17 +117,14 @@ func getInterfaceType(v *ast.InterfaceType, aliases map[string]string) (string,
}

func getFuncType(v *ast.FuncType, aliases map[string]string) (string, []string) {

function := getFunction(v, "", aliases, "")
params := make([]string, 0)
for _, pa := range function.Parameters {
params = append(params, pa.Type)
}
returns := ""
returnList := make([]string, 0)
for _, re := range function.ReturnValues {
returnList = append(returnList, re)
}
returnList = append(returnList, function.ReturnValues...)
if len(returnList) > 1 {
returns = fmt.Sprintf("(%s)", strings.Join(returnList, ", "))
} else {
Expand Down
Loading