Skip to content

Commit

Permalink
Merge pull request #679 from Cysharp/hotfix/FixNestedClassEnumFormatter
Browse files Browse the repository at this point in the history
Fix code generation for formatter of Enum nested in a class
  • Loading branch information
mayuki authored Sep 21, 2023
2 parents 7f7c2b1 + bf754e4 commit e9cf6b0
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ namespace MagicOnion.Generator.CodeAnalysis;
public interface ISerializationFormatterNameMapper
{
IWellKnownSerializationTypes WellKnownTypes { get; }
bool TryMapGeneric(MagicOnionTypeInfo type, out string formatterName);
string MapArray(MagicOnionTypeInfo type);
bool TryMapGeneric(MagicOnionTypeInfo type, out string formatterName, out string formatterConstructorArgs);
(string FormatterName, string FormatterConstructorArgs) MapArray(MagicOnionTypeInfo type);
}


Expand All @@ -30,38 +30,43 @@ public MessagePackFormatterNameMapper(string userDefinedFormatterNamespace)
this.userDefinedFormatterNamespace = userDefinedFormatterNamespace;
}

public bool TryMapGeneric(MagicOnionTypeInfo type, out string formatterName)
public bool TryMapGeneric(MagicOnionTypeInfo type, out string formatterName, out string formatterConstructorArgs)
{
formatterName = null;
formatterConstructorArgs = null;

var genericTypeArgs = string.Join(", ", type.GenericArguments.Select(x => x.FullName));
if (type is { Namespace: "MagicOnion", Name: "DynamicArgumentTuple" })
{
// MagicOnion.DynamicArgumentTuple
var ctorArguments = string.Join(", ", type.GenericArguments.Select(x => $"default({x.FullName})"));
formatterName = $"global::MagicOnion.DynamicArgumentTupleFormatter<{genericTypeArgs}>({ctorArguments})";
formatterName = $"global::MagicOnion.DynamicArgumentTupleFormatter<{genericTypeArgs}>";
formatterConstructorArgs = $"({ctorArguments})";
}
else if (MessagePackWellKnownSerializationTypes.Instance.GenericFormattersMap.TryGetValue(type.FullNameOpenType, out var mappedFormatterName))
{
// Well-known generic types (Nullable<T>, IList<T>, List<T>, Dictionary<TKey, TValue> ...)
formatterName = $"{mappedFormatterName}<{genericTypeArgs}>()";
formatterName = $"{mappedFormatterName}<{genericTypeArgs}>";
formatterConstructorArgs = "()";
}
else
{
// User-defined generic types
formatterName = $"{userDefinedFormatterNamespace}{(string.IsNullOrWhiteSpace(userDefinedFormatterNamespace) ? "" : ".")}{type.ToDisplayName(MagicOnionTypeInfo.DisplayNameFormat.Namespace | MagicOnionTypeInfo.DisplayNameFormat.WithoutGenericArguments)}Formatter<{genericTypeArgs}>()";
formatterName = $"{userDefinedFormatterNamespace}{(string.IsNullOrWhiteSpace(userDefinedFormatterNamespace) ? "" : ".")}{type.ToDisplayName(MagicOnionTypeInfo.DisplayNameFormat.Namespace | MagicOnionTypeInfo.DisplayNameFormat.WithoutGenericArguments)}Formatter<{genericTypeArgs}>";
formatterConstructorArgs = "()";
}

return formatterName != null;
}

