Skip to content

Commit

Permalink
Fix throw in generator comparer (#76769)
Browse files Browse the repository at this point in the history
* Add a test that fails

* Clean up strategies that always have a comparer

* Cleanup comparers in nodes:

All nodes either passed in a default comparer or null, which was stored in the node table. Modifications then had the option to supply a seperate comparer to do the modification. In all cases the comparer passed into the modify call was the same as the one passed when creating the table, so we can remove the call from modify and just use the table instance. Rather than each node passing it its own default when it doesn't have a comparer, just pass in null and let the table control creating the default.

* Use a wrapped comparer as the default comparer
  • Loading branch information
chsienki authored Jan 23, 2025
1 parent 66ecfa6 commit 3a8c9a8
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,54 @@ [Attr] class D { }
Assert.Equal(e, runResults.Results.Single().Exception);
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/76765")]
public void Incremental_Generators_Exception_In_DefaultComparer()
{
var source = """
class C { }
""";
var parseOptions = TestOptions.RegularPreview;
Compilation compilation = CreateCompilation(source, options: TestOptions.DebugDllThrowing, parseOptions: parseOptions);
compilation.VerifyDiagnostics();

var syntaxTree = compilation.SyntaxTrees.Single();

var e = new InvalidOperationException("abc");
var generator = new PipelineCallbackGenerator((ctx) =>
{
var name = ctx.CompilationProvider.Select((c, _) => new ThrowWhenEqualsItem(e));
ctx.RegisterSourceOutput(name, (spc, n) => spc.AddSource("item.cs", "// generated"));
});

GeneratorDriver driver = CSharpGeneratorDriver.Create([generator.AsSourceGenerator()], parseOptions: parseOptions);
driver = driver.RunGenerators(compilation);
var runResults = driver.GetRunResult();

Assert.Empty(runResults.Diagnostics);
Assert.Equal("// generated", runResults.Results.Single().GeneratedSources.Single().SourceText.ToString());

compilation = compilation.ReplaceSyntaxTree(syntaxTree, CSharpSyntaxTree.ParseText("""
class D { }
""", parseOptions));
compilation.VerifyDiagnostics();

driver = driver.RunGenerators(compilation);
runResults = driver.GetRunResult();

VerifyGeneratorExceptionDiagnostic<InvalidOperationException>(runResults.Diagnostics.Single(), nameof(PipelineCallbackGenerator), "abc");
Assert.Empty(runResults.GeneratedTrees);
Assert.Equal(e, runResults.Results.Single().Exception);
}

class ThrowWhenEqualsItem(Exception toThrow)
{
readonly Exception _toThrow = toThrow;

public override bool Equals(object? obj) => throw _toThrow;

public override int GetHashCode() => throw new NotImplementedException();
}

[Fact]
public void Incremental_Generators_Exception_During_Execution_Doesnt_Produce_AnySource()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ public void Node_Builder_Can_Add_Entries_From_Previous_Table()
var previousTable = builder.ToImmutableAndFree();

builder = previousTable.ToBuilder(stepName: null, false);
builder.TryModifyEntries(ImmutableArray.Create(10, 11), EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified);
builder.TryModifyEntries(ImmutableArray.Create(10, 11), TimeSpan.Zero, default, EntryState.Modified);
builder.TryUseCachedEntries(TimeSpan.Zero, default, out var cachedEntries); // ((2, EntryState.Cached), (3, EntryState.Cached))
builder.TryModifyEntries(ImmutableArray.Create(20, 21, 22), EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified);
builder.TryModifyEntries(ImmutableArray.Create(20, 21, 22), TimeSpan.Zero, default, EntryState.Modified);
bool didRemoveEntries = builder.TryRemoveEntries(TimeSpan.Zero, default, out var removedEntries); //((6, EntryState.Removed))
var newTable = builder.ToImmutableAndFree();

Expand Down Expand Up @@ -185,9 +185,9 @@ public void Node_Builder_Handles_Modification_When_Both_Tables_Have_Empty_Entrie
AssertTableEntries(previousTable, expected);

builder = previousTable.ToBuilder(stepName: null, false);
Assert.True(builder.TryModifyEntries(ImmutableArray.Create(3, 2), EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntries(ImmutableArray<int>.Empty, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntries(ImmutableArray.Create(3, 5), EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntries(ImmutableArray.Create(3, 2), TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntries(ImmutableArray<int>.Empty, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntries(ImmutableArray.Create(3, 5), TimeSpan.Zero, default, EntryState.Modified));

var newTable = builder.ToImmutableAndFree();

Expand All @@ -209,10 +209,10 @@ public void Node_Table_Doesnt_Modify_Single_Item_Multiple_Times_When_Same()
AssertTableEntries(previousTable, expected);

builder = previousTable.ToBuilder(stepName: null, false);
Assert.True(builder.TryModifyEntry(1, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(2, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(5, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(4, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(1, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(2, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(5, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(4, TimeSpan.Zero, default, EntryState.Modified));

var newTable = builder.ToImmutableAndFree();

Expand All @@ -232,10 +232,10 @@ public void Node_Table_Caches_Previous_Object_When_Modification_Considered_Cache
var expected = ImmutableArray.Create((1, EntryState.Added, 0), (2, EntryState.Added, 0), (3, EntryState.Added, 0));
AssertTableEntries(previousTable, expected);

builder = previousTable.ToBuilder(stepName: null, false);
Assert.True(builder.TryModifyEntry(1, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified)); // ((1, EntryState.Cached))
Assert.True(builder.TryModifyEntry(4, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified)); // ((4, EntryState.Modified))
Assert.True(builder.TryModifyEntry(5, new LambdaComparer<int>((i, j) => true), TimeSpan.Zero, default, EntryState.Modified)); // ((3, EntryState.Cached))
builder = previousTable.ToBuilder(stepName: null, false, new LambdaComparer<int>((i, j) => i == 3 || i == j));
Assert.True(builder.TryModifyEntry(1, TimeSpan.Zero, default, EntryState.Modified)); // ((1, EntryState.Cached))
Assert.True(builder.TryModifyEntry(4, TimeSpan.Zero, default, EntryState.Modified)); // ((4, EntryState.Modified))
Assert.True(builder.TryModifyEntry(5, TimeSpan.Zero, default, EntryState.Modified)); // ((3, EntryState.Cached))
var newTable = builder.ToImmutableAndFree();

expected = ImmutableArray.Create((1, EntryState.Cached, 0), (4, EntryState.Modified, 0), (3, EntryState.Cached, 0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ internal sealed class BatchNode<TInput> : IIncrementalGeneratorNode<ImmutableArr
private static readonly string? s_tableType = typeof(ImmutableArray<TInput>).FullName;

private readonly IIncrementalGeneratorNode<TInput> _sourceNode;
private readonly IEqualityComparer<ImmutableArray<TInput>> _comparer;
private readonly IEqualityComparer<ImmutableArray<TInput>>? _comparer;
private readonly string? _name;

public BatchNode(IIncrementalGeneratorNode<TInput> sourceNode, IEqualityComparer<ImmutableArray<TInput>>? comparer = null, string? name = null)
{
_sourceNode = sourceNode;
_comparer = comparer ?? EqualityComparer<ImmutableArray<TInput>>.Default;
_comparer = comparer;
_name = name;
}

Expand Down Expand Up @@ -136,7 +136,7 @@ public NodeStateTable<ImmutableArray<TInput>> UpdateStateTable(DriverStateTable.
}
else if (!sourceTable.IsCached || !tableBuilder.TryUseCachedEntries(stopwatch.Elapsed, sourceInputs))
{
if (!tableBuilder.TryModifyEntry(sourceValues, _comparer, stopwatch.Elapsed, sourceInputs, EntryState.Modified))
if (!tableBuilder.TryModifyEntry(sourceValues, stopwatch.Elapsed, sourceInputs, EntryState.Modified))
{
tableBuilder.AddEntry(sourceValues, EntryState.Added, stopwatch.Elapsed, sourceInputs, EntryState.Added);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public CombineNode(IIncrementalGeneratorNode<TInput1> input1, IIncrementalGenera
};

var entry = (entry1.Item, input2);
if (state != EntryState.Modified || _comparer is null || !tableBuilder.TryModifyEntry(entry, _comparer, stopwatch.Elapsed, stepInputs, state))
if (state != EntryState.Modified || _comparer is null || !tableBuilder.TryModifyEntry(entry, stopwatch.Elapsed, stepInputs, state))
{
tableBuilder.AddEntry(entry, state, stopwatch.Elapsed, stepInputs, state);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ internal sealed class InputNode<T> : IIncrementalGeneratorNode<T>
private readonly Func<DriverStateTable.Builder, ImmutableArray<T>> _getInput;
private readonly Action<IIncrementalGeneratorOutputNode> _registerOutput;
private readonly IEqualityComparer<T> _inputComparer;
private readonly IEqualityComparer<T> _comparer;
private readonly IEqualityComparer<T>? _comparer;
private readonly string? _name;

public InputNode(Func<DriverStateTable.Builder, ImmutableArray<T>> getInput, IEqualityComparer<T>? inputComparer = null)
Expand All @@ -35,7 +35,7 @@ public InputNode(Func<DriverStateTable.Builder, ImmutableArray<T>> getInput, IEq
private InputNode(Func<DriverStateTable.Builder, ImmutableArray<T>> getInput, Action<IIncrementalGeneratorOutputNode>? registerOutput, IEqualityComparer<T>? inputComparer = null, IEqualityComparer<T>? comparer = null, string? name = null)
{
_getInput = getInput;
_comparer = comparer ?? EqualityComparer<T>.Default;
_comparer = comparer;
_inputComparer = inputComparer ?? EqualityComparer<T>.Default;
_registerOutput = registerOutput ?? (o => throw ExceptionUtilities.Unreachable());
_name = name;
Expand Down Expand Up @@ -83,7 +83,7 @@ public NodeStateTable<T> UpdateStateTable(DriverStateTable.Builder graphState, N
// This allows us to correctly 'replace' items even when they aren't actually the same. In the case that the
// item really isn't modified, but a new item, we still function correctly as we mostly treat them the same,
// but will perform an extra comparison that is omitted in the pure 'added' case.
var modified = tableBuilder.TryModifyEntry(inputItems[itemIndex], _comparer, elapsedTime, noInputStepsStepInfo, EntryState.Modified);
var modified = tableBuilder.TryModifyEntry(inputItems[itemIndex], elapsedTime, noInputStepsStepInfo, EntryState.Modified);
Debug.Assert(modified);
itemsSet.Remove(inputItems[itemIndex]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public NodeStateTable<T> AsCached()
public Builder ToBuilder(string? stepName, bool stepTrackingEnabled, IEqualityComparer<T>? equalityComparer = null, int? tableCapacity = null)
=> new(this, stepName, stepTrackingEnabled, equalityComparer, tableCapacity);

public NodeStateTable<T> CreateCachedTableWithUpdatedSteps<TInput>(NodeStateTable<TInput> inputTable, string? stepName, IEqualityComparer<T> equalityComparer)
public NodeStateTable<T> CreateCachedTableWithUpdatedSteps<TInput>(NodeStateTable<TInput> inputTable, string? stepName, IEqualityComparer<T>? equalityComparer)
{
Debug.Assert(inputTable.HasTrackedSteps && inputTable.IsCached);
NodeStateTable<T>.Builder builder = ToBuilder(stepName, stepTrackingEnabled: true, equalityComparer);
Expand Down Expand Up @@ -256,7 +256,7 @@ internal Builder(
_states = ArrayBuilder<TableEntry>.GetInstance(tableCapacity ?? previous.GetTotalEntryItemCount());
_previous = previous;
_name = name;
_equalityComparer = equalityComparer ?? EqualityComparer<T>.Default;
_equalityComparer = equalityComparer ?? WrappedUserComparer<T>.Default;
if (stepTrackingEnabled)
{
_steps = ArrayBuilder<IncrementalGeneratorRunStep>.GetInstance();
Expand Down Expand Up @@ -320,7 +320,7 @@ internal bool TryUseCachedEntries(TimeSpan elapsedTime, ImmutableArray<(Incremen
return true;
}

public bool TryModifyEntry(T value, IEqualityComparer<T> comparer, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState)
public bool TryModifyEntry(T value, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState)
{
if (!TryGetPreviousEntry(out var previousEntry))
{
Expand All @@ -335,13 +335,13 @@ public bool TryModifyEntry(T value, IEqualityComparer<T> comparer, TimeSpan elap
}

Debug.Assert(previousEntry.Count == 1);
var (chosen, state, _) = GetModifiedItemAndState(previousEntry.GetItem(0), value, comparer);
var (chosen, state, _) = GetModifiedItemAndState(previousEntry.GetItem(0), value);
_states.Add(new TableEntry(OneOrMany.Create(chosen), state));
RecordStepInfoForLastEntry(elapsedTime, stepInputs, overallInputState);
return true;
}

public bool TryModifyEntries(ImmutableArray<T> outputs, IEqualityComparer<T> comparer, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState)
public bool TryModifyEntries(ImmutableArray<T> outputs, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState)
{
// Semantics:
// For each item in the row, we compare with the new matching new value.
Expand Down Expand Up @@ -384,7 +384,7 @@ public bool TryModifyEntries(ImmutableArray<T> outputs, IEqualityComparer<T> com
var previousState = previousEntry.GetState(i);
var replacementItem = outputs[i];

var (chosenItem, state, chosePrevious) = GetModifiedItemAndState(previousItem, replacementItem, comparer);
var (chosenItem, state, chosePrevious) = GetModifiedItemAndState(previousItem, replacementItem);

if (builder != null)
{
Expand Down Expand Up @@ -433,9 +433,9 @@ public bool TryModifyEntries(ImmutableArray<T> outputs, IEqualityComparer<T> com
return true;
}

public bool TryModifyEntries(ImmutableArray<T> outputs, IEqualityComparer<T> comparer, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState, out TableEntry entry)
public bool TryModifyEntries(ImmutableArray<T> outputs, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState, out TableEntry entry)
{
if (!TryModifyEntries(outputs, comparer, elapsedTime, stepInputs, overallInputState))
if (!TryModifyEntries(outputs, elapsedTime, stepInputs, overallInputState))
{
entry = default;
return false;
Expand Down Expand Up @@ -554,11 +554,11 @@ public NodeStateTable<T> ToImmutableAndFree()
isCached: finalStates.All(static s => s.IsCached) && _previous.GetTotalEntryItemCount() == finalStates.Sum(static s => s.Count));
}

private static (T chosen, EntryState state, bool chosePrevious) GetModifiedItemAndState(T previous, T replacement, IEqualityComparer<T> comparer)
private (T chosen, EntryState state, bool chosePrevious) GetModifiedItemAndState(T previous, T replacement)
{
// when comparing an item to check if its modified we explicitly cache the *previous* item in the case where its
// considered to be equal. This ensures that subsequent comparisons are stable across future generation passes.
return comparer.Equals(previous, replacement)
return _equalityComparer.Equals(previous, replacement)
? (previous, EntryState.Cached, chosePrevious: true)
: (replacement, EntryState.Modified, chosePrevious: false);
}
Expand Down
Loading

0 comments on commit 3a8c9a8

Please sign in to comment.