diff --git a/.travis.yml b/.travis.yml index 978c4af..38034d9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,7 @@ language: go go: - 1.x - - 1.11.x + - 1.12.x env: - GO111MODULE=on @@ -10,7 +10,10 @@ env: install: true script: - - go get -u golang.org/x/lint/golint - - golint -set_exit_status $(go list ./...) - - go test -v -vet=all ./... - - go test -v -race ./... \ No newline at end of file + - set -e + - fmt=$(gofmt -l .) + - test -z $fmt || (echo "please run gofmt" ; echo $fmt ; exit 1) + - go run golang.org/x/lint/golint -set_exit_status $(go list ./...) + - go test -v ./... + - go test -v -race -vet=all ./... + - git diff --quiet || (echo 'generated go files are not up to date, check go generate, go.sum and go.mod' ; git diff ; exit 1) diff --git a/binding.go b/binding.go index 0d6c22c..8872056 100644 --- a/binding.go +++ b/binding.go @@ -100,12 +100,14 @@ func (b *Binding) equal(to *Binding) bool { } // Create creates a new instance by the provider and requests injection, all provider arguments are automatically filled -func (p *Provider) Create(injector *Injector) reflect.Value { +func (p *Provider) Create(injector *Injector) (reflect.Value, error) { in := make([]reflect.Value, p.fnc.Type().NumIn()) + var err error for i := 0; i < p.fnc.Type().NumIn(); i++ { - in[i] = injector.getInstance(p.fnc.Type().In(i), "", traceCircular) + if in[i], err = injector.getInstance(p.fnc.Type().In(i), "", traceCircular); err != nil { + return reflect.Value{}, err + } } res := p.fnc.Call(in)[0] - injector.requestInjection(res, traceCircular) - return res + return res, injector.requestInjection(res, traceCircular) } diff --git a/circular_test.go b/circular_test.go index 9c097f7..6bafc3d 100644 --- a/circular_test.go +++ b/circular_test.go @@ -30,33 +30,39 @@ type ( ) func TestDingoCircula(t *testing.T) { - traceCircular = make([]circularTraceEntry, 0) + EnableCircularTracing() defer func() { traceCircular = nil }() - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) + assert.Panics(t, func() { - _, ok := injector.GetInstance(new(circA)).(*circA) + i, err := injector.GetInstance(new(circA)) + assert.NoError(t, err) + _, ok := i.(*circA) if !ok { t.Fail() } }) injector.Bind(new(circCInterface)).To(circC{}) - assert.NotPanics(t, func() { - c, ok := injector.GetInstance(new(circC)).(*circC) - if !ok { - t.Fail() - } - assert.NotNil(t, c.C()) - }) + i, err := injector.GetInstance(new(circC)) + assert.NoError(t, err) + c, ok := i.(*circC) + if !ok { + t.Fail() + } + assert.NotNil(t, c.C()) var d *circD assert.NotPanics(t, func() { var ok bool - d, ok = injector.GetInstance(new(circD)).(*circD) + i, err := injector.GetInstance(new(circD)) + assert.NoError(t, err) + d, ok = i.(*circD) if !ok { t.Fail() } diff --git a/dingo.go b/dingo.go index b77d715..3698b7b 100644 --- a/dingo.go +++ b/dingo.go @@ -1,6 +1,7 @@ package dingo import ( + "errors" "fmt" "log" "reflect" @@ -14,7 +15,10 @@ const ( DEFAULT ) -var traceCircular []circularTraceEntry +var ( + traceCircular []circularTraceEntry + fmtErrorf = fmt.Errorf +) // EnableCircularTracing activates dingo's trace feature to find circular dependencies // this is super expensive (memory wise), so it should only be used for debugging purposes @@ -52,7 +56,7 @@ type ( ) // NewInjector builds up a new Injector out of a list of Modules -func NewInjector(modules ...Module) *Injector { +func NewInjector(modules ...Module) (*Injector, error) { injector := &Injector{ bindings: make(map[reflect.Type][]*Binding), multibindings: make(map[reflect.Type][]*Binding), @@ -71,23 +75,29 @@ func NewInjector(modules ...Module) *Injector { injector.BindScope(ChildSingleton) // init current modules - injector.InitModules(modules...) - - return injector + return injector, injector.InitModules(modules...) } // Child derives a child injector with a new ChildSingletonScope -func (injector *Injector) Child() *Injector { - newInjector := NewInjector() +func (injector *Injector) Child() (*Injector, error) { + if injector == nil { + return nil, errors.New("can not create a child of an uninitialized injector") + } + + newInjector, err := NewInjector() + if err != nil { + return nil, err + } + newInjector.parent = injector newInjector.Bind(Injector{}).ToInstance(newInjector) newInjector.BindScope(NewChildSingletonScope()) // bind a new child-singleton - return newInjector + return newInjector, nil } // InitModules initializes the injector with the given modules -func (injector *Injector) InitModules(modules ...Module) { +func (injector *Injector) InitModules(modules ...Module) error { injector.stage = INIT modules = resolveDependencies(modules, nil) @@ -110,7 +120,7 @@ func (injector *Injector) InitModules(modules ...Module) { } continue } - panic("cannot override unknown binding " + override.typ.String() + " (annotated with " + override.annotatedWith + ")") // todo ok? + return fmtErrorf("cannot override unknown binding %q (annotated with %q)", override.typ.String(), override.annotatedWith) // todo ok? } // make sure there are no duplicated bindings @@ -125,7 +135,7 @@ func (injector *Injector) InitModules(modules ...Module) { if binding.to != nil { duplicateBinding = fmt.Sprintf("%#v%#v", binding.to.PkgPath(), binding.to.Name()) } - panic(fmt.Sprintf("already known binding for %q with annotation %q | Known binding: %q Try %q", typ, binding.annotatedWith, knownBinding, duplicateBinding)) + return fmtErrorf("already known binding for %q with annotation %q | Known binding: %q Try %q", typ, binding.annotatedWith, knownBinding, duplicateBinding) } known[binding.annotatedWith] = binding } @@ -135,35 +145,49 @@ func (injector *Injector) InitModules(modules ...Module) { // continue with delayed injections for _, object := range injector.delayed { - injector.requestInjection(object, traceCircular) + if err := injector.requestInjection(object, traceCircular); err != nil { + return err + } } injector.delayed = nil // build eager singletons - if injector.buildEagerSingletons { - for _, bindings := range injector.bindings { - for _, binding := range bindings { - if binding.eager { - injector.getInstance(binding.typeof, binding.annotatedWith, traceCircular) + if !injector.buildEagerSingletons { + return nil + } + for _, bindings := range injector.bindings { + for _, binding := range bindings { + if binding.eager { + if _, err := injector.getInstance(binding.typeof, binding.annotatedWith, traceCircular); err != nil { + return err } } } } + return nil } // GetInstance creates a new instance of what was requested -func (injector *Injector) GetInstance(of interface{}) interface{} { - return injector.getInstance(of, "", traceCircular).Interface() +func (injector *Injector) GetInstance(of interface{}) (interface{}, error) { + i, err := injector.getInstance(of, "", traceCircular) + if err != nil { + return nil, err + } + return i.Interface(), nil } // GetAnnotatedInstance creates a new instance of what was requested with the given annotation -func (injector *Injector) GetAnnotatedInstance(of interface{}, annotatedWith string) interface{} { - return injector.getInstance(of, annotatedWith, traceCircular).Interface() +func (injector *Injector) GetAnnotatedInstance(of interface{}, annotatedWith string) (interface{}, error) { + i, err := injector.getInstance(of, annotatedWith, traceCircular) + if err != nil { + return nil, err + } + return i.Interface(), nil } // getInstance creates the new instance of typ, returns a reflect.value -func (injector *Injector) getInstance(typ interface{}, annotatedWith string, circularTrace []circularTraceEntry) reflect.Value { +func (injector *Injector) getInstance(typ interface{}, annotatedWith string, circularTrace []circularTraceEntry) (reflect.Value, error) { oftype := reflect.TypeOf(typ) if oft, ok := typ.(reflect.Type); ok { @@ -179,8 +203,7 @@ func (injector *Injector) getInstance(typ interface{}, annotatedWith string, cir func (injector *Injector) findBinding(t reflect.Type, annotation string) *Binding { if len(injector.bindings[t]) > 0 { - binding := injector.lookupBinding(t, annotation) - if binding != nil { + if binding := injector.lookupBinding(t, annotation); binding != nil { return binding } } @@ -199,40 +222,45 @@ func (injector *Injector) findBinding(t reflect.Type, annotation string) *Bindin } // resolveType resolves a requested type, with annotation -func (injector *Injector) resolveType(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) reflect.Value { +func (injector *Injector) resolveType(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) (reflect.Value, error) { if t.Kind() == reflect.Ptr { t = t.Elem() } var final reflect.Value + var err error if binding := injector.findBinding(t, annotation); binding != nil { if binding.scope != nil { if scope, ok := injector.scopes[reflect.TypeOf(binding.scope)]; ok { //final = scope.ResolveType(t, annotation, injector.internalResolveType) - final = scope.ResolveType(t, annotation, func(t reflect.Type, annotation string, optional bool) reflect.Value { + if final, err = scope.ResolveType(t, annotation, func(t reflect.Type, annotation string, optional bool) (reflect.Value, error) { return injector.internalResolveType(t, annotation, optional, circularTrace) - }) + }); err != nil { + return reflect.Value{}, err + } if !final.IsValid() { - panic(fmt.Sprintf("%T did no resolve %s", scope, t)) + return reflect.Value{}, fmtErrorf("%T did not resolve %s", scope, t) } } else { - panic(fmt.Sprintf("unknown scope %T for %s", binding.scope, t)) + return reflect.Value{}, fmtErrorf("unknown scope %T for %s", binding.scope, t) } } } if !final.IsValid() { - final = injector.internalResolveType(t, annotation, optional, circularTrace) + if final, err = injector.internalResolveType(t, annotation, optional, circularTrace); err != nil { + return reflect.Value{}, err + } } if !final.IsValid() { - panic("can not resolve " + t.String()) + return reflect.Value{}, fmtErrorf("can not resolve %q", t.String()) } final = injector.intercept(final, t) - return final + return final, nil } func (injector *Injector) intercept(final reflect.Value, t reflect.Type) reflect.Value { @@ -248,50 +276,71 @@ func (injector *Injector) intercept(final reflect.Value, t reflect.Type) reflect return final } +type errUnbound struct { + binding *Binding + typ reflect.Type +} + +func (err errUnbound) Error() string { + return fmt.Sprintf("binding is not bound: %v for %s", err.binding, err.typ.String()) +} + func (injector *Injector) resolveBinding(binding *Binding, t reflect.Type, optional bool, circularTrace []circularTraceEntry) (reflect.Value, error) { if binding.instance != nil { return binding.instance.ivalue, nil } if binding.provider != nil { - result := binding.provider.Create(injector) + result, err := binding.provider.Create(injector) + if err != nil { + return reflect.Value{}, err + } if result.Kind() == reflect.Slice { - result = injector.internalResolveType(result.Type(), "", optional, circularTrace) + if result, err = injector.internalResolveType(result.Type(), "", optional, circularTrace); err != nil { + return reflect.Value{}, err + } } else { - injector.requestInjection(result.Interface(), circularTrace) + if err := injector.requestInjection(result.Interface(), circularTrace); err != nil { + return reflect.Value{}, err + } } return result, nil } if binding.to != nil { if binding.to == t { - panic("circular from " + t.String() + " to " + binding.to.String() + " (annotated with: " + binding.annotatedWith + ")") + return reflect.Value{}, fmtErrorf("circular from %q to %q (annotated with: %q)", t, binding.to, binding.annotatedWith) } - return injector.resolveType(binding.to, "", optional, circularTrace), nil + return injector.resolveType(binding.to, "", optional, circularTrace) } - return reflect.Value{}, fmt.Errorf("binding is not bound: %v for %s", binding, t.String()) + return reflect.Value{}, errUnbound{binding: binding, typ: t} } // internalResolveType resolves a type request with the current injector -func (injector *Injector) internalResolveType(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) reflect.Value { +func (injector *Injector) internalResolveType(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) (reflect.Value, error) { if binding := injector.findBinding(t, annotation); binding != nil { r, err := injector.resolveBinding(binding, t, optional, circularTrace) + // todo: go 1.13/1.14: if err == nil || !errors.As(err, new(errUnbound)) { if err == nil { - return r + return r, nil + } else if err, ok := err.(errUnbound); !ok { + return r, err } + // todo: proper testcases if annotation != "" { return injector.resolveType(binding.typeof, "", false, circularTrace) } } // This for an injection request on a provider, such as `func() MyInstance` - if t.Kind() == reflect.Func && t.NumOut() == 1 && strings.HasSuffix(t.Name(), "Provider") { + if t.Kind() == reflect.Func && (t.NumOut() == 1 || t.NumOut() == 2) && strings.HasSuffix(t.Name(), "Provider") { + providerCanError := t.NumOut() == 2 && t.Out(1).AssignableTo(reflect.TypeOf(new(error)).Elem()) if traceCircular != nil { - return injector.createProvider(t, annotation, optional, make([]circularTraceEntry, 0)) + return injector.createProvider(t, annotation, optional, providerCanError, make([]circularTraceEntry, 0)), nil } - return injector.createProvider(t, annotation, optional, nil) + return injector.createProvider(t, annotation, optional, providerCanError, nil), nil } // This is the injection request for multibindings @@ -305,15 +354,15 @@ func (injector *Injector) internalResolveType(t reflect.Type, annotation string, } if annotation != "" && !optional { - panic("Can not automatically create an annotated injection " + t.String() + " with annotation " + annotation) + return reflect.Value{}, fmtErrorf("can not automatically create an annotated injection %q with annotation %q", t, annotation) } if t.Kind() == reflect.Interface && !optional { - panic("Can not instantiate interface " + t.String()) + return reflect.Value{}, fmtErrorf("can not instantiate interface %s.%s", t.PkgPath(), t.Name()) } if t.Kind() == reflect.Func && !optional { - panic("Can not create a new function " + t.String() + " (Do you want a provider? Then suffix type with Provider)") + return reflect.Value{}, fmtErrorf("can not create a new function %q (Do you want a provider? Then suffix type with Provider)", t) } if circularTrace != nil { @@ -331,16 +380,23 @@ func (injector *Injector) internalResolveType(t reflect.Type, annotation string, subCircularTrace = append(subCircularTrace, circularTraceEntry{t, annotation}) n := reflect.New(t) - injector.requestInjection(n.Interface(), subCircularTrace) - return n + return n, injector.requestInjection(n.Interface(), subCircularTrace) } n := reflect.New(t) - injector.requestInjection(n.Interface(), nil) - return n + return n, injector.requestInjection(n.Interface(), nil) } -func (injector *Injector) createProvider(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) reflect.Value { +func reflectedError(err *error, t reflect.Type) reflect.Value { + rerr := reflect.New(reflect.TypeOf(new(error)).Elem()).Elem() + if err == nil || *err == nil { + return rerr + } + rerr.Set(reflect.ValueOf(fmtErrorf("%q: %w", t, *err))) + return rerr +} + +func (injector *Injector) createProvider(t reflect.Type, annotation string, optional bool, canError bool, circularTrace []circularTraceEntry) reflect.Value { return reflect.MakeFunc(t, func(args []reflect.Value) (results []reflect.Value) { // create a new type res := reflect.New(t.Out(0)) @@ -349,24 +405,36 @@ func (injector *Injector) createProvider(t reflect.Type, annotation string, opti res = res.Elem() } + ret := func(v reflect.Value, err error) []reflect.Value { + if err != nil && !canError { + panic(fmtErrorf("%q: %w", t, err)) + } else if canError { + return []reflect.Value{v, reflectedError(&err, t)} + } else { + return []reflect.Value{v} + } + } + // multibindings if res.Elem().Kind() == reflect.Slice { - return []reflect.Value{injector.internalResolveType(t.Out(0), annotation, optional, circularTrace)} + return ret(injector.internalResolveType(t.Out(0), annotation, optional, circularTrace)) } // mapbindings if res.Elem().Kind() == reflect.Map && res.Elem().Type().Key().Kind() == reflect.String { - return []reflect.Value{injector.internalResolveType(t.Out(0), annotation, optional, circularTrace)} + return ret(injector.internalResolveType(t.Out(0), annotation, optional, circularTrace)) } - // set to actual value - res.Set(injector.getInstance(t.Out(0), annotation, circularTrace)) - // return - return []reflect.Value{res} + r := ret(injector.getInstance(t.Out(0), annotation, circularTrace)) + + res.Set(r[0]) + r[0] = res + + return r }) } -func (injector *Injector) createProviderForBinding(t reflect.Type, binding *Binding, annotation string, optional bool, circularTrace []circularTraceEntry) reflect.Value { +func (injector *Injector) createProviderForBinding(t reflect.Type, binding *Binding, annotation string, optional bool, canError bool, circularTrace []circularTraceEntry) reflect.Value { return reflect.MakeFunc(t, func(args []reflect.Value) (results []reflect.Value) { // create a new type res := reflect.New(binding.typeof) @@ -377,12 +445,25 @@ func (injector *Injector) createProviderForBinding(t reflect.Type, binding *Bind if r, err := injector.resolveBinding(binding, t, optional, circularTrace); err == nil { res.Set(r) + if canError { + return []reflect.Value{res, reflectedError(nil, t)} + } return []reflect.Value{res} } // set to actual value - res.Set(injector.getInstance(binding.typeof, annotation, circularTrace)) + i, err := injector.getInstance(binding.typeof, annotation, circularTrace) + if err != nil { + if canError { + return []reflect.Value{res, reflectedError(&err, t)} + } + panic(fmtErrorf("%q: %w", t, err)) + } + res.Set(i) // return + if canError { + return []reflect.Value{res, reflectedError(nil, t)} + } return []reflect.Value{res} }) } @@ -405,7 +486,7 @@ func (injector *Injector) joinMultibindings(t reflect.Type, annotation string) [ return bindings[:c] } -func (injector *Injector) resolveMultibinding(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) reflect.Value { +func (injector *Injector) resolveMultibinding(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) (reflect.Value, error) { targetType := t.Elem() if targetType.Kind() == reflect.Ptr { targetType = targetType.Elem() @@ -413,6 +494,7 @@ func (injector *Injector) resolveMultibinding(t reflect.Type, annotation string, providerType := targetType provider := strings.HasSuffix(targetType.Name(), "Provider") && targetType.Kind() == reflect.Func + providerCanError := provider && targetType.NumOut() == 2 && targetType.Out(1).AssignableTo(reflect.TypeOf(new(error)).Elem()) if provider { targetType = targetType.Out(0) @@ -422,20 +504,20 @@ func (injector *Injector) resolveMultibinding(t reflect.Type, annotation string, n := reflect.MakeSlice(t, 0, len(bindings)) for _, binding := range bindings { if provider { - n = reflect.Append(n, injector.createProviderForBinding(providerType, binding, annotation, false, circularTrace)) + n = reflect.Append(n, injector.createProviderForBinding(providerType, binding, annotation, false, providerCanError, circularTrace)) continue } r, err := injector.resolveBinding(binding, t, optional, circularTrace) if err != nil { - panic(err) + return reflect.Value{}, err } n = reflect.Append(n, r) } - return n + return n, nil } - return reflect.MakeSlice(t, 0, 0) + return reflect.MakeSlice(t, 0, 0), nil } func (injector *Injector) joinMapbindings(t reflect.Type, annotation string) map[string]*Binding { @@ -456,7 +538,7 @@ func (injector *Injector) joinMapbindings(t reflect.Type, annotation string) map return bindings } -func (injector *Injector) resolveMapbinding(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) reflect.Value { +func (injector *Injector) resolveMapbinding(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) (reflect.Value, error) { targetType := t.Elem() if targetType.Kind() == reflect.Ptr { targetType = targetType.Elem() @@ -464,6 +546,7 @@ func (injector *Injector) resolveMapbinding(t reflect.Type, annotation string, o providerType := targetType provider := strings.HasSuffix(targetType.Name(), "Provider") && targetType.Kind() == reflect.Func + providerCanError := provider && targetType.NumOut() == 2 && targetType.Out(1).AssignableTo(reflect.TypeOf(new(error)).Elem()) if provider { targetType = targetType.Out(0) @@ -473,20 +556,20 @@ func (injector *Injector) resolveMapbinding(t reflect.Type, annotation string, o n := reflect.MakeMapWithSize(t, len(bindings)) for key, binding := range bindings { if provider { - n.SetMapIndex(reflect.ValueOf(key), injector.createProviderForBinding(providerType, binding, annotation, false, circularTrace)) + n.SetMapIndex(reflect.ValueOf(key), injector.createProviderForBinding(providerType, binding, annotation, false, providerCanError, circularTrace)) continue } r, err := injector.resolveBinding(binding, t, optional, circularTrace) if err != nil { - panic(err) + return reflect.Value{}, err } n.SetMapIndex(reflect.ValueOf(key), r) } - return n + return n, nil } - return reflect.MakeMap(t) + return reflect.MakeMap(t), nil } // lookupBinding search a binding with the corresponding annotation @@ -584,28 +667,36 @@ func (injector *Injector) Override(what interface{}, annotatedWith string) *Bind } // RequestInjection requests the object to have all fields annotated with `inject` to be filled -func (injector *Injector) RequestInjection(object interface{}) { +func (injector *Injector) RequestInjection(object interface{}) error { if injector.stage == INIT { injector.delayed = append(injector.delayed, object) } else { - injector.requestInjection(object, traceCircular) + return injector.requestInjection(object, traceCircular) } + return nil } -func (injector *Injector) requestInjection(object interface{}, circularTrace []circularTraceEntry) { +func (injector *Injector) requestInjection(object interface{}, circularTrace []circularTraceEntry) error { if _, ok := object.(reflect.Value); !ok { object = reflect.ValueOf(object) } var injectlist = []reflect.Value{object.(reflect.Value)} var i int var current reflect.Value + var err error - defer func() { - if e := recover(); e != nil { - log.Printf("%s: %s: injecting into %s", current.Type().PkgPath(), current.Type().Name(), current.String()) - panic(e) + wrapErr := func(err error) error { + path := current.Type().PkgPath() + if path == "" { + if current.Kind() == reflect.Ptr { + path = current.Elem().Type().PkgPath() + } } - }() + if path != "" { + path += "." + } + return fmtErrorf("injecting into %s%s:\n%w", path, current.String(), err) + } for { if i >= len(injectlist) { @@ -623,7 +714,9 @@ func (injector *Injector) requestInjection(object interface{}, circularTrace []c if setup := current.MethodByName("Inject"); setup.IsValid() { args := make([]reflect.Value, setup.Type().NumIn()) for i := range args { - args[i] = injector.getInstance(setup.Type().In(i), "", circularTrace) + if args[i], err = injector.getInstance(setup.Type().In(i), "", circularTrace); err != nil { + return wrapErr(err) + } } setup.Call(args) } @@ -637,7 +730,7 @@ func (injector *Injector) requestInjection(object interface{}, circularTrace []c field := current.Field(fieldIndex) if field.Kind() == reflect.Struct { - panic(fmt.Sprintf("Can not inject into struct %#v of %#v", field, current)) + return fmtErrorf("can not inject into struct %#v of %#v", field, current) } var optional bool @@ -649,7 +742,10 @@ func (injector *Injector) requestInjection(object interface{}, circularTrace []c } tag = strings.Split(tag, ",")[0] - instance := injector.resolveType(field.Type(), tag, optional, circularTrace) + instance, err := injector.resolveType(field.Type(), tag, optional, circularTrace) + if err != nil { + return wrapErr(err) + } if instance.Kind() == reflect.Ptr { if instance.Elem().Kind() == reflect.Func || instance.Elem().Kind() == reflect.Interface || instance.Elem().Kind() == reflect.Slice { instance = instance.Elem() @@ -667,6 +763,7 @@ func (injector *Injector) requestInjection(object interface{}, circularTrace []c continue } } + return nil } // Debug Output diff --git a/dingo_1.12.go b/dingo_1.12.go new file mode 100644 index 0000000..88122c9 --- /dev/null +++ b/dingo_1.12.go @@ -0,0 +1,14 @@ +//+build go1.12 + +package dingo + +import ( + "fmt" + "strings" +) + +func init() { + fmtErrorf = func(format string, a ...interface{}) error { + return fmt.Errorf(strings.Replace(format, "%w", "%v", 1), a...) + } +} diff --git a/dingo_child_test.go b/dingo_child_test.go new file mode 100644 index 0000000..a0538e6 --- /dev/null +++ b/dingo_child_test.go @@ -0,0 +1,48 @@ +package dingo + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type ( + childIface interface{} + childParentIface interface { + child() childIface + } + + childIfaceProvider func() childIface + + childParentIfaceImpl struct { + childInstance childIfaceProvider + } + childIfaceImpl struct{} +) + +func (i *childParentIfaceImpl) Inject(childInstance childIfaceProvider) { + i.childInstance = childInstance +} + +func (i *childParentIfaceImpl) child() childIface { + return i.childInstance() +} + +func TestChild(t *testing.T) { + injector, err := NewInjector() + assert.NoError(t, err) + injector.Bind(new(childParentIface)).To(new(childParentIfaceImpl)) + + child, err := injector.Child() + assert.NoError(t, err) + child.Bind(new(childIface)).To(new(childIfaceImpl)) + + _, err = injector.GetInstance(new(childParentIface)) + assert.NoError(t, err) + + // we can get an instance in child, because we have a binding here + i, err := child.GetInstance(new(childParentIface)) + assert.NoError(t, err) + + assert.NotNil(t, i.(childParentIface).child()) +} diff --git a/dingo_setup_test.go b/dingo_setup_test.go index 7c81505..aa720e6 100644 --- a/dingo_setup_test.go +++ b/dingo_setup_test.go @@ -25,13 +25,16 @@ func (s *setupT1) Inject(member1 string, annotated *struct { } func Test_Dingo_Setup(t *testing.T) { - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) injector.Bind((*string)(nil)).ToInstance("Member 1") injector.Bind((*string)(nil)).AnnotatedWith("annotation2").ToInstance("Member 2") injector.Bind((*string)(nil)).AnnotatedWith("annotation3").ToInstance("Member 3") injector.Bind((*string)(nil)).AnnotatedWith("annotation4").ToInstance("Member 4") - test := injector.GetInstance((*setupT1)(nil)).(*setupT1) + i, err := injector.GetInstance((*setupT1)(nil)) + assert.NoError(t, err) + test := i.(*setupT1) assert.Equal(t, test.member1, "Member 1") assert.Equal(t, test.member2, "Member 2") diff --git a/dingo_test.go b/dingo_test.go index e0d1f78..d515ca5 100644 --- a/dingo_test.go +++ b/dingo_test.go @@ -22,16 +22,18 @@ type ( i int } - IfaceProvider func() Interface + IfaceProvider func() Interface + IfaceWithErrorProvider func() (Interface, error) DepTest struct { Iface Interface `inject:""` Iface2 Interface `inject:"test"` - IfaceProvider IfaceProvider `inject:""` - IfaceProvided Interface `inject:"provider"` - IfaceImpl1Provided Interface `inject:"providerimpl1"` - IfaceInstance Interface `inject:"instance"` + IfaceProvider IfaceProvider `inject:""` + IfaceWithErrorProvider IfaceWithErrorProvider `inject:""` + IfaceProvided Interface `inject:"provider"` + IfaceImpl1Provided Interface `inject:"providerimpl1"` + IfaceInstance Interface `inject:"instance"` } TestSingleton struct { @@ -77,20 +79,25 @@ func (if2 *InterfaceImpl2) Test() int { func TestDingoResolving(t *testing.T) { t.Run("Should resolve dependencies on request", func(t *testing.T) { - injector := NewInjector(new(PreTestModule), new(TestModule)) + injector, err := NewInjector(new(PreTestModule), new(TestModule)) + assert.NoError(t, err) + i, err := injector.GetInstance(new(Interface)) + assert.NoError(t, err) var iface Interface - iface = injector.GetInstance(new(Interface)).(Interface) + iface = i.(Interface) assert.Equal(t, 1, iface.Test()) - dt := *injector.GetInstance(new(DepTest)).(*DepTest) + i, err = injector.GetInstance(new(DepTest)) + assert.NoError(t, err) + dt := *i.(*DepTest) assert.Equal(t, 1, dt.Iface.Test()) assert.Equal(t, 2, dt.Iface2.Test()) var dt2 DepTest - injector.requestInjection(&dt2, nil) + assert.NoError(t, injector.requestInjection(&dt2, nil)) assert.Equal(t, 1, dt2.Iface.Test()) assert.Equal(t, 2, dt2.Iface2.Test()) @@ -100,17 +107,47 @@ func TestDingoResolving(t *testing.T) { assert.Equal(t, 2, dt.IfaceInstance.Test()) assert.Equal(t, 1, dt.IfaceProvider().Test()) + iface, err = dt.IfaceWithErrorProvider() + assert.NoError(t, err) + assert.Equal(t, 1, iface.Test()) assert.Equal(t, "Hello World", dt.IfaceProvided.(*InterfaceImpl1).foo) assert.Equal(t, "Hello World", dt.IfaceImpl1Provided.(*InterfaceImpl1).foo) }) t.Run("Should resolve scopes", func(t *testing.T) { - injector := NewInjector(new(TestModule)) + injector, err := NewInjector(new(TestModule)) + assert.NoError(t, err) + + i1, err := injector.GetInstance(TestSingleton{}) + assert.NoError(t, err) + i2, err := injector.GetInstance(TestSingleton{}) + assert.NoError(t, err) + assert.Equal(t, i1, i2) + }) - assert.Equal(t, injector.GetInstance(TestSingleton{}), injector.GetInstance(TestSingleton{})) + t.Run("Error cases", func(t *testing.T) { + var injector *Injector + _, err := injector.Child() + assert.Error(t, err) }) } +type testBoundNothingProvider func() *InterfaceImpl1 + +func TestBoundToNothing(t *testing.T) { + injector, err := NewInjector() + assert.NoError(t, err) + + injector.Bind(new(InterfaceImpl1)).AnnotatedWith("test") + + i, err := injector.GetInstance(new(testBoundNothingProvider)) + assert.NoError(t, err) + ii, ok := i.(testBoundNothingProvider) + assert.True(t, ok) + assert.NotNil(t, ii) + assert.NotNil(t, ii()) +} + // interceptors type ( AopInterface interface { @@ -154,10 +191,11 @@ func (a *AopInterceptor2) Test() string { } func TestInterceptors(t *testing.T) { - injector := NewInjector(new(AopModule)) + injector, err := NewInjector(new(AopModule)) + assert.NoError(t, err) var dep AopDep - injector.requestInjection(&dep, nil) + assert.NoError(t, injector.requestInjection(&dep, nil)) assert.Equal(t, "Test 1 2", dep.A.Test()) } @@ -169,21 +207,23 @@ func TestOptional(t *testing.T) { Optional2 string `inject:"option, optional"` } - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) - assert.Panics(t, func() { - _ = injector.GetInstance(new(test)).(*test) - }, "should panic because `must` is unbound") + _, err = injector.GetInstance(new(test)) + assert.Error(t, err) injector.Bind(new(string)).AnnotatedWith("must").ToInstance("must") - i := injector.GetInstance(new(test)).(*test) - assert.Equal(t, i.Must, "must") - assert.Equal(t, i.Optional, "") - assert.Equal(t, i.Optional2, "") + i, err := injector.GetInstance(new(test)) + assert.NoError(t, err) + assert.Equal(t, i.(*test).Must, "must") + assert.Equal(t, i.(*test).Optional, "") + assert.Equal(t, i.(*test).Optional2, "") injector.Bind(new(string)).AnnotatedWith("option").ToInstance("option") - i = injector.GetInstance(new(test)).(*test) - assert.Equal(t, i.Must, "must") - assert.Equal(t, i.Optional, "option") - assert.Equal(t, i.Optional2, "option") + i, err = injector.GetInstance(new(test)) + assert.NoError(t, err) + assert.Equal(t, i.(*test).Must, "must") + assert.Equal(t, i.(*test).Optional, "option") + assert.Equal(t, i.(*test).Optional2, "option") } diff --git a/example/main.go b/example/main.go index 909a5de..53a31af 100644 --- a/example/main.go +++ b/example/main.go @@ -30,16 +30,22 @@ func (*defaultModule) Configure(injector *dingo.Injector) { func main() { // create a new injector and load modules - injector := dingo.NewInjector( + injector, err := dingo.NewInjector( new(paypal.Module), new(defaultModule), ) + if err != nil { + log.Fatal(err) + } // instantiate the application service - service := injector.GetInstance(application.Service{}).(*application.Service) + service, err := injector.GetInstance(application.Service{}) + if err != nil { + log.Fatal(err) + } // make a transaction - if err := service.MakeTransaction(99.95, "test transaction"); err != nil { + if err := service.(*application.Service).MakeTransaction(99.95, "test transaction"); err != nil { log.Fatal(err) } } diff --git a/go.mod b/go.mod index a910669..48cfced 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,9 @@ module flamingo.me/dingo -go 1.12 +go 1.13 require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/testify v1.2.2 + github.com/stretchr/testify v1.4.0 + golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f // indirect ) diff --git a/go.sum b/go.sum index e03ee77..168fe5a 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,23 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f h1:kDxGY2VmgABOe55qheT/TFqUMtcTHnomIPS1iv3G4Ms= +golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/miniexample/main.go b/miniexample/main.go index ce3ae70..cb0ff63 100644 --- a/miniexample/main.go +++ b/miniexample/main.go @@ -23,13 +23,19 @@ func (*loggerModule) Configure(injector *dingo.Injector) { func main() { // create a new injector - injector := dingo.NewInjector( + injector, err := dingo.NewInjector( new(loggerModule), ) + if err != nil { + log.Fatal(err) + } // instantiate the log service - service := injector.GetInstance(logger.LogService{}).(*logger.LogService) + service, err := injector.GetInstance(logger.LogService{}) + if err != nil { + log.Fatal(err) + } // do a sample log using our service - service.DoLog("here is an example log") + service.(*logger.LogService).DoLog("here is an example log") } diff --git a/module.go b/module.go index 7c12dda..0fb6c06 100644 --- a/module.go +++ b/module.go @@ -30,10 +30,12 @@ func TryModule(modules ...Module) (resultingError error) { } }() - injector := NewInjector() + injector, err := NewInjector() + if err != nil { + return err + } injector.buildEagerSingletons = false - injector.InitModules(modules...) - return nil + return injector.InitModules(modules...) } // resolveDependencies tries to get a complete list of all modules, including all dependencies diff --git a/multi_dingo_test.go b/multi_dingo_test.go index 868fdc0..9752c13 100644 --- a/multi_dingo_test.go +++ b/multi_dingo_test.go @@ -36,13 +36,16 @@ type ( ) func TestMultiBinding(t *testing.T) { - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) injector.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey instance") injector.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey2 instance") injector.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey3 instance") - test := injector.GetInstance(&multiBindTest{}).(*multiBindTest) + i, err := injector.GetInstance(&multiBindTest{}) + assert.NoError(t, err) + test := i.(*multiBindTest) list := test.Mb assert.Len(t, list, 3) @@ -53,16 +56,20 @@ func TestMultiBinding(t *testing.T) { } func TestMultiBindingChild(t *testing.T) { - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) injector.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey instance") injector.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey2 instance") injector.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey3 instance") - child := injector.Child() + child, err := injector.Child() + assert.NoError(t, err) child.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey4 instance") - test := injector.GetInstance(&multiBindTest{}).(*multiBindTest) + i, err := injector.GetInstance(&multiBindTest{}) + assert.NoError(t, err) + test := i.(*multiBindTest) list := test.Mb assert.Len(t, list, 3) @@ -71,7 +78,9 @@ func TestMultiBindingChild(t *testing.T) { assert.Equal(t, "testkey2 instance", list[1]) assert.Equal(t, "testkey3 instance", list[2]) - test = child.GetInstance(&multiBindTest{}).(*multiBindTest) + i, err = child.GetInstance(&multiBindTest{}) + assert.NoError(t, err) + test = i.(*multiBindTest) list = test.Mb assert.Len(t, list, 4) @@ -83,13 +92,16 @@ func TestMultiBindingChild(t *testing.T) { } func TestMultiBindingProvider(t *testing.T) { - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) injector.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey instance") injector.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey2 instance") injector.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey3 instance") - test := injector.GetInstance(&multiBindProviderTest{}).(*multiBindProviderTest) + i, err := injector.GetInstance(&multiBindProviderTest{}) + assert.NoError(t, err) + test := i.(*multiBindProviderTest) list := test.Mbp() assert.Len(t, list, 3) @@ -100,13 +112,16 @@ func TestMultiBindingProvider(t *testing.T) { } func TestMultiBindingComplex(t *testing.T) { - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) injector.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey instance") injector.BindMulti((*mapBindInterface)(nil)).To("testkey2 instance") injector.BindMulti((*mapBindInterface)(nil)).ToProvider(func() mapBindInterface { return "provided" }) - test := injector.GetInstance(&multiBindTest{}).(*multiBindTest) + i, err := injector.GetInstance(&multiBindTest{}) + assert.NoError(t, err) + test := i.(*multiBindTest) list := test.Mb assert.Len(t, list, 3) @@ -117,13 +132,16 @@ func TestMultiBindingComplex(t *testing.T) { } func TestMultiBindingComplexProvider(t *testing.T) { - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) injector.BindMulti((*mapBindInterface)(nil)).ToInstance("testkey instance") injector.BindMulti((*mapBindInterface)(nil)).To("testkey2 instance") injector.BindMulti((*mapBindInterface)(nil)).ToProvider(func() mapBindInterface { return "provided" }) - test := injector.GetInstance(&multiBindProviderTest{}).(*multiBindProviderTest) + i, err := injector.GetInstance(&multiBindProviderTest{}) + assert.NoError(t, err) + test := i.(*multiBindProviderTest) list := test.Mbp() assert.Len(t, list, 3) @@ -134,13 +152,16 @@ func TestMultiBindingComplexProvider(t *testing.T) { } func TestMapBinding(t *testing.T) { - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) injector.BindMap((*mapBindInterface)(nil), "testkey").ToInstance("testkey instance") injector.BindMap((*mapBindInterface)(nil), "testkey2").ToInstance("testkey2 instance") injector.BindMap((*mapBindInterface)(nil), "testkey3").ToInstance("testkey3 instance") - test1 := injector.GetInstance(&mapBindTest1{}).(*mapBindTest1) + i, err := injector.GetInstance(&mapBindTest1{}) + assert.NoError(t, err) + test1 := i.(*mapBindTest1) test1map := test1.Mbp() assert.Len(t, test1map, 3) @@ -148,21 +169,27 @@ func TestMapBinding(t *testing.T) { assert.Equal(t, "testkey2 instance", test1map["testkey2"]) assert.Equal(t, "testkey3 instance", test1map["testkey3"]) - test2 := injector.GetInstance(&mapBindTest2{}).(*mapBindTest2) + i, err = injector.GetInstance(&mapBindTest2{}) + assert.NoError(t, err) + test2 := i.(*mapBindTest2) assert.Equal(t, test2.Mb, "testkey instance") } func TestMapBindingChild(t *testing.T) { - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) injector.BindMap((*mapBindInterface)(nil), "testkey").ToInstance("testkey instance") injector.BindMap((*mapBindInterface)(nil), "testkey2").ToInstance("testkey2 instance") injector.BindMap((*mapBindInterface)(nil), "testkey3").ToInstance("testkey3 instance") - child := injector.Child() + child, err := injector.Child() + assert.NoError(t, err) child.BindMap((*mapBindInterface)(nil), "testkey4").ToInstance("testkey4 instance") - test1 := injector.GetInstance(&mapBindTest1{}).(*mapBindTest1) + i, err := injector.GetInstance(&mapBindTest1{}) + assert.NoError(t, err) + test1 := i.(*mapBindTest1) test1map := test1.Mbp() assert.Len(t, test1map, 3) @@ -170,10 +197,14 @@ func TestMapBindingChild(t *testing.T) { assert.Equal(t, "testkey2 instance", test1map["testkey2"]) assert.Equal(t, "testkey3 instance", test1map["testkey3"]) - test2 := injector.GetInstance(&mapBindTest2{}).(*mapBindTest2) + i, err = injector.GetInstance(&mapBindTest2{}) + assert.NoError(t, err) + test2 := i.(*mapBindTest2) assert.Equal(t, test2.Mb, "testkey instance") - testChild := child.GetInstance(&mapBindTest1{}).(*mapBindTest1) + i, err = child.GetInstance(&mapBindTest1{}) + assert.NoError(t, err) + testChild := i.(*mapBindTest1) testChildmap := testChild.Mbp() assert.Len(t, testChildmap, 4) @@ -184,13 +215,16 @@ func TestMapBindingChild(t *testing.T) { } func TestMapBindingProvider(t *testing.T) { - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) injector.BindMap((*mapBindInterface)(nil), "testkey").ToInstance("testkey instance") injector.BindMap((*mapBindInterface)(nil), "testkey2").ToInstance("testkey2 instance") injector.BindMap((*mapBindInterface)(nil), "testkey3").ToInstance("testkey3 instance") - test := injector.GetInstance(&mapBindTest3{}).(*mapBindTest3) + i, err := injector.GetInstance(&mapBindTest3{}) + assert.NoError(t, err) + test := i.(*mapBindTest3) testmap := test.Mbp() assert.Len(t, testmap, 3) diff --git a/scope.go b/scope.go index 21e6c3f..ab0262e 100644 --- a/scope.go +++ b/scope.go @@ -8,7 +8,7 @@ import ( type ( // Scope defines a scope's behaviour Scope interface { - ResolveType(t reflect.Type, annotation string, unscoped func(t reflect.Type, annotation string, optional bool) reflect.Value) reflect.Value + ResolveType(t reflect.Type, annotation string, unscoped func(t reflect.Type, annotation string, optional bool) (reflect.Value, error)) (reflect.Value, error) } identifier struct { @@ -47,7 +47,7 @@ func NewChildSingletonScope() *ChildSingletonScope { } // ResolveType resolves a request in this scope -func (s *SingletonScope) ResolveType(t reflect.Type, annotation string, unscoped func(t reflect.Type, annotation string, optional bool) reflect.Value) reflect.Value { +func (s *SingletonScope) ResolveType(t reflect.Type, annotation string, unscoped func(t reflect.Type, annotation string, optional bool) (reflect.Value, error)) (reflect.Value, error) { ident := identifier{t, annotation} // try to get the instance type lock @@ -60,7 +60,7 @@ func (s *SingletonScope) ResolveType(t reflect.Type, annotation string, unscoped defer l.RUnlock() instance, _ := s.instances.Load(ident) - return instance.(reflect.Value) + return instance.(reflect.Value), nil } s.instanceLock[ident] = new(sync.RWMutex) @@ -68,15 +68,15 @@ func (s *SingletonScope) ResolveType(t reflect.Type, annotation string, unscoped l.Lock() s.mu.Unlock() - instance := unscoped(t, annotation, false) + instance, err := unscoped(t, annotation, false) s.instances.Store(ident, instance) defer l.Unlock() - return instance + return instance, err } // ResolveType delegates to SingletonScope.ResolveType -func (c *ChildSingletonScope) ResolveType(t reflect.Type, annotation string, unscoped func(t reflect.Type, annotation string, optional bool) reflect.Value) reflect.Value { +func (c *ChildSingletonScope) ResolveType(t reflect.Type, annotation string, unscoped func(t reflect.Type, annotation string, optional bool) (reflect.Value, error)) (reflect.Value, error) { return (*SingletonScope)(c).ResolveType(t, annotation, unscoped) } diff --git a/scope_test.go b/scope_test.go index fd494f5..e8396b6 100644 --- a/scope_test.go +++ b/scope_test.go @@ -17,15 +17,15 @@ func testScope(t *testing.T, scope Scope) { test := reflect.TypeOf("string") test2 := reflect.TypeOf("int") - unscoped := func(t reflect.Type, annotation string, optional bool) reflect.Value { + unscoped := func(t reflect.Type, annotation string, optional bool) (reflect.Value, error) { atomic.AddInt64(&requestedUnscoped, 1) time.Sleep(1 * time.Nanosecond) if optional { - return reflect.Value{} + return reflect.Value{}, nil } - return reflect.New(t).Elem() + return reflect.New(t).Elem(), nil } runs := 100 // change to 10? 100? 1000? to trigger a bug? todo investigate @@ -34,10 +34,14 @@ func testScope(t *testing.T, scope Scope) { wg.Add(runs) for i := 0; i < runs; i++ { go func() { - t1 := scope.ResolveType(test, "", unscoped) - t12 := scope.ResolveType(test2, "", unscoped) - t2 := scope.ResolveType(test, "", unscoped) - t22 := scope.ResolveType(test2, "", unscoped) + t1, err := scope.ResolveType(test, "", unscoped) + assert.NoError(t, err) + t12, err := scope.ResolveType(test2, "", unscoped) + assert.NoError(t, err) + t2, err := scope.ResolveType(test, "", unscoped) + assert.NoError(t, err) + t22, err := scope.ResolveType(test2, "", unscoped) + assert.NoError(t, err) assert.Equal(t, t1, t2) assert.Equal(t, t12, t22) wg.Done() @@ -78,7 +82,8 @@ type ( func TestScopeWithSubDependencies(t *testing.T) { for i := 0; i < 10; i++ { t.Run(fmt.Sprintf("Run %d", i), func(t *testing.T) { - injector := NewInjector() + injector, err := NewInjector() + assert.NoError(t, err) injector.Bind(new(singletonA)).In(Singleton) injector.Bind(new(singletonB)).In(Singleton) @@ -90,7 +95,9 @@ func TestScopeWithSubDependencies(t *testing.T) { wg.Add(runs) for i := 0; i < runs; i++ { go func() { - a := injector.GetInstance(new(singletonA)).(*singletonA) + i, err := injector.GetInstance(new(singletonA)) + assert.NoError(t, err) + a := i.(*singletonA) assert.Equal(t, a.B.C, singletonC("singleton C")) wg.Done() }()