public string MapArray(MagicOnionTypeInfo type)
public (string FormatterName, string FormatterConstructorArgs) MapArray(MagicOnionTypeInfo type)
{
return type.ArrayRank switch
{
1 => $"global::MessagePack.Formatters.ArrayFormatter<{type.ElementType.FullName}>()",
2 => $"global::MessagePack.Formatters.TwoDimensionalArrayFormatter<{type.ElementType.FullName}>()",
3 => $"global::MessagePack.Formatters.ThreeDimensionalArrayFormatter<{type.ElementType.FullName}>()",
4 => $"global::MessagePack.Formatters.FourDimensionalArrayFormatter<{type.ElementType.FullName}>()",
1 => ($"global::MessagePack.Formatters.ArrayFormatter<{type.ElementType.FullName}>", "()"),
2 => ($"global::MessagePack.Formatters.TwoDimensionalArrayFormatter<{type.ElementType.FullName}>", "()"),
3 => ($"global::MessagePack.Formatters.ThreeDimensionalArrayFormatter<{type.ElementType.FullName}>", "()"),
4 => ($"global::MessagePack.Formatters.FourDimensionalArrayFormatter<{type.ElementType.FullName}>", "()"),
_ => throw new IndexOutOfRangeException($"An array of rank must be less than 5. ({type.FullName})"),
};
}
Expand Down Expand Up @@ -181,33 +186,37 @@ public MemoryPackFormatterNameMapper()
{
}

public bool TryMapGeneric(MagicOnionTypeInfo type, out string formatterName)
public bool TryMapGeneric(MagicOnionTypeInfo type, out string formatterName, out string formatterConstructorArgs)
{
formatterName = null;
formatterConstructorArgs = null;

var genericTypeArgs = string.Join(", ", type.GenericArguments.Select(x => x.FullName));
if (type is { Namespace: "MagicOnion", Name: "DynamicArgumentTuple" })
{
// MagicOnion.DynamicArgumentTuple
var ctorArguments = string.Join(", ", type.GenericArguments.Select(x => $"default({x.FullName})"));
formatterName = $"global::MagicOnion.Serialization.MemoryPack.DynamicArgumentTupleFormatter<{genericTypeArgs}>()";
formatterName = $"global::MagicOnion.Serialization.MemoryPack.DynamicArgumentTupleFormatter<{genericTypeArgs}>";
formatterConstructorArgs = "()";
}
else if (MessagePackWellKnownSerializationTypes.Instance.GenericFormattersMap.TryGetValue(type.FullNameOpenType, out var mappedFormatterName))
{
// Well-known generic types (Nullable<T>, IList<T>, List<T>, Dictionary<TKey, TValue> ...)
formatterName = $"{mappedFormatterName}<{genericTypeArgs}>()";
formatterName = $"{mappedFormatterName}<{genericTypeArgs}>";
formatterConstructorArgs = "()";
}

return formatterName != null;
}

public string MapArray(MagicOnionTypeInfo type)
public (string FormatterName, string FormatterConstructorArgs) MapArray(MagicOnionTypeInfo type)
{
return type.ArrayRank switch
{
1 => $"global::MemoryPack.Formatters.ArrayFormatter<{type.ElementType.FullName}>()",
2 => $"global::MemoryPack.Formatters.TwoDimensionalArrayFormatter<{type.ElementType.FullName}>()",
3 => $"global::MemoryPack.Formatters.ThreeDimensionalArrayFormatter<{type.ElementType.FullName}>()",
4 => $"global::MemoryPack.Formatters.FourDimensionalArrayFormatter<{type.ElementType.FullName}>()",
1 => ($"global::MemoryPack.Formatters.ArrayFormatter<{type.ElementType.FullName}>", "()"),
2 => ($"global::MemoryPack.Formatters.TwoDimensionalArrayFormatter<{type.ElementType.FullName}>", "()"),
3 => ($"global::MemoryPack.Formatters.ThreeDimensionalArrayFormatter<{type.ElementType.FullName}>", "()"),
4 => ($"global::MemoryPack.Formatters.FourDimensionalArrayFormatter<{type.ElementType.FullName}>", "()"),
_ => throw new IndexOutOfRangeException($"An array of rank must be less than 5. ({type.FullName})"),
};
}
Expand Down
13 changes: 11 additions & 2 deletions src/MagicOnion.GeneratorCore/CodeAnalysis/SerializationInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ public interface ISerializationFormatterRegisterInfo
{
string FullName { get; }
string FormatterName { get; }
string FormatterConstructorArgs { get; }
string FormatterNameWithConstructorArgs { get; } // e.g. MyEnumFormatter(), DynamicArgumentTupleFormatter<T1, T2>(default, default) ...

IReadOnlyList<string> IfDirectiveConditions { get; }
bool HasIfDirectiveConditions { get; }
Expand All @@ -15,14 +17,17 @@ public class GenericSerializationInfo : ISerializationFormatterRegisterInfo
public string FullName { get; }

public string FormatterName { get; }
public string FormatterConstructorArgs { get; }
public string FormatterNameWithConstructorArgs => FormatterName + FormatterConstructorArgs;

public IReadOnlyList<string> IfDirectiveConditions { get; }
public bool HasIfDirectiveConditions => IfDirectiveConditions.Any();

public GenericSerializationInfo(string fullName, string formatterName, IReadOnlyList<string> ifDirectiveConditions)
public GenericSerializationInfo(string fullName, string formatterName, string formatterConstructorArgs, IReadOnlyList<string> ifDirectiveConditions)
{
FullName = fullName;
FormatterName = formatterName;
FormatterConstructorArgs = formatterConstructorArgs;
IfDirectiveConditions = ifDirectiveConditions;
}
}
Expand All @@ -34,7 +39,9 @@ public class EnumSerializationInfo : ISerializationFormatterRegisterInfo
public string FullName { get; }
public string UnderlyingType { get; }

