From 900159e355b5d28065da154789669f4595aeb08a Mon Sep 17 00:00:00 2001 From: Georgii Kliukovkin Date: Tue, 27 Aug 2024 20:25:45 -0700 Subject: [PATCH 1/5] add interfaces flag with unit test --- mockgen/mockgen.go | 34 ++++++++++++++ mockgen/mockgen_test.go | 101 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+) diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 5d57868..9903a2a 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -71,6 +71,7 @@ var ( debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") showVersion = flag.Bool("version", false, "Print version.") + interfaces = flag.String("interfaces", "", "List of interfaces to generate mocks for; if empty, mockgen will generate mocks for all interfaces found in the input file(s).") ) func main() { @@ -115,6 +116,14 @@ func main() { return } + if len(*interfaces) > 0 { + ifaces := strings.Split(*interfaces, ",") + if pkg.Interfaces, err = filterInterfaces(pkg.Interfaces, ifaces); err != nil { + log.Fatalf("Filtering interfaces failed: %v", err) + } + + } + outputPackageName := *packageOut if outputPackageName == "" { // pkg.Name in reflect mode is the base name of the import path, @@ -894,3 +903,28 @@ func parsePackageImport(srcDir string) (string, error) { } return "", errOutsideGoPath } + +func filterInterfaces(all []*model.Interface, requested []string) ([]*model.Interface, error) { + if len(requested) == 0 { + return nil, fmt.Errorf("no interfaces requested, other provide them or remove flag -interfaces") + } + requestedIfaces := make(map[string]struct{}) + for _, iface := range requested { + requestedIfaces[iface] = struct{}{} + } + result := make([]*model.Interface, 0, len(all)) + for _, iface := range all { + if _, ok := requestedIfaces[iface.Name]; ok { + result = append(result, iface) + delete(requestedIfaces, iface.Name) + } + } + if len(requestedIfaces) > 0 { + var missing []string + for iface := range requestedIfaces { + missing = append(missing, iface) + } + return nil, fmt.Errorf("missing interfaces: %s", strings.Join(missing, ", ")) + } + return result, nil +} diff --git a/mockgen/mockgen_test.go b/mockgen/mockgen_test.go index 6b17127..50e838e 100644 --- a/mockgen/mockgen_test.go +++ b/mockgen/mockgen_test.go @@ -467,3 +467,104 @@ func TestParseExcludeInterfaces(t *testing.T) { }) } } + +func Test_filterInterfaces(t *testing.T) { + type args struct { + all []*model.Interface + requested []string + } + tests := []struct { + name string + args args + want []*model.Interface + wantErr bool + }{ + { + name: "no filter", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{}, + }, + want: nil, + wantErr: true, + }, + { + name: "filter by Foo", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{"Foo"}, + }, + want: []*model.Interface{ + { + Name: "Foo", + }, + }, + wantErr: false, + }, + { + name: "filter by Foo and Bar", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{"Foo", "Bar"}, + }, + want: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + wantErr: false, + }, + { + name: "incorrect filter by Foo and Baz", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{"Foo", "Baz"}, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := filterInterfaces(tt.args.all, tt.args.requested) + if (err != nil) != tt.wantErr { + t.Errorf("filterInterfaces() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("filterInterfaces() got = %v, want %v", got, tt.want) + } + }) + } +} From 5a7af371c4541fdce16cef961522669c18b30d46 Mon Sep 17 00:00:00 2001 From: Georgii Kliukovkin Date: Wed, 28 Aug 2024 21:00:15 -0700 Subject: [PATCH 2/5] move reading to a better place --- mockgen/mockgen.go | 7 ------- mockgen/parse.go | 9 +++++++++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 9903a2a..ec85c47 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -116,13 +116,6 @@ func main() { return } - if len(*interfaces) > 0 { - ifaces := strings.Split(*interfaces, ",") - if pkg.Interfaces, err = filterInterfaces(pkg.Interfaces, ifaces); err != nil { - log.Fatalf("Filtering interfaces failed: %v", err) - } - - } outputPackageName := *packageOut if outputPackageName == "" { diff --git a/mockgen/parse.go b/mockgen/parse.go index f43321c..1d7f7b8 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -92,6 +92,15 @@ func sourceMode(source string) (*model.Package, error) { for pkgPath := range dotImports { pkg.DotImports = append(pkg.DotImports, pkgPath) } + + if len(*interfaces) > 0 { + ifaces := strings.Split(*interfaces, ",") + if pkg.Interfaces, err = filterInterfaces(pkg.Interfaces, ifaces); err != nil { + log.Fatalf("Filtering interfaces failed: %v", err) + } + + } + return pkg, nil } From 6cc01b16e1017aaea03481a652b64150ad4981a1 Mon Sep 17 00:00:00 2001 From: Georgii Kliukovkin Date: Wed, 28 Aug 2024 21:04:45 -0700 Subject: [PATCH 3/5] Moving to a parse.go --- mockgen/mockgen.go | 26 ---------- mockgen/mockgen_test.go | 101 +------------------------------------- mockgen/parse.go | 27 ++++++++++- mockgen/parse_test.go | 104 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 131 insertions(+), 127 deletions(-) diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index ec85c47..47216c1 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -116,7 +116,6 @@ func main() { return } - outputPackageName := *packageOut if outputPackageName == "" { // pkg.Name in reflect mode is the base name of the import path, @@ -896,28 +895,3 @@ func parsePackageImport(srcDir string) (string, error) { } return "", errOutsideGoPath } - -func filterInterfaces(all []*model.Interface, requested []string) ([]*model.Interface, error) { - if len(requested) == 0 { - return nil, fmt.Errorf("no interfaces requested, other provide them or remove flag -interfaces") - } - requestedIfaces := make(map[string]struct{}) - for _, iface := range requested { - requestedIfaces[iface] = struct{}{} - } - result := make([]*model.Interface, 0, len(all)) - for _, iface := range all { - if _, ok := requestedIfaces[iface.Name]; ok { - result = append(result, iface) - delete(requestedIfaces, iface.Name) - } - } - if len(requestedIfaces) > 0 { - var missing []string - for iface := range requestedIfaces { - missing = append(missing, iface) - } - return nil, fmt.Errorf("missing interfaces: %s", strings.Join(missing, ", ")) - } - return result, nil -} diff --git a/mockgen/mockgen_test.go b/mockgen/mockgen_test.go index 50e838e..0031345 100644 --- a/mockgen/mockgen_test.go +++ b/mockgen/mockgen_test.go @@ -468,103 +468,4 @@ func TestParseExcludeInterfaces(t *testing.T) { } } -func Test_filterInterfaces(t *testing.T) { - type args struct { - all []*model.Interface - requested []string - } - tests := []struct { - name string - args args - want []*model.Interface - wantErr bool - }{ - { - name: "no filter", - args: args{ - all: []*model.Interface{ - { - Name: "Foo", - }, - { - Name: "Bar", - }, - }, - requested: []string{}, - }, - want: nil, - wantErr: true, - }, - { - name: "filter by Foo", - args: args{ - all: []*model.Interface{ - { - Name: "Foo", - }, - { - Name: "Bar", - }, - }, - requested: []string{"Foo"}, - }, - want: []*model.Interface{ - { - Name: "Foo", - }, - }, - wantErr: false, - }, - { - name: "filter by Foo and Bar", - args: args{ - all: []*model.Interface{ - { - Name: "Foo", - }, - { - Name: "Bar", - }, - }, - requested: []string{"Foo", "Bar"}, - }, - want: []*model.Interface{ - { - Name: "Foo", - }, - { - Name: "Bar", - }, - }, - wantErr: false, - }, - { - name: "incorrect filter by Foo and Baz", - args: args{ - all: []*model.Interface{ - { - Name: "Foo", - }, - { - Name: "Bar", - }, - }, - requested: []string{"Foo", "Baz"}, - }, - want: nil, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := filterInterfaces(tt.args.all, tt.args.requested) - if (err != nil) != tt.wantErr { - t.Errorf("filterInterfaces() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("filterInterfaces() got = %v, want %v", got, tt.want) - } - }) - } -} + diff --git a/mockgen/parse.go b/mockgen/parse.go index 1d7f7b8..3b3143c 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -100,7 +100,7 @@ func sourceMode(source string) (*model.Package, error) { } } - + return pkg, nil } @@ -811,4 +811,29 @@ func packageNameOfDir(srcDir string) (string, error) { return packageImport, nil } +func filterInterfaces(all []*model.Interface, requested []string) ([]*model.Interface, error) { + if len(requested) == 0 { + return nil, fmt.Errorf("no interfaces requested, other provide them or remove flag -interfaces") + } + requestedIfaces := make(map[string]struct{}) + for _, iface := range requested { + requestedIfaces[iface] = struct{}{} + } + result := make([]*model.Interface, 0, len(all)) + for _, iface := range all { + if _, ok := requestedIfaces[iface.Name]; ok { + result = append(result, iface) + delete(requestedIfaces, iface.Name) + } + } + if len(requestedIfaces) > 0 { + var missing []string + for iface := range requestedIfaces { + missing = append(missing, iface) + } + return nil, fmt.Errorf("missing interfaces: %s", strings.Join(missing, ", ")) + } + return result, nil +} + var errOutsideGoPath = errors.New("source directory is outside GOPATH") diff --git a/mockgen/parse_test.go b/mockgen/parse_test.go index 3c4ba4c..6683a72 100644 --- a/mockgen/parse_test.go +++ b/mockgen/parse_test.go @@ -4,6 +4,9 @@ import ( "go/parser" "go/token" "testing" + "reflect" + + "go.uber.org/mock/mockgen/model" ) func TestFileParser_ParseFile(t *testing.T) { @@ -143,3 +146,104 @@ func TestParseArrayWithConstLength(t *testing.T) { } } } + +func Test_filterInterfaces(t *testing.T) { + type args struct { + all []*model.Interface + requested []string + } + tests := []struct { + name string + args args + want []*model.Interface + wantErr bool + }{ + { + name: "no filter", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{}, + }, + want: nil, + wantErr: true, + }, + { + name: "filter by Foo", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{"Foo"}, + }, + want: []*model.Interface{ + { + Name: "Foo", + }, + }, + wantErr: false, + }, + { + name: "filter by Foo and Bar", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{"Foo", "Bar"}, + }, + want: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + wantErr: false, + }, + { + name: "incorrect filter by Foo and Baz", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{"Foo", "Baz"}, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := filterInterfaces(tt.args.all, tt.args.requested) + if (err != nil) != tt.wantErr { + t.Errorf("filterInterfaces() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("filterInterfaces() got = %v, want %v", got, tt.want) + } + }) + } +} \ No newline at end of file From 255383515f9d82c7b241093bdebbfaba18fcf8f8 Mon Sep 17 00:00:00 2001 From: Georgii Kliukovkin Date: Wed, 28 Aug 2024 21:05:33 -0700 Subject: [PATCH 4/5] remove spaces --- mockgen/mockgen_test.go | 2 -- mockgen/parse_test.go | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mockgen/mockgen_test.go b/mockgen/mockgen_test.go index 0031345..6b17127 100644 --- a/mockgen/mockgen_test.go +++ b/mockgen/mockgen_test.go @@ -467,5 +467,3 @@ func TestParseExcludeInterfaces(t *testing.T) { }) } } - - diff --git a/mockgen/parse_test.go b/mockgen/parse_test.go index 6683a72..ac6494f 100644 --- a/mockgen/parse_test.go +++ b/mockgen/parse_test.go @@ -5,7 +5,7 @@ import ( "go/token" "testing" "reflect" - + "go.uber.org/mock/mockgen/model" ) @@ -246,4 +246,4 @@ func Test_filterInterfaces(t *testing.T) { } }) } -} \ No newline at end of file +} From ec91f0c2187cf11404bf6f1bed1cce6485cfbb7f Mon Sep 17 00:00:00 2001 From: Georgii Kliukovkin Date: Sun, 13 Oct 2024 18:23:24 -0700 Subject: [PATCH 5/5] update the PR according to the comments: - Instead of using an -interfaces flag, what do you think about having source mode read positional arguments (i.e., mockgen -source InterfaceOne,InterfaceTwo) to align it with how reflect mode works? (If there are no positional arguments specified, we would parse all interfaces to keep backwards compatibility) - Instead of parsing and then dropping interfaces that aren't specified, can we simply not parse ones that aren't requested? This is similar to how the exclusion flag already works and would avoid some wasted computation. --- mockgen/mockgen.go | 1 - mockgen/parse.go | 63 ++++++++++++++++++++++++++----------------- mockgen/parse_test.go | 31 ++++++++++++++++++--- 3 files changed, 65 insertions(+), 30 deletions(-) diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 47216c1..5d57868 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -71,7 +71,6 @@ var ( debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") showVersion = flag.Bool("version", false, "Print version.") - interfaces = flag.String("interfaces", "", "List of interfaces to generate mocks for; if empty, mockgen will generate mocks for all interfaces found in the input file(s).") ) func main() { diff --git a/mockgen/parse.go b/mockgen/parse.go index 3b3143c..3429ab0 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -18,6 +18,7 @@ package main import ( "errors" + "flag" "fmt" "go/ast" "go/build" @@ -93,12 +94,17 @@ func sourceMode(source string) (*model.Package, error) { pkg.DotImports = append(pkg.DotImports, pkgPath) } - if len(*interfaces) > 0 { - ifaces := strings.Split(*interfaces, ",") + // Get positional arguments after the flags + ifaces := flag.Args() + + // If there are interfaces provided as positional arguments, filter them + if len(ifaces) > 0 { if pkg.Interfaces, err = filterInterfaces(pkg.Interfaces, ifaces); err != nil { log.Fatalf("Filtering interfaces failed: %v", err) } - + } else { + // No interfaces provided, process all interfaces for backward compatibility + log.Printf("No interfaces specified, processing all interfaces") } return pkg, nil @@ -812,28 +818,35 @@ func packageNameOfDir(srcDir string) (string, error) { } func filterInterfaces(all []*model.Interface, requested []string) ([]*model.Interface, error) { - if len(requested) == 0 { - return nil, fmt.Errorf("no interfaces requested, other provide them or remove flag -interfaces") - } - requestedIfaces := make(map[string]struct{}) - for _, iface := range requested { - requestedIfaces[iface] = struct{}{} - } - result := make([]*model.Interface, 0, len(all)) - for _, iface := range all { - if _, ok := requestedIfaces[iface.Name]; ok { - result = append(result, iface) - delete(requestedIfaces, iface.Name) - } - } - if len(requestedIfaces) > 0 { - var missing []string - for iface := range requestedIfaces { - missing = append(missing, iface) - } - return nil, fmt.Errorf("missing interfaces: %s", strings.Join(missing, ", ")) - } - return result, nil + // If no interfaces are requested, return all interfaces + if len(requested) == 0 { + return all, nil + } + + requestedIfaces := make(map[string]struct{}) + for _, iface := range requested { + requestedIfaces[iface] = struct{}{} + } + + result := make([]*model.Interface, 0, len(requestedIfaces)) + for _, iface := range all { + // Only add interfaces that are requested + if _, ok := requestedIfaces[iface.Name]; ok { + result = append(result, iface) + delete(requestedIfaces, iface.Name) // Remove matched iface from requested + } + } + + // If any requested interfaces were not found, return an error + if len(requestedIfaces) > 0 { + var missing []string + for iface := range requestedIfaces { + missing = append(missing, iface) + } + return nil, fmt.Errorf("missing interfaces: %s", strings.Join(missing, ", ")) + } + + return result, nil } var errOutsideGoPath = errors.New("source directory is outside GOPATH") diff --git a/mockgen/parse_test.go b/mockgen/parse_test.go index ac6494f..2c8da67 100644 --- a/mockgen/parse_test.go +++ b/mockgen/parse_test.go @@ -3,8 +3,8 @@ package main import ( "go/parser" "go/token" - "testing" "reflect" + "testing" "go.uber.org/mock/mockgen/model" ) @@ -159,7 +159,7 @@ func Test_filterInterfaces(t *testing.T) { wantErr bool }{ { - name: "no filter", + name: "no filter (returns all interfaces)", args: args{ all: []*model.Interface{ { @@ -171,8 +171,15 @@ func Test_filterInterfaces(t *testing.T) { }, requested: []string{}, }, - want: nil, - wantErr: true, + want: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + wantErr: false, }, { name: "filter by Foo", @@ -233,6 +240,22 @@ func Test_filterInterfaces(t *testing.T) { want: nil, wantErr: true, }, + { + name: "missing interface (Baz not found)", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{"Baz"}, + }, + want: nil, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {