Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions git_theta/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import dataclasses
import functools
import sys
from concurrent.futures import thread
from typing import Any, Awaitable, Dict, Optional, Sequence, Tuple, TypeVar, Union

import numba
import numba.misc.numba_sysinfo
import six

if sys.version_info >= (3, 8):
Expand All @@ -14,6 +17,42 @@
from typing_extensions import Protocol


class Asyncify:
"""Wrap sync functions for use in async."""

def __init__(self, *args, **kwargs):
self.executor = thread.ThreadPoolExecutor(*args, **kwargs)
# Adapted from https://github.com/numba/numba/blob/d44573b43dec9a7b66e9a0d24ef8db94c3dc346c/numba/misc/numba_sysinfo.py#L459
try:
# check import is ok, this means the DSO linkage is working
from numba.np.ufunc import tbbpool # NOQA

# check that the version is compatible, this is a check performed at
# runtime (well, compile time), it will also ImportError if there's
# a problem.
from numba.np.ufunc.parallel import _check_tbb_version_compatible

numba.misc.numba_sysifo._check_tbb_version_compatible()
self.threadsafe = True
except ImportError as e:
try:
from numba.np.ufunc import omppool

self.threadsafe = True
except ImportError as e:
self.threadsafe = False
numba.config.THREADING_LAYER = "threadsafe" if self.threadsafe else "default"

async def __call__(self, fn, *args, **kwargs):
# If numba is thread safe then do it async!
if self.threadsafe:
return await asyncio.wrap_future(self.executor.submit(fn, *args, **kwargs))
return fn(*args, **kwargs)


asyncify = Asyncify()


def run(*args, **kwargs):
"""Run an awaitable to completion, dispatch based on python version."""
if sys.version_info < (3, 8):
Expand Down
2 changes: 2 additions & 0 deletions git_theta/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import filecmp
import fnmatch
import functools
import io
import json
import logging
Expand Down Expand Up @@ -228,6 +229,7 @@ def remove_file(f, repo):
repo.git.rm(f)


@functools.lru_cache(32)
def get_file_version(repo, path, commit_hash):
path = get_relative_path_from_root(repo, path)
try:
Expand Down
4 changes: 3 additions & 1 deletion git_theta/scripts/git_theta_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ async def _clean(param_keys, new_param):
# Get the metadata from the previous version of the parameter
param_metadata = prev_metadata.get(param_keys)
# Create new metadata from the current value
new_tensor_metadata = metadata.TensorMetadata.from_tensor(new_param)
new_tensor_metadata = await async_utils.asyncify(
metadata.TensorMetadata.from_tensor, new_param
)

# If the parameter tensor has not changed, just keep the metadata the same
# TODO: Encapsulate this parameter check within an equality check.
Expand Down
5 changes: 3 additions & 2 deletions git_theta/updates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import numpy as np

from git_theta import checkpoints, git_utils, lsh, metadata, params, utils
from git_theta import async_utils, checkpoints, git_utils, lsh, metadata, params, utils
from git_theta.lsh.types import Signature

Parameter = np.ndarray
Expand Down Expand Up @@ -100,6 +100,7 @@ async def get_previous_metadata(
logging.debug(
f"Getting metadata for {'/'.join(param_keys)} from commit {last_commit}"
)
# Note: I tried to asyncify this and it seems to cause deadlock.
last_metadata_obj = git_utils.get_file_version(repo, path, last_commit)
last_metadata = metadata.Metadata.from_file(last_metadata_obj.data_stream)
last_param_metadata = last_metadata.flatten()[param_keys]
Expand Down Expand Up @@ -179,7 +180,7 @@ async def write(
# Calculate and hash the *new* value so that we can update the
# metadata when using side-loaded information.
new_value = await self.apply_update(update_value, previous_value)
new_hash = lsh.get_lsh().hash(new_value)
new_hash = await async_utils.asyncify(lsh.get_lsh().hash, new_value)
return await self.write_update(update_value), new_hash
else:
update_value = await self.calculate_update(param, previous_value)
Expand Down