Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,27 @@ public static IFeatureManagementBuilder WithVariantService<TService>(this IFeatu

if (builder.Services.Any(descriptor => descriptor.ServiceType == typeof(IFeatureManager) && descriptor.Lifetime == ServiceLifetime.Scoped))
{
builder.Services.AddScoped<IVariantServiceProvider<TService>>(sp => new VariantServiceProvider<TService>(
featureName,
sp.GetRequiredService<IVariantFeatureManager>(),
sp.GetRequiredService<IEnumerable<TService>>()));
builder.Services.AddScoped<IVariantServiceProvider<TService>>(sp =>
{
IEnumerable<ServiceDescriptor> serviceDescriptors = builder.Services.Where(d => d.ServiceType == typeof(TService));
return new VariantServiceProvider<TService>(
featureName,
sp.GetRequiredService<IVariantFeatureManager>(),
serviceDescriptors,
sp);
});
}
else
{
builder.Services.AddSingleton<IVariantServiceProvider<TService>>(sp => new VariantServiceProvider<TService>(
featureName,
sp.GetRequiredService<IVariantFeatureManager>(),
sp.GetRequiredService<IEnumerable<TService>>()));
builder.Services.AddSingleton<IVariantServiceProvider<TService>>(sp =>
{
IEnumerable<ServiceDescriptor> serviceDescriptors = builder.Services.Where(d => d.ServiceType == typeof(TService));
return new VariantServiceProvider<TService>(
featureName,
sp.GetRequiredService<IVariantFeatureManager>(),
serviceDescriptors,
sp);
});
}

return builder;
Expand Down
121 changes: 99 additions & 22 deletions src/Microsoft.FeatureManagement/VariantServiceProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.FeatureManagement
{
Expand All @@ -16,26 +16,55 @@ namespace Microsoft.FeatureManagement
/// </summary>
internal class VariantServiceProvider<TService> : IVariantServiceProvider<TService> where TService : class
{
private readonly IEnumerable<TService> _services;
private readonly IVariantFeatureManager _featureManager;
private readonly string _featureName;
private readonly ConcurrentDictionary<string, TService> _variantServiceCache;
private readonly IServiceProvider _serviceProvider;
private readonly Dictionary<string, ServiceDescriptor> _variantNameToDescriptor; // ImplementationType/Instance descriptors mapped by variant name.
private readonly List<ServiceDescriptor> _factoryDescriptors; // Descriptors that require factory invocation to discover variant name.

/// <summary>
/// Creates a variant service provider.
/// </summary>
/// <param name="featureName">The feature flag that should be used to determine which variant of the service should be used.</param>
/// <param name="featureManager">The feature manager to get the assigned variant of the feature flag.</param>
/// <param name="services">Implementation variants of TService.</param>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="featureName"/> is null.</exception>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="featureManager"/> is null.</exception>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="services"/> is null.</exception>
public VariantServiceProvider(string featureName, IVariantFeatureManager featureManager, IEnumerable<TService> services)
/// <param name="serviceDescriptors">Service descriptors for implementation variants of TService.</param>
/// <param name="serviceProvider">The service provider / scope used to activate implementations lazily.</param>
public VariantServiceProvider(string featureName, IVariantFeatureManager featureManager, IEnumerable<ServiceDescriptor> serviceDescriptors, IServiceProvider serviceProvider)
{
_featureName = featureName ?? throw new ArgumentNullException(nameof(featureName));
_featureManager = featureManager ?? throw new ArgumentNullException(nameof(featureManager));
_services = services ?? throw new ArgumentNullException(nameof(services));
_variantServiceCache = new ConcurrentDictionary<string, TService>();
if (serviceDescriptors == null) throw new ArgumentNullException(nameof(serviceDescriptors));
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
_variantServiceCache = new ConcurrentDictionary<string, TService>(StringComparer.OrdinalIgnoreCase);
_variantNameToDescriptor = new Dictionary<string, ServiceDescriptor>(StringComparer.OrdinalIgnoreCase);
_factoryDescriptors = new List<ServiceDescriptor>();

// Precompute mapping for descriptors whose variant name can be determined without instantiation.
foreach (ServiceDescriptor descriptor in serviceDescriptors)
{
if (descriptor.ImplementationType != null)
{
string name = GetVariantName(descriptor.ImplementationType);
if (!_variantNameToDescriptor.ContainsKey(name))
{
_variantNameToDescriptor.Add(name, descriptor);
}
}
else if (descriptor.ImplementationInstance != null)
{
string name = GetVariantName(descriptor.ImplementationInstance.GetType());
if (!_variantNameToDescriptor.ContainsKey(name))
{
_variantNameToDescriptor.Add(name, descriptor);
}
}
else if (descriptor.ImplementationFactory != null)
{
// Factory descriptors require instantiation to discover variant name; hold for later.
_factoryDescriptors.Add(descriptor);
}
}
}

/// <summary>
Expand All @@ -47,25 +76,73 @@ public async ValueTask<TService> GetServiceAsync(CancellationToken cancellationT
{
Debug.Assert(_featureName != null);

Variant variant = await _featureManager.GetVariantAsync(_featureName, cancellationToken);
Variant variant = await _featureManager.GetVariantAsync(_featureName, cancellationToken).ConfigureAwait(false);

if (variant == null)
{
return null;
}

return _variantServiceCache.GetOrAdd(variant.Name, ResolveVariant);
}

private TService ResolveVariant(string variantName)
{
// Try fast path using precomputed mapping.
if (_variantNameToDescriptor.TryGetValue(variantName, out ServiceDescriptor descriptor))
{
return ActivateDescriptor(descriptor);
}

// Need to probe factory descriptors lazily.
foreach (ServiceDescriptor factoryDescriptor in _factoryDescriptors)
{
TService instance = ActivateDescriptor(factoryDescriptor);

TService implementation = null;
if (instance == null)
{
continue;
}

string discoveredName = GetVariantName(instance.GetType());

// Cache the mapping for future lookups.
if (!_variantNameToDescriptor.ContainsKey(discoveredName))
{
_variantNameToDescriptor.Add(discoveredName, factoryDescriptor);
}

if (string.Equals(discoveredName, variantName, StringComparison.OrdinalIgnoreCase))
{
return instance;
}
}

return null;
}

private TService ActivateDescriptor(ServiceDescriptor descriptor)
{
if (descriptor.ImplementationInstance != null)
{
return (TService)descriptor.ImplementationInstance;
}

if (descriptor.ImplementationType != null)
{
// Use ActivatorUtilities to honor DI for dependencies of the implementation type.
return (TService)ActivatorUtilities.GetServiceOrCreateInstance(_serviceProvider, descriptor.ImplementationType);
}

if (variant != null)
if (descriptor.ImplementationFactory != null)
{
implementation = _variantServiceCache.GetOrAdd(
variant.Name,
(_) => _services.FirstOrDefault(
service => IsMatchingVariantName(
service.GetType(),
variant.Name))
);
return (TService)descriptor.ImplementationFactory(_serviceProvider);
}

return implementation;
return null;
}

private bool IsMatchingVariantName(Type implementationType, string variantName)
private string GetVariantName(Type implementationType)
{
string implementationName = ((VariantServiceAliasAttribute)Attribute.GetCustomAttribute(implementationType, typeof(VariantServiceAliasAttribute)))?.Alias;

Expand All @@ -74,7 +151,7 @@ private bool IsMatchingVariantName(Type implementationType, string variantName)
implementationName = implementationType.Name;
}

return string.Equals(implementationName, variantName, StringComparison.OrdinalIgnoreCase);
return implementationName;
}
}
}
78 changes: 47 additions & 31 deletions tests/Tests.FeatureManagement/FeatureManagementTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1803,59 +1803,75 @@ public async Task VariantBasedInjection()
services = new ServiceCollection();

Assert.Throws<InvalidOperationException>(() =>
{
services.AddFeatureManagement()
.WithVariantService<IAlgorithm>("DummyFeature1")
.WithVariantService<IAlgorithm>("DummyFeature2");
}
{
services.AddFeatureManagement()
.WithVariantService<IAlgorithm>("DummyFeature1")
.WithVariantService<IAlgorithm>("DummyFeature2");
}
);
}

