Skip to content

Commit

Permalink
Improved resolution of generic types
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolayPianikov committed Oct 25, 2023
1 parent 0980d3a commit 5c9df17
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 44 deletions.
16 changes: 0 additions & 16 deletions src/Pure.DI.Core/Core/ContractsBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,6 @@ namespace Pure.DI.Core;

internal sealed class ContractsBuilder: IBuilder<ContractsBuildContext, ISet<Injection>>
{
private readonly IMarker _marker;
private readonly IUnboundTypeConstructor _unboundTypeConstructor;

public ContractsBuilder(
IMarker marker,
IUnboundTypeConstructor unboundTypeConstructor)
{
_marker = marker;
_unboundTypeConstructor = unboundTypeConstructor;
}

public ISet<Injection> Build(ContractsBuildContext context)
{
var binding = context.Binding;
Expand All @@ -28,11 +17,6 @@ public ISet<Injection> Build(ContractsBuildContext context)
foreach (var contract in binding.Contracts)
{
var contractType = contract.ContractType;
if (_marker.IsMarkerBased(contractType))
{
contractType = _unboundTypeConstructor.Construct(binding.SemanticModel.Compilation, contractType);
}

var contractTags = new HashSet<object?>(bindingTags);
foreach (var tag in contract.Tags)
{
Expand Down
44 changes: 32 additions & 12 deletions src/Pure.DI.Core/Core/DependencyGraphBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,36 @@ public IEnumerable<DependencyNode> TryBuild(
case INamedTypeSymbol { IsGenericType: true } geneticType:
{
// Generic
var unboundType = _unboundTypeConstructor.Construct(targetNode.Binding.SemanticModel.Compilation, injection.Type);
var unboundInjection = injection with { Type = unboundType };
if (map.TryGetValue(unboundInjection, out sourceNode))
var isGenericOk = false;
foreach (var item in map)
{
if (!Injection.EqualTags(injection.Tag, item.Key.Tag))
{
continue;
}

if (item.Key.Type is not INamedTypeSymbol { IsGenericType: true })
{
continue;
}

var typeConstructor = _typeConstructorFactory();
if (!typeConstructor.TryBind(item.Key.Type, injection.Type))
{
continue;
}

sourceNode = item.Value;
var genericBinding = CreateGenericBinding(targetNode, injection, sourceNode, typeConstructor, ++maxId);
var genericNode = CreateNodes(setup, genericBinding).Single(i => i.Variation == sourceNode.Variation);
map[injection] = genericNode;
queue.Enqueue(CreateNewProcessingNode(injection, genericNode));
isGenericOk = true;
break;
}

if (isGenericOk)
{
var newBinding = CreateGenericBinding(targetNode, injection, sourceNode, ++maxId);
var newNode = CreateNodes(setup, newBinding)
.Single(i => i.Variation == sourceNode.Variation);
map[injection] = newNode;
queue.Enqueue(CreateNewProcessingNode(injection, newNode));
continue;
}

Expand Down Expand Up @@ -260,16 +281,15 @@ private MdBinding CreateGenericBinding(
DependencyNode targetNode,
Injection injection,
DependencyNode sourceNode,
ITypeConstructor typeConstructor,
int newId)
{
var semanticModel = targetNode.Binding.SemanticModel;
var compilation = semanticModel.Compilation;
var typeConstructor = _typeConstructorFactory();
typeConstructor.Bind(sourceNode.Type, injection.Type);
var newContracts = sourceNode.Binding.Contracts
.Select(contract => contract with
{
ContractType = typeConstructor.Construct(semanticModel.Compilation, contract.ContractType),
ContractType = injection.Type,
Tags = contract.Tags.Select( tag => CreateTag(injection, tag)).Where(tag => tag.HasValue).Select(tag => tag!.Value).ToImmutableArray()
})
.ToImmutableArray();
Expand Down Expand Up @@ -309,7 +329,7 @@ private MdBinding CreateAutoBinding(
var typeConstructor = _typeConstructorFactory();
if (_marker.IsMarkerBased(injection.Type))
{
typeConstructor.Bind(injection.Type, injection.Type);
typeConstructor.TryBind(injection.Type, injection.Type);
sourceType = typeConstructor.Construct(compilation, injection.Type);
}

Expand Down
2 changes: 1 addition & 1 deletion src/Pure.DI.Core/Core/ITypeConstructor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ namespace Pure.DI.Core;

internal interface ITypeConstructor
{
void Bind(ITypeSymbol source, ITypeSymbol target);
bool TryBind(ITypeSymbol source, ITypeSymbol target);

ITypeSymbol Construct(Compilation compilation, ITypeSymbol type);
}
2 changes: 1 addition & 1 deletion src/Pure.DI.Core/Core/Models/Injection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public bool Equals(Injection other) =>
public override int GetHashCode() =>
SymbolEqualityComparer.Default.GetHashCode(Type);

private static bool EqualTags(object? tag, object? otherTag)
public static bool EqualTags(object? tag, object? otherTag)
{
if (ReferenceEquals(tag, MdTag.ContextTag))
{
Expand Down
69 changes: 55 additions & 14 deletions src/Pure.DI.Core/Core/TypeConstructor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,42 @@ namespace Pure.DI.Core;
internal sealed class TypeConstructor : ITypeConstructor
{
private readonly IMarker _marker;
private readonly Dictionary<ITypeSymbol,ITypeSymbol> _map = new(SymbolEqualityComparer.Default);
private readonly Dictionary<ITypeSymbol, ITypeSymbol> _map = new(SymbolEqualityComparer.Default);

public TypeConstructor(IMarker marker) => _marker = marker;

public void Bind(ITypeSymbol source, ITypeSymbol target)
public bool TryBind(ITypeSymbol source, ITypeSymbol target)
{
if (_marker.IsMarker(source))
{
_map[source] = target;
return;
return true;
}


var result = true;
switch (source)
{
case INamedTypeSymbol sourceNamedType when target is INamedTypeSymbol targetNamedType:
{
if (!sourceNamedType.IsGenericType && SymbolEqualityComparer.Default.Equals(source, target))
if (!SymbolEqualityComparer.Default.Equals(source.OriginalDefinition, target.OriginalDefinition))
{
return false;
}

if (!sourceNamedType.IsGenericType)
{
return;
return SymbolEqualityComparer.Default.Equals(source, target);
}

if (_map.ContainsKey(source))
{
return;
return true;
}

if (_marker.IsMarker(source))
{
_map[source] = target;
return;
return true;
}

// Constructed generic
Expand All @@ -51,31 +57,66 @@ public void Bind(ITypeSymbol source, ITypeSymbol target)
{
for (var i = 0; i < sourceArgs.Length; i++)
{
Bind(sourceArgs[i], targetArgs[i]);
result &= TryBind(sourceArgs[i], targetArgs[i]);
if (!result)
{
break;
}
}
}

return;
else
{
result = false;
}
}
}

break;
}

case IArrayTypeSymbol sourceArrayType when target is IArrayTypeSymbol targetArrayType:
Bind(sourceArrayType.ElementType, targetArrayType.ElementType);
result &= result && TryBind(sourceArrayType.ElementType, targetArrayType.ElementType);
break;

default:
result &= result && SymbolEqualityComparer.Default.Equals(source.OriginalDefinition, target.OriginalDefinition);
break;
}

if (!result)
{
return result;
}

foreach (var implementationInterfaceType in target.Interfaces)
{
Bind(source, implementationInterfaceType);
if (!SymbolEqualityComparer.Default.Equals(source.OriginalDefinition, implementationInterfaceType.OriginalDefinition))
{
continue;
}

result &= TryBind(source, implementationInterfaceType);
if (!result)
{
break;
}
}

foreach (var dependencyInterfaceType in source.Interfaces)
{
Bind(dependencyInterfaceType, target);
if (!SymbolEqualityComparer.Default.Equals(target.OriginalDefinition, dependencyInterfaceType.OriginalDefinition))
{
continue;
}

result &= TryBind(dependencyInterfaceType, target);
if (!result)
{
break;
}
}

return result;
}

public ITypeSymbol Construct(Compilation compilation, ITypeSymbol type)
Expand Down
52 changes: 52 additions & 0 deletions tests/Pure.DI.IntegrationTests/SetupTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,58 @@ public static void Main()
result.Warnings.Count(i => i.Id == LogId.WarningOverriddenBinding).ShouldBe(1);
}

[Fact]
public async Task ShouldNotOverrideBindingForDifferentMarkerTypes()
{
// Given

// When
var result = await """
using System;
using System.Collections.Generic;
using Pure.DI;

namespace Sample
{
static class Setup
{
private static void SetupComposition()
{
DI.Setup("Composition")
.Bind<Func<char, IList<TT1>>>().To<Func<char, IList<TT1>>>(_ => ch => new List<TT1>())
.Bind<Func<byte, TT1[]>>().To<Func<byte, TT1[]>>(_ => ch => new TT1[0])
.Bind<Func<int, IList<TT2>>>().To<Func<int, IList<TT2>>>(_ => ch => new List<TT2>())
.Root<Func<char, IList<string>>>("Root1")
.Root<Func<byte, int[]>>("Root2")
.Root<Func<int, IList<char>>>("Root3");
}
}

public class Program
{
public static void Main()
{
var composition = new Composition();
Console.WriteLine(composition.Root1.GetType());
Console.WriteLine(composition.Root2.GetType());
Console.WriteLine(composition.Root3.GetType());
}
}
}
""".RunAsync();

// Then
result.Success.ShouldBeTrue(result);
result.StdOut.ShouldBe(
ImmutableArray.Create(
"System.Func`2[System.Char,System.Collections.Generic.IList`1[System.String]]",
"System.Func`2[System.Byte,System.Int32[]]",
"System.Func`2[System.Int32,System.Collections.Generic.IList`1[System.Char]]"),
result);
result.Warnings.Count.ShouldBe(0);
result.Warnings.Count(i => i.Id == LogId.WarningOverriddenBinding).ShouldBe(0);
}

[Fact]
public async Task ShouldOverrideGlobalBindingWithoutWarning()
{
Expand Down

0 comments on commit 5c9df17

Please sign in to comment.