diff --git a/instrumenter.go b/instrumenter.go index 73efd6c..f26d73c 100644 --- a/instrumenter.go +++ b/instrumenter.go @@ -11,11 +11,11 @@ import ( "go/token" "go/types" "os" + "os/exec" "path/filepath" "reflect" "runtime" "sort" - "strconv" "strings" ) @@ -106,7 +106,7 @@ func (i *instrumenter) instrument(srcDir, singleFile, dstDir string) bool { i.instrumentFile(name, file, dstDir) }) } - i.writeGobcoFiles(dstDir, pkgs) + i.writeGobcoFiles(srcDir, dstDir, pkgs) return true } @@ -623,7 +623,7 @@ var fixedTemplate string //go:embed templates/gobco_no_testmain_test.go var noTestMainTemplate string -func (i *instrumenter) writeGobcoFiles(tmpDir string, pkgs []*ast.Package) { +func (i *instrumenter) writeGobcoFiles(srcDir, tmpDir string, pkgs []*ast.Package) { pkgname := pkgs[0].Name fixPkgname := func(str string) string { str = strings.TrimPrefix(str, "//go:build ignore\n// +build ignore\n\n") @@ -636,7 +636,7 @@ func (i *instrumenter) writeGobcoFiles(tmpDir string, pkgs []*ast.Package) { writeFile(filepath.Join(tmpDir, "gobco_no_testmain_test.go"), fixPkgname(noTestMainTemplate)) } - i.writeGobcoBlackBox(pkgs, tmpDir) + i.writeGobcoBlackBox(pkgs, srcDir, tmpDir) } func (i *instrumenter) writeGobcoGo(filename, pkgname string) { @@ -664,32 +664,14 @@ func (i *instrumenter) writeGobcoGo(filename, pkgname string) { // writeGobcoBlackBox makes the function 'GobcoCover' available // to black box tests (those in 'package x_test' instead of 'package x') // by delegating to the function of the same name in the main package. -func (i *instrumenter) writeGobcoBlackBox(pkgs []*ast.Package, dstDir string) { +func (i *instrumenter) writeGobcoBlackBox(pkgs []*ast.Package, srcDir, dstDir string) { if len(pkgs) < 2 { return } - // Copy the 'import' directive from one of the existing files. - pkgName, pkgPath := "", "" - for _, pkg := range pkgs { - forEachFile(pkg, func(name string, file *ast.File) { - for _, imp := range file.Imports { - var impName string - p, err := strconv.Unquote(imp.Path.Value) - ok(err) - if imp.Name != nil { - impName = imp.Name.Name - } else { - impName = filepath.Base(p) - } - - if impName == pkgs[0].Name { - pkgName = impName - pkgPath = p - } - } - }) - } + pkgPath, err := findPackagePath(srcDir) + ok(err) + pkgName := filepath.Base(pkgPath) text := "" + "package " + pkgs[0].Name + "_test\n" + @@ -707,6 +689,35 @@ func (i *instrumenter) writeGobcoBlackBox(pkgs []*ast.Package, dstDir string) { writeFile(filepath.Join(dstDir, "gobco_bridge_test.go"), text) } +// findPackagePath finds import path of a package that srcDir indicates +func findPackagePath(srcDir string) (string, error) { + _, moduleRel, err := findInModule(srcDir) + if err != nil { + return "", err + } + + moduleName, err := getModuleName() + if err != nil { + return "", err + } + + if moduleRel == "." { + return moduleName, nil + } else { + pkgPath := fmt.Sprintf("%s/%s", moduleName, moduleRel) + return pkgPath, nil + } +} + +func getModuleName() (string, error) { + cmd := exec.Command("go", "list", "-m") + output, err := cmd.Output() + if err != nil { + return "", err + } + return strings.TrimSpace(string(output)), nil +} + func (i *instrumenter) str(expr ast.Expr) string { var sb strings.Builder ok(printer.Fprint(&sb, i.fset, expr)) diff --git a/main.go b/main.go index afd9bff..95ee4f7 100644 --- a/main.go +++ b/main.go @@ -204,27 +204,38 @@ func (g *gobco) gopaths() string { return filepath.Join(home, "go") } -func (g *gobco) findInModule(dir string) (moduleRoot, moduleRel string) { - absDir, err := filepath.Abs(dir) +func (g *gobco) findInModule(dir string) (string, string) { + moduleRoot, moduleRel, err := findInModule(dir) g.check(err) + return moduleRoot, moduleRel +} + +// findInModule finds path of moduleRoot and relative path from the moduleRoot to dir +func findInModule(dir string) (moduleRoot, moduleRel string, err error) { + absDir, err := filepath.Abs(dir) + if err != nil { + return "", "", err + } abs := absDir for { if _, err := os.Lstat(filepath.Join(abs, "go.mod")); err == nil { rel, err := filepath.Rel(abs, absDir) - g.check(err) + if err != nil { + return "", "", err + } root := abs if rel == "." { root = dir } - return root, rel + return root, rel, nil } parent := filepath.Dir(abs) if parent == abs { - return "", "" + return "", "", nil } abs = parent }