diff --git a/devpkg/injectablegen/injectable.go b/devpkg/injectablegen/injectable.go index d71561b..35241b2 100644 --- a/devpkg/injectablegen/injectable.go +++ b/devpkg/injectablegen/injectable.go @@ -14,9 +14,9 @@ func init() { } type injectableGen struct { - publicProviderInterface *types.Interface - publicInitInterface *types.Interface - once sync.Once + publicInjectContextInterface *types.Interface + publicInitInterface *types.Interface + once sync.Once } func (*injectableGen) Name() string { @@ -27,7 +27,7 @@ func (g *injectableGen) init(c gengo.Context) { { sig := c.Package("context").Function("Cause").Signature() - g.publicProviderInterface = types.NewInterfaceType([]*types.Func{ + g.publicInjectContextInterface = types.NewInterfaceType([]*types.Func{ types.NewFunc(0, c.Package("context").Pkg(), "InjectContext", types.NewSignatureType(nil, nil, nil, types.NewTuple(sig.Params().At(0)), @@ -175,6 +175,18 @@ ctx = @FieldType'InjectContext(ctx, p.@Field) }) } } + + if !exists { + if g.hasPublicInjectContext(c, f.Type()) { + sw.Render(gengo.Snippet{ + gengo.T: ` +ctx = p.@Field.InjectContext(ctx) +`, + "Field": gengo.ID(f.Name()), + }) + continue + } + } } } @@ -359,16 +371,20 @@ if err := v.afterInit(ctx); err != nil { return nil } -func (g *injectableGen) isInjectable(c gengo.Context, t types.Type) bool { +func (g *injectableGen) hasPublicInjectContext(c gengo.Context, t types.Type) bool { switch x := t.(type) { case *types.Pointer: - return g.isInjectable(c, x.Elem()) + return g.hasPublicInjectContext(c, x.Elem()) case *types.Named: + _, isStruct := x.Underlying().(*types.Struct) + if !isStruct { + return false + } tags, _ := c.Doc(x.Obj()) - if _, ok := tags["gengo:injectable"]; ok { - return true + if _, ok := tags["gengo:injectable:provider"]; ok { + return ok } - return types.Implements(x, g.publicProviderInterface) || types.Implements(types.NewPointer(x), g.publicProviderInterface) + return types.Implements(x, g.publicInjectContextInterface) || types.Implements(types.NewPointer(x), g.publicInjectContextInterface) } return false @@ -379,10 +395,15 @@ func (g *injectableGen) hasPublicInit(c gengo.Context, t types.Type) bool { case *types.Pointer: return g.hasPublicInit(c, x.Elem()) case *types.Named: + _, isStruct := x.Underlying().(*types.Struct) + if !isStruct { + return false + } tags, _ := c.Doc(x.Obj()) - if _, ok := tags["gengo:injectable:provider"]; ok { - _, ok := x.Obj().Type().(*types.Struct) - return ok + _, injectable := tags["gengo:injectable"] + _, injectableProvider := tags["gengo:injectable:provider"] + if injectable || injectableProvider { + return true } return types.Implements(x, g.publicInitInterface) || types.Implements(types.NewPointer(x), g.publicInitInterface) }