Skip to content

Commit

Permalink
Fixed code generation issue when IEnumerable<T> is injected inside …
Browse files Browse the repository at this point in the history
…a factory
  • Loading branch information
Nikolay Pyanikov committed Aug 4, 2023
1 parent d1cb43e commit 8867150
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 14 deletions.
19 changes: 12 additions & 7 deletions src/Pure.DI.Core/Core/Code/CompositionBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ protected override void VisitImplementation(
in DpImplementation implementation,
CancellationToken cancellationToken)
{
if (!instantiation.Target.IsCreationRequired(instantiation.Target.Node))
{
return;
}

base.VisitImplementation(context, root, instantiation, implementation, cancellationToken);
AddReturnStatement(context, root, instantiation);
instantiation.Target.IsCreated = true;
Expand Down Expand Up @@ -218,7 +223,7 @@ protected override void VisitEnumerableConstruct(
Instantiation instantiation,
CancellationToken cancellationToken)
{
if (instantiation.Target.IsCreated)
if (!instantiation.Target.IsCreationRequired(instantiation.Target.Node))
{
return;
}
Expand All @@ -240,7 +245,7 @@ protected override void VisitEnumerableConstruct(

if (arg.Node.Lifetime != Lifetime.PerResolve)
{
arg.IsCreated = false;
arg.AllowCreation();
VisitRootVariable(context with { IsRootContext = false }, dependencyGraph, context.Variables, arg, cancellationToken);
}

Expand Down Expand Up @@ -269,7 +274,7 @@ protected override void VisitArrayConstruct(
in DpConstruct construct,
Instantiation instantiation)
{
if (instantiation.Target.IsCreated)
if (!instantiation.Target.IsCreationRequired(instantiation.Target.Node))
{
return;
}
Expand All @@ -289,7 +294,7 @@ protected override void VisitSpanConstruct(
in DpConstruct construct,
Instantiation instantiation)
{
if (instantiation.Target.IsCreated)
if (!instantiation.Target.IsCreationRequired(instantiation.Target.Node))
{
return;
}
Expand All @@ -315,7 +320,7 @@ protected override void VisitCompositionConstruct(
in DpConstruct construct,
Instantiation instantiation)
{
if (instantiation.Target.IsCreated)
if (!instantiation.Target.IsCreationRequired(instantiation.Target.Node))
{
return;
}
Expand All @@ -327,7 +332,7 @@ protected override void VisitCompositionConstruct(

protected override void VisitOnCannotResolve(BuildContext context, Variable root, in DpConstruct construct, Instantiation instantiation)
{
if (instantiation.Target.IsCreated)
if (!instantiation.Target.IsCreationRequired(instantiation.Target.Node))
{
return;
}
Expand Down Expand Up @@ -357,7 +362,7 @@ protected override void VisitFactory(
in DpFactory factory,
CancellationToken cancellationToken)
{
if (instantiation.Target.IsCreated)
if (!instantiation.Target.IsCreationRequired(instantiation.Target.Node))
{
return;
}
Expand Down
10 changes: 4 additions & 6 deletions src/Pure.DI.Core/Core/CodeGraphWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ void ProcessVariable(Variable var)
if (targetVariable.Node.Construct is { Source.Kind: MdConstructKind.Enumerable })
{
// Will be created lazy in a local function
var.IsCreated = var.Node.Lifetime != Lifetime.PerResolve;
if (var.Node.Lifetime != Lifetime.PerResolve)
{
var.Owner = targetVariable.Node;
}
}

arguments.Add(var);
Expand Down Expand Up @@ -194,11 +197,6 @@ protected virtual void VisitImplementation(
in DpImplementation implementation,
CancellationToken cancellationToken)
{
if (instantiation.Target.IsCreated)
{
return;
}

var args = instantiation.Arguments.ToList();
var argsWalker = new DependenciesToVariablesWalker(args);
argsWalker.VisitConstructor(implementation.Constructor);
Expand Down
13 changes: 13 additions & 0 deletions src/Pure.DI.Core/Core/Models/Variable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ public string Name

public virtual bool IsCreated { get; set; }

public virtual DependencyNode? Owner { get; set; }

public bool IsCreationRequired(in DependencyNode node) =>
!IsCreated && (!Owner.HasValue || Owner.Equals(node));

public void AllowCreation() => Owner = default;

public virtual bool IsBlockRoot { get; init; }

public override string ToString() => Name;
Expand All @@ -74,6 +81,12 @@ public override bool IsCreated
set => _variable.IsCreated = value;
}

public override DependencyNode? Owner
{
get => _variable.Owner;
set => _variable.Owner = value;
}

public override bool IsBlockRoot => _variable.IsBlockRoot;

public override string ToString() => _variable.ToString();
Expand Down
81 changes: 81 additions & 0 deletions tests/Pure.DI.IntegrationTests/EnumerableInjectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,85 @@ public static void Main()
result.Success.ShouldBeTrue(result);
result.StdOut.ShouldBe(ImmutableArray.Create("Dependency created", "Service creating", "Dependency created", "Dependency created"), result);
}

[Fact]
public async Task ShouldSupportEnumerableWhenPerResolveInFunc()
{
// Given

// When
var result = await """
using System;
using System.Collections.Generic;
using System.Linq;
using Pure.DI;
namespace Sample
{
interface IDependency
{
int Id { get; }
}
class Dependency: IDependency
{
private static int _nextId;
public Dependency()
{
Id = _nextId++;
Console.WriteLine($"Dependency {_nextId} created");
}
public int Id { get; set; }
}
interface IService
{
}
class Service: IService
{
public Service(IEnumerable<int> deps)
{
Console.WriteLine("Service creating");
foreach (var dep in deps)
{
}
}
}
static class Setup
{
private static void SetupComposition()
{
// FormatCode = On
DI.Setup("Composition")
.Bind<IDependency>(1).To<Dependency>()
.Bind<IDependency>(2).As(Lifetime.PerResolve).To<Dependency>()
.Bind<IDependency>(3).To<Dependency>()
.Bind<IEnumerable<int>>().To(ctx =>
{
ctx.Inject(out IEnumerable<IDependency> dependencies);
return dependencies.Select(i => i.Id);
})
.Bind<IService>().To<Service>()
.Root<IService>("Service");
}
}
public class Program
{
public static void Main()
{
var composition = new Composition();
var service = composition.Service;
}
}
}
""".RunAsync();

// Then
result.Success.ShouldBeTrue(result);
result.StdOut.ShouldBe(ImmutableArray.Create("Dependency 1 created", "Service creating", "Dependency 2 created", "Dependency 3 created"), result);
}
}
2 changes: 1 addition & 1 deletion tests/Pure.DI.IntegrationTests/TestExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ internal static async Task<Result> RunAsync(this string setupCode, Options? opti
.WithOptions(new CSharpCompilationOptions(OutputKind.ConsoleApplication).WithNullableContextOptions(runOptions.NullableContextOptions))
.AddSyntaxTrees(generatedApiSources.Select(api => CSharpSyntaxTree.ParseText(api.SourceText, parseOptions)))
.AddSyntaxTrees(CSharpSyntaxTree.ParseText(setupCode, parseOptions));
//.Check(stdOut, options);
// .Check(stdOut, options);

var globalOptions = new TestAnalyzerConfigOptions(new Dictionary<string, string>
{
Expand Down

0 comments on commit 8867150

Please sign in to comment.