diff --git a/git_theta/async_utils.py b/git_theta/async_utils.py index 40879ea..2d65c9a 100644 --- a/git_theta/async_utils.py +++ b/git_theta/async_utils.py @@ -4,8 +4,11 @@ import dataclasses import functools import sys +from concurrent.futures import thread from typing import Any, Awaitable, Dict, Optional, Sequence, Tuple, TypeVar, Union +import numba +import numba.misc.numba_sysinfo import six if sys.version_info >= (3, 8): @@ -14,6 +17,42 @@ from typing_extensions import Protocol +class Asyncify: + """Wrap sync functions for use in async.""" + + def __init__(self, *args, **kwargs): + self.executor = thread.ThreadPoolExecutor(*args, **kwargs) + # Adapted from https://github.com/numba/numba/blob/d44573b43dec9a7b66e9a0d24ef8db94c3dc346c/numba/misc/numba_sysinfo.py#L459 + try: + # check import is ok, this means the DSO linkage is working + from numba.np.ufunc import tbbpool # NOQA + + # check that the version is compatible, this is a check performed at + # runtime (well, compile time), it will also ImportError if there's + # a problem. + from numba.np.ufunc.parallel import _check_tbb_version_compatible + + numba.misc.numba_sysifo._check_tbb_version_compatible() + self.threadsafe = True + except ImportError as e: + try: + from numba.np.ufunc import omppool + + self.threadsafe = True + except ImportError as e: + self.threadsafe = False + numba.config.THREADING_LAYER = "threadsafe" if self.threadsafe else "default" + + async def __call__(self, fn, *args, **kwargs): + # If numba is thread safe then do it async! + if self.threadsafe: + return await asyncio.wrap_future(self.executor.submit(fn, *args, **kwargs)) + return fn(*args, **kwargs) + + +asyncify = Asyncify() + + def run(*args, **kwargs): """Run an awaitable to completion, dispatch based on python version.""" if sys.version_info < (3, 8): diff --git a/git_theta/git_utils.py b/git_theta/git_utils.py index d36b598..a1586c2 100644 --- a/git_theta/git_utils.py +++ b/git_theta/git_utils.py @@ -2,6 +2,7 @@ import filecmp import fnmatch +import functools import io import json import logging @@ -228,6 +229,7 @@ def remove_file(f, repo): repo.git.rm(f) +@functools.lru_cache(32) def get_file_version(repo, path, commit_hash): path = get_relative_path_from_root(repo, path) try: diff --git a/git_theta/scripts/git_theta_filter.py b/git_theta/scripts/git_theta_filter.py index 6d8a392..1f26f80 100755 --- a/git_theta/scripts/git_theta_filter.py +++ b/git_theta/scripts/git_theta_filter.py @@ -64,7 +64,9 @@ async def _clean(param_keys, new_param): # Get the metadata from the previous version of the parameter param_metadata = prev_metadata.get(param_keys) # Create new metadata from the current value - new_tensor_metadata = metadata.TensorMetadata.from_tensor(new_param) + new_tensor_metadata = await async_utils.asyncify( + metadata.TensorMetadata.from_tensor, new_param + ) # If the parameter tensor has not changed, just keep the metadata the same # TODO: Encapsulate this parameter check within an equality check. diff --git a/git_theta/updates/base.py b/git_theta/updates/base.py index a04917d..573e03d 100644 --- a/git_theta/updates/base.py +++ b/git_theta/updates/base.py @@ -14,7 +14,7 @@ import numpy as np -from git_theta import checkpoints, git_utils, lsh, metadata, params, utils +from git_theta import async_utils, checkpoints, git_utils, lsh, metadata, params, utils from git_theta.lsh.types import Signature Parameter = np.ndarray @@ -100,6 +100,7 @@ async def get_previous_metadata( logging.debug( f"Getting metadata for {'/'.join(param_keys)} from commit {last_commit}" ) + # Note: I tried to asyncify this and it seems to cause deadlock. last_metadata_obj = git_utils.get_file_version(repo, path, last_commit) last_metadata = metadata.Metadata.from_file(last_metadata_obj.data_stream) last_param_metadata = last_metadata.flatten()[param_keys] @@ -179,7 +180,7 @@ async def write( # Calculate and hash the *new* value so that we can update the # metadata when using side-loaded information. new_value = await self.apply_update(update_value, previous_value) - new_hash = lsh.get_lsh().hash(new_value) + new_hash = await async_utils.asyncify(lsh.get_lsh().hash, new_value) return await self.write_update(update_value), new_hash else: update_value = await self.calculate_update(param, previous_value)