public string FormatterName => Name + "Formatter()";
public string FormatterName => $"{Name.Replace(".", "_")}Formatter";
public string FormatterConstructorArgs => "()";
public string FormatterNameWithConstructorArgs => FormatterName + FormatterConstructorArgs;

public IReadOnlyList<string> IfDirectiveConditions { get; }
public bool HasIfDirectiveConditions => IfDirectiveConditions.Any();
Expand All @@ -54,6 +61,8 @@ public class SerializationTypeHintInfo : ISerializationFormatterRegisterInfo
public string FullName { get; }

string ISerializationFormatterRegisterInfo.FormatterName => string.Empty; // Dummy
string ISerializationFormatterRegisterInfo.FormatterConstructorArgs => string.Empty; // Dummy
string ISerializationFormatterRegisterInfo.FormatterNameWithConstructorArgs => string.Empty; // Dummy

public IReadOnlyList<string> IfDirectiveConditions { get; }
public bool HasIfDirectiveConditions => IfDirectiveConditions.Any();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ public MagicOnionSerializationInfoCollection Collect(IEnumerable<TypeWithIfDirec
}

logger.Trace($"[{nameof(SerializationInfoCollector)}] Array type '{type.FullName}'");
context.Generics.Add(new GenericSerializationInfo(type.FullName, mapper.MapArray(type), typeWithDirectives.IfDirectives));

var (formatterName, formatterConstructorArgs) = mapper.MapArray(type);
context.Generics.Add(new GenericSerializationInfo(type.FullName, formatterName, formatterConstructorArgs, typeWithDirectives.IfDirectives));
mapper.MapArray(type);
}
else if (type.HasGenericArguments)
Expand All @@ -124,10 +126,10 @@ public MagicOnionSerializationInfoCollection Collect(IEnumerable<TypeWithIfDirec
continue;
}