[Fact]
public async Task VariantFeatureFlagWithContextualFeatureFilter()
public async Task VariantServiceLazyInstantiation()
{
// Reset counters
AlgorithmBeta.Instances = 0;
AlgorithmOmega.Instances = 0;

IConfiguration configuration = new ConfigurationBuilder()
.AddJsonFile("appsettings.json")
.Build();

IServiceCollection services = new ServiceCollection();

services.AddSingleton<IAlgorithm, AlgorithmBeta>();
services.AddSingleton<IAlgorithm, AlgorithmSigma>();
services.AddSingleton<IAlgorithm>(sp => new AlgorithmOmega("OMEGA"));

services.AddSingleton(configuration)
.AddFeatureManagement()
.AddFeatureFilter<ContextualTestFilter>();

ServiceProvider serviceProvider = services.BuildServiceProvider();

ContextualTestFilter contextualTestFeatureFilter = (ContextualTestFilter)serviceProvider.GetRequiredService<IEnumerable<IFeatureFilterMetadata>>().First(f => f is ContextualTestFilter);
.AddFeatureFilter<TargetingFilter>()
.WithVariantService<IAlgorithm>(Features.VariantImplementationFeature);

contextualTestFeatureFilter.ContextualCallback = (ctx, accountContext) =>
{
var allowedAccounts = new List<string>();
var targetingContextAccessor = new OnDemandTargetingContextAccessor();
services.AddSingleton<ITargetingContextAccessor>(targetingContextAccessor);

ctx.Parameters.Bind("AllowedAccounts", allowedAccounts);
ServiceProvider serviceProvider = services.BuildServiceProvider();

return allowedAccounts.Contains(accountContext.AccountId);
};
// At this point none of the implementations should have been instantiated yet because provider hasn't requested them.
Assert.Equal(0, AlgorithmBeta.Instances);
Assert.Equal(0, AlgorithmOmega.Instances);

IVariantServiceProvider<IAlgorithm> variantProvider = serviceProvider.GetRequiredService<IVariantServiceProvider<IAlgorithm>>();
IVariantFeatureManager featureManager = serviceProvider.GetRequiredService<IVariantFeatureManager>();

var context = new AppContext();

context.AccountId = "NotEnabledAccount";

Assert.False(await featureManager.IsEnabledAsync(Features.ContextualFeatureWithVariant, context));

Variant variant = await featureManager.GetVariantAsync(Features.ContextualFeatureWithVariant, context);

Assert.Equal("Small", variant.Name);

context.AccountId = "abc";
targetingContextAccessor.Current = new TargetingContext { UserId = "Guest" };
IAlgorithm algorithm = await variantProvider.GetServiceAsync(CancellationToken.None);
Assert.Null(algorithm);
Assert.Equal(0, AlgorithmBeta.Instances);
Assert.Equal(0, AlgorithmOmega.Instances);

Assert.True(await featureManager.IsEnabledAsync(Features.ContextualFeatureWithVariant, context));
targetingContextAccessor.Current = new TargetingContext { UserId = "UserBeta" };
algorithm = await variantProvider.GetServiceAsync(CancellationToken.None);
Assert.NotNull(algorithm);
Assert.Equal("Beta", algorithm.Style);
Assert.Equal(1, AlgorithmBeta.Instances);
Assert.Equal(0, AlgorithmOmega.Instances);

variant = await featureManager.GetVariantAsync(Features.ContextualFeatureWithVariant, context);
targetingContextAccessor.Current = new TargetingContext { UserId = "UserOmega" };
algorithm = await variantProvider.GetServiceAsync(CancellationToken.None);
Assert.NotNull(algorithm);
Assert.Equal("OMEGA", algorithm.Style);
Assert.Equal(1, AlgorithmBeta.Instances);
Assert.Equal(1, AlgorithmOmega.Instances);

Assert.Equal("Big", variant.Name);
// Re-resolve Beta variant should not create additional instance because singleton already constructed previously
targetingContextAccessor.Current = new TargetingContext { UserId = "UserBeta" };
algorithm = await variantProvider.GetServiceAsync(CancellationToken.None);
Assert.NotNull(algorithm);
Assert.Equal("Beta", algorithm.Style);
Assert.Equal(1, AlgorithmBeta.Instances);
Assert.Equal(1, AlgorithmOmega.Instances);
}
}

Expand Down
6 changes: 6 additions & 0 deletions tests/Tests.FeatureManagement/VariantServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,37 @@ interface IAlgorithm

class AlgorithmBeta : IAlgorithm
{
public static int Instances; // Tracks constructed instances
public string Style { get; set; }

public AlgorithmBeta()
{
Instances++;
Style = "Beta";
}
}

class AlgorithmSigma : IAlgorithm
{
public static int Instances; // Tracks constructed instances
public string Style { get; set; }

public AlgorithmSigma()
{
Instances++;
Style = "Sigma";
}
}

[VariantServiceAlias("Omega")]
class AlgorithmOmega : IAlgorithm
{
public static int Instances; // Tracks constructed instances
public string Style { get; set; }

public AlgorithmOmega(string style)
{
Instances++;
Style = style;
}
}
Expand Down