From 5254f5439a038dabc6e3dbf1f92bb78f92d85ad2 Mon Sep 17 00:00:00 2001 From: Marcin Iwanicki Date: Sun, 2 Nov 2025 21:23:39 +0000 Subject: [PATCH] Add support for async resolution --- Sources/SCInject/Container.swift | 200 +++++++++++++++++----- Sources/SCInject/ReferenceResolvers.swift | 91 ++++++++++ Sources/SCInject/Registry.swift | 49 ++++++ Sources/SCInject/Resolver.swift | 24 ++- Sources/SCInject/Scope.swift | 2 +- Tests/SCInject/ContainerTests.swift | 111 +++++++++++- Tests/SCInject/TestUtils/Stubs.swift | 20 +++ 7 files changed, 447 insertions(+), 50 deletions(-) create mode 100644 Sources/SCInject/ReferenceResolvers.swift diff --git a/Sources/SCInject/Container.swift b/Sources/SCInject/Container.swift index 2f0eb0c..42a171e 100644 --- a/Sources/SCInject/Container.swift +++ b/Sources/SCInject/Container.swift @@ -27,11 +27,11 @@ public protocol Container: Registry, Resolver {} /// Dependencies can be registered with or without names, and resolved accordingly. If a dependency is not found in the /// current container, it will attempt to resolve it from a parent container if one exists. /// This class is thread-safe. -public final class DefaultContainer: Container { +public final class DefaultContainer: Container, @unchecked Sendable { private let parent: DefaultContainer? private let lock = NSRecursiveLock() private let defaultScope = Scope.transient - private var resolvers: [ResolverIdentifier: ReferenceResolver] = [:] + private var resolvers: [ResolverIdentifier: ConcreteResolver] = [:] public init(parent: DefaultContainer? = nil) { self.parent = parent @@ -68,6 +68,46 @@ public final class DefaultContainer: Container { register(type: type, name: name, scope: scope, closure: closure) } + // MARK: - Async Registry + + public func registerAsync(_ type: T.Type, closure: @escaping @Sendable (Resolver) async -> T) { + register(type: type, name: nil, scope: nil, closure: closure) + } + + public func registerAsync(_ type: T.Type, _ scope: Scope, closure: @escaping @Sendable (Resolver) async -> T) { + register(type: type, name: nil, scope: scope, closure: closure) + } + + public func registerAsync(_ type: T.Type, name: String, closure: @escaping @Sendable (Resolver) async -> T) { + register(type: type, name: .init(rawValue: name), scope: nil, closure: closure) + } + + public func registerAsync( + _ type: T.Type, + name: String, + _ scope: Scope, + closure: @escaping @Sendable (Resolver) async -> T + ) { + register(type: type, name: .init(rawValue: name), scope: scope, closure: closure) + } + + public func registerAsync( + _ type: T.Type, + name: RegistrationName, + closure: @escaping @Sendable (Resolver) async -> T + ) { + register(type: type, name: name, scope: nil, closure: closure) + } + + public func registerAsync( + _ type: T.Type, + name: RegistrationName, + _ scope: Scope, + closure: @escaping @Sendable (Resolver) async -> T + ) { + register(type: type, name: name, scope: scope, closure: closure) + } + // MARK: - Resolver public func resolve(_ type: T.Type) -> T { @@ -90,6 +130,28 @@ public final class DefaultContainer: Container { return instance } + // MARK: - Async Resolver + + public func resolveAsync(_ type: T.Type) async -> T { + guard let instance = await tryResolve(type: type, name: nil, container: self) else { + ContainerError.raise(reason: "Failed to resolve given async type", type: "\(type)", name: nil) + fatalError() + } + return instance + } + + public func resolveAsync(_ type: T.Type, name: String) async -> T { + await resolveAsync(type, name: .init(rawValue: name)) + } + + public func resolveAsync(_ type: T.Type, name: RegistrationName) async -> T { + guard let instance = await tryResolve(type: type, name: name, container: self) else { + ContainerError.raise(reason: "Failed to resolve given async type", type: "\(type)", name: name.rawValue) + fatalError() + } + return instance + } + // MARK: - Public public func tryResolve(_ type: T.Type) -> T? { @@ -104,6 +166,18 @@ public final class DefaultContainer: Container { tryResolve(type: type, name: name, container: self) } + public func tryResolveAsync(_ type: T.Type) async -> T? { + await tryResolve(type: type, name: nil, container: self) + } + + public func tryResolveAsync(_ type: T.Type, name: String) async -> T? { + await tryResolve(type: type, name: .init(rawValue: name), container: self) + } + + public func tryResolveAsync(_ type: T.Type, name: RegistrationName) async -> T? { + await tryResolve(type: type, name: name, container: self) + } + /// Validates the dependency graph to ensure that all dependencies /// can be successfully resolved. /// @@ -126,7 +200,12 @@ public final class DefaultContainer: Container { public func validate() throws { try ContainerError.rethrow { for resolver in resolvers { - _ = resolver.value.resolve(with: self) + switch resolver.value { + case let .reference(resolver): + _ = resolver.resolve(with: self) + case .referenceAsync: + break + } } } try parent?.validate() @@ -146,18 +225,73 @@ public final class DefaultContainer: Container { ContainerError.raise(reason: "Given type is already registered", type: "\(type)", name: name?.rawValue) fatalError() } - resolvers[identifier] = makeResolver(scope ?? defaultScope, closure: closure) + resolvers[identifier] = .reference(makeResolver(scope ?? defaultScope, closure: closure)) } private func tryResolve(type: T.Type, name: RegistrationName? = nil, container: Container) -> T? { - lock.lock(); defer { lock.unlock() } - if let resolver = resolvers[identifier(of: type, name: name)] { + let resolver: ConcreteResolver? = { + lock.lock(); defer { lock.unlock() } + return resolvers[identifier(of: type, name: name)] + }() + + switch resolver { + case let .reference(resolver): return resolver.resolve(with: container) as? T + case .referenceAsync: + ContainerError.raise( + reason: "Given type requires async resolution", + type: "\(type)", + name: name?.rawValue + ) + return nil + case nil: + if let parent { + return parent.tryResolve(type: type, name: name, container: container) + } + return nil + } + } + + private func register( + type: T.Type, + name: RegistrationName?, + scope: Scope?, + closure: @escaping @Sendable (Resolver) async -> T + ) { + lock.lock(); defer { lock.unlock() } + let identifier = identifier(of: type, name: name) + if resolvers[identifier] != nil { + ContainerError.raise( + reason: "Given async type is already registered", + type: "\(type)", + name: name?.rawValue + ) + fatalError() } - if let parent { - return parent.tryResolve(type: type, name: name, container: container) + resolvers[identifier] = .referenceAsync(makeResolver(scope ?? defaultScope, closure: closure)) + } + + private func tryResolve( + type: T.Type, + name: RegistrationName? = nil, + container: Container + ) async -> T? { + let resolver: ConcreteResolver? = { + lock.lock(); defer { lock.unlock() } + return resolvers[identifier(of: type, name: name)] + }() + + switch resolver { + case let .reference(resolver): + return resolver.resolve(with: container) as? T + case let .referenceAsync(resolver): + return await resolver.resolve(with: container) as? T + case nil: + if let parent { + return await parent.tryResolve(type: type, name: name, container: container) + } + return nil } - return nil } private func makeResolver(_ scope: Scope, closure: @escaping (Resolver) -> some Any) -> ReferenceResolver { @@ -169,6 +303,18 @@ public final class DefaultContainer: Container { } } + private func makeResolver( + _ scope: Scope, + closure: @escaping @Sendable (Resolver) async -> some Any + ) -> ReferenceAsyncResolver { + switch scope { + case .transient: + TransientReferenceAsyncResolver(factory: closure) + case .container: + ContainerReferenceAsyncResolver(factory: closure) + } + } + private func identifier(of type: (some Any).Type, name: RegistrationName?) -> ResolverIdentifier { ResolverIdentifier( name: name, @@ -182,39 +328,9 @@ public final class DefaultContainer: Container { let typeIdentifier: ObjectIdentifier let description: String } -} - -private protocol ReferenceResolver { - func resolve(with resolver: Resolver) -> Any -} - -private final class TransientReferenceResolver: ReferenceResolver { - private let factory: (Resolver) -> Any - - init(factory: @escaping (Resolver) -> Any) { - self.factory = factory - } - - func resolve(with resolver: Resolver) -> Any { - factory(resolver) - } -} -private final class ContainerReferenceResolver: ReferenceResolver { - private var instance: Any? - - private let factory: (Resolver) -> Any - - init(factory: @escaping (Resolver) -> Any) { - self.factory = factory - } - - func resolve(with resolver: Resolver) -> Any { - if let instance { - return instance - } - let newInstance = factory(resolver) - instance = newInstance - return newInstance + private enum ConcreteResolver { + case reference(ReferenceResolver) + case referenceAsync(ReferenceAsyncResolver) } } diff --git a/Sources/SCInject/ReferenceResolvers.swift b/Sources/SCInject/ReferenceResolvers.swift new file mode 100644 index 0000000..a7ccd03 --- /dev/null +++ b/Sources/SCInject/ReferenceResolvers.swift @@ -0,0 +1,91 @@ +// +// Copyright 2024 Marcin Iwanicki and contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +// MARK: - Synchronous Resolvers + +protocol ReferenceResolver { + func resolve(with resolver: Resolver) -> Any +} + +final class TransientReferenceResolver: ReferenceResolver { + private let factory: (Resolver) -> Any + + init(factory: @escaping (Resolver) -> Any) { + self.factory = factory + } + + func resolve(with resolver: Resolver) -> Any { + factory(resolver) + } +} + +final class ContainerReferenceResolver: ReferenceResolver { + private var instance: Any? + + private let factory: (Resolver) -> Any + + init(factory: @escaping (Resolver) -> Any) { + self.factory = factory + } + + func resolve(with resolver: Resolver) -> Any { + if let instance { + return instance + } + let newInstance = factory(resolver) + instance = newInstance + return newInstance + } +} + +// MARK: - Asynchronous Resolvers + +protocol ReferenceAsyncResolver { + func resolve(with resolver: Resolver) async -> Sendable +} + +final class TransientReferenceAsyncResolver: ReferenceAsyncResolver { + private let factory: (Resolver) async -> Sendable + + init(factory: @escaping (Resolver) async -> Sendable) { + self.factory = factory + } + + func resolve(with resolver: Resolver) async -> Sendable { + await factory(resolver) + } +} + +actor ContainerReferenceAsyncResolver: ReferenceAsyncResolver { + private var instance: Sendable? + + private let factory: @Sendable (Resolver) async -> Sendable + + init(factory: @escaping @Sendable (Resolver) async -> Any) { + self.factory = factory + } + + func resolve(with resolver: Resolver) async -> Sendable { + if let instance { + return instance + } + let instance = await factory(resolver) + self.instance = instance + return instance + } +} diff --git a/Sources/SCInject/Registry.swift b/Sources/SCInject/Registry.swift index 66e53f1..3efc964 100644 --- a/Sources/SCInject/Registry.swift +++ b/Sources/SCInject/Registry.swift @@ -59,4 +59,53 @@ public protocol Registry { /// - Parameter scope: The scope in which the dependency should be resolved. /// - Parameter closure: A closure that provides the instance of the dependency. func register(_ type: T.Type, name: RegistrationName, _ scope: Scope, closure: @escaping (Resolver) -> T) + + // MARK: - Async Registration + + /// Registers an async dependency with a transient scope. + /// - Parameter type: The type of the dependency to register. + /// - Parameter closure: An async closure that provides the instance of the dependency. + func registerAsync(_ type: T.Type, closure: @escaping @Sendable (Resolver) async -> T) + + /// Registers an async dependency with a specified scope. + /// - Parameter type: The type of the dependency to register. + /// - Parameter scope: The scope in which the dependency should be resolved. + /// - Parameter closure: An async closure that provides the instance of the dependency. + func registerAsync(_ type: T.Type, _ scope: Scope, closure: @escaping @Sendable (Resolver) async -> T) + + /// Registers a named async dependency with a transient scope. + /// - Parameter type: The type of the dependency to register. + /// - Parameter name: The name associated with the dependency. + /// - Parameter closure: An async closure that provides the instance of the dependency. + func registerAsync(_ type: T.Type, name: String, closure: @escaping @Sendable (Resolver) async -> T) + + /// Registers a named async dependency with a specified scope. + /// - Parameter type: The type of the dependency to register. + /// - Parameter name: The name associated with the dependency. + /// - Parameter scope: The scope in which the dependency should be resolved. + /// - Parameter closure: An async closure that provides the instance of the dependency. + func registerAsync( + _ type: T.Type, + name: String, + _ scope: Scope, + closure: @escaping @Sendable (Resolver) async -> T + ) + + /// Registers a named async dependency with a transient scope. + /// - Parameter type: The type of the dependency to register. + /// - Parameter name: The name associated with the dependency. + /// - Parameter closure: An async closure that provides the instance of the dependency. + func registerAsync(_ type: T.Type, name: RegistrationName, closure: @escaping @Sendable (Resolver) async -> T) + + /// Registers a named async dependency with a specified scope. + /// - Parameter type: The type of the dependency to register. + /// - Parameter name: The name associated with the dependency. + /// - Parameter scope: The scope in which the dependency should be resolved. + /// - Parameter closure: An async closure that provides the instance of the dependency. + func registerAsync( + _ type: T.Type, + name: RegistrationName, + _ scope: Scope, + closure: @escaping @Sendable (Resolver) async -> T + ) } diff --git a/Sources/SCInject/Resolver.swift b/Sources/SCInject/Resolver.swift index ac592df..834b9eb 100644 --- a/Sources/SCInject/Resolver.swift +++ b/Sources/SCInject/Resolver.swift @@ -21,7 +21,7 @@ import Foundation /// Dependencies can be resolved by their type, and optionally by a name, if they were registered with one. /// Implementations of this protocol are typically provided by dependency injection containers, such as /// `DefaultContainer`. -public protocol Resolver: AnyObject { +public protocol Resolver: AnyObject, Sendable { /// Resolves a dependency by its type. /// - Parameter type: The type of the dependency to resolve. /// - Returns: An instance of the resolved dependency. @@ -41,4 +41,26 @@ public protocol Resolver: AnyObject { /// - Returns: An instance of the resolved dependency. /// - Note: The application will crash if the dependency cannot be resolved. func resolve(_ type: T.Type, name: RegistrationName) -> T + + // MARK: - Async Resolution + + /// Asynchronously resolves a dependency by its type. + /// - Parameter type: The type of the dependency to resolve. + /// - Returns: An instance of the resolved dependency. + /// - Note: The application will crash if the dependency cannot be resolved. + func resolveAsync(_ type: T.Type) async -> T + + /// Asynchronously resolves a named dependency by its type. + /// - Parameter type: The type of the dependency to resolve. + /// - Parameter name: The name associated with the dependency. + /// - Returns: An instance of the resolved dependency. + /// - Note: The application will crash if the dependency cannot be resolved. + func resolveAsync(_ type: T.Type, name: String) async -> T + + /// Asynchronously resolves a named dependency by its type. + /// - Parameter type: The type of the dependency to resolve. + /// - Parameter name: The `RegisterName` associated with the dependency. + /// - Returns: An instance of the resolved dependency. + /// - Note: The application will crash if the dependency cannot be resolved. + func resolveAsync(_ type: T.Type, name: RegistrationName) async -> T } diff --git a/Sources/SCInject/Scope.swift b/Sources/SCInject/Scope.swift index d03d989..4e36751 100644 --- a/Sources/SCInject/Scope.swift +++ b/Sources/SCInject/Scope.swift @@ -21,7 +21,7 @@ /// /// - `transient`: A new instance of the dependency is created every time it is resolved. /// - `container`: A single instance of the dependency is created and reused throughout the container's lifetime. -public enum Scope { +public enum Scope: Sendable { /// A new instance of the dependency is created every time it is resolved. case transient diff --git a/Tests/SCInject/ContainerTests.swift b/Tests/SCInject/ContainerTests.swift index 8883c08..ff9c31f 100644 --- a/Tests/SCInject/ContainerTests.swift +++ b/Tests/SCInject/ContainerTests.swift @@ -84,7 +84,7 @@ final class ContainerTests: XCTestCase { XCTAssertTrue(class2?.value !== class1_name) } - func testValidate() { + func testValidate() async throws { // Given let second: RegistrationName = .init(rawValue: "second") let container = DefaultContainer() @@ -99,18 +99,17 @@ final class ContainerTests: XCTestCase { } // When / Then - XCTAssertNoThrow(try container.validate()) + try container.validate() } - func testValidate_missingNamedType() throws { + func testValidate_missingNamedType() async throws { // Given - let second: RegistrationName = .init(rawValue: "second") let container = DefaultContainer() container.register(TestClass1.self) { _ in TestClass1(value: "TestClass1_Instance") } container.register(TestClass2.self) { r in - TestClass2(value: r.resolve(TestClass1.self, name: second)) + TestClass2(value: r.resolve(TestClass1.self, name: "second")) } // When / Then @@ -118,8 +117,108 @@ final class ContainerTests: XCTestCase { let error = error as? ContainerError XCTAssertEqual(error?.reason, "Failed to resolve given type -- TYPE=TestClass1 NAME=second") XCTAssertEqual(error?.type, "TestClass1") - XCTAssertEqual(error?.name, second.rawValue) + XCTAssertEqual(error?.name, "second") + } + } + + // MARK: - Async Tests + + func testRegisterAsync_transientActor() async { + // Given + let container = DefaultContainer() + container.registerAsync(TestActor1.self) { _ in + await TestActor1(value: "TestActor1_Instance") + } + container.registerAsync(TestActor2.self) { r in + await TestActor2(value: r.resolveAsync(TestActor1.self)) + } + + // When + let actor1 = await container.tryResolveAsync(TestActor1.self) + let actor2 = await container.tryResolveAsync(TestActor2.self) + let actor1_name = await container.tryResolveAsync(TestActor1.self, name: "Test") + let actor2_name = await container.tryResolveAsync(TestActor2.self, name: "Test") + let actor1_second = await container.tryResolveAsync(TestActor1.self) + let actor2_second = await container.tryResolveAsync(TestActor2.self) + + // Then + XCTAssertNotNil(actor1) + XCTAssertNotNil(actor2) + XCTAssertNotNil(actor1_second) + XCTAssertNotNil(actor2_second) + XCTAssertNil(actor1_name) + XCTAssertNil(actor2_name) + XCTAssertTrue(actor1 !== actor1_second) + XCTAssertTrue(actor2 !== actor2_second) + XCTAssertEqual(actor1?.rawValue, "TestActor1_Instance") + XCTAssertEqual(actor1_second?.rawValue, "TestActor1_Instance") + } + + func testRegisterAsync_transientActorWithName() async { + // Given + let second = "second" + let container = DefaultContainer() + container.registerAsync(TestActor1.self) { _ in + await TestActor1(value: "TestActor1_Instance") } + container.registerAsync(TestActor1.self, name: second) { _ in + await TestActor1(value: "TestActor1_Second_Instance") + } + container.registerAsync(TestActor2.self) { r in + await TestActor2(value: r.resolveAsync(TestActor1.self, name: second)) + } + + // When + let actor1 = await container.tryResolveAsync(TestActor1.self) + let actor2 = await container.tryResolveAsync(TestActor2.self) + let actor1_name = await container.tryResolveAsync(TestActor1.self, name: second) + let actor2_name = await container.tryResolveAsync(TestActor2.self, name: second) + + // Then + XCTAssertNotNil(actor1) + XCTAssertNotNil(actor2) + XCTAssertNotNil(actor1_name) + XCTAssertNil(actor2_name) + XCTAssertTrue(actor1 !== actor1_name) + XCTAssertEqual(actor1_name?.rawValue, "TestActor1_Second_Instance") + XCTAssertEqual(actor2?.value.rawValue, "TestActor1_Second_Instance") + } + + func testRegisterAsync_containerScope() async { + // Given + let container = DefaultContainer() + container.registerAsync(TestActor1.self, .container) { _ in + await TestActor1(value: "TestActor1_Singleton") + } + + // When + let actor1_first = await container.resolveAsync(TestActor1.self) + let actor1_second = await container.resolveAsync(TestActor1.self) + + // Then + XCTAssertTrue(actor1_first === actor1_second) + XCTAssertEqual(actor1_first.rawValue, "TestActor1_Singleton") + } + + func testMixedSyncAndAsync() async { + // Given + let container = DefaultContainer() + container.register(TestClass1.self) { _ in + TestClass1(value: "TestClass1_Instance") + } + container.registerAsync(TestActor1.self) { _ in + await TestActor1(value: "TestActor1_Instance") + } + + // When + let class1 = container.tryResolve(TestClass1.self) + let actor1 = await container.tryResolveAsync(TestActor1.self) + + // Then + XCTAssertNotNil(class1) + XCTAssertNotNil(actor1) + XCTAssertEqual(class1?.rawValue, "TestClass1_Instance") + XCTAssertEqual(actor1?.rawValue, "TestActor1_Instance") } } diff --git a/Tests/SCInject/TestUtils/Stubs.swift b/Tests/SCInject/TestUtils/Stubs.swift index 838de34..a6792a5 100644 --- a/Tests/SCInject/TestUtils/Stubs.swift +++ b/Tests/SCInject/TestUtils/Stubs.swift @@ -31,3 +31,23 @@ class TestClass2 { self.value = value } } + +actor TestActor1 { + let rawValue: String + + init(value: String) async { + // Simulate async initialization + try? await Task.sleep(nanoseconds: 1_000_000) // 1ms + rawValue = value + } +} + +actor TestActor2 { + let value: TestActor1 + + init(value: TestActor1) async { + // Simulate async initialization + try? await Task.sleep(nanoseconds: 1_000_000) // 1ms + self.value = value + } +}