if (mapper.TryMapGeneric(type, out var formatterName))
if (mapper.TryMapGeneric(type, out var formatterName, out var formatterConstructorArgs))
{
logger.Trace($"[{nameof(SerializationInfoCollector)}] Generic type '{type.FullName}' (IfDirectives={string.Join(", ", typeWithDirectives.IfDirectives)})");
context.Generics.Add(new GenericSerializationInfo(type.FullName, formatterName, typeWithDirectives.IfDirectives));
context.Generics.Add(new GenericSerializationInfo(type.FullName, formatterName, formatterConstructorArgs, typeWithDirectives.IfDirectives));
}
}
}
Expand All @@ -144,6 +146,7 @@ static GenericSerializationInfo[] MergeResolverRegisterInfo(IEnumerable<GenericS
new GenericSerializationInfo(
serializationInfo.FullName,
serializationInfo.FormatterName,
serializationInfo.FormatterConstructorArgs,
serializationInfo.IfDirectiveConditions.Concat(serializationInfoCandidate.IfDirectiveConditions).ToArray()
)
);
Expand Down
8 changes: 4 additions & 4 deletions src/MagicOnion.GeneratorCore/CodeGen/EnumTemplate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ public virtual string TransformText()
this.Write("\r\n");
}
this.Write(" public sealed class ");
this.Write(this.ToStringHelper.ToStringWithCulture(info.Name));
this.Write("Formatter : global::MessagePack.Formatters.IMessagePackFormatter<");
this.Write(this.ToStringHelper.ToStringWithCulture(info.FormatterName));
this.Write(" : global::MessagePack.Formatters.IMessagePackFormatter<");
this.Write(this.ToStringHelper.ToStringWithCulture(info.FullName));
this.Write(">\r\n {\r\n public void Serialize(ref MessagePackWriter writer, ");
this.Write(this.ToStringHelper.ToStringWithCulture(info.FullName));
Expand All @@ -58,7 +58,7 @@ public virtual string TransformText()
}
}
this.Write("\r\n}\r\n\r\n#pragma warning restore 168\r\n#pragma warning restore 219\r\n#pragma warning " +
"restore 414\r\n#pragma warning restore 612\r\n#pragma warning restore 618");
"restore 414\r\n#pragma warning restore 612\r\n#pragma warning restore 618\r\n");
return this.GenerationEnvironment.ToString();
}
}
Expand All @@ -81,7 +81,7 @@ public class EnumTemplateBase
/// <summary>
/// The string builder that generation-time code is using to assemble generated output
/// </summary>
protected System.Text.StringBuilder GenerationEnvironment
public System.Text.StringBuilder GenerationEnvironment
{
get
{
Expand Down
4 changes: 2 additions & 2 deletions src/MagicOnion.GeneratorCore/CodeGen/EnumTemplate.tt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace <#= Namespace #>
<# if (info.HasIfDirectiveConditions) { #>
#if <#= string.Join(" || ", info.IfDirectiveConditions.Select(y => $"({y})")) #>
<# } #>
public sealed class <#= info.Name #>Formatter : global::MessagePack.Formatters.IMessagePackFormatter<<#= info.FullName #>>
public sealed class <#= info.FormatterName #> : global::MessagePack.Formatters.IMessagePackFormatter<<#= info.FullName #>>
{
public void Serialize(ref MessagePackWriter writer, <#= info.FullName #> value, MessagePackSerializerOptions options)
{
Expand All @@ -42,4 +42,4 @@ namespace <#= Namespace #>
#pragma warning restore 219
#pragma warning restore 414
#pragma warning restore 612
#pragma warning restore 618
#pragma warning restore 618
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public static void RegisterFormatters()
using (ctx.TextWriter.IfDirective(string.Join(" || ", resolverInfo.IfDirectiveConditions.Select(y => $"({y})"))))
{
ctx.TextWriter.WriteLines($$"""
global::MemoryPack.MemoryPackFormatterProvider.Register(new {{(resolverInfo.FormatterName.StartsWith("global::") ? resolverInfo.FormatterName : (string.IsNullOrWhiteSpace(ctx.FormatterNamespace) ? "" : ctx.FormatterNamespace + ".") + resolverInfo.FormatterName)}});
global::MemoryPack.MemoryPackFormatterProvider.Register(new {{(resolverInfo.FormatterName.StartsWith("global::") || string.IsNullOrWhiteSpace(ctx.FormatterNamespace) ? "" : ctx.FormatterNamespace + ".") + resolverInfo.FormatterName}}{{resolverInfo.FormatterConstructorArgs}});
""");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ internal static object GetFormatter(Type t)
using (ctx.TextWriter.IfDirective(string.Join(" || ", resolverInfo.IfDirectiveConditions.Select(y => $"({y})"))))
{
ctx.TextWriter.WriteLines($$"""
case {{index}}: return new {{(resolverInfo.FormatterName.StartsWith("global::") ? resolverInfo.FormatterName : (string.IsNullOrWhiteSpace(ctx.FormatterNamespace) ? "" : ctx.FormatterNamespace + ".") + resolverInfo.FormatterName)}};
case {{index}}: return new {{(resolverInfo.FormatterName.StartsWith("global::") || string.IsNullOrWhiteSpace(ctx.FormatterNamespace) ? "" : ctx.FormatterNamespace + ".") + resolverInfo.FormatterName}}{{resolverInfo.FormatterConstructorArgs}};
""");
}
}
Expand Down
21 changes: 11 additions & 10 deletions src/MagicOnion.GeneratorCore/MagicOnionCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,26 @@ public async Task GenerateFileAsync(
logger.Trace($"[{nameof(MagicOnionCompiler)}] RuntimeInformation.FrameworkDescription: {RuntimeInformation.FrameworkDescription}");

// Configure serialization
var serialization = serializerType switch
(ISerializationFormatterNameMapper Mapper, string Namespace, string InitializerName, ISerializerFormatterGenerator Generator, Func<IEnumerable<EnumSerializationInfo>, string> EnumFormatterGenerator)
serialization = serializerType switch
{
SerializerType.MemoryPack => (
Mapper: (ISerializationFormatterNameMapper)new MemoryPackFormatterNameMapper(),
Mapper: new MemoryPackFormatterNameMapper(),
Namespace: @namespace,
InitializerName: "MagicOnionMemoryPackFormatterProvider",
Generator: (ISerializerFormatterGenerator)new MemoryPackFormatterRegistrationGenerator(),
EnumFormatterGenerator: (Func<IEnumerable<EnumSerializationInfo>, string>)(enumSerializationInfo => string.Empty)
Generator: new MemoryPackFormatterRegistrationGenerator(),
EnumFormatterGenerator: _ => string.Empty
),
SerializerType.MessagePack => (
Mapper: (ISerializationFormatterNameMapper)new MessagePackFormatterNameMapper(userDefinedFormattersNamespace),
Mapper: new MessagePackFormatterNameMapper(userDefinedFormattersNamespace),
Namespace: namespaceDot + "Resolvers",
InitializerName: "MagicOnionResolver",
Generator: (ISerializerFormatterGenerator)new MessagePackFormatterResolverGenerator(),
EnumFormatterGenerator: (Func<IEnumerable<EnumSerializationInfo>, string>)(enumSerializationInfo => new EnumTemplate()
Generator: new MessagePackFormatterResolverGenerator(),
EnumFormatterGenerator: x => new EnumTemplate()
{
Namespace = namespaceDot + "Formatters",
EnumSerializationInfos = enumSerializationInfo.ToArray()
}.TransformText())
EnumSerializationInfos = x.ToArray()
}.TransformText()
),
_ => throw new NotImplementedException(),
};
Expand Down Expand Up @@ -144,7 +145,7 @@ public async Task GenerateFileAsync(

foreach (var enumSerializationInfo in serializationInfoCollection.Enums)
{
Output(NormalizePath(output, namespaceDot + "Formatters", enumSerializationInfo.Name + "Formatter"), WithAutoGenerated(serialization.EnumFormatterGenerator(new []{ enumSerializationInfo })));
Output(NormalizePath(output, namespaceDot + "Formatters", enumSerializationInfo.FormatterName), WithAutoGenerated(serialization.EnumFormatterGenerator(new []{ enumSerializationInfo })));
}

foreach (var service in serviceCollection.Services)
Expand Down
Loading

0 comments on commit e9cf6b0

Please sign in to comment.