diff --git a/stdsym/main.go b/stdsym/main.go index c91be75..85d9ab0 100644 --- a/stdsym/main.go +++ b/stdsym/main.go @@ -1,9 +1,11 @@ package main import ( - "bytes" + "bufio" "log" "os" + "regexp" + "strings" "github.com/lotusirous/gostdsym" ) @@ -17,20 +19,18 @@ func main() { } func run() error { - cwd, err := os.Getwd() + stdPattern := "std" + pkgs, err := gostdsym.LoadPackages(stdPattern) if err != nil { return err } - - pkgs, err := gostdsym.LoadPackages("std") - if err != nil { - return err - } - w := os.Stdout - var buf bytes.Buffer - for _, v := range pkgs { - out, err := gostdsym.GetPackageSymbols(v, cwd) + buf := bufio.NewWriter(w) + for _, pattern := range pkgs { + if isSkipPackage(pattern) { + continue + } + out, err := gostdsym.GetPackageSymbols(pattern) if err != nil { return err } @@ -38,6 +38,11 @@ func run() error { buf.WriteString(sym + "\n") } } - _, err = buf.WriteTo(w) - return err + return buf.Flush() +} + +var internalPkg = regexp.MustCompile(`(^|/)internal($|/)`) + +func isSkipPackage(v string) bool { + return internalPkg.MatchString(v) || strings.HasPrefix(v, "vendor/") && v != "" } diff --git a/symbol.go b/symbol.go index c5bb256..b0e6b8a 100644 --- a/symbol.go +++ b/symbol.go @@ -7,20 +7,13 @@ import ( "go/parser" "go/token" "io/fs" - "regexp" + "os" "slices" - "strings" "golang.org/x/tools/go/packages" ) -var internalPkg = regexp.MustCompile(`(^|/)internal($|/)`) - -func isSkipPackage(v string) bool { - return internalPkg.MatchString(v) || strings.HasPrefix(v, "vendor/") -} - -// LoadPackages returns a list of packages. +// LoadPackages returns all packages from a given pattern. func LoadPackages(pattern string) ([]string, error) { pkgs, err := packages.Load(nil, pattern) if err != nil { @@ -28,23 +21,23 @@ func LoadPackages(pattern string) ([]string, error) { } out := make([]string, len(pkgs)) for i := 0; i < len(pkgs); i++ { - path := pkgs[i].PkgPath - if isSkipPackage(path) { - continue - } out[i] = pkgs[i].PkgPath } return out, nil } // GetPackageSymbols extracts all exported symbols from a package. -func GetPackageSymbols(name, srcDir string) ([]string, error) { - buildPkg, err := build.Import(name, srcDir, build.ImportComment) +func GetPackageSymbols(pattern string) ([]string, error) { + wd, err := os.Getwd() + if err != nil { + return nil, err + } + buildPkg, err := build.Import(pattern, wd, build.ImportComment) if err != nil { return nil, err } - syms, err := buildSymbols(buildPkg) + syms, err := parsePackage(buildPkg) if err != nil { return nil, err } @@ -57,8 +50,10 @@ func GetPackageSymbols(name, srcDir string) ([]string, error) { return syms, nil } -func buildSymbols(pkg *build.Package) ([]string, error) { - fset := token.NewFileSet() +func parsePackage(pkg *build.Package) ([]string, error) { + // include tells parser.ParseDir which files to include. + // That means the file must be in the build package's GoFiles or CgoFiles + // list only (no tag-ignored files, tests, swig or other non-Go files). include := func(info fs.FileInfo) bool { for _, name := range pkg.GoFiles { if name == info.Name() { @@ -67,14 +62,24 @@ func buildSymbols(pkg *build.Package) ([]string, error) { } return false } + fset := token.NewFileSet() pkgs, err := parser.ParseDir(fset, pkg.Dir, include, parser.ParseComments) if err != nil { return nil, err } - astPkg, ok := pkgs[pkg.Name] - if !ok { - return nil, fmt.Errorf("not found package name: %s", pkg.Name) + if len(pkgs) == 0 { + return nil, fmt.Errorf("no source-code package in directory %s", pkg.Dir) } + astPkg := pkgs[pkg.Name] + + // TODO: go/doc does not include typed constants in the constants + // list, which is what we want. For instance, time.Sunday is of type + // time.Weekday, so it is defined in the type but not in the + // Consts list for the package. This prevents + // go doc time.Sunday + // from finding the symbol. Work around this for now, but we + // should fix it in go/doc. + // A similar story applies to factory functions. docPkg := doc.New(astPkg, pkg.ImportPath, doc.AllDecls) typs := types(docPkg) diff --git a/symbol_test.go b/symbol_test.go index 323b13a..263c7f0 100644 --- a/symbol_test.go +++ b/symbol_test.go @@ -1,21 +1,23 @@ package gostdsym import ( - "os" "reflect" "sort" "testing" ) func TestAll(t *testing.T) { - wd, _ := os.Getwd() + _, err := LoadPackages("std") + if err != nil { + t.Fatal(err) + } + for _, test := range []struct { in string - dir string want []string }{ - {in: "cmp", dir: wd, want: []string{"cmp", "cmp.Less", "cmp.Ordered", "cmp.Compare"}}, - {in: "html/template", dir: wd, want: []string{ + {in: "cmp", want: []string{"cmp", "cmp.Less", "cmp.Ordered", "cmp.Compare"}}, + {in: "html/template", want: []string{ "html/template", "html/template.CSS", "html/template.ErrAmbigContext", @@ -78,8 +80,8 @@ func TestAll(t *testing.T) { "html/template.parseGlob", }}, { - in: "container/list", - dir: wd, + in: "container/list", + want: []string{ "container/list", "container/list.Element", @@ -106,8 +108,8 @@ func TestAll(t *testing.T) { }, { - in: "context", - dir: wd, + in: "context", + want: []string{ "context", "context.AfterFunc", @@ -130,8 +132,7 @@ func TestAll(t *testing.T) { }, }, { - in: "errors", - dir: wd, + in: "errors", want: []string{ "errors", "errors.Is", @@ -143,7 +144,7 @@ func TestAll(t *testing.T) { }, }, } { - got, err := GetPackageSymbols(test.in, test.dir) + got, err := GetPackageSymbols(test.in) if err != nil { t.Fatalf("want no error for MustExtract, got: %v", err) }