diff --git a/inject.go b/inject.go index d04d9bd..3ff713c 100644 --- a/inject.go +++ b/inject.go @@ -157,10 +157,29 @@ func (i *injector) Set(typ reflect.Type, val reflect.Value) TypeMapper { func (i *injector) Get(t reflect.Type) reflect.Value { val := i.values[t] + + if val.IsValid() { + return val + } + + // no concrete types found, try to find implementors + // if t is an interface + if t.Kind() == reflect.Interface { + for k, v := range i.values { + if k.Implements(t) { + val = v + break + } + } + } + + // Still no type found, try to look it up on the parent if !val.IsValid() && i.parent != nil { val = i.parent.Get(t) } + return val + } func (i *injector) SetParent(parent Injector) { diff --git a/inject_test.go b/inject_test.go index 419a328..b436dd0 100644 --- a/inject_test.go +++ b/inject_test.go @@ -1,6 +1,7 @@ package inject_test import ( + "fmt" "github.com/codegangsta/inject" "reflect" "testing" @@ -15,6 +16,14 @@ type TestStruct struct { Dep3 string } +type Greeter struct { + Name string +} + +func (g *Greeter) String() string { + return "Hello, My name is" + g.Name +} + /* Test Helpers */ func expect(t *testing.T, a interface{}, b interface{}) { if a != b { @@ -140,3 +149,11 @@ func Test_InjectorSetParent(t *testing.T) { expect(t, injector2.Get(inject.InterfaceOf((*SpecialString)(nil))).IsValid(), true) } + +func TestInjectImplementors(t *testing.T) { + injector := inject.New() + g := &Greeter{"Jeremy"} + injector.Map(g) + + expect(t, injector.Get(inject.InterfaceOf((*fmt.Stringer)(nil))).IsValid(), true) +}