diff --git a/SourceKit.sln b/SourceKit.sln index 0b96d59..001e3b8 100644 --- a/SourceKit.sln +++ b/SourceKit.sln @@ -93,6 +93,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SourceKit.Analyzers.MemberA EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SourceKit.Analyzers.MemberAccessibility.Tests", "tests\SourceKit.Analyzers.MemberAccessibility.Tests\SourceKit.Analyzers.MemberAccessibility.Tests.csproj", "{EACEDBCD-081E-4F95-A81E-B62CD18D3028}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SourceKit.Generators.Grpc.Samples.Transitive", "samples\generators\SourceKit.Generators.Grpc.Samples.Transitive\SourceKit.Generators.Grpc.Samples.Transitive.csproj", "{B50328ED-E004-47F6-88F6-AC2E0BEABC17}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -137,6 +139,7 @@ Global {B3EEEF25-A90E-404F-9874-93A2E6990468} = {68973D47-37E1-492E-8F62-E94B002349BB} {6D49F53F-4765-4108-9AA7-1470D5FB8FD7} = {365210F4-5F33-49FD-9E14-552154E26285} {EACEDBCD-081E-4F95-A81E-B62CD18D3028} = {CB9AFB88-6DC1-436D-8F6F-398E065A07DE} + {B50328ED-E004-47F6-88F6-AC2E0BEABC17} = {84DBA1F6-2A81-452F-97EE-AEB57C0E7BC6} EndGlobalSection GlobalSection(ProjectConfigurationPlatforms) = postSolution {637C01C1-3A3C-4FC6-9874-6CFBA4319A79}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -263,5 +266,9 @@ Global {EACEDBCD-081E-4F95-A81E-B62CD18D3028}.Debug|Any CPU.Build.0 = Debug|Any CPU {EACEDBCD-081E-4F95-A81E-B62CD18D3028}.Release|Any CPU.ActiveCfg = Release|Any CPU {EACEDBCD-081E-4F95-A81E-B62CD18D3028}.Release|Any CPU.Build.0 = Release|Any CPU + {B50328ED-E004-47F6-88F6-AC2E0BEABC17}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B50328ED-E004-47F6-88F6-AC2E0BEABC17}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B50328ED-E004-47F6-88F6-AC2E0BEABC17}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B50328ED-E004-47F6-88F6-AC2E0BEABC17}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection EndGlobal diff --git a/samples/generators/SourceKit.Generators.Grpc.Samples.Transitive/Program.cs b/samples/generators/SourceKit.Generators.Grpc.Samples.Transitive/Program.cs new file mode 100644 index 0000000..cf0692d --- /dev/null +++ b/samples/generators/SourceKit.Generators.Grpc.Samples.Transitive/Program.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; + +Console.WriteLine("Hello, World!"); + +var model = new ProtoProtoModel( + pageToken: "", + values: new[] { "" }, + pageSize: 0, + intValues: new[] { 0 }, + intOneofValue: null, + stringOneofValue: null, + notNullStringValue: "", + nullIntValue: null, + nullStringValue: null, + mapValue: new Dictionary { [1] = "1" }, + m: new ProtoProtoModel.Types.InnerMessage(@enum: ProtoProtoModel.Types.InnerEnum.Aboba1)); + +var emptyModel = new ProtoEmptyMessage(); + +Console.WriteLine(model); diff --git a/samples/generators/SourceKit.Generators.Grpc.Samples.Transitive/SourceKit.Generators.Grpc.Samples.Transitive.csproj b/samples/generators/SourceKit.Generators.Grpc.Samples.Transitive/SourceKit.Generators.Grpc.Samples.Transitive.csproj new file mode 100644 index 0000000..6794d5c --- /dev/null +++ b/samples/generators/SourceKit.Generators.Grpc.Samples.Transitive/SourceKit.Generators.Grpc.Samples.Transitive.csproj @@ -0,0 +1,14 @@ + + + + Exe + net10.0 + enable + enable + + + + + + + diff --git a/src/generators/SourceKit.Generators.Grpc/Generators/ProtoMessageAliasGenerator.cs b/src/generators/SourceKit.Generators.Grpc/Generators/ProtoMessageAliasGenerator.cs index b5cff83..d6393e3 100644 --- a/src/generators/SourceKit.Generators.Grpc/Generators/ProtoMessageAliasGenerator.cs +++ b/src/generators/SourceKit.Generators.Grpc/Generators/ProtoMessageAliasGenerator.cs @@ -1,46 +1,119 @@ +using System.Collections.Immutable; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using SourceKit.Extensions; -using SourceKit.Generators.Grpc.Receivers; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace SourceKit.Generators.Grpc.Generators; [Generator] -public class ProtoMessageAliasGenerator : ISourceGenerator +public sealed class ProtoMessageAliasGenerator : IIncrementalGenerator { - public void Initialize(GeneratorInitializationContext context) + public void Initialize(IncrementalGeneratorInitializationContext context) { - context.RegisterForSyntaxNotifications(() => new ProtoMessageAliasReceiver()); + IncrementalValuesProvider allTypes = context.CompilationProvider + .SelectMany(static (compilation, ct) => EnumerateNestedTypesAndSelf(compilation.GlobalNamespace, ct)); + + allTypes = allTypes + .Where(type => type.ContainingNamespace.ToDisplayString().StartsWith("Google") is false); + + IncrementalValueProvider messageInterfaceSymbol = context.CompilationProvider + .Select(static (compilation, _) => compilation.GetTypeByMetadataName( + Constants.ProtobufMessageInterfaceFullyQualifiedName)); + + IncrementalValueProvider enumAttributeSymbol = context.CompilationProvider + .Select(static (compilation, _) => compilation.GetTypeByMetadataName( + Constants.ProtobufOriginalNameAttributeFullyQualifiedName)); + + IncrementalValuesProvider protoMessages = allTypes + .Combine(messageInterfaceSymbol) + .Where(static tuple => tuple.Right is not null) + .Where(static tuple => tuple.Left.TypeKind is TypeKind.Class) + .Where(static tuple => tuple.Left.AllInterfaces.Contains(tuple.Right!, SymbolEqualityComparer.Default)) + .Where(static tuple => tuple.Left.ContainingType is null) + .Select(static (tuple, _) => tuple.Left); + + IncrementalValuesProvider protoEnums = allTypes + .Combine(enumAttributeSymbol) + .Where(static tuple => tuple.Left.TypeKind is TypeKind.Enum) + .Where(static tuple => tuple.Left + .GetMembers() + .OfType() + .All(member => member + .GetAttributes() + .Any(attr => attr.AttributeClass?.Equals(tuple.Right, SymbolEqualityComparer.Default) is true))) + .Where(tuple => tuple.Left.ContainingType is null) + .Select((tuple, _) => tuple.Left); + + IncrementalValueProvider> protoTypes = protoMessages + .Collect() + .Combine(protoEnums.Collect()) + .SelectMany((tuple, _) => tuple.Left.Concat(tuple.Right)) + .Collect(); + + context.RegisterSourceOutput( + protoTypes, + static (context, protoTypes) => + { + if (protoTypes is []) + return; + + UsingDirectiveSyntax[] directives = protoTypes + .GroupBy(x => x.Name, (k, values) => (k, values: values.ToArray())) + .Where(x => x.values.Length is 1) + .Select(x => x.values.Single()) + .OrderBy(x => x.Name) + .Select(GenerateAlias) + .ToArray(); + + if (directives is []) + return; + + CompilationUnitSyntax unit = CompilationUnit().AddUsings(directives).NormalizeWhitespace(eol: "\n"); + string text = unit.ToFullString(); + + text = $""" + // + // This code was generated by a SourceKit.Generators.Grpc code generator. + // https://github.com/itmo-is-dev/SourceKit + // + + {text} + """; + + context.AddSource("SourceKit.Generators.Builder.ProtoAlias.cs", text); + }); } - public void Execute(GeneratorExecutionContext context) + private static IEnumerable EnumerateNestedTypesAndSelf( + INamespaceOrTypeSymbol symbol, + CancellationToken cancellationToken) { - if (context.SyntaxContextReceiver is not ProtoMessageAliasReceiver receiver) - return; - - UsingDirectiveSyntax[] directives = receiver.Symbols - .GroupBy(x => x.Name, (k, values) => (k, values: values.ToArray())) - .Where(x => x.values.Length is 1) - .Select(x => x.values.Single()) - .OrderBy(x => x.Name) - .Select(GenerateAlias) - .ToArray(); - - CompilationUnitSyntax unit = CompilationUnit().AddUsings(directives).NormalizeWhitespace(eol: "\n"); - string text = unit.ToFullString(); - - text = $""" - // - // This code was generated by a SourceKit.Generators.Grpc code generator. - // https://github.com/itmo-is-dev/SourceKit - // - - {text} - """; - - context.AddSource("SourceKit.Generators.Builder.ProtoAlias.cs", text); + cancellationToken.ThrowIfCancellationRequested(); + + if (symbol is INamedTypeSymbol namedTypeSymbol) + { + return namedTypeSymbol + .GetTypeMembers() + .SelectMany(type => EnumerateNestedTypesAndSelf(type, cancellationToken)) + .Prepend(namedTypeSymbol); + } + + if (symbol is INamespaceSymbol namespaceSymbol) + { + IEnumerable directTypes = namespaceSymbol + .GetTypeMembers() + .SelectMany(type => EnumerateNestedTypesAndSelf(type, cancellationToken)); + + IEnumerable nestedNamespaceTypes = namespaceSymbol + .GetNamespaceMembers() + .SelectMany(ns => EnumerateNestedTypesAndSelf(ns, cancellationToken)); + + return directTypes.Concat(nestedNamespaceTypes); + } + + return []; } private static UsingDirectiveSyntax GenerateAlias(INamedTypeSymbol symbol) diff --git a/src/generators/SourceKit.Generators.Grpc/Receivers/ProtoMessageAliasReceiver.cs b/src/generators/SourceKit.Generators.Grpc/Receivers/ProtoMessageAliasReceiver.cs deleted file mode 100644 index 7a69d39..0000000 --- a/src/generators/SourceKit.Generators.Grpc/Receivers/ProtoMessageAliasReceiver.cs +++ /dev/null @@ -1,60 +0,0 @@ -using Microsoft.CodeAnalysis; - -namespace SourceKit.Generators.Grpc.Receivers; - -public class ProtoMessageAliasReceiver : ISyntaxContextReceiver -{ - private readonly List _symbols = []; - - public IReadOnlyCollection Symbols => _symbols; - - public void OnVisitSyntaxNode(GeneratorSyntaxContext context) - { - INamedTypeSymbol? messageInterfaceSymbol = context.SemanticModel.Compilation - .GetTypeByMetadataName(Constants.ProtobufMessageInterfaceFullyQualifiedName); - - INamedTypeSymbol? enumAttributeSymbol = context.SemanticModel.Compilation - .GetTypeByMetadataName(Constants.ProtobufOriginalNameAttributeFullyQualifiedName); - - if (messageInterfaceSymbol is null || enumAttributeSymbol is null) - return; - - ISymbol? symbolInfo = context.SemanticModel.GetDeclaredSymbol(context.Node); - - if (symbolInfo is not INamedTypeSymbol symbol) - return; - - if (IsProtoClass(symbol) is false && IsProtoEnum(symbol) is false) - return; - - if (symbol.ContainingType is not null) - return; - - _symbols.Add(symbol); - - bool IsProtoClass(INamedTypeSymbol type) - { - return type.TypeKind is TypeKind.Class - && type.AllInterfaces.Contains(messageInterfaceSymbol, SymbolEqualityComparer.Default); - } - - bool IsProtoEnum(INamedTypeSymbol type) - { - if (type.TypeKind is not TypeKind.Enum) - return false; - - return type - .GetMembers() - .OfType() - .All(member => member - .GetAttributes() - .Any(attr => - { - if (attr.AttributeClass is null) - return false; - - return attr.AttributeClass.Equals(enumAttributeSymbol, SymbolEqualityComparer.Default); - })); - } - } -}