Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 158 additions & 42 deletions Sources/SCInject/Container.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ public protocol Container: Registry, Resolver {}
/// Dependencies can be registered with or without names, and resolved accordingly. If a dependency is not found in the
/// current container, it will attempt to resolve it from a parent container if one exists.
/// This class is thread-safe.
public final class DefaultContainer: Container {
public final class DefaultContainer: Container, @unchecked Sendable {
private let parent: DefaultContainer?
private let lock = NSRecursiveLock()
private let defaultScope = Scope.transient
private var resolvers: [ResolverIdentifier: ReferenceResolver] = [:]
private var resolvers: [ResolverIdentifier: ConcreteResolver] = [:]

public init(parent: DefaultContainer? = nil) {
self.parent = parent
Expand Down Expand Up @@ -68,6 +68,46 @@ public final class DefaultContainer: Container {
register(type: type, name: name, scope: scope, closure: closure)
}

// MARK: - Async Registry

public func registerAsync<T>(_ type: T.Type, closure: @escaping @Sendable (Resolver) async -> T) {
register(type: type, name: nil, scope: nil, closure: closure)
}

public func registerAsync<T>(_ type: T.Type, _ scope: Scope, closure: @escaping @Sendable (Resolver) async -> T) {
register(type: type, name: nil, scope: scope, closure: closure)
}

public func registerAsync<T>(_ type: T.Type, name: String, closure: @escaping @Sendable (Resolver) async -> T) {
register(type: type, name: .init(rawValue: name), scope: nil, closure: closure)
}

public func registerAsync<T>(
_ type: T.Type,
name: String,
_ scope: Scope,
closure: @escaping @Sendable (Resolver) async -> T
) {
register(type: type, name: .init(rawValue: name), scope: scope, closure: closure)
}

public func registerAsync<T>(
_ type: T.Type,
name: RegistrationName,
closure: @escaping @Sendable (Resolver) async -> T
) {
register(type: type, name: name, scope: nil, closure: closure)
}

public func registerAsync<T>(
_ type: T.Type,
name: RegistrationName,
_ scope: Scope,
closure: @escaping @Sendable (Resolver) async -> T
) {
register(type: type, name: name, scope: scope, closure: closure)
}

// MARK: - Resolver

public func resolve<T>(_ type: T.Type) -> T {
Expand All @@ -90,6 +130,28 @@ public final class DefaultContainer: Container {
return instance
}

// MARK: - Async Resolver

public func resolveAsync<T>(_ type: T.Type) async -> T {
guard let instance = await tryResolve(type: type, name: nil, container: self) else {
ContainerError.raise(reason: "Failed to resolve given async type", type: "\(type)", name: nil)
fatalError()
}
return instance
}

public func resolveAsync<T>(_ type: T.Type, name: String) async -> T {
await resolveAsync(type, name: .init(rawValue: name))
}

public func resolveAsync<T>(_ type: T.Type, name: RegistrationName) async -> T {
guard let instance = await tryResolve(type: type, name: name, container: self) else {
ContainerError.raise(reason: "Failed to resolve given async type", type: "\(type)", name: name.rawValue)
fatalError()
}
return instance
}

// MARK: - Public

public func tryResolve<T>(_ type: T.Type) -> T? {
Expand All @@ -104,6 +166,18 @@ public final class DefaultContainer: Container {
tryResolve(type: type, name: name, container: self)
}

public func tryResolveAsync<T>(_ type: T.Type) async -> T? {
await tryResolve(type: type, name: nil, container: self)
}

public func tryResolveAsync<T>(_ type: T.Type, name: String) async -> T? {
await tryResolve(type: type, name: .init(rawValue: name), container: self)
}

public func tryResolveAsync<T>(_ type: T.Type, name: RegistrationName) async -> T? {
await tryResolve(type: type, name: name, container: self)
}

/// Validates the dependency graph to ensure that all dependencies
/// can be successfully resolved.
///
Expand All @@ -126,7 +200,12 @@ public final class DefaultContainer: Container {
public func validate() throws {
try ContainerError.rethrow {
for resolver in resolvers {
_ = resolver.value.resolve(with: self)
switch resolver.value {
case let .reference(resolver):
_ = resolver.resolve(with: self)
case .referenceAsync:
break
}
}
}
try parent?.validate()
Expand All @@ -146,18 +225,73 @@ public final class DefaultContainer: Container {
ContainerError.raise(reason: "Given type is already registered", type: "\(type)", name: name?.rawValue)
fatalError()
}
resolvers[identifier] = makeResolver(scope ?? defaultScope, closure: closure)
resolvers[identifier] = .reference(makeResolver(scope ?? defaultScope, closure: closure))
}

private func tryResolve<T>(type: T.Type, name: RegistrationName? = nil, container: Container) -> T? {
lock.lock(); defer { lock.unlock() }
if let resolver = resolvers[identifier(of: type, name: name)] {
let resolver: ConcreteResolver? = {
lock.lock(); defer { lock.unlock() }
return resolvers[identifier(of: type, name: name)]
}()

switch resolver {
case let .reference(resolver):
return resolver.resolve(with: container) as? T
case .referenceAsync:
ContainerError.raise(
reason: "Given type requires async resolution",
type: "\(type)",
name: name?.rawValue
)
return nil
case nil:
if let parent {
return parent.tryResolve(type: type, name: name, container: container)
}
return nil
}
}

private func register<T>(
type: T.Type,
name: RegistrationName?,
scope: Scope?,
closure: @escaping @Sendable (Resolver) async -> T
) {
lock.lock(); defer { lock.unlock() }
let identifier = identifier(of: type, name: name)
if resolvers[identifier] != nil {
ContainerError.raise(
reason: "Given async type is already registered",
type: "\(type)",
name: name?.rawValue
)
fatalError()
}
if let parent {
return parent.tryResolve(type: type, name: name, container: container)
resolvers[identifier] = .referenceAsync(makeResolver(scope ?? defaultScope, closure: closure))
}

private func tryResolve<T>(
type: T.Type,
name: RegistrationName? = nil,
container: Container
) async -> T? {
let resolver: ConcreteResolver? = {
lock.lock(); defer { lock.unlock() }
return resolvers[identifier(of: type, name: name)]
}()

switch resolver {
case let .reference(resolver):
return resolver.resolve(with: container) as? T
case let .referenceAsync(resolver):
return await resolver.resolve(with: container) as? T
case nil:
if let parent {
return await parent.tryResolve(type: type, name: name, container: container)
}
return nil
}
return nil
}

private func makeResolver(_ scope: Scope, closure: @escaping (Resolver) -> some Any) -> ReferenceResolver {
Expand All @@ -169,6 +303,18 @@ public final class DefaultContainer: Container {
}
}

private func makeResolver(
_ scope: Scope,
closure: @escaping @Sendable (Resolver) async -> some Any
) -> ReferenceAsyncResolver {
switch scope {
case .transient:
TransientReferenceAsyncResolver(factory: closure)
case .container:
ContainerReferenceAsyncResolver(factory: closure)
}
}

private func identifier(of type: (some Any).Type, name: RegistrationName?) -> ResolverIdentifier {
ResolverIdentifier(
name: name,
Expand All @@ -182,39 +328,9 @@ public final class DefaultContainer: Container {
let typeIdentifier: ObjectIdentifier
let description: String
}
}

private protocol ReferenceResolver {
func resolve(with resolver: Resolver) -> Any
}

private final class TransientReferenceResolver: ReferenceResolver {
private let factory: (Resolver) -> Any

init(factory: @escaping (Resolver) -> Any) {
self.factory = factory
}

func resolve(with resolver: Resolver) -> Any {
factory(resolver)
}
}

private final class ContainerReferenceResolver: ReferenceResolver {
private var instance: Any?

private let factory: (Resolver) -> Any

init(factory: @escaping (Resolver) -> Any) {
self.factory = factory
}

func resolve(with resolver: Resolver) -> Any {
if let instance {
return instance
}
let newInstance = factory(resolver)
instance = newInstance
return newInstance
private enum ConcreteResolver {
case reference(ReferenceResolver)
case referenceAsync(ReferenceAsyncResolver)
}
}
91 changes: 91 additions & 0 deletions Sources/SCInject/ReferenceResolvers.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//
// Copyright 2024 Marcin Iwanicki and contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

import Foundation

// MARK: - Synchronous Resolvers

protocol ReferenceResolver {
func resolve(with resolver: Resolver) -> Any
}

final class TransientReferenceResolver: ReferenceResolver {
private let factory: (Resolver) -> Any

init(factory: @escaping (Resolver) -> Any) {
self.factory = factory
}

func resolve(with resolver: Resolver) -> Any {
factory(resolver)
}
}

final class ContainerReferenceResolver: ReferenceResolver {
private var instance: Any?

private let factory: (Resolver) -> Any

init(factory: @escaping (Resolver) -> Any) {
self.factory = factory
}

func resolve(with resolver: Resolver) -> Any {
if let instance {
return instance
}
let newInstance = factory(resolver)
instance = newInstance
return newInstance
}
}

// MARK: - Asynchronous Resolvers

protocol ReferenceAsyncResolver {
func resolve(with resolver: Resolver) async -> Sendable
}

final class TransientReferenceAsyncResolver: ReferenceAsyncResolver {
private let factory: (Resolver) async -> Sendable

init(factory: @escaping (Resolver) async -> Sendable) {
self.factory = factory
}

func resolve(with resolver: Resolver) async -> Sendable {
await factory(resolver)
}
}

actor ContainerReferenceAsyncResolver: ReferenceAsyncResolver {
private var instance: Sendable?

private let factory: @Sendable (Resolver) async -> Sendable

init(factory: @escaping @Sendable (Resolver) async -> Any) {
self.factory = factory
}

func resolve(with resolver: Resolver) async -> Sendable {
if let instance {
return instance
}
let instance = await factory(resolver)
self.instance = instance
return instance
}
}
Loading