Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/dependency_injector/providers.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ cdef class Resource(Provider):
cpdef object _provide(self, tuple args, dict kwargs)


cdef class ContextLocalResource(Resource):
cdef object _resource_context_var
cdef object _shutdowner_context_var

cpdef object _provide(self, tuple args, dict kwargs)


cdef class Container(Provider):
cdef object _container_cls
cdef dict _overriding_providers
Expand Down
2 changes: 2 additions & 0 deletions src/dependency_injector/providers.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ class Resource(Provider[T]):
def init(self) -> Optional[Awaitable[T]]: ...
def shutdown(self) -> Optional[Awaitable]: ...

class ContextLocalResource(Resource[T]):...

class Container(Provider[T]):
def __init__(
self,
Expand Down
129 changes: 128 additions & 1 deletion src/dependency_injector/providers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3186,7 +3186,7 @@ cdef class ThreadLocalSingleton(BaseSingleton):
return future_result

self._storage.instance = instance

return instance

def _async_init_instance(self, future_result, result):
Expand Down Expand Up @@ -3867,6 +3867,133 @@ cdef class Resource(Provider):
return self._resource


cdef class ContextLocalResource(Resource):
_none = object()

def __init__(self, provides=None, *args, **kwargs):
self._resource_context_var = ContextVar("_resource_context_var", default=self._none)
self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=self._none)
super().__init__(provides, *args, **kwargs)

def __deepcopy__(self, memo):
"""Create and return full copy of provider."""
copied = memo.get(id(self))
if copied is not None:
return copied

if self._resource_context_var.get() != self._none:
raise Error("Can not copy initialized resource")
copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))

self._copy_overridings(copied, memo)

return copied

@property
def initialized(self):
"""Check if resource is initialized."""
return self._resource_context_var.get() != self._none


def shutdown(self):
"""Shutdown resource."""
if self._resource_context_var.get() == self._none :
self._reset_all_contex_vars()
if self._async_mode == ASYNC_MODE_ENABLED:
return NULL_AWAITABLE
return
if self._shutdowner_context_var.get() != self._none:
future = self._shutdowner_context_var.get()(None, None, None)
if __is_future_or_coroutine(future):
self._reset_all_contex_vars()
return ensure_future(self._shutdown_async(future))


self._reset_all_contex_vars()
if self._async_mode == ASYNC_MODE_ENABLED:
return NULL_AWAITABLE

def _reset_all_contex_vars(self):
self._resource_context_var.set(self._none)
self._shutdowner_context_var.set(self._none)


async def _shutdown_async(self, future) -> None:
await future


async def _handle_async_cm(self, obj) -> None:
resource = await obj.__aenter__()
return resource

async def _provide_async(self, future):
try:
obj = await future

if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
resource = await obj.__aenter__()
shutdowner = obj.__aexit__
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
resource = obj.__enter__()
shutdowner = obj.__exit__
else:
resource = obj
shutdowner = None

return resource, shutdowner
except:
raise

cpdef object _provide(self, tuple args, dict kwargs):
if self._resource_context_var.get() != self._none:
return self._resource_context_var.get()
obj = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)

if __is_future_or_coroutine(obj):
future_result = asyncio.Future()
future = ensure_future(self._provide_async(obj))
future.add_done_callback(functools.partial(self._async_init_instance, future_result))
return future_result
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
resource = obj.__enter__()
self._resource_context_var.set(resource)
self._shutdowner_context_var.set(obj.__exit__)
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
resource = ensure_future(self._handle_async_cm(obj))
self._resource_context_var.set(resource)
self._shutdowner_context_var.set(obj.__aexit__)
return resource
else:
self._resource_context_var.set(obj)
self._shutdowner_context_var.set(self._none)

return self._resource_context_var.get()

def _async_init_instance(self, future_result, result):
try:
resource, shutdowner = result.result()
except Exception as exception:
self._resource_context_var.set(self._none)
self._shutdowner_context_var.set(self._none)
future_result.set_exception(exception)
else:
self._resource_context_var.set(resource)
self._shutdowner_context_var.set(shutdowner)
future_result.set_result(resource)


cdef class Container(Provider):
"""Container provider provides an instance of declarative container.

Expand Down
Loading