Skip to content

Commit

Permalink
Improve support for validating indirectly used types
Browse files Browse the repository at this point in the history
Detect lacking partial type in indirect types (i.e. parameters to actor messages). We cannot provide code fixers for these because they have to be reported for an entire compilation (since we need to scan types and usages), which makes the experience suboptimal. But at least we don't fail at run-time.
  • Loading branch information
kzu committed Jul 29, 2024
1 parent fd6432f commit bbfff46
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 35 deletions.
49 changes: 39 additions & 10 deletions src/CloudActors.CodeAnaysis/ActorMessageGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
using static Devlooped.CloudActors.Diagnostics;
Expand All @@ -15,13 +16,42 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Where(IsActorMessage)
.Where(t => t.IsPartial());

context.RegisterImplementationSourceOutput(messages.Combine(options), (ctx, source) =>
var additionalTypes = messages.SelectMany((x, _) =>
x.GetMembers().OfType<IPropertySymbol>()
// Generated serializers only expose public members.
.Where(p => p.DeclaredAccessibility == Accessibility.Public)
.Select(p => p.Type)
.OfType<INamedTypeSymbol>()
.Where(t => t.IsPartial())
.Concat(x.GetMembers()
.OfType<IMethodSymbol>()
// Generated serializers only expose public members.
.Where(m => m.DeclaredAccessibility == Accessibility.Public)
.SelectMany(m => m.Parameters)
.Select(p => p.Type)
.OfType<INamedTypeSymbol>()))
// We already generate separately for actor messages.
.Where(t => !IsActorMessage(t) && t.IsPartial())
.Collect();

context.RegisterImplementationSourceOutput(messages.Combine(options), GenerateCode);
context.RegisterImplementationSourceOutput(additionalTypes.Combine(options), (ctx, source) =>
{
var (message, options) = source;
var ns = message.ContainingNamespace.ToDisplayString();
var kind = message.IsRecord ? "record" : "class";
var output =
$$"""
var (messages, options) = source;
var distinct = new HashSet<INamedTypeSymbol>(messages, SymbolEqualityComparer.Default);
foreach (var message in distinct)
GenerateCode(ctx, (message, options));
});
}

static void GenerateCode(SourceProductionContext ctx, (INamedTypeSymbol, OrleansGeneratorOptions) source)
{
var (message, options) = source;

var ns = message.ContainingNamespace.ToDisplayString();
var kind = message.IsRecord ? "record" : "class";
var output =
$$"""
// <auto-generated/>
using System.CodeDom.Compiler;
Expand All @@ -35,10 +65,9 @@ namespace {{ns}}
}
""";

var orleans = OrleansGenerator.GenerateCode(options, output, message.Name, ctx.CancellationToken);
var orleans = OrleansGenerator.GenerateCode(options, output, message.Name, ctx.CancellationToken);

ctx.AddSource($"{message.ToFileName()}.Serializable.cs", output);
ctx.AddSource($"{message.ToFileName()}.Serializable.orleans.cs", orleans);
});
ctx.AddSource($"{message.ToFileName()}.Serializable.cs", output);
ctx.AddSource($"{message.ToFileName()}.Serializable.orleans.cs", orleans);
}
}
13 changes: 12 additions & 1 deletion src/CloudActors.CodeAnaysis/Diagnostics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public static class Diagnostics
public static DiagnosticDescriptor MustBePartial { get; } = new(
"DCA001",
"Actors and messages must be partial",
"Cloud Actors require partial actor and message types.",
"Add the partial keyword to '{0}' as required for types used by actors.",
"Build",
DiagnosticSeverity.Error,
isEnabledByDefault: true);
Expand All @@ -40,6 +40,17 @@ public static class Diagnostics
DiagnosticSeverity.Error,
isEnabledByDefault: true);

/// <summary>
/// DCA004: Indirectly used types must be serializable
/// </summary>
public static DiagnosticDescriptor MustBeSerializable { get; } = new(
"DCA004",
"Types used in actor messages must have a [GenerateSerializer] attribute,",
"Annotate '{0}' with [GenerateSerializer] as it is used by at least one actor message.",
"Build",
DiagnosticSeverity.Error,
isEnabledByDefault: true);

public static SymbolDisplayFormat FullName { get; } = new(
typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces,
genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters,
Expand Down
2 changes: 0 additions & 2 deletions src/CloudActors.CodeAnaysis/OrleansGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading;
using System.Xml.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
Expand Down
55 changes: 49 additions & 6 deletions src/CloudActors.CodeAnaysis/PartialAnalyzer.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Immutable;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand All @@ -17,18 +18,16 @@ public override void Initialize(AnalysisContext context)
{
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
context.EnableConcurrentExecution();
context.RegisterCompilationAction(Analyze);
context.RegisterSyntaxNodeAction(Analyze, SyntaxKind.ClassDeclaration, SyntaxKind.RecordDeclaration);
}

void Analyze(SyntaxNodeAnalysisContext context)
static void Analyze(SyntaxNodeAnalysisContext context)
{
if (context.Node is not TypeDeclarationSyntax typeDeclaration ||
context.Compilation.GetTypeByMetadataName("Devlooped.CloudActors.IActorMessage") is not { } messageType)
return;

if (typeDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword))
return;

var symbol = context.SemanticModel.GetDeclaredSymbol(typeDeclaration);
if (symbol is null)
return;
Expand All @@ -39,6 +38,50 @@ void Analyze(SyntaxNodeAnalysisContext context)
!symbol.AllInterfaces.Contains(messageType, SymbolEqualityComparer.Default))
return;

context.ReportDiagnostic(Diagnostic.Create(MustBePartial, typeDeclaration.Identifier.GetLocation()));
if (!typeDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword))
context.ReportDiagnostic(Diagnostic.Create(MustBePartial, typeDeclaration.Identifier.GetLocation(), symbol.Name));
}

static void Analyze(CompilationAnalysisContext context)
{
if (context.Compilation.GetTypeByMetadataName("Devlooped.CloudActors.IActorMessage") is not { } messageType)
return;

var messageTypes = context.Compilation
.Assembly.GetAllTypes()
.OfType<INamedTypeSymbol>()
.Where(x => x.AllInterfaces.Contains(messageType, SymbolEqualityComparer.Default));

// Report also for all source-declared custom types used in the message, as constructors or properties
var indirect = new HashSet<INamedTypeSymbol>(messageTypes
.SelectMany(x => x.GetMembers())
.OfType<IPropertySymbol>()
// Generated serializers only expose public members.
.Where(p => p.DeclaredAccessibility == Accessibility.Public)
.Select(p => p.Type)
.Concat(messageTypes.SelectMany(x => x.GetMembers()
.OfType<IMethodSymbol>()
// Generated serializers only expose public members.
.Where(m => m.DeclaredAccessibility == Accessibility.Public)
.SelectMany(m => m.Parameters)
.Select(p => p.Type)))
.OfType<INamedTypeSymbol>()
// where the type is not partial
//.Where(t => !t.GetAttributes().Any(a => generateAttr.Equals(a.AttributeClass, SymbolEqualityComparer.Default))),
.Where(t =>
!t.GetAttributes().Any(IsActor) &&
!t.AllInterfaces.Contains(messageType, SymbolEqualityComparer.Default) &&
!t.IsPartial() &&
t.Locations.Any(l => l.IsInSource)),
SymbolEqualityComparer.Default);

foreach (var type in indirect)
{
// select the type declarations
if (type.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() is not TypeDeclarationSyntax declaration)
continue;

context.ReportDiagnostic(Diagnostic.Create(MustBePartial, declaration!.Identifier.GetLocation(), type.Name));
}
}
}
67 changes: 52 additions & 15 deletions src/Tests/CodeFixers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,34 +70,41 @@ public partial class MyActor { }
}

[Fact]
public async Task NoGenerateSerializer()
public async Task AddPartialMessage()
{
var context = new CSharpAnalyzerTest<ActorMessageAnalyzer, DefaultVerifier>();
var context = new CSharpCodeFixTest<PartialAnalyzer, TypeMustBePartial, DefaultVerifier>();
context.ReferenceAssemblies = ReferenceAssemblies.Net.Net80
.AddPackages([
new PackageIdentity("Devlooped.CloudActors", "0.4.0"),
new PackageIdentity("Microsoft.Orleans.Serialization.Abstractions", "8.2.0"),
]);
.AddPackages([new PackageIdentity("Devlooped.CloudActors", "0.4.0")]);

context.TestCode =
/* lang=c#-test */
"""
using Devlooped.CloudActors;
using Orleans;
namespace Tests;
[GenerateSerializer]
public record {|DCA002:GetBalance|}() : IActorQuery<decimal>;
public record {|DCA001:GetBalance|}() : IActorQuery<decimal>;
""";

context.FixedCode =
/* lang=c#-test */
"""
using Devlooped.CloudActors;
namespace Tests;
public partial record GetBalance() : IActorQuery<decimal>;
""";

await context.RunAsync();
}


[Fact]
public async Task AddPartialMessage()
public async Task ReportPartialIndirectMessage()
{
var context = new CSharpCodeFixTest<PartialAnalyzer, TypeMustBePartial, DefaultVerifier>();
// Can't verify the codefix due to being reported for another node.
var context = new CSharpAnalyzerTest<PartialAnalyzer, DefaultVerifier>();
context.ReferenceAssemblies = ReferenceAssemblies.Net.Net80
.AddPackages([new PackageIdentity("Devlooped.CloudActors", "0.4.0")]);

Expand All @@ -108,19 +115,49 @@ public async Task AddPartialMessage()
namespace Tests;
public record {|DCA001:GetBalance|}() : IActorQuery<decimal>;
public record {|DCA001:Address|}(string Street, string City, string State, string Zip);
public partial record SetAddress(Address Address) : IActorCommand;
""";

context.FixedCode =
//context.FixedCode =
// /* lang=c#-test */
// """
// using Devlooped.CloudActors;

// namespace Tests;

// public partial record Address(string Street, string City, string State, string Zip);

// public partial record SetAddress(Address Address) : IActorCommand;
// """;

await context.RunAsync();
}

[Fact]
public async Task NoGenerateSerializer()
{
var context = new CSharpAnalyzerTest<ActorMessageAnalyzer, DefaultVerifier>();
context.ReferenceAssemblies = ReferenceAssemblies.Net.Net80
.AddPackages([
new PackageIdentity("Devlooped.CloudActors", "0.4.0"),
new PackageIdentity("Microsoft.Orleans.Serialization.Abstractions", "8.2.0"),
]);

context.TestCode =
/* lang=c#-test */
"""
using Devlooped.CloudActors;
using Orleans;
namespace Tests;
public partial record GetBalance() : IActorQuery<decimal>;
[GenerateSerializer]
public record {|DCA002:GetBalance|}() : IActorQuery<decimal>;
""";

await context.RunAsync();
}

}
2 changes: 1 addition & 1 deletion src/Tests/Customer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ await CloudStorageAccount.DevelopmentStorageAccount
}
}

public record Address(string Street, string City, string State, string Zip);
public partial record Address(string Street, string City, string State, string Zip);

public partial record SetAddress(Address Address) : IActorCommand;

Expand Down

0 comments on commit bbfff46

Please sign in to comment.