From bdebcd698b72c0654e9a57a264566d521afdc521 Mon Sep 17 00:00:00 2001 From: Dan Walmsley <4672627+danwalmsley@users.noreply.github.com> Date: Wed, 11 Mar 2026 11:57:04 +0000 Subject: [PATCH 1/5] Add configurable StaticViewLocator naming rules --- README.md | 6 + .../StaticViewLocatorGeneratorRuntimeTests.cs | 100 ++++++++ ...StaticViewLocatorGeneratorSnapshotTests.cs | 216 +++++++++++++++++- .../StaticViewLocatorGeneratorVerifier.cs | 47 +++- .../StaticViewLocatorGenerator.cs | 198 +++++++++++++++- .../buildTransitive/StaticViewLocator.props | 3 + 6 files changed, 560 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index f16f0e4..1448ea4 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,9 @@ You can scope which view model namespaces are considered and opt into additional false false MyApp.Controls.ToolWindowBase + ViewModels=Views + ViewModel=View;Vm=Page + true ``` @@ -78,6 +81,9 @@ Defaults and behavior: - `StaticViewLocatorIncludeReferencedAssemblies` defaults to `false`. When `true`, view models from referenced assemblies are included. - `StaticViewLocatorIncludeInternalViewModels` defaults to `false`. When `true`, internal view models from referenced assemblies are included only if the referenced assembly exposes them via `InternalsVisibleTo`. - `StaticViewLocatorAdditionalViewBaseTypes` uses `;` or `,` separators and extends the default view base type list. +- `StaticViewLocatorNamespaceReplacementRules` uses `;` or `,` separators with `from=to` pairs and is applied sequentially to the view-model namespace when deriving the target view namespace. The default includes `ViewModels=Views`. +- `StaticViewLocatorTypeNameReplacementRules` uses `;` or `,` separators with `from=to` pairs and is applied sequentially to the view-model type name when deriving the target view name. The default includes `ViewModel=View`. +- `StaticViewLocatorStripGenericArityFromViewName` defaults to `true`. When enabled, generic arity markers like `` `1 `` are removed from the derived target view name, so `WidgetViewModel` can map to `WidgetView`. These properties are exported as `CompilerVisibleProperty` by the package, so analyzers can read them without extra project configuration. diff --git a/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs b/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs index c2978d9..efede81 100644 --- a/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs +++ b/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs @@ -9,6 +9,7 @@ using Avalonia.Headless.XUnit; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; using StaticViewLocator; using Xunit; @@ -195,6 +196,73 @@ public class ReportsView : UserControl Assert.DoesNotContain(viewsMap.Keys, key => key.FullName?.Contains("WorkspaceViewModel", StringComparison.Ordinal) == true); } + [AvaloniaFact] + public async Task ResolvesGenericViewModelsUsingGenericTypeDefinition() + { + const string source = @" +using System; +using Avalonia.Controls; +using StaticViewLocator; + +namespace TestApp +{ + [StaticViewLocator] + public partial class ViewLocator + { + } +} + +namespace TestApp.ViewModels +{ + public class WidgetViewModel + { + } +} + +namespace TestApp.Views +{ + public class WidgetView : UserControl + { + } +} +"; + + var compilation = await CreateCompilationAsync(source); + var sourceGenerator = new StaticViewLocatorGenerator().AsSourceGenerator(); + var driver = CSharpGeneratorDriver.Create( + new[] { sourceGenerator }, + parseOptions: (CSharpParseOptions)compilation.SyntaxTrees.First().Options, + optionsProvider: new TestAnalyzerConfigOptionsProvider(new Dictionary + { + ["build_property.StaticViewLocatorNamespaceReplacementRules"] = "ViewModels=Views", + ["build_property.StaticViewLocatorTypeNameReplacementRules"] = "ViewModel=View", + })); + + driver.RunGeneratorsAndUpdateCompilation(compilation, out var updatedCompilation, out var diagnostics); + + Assert.Empty(diagnostics.Where(d => d.Severity == DiagnosticSeverity.Error)); + + using var peStream = new MemoryStream(); + var emitResult = updatedCompilation.Emit(peStream); + Assert.True(emitResult.Success, string.Join(Environment.NewLine, emitResult.Diagnostics)); + + peStream.Seek(0, SeekOrigin.Begin); + var assembly = Assembly.Load(peStream.ToArray()); + + var locatorType = assembly.GetType("TestApp.ViewLocator") ?? throw new InvalidOperationException("Generated locator type not found."); + var buildMethod = locatorType.GetMethod("Build", BindingFlags.Public | BindingFlags.Instance) ?? throw new InvalidOperationException("Build method not found."); + var widgetVmType = assembly.GetType("TestApp.ViewModels.WidgetViewModel`1", throwOnError: true) ?? throw new InvalidOperationException("Generic VM type not found."); + var closedVm = Activator.CreateInstance(widgetVmType.MakeGenericType(typeof(int))) ?? throw new InvalidOperationException("Unable to instantiate closed generic VM."); + var locator = Activator.CreateInstance(locatorType) ?? throw new InvalidOperationException("Unable to instantiate generated locator."); + + _ = HeadlessUnitTestSession.GetOrStartForAssembly(typeof(StaticViewLocatorGeneratorRuntimeTests).Assembly); + + var control = (Control?)buildMethod.Invoke(locator, new[] { closedVm }); + + Assert.NotNull(control); + Assert.Equal("TestApp.Views.WidgetView", control!.GetType().FullName); + } + private static Task CreateCompilationAsync(string source) { var parseOptions = new CSharpParseOptions(LanguageVersion.Preview); @@ -255,4 +323,36 @@ private static object CreateInstance(Assembly assembly, string typeName) return Activator.CreateInstance(type) ?? throw new InvalidOperationException($"Unable to instantiate type '{typeName}'."); } + + private sealed class TestAnalyzerConfigOptionsProvider : AnalyzerConfigOptionsProvider + { + private static readonly AnalyzerConfigOptions EmptyOptions = new TestAnalyzerConfigOptions(new Dictionary()); + private readonly AnalyzerConfigOptions _globalOptions; + + public TestAnalyzerConfigOptionsProvider(IReadOnlyDictionary globalOptions) + { + _globalOptions = new TestAnalyzerConfigOptions(globalOptions); + } + + public override AnalyzerConfigOptions GlobalOptions => _globalOptions; + + public override AnalyzerConfigOptions GetOptions(SyntaxTree tree) => EmptyOptions; + + public override AnalyzerConfigOptions GetOptions(AdditionalText textFile) => EmptyOptions; + } + + private sealed class TestAnalyzerConfigOptions : AnalyzerConfigOptions + { + private readonly IReadOnlyDictionary _options; + + public TestAnalyzerConfigOptions(IReadOnlyDictionary options) + { + _options = options; + } + + public override bool TryGetValue(string key, out string value) + { + return _options.TryGetValue(key, out value!); + } + } } diff --git a/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs b/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs index 6464bda..6aab6bd 100644 --- a/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs +++ b/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs @@ -1,3 +1,4 @@ +using System.Collections.Generic; using System.Threading.Tasks; using StaticViewLocator.Tests.TestHelpers; using Xunit; @@ -87,7 +88,13 @@ public partial class ViewLocator var type = data.GetType(); - if (s_views.TryGetValue(type, out var func)) + if (!s_views.TryGetValue(type, out var func) && + type.IsGenericType) + { + s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + } + + if (func is not null) { return func.Invoke(); } @@ -191,7 +198,13 @@ public partial class AdminViewLocator var type = data.GetType(); - if (s_views.TryGetValue(type, out var func)) + if (!s_views.TryGetValue(type, out var func) && + type.IsGenericType) + { + s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + } + + if (func is not null) { return func.Invoke(); } @@ -229,7 +242,13 @@ public partial class ClientViewLocator var type = data.GetType(); - if (s_views.TryGetValue(type, out var func)) + if (!s_views.TryGetValue(type, out var func) && + type.IsGenericType) + { + s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + } + + if (func is not null) { return func.Invoke(); } @@ -317,4 +336,195 @@ await StaticViewLocatorGeneratorVerifier.VerifyGeneratedSourcesAsync( ("StaticViewLocatorAttribute.cs", expectedAttribute), ("ViewLocator_StaticViewLocator.cs", expectedLocator)); } + + [Fact] + public async Task AppliesConfiguredNamespaceAndTypeReplacementRules() + { + const string input = @" +using Avalonia.Controls; +using StaticViewLocator; + +namespace TestApp.ViewModels.Pages +{ + public class DashboardViewModel + { + } +} + +namespace TestApp.Views.Pages +{ + public class DashboardScreen : UserControl + { + } +} + +namespace TestApp +{ + [StaticViewLocator] + public partial class ViewLocator + { + } +} +"; + + const string expectedAttribute = """ +// +using System; + +namespace StaticViewLocator; + +[AttributeUsage(AttributeTargets.Class, Inherited = false, AllowMultiple = false)] +public sealed class StaticViewLocatorAttribute : Attribute +{ +} + +"""; + + const string expectedLocator = """ +// +#nullable enable +using System; +using System.Collections.Generic; +using Avalonia.Controls; + +namespace TestApp; + +public partial class ViewLocator +{ + private static Dictionary> s_views = new() + { + [typeof(TestApp.ViewModels.Pages.DashboardViewModel)] = () => new TestApp.Views.Pages.DashboardScreen(), + }; + + public Control? Build(object? data) + { + if (data is null) + { + return null; + } + + var type = data.GetType(); + + if (!s_views.TryGetValue(type, out var func) && + type.IsGenericType) + { + s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + } + + if (func is not null) + { + return func.Invoke(); + } + + throw new Exception($"Unable to create view for type: {type}"); + } +} + +"""; + + await StaticViewLocatorGeneratorVerifier.VerifyGeneratedSourcesAsync( + input, + new Dictionary + { + ["build_property.StaticViewLocatorNamespaceReplacementRules"] = "ViewModels=Views", + ["build_property.StaticViewLocatorTypeNameReplacementRules"] = "ViewModel=Screen", + }, + ("StaticViewLocatorAttribute.cs", expectedAttribute), + ("ViewLocator_StaticViewLocator.cs", expectedLocator)); + } + + [Fact] + public async Task StripsGenericArityFromConfiguredViewNamesByDefault() + { + const string input = @" +using Avalonia.Controls; +using StaticViewLocator; + +namespace TestApp.ViewModels +{ + public class WidgetViewModel + { + } +} + +namespace TestApp.Views +{ + public class WidgetView : UserControl + { + } +} + +namespace TestApp +{ + [StaticViewLocator] + public partial class ViewLocator + { + } +} +"; + + const string expectedAttribute = """ +// +using System; + +namespace StaticViewLocator; + +[AttributeUsage(AttributeTargets.Class, Inherited = false, AllowMultiple = false)] +public sealed class StaticViewLocatorAttribute : Attribute +{ +} + +"""; + + const string expectedLocator = """ +// +#nullable enable +using System; +using System.Collections.Generic; +using Avalonia.Controls; + +namespace TestApp; + +public partial class ViewLocator +{ + private static Dictionary> s_views = new() + { + [typeof(TestApp.ViewModels.WidgetViewModel<>)] = () => new TestApp.Views.WidgetView(), + }; + + public Control? Build(object? data) + { + if (data is null) + { + return null; + } + + var type = data.GetType(); + + if (!s_views.TryGetValue(type, out var func) && + type.IsGenericType) + { + s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + } + + if (func is not null) + { + return func.Invoke(); + } + + throw new Exception($"Unable to create view for type: {type}"); + } +} + +"""; + + await StaticViewLocatorGeneratorVerifier.VerifyGeneratedSourcesAsync( + input, + new Dictionary + { + ["build_property.StaticViewLocatorTypeNameReplacementRules"] = "ViewModel=View", + }, + ("StaticViewLocatorAttribute.cs", expectedAttribute), + ("ViewLocator_StaticViewLocator.cs", expectedLocator)); + } } diff --git a/StaticViewLocator.Tests/TestHelpers/StaticViewLocatorGeneratorVerifier.cs b/StaticViewLocator.Tests/TestHelpers/StaticViewLocatorGeneratorVerifier.cs index 16d17d8..3fc9d5e 100644 --- a/StaticViewLocator.Tests/TestHelpers/StaticViewLocatorGeneratorVerifier.cs +++ b/StaticViewLocator.Tests/TestHelpers/StaticViewLocatorGeneratorVerifier.cs @@ -7,6 +7,7 @@ using Avalonia.Controls; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; using StaticViewLocator; using Xunit; @@ -15,6 +16,14 @@ namespace StaticViewLocator.Tests.TestHelpers; internal static class StaticViewLocatorGeneratorVerifier { public static Task VerifyGeneratedSourcesAsync(string source, params (string hintName, string source)[] generatedSources) + { + return VerifyGeneratedSourcesAsync(source, globalOptions: null, generatedSources); + } + + public static Task VerifyGeneratedSourcesAsync( + string source, + IReadOnlyDictionary? globalOptions = null, + params (string hintName, string source)[] generatedSources) { var parseOptions = new CSharpParseOptions(LanguageVersion.Preview); var syntaxTree = CSharpSyntaxTree.ParseText(source, parseOptions); @@ -26,7 +35,11 @@ public static Task VerifyGeneratedSourcesAsync(string source, params (string hin options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); var generator = new StaticViewLocatorGenerator().AsSourceGenerator(); - GeneratorDriver driver = CSharpGeneratorDriver.Create(new[] { generator }, parseOptions: parseOptions); + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: new[] { generator }, + additionalTexts: null, + parseOptions: parseOptions, + optionsProvider: new TestAnalyzerConfigOptionsProvider(globalOptions)); driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out _, out var diagnostics); @@ -61,6 +74,38 @@ public static Task VerifyGeneratedSourcesAsync(string source, params (string hin return Task.CompletedTask; } + private sealed class TestAnalyzerConfigOptionsProvider : AnalyzerConfigOptionsProvider + { + private static readonly AnalyzerConfigOptions EmptyOptions = new TestAnalyzerConfigOptions(null); + private readonly AnalyzerConfigOptions _globalOptions; + + public TestAnalyzerConfigOptionsProvider(IReadOnlyDictionary? globalOptions) + { + _globalOptions = new TestAnalyzerConfigOptions(globalOptions); + } + + public override AnalyzerConfigOptions GlobalOptions => _globalOptions; + + public override AnalyzerConfigOptions GetOptions(SyntaxTree tree) => EmptyOptions; + + public override AnalyzerConfigOptions GetOptions(AdditionalText textFile) => EmptyOptions; + } + + private sealed class TestAnalyzerConfigOptions : AnalyzerConfigOptions + { + private readonly IReadOnlyDictionary _options; + + public TestAnalyzerConfigOptions(IReadOnlyDictionary? options) + { + _options = options ?? new Dictionary(StringComparer.Ordinal); + } + + public override bool TryGetValue(string key, out string value) + { + return _options.TryGetValue(key, out value!); + } + } + private static IReadOnlyCollection GetMetadataReferences() { var references = new List(); diff --git a/StaticViewLocator/StaticViewLocatorGenerator.cs b/StaticViewLocator/StaticViewLocatorGenerator.cs index aac7e23..92dc0c7 100644 --- a/StaticViewLocator/StaticViewLocatorGenerator.cs +++ b/StaticViewLocator/StaticViewLocatorGenerator.cs @@ -21,6 +21,9 @@ public sealed class StaticViewLocatorGenerator : IIncrementalGenerator private const string IncludeInternalViewModelsProperty = "build_property.StaticViewLocatorIncludeInternalViewModels"; private const string IncludeReferencedAssembliesProperty = "build_property.StaticViewLocatorIncludeReferencedAssemblies"; private const string AdditionalViewBaseTypesProperty = "build_property.StaticViewLocatorAdditionalViewBaseTypes"; + private const string NamespaceReplacementRulesProperty = "build_property.StaticViewLocatorNamespaceReplacementRules"; + private const string TypeNameReplacementRulesProperty = "build_property.StaticViewLocatorTypeNameReplacementRules"; + private const string StripGenericArityFromViewNameProperty = "build_property.StaticViewLocatorStripGenericArityFromViewName"; private readonly struct GeneratorOptions { @@ -28,18 +31,39 @@ public GeneratorOptions( ImmutableArray namespacePrefixes, bool includeInternalViewModels, bool includeReferencedAssemblies, - ImmutableArray additionalViewBaseTypes) + ImmutableArray additionalViewBaseTypes, + ImmutableArray namespaceReplacementRules, + ImmutableArray typeNameReplacementRules, + bool stripGenericArityFromViewName) { NamespacePrefixes = namespacePrefixes; IncludeInternalViewModels = includeInternalViewModels; IncludeReferencedAssemblies = includeReferencedAssemblies; AdditionalViewBaseTypes = additionalViewBaseTypes; + NamespaceReplacementRules = namespaceReplacementRules; + TypeNameReplacementRules = typeNameReplacementRules; + StripGenericArityFromViewName = stripGenericArityFromViewName; } public ImmutableArray NamespacePrefixes { get; } public bool IncludeInternalViewModels { get; } public bool IncludeReferencedAssemblies { get; } public ImmutableArray AdditionalViewBaseTypes { get; } + public ImmutableArray NamespaceReplacementRules { get; } + public ImmutableArray TypeNameReplacementRules { get; } + public bool StripGenericArityFromViewName { get; } + } + + private readonly struct ReplacementRule + { + public ReplacementRule(string from, string to) + { + From = from; + To = to; + } + + public string From { get; } + public string To { get; } } private const string AttributeText = @@ -161,12 +185,18 @@ private static GeneratorOptions GetGeneratorOptions(AnalyzerConfigOptionsProvide var includeInternal = GetIncludeInternalViewModels(optionsProvider); var includeReferencedAssemblies = GetIncludeReferencedAssemblies(optionsProvider); var additionalViewBaseTypes = GetAdditionalViewBaseTypes(optionsProvider); + var namespaceReplacementRules = GetReplacementRules(optionsProvider, NamespaceReplacementRulesProperty, new ReplacementRule("ViewModels", "Views")); + var typeNameReplacementRules = GetReplacementRules(optionsProvider, TypeNameReplacementRulesProperty, new ReplacementRule(ViewModelSuffix, ViewSuffix)); + var stripGenericArityFromViewName = GetStripGenericArityFromViewName(optionsProvider); return new GeneratorOptions( namespacePrefixes, includeInternal, includeReferencedAssemblies, - additionalViewBaseTypes); + additionalViewBaseTypes, + namespaceReplacementRules, + typeNameReplacementRules, + stripGenericArityFromViewName); } private static ImmutableArray GetNamespacePrefixes(AnalyzerConfigOptionsProvider optionsProvider) @@ -251,6 +281,68 @@ private static ImmutableArray GetAdditionalViewBaseTypes(AnalyzerConfigO return builder.ToImmutable(); } + private static ImmutableArray GetReplacementRules( + AnalyzerConfigOptionsProvider optionsProvider, + string propertyName, + params ReplacementRule[] defaults) + { + var builder = ImmutableArray.CreateBuilder(); + + if (!optionsProvider.GlobalOptions.TryGetValue(propertyName, out var rawValue) || + string.IsNullOrWhiteSpace(rawValue)) + { + foreach (var replacementRule in defaults) + { + builder.Add(replacementRule); + } + + return builder.ToImmutable(); + } + + var parts = rawValue.Split(new[] { ';', ',' }, StringSplitOptions.RemoveEmptyEntries); + foreach (var part in parts) + { + var trimmed = part.Trim(); + if (trimmed.Length == 0) + { + continue; + } + + var separatorIndex = trimmed.IndexOf('='); + if (separatorIndex <= 0 || separatorIndex == trimmed.Length - 1) + { + continue; + } + + var from = trimmed.Substring(0, separatorIndex).Trim(); + var to = trimmed.Substring(separatorIndex + 1).Trim(); + + if (from.Length == 0) + { + continue; + } + + builder.Add(new ReplacementRule(from, to)); + } + + foreach (var replacementRule in defaults) + { + builder.Add(replacementRule); + } + + return builder.ToImmutable(); + } + + private static bool GetStripGenericArityFromViewName(AnalyzerConfigOptionsProvider optionsProvider) + { + if (!optionsProvider.GlobalOptions.TryGetValue(StripGenericArityFromViewNameProperty, out var rawValue)) + { + return true; + } + + return !bool.TryParse(rawValue, out var stripGenericArityFromViewName) || stripGenericArityFromViewName; + } + private static bool IsViewModelType(INamedTypeSymbol symbol) { if (symbol.TypeKind != TypeKind.Class || symbol.IsAbstract) @@ -362,6 +454,78 @@ private static bool MatchesNamespace(string namespaceName, ImmutableArray replacementRules) + { + var value = input; + + foreach (var replacementRule in replacementRules) + { + value = value.Replace(replacementRule.From, replacementRule.To); + } + + return value; + } + + private static string GetMetadataTypeName(INamedTypeSymbol symbol) + { + var parts = new Stack(); + + for (var current = symbol; current is not null; current = current.ContainingType) + { + parts.Push(current.MetadataName); + } + + return string.Join(".", parts); + } + + private static string GetSourceTypeName(INamedTypeSymbol symbol) + { + var parts = new Stack(); + + for (var current = symbol; current is not null; current = current.ContainingType) + { + parts.Push(current.Name); + } + + return string.Join(".", parts); + } + + private static string GetSourceTypeReference(INamedTypeSymbol symbol) + { + var parts = new Stack(); + + for (var current = symbol; current is not null; current = current.ContainingType) + { + var part = current.Name; + if (current.IsGenericType) + { + part += "<" + new string(',', current.TypeParameters.Length - 1) + ">"; + } + + parts.Push(part); + } + + var typeName = string.Join(".", parts); + var namespaceName = symbol.ContainingNamespace.ToDisplayString(); + return string.IsNullOrEmpty(namespaceName) ? typeName : $"{namespaceName}.{typeName}"; + } + + private static string StripGenericArity(string typeName) + { + var parts = typeName.Split('.'); + + for (var i = 0; i < parts.Length; i++) + { + var tickIndex = parts[i].IndexOf('`'); + if (tickIndex >= 0) + { + parts[i] = parts[i].Substring(0, tickIndex); + } + } + + return string.Join(".", parts); + } + private static string? ProcessClass( Compilation compilation, INamedTypeSymbol locatorSymbol, @@ -449,10 +613,26 @@ public partial class {{classNameLocator}} foreach (var viewModelSymbol in relevantViewModels) { var namespaceNameViewModel = viewModelSymbol.ContainingNamespace.ToDisplayString(); - var classNameViewModel = $"{namespaceNameViewModel}.{viewModelSymbol.ToDisplayString(format)}"; - var classNameView = classNameViewModel.Replace(ViewModelSuffix, ViewSuffix); + var sourceTypeNameViewModel = GetSourceTypeName(viewModelSymbol); + var metadataTypeNameViewModel = GetMetadataTypeName(viewModelSymbol); + + var sourceNamespaceView = ApplyReplacementRules(namespaceNameViewModel, options.NamespaceReplacementRules); + var metadataNamespaceView = ApplyReplacementRules(namespaceNameViewModel, options.NamespaceReplacementRules); + + var sourceTypeNameView = ApplyReplacementRules(sourceTypeNameViewModel, options.TypeNameReplacementRules); + var metadataTypeNameView = ApplyReplacementRules(metadataTypeNameViewModel, options.TypeNameReplacementRules); + + if (options.StripGenericArityFromViewName) + { + sourceTypeNameView = StripGenericArity(sourceTypeNameView); + metadataTypeNameView = StripGenericArity(metadataTypeNameView); + } + + var classNameViewModel = GetSourceTypeReference(viewModelSymbol); + var classNameView = $"{sourceNamespaceView}.{sourceTypeNameView}"; + var metadataNameView = $"{metadataNamespaceView}.{metadataTypeNameView}"; - var viewSymbol = compilation.GetTypeByMetadataName(classNameView); + var viewSymbol = compilation.GetTypeByMetadataName(metadataNameView); var isSupportedView = false; if (viewSymbol is not null && viewBaseTypes.Count > 0) { @@ -493,7 +673,13 @@ public partial class {{classNameLocator}} var type = data.GetType(); - if (s_views.TryGetValue(type, out var func)) + if (!s_views.TryGetValue(type, out var func) && + type.IsGenericType) + { + s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + } + + if (func is not null) { return func.Invoke(); } diff --git a/StaticViewLocator/buildTransitive/StaticViewLocator.props b/StaticViewLocator/buildTransitive/StaticViewLocator.props index 4255837..ba71ae0 100644 --- a/StaticViewLocator/buildTransitive/StaticViewLocator.props +++ b/StaticViewLocator/buildTransitive/StaticViewLocator.props @@ -4,5 +4,8 @@ + + + From 258d961eca4604820f6d7aee24aa7cc86d2ca65a Mon Sep 17 00:00:00 2001 From: Dan Walmsley <4672627+danwalmsley@users.noreply.github.com> Date: Wed, 11 Mar 2026 12:02:57 +0000 Subject: [PATCH 2/5] Add base and interface fallback resolution --- .../StaticViewLocatorGeneratorRuntimeTests.cs | 167 ++++- ...StaticViewLocatorGeneratorSnapshotTests.cs | 676 +++++++++++++++++- .../StaticViewLocatorGenerator.cs | 265 ++++++- .../buildTransitive/StaticViewLocator.props | 1 + 4 files changed, 1038 insertions(+), 71 deletions(-) diff --git a/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs b/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs index efede81..a95909a 100644 --- a/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs +++ b/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs @@ -30,15 +30,6 @@ namespace TestApp [StaticViewLocator] public partial class ViewLocator { - public static Control Resolve(object vm) - { - if (vm is null) - { - throw new ArgumentNullException(nameof(vm)); - } - - return s_views[vm.GetType()](); - } } } @@ -76,14 +67,15 @@ public class SampleView : UserControl var assembly = Assembly.Load(peStream.ToArray()); var locatorType = assembly.GetType("TestApp.ViewLocator") ?? throw new InvalidOperationException("Generated locator type not found."); - var resolveMethod = locatorType.GetMethod("Resolve", BindingFlags.Public | BindingFlags.Static) ?? throw new InvalidOperationException("Resolve method not found."); + var buildMethod = locatorType.GetMethod("Build", BindingFlags.Public | BindingFlags.Instance) ?? throw new InvalidOperationException("Build method not found."); + var locator = Activator.CreateInstance(locatorType) ?? throw new InvalidOperationException("Unable to instantiate generated locator."); var sampleViewModel = CreateInstance(assembly, "TestApp.ViewModels.SampleViewModel"); var missingViewModel = CreateInstance(assembly, "TestApp.ViewModels.MissingViewModel"); _ = HeadlessUnitTestSession.GetOrStartForAssembly(typeof(StaticViewLocatorGeneratorRuntimeTests).Assembly); - var sampleControl = (Control)resolveMethod.Invoke(null, new[] { sampleViewModel })!; - var missingControl = (Control)resolveMethod.Invoke(null, new[] { missingViewModel })!; + var sampleControl = (Control)buildMethod.Invoke(locator, new[] { sampleViewModel })!; + var missingControl = (Control)buildMethod.Invoke(locator, new[] { missingViewModel })!; Assert.Equal("TestApp.Views.SampleView", sampleControl.GetType().FullName); Assert.Equal("Avalonia.Controls.TextBlock", missingControl.GetType().FullName); @@ -104,15 +96,6 @@ namespace Portal [StaticViewLocator] public partial class PortalViewLocator { - public static Control Locate(object vm) - { - if (vm is null) - { - throw new ArgumentNullException(nameof(vm)); - } - - return s_views[vm.GetType()](); - } } } @@ -166,7 +149,8 @@ public class ReportsView : UserControl var assembly = Assembly.Load(peStream.ToArray()); var locatorType = assembly.GetType("Portal.PortalViewLocator") ?? throw new InvalidOperationException("Generated locator type not found."); - var locateMethod = locatorType.GetMethod("Locate", BindingFlags.Public | BindingFlags.Static) ?? throw new InvalidOperationException("Locate method not found."); + var buildMethod = locatorType.GetMethod("Build", BindingFlags.Public | BindingFlags.Instance) ?? throw new InvalidOperationException("Build method not found."); + var locator = Activator.CreateInstance(locatorType) ?? throw new InvalidOperationException("Unable to instantiate generated locator."); var dictionaryField = locatorType.GetField("s_views", BindingFlags.NonPublic | BindingFlags.Static) ?? throw new InvalidOperationException("Dictionary field not found."); var viewsMap = (Dictionary>)dictionaryField.GetValue(null)!; @@ -174,7 +158,6 @@ public class ReportsView : UserControl { "Portal.ViewModels.HomeViewModel", "Portal.ViewModels.ReportsViewModel", - "Portal.ViewModels.SettingsViewModel", }; Assert.Equal(expectedOrder.Length, viewsMap.Count); @@ -184,9 +167,9 @@ public class ReportsView : UserControl var reportsViewModel = CreateInstance(assembly, "Portal.ViewModels.ReportsViewModel"); var settingsViewModel = CreateInstance(assembly, "Portal.ViewModels.SettingsViewModel"); - var homeControl = (Control)locateMethod.Invoke(null, new[] { homeViewModel })!; - var reportsControl = (Control)locateMethod.Invoke(null, new[] { reportsViewModel })!; - var settingsControl = (Control)locateMethod.Invoke(null, new[] { settingsViewModel })!; + var homeControl = (Control)buildMethod.Invoke(locator, new[] { homeViewModel })!; + var reportsControl = (Control)buildMethod.Invoke(locator, new[] { reportsViewModel })!; + var settingsControl = (Control)buildMethod.Invoke(locator, new[] { settingsViewModel })!; Assert.Equal("Portal.Views.HomeView", homeControl.GetType().FullName); Assert.Equal("Portal.Views.ReportsView", reportsControl.GetType().FullName); @@ -263,6 +246,138 @@ public class WidgetView : UserControl Assert.Equal("TestApp.Views.WidgetView", control!.GetType().FullName); } + [AvaloniaFact] + public async Task ResolvesUsingBaseClassBeforeInterfaceFallback() + { + const string source = @" +using System; +using Avalonia.Controls; +using StaticViewLocator; + +namespace TestApp +{ + [StaticViewLocator] + public partial class ViewLocator + { + } +} + +namespace TestApp.ViewModels +{ + public abstract class BaseViewModel + { + } + + public interface IAlternateViewModel + { + } + + public sealed class ConcreteViewModel : BaseViewModel, IAlternateViewModel + { + } +} + +namespace TestApp.Views +{ + public class BaseView : UserControl + { + } + + public class AlternateView : UserControl + { + } +} +"; + + var compilation = await CreateCompilationAsync(source); + var sourceGenerator = new StaticViewLocatorGenerator().AsSourceGenerator(); + var driver = CSharpGeneratorDriver.Create(new[] { sourceGenerator }, parseOptions: (CSharpParseOptions)compilation.SyntaxTrees.First().Options); + driver.RunGeneratorsAndUpdateCompilation(compilation, out var updatedCompilation, out var diagnostics); + + Assert.Empty(diagnostics.Where(d => d.Severity == DiagnosticSeverity.Error)); + + using var peStream = new MemoryStream(); + var emitResult = updatedCompilation.Emit(peStream); + Assert.True(emitResult.Success, string.Join(Environment.NewLine, emitResult.Diagnostics)); + + peStream.Seek(0, SeekOrigin.Begin); + var assembly = Assembly.Load(peStream.ToArray()); + + var locatorType = assembly.GetType("TestApp.ViewLocator") ?? throw new InvalidOperationException("Generated locator type not found."); + var buildMethod = locatorType.GetMethod("Build", BindingFlags.Public | BindingFlags.Instance) ?? throw new InvalidOperationException("Build method not found."); + var concreteViewModel = CreateInstance(assembly, "TestApp.ViewModels.ConcreteViewModel"); + var locator = Activator.CreateInstance(locatorType) ?? throw new InvalidOperationException("Unable to instantiate generated locator."); + + _ = HeadlessUnitTestSession.GetOrStartForAssembly(typeof(StaticViewLocatorGeneratorRuntimeTests).Assembly); + + var control = (Control?)buildMethod.Invoke(locator, new[] { concreteViewModel }); + + Assert.NotNull(control); + Assert.Equal("TestApp.Views.BaseView", control!.GetType().FullName); + } + + [AvaloniaFact] + public async Task ResolvesInterfaceMappingsByStrippingConfiguredPrefix() + { + const string source = @" +using System; +using Avalonia.Controls; +using StaticViewLocator; + +namespace TestApp +{ + [StaticViewLocator] + public partial class ViewLocator + { + } +} + +namespace TestApp.ViewModels +{ + public interface IDetailsViewModel + { + } + + public sealed class ConcreteViewModel : IDetailsViewModel + { + } +} + +namespace TestApp.Views +{ + public class DetailsView : UserControl + { + } +} +"; + + var compilation = await CreateCompilationAsync(source); + var sourceGenerator = new StaticViewLocatorGenerator().AsSourceGenerator(); + var driver = CSharpGeneratorDriver.Create(new[] { sourceGenerator }, parseOptions: (CSharpParseOptions)compilation.SyntaxTrees.First().Options); + driver.RunGeneratorsAndUpdateCompilation(compilation, out var updatedCompilation, out var diagnostics); + + Assert.Empty(diagnostics.Where(d => d.Severity == DiagnosticSeverity.Error)); + + using var peStream = new MemoryStream(); + var emitResult = updatedCompilation.Emit(peStream); + Assert.True(emitResult.Success, string.Join(Environment.NewLine, emitResult.Diagnostics)); + + peStream.Seek(0, SeekOrigin.Begin); + var assembly = Assembly.Load(peStream.ToArray()); + + var locatorType = assembly.GetType("TestApp.ViewLocator") ?? throw new InvalidOperationException("Generated locator type not found."); + var buildMethod = locatorType.GetMethod("Build", BindingFlags.Public | BindingFlags.Instance) ?? throw new InvalidOperationException("Build method not found."); + var concreteViewModel = CreateInstance(assembly, "TestApp.ViewModels.ConcreteViewModel"); + var locator = Activator.CreateInstance(locatorType) ?? throw new InvalidOperationException("Unable to instantiate generated locator."); + + _ = HeadlessUnitTestSession.GetOrStartForAssembly(typeof(StaticViewLocatorGeneratorRuntimeTests).Assembly); + + var control = (Control?)buildMethod.Invoke(locator, new[] { concreteViewModel }); + + Assert.NotNull(control); + Assert.Equal("TestApp.Views.DetailsView", control!.GetType().FullName); + } + private static Task CreateCompilationAsync(string source) { var parseOptions = new CSharpParseOptions(LanguageVersion.Preview); diff --git a/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs b/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs index 6aab6bd..d2283a4 100644 --- a/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs +++ b/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs @@ -75,10 +75,15 @@ public partial class ViewLocator { private static Dictionary> s_views = new() { - [typeof(TestApp.ViewModels.MainWindowViewModel)] = () => new TextBlock() { Text = "Not Found: TestApp.Views.MainWindowView" }, [typeof(TestApp.ViewModels.TestViewModel)] = () => new TestApp.Views.TestView(), }; + private static Dictionary s_missingViews = new() + { + [typeof(TestApp.ViewModels.IgnoredViewModel)] = "Not Found: TestApp.Views.IgnoredView", + [typeof(TestApp.ViewModels.MainWindowViewModel)] = "Not Found: TestApp.Views.MainWindowView", + }; + public Control? Build(object? data) { if (data is null) @@ -87,20 +92,139 @@ public partial class ViewLocator } var type = data.GetType(); + var func = TryGetFactory(type) ?? TryGetFactoryFromInterfaces(type); - if (!s_views.TryGetValue(type, out var func) && - type.IsGenericType) + if (func is not null) { - s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + return func.Invoke(); } - if (func is not null) + var missingView = TryGetMissingView(type) ?? TryGetMissingViewFromInterfaces(type); + if (missingView is not null) { - return func.Invoke(); + return new TextBlock { Text = missingView }; } throw new Exception($"Unable to create view for type: {type}"); } + + private static Func? TryGetFactory(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetFactoryForType(type, out var func)) + { + return func; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetFactoryForType(current, out func)) + { + return func; + } + } + + return null; + } + + private static Func? TryGetFactoryFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetFactoryForType(interfaceType, out var func)) + { + return func; + } + } + + return null; + } + + private static bool TryGetFactoryForType(Type type, out Func? func) + { + if (s_views.TryGetValue(type, out func)) + { + return true; + } + + if (type.IsGenericType && s_views.TryGetValue(type.GetGenericTypeDefinition(), out func)) + { + return true; + } + + func = null; + return false; + } + + private static string? TryGetMissingView(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetMissingViewForType(type, out var missingView)) + { + return missingView; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetMissingViewForType(current, out missingView)) + { + return missingView; + } + } + + return null; + } + + private static string? TryGetMissingViewFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetMissingViewForType(interfaceType, out var missingView)) + { + return missingView; + } + } + + return null; + } + + private static bool TryGetMissingViewForType(Type type, out string? missingView) + { + if (s_missingViews.TryGetValue(type, out missingView)) + { + return true; + } + + if (type.IsGenericType && s_missingViews.TryGetValue(type.GetGenericTypeDefinition(), out missingView)) + { + return true; + } + + missingView = null; + return false; + } } """; @@ -186,7 +310,11 @@ public partial class AdminViewLocator { [typeof(App.Modules.Admin.AdminDashboardViewModel)] = () => new App.Modules.Admin.AdminDashboardView(), [typeof(App.Modules.Client.ClientDashboardViewModel)] = () => new App.Modules.Client.ClientDashboardView(), - [typeof(App.Modules.Shared.ActivityLogViewModel)] = () => new TextBlock() { Text = "Not Found: App.Modules.Shared.ActivityLogView" }, + }; + + private static Dictionary s_missingViews = new() + { + [typeof(App.Modules.Shared.ActivityLogViewModel)] = "Not Found: App.Modules.Shared.ActivityLogView", }; public Control? Build(object? data) @@ -197,20 +325,139 @@ public partial class AdminViewLocator } var type = data.GetType(); + var func = TryGetFactory(type) ?? TryGetFactoryFromInterfaces(type); - if (!s_views.TryGetValue(type, out var func) && - type.IsGenericType) + if (func is not null) { - s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + return func.Invoke(); } - if (func is not null) + var missingView = TryGetMissingView(type) ?? TryGetMissingViewFromInterfaces(type); + if (missingView is not null) { - return func.Invoke(); + return new TextBlock { Text = missingView }; } throw new Exception($"Unable to create view for type: {type}"); } + + private static Func? TryGetFactory(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetFactoryForType(type, out var func)) + { + return func; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetFactoryForType(current, out func)) + { + return func; + } + } + + return null; + } + + private static Func? TryGetFactoryFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetFactoryForType(interfaceType, out var func)) + { + return func; + } + } + + return null; + } + + private static bool TryGetFactoryForType(Type type, out Func? func) + { + if (s_views.TryGetValue(type, out func)) + { + return true; + } + + if (type.IsGenericType && s_views.TryGetValue(type.GetGenericTypeDefinition(), out func)) + { + return true; + } + + func = null; + return false; + } + + private static string? TryGetMissingView(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetMissingViewForType(type, out var missingView)) + { + return missingView; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetMissingViewForType(current, out missingView)) + { + return missingView; + } + } + + return null; + } + + private static string? TryGetMissingViewFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetMissingViewForType(interfaceType, out var missingView)) + { + return missingView; + } + } + + return null; + } + + private static bool TryGetMissingViewForType(Type type, out string? missingView) + { + if (s_missingViews.TryGetValue(type, out missingView)) + { + return true; + } + + if (type.IsGenericType && s_missingViews.TryGetValue(type.GetGenericTypeDefinition(), out missingView)) + { + return true; + } + + missingView = null; + return false; + } } """; @@ -230,7 +477,11 @@ public partial class ClientViewLocator { [typeof(App.Modules.Admin.AdminDashboardViewModel)] = () => new App.Modules.Admin.AdminDashboardView(), [typeof(App.Modules.Client.ClientDashboardViewModel)] = () => new App.Modules.Client.ClientDashboardView(), - [typeof(App.Modules.Shared.ActivityLogViewModel)] = () => new TextBlock() { Text = "Not Found: App.Modules.Shared.ActivityLogView" }, + }; + + private static Dictionary s_missingViews = new() + { + [typeof(App.Modules.Shared.ActivityLogViewModel)] = "Not Found: App.Modules.Shared.ActivityLogView", }; public Control? Build(object? data) @@ -241,20 +492,139 @@ public partial class ClientViewLocator } var type = data.GetType(); + var func = TryGetFactory(type) ?? TryGetFactoryFromInterfaces(type); - if (!s_views.TryGetValue(type, out var func) && - type.IsGenericType) + if (func is not null) { - s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + return func.Invoke(); } - if (func is not null) + var missingView = TryGetMissingView(type) ?? TryGetMissingViewFromInterfaces(type); + if (missingView is not null) { - return func.Invoke(); + return new TextBlock { Text = missingView }; } throw new Exception($"Unable to create view for type: {type}"); } + + private static Func? TryGetFactory(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetFactoryForType(type, out var func)) + { + return func; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetFactoryForType(current, out func)) + { + return func; + } + } + + return null; + } + + private static Func? TryGetFactoryFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetFactoryForType(interfaceType, out var func)) + { + return func; + } + } + + return null; + } + + private static bool TryGetFactoryForType(Type type, out Func? func) + { + if (s_views.TryGetValue(type, out func)) + { + return true; + } + + if (type.IsGenericType && s_views.TryGetValue(type.GetGenericTypeDefinition(), out func)) + { + return true; + } + + func = null; + return false; + } + + private static string? TryGetMissingView(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetMissingViewForType(type, out var missingView)) + { + return missingView; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetMissingViewForType(current, out missingView)) + { + return missingView; + } + } + + return null; + } + + private static string? TryGetMissingViewFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetMissingViewForType(interfaceType, out var missingView)) + { + return missingView; + } + } + + return null; + } + + private static bool TryGetMissingViewForType(Type type, out string? missingView) + { + if (s_missingViews.TryGetValue(type, out missingView)) + { + return true; + } + + if (type.IsGenericType && s_missingViews.TryGetValue(type.GetGenericTypeDefinition(), out missingView)) + { + return true; + } + + missingView = null; + return false; + } } """; @@ -327,6 +697,10 @@ public partial class ViewLocator { [typeof(TestApp.ViewModels.SampleViewModel)] = () => new TestApp.Views.SampleView(), }; + + private static Dictionary s_missingViews = new() + { + }; } """; @@ -396,6 +770,10 @@ public partial class ViewLocator [typeof(TestApp.ViewModels.Pages.DashboardViewModel)] = () => new TestApp.Views.Pages.DashboardScreen(), }; + private static Dictionary s_missingViews = new() + { + }; + public Control? Build(object? data) { if (data is null) @@ -404,20 +782,139 @@ public partial class ViewLocator } var type = data.GetType(); + var func = TryGetFactory(type) ?? TryGetFactoryFromInterfaces(type); - if (!s_views.TryGetValue(type, out var func) && - type.IsGenericType) + if (func is not null) { - s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + return func.Invoke(); } - if (func is not null) + var missingView = TryGetMissingView(type) ?? TryGetMissingViewFromInterfaces(type); + if (missingView is not null) { - return func.Invoke(); + return new TextBlock { Text = missingView }; } throw new Exception($"Unable to create view for type: {type}"); } + + private static Func? TryGetFactory(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetFactoryForType(type, out var func)) + { + return func; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetFactoryForType(current, out func)) + { + return func; + } + } + + return null; + } + + private static Func? TryGetFactoryFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetFactoryForType(interfaceType, out var func)) + { + return func; + } + } + + return null; + } + + private static bool TryGetFactoryForType(Type type, out Func? func) + { + if (s_views.TryGetValue(type, out func)) + { + return true; + } + + if (type.IsGenericType && s_views.TryGetValue(type.GetGenericTypeDefinition(), out func)) + { + return true; + } + + func = null; + return false; + } + + private static string? TryGetMissingView(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetMissingViewForType(type, out var missingView)) + { + return missingView; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetMissingViewForType(current, out missingView)) + { + return missingView; + } + } + + return null; + } + + private static string? TryGetMissingViewFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetMissingViewForType(interfaceType, out var missingView)) + { + return missingView; + } + } + + return null; + } + + private static bool TryGetMissingViewForType(Type type, out string? missingView) + { + if (s_missingViews.TryGetValue(type, out missingView)) + { + return true; + } + + if (type.IsGenericType && s_missingViews.TryGetValue(type.GetGenericTypeDefinition(), out missingView)) + { + return true; + } + + missingView = null; + return false; + } } """; @@ -492,6 +989,10 @@ public partial class ViewLocator [typeof(TestApp.ViewModels.WidgetViewModel<>)] = () => new TestApp.Views.WidgetView(), }; + private static Dictionary s_missingViews = new() + { + }; + public Control? Build(object? data) { if (data is null) @@ -500,20 +1001,139 @@ public partial class ViewLocator } var type = data.GetType(); + var func = TryGetFactory(type) ?? TryGetFactoryFromInterfaces(type); - if (!s_views.TryGetValue(type, out var func) && - type.IsGenericType) + if (func is not null) { - s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + return func.Invoke(); } - if (func is not null) + var missingView = TryGetMissingView(type) ?? TryGetMissingViewFromInterfaces(type); + if (missingView is not null) { - return func.Invoke(); + return new TextBlock { Text = missingView }; } throw new Exception($"Unable to create view for type: {type}"); } + + private static Func? TryGetFactory(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetFactoryForType(type, out var func)) + { + return func; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetFactoryForType(current, out func)) + { + return func; + } + } + + return null; + } + + private static Func? TryGetFactoryFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetFactoryForType(interfaceType, out var func)) + { + return func; + } + } + + return null; + } + + private static bool TryGetFactoryForType(Type type, out Func? func) + { + if (s_views.TryGetValue(type, out func)) + { + return true; + } + + if (type.IsGenericType && s_views.TryGetValue(type.GetGenericTypeDefinition(), out func)) + { + return true; + } + + func = null; + return false; + } + + private static string? TryGetMissingView(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetMissingViewForType(type, out var missingView)) + { + return missingView; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetMissingViewForType(current, out missingView)) + { + return missingView; + } + } + + return null; + } + + private static string? TryGetMissingViewFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetMissingViewForType(interfaceType, out var missingView)) + { + return missingView; + } + } + + return null; + } + + private static bool TryGetMissingViewForType(Type type, out string? missingView) + { + if (s_missingViews.TryGetValue(type, out missingView)) + { + return true; + } + + if (type.IsGenericType && s_missingViews.TryGetValue(type.GetGenericTypeDefinition(), out missingView)) + { + return true; + } + + missingView = null; + return false; + } } """; diff --git a/StaticViewLocator/StaticViewLocatorGenerator.cs b/StaticViewLocator/StaticViewLocatorGenerator.cs index 92dc0c7..c95066c 100644 --- a/StaticViewLocator/StaticViewLocatorGenerator.cs +++ b/StaticViewLocator/StaticViewLocatorGenerator.cs @@ -24,6 +24,7 @@ public sealed class StaticViewLocatorGenerator : IIncrementalGenerator private const string NamespaceReplacementRulesProperty = "build_property.StaticViewLocatorNamespaceReplacementRules"; private const string TypeNameReplacementRulesProperty = "build_property.StaticViewLocatorTypeNameReplacementRules"; private const string StripGenericArityFromViewNameProperty = "build_property.StaticViewLocatorStripGenericArityFromViewName"; + private const string InterfacePrefixesToStripProperty = "build_property.StaticViewLocatorInterfacePrefixesToStrip"; private readonly struct GeneratorOptions { @@ -34,7 +35,8 @@ public GeneratorOptions( ImmutableArray additionalViewBaseTypes, ImmutableArray namespaceReplacementRules, ImmutableArray typeNameReplacementRules, - bool stripGenericArityFromViewName) + bool stripGenericArityFromViewName, + ImmutableArray interfacePrefixesToStrip) { NamespacePrefixes = namespacePrefixes; IncludeInternalViewModels = includeInternalViewModels; @@ -43,6 +45,7 @@ public GeneratorOptions( NamespaceReplacementRules = namespaceReplacementRules; TypeNameReplacementRules = typeNameReplacementRules; StripGenericArityFromViewName = stripGenericArityFromViewName; + InterfacePrefixesToStrip = interfacePrefixesToStrip; } public ImmutableArray NamespacePrefixes { get; } @@ -52,6 +55,7 @@ public GeneratorOptions( public ImmutableArray NamespaceReplacementRules { get; } public ImmutableArray TypeNameReplacementRules { get; } public bool StripGenericArityFromViewName { get; } + public ImmutableArray InterfacePrefixesToStrip { get; } } private readonly struct ReplacementRule @@ -87,17 +91,19 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var viewModelsProvider = context.SyntaxProvider .CreateSyntaxProvider( - static (node, _) => node is ClassDeclarationSyntax classDeclaration && - classDeclaration.Identifier.ValueText.EndsWith(ViewModelSuffix, StringComparison.Ordinal), + static (node, _) => + node is TypeDeclarationSyntax typeDeclaration && + (typeDeclaration is ClassDeclarationSyntax || typeDeclaration is InterfaceDeclarationSyntax) && + typeDeclaration.Identifier.ValueText.EndsWith(ViewModelSuffix, StringComparison.Ordinal), static (generatorContext, cancellationToken) => { - var classDeclaration = (ClassDeclarationSyntax)generatorContext.Node; - if (generatorContext.SemanticModel.GetDeclaredSymbol(classDeclaration, cancellationToken) is not { } symbol) + if (generatorContext.Node is not TypeDeclarationSyntax typeDeclaration || + generatorContext.SemanticModel.GetDeclaredSymbol(typeDeclaration, cancellationToken) is not INamedTypeSymbol symbol) { return null; } - return symbol.IsAbstract ? null : symbol; + return symbol; }) .Where(static symbol => symbol is not null) .Select(static (symbol, _) => symbol!) @@ -188,6 +194,7 @@ private static GeneratorOptions GetGeneratorOptions(AnalyzerConfigOptionsProvide var namespaceReplacementRules = GetReplacementRules(optionsProvider, NamespaceReplacementRulesProperty, new ReplacementRule("ViewModels", "Views")); var typeNameReplacementRules = GetReplacementRules(optionsProvider, TypeNameReplacementRulesProperty, new ReplacementRule(ViewModelSuffix, ViewSuffix)); var stripGenericArityFromViewName = GetStripGenericArityFromViewName(optionsProvider); + var interfacePrefixesToStrip = GetInterfacePrefixesToStrip(optionsProvider); return new GeneratorOptions( namespacePrefixes, @@ -196,7 +203,8 @@ private static GeneratorOptions GetGeneratorOptions(AnalyzerConfigOptionsProvide additionalViewBaseTypes, namespaceReplacementRules, typeNameReplacementRules, - stripGenericArityFromViewName); + stripGenericArityFromViewName, + interfacePrefixesToStrip); } private static ImmutableArray GetNamespacePrefixes(AnalyzerConfigOptionsProvider optionsProvider) @@ -343,9 +351,31 @@ private static bool GetStripGenericArityFromViewName(AnalyzerConfigOptionsProvid return !bool.TryParse(rawValue, out var stripGenericArityFromViewName) || stripGenericArityFromViewName; } + private static ImmutableArray GetInterfacePrefixesToStrip(AnalyzerConfigOptionsProvider optionsProvider) + { + if (!optionsProvider.GlobalOptions.TryGetValue(InterfacePrefixesToStripProperty, out var rawValue) || + string.IsNullOrWhiteSpace(rawValue)) + { + return ImmutableArray.Create("I"); + } + + var parts = rawValue.Split(new[] { ';', ',' }, StringSplitOptions.RemoveEmptyEntries); + var builder = ImmutableArray.CreateBuilder(parts.Length); + foreach (var part in parts) + { + var trimmed = part.Trim(); + if (trimmed.Length > 0) + { + builder.Add(trimmed); + } + } + + return builder.Count == 0 ? ImmutableArray.Create("I") : builder.ToImmutable(); + } + private static bool IsViewModelType(INamedTypeSymbol symbol) { - if (symbol.TypeKind != TypeKind.Class || symbol.IsAbstract) + if (symbol.TypeKind != TypeKind.Class && symbol.TypeKind != TypeKind.Interface) { return false; } @@ -526,6 +556,29 @@ private static string StripGenericArity(string typeName) return string.Join(".", parts); } + private static string StripInterfacePrefix(string typeName, ImmutableArray interfacePrefixesToStrip) + { + foreach (var prefix in interfacePrefixesToStrip) + { + if (string.IsNullOrEmpty(prefix)) + { + continue; + } + + var parts = typeName.Split('.'); + var last = parts[parts.Length - 1]; + if (!last.StartsWith(prefix, StringComparison.Ordinal) || last.Length <= prefix.Length) + { + continue; + } + + parts[parts.Length - 1] = last.Substring(prefix.Length); + return string.Join(".", parts); + } + + return typeName; + } + private static string? ProcessClass( Compilation compilation, INamedTypeSymbol locatorSymbol, @@ -622,6 +675,12 @@ public partial class {{classNameLocator}} var sourceTypeNameView = ApplyReplacementRules(sourceTypeNameViewModel, options.TypeNameReplacementRules); var metadataTypeNameView = ApplyReplacementRules(metadataTypeNameViewModel, options.TypeNameReplacementRules); + if (viewModelSymbol.TypeKind == TypeKind.Interface) + { + sourceTypeNameView = StripInterfacePrefix(sourceTypeNameView, options.InterfacePrefixesToStrip); + metadataTypeNameView = StripInterfacePrefix(metadataTypeNameView, options.InterfacePrefixesToStrip); + } + if (options.StripGenericArityFromViewName) { sourceTypeNameView = StripGenericArity(sourceTypeNameView); @@ -648,13 +707,66 @@ public partial class {{classNameLocator}} if (viewSymbol is null || !isSupportedView) { - source.AppendLine( - $"\t\t[typeof({classNameViewModel})] = () => new TextBlock() {{ Text = \"Not Found: {classNameView}\" }},"); + continue; + } + + source.AppendLine($"\t\t[typeof({classNameViewModel})] = () => new {classNameView}(),"); + } + + source.AppendLine("\t};"); + source.AppendLine(); + source.AppendLine("\tprivate static Dictionary s_missingViews = new()"); + source.AppendLine("\t{"); + + foreach (var viewModelSymbol in relevantViewModels) + { + var namespaceNameViewModel = viewModelSymbol.ContainingNamespace.ToDisplayString(); + var sourceTypeNameViewModel = GetSourceTypeName(viewModelSymbol); + var metadataTypeNameViewModel = GetMetadataTypeName(viewModelSymbol); + + var sourceNamespaceView = ApplyReplacementRules(namespaceNameViewModel, options.NamespaceReplacementRules); + var metadataNamespaceView = ApplyReplacementRules(namespaceNameViewModel, options.NamespaceReplacementRules); + + var sourceTypeNameView = ApplyReplacementRules(sourceTypeNameViewModel, options.TypeNameReplacementRules); + var metadataTypeNameView = ApplyReplacementRules(metadataTypeNameViewModel, options.TypeNameReplacementRules); + + if (viewModelSymbol.TypeKind == TypeKind.Interface) + { + sourceTypeNameView = StripInterfacePrefix(sourceTypeNameView, options.InterfacePrefixesToStrip); + metadataTypeNameView = StripInterfacePrefix(metadataTypeNameView, options.InterfacePrefixesToStrip); + } + + if (options.StripGenericArityFromViewName) + { + sourceTypeNameView = StripGenericArity(sourceTypeNameView); + metadataTypeNameView = StripGenericArity(metadataTypeNameView); + } + + var classNameViewModel = GetSourceTypeReference(viewModelSymbol); + var classNameView = $"{sourceNamespaceView}.{sourceTypeNameView}"; + var metadataNameView = $"{metadataNamespaceView}.{metadataTypeNameView}"; + + var viewSymbol = compilation.GetTypeByMetadataName(metadataNameView); + var isSupportedView = false; + if (viewSymbol is not null && viewBaseTypes.Count > 0) + { + for (var current = viewSymbol; current is not null; current = current.BaseType) + { + if (viewBaseTypes.Contains(current)) + { + isSupportedView = true; + break; + } + } } - else + + if (viewSymbol is not null && isSupportedView) { - source.AppendLine($"\t\t[typeof({classNameViewModel})] = () => new {classNameView}(),"); + continue; } + + source.AppendLine( + $"\t\t[typeof({classNameViewModel})] = \"Not Found: {classNameView}\","); } source.AppendLine("\t};"); @@ -672,21 +784,140 @@ public partial class {{classNameLocator}} } var type = data.GetType(); + var func = TryGetFactory(type) ?? TryGetFactoryFromInterfaces(type); - if (!s_views.TryGetValue(type, out var func) && - type.IsGenericType) + if (func is not null) { - s_views.TryGetValue(type.GetGenericTypeDefinition(), out func); + return func.Invoke(); } - if (func is not null) + var missingView = TryGetMissingView(type) ?? TryGetMissingViewFromInterfaces(type); + if (missingView is not null) { - return func.Invoke(); + return new TextBlock { Text = missingView }; } throw new Exception($"Unable to create view for type: {type}"); } + private static Func? TryGetFactory(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetFactoryForType(type, out var func)) + { + return func; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetFactoryForType(current, out func)) + { + return func; + } + } + + return null; + } + + private static Func? TryGetFactoryFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetFactoryForType(interfaceType, out var func)) + { + return func; + } + } + + return null; + } + + private static bool TryGetFactoryForType(Type type, out Func? func) + { + if (s_views.TryGetValue(type, out func)) + { + return true; + } + + if (type.IsGenericType && s_views.TryGetValue(type.GetGenericTypeDefinition(), out func)) + { + return true; + } + + func = null; + return false; + } + + private static string? TryGetMissingView(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetMissingViewForType(type, out var missingView)) + { + return missingView; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetMissingViewForType(current, out missingView)) + { + return missingView; + } + } + + return null; + } + + private static string? TryGetMissingViewFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetMissingViewForType(interfaceType, out var missingView)) + { + return missingView; + } + } + + return null; + } + + private static bool TryGetMissingViewForType(Type type, out string? missingView) + { + if (s_missingViews.TryGetValue(type, out missingView)) + { + return true; + } + + if (type.IsGenericType && s_missingViews.TryGetValue(type.GetGenericTypeDefinition(), out missingView)) + { + return true; + } + + missingView = null; + return false; + } + """); } diff --git a/StaticViewLocator/buildTransitive/StaticViewLocator.props b/StaticViewLocator/buildTransitive/StaticViewLocator.props index ba71ae0..7de6506 100644 --- a/StaticViewLocator/buildTransitive/StaticViewLocator.props +++ b/StaticViewLocator/buildTransitive/StaticViewLocator.props @@ -7,5 +7,6 @@ + From caa3fd164620a7aa0082e85037f4b16a5febfcf5 Mon Sep 17 00:00:00 2001 From: Dan Walmsley <4672627+danwalmsley@users.noreply.github.com> Date: Wed, 11 Mar 2026 12:39:26 +0000 Subject: [PATCH 3/5] Document configurable locator rules --- README.md | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 1448ea4..eaaa06d 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Add NuGet package reference to project. ``` -Annotate view locator class with `[StaticViewLocator]` attribute, make class `partial` and imlement `Build` using `s_views` dictionary to retrieve views for `data` objects. +Annotate a view locator class with `[StaticViewLocator]`, make it `partial`, and let the generator provide the lookup tables and fallback helpers. ```csharp [StaticViewLocator] @@ -31,12 +31,19 @@ public partial class ViewLocator : IDataTemplate } var type = data.GetType(); + var func = TryGetFactory(type) ?? TryGetFactoryFromInterfaces(type); - if (s_views.TryGetValue(type, out var func)) + if (func is not null) { return func.Invoke(); } + var missingView = TryGetMissingView(type) ?? TryGetMissingViewFromInterfaces(type); + if (missingView is not null) + { + return new TextBlock { Text = missingView }; + } + throw new Exception($"Unable to create view for type: {type}"); } @@ -47,16 +54,40 @@ public partial class ViewLocator : IDataTemplate } ``` -Source generator will generate the `s_views` dictionary similar to below code using convention based on `ViewModel` suffix for view models subsituted to `View` suffix. +The generator emits: +- `s_views`: resolved mappings from `Type` to `Func` +- `s_missingViews`: unresolved mappings used for `"Not Found: ..."` fallback text +- helper methods for exact type lookup, generic type-definition lookup, base-class fallback, and interface fallback + +By default, the generated lookup order is: +1. exact runtime type +2. generic type definition for generic runtime types +3. base type chain +4. implemented interfaces in reverse order + +Source generator will generate mappings using convention-based transforms. By default: +- namespace `ViewModels` becomes `Views` +- type suffix `ViewModel` becomes `View` +- generic arity markers are removed from the target view name +- interface prefix `I` is stripped before resolving the target view name + +This allows patterns like: +- `MyApp.ViewModels.SettingsViewModel -> MyApp.Views.SettingsView` +- `MyApp.ViewModels.WidgetViewModel -> MyApp.Views.WidgetView` +- `MyApp.ViewModels.IDetailsViewModel -> MyApp.Views.DetailsView` ```csharp public partial class ViewLocator { private static Dictionary> s_views = new() { - [typeof(StaticViewLocatorDemo.ViewModels.MainWindowViewModel)] = () => new TextBlock() { Text = "Not Found: StaticViewLocatorDemo.Views.MainWindowView" }, [typeof(StaticViewLocatorDemo.ViewModels.TestViewModel)] = () => new StaticViewLocatorDemo.Views.TestView(), }; + + private static Dictionary s_missingViews = new() + { + [typeof(StaticViewLocatorDemo.ViewModels.MainWindowViewModel)] = "Not Found: StaticViewLocatorDemo.Views.MainWindowView", + }; } ``` @@ -73,6 +104,7 @@ You can scope which view model namespaces are considered and opt into additional ViewModels=Views ViewModel=View;Vm=Page true + I ``` @@ -84,9 +116,29 @@ Defaults and behavior: - `StaticViewLocatorNamespaceReplacementRules` uses `;` or `,` separators with `from=to` pairs and is applied sequentially to the view-model namespace when deriving the target view namespace. The default includes `ViewModels=Views`. - `StaticViewLocatorTypeNameReplacementRules` uses `;` or `,` separators with `from=to` pairs and is applied sequentially to the view-model type name when deriving the target view name. The default includes `ViewModel=View`. - `StaticViewLocatorStripGenericArityFromViewName` defaults to `true`. When enabled, generic arity markers like `` `1 `` are removed from the derived target view name, so `WidgetViewModel` can map to `WidgetView`. +- `StaticViewLocatorInterfacePrefixesToStrip` uses `;` or `,` separators and is applied to interface view-model names before looking up the target view. The default includes `I`. These properties are exported as `CompilerVisibleProperty` by the package, so analyzers can read them without extra project configuration. +## Supported resolution features + +- Exact type mapping +- Open generic mapping, for example `WidgetViewModel -> WidgetView` +- Base-class fallback +- Interface fallback +- Configurable namespace replacement rules +- Configurable type-name replacement rules +- Configurable interface prefix stripping +- Configurable additional allowed view base types +- Optional referenced-assembly scanning +- Optional internal view-model inclusion + +## Notes + +- Candidate discovery still starts from types whose names end with `ViewModel`. +- Missing views do not block fallback resolution. The generator keeps unresolved targets in `s_missingViews`, so a derived type can still fall back to a base-class or interface mapping before returning a `"Not Found"` placeholder. +- If you provide custom replacement rules, they take precedence over the built-in defaults. + Default view base types: - `Avalonia.Controls.UserControl` - `Avalonia.Controls.Window` From 1032f03e3d2b9e519df7533392e35aa5d27a9f95 Mon Sep 17 00:00:00 2001 From: Dan Walmsley <4672627+danwalmsley@users.noreply.github.com> Date: Wed, 11 Mar 2026 13:01:15 +0000 Subject: [PATCH 4/5] Fix generated helper emission and generic view factories --- .../StaticViewLocator.Tests.csproj | 2 +- .../StaticViewLocatorGeneratorRuntimeTests.cs | 69 ++++++++ ...StaticViewLocatorGeneratorSnapshotTests.cs | 163 ++++++++++++++++++ .../StaticViewLocatorGeneratorVerifier.cs | 37 ++++ StaticViewLocator/StaticViewLocator.csproj | 2 +- .../StaticViewLocatorGenerator.cs | 12 +- 6 files changed, 282 insertions(+), 3 deletions(-) diff --git a/StaticViewLocator.Tests/StaticViewLocator.Tests.csproj b/StaticViewLocator.Tests/StaticViewLocator.Tests.csproj index 16f82ea..cf3b6f2 100644 --- a/StaticViewLocator.Tests/StaticViewLocator.Tests.csproj +++ b/StaticViewLocator.Tests/StaticViewLocator.Tests.csproj @@ -1,7 +1,7 @@ - net9.0 + net10.0 false diff --git a/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs b/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs index a95909a..4759829 100644 --- a/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs +++ b/StaticViewLocator.Tests/StaticViewLocatorGeneratorRuntimeTests.cs @@ -316,6 +316,75 @@ public class AlternateView : UserControl Assert.Equal("TestApp.Views.BaseView", control!.GetType().FullName); } + [AvaloniaFact] + public async Task CustomBuildCanCallGeneratedHelpers() + { + const string source = @" +using System; +using Avalonia.Controls; +using StaticViewLocator; + +namespace TestApp +{ + [StaticViewLocator] + public partial class ViewLocator + { + public Control? Build(object? data) + { + if (data is null) + { + return null; + } + + var type = data.GetType(); + var factory = TryGetFactory(type) ?? TryGetFactoryFromInterfaces(type); + return factory?.Invoke(); + } + } +} + +namespace TestApp.ViewModels +{ + public class SampleViewModel + { + } +} + +namespace TestApp.Views +{ + public class SampleView : UserControl + { + } +} +"; + + var compilation = await CreateCompilationAsync(source); + var sourceGenerator = new StaticViewLocatorGenerator().AsSourceGenerator(); + var driver = CSharpGeneratorDriver.Create(new[] { sourceGenerator }, parseOptions: (CSharpParseOptions)compilation.SyntaxTrees.First().Options); + driver.RunGeneratorsAndUpdateCompilation(compilation, out var updatedCompilation, out var diagnostics); + + Assert.Empty(diagnostics.Where(d => d.Severity == DiagnosticSeverity.Error)); + + using var peStream = new MemoryStream(); + var emitResult = updatedCompilation.Emit(peStream); + Assert.True(emitResult.Success, string.Join(Environment.NewLine, emitResult.Diagnostics)); + + peStream.Seek(0, SeekOrigin.Begin); + var assembly = Assembly.Load(peStream.ToArray()); + + var locatorType = assembly.GetType("TestApp.ViewLocator") ?? throw new InvalidOperationException("Generated locator type not found."); + var buildMethod = locatorType.GetMethod("Build", BindingFlags.Public | BindingFlags.Instance) ?? throw new InvalidOperationException("Build method not found."); + var sampleViewModel = CreateInstance(assembly, "TestApp.ViewModels.SampleViewModel"); + var locator = Activator.CreateInstance(locatorType) ?? throw new InvalidOperationException("Unable to instantiate generated locator."); + + _ = HeadlessUnitTestSession.GetOrStartForAssembly(typeof(StaticViewLocatorGeneratorRuntimeTests).Assembly); + + var control = (Control?)buildMethod.Invoke(locator, new[] { sampleViewModel }); + + Assert.NotNull(control); + Assert.Equal("TestApp.Views.SampleView", control!.GetType().FullName); + } + [AvaloniaFact] public async Task ResolvesInterfaceMappingsByStrippingConfiguredPrefix() { diff --git a/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs b/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs index d2283a4..a92ed4b 100644 --- a/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs +++ b/StaticViewLocator.Tests/StaticViewLocatorGeneratorSnapshotTests.cs @@ -701,6 +701,124 @@ public partial class ViewLocator private static Dictionary s_missingViews = new() { }; + + private static Func? TryGetFactory(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetFactoryForType(type, out var func)) + { + return func; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetFactoryForType(current, out func)) + { + return func; + } + } + + return null; + } + + private static Func? TryGetFactoryFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetFactoryForType(interfaceType, out var func)) + { + return func; + } + } + + return null; + } + + private static bool TryGetFactoryForType(Type type, out Func? func) + { + if (s_views.TryGetValue(type, out func)) + { + return true; + } + + if (type.IsGenericType && s_views.TryGetValue(type.GetGenericTypeDefinition(), out func)) + { + return true; + } + + func = null; + return false; + } + + private static string? TryGetMissingView(Type? type) + { + if (type is null) + { + return null; + } + + if (TryGetMissingViewForType(type, out var missingView)) + { + return missingView; + } + + for (var current = type.BaseType; current is not null; current = current.BaseType) + { + if (TryGetMissingViewForType(current, out missingView)) + { + return missingView; + } + } + + return null; + } + + private static string? TryGetMissingViewFromInterfaces(Type type) + { + var interfaces = type.GetInterfaces(); + for (var index = interfaces.Length - 1; index >= 0; index--) + { + var interfaceType = interfaces[index]; + if (!interfaceType.Name.EndsWith("ViewModel", StringComparison.Ordinal)) + { + continue; + } + + if (TryGetMissingViewForType(interfaceType, out var missingView)) + { + return missingView; + } + } + + return null; + } + + private static bool TryGetMissingViewForType(Type type, out string? missingView) + { + if (s_missingViews.TryGetValue(type, out missingView)) + { + return true; + } + + if (type.IsGenericType && s_missingViews.TryGetValue(type.GetGenericTypeDefinition(), out missingView)) + { + return true; + } + + missingView = null; + return false; + } } """; @@ -1147,4 +1265,49 @@ await StaticViewLocatorGeneratorVerifier.VerifyGeneratedSourcesAsync( ("StaticViewLocatorAttribute.cs", expectedAttribute), ("ViewLocator_StaticViewLocator.cs", expectedLocator)); } + + [Fact] + public async Task DoesNotGenerateInvalidFactoryForOpenGenericView() + { + const string input = @" +using Avalonia.Controls; +using StaticViewLocator; + +namespace TestApp.ViewModels +{ + public class WidgetViewModel + { + } +} + +namespace TestApp.Views +{ + public class WidgetView : UserControl + { + } +} + +namespace TestApp +{ + [StaticViewLocator] + public partial class ViewLocator + { + } +} +"; + + var generated = await StaticViewLocatorGeneratorVerifier.GetGeneratedSourcesAsync( + input, + new Dictionary + { + ["build_property.StaticViewLocatorNamespaceReplacementRules"] = "ViewModels=Views", + ["build_property.StaticViewLocatorTypeNameReplacementRules"] = "ViewModel=View", + ["build_property.StaticViewLocatorStripGenericArityFromViewName"] = "false", + }); + + var locatorSource = generated["ViewLocator_StaticViewLocator.cs"]; + + Assert.DoesNotContain("new TestApp.Views.WidgetView()", locatorSource, StringComparison.Ordinal); + Assert.DoesNotContain("[typeof(TestApp.ViewModels.WidgetViewModel<>)] = () =>", locatorSource, StringComparison.Ordinal); + } } diff --git a/StaticViewLocator.Tests/TestHelpers/StaticViewLocatorGeneratorVerifier.cs b/StaticViewLocator.Tests/TestHelpers/StaticViewLocatorGeneratorVerifier.cs index 3fc9d5e..77308af 100644 --- a/StaticViewLocator.Tests/TestHelpers/StaticViewLocatorGeneratorVerifier.cs +++ b/StaticViewLocator.Tests/TestHelpers/StaticViewLocatorGeneratorVerifier.cs @@ -15,6 +15,43 @@ namespace StaticViewLocator.Tests.TestHelpers; internal static class StaticViewLocatorGeneratorVerifier { + public static Task> GetGeneratedSourcesAsync( + string source, + IReadOnlyDictionary? globalOptions = null) + { + var parseOptions = new CSharpParseOptions(LanguageVersion.Preview); + var syntaxTree = CSharpSyntaxTree.ParseText(source, parseOptions); + + var compilation = CSharpCompilation.Create( + assemblyName: "StaticViewLocatorGenerator.Tests", + syntaxTrees: new[] { syntaxTree }, + references: GetMetadataReferences(), + options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + var generator = new StaticViewLocatorGenerator().AsSourceGenerator(); + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: new[] { generator }, + additionalTexts: null, + parseOptions: parseOptions, + optionsProvider: new TestAnalyzerConfigOptionsProvider(globalOptions)); + + driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out _, out var diagnostics); + + var failures = diagnostics.Where(static d => d.Severity == DiagnosticSeverity.Error).ToArray(); + if (failures.Length > 0) + { + var message = string.Join(Environment.NewLine, failures.Select(static d => d.ToString())); + throw new Xunit.Sdk.XunitException($"Generator reported diagnostics:{Environment.NewLine}{message}"); + } + + var runResult = driver.GetRunResult(); + var generated = runResult.GeneratedTrees + .Select(static tree => (HintName: Path.GetFileName(tree.FilePath) ?? string.Empty, Source: tree.GetText().ToString())) + .ToDictionary(static x => x.HintName, static x => x.Source, StringComparer.Ordinal); + + return Task.FromResult>(generated); + } + public static Task VerifyGeneratedSourcesAsync(string source, params (string hintName, string source)[] generatedSources) { return VerifyGeneratedSourcesAsync(source, globalOptions: null, generatedSources); diff --git a/StaticViewLocator/StaticViewLocator.csproj b/StaticViewLocator/StaticViewLocator.csproj index 47422b5..8afc242 100644 --- a/StaticViewLocator/StaticViewLocator.csproj +++ b/StaticViewLocator/StaticViewLocator.csproj @@ -1,7 +1,7 @@  - netstandard2.0 + net10.0 true false true diff --git a/StaticViewLocator/StaticViewLocatorGenerator.cs b/StaticViewLocator/StaticViewLocatorGenerator.cs index c95066c..c0d0ac0 100644 --- a/StaticViewLocator/StaticViewLocatorGenerator.cs +++ b/StaticViewLocator/StaticViewLocatorGenerator.cs @@ -710,6 +710,11 @@ public partial class {{classNameLocator}} continue; } + if (viewSymbol.IsGenericType) + { + continue; + } + source.AppendLine($"\t\t[typeof({classNameViewModel})] = () => new {classNameView}(),"); } @@ -800,6 +805,12 @@ public partial class {{classNameLocator}} throw new Exception($"Unable to create view for type: {type}"); } +"""); + } + + source.Append( + """ + private static Func? TryGetFactory(Type? type) { if (type is null) @@ -919,7 +930,6 @@ private static bool TryGetMissingViewForType(Type type, out string? missingView) } """); - } source.AppendLine("}"); From 41936fc2122ece43ca3fe1edaf7b11ffca21eef5 Mon Sep 17 00:00:00 2001 From: Dan Walmsley <4672627+danwalmsley@users.noreply.github.com> Date: Wed, 11 Mar 2026 13:12:15 +0000 Subject: [PATCH 5/5] netstandard2.0 --- StaticViewLocator/StaticViewLocator.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/StaticViewLocator/StaticViewLocator.csproj b/StaticViewLocator/StaticViewLocator.csproj index 8afc242..47422b5 100644 --- a/StaticViewLocator/StaticViewLocator.csproj +++ b/StaticViewLocator/StaticViewLocator.csproj @@ -1,7 +1,7 @@  - net10.0 + netstandard2.0 true false true