Skip to content

Commit

Permalink
Support simple generic types
Browse files Browse the repository at this point in the history
Signed-off-by: Mikhail Sherstennikov <[email protected]>
  • Loading branch information
shermike committed Jun 30, 2024
1 parent e6b840a commit 05d4f47
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 19 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
eth2.0-spec-tests
*.tar.gz
*.tar.gz
.idea
78 changes: 60 additions & 18 deletions sszgen/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,13 @@ func appendObjSignature(str string, v *Value) string {
}

type astStruct struct {
name string
obj *ast.StructType
packName string
typ ast.Expr
implFunc bool
isRef bool
name string
obj *ast.StructType
packName string
typ ast.Expr
implFunc bool
isRef bool
paramTypes map[string]int
}

func (a *astStruct) isAlias() bool {
Expand Down Expand Up @@ -565,6 +566,14 @@ func decodeASTStruct(file *ast.File) *astResult {
if ok {
// type is a struct
obj.obj = structType

// process the type params if there is any
if typeSpec.TypeParams != nil && len(typeSpec.TypeParams.List) > 0 {
obj.paramTypes = map[string]int{}
for i, typ := range typeSpec.TypeParams.List {
obj.paramTypes[typ.Names[0].Name] = i
}
}
} else {
if _, ok := typeSpec.Type.(*ast.InterfaceType); !ok {
// type is an alias (skip interfaces)
Expand Down Expand Up @@ -922,8 +931,8 @@ func (e *env) parseASTStructType(name string) (*Value, error) {

visited := map[string]struct{}{}

var getFields func(subName string) ([]*ast.Field, error)
getFields = func(subName string) ([]*ast.Field, error) {
var getFields func(subName string, genericTypes []string) ([]*ast.Field, error)
getFields = func(subName string, genericTypes []string) ([]*ast.Field, error) {
if _, ok := visited[subName]; ok {
return nil, fmt.Errorf("loop in embed types %s", subName)
}
Expand All @@ -946,24 +955,57 @@ func (e *env) parseASTStructType(name string) (*Value, error) {
// skip protobuf methods
continue
}
if itype, ok := f.Type.(*ast.Ident); ok {
if genericIndex, ok := item.paramTypes[itype.Name]; ok {
f.Type.(*ast.Ident).Name = genericTypes[genericIndex]
}
}

fields = append(fields, f)
} else if len(f.Names) == 0 {
// embed item in the same package, resolve it recursively
ident, ok := f.Type.(*ast.Ident)
if !ok {
return nil, fmt.Errorf("embed type expects a typed object in same package but %s found", reflect.TypeOf(f.Type))
}
subFields, err := getFields(ident.Name)
if err != nil {
return nil, err
switch tp := f.Type.(type) {
case *ast.Ident:
// embed item in the same package, resolve it recursively
subFields, err := getFields(tp.Name, nil)
if err != nil {
return nil, err
}
fields = append(fields, subFields...)
case *ast.IndexExpr:
// embed item is a generic type with one parameter
genericType, ok := tp.Index.(*ast.Ident)
if !ok {
return nil, fmt.Errorf("embed type expects a typed object but %s found", reflect.TypeOf(tp.Index))
}
subFields, err := getFields(tp.X.(*ast.Ident).Name, []string{genericType.Name})
if err != nil {
return nil, err
}
fields = append(fields, subFields...)
case *ast.IndexListExpr:
// embed item is a generic type with two or more parameters
var typeList []string
for _, index := range tp.Indices {
genericType, ok := index.(*ast.Ident)
if !ok {
return nil, fmt.Errorf("embed type expects a typed object but %s found", reflect.TypeOf(index))
}
typeList = append(typeList, genericType.Name)
}
subFields, err := getFields(tp.X.(*ast.Ident).Name, typeList)
if err != nil {
return nil, err
}
fields = append(fields, subFields...)
default:
return nil, fmt.Errorf("unknwon embed type %s found", reflect.TypeOf(f.Type))
}
fields = append(fields, subFields...)
}
}
return fields, nil
}

fields, err := getFields(name)
fields, err := getFields(name, nil)
if err != nil {
return nil, err
}
Expand Down
28 changes: 28 additions & 0 deletions sszgen/testcases/generics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package testcases

//go:generate go run ../main.go --path generics.go --objs Test1,Test2

type Generic[T comparable] struct {
Value T
}

type Wrapper struct {
Generic[uint64]
}

type Test1 struct {
G Wrapper
}

type Generic2[T any, F any] struct {
Value1 T
Value2 F
}

type Wrapper2 struct {
Generic2[uint64, uint16]
}

type Test2 struct {
G Wrapper2
}
Loading

0 comments on commit 05d4f47

Please sign in to comment.