diff --git a/extract/extract.go b/extract/extract.go index 85f2e0422..cee98835a 100644 --- a/extract/extract.go +++ b/extract/extract.go @@ -23,6 +23,8 @@ import ( "strconv" "strings" "text/template" + + "golang.org/x/tools/go/packages" ) const model = `// Code generated by 'yaegi extract {{.ImportPath}}'. DO NOT EDIT. @@ -141,7 +143,7 @@ type Extractor struct { Tag []string // Comma separated of build tags to be added to the created package. } -func (e *Extractor) genContent(importPath string, p *types.Package) ([]byte, error) { +func (e *Extractor) genContent(importPath string, p *types.Package, fset *token.FileSet) ([]byte, error) { prefix := "_" + importPath + "_" prefix = strings.NewReplacer("/", "_", "-", "_", ".", "_", "~", "_").Replace(prefix) @@ -201,8 +203,31 @@ func (e *Extractor) genContent(importPath string, p *types.Package) ([]byte, err val[name] = Val{pname, false} } case *types.Func: - // Skip generic functions and methods. + // Generic functions and methods must be extracted as code that + // can be interpreted, since they cannot be compiled in. if s := o.Type().(*types.Signature); s.TypeParams().Len() > 0 || s.RecvTypeParams().Len() > 0 { + scope := o.Scope() + start, end := scope.Pos(), scope.End() + ff := fset.File(start) + base := token.Pos(ff.Base()) + start -= base + end -= base + + f, err := os.Open(ff.Name()) + if err != nil { + return nil, err + } + b := make([]byte, end-start) + _, err = f.ReadAt(b, int64(start)) + if err != nil { + return nil, err + } + // only add if we have a //yaegi:add directive + if !bytes.Contains(b, []byte(`//yaegi:add`)) { + continue + } + val[name] = Val{fmt.Sprintf("interp.GenericFunc(%q)", b), false} + imports["github.com/traefik/yaegi/interp"] = true continue } val[name] = Val{pname, false} @@ -447,16 +472,47 @@ func (e *Extractor) Extract(pkgIdent, importPath string, rw io.Writer) (string, return "", err } - pkg, err := importer.ForCompiler(token.NewFileSet(), "source", nil).Import(pkgIdent) - if err != nil { - return "", err + var pkg *types.Package + isRelative := strings.HasPrefix(pkgIdent, ".") + fset := token.NewFileSet() + // If we are relative with a manual import path, we cannot use modules + // and must fall back on the standard go/importer loader. + if isRelative && importPath != "" { + pkg, err = importer.ForCompiler(fset, "source", nil).Import(pkgIdent) + if err != nil { + return "", err + } + } else { + // Otherwise, we can use the much faster x/tools/go/packages loader. + if isRelative { + // We must be in the location of the module for the loader to work correctly. + err := os.Chdir(pkgIdent) + if err != nil { + return "", err + } + // Our path must point back to ourself here. + pkgIdent = filepath.Join("..", filepath.Base(pkgIdent)) + } + // NeedsSyntax is needed for getting the scopes of generic functions. + pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedTypes | packages.NeedSyntax}, pkgIdent) + if err != nil { + return "", err + } + if len(pkgs) != 1 { + return "", fmt.Errorf("expected one package, got %d", len(pkgs)) + } + ppkg := pkgs[0] + if len(ppkg.Errors) > 0 { + return "", ppkg.Errors[0] + } + pkg = ppkg.Types + fset = ppkg.Fset } - content, err := e.genContent(ipp, pkg) + content, err := e.genContent(ipp, pkg, fset) if err != nil { return "", err } - if _, err := rw.Write(content); err != nil { return "", err } diff --git a/extract/extract_test.go b/extract/extract_test.go index a08e34faf..65880adc1 100644 --- a/extract/extract_test.go +++ b/extract/extract_test.go @@ -119,6 +119,30 @@ type _guthib_com_variadic_Variadic struct { func (W _guthib_com_variadic_Variadic) Call(method string, args ...[]interface{}) (interface{}, error) { return W.WCall(method, args...) } +`[1:], + }, + { + desc: "using relative path, function is generic", + wd: "./testdata/8/src/guthib.com/generic", + arg: "../generic", + importPath: "guthib.com/generic", + expected: ` +// Code generated by 'yaegi extract guthib.com/generic'. DO NOT EDIT. + +package generic + +import ( + "github.com/traefik/yaegi/interp" + "guthib.com/generic" + "reflect" +) + +func init() { + Symbols["guthib.com/generic/generic"] = map[string]reflect.Value{ + // function, constant and variable definitions + "Hello": reflect.ValueOf(interp.GenericFunc("func Hello[T comparable](v T) *T { //yaegi:add\n\treturn &v\n}")), + } +} `[1:], }, } diff --git a/extract/testdata/8/src/guthib.com/generic/generic.go b/extract/testdata/8/src/guthib.com/generic/generic.go new file mode 100644 index 000000000..a3f6a7e7e --- /dev/null +++ b/extract/testdata/8/src/guthib.com/generic/generic.go @@ -0,0 +1,5 @@ +package generic + +func Hello[T comparable](v T) *T { //yaegi:add + return &v +} diff --git a/extract/testdata/8/src/guthib.com/generic/go.mod b/extract/testdata/8/src/guthib.com/generic/go.mod new file mode 100644 index 000000000..8cb7dd0ac --- /dev/null +++ b/extract/testdata/8/src/guthib.com/generic/go.mod @@ -0,0 +1,4 @@ +module guthib.com/generic + +go 1.21 + diff --git a/go.mod b/go.mod index e8cc07765..c093c494f 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,10 @@ module github.com/traefik/yaegi go 1.21 + +require golang.org/x/tools v0.22.0 + +require ( + golang.org/x/mod v0.18.0 // indirect + golang.org/x/sync v0.7.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..16f2f692e --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= diff --git a/interp/cfg.go b/interp/cfg.go index 90db9a545..32e240ae8 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -1214,11 +1214,12 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string case c0.isType(sc): // Type conversion expression - c1 := n.child[1] + var c1 *node switch len(n.child) { case 1: err = n.cfgErrorf("missing argument in conversion to %s", c0.typ.id()) case 2: + c1 = n.child[1] err = check.conversion(c1, c0.typ) default: err = n.cfgErrorf("too many arguments in conversion to %s", c0.typ.id()) @@ -1293,7 +1294,10 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string default: // The call may be on a generic function. In that case, replace the // generic function AST by an instantiated one before going further. - if isGeneric(c0.typ) { + if c0.typ == nil { + err = c0.cfgErrorf("nil type for function call: likely generic type error") + break + } else if isGeneric(c0.typ) { fun := c0.typ.node.anc var g *node var types []*itype @@ -2553,7 +2557,18 @@ func (n *node) isType(sc *scope) bool { return true // Imported source type } case identExpr: - return sc.getType(n.ident) != nil + sym, _, found := sc.lookup(n.ident) + if found { + return sym.kind == typeSym + } + // note: in case of generic functions, the type might not exist within + // the scope where the generic function was defined, so we + // fall back on comparing the scopes: anything out of scope is assumed + // to be a type. + if n.typ == nil || n.typ.scope == nil { + return false + } + return n.typ.scope.pkgID != sc.pkgID case indexExpr: // Maybe a generic type. sym, _, ok := sc.lookup(n.child[0].ident) @@ -2879,6 +2894,9 @@ func setExec(n *node) { set(n.fnext) } } + if n.gen == nil { + n.gen = nop + } n.gen(n) } diff --git a/interp/generic.go b/interp/generic.go index e1b06d0ec..748cbcf58 100644 --- a/interp/generic.go +++ b/interp/generic.go @@ -5,6 +5,13 @@ import ( "sync/atomic" ) +// GenericFunc contains the code of a generic function. +// This is used in the `yaegi extract` command to represent generic functions +// instead of the actual value of the function, since you cannot get the +// [reflect.Value] of a generic function. This is then used to interpret the +// function when it is imported in yaegi. +type GenericFunc string + // adot produces an AST dot(1) directed acyclic graph for the given node. For debugging only. // func (n *node) adot() { n.astDot(dotWriter(n.interp.dotCmd), n.ident) } @@ -58,7 +65,7 @@ func genAST(sc *scope, root *node, types []*itype) (*node, bool, error) { case fieldList: // Node is the type parameters list of a generic function. - if root.kind == funcDecl && n.anc == root.child[2] && childPos(n) == 0 { + if root.kind == funcDecl && n.anc == root.child[2] && childPos(n) == 0 && len(types) > 0 { // Fill the types lookup table used for type substitution. for _, c := range n.child { l := len(c.child) - 1 @@ -291,11 +298,15 @@ func inferTypesFromCall(sc *scope, fun *node, args []*node) ([]*itype, error) { if err != nil { return nil, err } - lt, err := inferTypes(typ, args[i].typ) - if err != nil { - return nil, err + if i < len(args) { + lt, err := inferTypes(typ, args[i].typ) + if err != nil { + return nil, err + } + types = append(types, lt...) + } else { + types = append(types, typ) } - types = append(types, lt...) } return types, nil diff --git a/interp/generic_test.go b/interp/generic_test.go new file mode 100644 index 000000000..181653e9d --- /dev/null +++ b/interp/generic_test.go @@ -0,0 +1,200 @@ +package interp + +import ( + "reflect" + "testing" +) + +func TestGenericFuncDeclare(t *testing.T) { + i := New(Options{}) + _, err := i.Eval("func Hello[T comparable](v T) *T {\n\treturn &v\n}") + if err != nil { + t.Error(err) + } + res, err := i.Eval("Hello(3)") + if err != nil { + t.Error(err) + } + if res.Elem().Interface() != 3 { + t.Error("expected &(3), got", res) + } +} + +func TestGenericFuncBasic(t *testing.T) { + i := New(Options{}) + err := i.Use(Exports{ + "guthib.com/generic/generic": map[string]reflect.Value{ + "Hello": reflect.ValueOf(GenericFunc("func Hello[T comparable](v T) *T {\n\treturn &v\n}")), + }, + }) + if err != nil { + t.Error(err) + } + res, err := i.Eval("generic.Hello(3)") + if err != nil { + t.Error(err) + } + if res.Elem().Interface() != 3 { + t.Error("expected &(3), got", res) + } +} + +func TestGenericFuncNoDotImport(t *testing.T) { + i := New(Options{}) + err := i.Use(Exports{ + "guthib.com/generic/generic": map[string]reflect.Value{ + "Hello": reflect.ValueOf(GenericFunc("func Hello[T any](v T) { println(v) }")), + }, + }) + if err != nil { + t.Error(err) + } + _, err = i.Eval(` +import "guthib.com/generic" +func main() { generic.Hello(3) } +`) + if err != nil { + t.Error(err) + } +} + +func TestGenericFuncDotImport(t *testing.T) { + i := New(Options{}) + err := i.Use(Exports{ + "guthib.com/generic/generic": map[string]reflect.Value{ + "Hello": reflect.ValueOf(GenericFunc("func Hello[T any](v T) { println(v) }")), + }, + }) + if err != nil { + t.Error(err) + } + _, err = i.Eval(` +import . "guthib.com/generic" +func main() { Hello(3) } +`) + if err != nil { + t.Error(err) + } +} + +func TestGenericFuncComplex(t *testing.T) { + i := New(Options{}) + done := false + err := i.Use(Exports{ + "guthib.com/generic/generic": map[string]reflect.Value{ + "Do": reflect.ValueOf(func() { done = true }), + "Hello": reflect.ValueOf(GenericFunc("func Hello[T comparable, F any](v T, f func(a T) F) *T {\n\tDo(); return &v\n}")), + }, + }) + i.ImportUsed() + if err != nil { + t.Error(err) + } + res, err := i.Eval("generic.Hello[int, bool](3, func(a int) bool { return true })") + if err != nil { + t.Error(err) + } + if res.Elem().Interface() != 3 { + t.Error("expected &(3), got", res) + } + if !done { + t.Error("!done") + } +} + +func TestGenericFuncTwice(t *testing.T) { + i := New(Options{}) + err := i.Use(Exports{ + "guthib.com/generic/generic": map[string]reflect.Value{ + "Do": reflect.ValueOf(GenericFunc("func Do[T any](v T) { println(v) }")), + "Hello": reflect.ValueOf(GenericFunc("func Hello[T any](v T) { Do(v) }")), + }, + }) + i.ImportUsed() + if err != nil { + t.Error(err) + } + _, err = i.Eval(` +func main() { generic.Hello[int](3) } +`) + if err != nil { + t.Error(err) + } +} + +func TestGenericFuncInfer(t *testing.T) { + i := New(Options{}) + err := i.Use(Exports{ + "guthib.com/generic/generic": map[string]reflect.Value{ + "New": reflect.ValueOf(GenericFunc("func New[T any]() *T { return new(T) }")), + "AddAt": reflect.ValueOf(GenericFunc("func AddAt[T any](init func(n *T)) { v := New[T](); init(any(v).(*T)); println(*v) }")), + }, + }) + i.ImportUsed() + if err != nil { + t.Error(err) + } + _, err = i.Eval(` +func main() { + generic.AddAt(func(w *int) { *w = 3 }) +} +`) + if err != nil { + t.Error(err) + } +} + +type Plan struct{} + +// this one failed with valueT included in inferTypes. +func TestGenericFuncInferSecondArg(t *testing.T) { + i := New(Options{}) + err := i.Use(Exports{ + "guthib.com/generic/generic": map[string]reflect.Value{ + "Plan": reflect.ValueOf((*Plan)(nil)), + }, + }) + i.ImportUsed() + if err != nil { + t.Error(err) + } + _, err = i.Eval(` +func Add[T any](p generic.Plan, v T) { } +func main() { + Add(generic.Plan{}, []int{}) +} +`) + if err != nil { + t.Error(err) + } +} + +// this one worked fine with valueT. +func TestGenericFuncInferSecondArgLocal(t *testing.T) { + i := New(Options{}) + _, err := i.Eval(` +type Plan struct{} +func Add[T any](p Plan, v T) { } +func main() { + Add(Plan{}, []int{}) +} +`) + if err != nil { + t.Error(err) + } +} + +// this one failed without more robust arg type matching in generic.go:300. +func TestGenericFuncIgnoreError(t *testing.T) { + i := New(Options{}) + _, err := i.Eval(` +func Ignore[T any](v T, err error) T { return v } +func Make() (int, error) { return 3, nil } +func main() { + a := Ignore(Make()) +} +`) + if err != nil { + t.Error(err) + } +} diff --git a/interp/gta.go b/interp/gta.go index 28f84aee2..2aefbdf32 100644 --- a/interp/gta.go +++ b/interp/gta.go @@ -3,6 +3,7 @@ package interp import ( "path" "path/filepath" + "strings" ) // gta performs a global types analysis on the AST, registering types, @@ -241,7 +242,17 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ typ = typ.Elem() kind = typeSym } - sc.sym[n] = &symbol{kind: kind, typ: valueTOf(typ, withScope(sc)), rval: v} + if gf, ok := v.Interface().(GenericFunc); ok { + samePath := strings.HasSuffix(ipath, importPath) + if !samePath { + if _, cerr := interp.Compile(string(gf)); cerr != nil { + err = cerr + return false + } + } + } else { + sc.sym[n] = &symbol{kind: kind, typ: valueTOf(typ, withScope(sc)), rval: v} + } } default: // import symbols in package namespace if name == "" { diff --git a/interp/run.go b/interp/run.go index b8caa9086..02e4b4db3 100644 --- a/interp/run.go +++ b/interp/run.go @@ -428,7 +428,9 @@ func typeAssert(n *node, withResult, withOk bool) { v = val.value leftType = val.node.typ.rtype } else { - v = v.Elem() + if v.IsValid() && !canAssertTypes(v.Type(), rtype) { + v = v.Elem() + } leftType = v.Type() ok = true } diff --git a/interp/type.go b/interp/type.go index b3220455b..26922577d 100644 --- a/interp/type.go +++ b/interp/type.go @@ -2397,7 +2397,10 @@ func isEmptyInterface(t *itype) bool { } func isGeneric(t *itype) bool { - return t.cat == funcT && t.node != nil && len(t.node.child) > 0 && len(t.node.child[0].child) > 0 + if t.cat != funcT || t.node == nil || len(t.node.child) == 0 || t.node.child[0] == nil { + return false + } + return len(t.node.child[0].child) > 0 } func isNamedFuncSrc(t *itype) bool { diff --git a/interp/typecheck.go b/interp/typecheck.go index 3ebcd2a3c..168216014 100644 --- a/interp/typecheck.go +++ b/interp/typecheck.go @@ -967,6 +967,10 @@ func (check typecheck) arguments(n *node, child []*node, fun *node, ellipsis boo } } + if fun.typ == nil { + err := fun.cfgErrorf("typecheck arguments: nil function type: likely a syntax error above this point") + return err + } var cnt int for i, param := range params { ellip := i == l-1 && ellipsis diff --git a/interp/use.go b/interp/use.go index e4a6b6224..14e99047f 100644 --- a/interp/use.go +++ b/interp/use.go @@ -137,6 +137,18 @@ func (interp *Interpreter) Use(values Exports) error { } } + for k, v := range values { + packageName := path.Base(k) + for _, sym := range v { + if gf, ok := sym.Interface().(GenericFunc); ok { + str := fmt.Sprintf("package %s\nimport . %q\n%s", packageName, path.Dir(k), string(gf)) + if _, err := interp.Compile(str); err != nil { + return err + } + } + } + } + // Checks if input values correspond to stdlib packages by looking for one // well known stdlib package path. if _, ok := values["fmt/fmt"]; ok {