From 9a1416bbf043a8466aa035b835f474e516adfe68 Mon Sep 17 00:00:00 2001 From: xhd2015 Date: Tue, 2 Apr 2024 17:11:52 +0800 Subject: [PATCH] fix patching type method --- .github/workflows/go.yml | 7 ++- cmd/xgo/version.go | 4 +- runtime/core/version.go | 4 +- runtime/mock/mock.go | 59 ++++++++++++-------- runtime/mock/patch.go | 14 +++-- runtime/mock/patch_go1.17.go | 3 +- runtime/mock/patch_go1.18.go | 3 +- runtime/test/patch/patch_type_method_test.go | 21 +++++++ 8 files changed, 80 insertions(+), 35 deletions(-) create mode 100644 runtime/test/patch/patch_type_method_test.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 5b9cfa06..518e8db6 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -38,4 +38,9 @@ jobs: run: ~/.xgo/bin/xgo revision - name: Check Go Version - run: ~/.xgo/bin/xgo exec go version \ No newline at end of file + run: ~/.xgo/bin/xgo exec go version + + - name: Check spelling of files + uses: crate-ci/typos@master + with: + files: ./ \ No newline at end of file diff --git a/cmd/xgo/version.go b/cmd/xgo/version.go index 44d59df4..4b846ba7 100644 --- a/cmd/xgo/version.go +++ b/cmd/xgo/version.go @@ -3,8 +3,8 @@ package main import "fmt" const VERSION = "1.0.12" -const REVISION = "f3d7271450fef6b7575368a82d6fe254c894a97e+1" -const NUMBER = 147 +const REVISION = "6150ddb324b4f0915a0e1926c49e6fa94632f677+1" +const NUMBER = 148 func getRevision() string { return fmt.Sprintf("%s %s BUILD_%d", VERSION, REVISION, NUMBER) diff --git a/runtime/core/version.go b/runtime/core/version.go index 3bda771d..bdb868dd 100644 --- a/runtime/core/version.go +++ b/runtime/core/version.go @@ -7,8 +7,8 @@ import ( ) const VERSION = "1.0.12" -const REVISION = "f3d7271450fef6b7575368a82d6fe254c894a97e+1" -const NUMBER = 147 +const REVISION = "6150ddb324b4f0915a0e1926c49e6fa94632f677+1" +const NUMBER = 148 // these fields will be filled by compiler const XGO_VERSION = "" diff --git a/runtime/mock/mock.go b/runtime/mock/mock.go index cd277d0a..efdc3b94 100644 --- a/runtime/mock/mock.go +++ b/runtime/mock/mock.go @@ -12,31 +12,55 @@ import ( "github.com/xhd2015/xgo/runtime/trap" ) +// a marker to indicate the +// original function should be called var ErrCallOld = errors.New("mock: call old") type Interceptor func(ctx context.Context, fn *core.FuncInfo, args core.Object, results core.Object) error // Mock setup mock on given function `fn`. // `fn` can be a function or a method, -// when `fn` is a method, only the bound +// if `fn` is a method, only the bound // instance will be mocked, other instances // are not affected. // The returned function can be used to cancel // the passed interceptor. func Mock(fn interface{}, interceptor Interceptor) func() { - return mockByFunc(fn, interceptor) + recvPtr, fnInfo, funcPC, trappingPC := getFunc(fn) + return mock(recvPtr, fnInfo, funcPC, trappingPC, interceptor) } func MockByName(pkgPath string, funcName string, interceptor Interceptor) func() { - funcInfo := functab.GetFuncByPkg(pkgPath, funcName) - if funcInfo == nil { - panic(fmt.Errorf("failed to setup mock for: %s.%s", pkgPath, funcName)) - } - return mock(nil, funcInfo, 0, 0, interceptor) + recv, fn, funcPC, trappingPC := getFuncByName(pkgPath, funcName) + return mock(recv, fn, funcPC, trappingPC, interceptor) } // Can instance be nil? func MockMethodByName(instance interface{}, method string, interceptor Interceptor) func() { + recvPtr, fn, funcPC, trappingPC := getMethodByName(instance, method) + return mock(recvPtr, fn, funcPC, trappingPC, interceptor) +} + +func getFunc(fn interface{}) (recvPtr interface{}, fnInfo *core.FuncInfo, funcPC uintptr, trappingPC uintptr) { + // if the target function is a method, then a + // recv ptr must be given + recvPtr, fnInfo, funcPC, trappingPC = trap.InspectPC(fn) + if fnInfo == nil { + pc := reflect.ValueOf(fn).Pointer() + panic(fmt.Errorf("failed to setup mock for: %v", runtime.FuncForPC(pc).Name())) + } + return recvPtr, fnInfo, funcPC, trappingPC +} + +func getFuncByName(pkgPath string, funcName string) (recvPtr interface{}, fn *core.FuncInfo, funcPC uintptr, trappingPC uintptr) { + fn = functab.GetFuncByPkg(pkgPath, funcName) + if fn == nil { + panic(fmt.Errorf("failed to setup mock for: %s.%s", pkgPath, funcName)) + } + return nil, fn, 0, 0 +} + +func getMethodByName(instance interface{}, method string) (recvPtr interface{}, fn *core.FuncInfo, funcPC uintptr, trappingPC uintptr) { // extract instance's reflect.Type // use that type to query for reflect mapping in functab: // reflectTypeMapping map[reflect.Type]map[string]*funcInfo @@ -45,34 +69,25 @@ func MockMethodByName(instance interface{}, method string, interceptor Intercept if funcMapping == nil { panic(fmt.Errorf("failed to setup mock for type %T", instance)) } - fn := funcMapping[method] + fn = funcMapping[method] if fn == nil { panic(fmt.Errorf("failed to setup mock for: %T.%s", instance, method)) } addr := reflect.New(t) addr.Elem().Set(reflect.ValueOf(instance)) - return mock(addr.Interface(), fn, 0, 0, interceptor) + + return addr.Interface(), fn, 0, 0 } // Deprecated: use Mock instead func AddFuncInterceptor(fn interface{}, interceptor Interceptor) func() { - return mockByFunc(fn, interceptor) + return Mock(fn, interceptor) } // TODO: ensure them run in last? // no abort, run mocks // mocks are special in that they on run in pre stage -func mockByFunc(fn interface{}, interceptor Interceptor) func() { - // if the target function is a method, then a - // recv ptr must be given - recvPtr, fnInfo, funcPC, trappingPC := trap.InspectPC(fn) - if fnInfo == nil { - pc := reflect.ValueOf(fn).Pointer() - panic(fmt.Errorf("failed to setup mock for: %v", runtime.FuncForPC(pc).Name())) - } - return mock(recvPtr, fnInfo, funcPC, trappingPC, interceptor) -} // if mockFnInfo is a function, mockRecvPtr is always nil // if mockFnInfo is a method, @@ -109,10 +124,6 @@ func mock(mockRecvPtr interface{}, mockFnInfo *core.FuncInfo, funcPC uintptr, tr // check they pointing to the same variable re := reflect.ValueOf(recvPtr).Elem().Interface() me := reflect.ValueOf(mockRecvPtr).Elem().Interface() - if f.RecvPtr { - // compare pointer - // unsafe.Pointer(&re) - } if re != me { // if *recvPtr != *mockRecvPtr { return nil, nil diff --git a/runtime/mock/patch.go b/runtime/mock/patch.go index 27d4822c..050002bb 100644 --- a/runtime/mock/patch.go +++ b/runtime/mock/patch.go @@ -9,14 +9,16 @@ import ( ) func PatchByName(pkgPath string, funcName string, replacer interface{}) func() { - return MockByName(pkgPath, funcName, buildInterceptorFromPatch(replacer)) + recvPtr, funcInfo, funcPC, trappingPC := getFuncByName(pkgPath, funcName) + return mock(recvPtr, funcInfo, funcPC, trappingPC, buildInterceptorFromPatch(recvPtr, replacer)) } func PatchMethodByName(instance interface{}, method string, replacer interface{}) func() { - return MockMethodByName(instance, method, buildInterceptorFromPatch(replacer)) + recvPtr, funcInfo, funcPC, trappingPC := getMethodByName(instance, method) + return mock(recvPtr, funcInfo, funcPC, trappingPC, buildInterceptorFromPatch(recvPtr, replacer)) } -func buildInterceptorFromPatch(replacer interface{}) func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { +func buildInterceptorFromPatch(recvPtr interface{}, replacer interface{}) func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { v := reflect.ValueOf(replacer) t := v.Type() if t.Kind() != reflect.Func { @@ -26,12 +28,16 @@ func buildInterceptorFromPatch(replacer interface{}) func(ctx context.Context, f panic("replacer is nil") } nIn := t.NumIn() + + // first arg ctx: true => [recv,args[1:]...] + // first arg ctx: false => [recv, args[0:]...] return func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { // assemble arguments callArgs := make([]reflect.Value, nIn) src := 0 dst := 0 - if fn.RecvType != "" { + if fn.RecvType != "" && recvPtr != nil { + // patching an instance method src++ } if fn.FirstArgCtx { diff --git a/runtime/mock/patch_go1.17.go b/runtime/mock/patch_go1.17.go index 4e39ac50..3e1f388a 100644 --- a/runtime/mock/patch_go1.17.go +++ b/runtime/mock/patch_go1.17.go @@ -4,5 +4,6 @@ package mock func Patch(fn interface{}, replacer interface{}) func() { - return Mock(fn, buildInterceptorFromPatch(replacer)) + recvPtr, fnInfo, funcPC, trappingPC := getFunc(fn) + return mock(recvPtr, fnInfo, funcPC, trappingPC, buildInterceptorFromPatch(recvPtr, replacer)) } diff --git a/runtime/mock/patch_go1.18.go b/runtime/mock/patch_go1.18.go index d921cad7..18ca9da3 100644 --- a/runtime/mock/patch_go1.18.go +++ b/runtime/mock/patch_go1.18.go @@ -6,5 +6,6 @@ package mock // TODO: what if `fn` is a Type function // instead of an instance method? func Patch[T any](fn T, replacer T) func() { - return Mock(fn, buildInterceptorFromPatch(replacer)) + recvPtr, fnInfo, funcPC, trappingPC := getFunc(fn) + return mock(recvPtr, fnInfo, funcPC, trappingPC, buildInterceptorFromPatch(recvPtr, replacer)) } diff --git a/runtime/test/patch/patch_type_method_test.go b/runtime/test/patch/patch_type_method_test.go new file mode 100644 index 00000000..93a7b040 --- /dev/null +++ b/runtime/test/patch/patch_type_method_test.go @@ -0,0 +1,21 @@ +package patch + +import ( + "testing" + + "github.com/xhd2015/xgo/runtime/mock" +) + +func TestPatchTypeMethod(t *testing.T) { + ins := &struct_{ + s: "world", + } + mock.Patch((*struct_).greet, func(ins *struct_) string { + return "mock " + ins.s + }) + + res := ins.greet() + if res != "mock world" { + t.Fatalf("expect patched result to be %q, actual: %q", "mock world", res) + } +}