Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support simple generic types #5

Merged
merged 1 commit into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
eth2.0-spec-tests
*.tar.gz
*.tar.gz
.idea
78 changes: 60 additions & 18 deletions sszgen/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,13 @@ func appendObjSignature(str string, v *Value) string {
}

type astStruct struct {
name string
obj *ast.StructType
packName string
typ ast.Expr
implFunc bool
isRef bool
name string
obj *ast.StructType
packName string
typ ast.Expr
implFunc bool
isRef bool
paramTypes map[string]int
}

func (a *astStruct) isAlias() bool {
Expand Down Expand Up @@ -565,6 +566,14 @@ func decodeASTStruct(file *ast.File) *astResult {
if ok {
// type is a struct
obj.obj = structType

// process the type params if there is any
if typeSpec.TypeParams != nil && len(typeSpec.TypeParams.List) > 0 {
obj.paramTypes = map[string]int{}
for i, typ := range typeSpec.TypeParams.List {
obj.paramTypes[typ.Names[0].Name] = i
}
}
} else {
if _, ok := typeSpec.Type.(*ast.InterfaceType); !ok {
// type is an alias (skip interfaces)
Expand Down Expand Up @@ -922,8 +931,8 @@ func (e *env) parseASTStructType(name string) (*Value, error) {

visited := map[string]struct{}{}

var getFields func(subName string) ([]*ast.Field, error)
getFields = func(subName string) ([]*ast.Field, error) {
var getFields func(subName string, genericTypes []string) ([]*ast.Field, error)
getFields = func(subName string, genericTypes []string) ([]*ast.Field, error) {
if _, ok := visited[subName]; ok {
return nil, fmt.Errorf("loop in embed types %s", subName)
}
Expand All @@ -946,24 +955,57 @@ func (e *env) parseASTStructType(name string) (*Value, error) {
// skip protobuf methods
continue
}
if itype, ok := f.Type.(*ast.Ident); ok {
if genericIndex, ok := item.paramTypes[itype.Name]; ok {
f.Type.(*ast.Ident).Name = genericTypes[genericIndex]
}
}

fields = append(fields, f)
} else if len(f.Names) == 0 {
// embed item in the same package, resolve it recursively
ident, ok := f.Type.(*ast.Ident)
if !ok {
return nil, fmt.Errorf("embed type expects a typed object in same package but %s found", reflect.TypeOf(f.Type))
}
subFields, err := getFields(ident.Name)
if err != nil {
return nil, err
switch tp := f.Type.(type) {
case *ast.Ident:
// embed item in the same package, resolve it recursively
subFields, err := getFields(tp.Name, nil)
if err != nil {
return nil, err
}
fields = append(fields, subFields...)
case *ast.IndexExpr:
// embed item is a generic type with one parameter
genericType, ok := tp.Index.(*ast.Ident)
if !ok {
return nil, fmt.Errorf("embed type expects a typed object but %s found", reflect.TypeOf(tp.Index))
}
subFields, err := getFields(tp.X.(*ast.Ident).Name, []string{genericType.Name})
if err != nil {
return nil, err
}
fields = append(fields, subFields...)
case *ast.IndexListExpr:
// embed item is a generic type with two or more parameters
var typeList []string
for _, index := range tp.Indices {
genericType, ok := index.(*ast.Ident)
if !ok {
return nil, fmt.Errorf("embed type expects a typed object but %s found", reflect.TypeOf(index))
}
typeList = append(typeList, genericType.Name)
}
subFields, err := getFields(tp.X.(*ast.Ident).Name, typeList)
if err != nil {
return nil, err
}
fields = append(fields, subFields...)
default:
return nil, fmt.Errorf("unknwon embed type %s found", reflect.TypeOf(f.Type))
}
fields = append(fields, subFields...)
}
}
return fields, nil
}

fields, err := getFields(name)
fields, err := getFields(name, nil)
if err != nil {
return nil, err
}
Expand Down
28 changes: 28 additions & 0 deletions sszgen/testcases/generics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package testcases

//go:generate go run ../main.go --path generics.go --objs Test1,Test2

type Generic[T comparable] struct {
Value T
}

type Wrapper struct {
Generic[uint64]
}

type Test1 struct {
G Wrapper
}

type Generic2[T any, F any] struct {
Value1 T
Value2 F
}

type Wrapper2 struct {
Generic2[uint64, uint16]
}

type Test2 struct {
G Wrapper2
}
Loading
Loading