Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get import path from go.mod #1

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 37 additions & 26 deletions instrumenter.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
"go/token"
"go/types"
"os"
"os/exec"
"path/filepath"
"reflect"
"runtime"
"sort"
"strconv"
"strings"
)

Expand Down Expand Up @@ -106,7 +106,7 @@ func (i *instrumenter) instrument(srcDir, singleFile, dstDir string) bool {
i.instrumentFile(name, file, dstDir)
})
}
i.writeGobcoFiles(dstDir, pkgs)
i.writeGobcoFiles(srcDir, dstDir, pkgs)
return true
}

Expand Down Expand Up @@ -623,7 +623,7 @@ var fixedTemplate string
//go:embed templates/gobco_no_testmain_test.go
var noTestMainTemplate string

func (i *instrumenter) writeGobcoFiles(tmpDir string, pkgs []*ast.Package) {
func (i *instrumenter) writeGobcoFiles(srcDir, tmpDir string, pkgs []*ast.Package) {
pkgname := pkgs[0].Name
fixPkgname := func(str string) string {
str = strings.TrimPrefix(str, "//go:build ignore\n// +build ignore\n\n")
Expand All @@ -636,7 +636,7 @@ func (i *instrumenter) writeGobcoFiles(tmpDir string, pkgs []*ast.Package) {
writeFile(filepath.Join(tmpDir, "gobco_no_testmain_test.go"), fixPkgname(noTestMainTemplate))
}

i.writeGobcoBlackBox(pkgs, tmpDir)
i.writeGobcoBlackBox(pkgs, srcDir, tmpDir)
}

func (i *instrumenter) writeGobcoGo(filename, pkgname string) {
Expand Down Expand Up @@ -664,32 +664,14 @@ func (i *instrumenter) writeGobcoGo(filename, pkgname string) {
// writeGobcoBlackBox makes the function 'GobcoCover' available
// to black box tests (those in 'package x_test' instead of 'package x')
// by delegating to the function of the same name in the main package.
func (i *instrumenter) writeGobcoBlackBox(pkgs []*ast.Package, dstDir string) {
func (i *instrumenter) writeGobcoBlackBox(pkgs []*ast.Package, srcDir, dstDir string) {
if len(pkgs) < 2 {
return
}

// Copy the 'import' directive from one of the existing files.
pkgName, pkgPath := "", ""
for _, pkg := range pkgs {
forEachFile(pkg, func(name string, file *ast.File) {
for _, imp := range file.Imports {
var impName string
p, err := strconv.Unquote(imp.Path.Value)
ok(err)
if imp.Name != nil {
impName = imp.Name.Name
} else {
impName = filepath.Base(p)
}

if impName == pkgs[0].Name {
pkgName = impName
pkgPath = p
}
}
})
}
pkgPath, err := findPackagePath(srcDir)
ok(err)
pkgName := filepath.Base(pkgPath)

text := "" +
"package " + pkgs[0].Name + "_test\n" +
Expand All @@ -707,6 +689,35 @@ func (i *instrumenter) writeGobcoBlackBox(pkgs []*ast.Package, dstDir string) {
writeFile(filepath.Join(dstDir, "gobco_bridge_test.go"), text)
}

// findPackagePath finds import path of a package that srcDir indicates
func findPackagePath(srcDir string) (string, error) {
_, moduleRel, err := findInModule(srcDir)
if err != nil {
return "", err
}

moduleName, err := getModuleName()
if err != nil {
return "", err
}

if moduleRel == "." {
return moduleName, nil
} else {
pkgPath := fmt.Sprintf("%s/%s", moduleName, moduleRel)
return pkgPath, nil
}
}

func getModuleName() (string, error) {
cmd := exec.Command("go", "list", "-m")
output, err := cmd.Output()
if err != nil {
return "", err
}
return strings.TrimSpace(string(output)), nil
}

func (i *instrumenter) str(expr ast.Expr) string {
var sb strings.Builder
ok(printer.Fprint(&sb, i.fset, expr))
Expand Down
21 changes: 16 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,27 +204,38 @@ func (g *gobco) gopaths() string {
return filepath.Join(home, "go")
}

func (g *gobco) findInModule(dir string) (moduleRoot, moduleRel string) {
absDir, err := filepath.Abs(dir)
func (g *gobco) findInModule(dir string) (string, string) {
moduleRoot, moduleRel, err := findInModule(dir)
g.check(err)
return moduleRoot, moduleRel
}

// findInModule finds path of moduleRoot and relative path from the moduleRoot to dir
func findInModule(dir string) (moduleRoot, moduleRel string, err error) {
absDir, err := filepath.Abs(dir)
if err != nil {
return "", "", err
}

abs := absDir
for {
if _, err := os.Lstat(filepath.Join(abs, "go.mod")); err == nil {
rel, err := filepath.Rel(abs, absDir)
g.check(err)
if err != nil {
return "", "", err
}

root := abs
if rel == "." {
root = dir
}

return root, rel
return root, rel, nil
}

parent := filepath.Dir(abs)
if parent == abs {
return "", ""
return "", "", nil
}
abs = parent
}
Expand Down