diff --git a/.gitignore b/.gitignore index 387c399..d20ffe7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ eth2.0-spec-tests -*.tar.gz \ No newline at end of file +*.tar.gz +.idea \ No newline at end of file diff --git a/sszgen/generator/generator.go b/sszgen/generator/generator.go index 57267e7..01fc344 100644 --- a/sszgen/generator/generator.go +++ b/sszgen/generator/generator.go @@ -518,12 +518,13 @@ func appendObjSignature(str string, v *Value) string { } type astStruct struct { - name string - obj *ast.StructType - packName string - typ ast.Expr - implFunc bool - isRef bool + name string + obj *ast.StructType + packName string + typ ast.Expr + implFunc bool + isRef bool + paramTypes map[string]int } func (a *astStruct) isAlias() bool { @@ -565,6 +566,14 @@ func decodeASTStruct(file *ast.File) *astResult { if ok { // type is a struct obj.obj = structType + + // process the type params if there is any + if typeSpec.TypeParams != nil && len(typeSpec.TypeParams.List) > 0 { + obj.paramTypes = map[string]int{} + for i, typ := range typeSpec.TypeParams.List { + obj.paramTypes[typ.Names[0].Name] = i + } + } } else { if _, ok := typeSpec.Type.(*ast.InterfaceType); !ok { // type is an alias (skip interfaces) @@ -922,8 +931,8 @@ func (e *env) parseASTStructType(name string) (*Value, error) { visited := map[string]struct{}{} - var getFields func(subName string) ([]*ast.Field, error) - getFields = func(subName string) ([]*ast.Field, error) { + var getFields func(subName string, genericTypes []string) ([]*ast.Field, error) + getFields = func(subName string, genericTypes []string) ([]*ast.Field, error) { if _, ok := visited[subName]; ok { return nil, fmt.Errorf("loop in embed types %s", subName) } @@ -946,24 +955,57 @@ func (e *env) parseASTStructType(name string) (*Value, error) { // skip protobuf methods continue } + if itype, ok := f.Type.(*ast.Ident); ok { + if genericIndex, ok := item.paramTypes[itype.Name]; ok { + f.Type.(*ast.Ident).Name = genericTypes[genericIndex] + } + } + fields = append(fields, f) } else if len(f.Names) == 0 { - // embed item in the same package, resolve it recursively - ident, ok := f.Type.(*ast.Ident) - if !ok { - return nil, fmt.Errorf("embed type expects a typed object in same package but %s found", reflect.TypeOf(f.Type)) - } - subFields, err := getFields(ident.Name) - if err != nil { - return nil, err + switch tp := f.Type.(type) { + case *ast.Ident: + // embed item in the same package, resolve it recursively + subFields, err := getFields(tp.Name, nil) + if err != nil { + return nil, err + } + fields = append(fields, subFields...) + case *ast.IndexExpr: + // embed item is a generic type with one parameter + genericType, ok := tp.Index.(*ast.Ident) + if !ok { + return nil, fmt.Errorf("embed type expects a typed object but %s found", reflect.TypeOf(tp.Index)) + } + subFields, err := getFields(tp.X.(*ast.Ident).Name, []string{genericType.Name}) + if err != nil { + return nil, err + } + fields = append(fields, subFields...) + case *ast.IndexListExpr: + // embed item is a generic type with two or more parameters + var typeList []string + for _, index := range tp.Indices { + genericType, ok := index.(*ast.Ident) + if !ok { + return nil, fmt.Errorf("embed type expects a typed object but %s found", reflect.TypeOf(index)) + } + typeList = append(typeList, genericType.Name) + } + subFields, err := getFields(tp.X.(*ast.Ident).Name, typeList) + if err != nil { + return nil, err + } + fields = append(fields, subFields...) + default: + return nil, fmt.Errorf("unknwon embed type %s found", reflect.TypeOf(f.Type)) } - fields = append(fields, subFields...) } } return fields, nil } - fields, err := getFields(name) + fields, err := getFields(name, nil) if err != nil { return nil, err } diff --git a/sszgen/testcases/generics.go b/sszgen/testcases/generics.go new file mode 100644 index 0000000..25b33eb --- /dev/null +++ b/sszgen/testcases/generics.go @@ -0,0 +1,28 @@ +package testcases + +//go:generate go run ../main.go --path generics.go --objs Test1,Test2 + +type Generic[T comparable] struct { + Value T +} + +type Wrapper struct { + Generic[uint64] +} + +type Test1 struct { + G Wrapper +} + +type Generic2[T any, F any] struct { + Value1 T + Value2 F +} + +type Wrapper2 struct { + Generic2[uint64, uint16] +} + +type Test2 struct { + G Wrapper2 +} diff --git a/sszgen/testcases/generics_encoding.go b/sszgen/testcases/generics_encoding.go new file mode 100644 index 0000000..136eae4 --- /dev/null +++ b/sszgen/testcases/generics_encoding.go @@ -0,0 +1,253 @@ +// Code generated by fastssz. DO NOT EDIT. +// Hash: 279ec70eb687a0d8b56126a61c51b8f239f2cd00422ae048f8c0977d2fc1febc +// Version: 0.1.3 +package testcases + +import ( + ssz "github.com/NilFoundation/fastssz" +) + +// MarshalSSZ ssz marshals the Wrapper object +func (w *Wrapper) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(w) +} + +// MarshalSSZTo ssz marshals the Wrapper object to a target array +func (w *Wrapper) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + + // Field (0) 'Value' + dst = ssz.MarshalUint64(dst, w.Value) + + return +} + +// UnmarshalSSZ ssz unmarshals the Wrapper object +func (w *Wrapper) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size != 8 { + return ssz.ErrSize + } + + // Field (0) 'Value' + w.Value = ssz.UnmarshallUint64(buf[0:8]) + + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Wrapper object +func (w *Wrapper) SizeSSZ() (size int) { + size = 8 + return +} + +// HashTreeRoot ssz hashes the Wrapper object +func (w *Wrapper) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(w) +} + +// HashTreeRootWith ssz hashes the Wrapper object with a hasher +func (w *Wrapper) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Value' + hh.PutUint64(w.Value) + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Wrapper object +func (w *Wrapper) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(w) +} + +// MarshalSSZ ssz marshals the Test1 object +func (t *Test1) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(t) +} + +// MarshalSSZTo ssz marshals the Test1 object to a target array +func (t *Test1) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + + // Field (0) 'G' + if dst, err = t.G.MarshalSSZTo(dst); err != nil { + return + } + + return +} + +// UnmarshalSSZ ssz unmarshals the Test1 object +func (t *Test1) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size != 8 { + return ssz.ErrSize + } + + // Field (0) 'G' + if err = t.G.UnmarshalSSZ(buf[0:8]); err != nil { + return err + } + + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Test1 object +func (t *Test1) SizeSSZ() (size int) { + size = 8 + return +} + +// HashTreeRoot ssz hashes the Test1 object +func (t *Test1) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(t) +} + +// HashTreeRootWith ssz hashes the Test1 object with a hasher +func (t *Test1) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'G' + if err = t.G.HashTreeRootWith(hh); err != nil { + return + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Test1 object +func (t *Test1) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(t) +} + +// MarshalSSZ ssz marshals the Wrapper2 object +func (w *Wrapper2) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(w) +} + +// MarshalSSZTo ssz marshals the Wrapper2 object to a target array +func (w *Wrapper2) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + + // Field (0) 'Value1' + dst = ssz.MarshalUint64(dst, w.Value1) + + // Field (1) 'Value2' + dst = ssz.MarshalUint16(dst, w.Value2) + + return +} + +// UnmarshalSSZ ssz unmarshals the Wrapper2 object +func (w *Wrapper2) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size != 10 { + return ssz.ErrSize + } + + // Field (0) 'Value1' + w.Value1 = ssz.UnmarshallUint64(buf[0:8]) + + // Field (1) 'Value2' + w.Value2 = ssz.UnmarshallUint16(buf[8:10]) + + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Wrapper2 object +func (w *Wrapper2) SizeSSZ() (size int) { + size = 10 + return +} + +// HashTreeRoot ssz hashes the Wrapper2 object +func (w *Wrapper2) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(w) +} + +// HashTreeRootWith ssz hashes the Wrapper2 object with a hasher +func (w *Wrapper2) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Value1' + hh.PutUint64(w.Value1) + + // Field (1) 'Value2' + hh.PutUint16(w.Value2) + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Wrapper2 object +func (w *Wrapper2) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(w) +} + +// MarshalSSZ ssz marshals the Test2 object +func (t *Test2) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(t) +} + +// MarshalSSZTo ssz marshals the Test2 object to a target array +func (t *Test2) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + + // Field (0) 'G' + if dst, err = t.G.MarshalSSZTo(dst); err != nil { + return + } + + return +} + +// UnmarshalSSZ ssz unmarshals the Test2 object +func (t *Test2) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size != 10 { + return ssz.ErrSize + } + + // Field (0) 'G' + if err = t.G.UnmarshalSSZ(buf[0:10]); err != nil { + return err + } + + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Test2 object +func (t *Test2) SizeSSZ() (size int) { + size = 10 + return +} + +// HashTreeRoot ssz hashes the Test2 object +func (t *Test2) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(t) +} + +// HashTreeRootWith ssz hashes the Test2 object with a hasher +func (t *Test2) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'G' + if err = t.G.HashTreeRootWith(hh); err != nil { + return + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Test2 object +func (t *Test2) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(t) +}