From b8f06bc2efbee603f1faaf7b9597c33142b3a87c Mon Sep 17 00:00:00 2001 From: Brian Lester Date: Fri, 14 Apr 2023 13:09:25 -0400 Subject: [PATCH 1/5] Try to speed things up with async This PR adds an `asyncify` function to try to turn sync code into async code. Still testing if this speeds things up. --- git_theta/async_utils.py | 14 ++++++++++++++ git_theta/git_utils.py | 2 ++ git_theta/scripts/git_theta_filter.py | 4 +++- git_theta/updates/base.py | 5 +++-- 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/git_theta/async_utils.py b/git_theta/async_utils.py index 40879ea..d38e5a0 100644 --- a/git_theta/async_utils.py +++ b/git_theta/async_utils.py @@ -4,6 +4,7 @@ import dataclasses import functools import sys +from concurrent.futures import thread from typing import Any, Awaitable, Dict, Optional, Sequence, Tuple, TypeVar, Union import six @@ -14,6 +15,19 @@ from typing_extensions import Protocol +class Asyncify: + """Wrap sync functions for use in async.""" + + def __init__(self, *args, **kwargs): + self.executor = thread.ThreadPoolExecutor(*args, **kwargs) + + async def __call__(self, fn, *args, **kwargs): + return await asyncio.wrap_future(self.executor.submit(fn, *args, **kwargs)) + + +asyncify = Asyncify() + + def run(*args, **kwargs): """Run an awaitable to completion, dispatch based on python version.""" if sys.version_info < (3, 8): diff --git a/git_theta/git_utils.py b/git_theta/git_utils.py index d36b598..a1586c2 100644 --- a/git_theta/git_utils.py +++ b/git_theta/git_utils.py @@ -2,6 +2,7 @@ import filecmp import fnmatch +import functools import io import json import logging @@ -228,6 +229,7 @@ def remove_file(f, repo): repo.git.rm(f) +@functools.lru_cache(32) def get_file_version(repo, path, commit_hash): path = get_relative_path_from_root(repo, path) try: diff --git a/git_theta/scripts/git_theta_filter.py b/git_theta/scripts/git_theta_filter.py index 6d8a392..1f26f80 100755 --- a/git_theta/scripts/git_theta_filter.py +++ b/git_theta/scripts/git_theta_filter.py @@ -64,7 +64,9 @@ async def _clean(param_keys, new_param): # Get the metadata from the previous version of the parameter param_metadata = prev_metadata.get(param_keys) # Create new metadata from the current value - new_tensor_metadata = metadata.TensorMetadata.from_tensor(new_param) + new_tensor_metadata = await async_utils.asyncify( + metadata.TensorMetadata.from_tensor, new_param + ) # If the parameter tensor has not changed, just keep the metadata the same # TODO: Encapsulate this parameter check within an equality check. diff --git a/git_theta/updates/base.py b/git_theta/updates/base.py index a04917d..573e03d 100644 --- a/git_theta/updates/base.py +++ b/git_theta/updates/base.py @@ -14,7 +14,7 @@ import numpy as np -from git_theta import checkpoints, git_utils, lsh, metadata, params, utils +from git_theta import async_utils, checkpoints, git_utils, lsh, metadata, params, utils from git_theta.lsh.types import Signature Parameter = np.ndarray @@ -100,6 +100,7 @@ async def get_previous_metadata( logging.debug( f"Getting metadata for {'/'.join(param_keys)} from commit {last_commit}" ) + # Note: I tried to asyncify this and it seems to cause deadlock. last_metadata_obj = git_utils.get_file_version(repo, path, last_commit) last_metadata = metadata.Metadata.from_file(last_metadata_obj.data_stream) last_param_metadata = last_metadata.flatten()[param_keys] @@ -179,7 +180,7 @@ async def write( # Calculate and hash the *new* value so that we can update the # metadata when using side-loaded information. new_value = await self.apply_update(update_value, previous_value) - new_hash = lsh.get_lsh().hash(new_value) + new_hash = await async_utils.asyncify(lsh.get_lsh().hash, new_value) return await self.write_update(update_value), new_hash else: update_value = await self.calculate_update(param, previous_value) From 6c940b2267ddb50f06e77779d0bc5e6d63d45005 Mon Sep 17 00:00:00 2001 From: Brian Lester Date: Fri, 14 Apr 2023 13:37:19 -0400 Subject: [PATCH 2/5] use better threading library --- git_theta/__init__.py | 4 ++++ setup.py | 1 + 2 files changed, 5 insertions(+) diff --git a/git_theta/__init__.py b/git_theta/__init__.py index 9f27665..e8de2ab 100644 --- a/git_theta/__init__.py +++ b/git_theta/__init__.py @@ -1,5 +1,9 @@ __version__ = "0.0.2" +import numba + +numba.config.THREADING_LAYER = "tbb" + from git_theta import ( checkpoints, git_utils, diff --git a/setup.py b/setup.py index 11e2688..c872ca9 100644 --- a/setup.py +++ b/setup.py @@ -78,6 +78,7 @@ def get_version(file_name: str, version_variable: str = "__version__") -> str: 'importlib_resources; python_version < "3.9.0"', 'importlib_metadata; python_version < "3.10.0"', 'typing_extensions; python_version < "3.8.0"', + "tbb", ], extras_require={ **frameworks_require, From 4fc2143f38867c230076247a924909a3507c7f5f Mon Sep 17 00:00:00 2001 From: Brian Lester Date: Fri, 14 Apr 2023 13:51:05 -0400 Subject: [PATCH 3/5] use better threading library --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index c872ca9..ed76850 100644 --- a/setup.py +++ b/setup.py @@ -79,6 +79,7 @@ def get_version(file_name: str, version_variable: str = "__version__") -> str: 'importlib_metadata; python_version < "3.10.0"', 'typing_extensions; python_version < "3.8.0"', "tbb", + "tbb-devel", ], extras_require={ **frameworks_require, From 2ce9f20980c96190a3e8edfaefac1a366e1683ba Mon Sep 17 00:00:00 2001 From: Brian Lester Date: Tue, 25 Apr 2023 13:34:09 -0400 Subject: [PATCH 4/5] ups --- git_theta/__init__.py | 2 +- setup.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/git_theta/__init__.py b/git_theta/__init__.py index e8de2ab..18ed8e9 100644 --- a/git_theta/__init__.py +++ b/git_theta/__init__.py @@ -2,7 +2,7 @@ import numba -numba.config.THREADING_LAYER = "tbb" +numba.config.THREADING_LAYER = "threadsafe" from git_theta import ( checkpoints, diff --git a/setup.py b/setup.py index ed76850..11e2688 100644 --- a/setup.py +++ b/setup.py @@ -78,8 +78,6 @@ def get_version(file_name: str, version_variable: str = "__version__") -> str: 'importlib_resources; python_version < "3.9.0"', 'importlib_metadata; python_version < "3.10.0"', 'typing_extensions; python_version < "3.8.0"', - "tbb", - "tbb-devel", ], extras_require={ **frameworks_require, From cfd959466813d454b7276fc59eacd35991fef504 Mon Sep 17 00:00:00 2001 From: Brian Lester Date: Tue, 25 Apr 2023 14:04:41 -0400 Subject: [PATCH 5/5] detect threading library --- git_theta/__init__.py | 4 ---- git_theta/async_utils.py | 27 ++++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/git_theta/__init__.py b/git_theta/__init__.py index 18ed8e9..9f27665 100644 --- a/git_theta/__init__.py +++ b/git_theta/__init__.py @@ -1,9 +1,5 @@ __version__ = "0.0.2" -import numba - -numba.config.THREADING_LAYER = "threadsafe" - from git_theta import ( checkpoints, git_utils, diff --git a/git_theta/async_utils.py b/git_theta/async_utils.py index d38e5a0..2d65c9a 100644 --- a/git_theta/async_utils.py +++ b/git_theta/async_utils.py @@ -7,6 +7,8 @@ from concurrent.futures import thread from typing import Any, Awaitable, Dict, Optional, Sequence, Tuple, TypeVar, Union +import numba +import numba.misc.numba_sysinfo import six if sys.version_info >= (3, 8): @@ -20,9 +22,32 @@ class Asyncify: def __init__(self, *args, **kwargs): self.executor = thread.ThreadPoolExecutor(*args, **kwargs) + # Adapted from https://github.com/numba/numba/blob/d44573b43dec9a7b66e9a0d24ef8db94c3dc346c/numba/misc/numba_sysinfo.py#L459 + try: + # check import is ok, this means the DSO linkage is working + from numba.np.ufunc import tbbpool # NOQA + + # check that the version is compatible, this is a check performed at + # runtime (well, compile time), it will also ImportError if there's + # a problem. + from numba.np.ufunc.parallel import _check_tbb_version_compatible + + numba.misc.numba_sysifo._check_tbb_version_compatible() + self.threadsafe = True + except ImportError as e: + try: + from numba.np.ufunc import omppool + + self.threadsafe = True + except ImportError as e: + self.threadsafe = False + numba.config.THREADING_LAYER = "threadsafe" if self.threadsafe else "default" async def __call__(self, fn, *args, **kwargs): - return await asyncio.wrap_future(self.executor.submit(fn, *args, **kwargs)) + # If numba is thread safe then do it async! + if self.threadsafe: + return await asyncio.wrap_future(self.executor.submit(fn, *args, **kwargs)) + return fn(*args, **kwargs) asyncify = Asyncify()