From 23ed8e4bbb12655e1099649ddae8eefba7e52c01 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Oct 2022 21:48:48 +0200 Subject: [PATCH 1/7] refactor patching logic --- light_the_torch/_packages.py | 140 ++++++++++++++++++++++++++++++++++ light_the_torch/_patch.py | 144 ++++++++--------------------------- 2 files changed, 171 insertions(+), 113 deletions(-) create mode 100644 light_the_torch/_packages.py diff --git a/light_the_torch/_packages.py b/light_the_torch/_packages.py new file mode 100644 index 0000000..f557365 --- /dev/null +++ b/light_the_torch/_packages.py @@ -0,0 +1,140 @@ +import abc +import enum +import itertools +import re + +from pip._internal.models.search_scope import SearchScope + +from . import _cb as cb + + +class Channel(enum.Enum): + STABLE = enum.auto() + TEST = enum.auto() + NIGHTLY = enum.auto() + LTS = enum.auto() + + @classmethod + def from_str(cls, string): + return cls[string.upper()] + + +class PatchedPackages: + _PATCHED_PACKAGE_CLSS_MAP = {} + + @classmethod + def _register(cls, name): + def wrapper(patched_package_cls): + cls._PATCHED_PACKAGE_CLSS_MAP[name] = patched_package_cls + return patched_package_cls + + return wrapper + + def __init__(self, options): + self._options = options + self._patched_packages_map = { + name: cls(options) for name, cls in self._PATCHED_PACKAGE_CLSS_MAP.items() + } + + def get(self, name): + return self._patched_packages_map.get(name) + + +class _PatchedPackage(abc.ABC): + def __init__(self, options): + self._options = options + + @abc.abstractmethod + def make_search_scope(self): + pass + + @abc.abstractmethod + def filter_candidates(self, candidates): + pass + + @abc.abstractmethod + def make_sort_key(self, candidate): + pass + + +class _PatchedPyTorchPackage(_PatchedPackage): + def _get_extra_index_urls(self, computation_backends, channel): + if channel == Channel.STABLE: + channel_paths = [""] + elif channel == Channel.LTS: + channel_paths = [ + f"lts/{major}.{minor}/" + for major, minor in [ + (1, 8), + ] + ] + else: + channel_paths = [f"{channel.name.lower()}/"] + return [ + f"https://download.pytorch.org/whl/{channel_path}{backend}" + for channel_path, backend in itertools.product( + channel_paths, sorted(computation_backends) + ) + ] + + def make_search_scope(self): + return SearchScope( + find_links=[], + index_urls=self._get_extra_index_urls( + self._options.computation_backends, self._options.channel + ), + no_index=False, + ) + + _COMPUTATION_BACKEND_PATTERN = re.compile( + r"/(?P(cpu|cu\d+|rocm([\d.]+)))/" + ) + + def _extract_local_specifier(self, candidate): + local = candidate.version.local + + if local is None: + match = self._COMPUTATION_BACKEND_PATTERN.search(candidate.link.path) + local = match["computation_backend"] if match else "any" + + # Early PyTorch distributions used the "any" local specifier to indicate a + # pure Python binary. This was changed to no local specifier later. + # Setting this to "cpu" is technically not correct as it will exclude this + # binary if a non-CPU backend is requested. Still, this is probably the + # right thing to do, since the user requested a specific backend and + # although this binary will work with it, it was not compiled against it. + if local == "any": + local = "cpu" + + return local + + def filter_candidates(self, candidates): + return [ + candidate + for candidate in candidates + if self._extract_local_specifier(candidate) + in self._options.computation_backends + ] + + def make_sort_key(self, candidate): + return ( + cb.ComputationBackend.from_str(self._extract_local_specifier(candidate)), + candidate.version.base_version, + ) + + +for name in ["torch", "torchvision", "torchaudio"]: + PatchedPackages._register(name)(_PatchedPyTorchPackage) + + +@PatchedPackages._register("torchdata") +class _TorchData(_PatchedPyTorchPackage): + def make_search_scope(self): + if self._options.channel == Channel.STABLE: + return SearchScope( + find_links=[], + index_urls=["https://pypi.org/simple"], + no_index=False, + ) + + return super().make_search_scope() diff --git a/light_the_torch/_patch.py b/light_the_torch/_patch.py index a28b3d8..8eb60c5 100644 --- a/light_the_torch/_patch.py +++ b/light_the_torch/_patch.py @@ -1,41 +1,23 @@ import contextlib import dataclasses -import enum import functools import itertools import optparse import os -import re import sys import unittest.mock from typing import List, Set from unittest import mock import pip._internal.cli.cmdoptions - -from pip._internal.index.collector import CollectedSources from pip._internal.index.package_finder import CandidateEvaluator -from pip._internal.index.sources import build_source -from pip._internal.models.search_scope import SearchScope import light_the_torch as ltt - from . import _cb as cb - +from ._packages import Channel, PatchedPackages from ._utils import apply_fn_patch -class Channel(enum.Enum): - STABLE = enum.auto() - TEST = enum.auto() - NIGHTLY = enum.auto() - LTS = enum.auto() - - @classmethod - def from_str(cls, string): - return cls[string.upper()] - - PYTORCH_DISTRIBUTIONS = { "torch", "torch_model_archiver", @@ -168,11 +150,13 @@ def from_pip_argv(cls, argv: List[str]): def apply_patches(argv): options = LttOptions.from_pip_argv(argv) + packages = PatchedPackages(options) + patches = [ patch_cli_version(), patch_cli_options(), - patch_link_collection(options.computation_backends, options.channel), - patch_candidate_selection(options.computation_backends), + patch_link_collection(packages), + patch_candidate_selection(packages), ] with contextlib.ExitStack() as stack: @@ -239,57 +223,17 @@ def get_extra_index_urls(computation_backends, channel): @contextlib.contextmanager -def patch_link_collection(computation_backends, channel): - search_scope = SearchScope( - find_links=[], - index_urls=get_extra_index_urls(computation_backends, channel), - no_index=False, - ) - +def patch_link_collection(packages): @contextlib.contextmanager def context(input): - if input.project_name not in PYTORCH_DISTRIBUTIONS: + package = packages.get(input.project_name) + if not package: yield return - with mock.patch.object(input.self, "search_scope", search_scope): + with mock.patch.object(input.self, "search_scope", package.make_search_scope()): yield - def postprocessing(input, output): - if input.project_name not in PYTORCH_DISTRIBUTIONS: - return output - - if channel != Channel.STABLE: - return output - - # Some stable binaries are not hosted on the PyTorch indices. We check if this - # is the case for the current distribution. - for remote_file_source in output.index_urls: - candidates = list(remote_file_source.page_candidates()) - - # Cache the candidates, so `pip` doesn't has to retrieve them again later. - remote_file_source.page_candidates = lambda: iter(candidates) - - # If there are any candidates on the PyTorch indices, we continue normally. - if candidates: - return output - - # In case the distribution is not present on the PyTorch indices, we fall back - # to PyPI. - _, pypi_file_source = build_source( - SearchScope( - find_links=[], - index_urls=["https://pypi.org/simple"], - no_index=False, - ).get_index_urls_locations(input.project_name)[0], - candidates_from_page=input.candidates_from_page, - page_validator=input.self.session.is_secure_origin, - expand_dir=False, - cache_link_parsing=False, - ) - - return CollectedSources(find_links=[], index_urls=[pypi_file_source]) - with apply_fn_patch( "pip", "_internal", @@ -298,67 +242,43 @@ def postprocessing(input, output): "LinkCollector", "collect_sources", context=context, - postprocessing=postprocessing, ): yield @contextlib.contextmanager -def patch_candidate_selection(computation_backends): - computation_backend_pattern = re.compile( - r"/(?P(cpu|cu\d+|rocm([\d.]+)))/" - ) - - def extract_local_specifier(candidate): - local = candidate.version.local - - if local is None: - match = computation_backend_pattern.search(candidate.link.path) - local = match["computation_backend"] if match else "any" - - # Early PyTorch distributions used the "any" local specifier to indicate a - # pure Python binary. This was changed to no local specifier later. - # Setting this to "cpu" is technically not correct as it will exclude this - # binary if a non-CPU backend is requested. Still, this is probably the - # right thing to do, since the user requested a specific backend and - # although this binary will work with it, it was not compiled against it. - if local == "any": - local = "cpu" - - return local - +def patch_candidate_selection(packages): def preprocessing(input): if not input.candidates: return - candidates = iter(input.candidates) - candidate = next(candidates) - - if candidate.name not in PYTORCH_DISTRIBUTIONS: - # At this stage all candidates have the same name. Thus, if the first is - # not a PyTorch distribution, we don't need to check the rest and can - # return without changes. + # At this stage all candidates have the same name. Thus, if the first is + # not a PyTorch distribution, we don't need to check the rest and can + # return without changes. + package = packages.get(input.candidates[0].name) + if not package: return - input.candidates = [ - candidate - for candidate in itertools.chain([candidate], candidates) - if extract_local_specifier(candidate) in computation_backends - ] + input.candidates = list(package.filter_candidates(input.candidates)) vanilla_sort_key = CandidateEvaluator._sort_key def patched_sort_key(candidate_evaluator, candidate): # At this stage all candidates have the same name. Thus, we don't need to # mirror the exact key structure that the vanilla sort keys have. - return ( - vanilla_sort_key(candidate_evaluator, candidate) - if candidate.name not in PYTORCH_DISTRIBUTIONS - else ( - cb.ComputationBackend.from_str(extract_local_specifier(candidate)), - candidate.version.base_version, - ) - ) + package = packages.get(candidate.name) + if not package: + return vanilla_sort_key(candidate_evaluator, candidate) + + return package.make_sort_key(candidate) + + @contextlib.contextmanager + def context(input): + # TODO: refactor this to early return here + with unittest.mock.patch.object( + CandidateEvaluator, "_sort_key", new=patched_sort_key + ): + yield with apply_fn_patch( "pip", @@ -368,8 +288,6 @@ def patched_sort_key(candidate_evaluator, candidate): "CandidateEvaluator", "get_applicable_candidates", preprocessing=preprocessing, + context=context, ): - with unittest.mock.patch.object( - CandidateEvaluator, "_sort_key", new=patched_sort_key - ): - yield + yield From 6c544a6c73af0fa9ac3ac595222ff243dd24187c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Oct 2022 22:12:38 +0200 Subject: [PATCH 2/7] small fix --- light_the_torch/_packages.py | 17 ++++++++++++- light_the_torch/_patch.py | 49 ++++++------------------------------ 2 files changed, 23 insertions(+), 43 deletions(-) diff --git a/light_the_torch/_packages.py b/light_the_torch/_packages.py index f557365..909bd81 100644 --- a/light_the_torch/_packages.py +++ b/light_the_torch/_packages.py @@ -36,6 +36,9 @@ def __init__(self, options): name: cls(options) for name, cls in self._PATCHED_PACKAGE_CLSS_MAP.items() } + def __contains__(self, name): + return name in self._patched_packages_map + def get(self, name): return self._patched_packages_map.get(name) @@ -123,7 +126,19 @@ def make_sort_key(self, candidate): ) -for name in ["torch", "torchvision", "torchaudio"]: +# FIXME: check whether all of these are hosted on all channels +for name in { + "torch", + "torch_model_archiver", + "torch_tb_profiler", + "torcharrow", + "torchaudio", + "torchcsprng", + "torchdistx", + "torchserve", + "torchtext", + "torchvision", +}: PatchedPackages._register(name)(_PatchedPyTorchPackage) diff --git a/light_the_torch/_patch.py b/light_the_torch/_patch.py index 8eb60c5..a7547f7 100644 --- a/light_the_torch/_patch.py +++ b/light_the_torch/_patch.py @@ -1,7 +1,6 @@ import contextlib import dataclasses import functools -import itertools import optparse import os import sys @@ -18,21 +17,6 @@ from ._utils import apply_fn_patch -PYTORCH_DISTRIBUTIONS = { - "torch", - "torch_model_archiver", - "torch_tb_profiler", - "torcharrow", - "torchaudio", - "torchcsprng", - "torchdata", - "torchdistx", - "torchserve", - "torchtext", - "torchvision", -} - - def patch(pip_main): @functools.wraps(pip_main) def wrapper(argv=None): @@ -202,26 +186,6 @@ def postprocessing(input, output): yield -def get_extra_index_urls(computation_backends, channel): - if channel == Channel.STABLE: - channel_paths = [""] - elif channel == Channel.LTS: - channel_paths = [ - f"lts/{major}.{minor}/" - for major, minor in [ - (1, 8), - ] - ] - else: - channel_paths = [f"{channel.name.lower()}/"] - return [ - f"https://download.pytorch.org/whl/{channel_path}{backend}" - for channel_path, backend in itertools.product( - channel_paths, sorted(computation_backends) - ) - ] - - @contextlib.contextmanager def patch_link_collection(packages): @contextlib.contextmanager @@ -264,17 +228,18 @@ def preprocessing(input): vanilla_sort_key = CandidateEvaluator._sort_key def patched_sort_key(candidate_evaluator, candidate): - # At this stage all candidates have the same name. Thus, we don't need to - # mirror the exact key structure that the vanilla sort keys have. package = packages.get(candidate.name) - if not package: - return vanilla_sort_key(candidate_evaluator, candidate) - + assert package return package.make_sort_key(candidate) @contextlib.contextmanager def context(input): - # TODO: refactor this to early return here + # At this stage all candidates have the same name. Thus, we don't need to + # mirror the exact key structure that the vanilla sort keys have. + if not input.candidates or input.candidates[0].name not in packages: + yield + return + with unittest.mock.patch.object( CandidateEvaluator, "_sort_key", new=patched_sort_key ): From 9a170d697411f7573bd734dbaab53e57cf9283c1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 23 Jan 2023 11:17:31 +0100 Subject: [PATCH 3/7] small update --- light_the_torch/_packages.py | 88 +++++++++++++++--------------------- light_the_torch/_patch.py | 22 ++++----- 2 files changed, 46 insertions(+), 64 deletions(-) diff --git a/light_the_torch/_packages.py b/light_the_torch/_packages.py index 909bd81..c8eac1d 100644 --- a/light_the_torch/_packages.py +++ b/light_the_torch/_packages.py @@ -1,4 +1,5 @@ import abc +import dataclasses import enum import itertools import re @@ -7,7 +8,28 @@ from . import _cb as cb +__all__ = ["packages"] + +@dataclasses.dataclass +class _Package(abc.ABC): + name: str + + @abc.abstractmethod + def make_search_scope(self, options): + pass + + @abc.abstractmethod + def filter_candidates(self, candidates, options): + pass + + @abc.abstractmethod + def make_sort_key(self, candidate, options): + pass + + +# FIXME: move this to cli patch +# create patch.cli and patch.packages class Channel(enum.Enum): STABLE = enum.auto() TEST = enum.auto() @@ -19,48 +41,10 @@ def from_str(cls, string): return cls[string.upper()] -class PatchedPackages: - _PATCHED_PACKAGE_CLSS_MAP = {} - - @classmethod - def _register(cls, name): - def wrapper(patched_package_cls): - cls._PATCHED_PACKAGE_CLSS_MAP[name] = patched_package_cls - return patched_package_cls - - return wrapper - - def __init__(self, options): - self._options = options - self._patched_packages_map = { - name: cls(options) for name, cls in self._PATCHED_PACKAGE_CLSS_MAP.items() - } - - def __contains__(self, name): - return name in self._patched_packages_map - - def get(self, name): - return self._patched_packages_map.get(name) - - -class _PatchedPackage(abc.ABC): - def __init__(self, options): - self._options = options - - @abc.abstractmethod - def make_search_scope(self): - pass - - @abc.abstractmethod - def filter_candidates(self, candidates): - pass - - @abc.abstractmethod - def make_sort_key(self, candidate): - pass +packages = {} -class _PatchedPyTorchPackage(_PatchedPackage): +class _PyTorchDistribution(_Package): def _get_extra_index_urls(self, computation_backends, channel): if channel == Channel.STABLE: channel_paths = [""] @@ -80,11 +64,11 @@ def _get_extra_index_urls(self, computation_backends, channel): ) ] - def make_search_scope(self): + def make_search_scope(self, options): return SearchScope( find_links=[], index_urls=self._get_extra_index_urls( - self._options.computation_backends, self._options.channel + options.computation_backends, options.channel ), no_index=False, ) @@ -111,15 +95,14 @@ def _extract_local_specifier(self, candidate): return local - def filter_candidates(self, candidates): + def filter_candidates(self, candidates, options): return [ candidate for candidate in candidates - if self._extract_local_specifier(candidate) - in self._options.computation_backends + if self._extract_local_specifier(candidate) in options.computation_backends ] - def make_sort_key(self, candidate): + def make_sort_key(self, candidate, options): return ( cb.ComputationBackend.from_str(self._extract_local_specifier(candidate)), candidate.version.base_version, @@ -127,6 +110,8 @@ def make_sort_key(self, candidate): # FIXME: check whether all of these are hosted on all channels +# If not, change `_TorchData` below to a more general class +# FIXME: check if they are valid at all for name in { "torch", "torch_model_archiver", @@ -139,17 +124,16 @@ def make_sort_key(self, candidate): "torchtext", "torchvision", }: - PatchedPackages._register(name)(_PatchedPyTorchPackage) + packages[name] = _PyTorchDistribution(name) -@PatchedPackages._register("torchdata") -class _TorchData(_PatchedPyTorchPackage): - def make_search_scope(self): - if self._options.channel == Channel.STABLE: +class _TorchData(_PyTorchDistribution): + def make_search_scope(self, options): + if options.channel == Channel.STABLE: return SearchScope( find_links=[], index_urls=["https://pypi.org/simple"], no_index=False, ) - return super().make_search_scope() + return super().make_search_scope(options) diff --git a/light_the_torch/_patch.py b/light_the_torch/_patch.py index a7547f7..f0bede9 100644 --- a/light_the_torch/_patch.py +++ b/light_the_torch/_patch.py @@ -13,7 +13,7 @@ import light_the_torch as ltt from . import _cb as cb -from ._packages import Channel, PatchedPackages +from ._packages import Channel, packages from ._utils import apply_fn_patch @@ -134,13 +134,11 @@ def from_pip_argv(cls, argv: List[str]): def apply_patches(argv): options = LttOptions.from_pip_argv(argv) - packages = PatchedPackages(options) - patches = [ patch_cli_version(), patch_cli_options(), - patch_link_collection(packages), - patch_candidate_selection(packages), + patch_link_collection(packages, options), + patch_candidate_selection(packages, options), ] with contextlib.ExitStack() as stack: @@ -187,7 +185,7 @@ def postprocessing(input, output): @contextlib.contextmanager -def patch_link_collection(packages): +def patch_link_collection(packages, options): @contextlib.contextmanager def context(input): package = packages.get(input.project_name) @@ -195,7 +193,9 @@ def context(input): yield return - with mock.patch.object(input.self, "search_scope", package.make_search_scope()): + with mock.patch.object( + input.self, "search_scope", package.make_search_scope(options) + ): yield with apply_fn_patch( @@ -211,7 +211,7 @@ def context(input): @contextlib.contextmanager -def patch_candidate_selection(packages): +def patch_candidate_selection(packages, options): def preprocessing(input): if not input.candidates: return @@ -223,14 +223,12 @@ def preprocessing(input): if not package: return - input.candidates = list(package.filter_candidates(input.candidates)) - - vanilla_sort_key = CandidateEvaluator._sort_key + input.candidates = list(package.filter_candidates(input.candidates, options)) def patched_sort_key(candidate_evaluator, candidate): package = packages.get(candidate.name) assert package - return package.make_sort_key(candidate) + return package.make_sort_key(candidate, options) @contextlib.contextmanager def context(input): From 3058d645a001c23ab5765e97242c3b0ada62bd5c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 23 Jan 2023 11:17:53 +0100 Subject: [PATCH 4/7] remove LTS --- light_the_torch/_packages.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/light_the_torch/_packages.py b/light_the_torch/_packages.py index c8eac1d..05ed797 100644 --- a/light_the_torch/_packages.py +++ b/light_the_torch/_packages.py @@ -34,7 +34,6 @@ class Channel(enum.Enum): STABLE = enum.auto() TEST = enum.auto() NIGHTLY = enum.auto() - LTS = enum.auto() @classmethod def from_str(cls, string): @@ -48,13 +47,6 @@ class _PyTorchDistribution(_Package): def _get_extra_index_urls(self, computation_backends, channel): if channel == Channel.STABLE: channel_paths = [""] - elif channel == Channel.LTS: - channel_paths = [ - f"lts/{major}.{minor}/" - for major, minor in [ - (1, 8), - ] - ] else: channel_paths = [f"{channel.name.lower()}/"] return [ From 9cbd8ff49a47c1f2a7b66f72c4775bd0cc7dea7b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 23 Jan 2023 11:29:07 +0100 Subject: [PATCH 5/7] create a patch package rather than module --- light_the_torch/_cli.py | 6 +- light_the_torch/_patch/__init__.py | 1 + light_the_torch/_patch/cli.py | 118 ++++++++++++++++++ .../{_packages.py => _patch/packages.py} | 17 +-- .../{_patch.py => _patch/patch.py} | 113 +---------------- .../{_utils.py => _patch/utils.py} | 0 tests/test_cli.py | 2 +- tests/test_computation_backend.py | 8 +- tests/test_smoke.py | 4 +- 9 files changed, 137 insertions(+), 132 deletions(-) create mode 100644 light_the_torch/_patch/__init__.py create mode 100644 light_the_torch/_patch/cli.py rename light_the_torch/{_packages.py => _patch/packages.py} (92%) rename light_the_torch/{_patch.py => _patch/patch.py} (50%) rename light_the_torch/{_utils.py => _patch/utils.py} (100%) diff --git a/light_the_torch/_cli.py b/light_the_torch/_cli.py index 2515ec7..c3b5227 100644 --- a/light_the_torch/_cli.py +++ b/light_the_torch/_cli.py @@ -1,5 +1,5 @@ -from pip._internal.cli.main import main as pip_main +from pip._internal.cli.main import main -from ._patch import patch +from ._patch import patch_pip_main -main = patch(pip_main) +main = patch_pip_main(main) diff --git a/light_the_torch/_patch/__init__.py b/light_the_torch/_patch/__init__.py new file mode 100644 index 0000000..7909415 --- /dev/null +++ b/light_the_torch/_patch/__init__.py @@ -0,0 +1 @@ +from .patch import patch_pip_main diff --git a/light_the_torch/_patch/cli.py b/light_the_torch/_patch/cli.py new file mode 100644 index 0000000..fe86294 --- /dev/null +++ b/light_the_torch/_patch/cli.py @@ -0,0 +1,118 @@ +import dataclasses +import enum +import optparse +import os +from typing import List, Set + +import light_the_torch._cb as cb + + +class Channel(enum.Enum): + STABLE = enum.auto() + TEST = enum.auto() + NIGHTLY = enum.auto() + + @classmethod + def from_str(cls, string): + return cls[string.upper()] + + +# adapted from https://stackoverflow.com/a/9307174 +class PassThroughOptionParser(optparse.OptionParser): + def __init__(self): + super().__init__(add_help_option=False) + + def _process_args(self, largs, rargs, values): + while rargs: + try: + super()._process_args(largs, rargs, values) + except (optparse.BadOptionError, optparse.AmbiguousOptionError) as error: + largs.append(error.opt_str) + + +@dataclasses.dataclass +class LttOptions: + computation_backends: Set[cb.ComputationBackend] = dataclasses.field( + default_factory=lambda: {cb.CPUBackend()} + ) + channel: Channel = Channel.STABLE + + @staticmethod + def computation_backend_parser_options(): + return [ + optparse.Option( + "--pytorch-computation-backend", + help=( + "Computation backend for compiled PyTorch distributions, " + "e.g. 'cu102', 'cu115', or 'cpu'. " + "Multiple computation backends can be passed as a comma-separated " + "list, e.g 'cu102,cu113,cu116'. " + "If not specified, the computation backend is detected from the " + "available hardware, preferring CUDA over CPU." + ), + ), + optparse.Option( + "--cpuonly", + action="store_true", + help=( + "Shortcut for '--pytorch-computation-backend=cpu'. " + "If '--computation-backend' is used simultaneously, " + "it takes precedence over '--cpuonly'." + ), + ), + ] + + @staticmethod + def channel_parser_option() -> optparse.Option: + return optparse.Option( + "--pytorch-channel", + help=( + "Channel to download PyTorch distributions from, e.g. 'stable' , " + "'test', 'nightly' and 'lts'. " + "If not specified, defaults to 'stable' unless '--pre' is given in " + "which case it defaults to 'test'." + ), + ) + + @staticmethod + def _parse(argv): + parser = PassThroughOptionParser() + + for option in LttOptions.computation_backend_parser_options(): + parser.add_option(option) + parser.add_option(LttOptions.channel_parser_option()) + parser.add_option("--pre", dest="pre", action="store_true") + + opts, _ = parser.parse_args(argv) + return opts + + @classmethod + def from_pip_argv(cls, argv: List[str]): + if not argv or argv[0] != "install": + return cls() + + opts = cls._parse(argv) + + if opts.pytorch_computation_backend is not None: + cbs = { + cb.ComputationBackend.from_str(string.strip()) + for string in opts.pytorch_computation_backend.split(",") + } + elif opts.cpuonly: + cbs = {cb.CPUBackend()} + elif "LTT_PYTORCH_COMPUTATION_BACKEND" in os.environ: + cbs = { + cb.ComputationBackend.from_str(string.strip()) + for string in os.environ["LTT_PYTORCH_COMPUTATION_BACKEND"].split(",") + } + else: + cbs = cb.detect_compatible_computation_backends() + + if opts.pytorch_channel is not None: + channel = Channel.from_str(opts.pytorch_channel) + elif opts.pre: + channel = Channel.TEST + else: + channel = Channel.STABLE + + return cls(cbs, channel) diff --git a/light_the_torch/_packages.py b/light_the_torch/_patch/packages.py similarity index 92% rename from light_the_torch/_packages.py rename to light_the_torch/_patch/packages.py index 05ed797..44c11e3 100644 --- a/light_the_torch/_packages.py +++ b/light_the_torch/_patch/packages.py @@ -1,12 +1,13 @@ import abc import dataclasses -import enum import itertools import re from pip._internal.models.search_scope import SearchScope -from . import _cb as cb +import light_the_torch._cb as cb + +from .cli import Channel __all__ = ["packages"] @@ -28,18 +29,6 @@ def make_sort_key(self, candidate, options): pass -# FIXME: move this to cli patch -# create patch.cli and patch.packages -class Channel(enum.Enum): - STABLE = enum.auto() - TEST = enum.auto() - NIGHTLY = enum.auto() - - @classmethod - def from_str(cls, string): - return cls[string.upper()] - - packages = {} diff --git a/light_the_torch/_patch.py b/light_the_torch/_patch/patch.py similarity index 50% rename from light_the_torch/_patch.py rename to light_the_torch/_patch/patch.py index f0bede9..20ba1c3 100644 --- a/light_the_torch/_patch.py +++ b/light_the_torch/_patch/patch.py @@ -1,23 +1,19 @@ import contextlib -import dataclasses import functools -import optparse -import os import sys import unittest.mock -from typing import List, Set from unittest import mock import pip._internal.cli.cmdoptions from pip._internal.index.package_finder import CandidateEvaluator import light_the_torch as ltt -from . import _cb as cb -from ._packages import Channel, packages -from ._utils import apply_fn_patch +from .cli import LttOptions +from .packages import packages +from .utils import apply_fn_patch -def patch(pip_main): +def patch_pip_main(pip_main): @functools.wraps(pip_main) def wrapper(argv=None): if argv is None: @@ -29,107 +25,6 @@ def wrapper(argv=None): return wrapper -# adapted from https://stackoverflow.com/a/9307174 -class PassThroughOptionParser(optparse.OptionParser): - def __init__(self): - super().__init__(add_help_option=False) - - def _process_args(self, largs, rargs, values): - while rargs: - try: - super()._process_args(largs, rargs, values) - except (optparse.BadOptionError, optparse.AmbiguousOptionError) as error: - largs.append(error.opt_str) - - -@dataclasses.dataclass -class LttOptions: - computation_backends: Set[cb.ComputationBackend] = dataclasses.field( - default_factory=lambda: {cb.CPUBackend()} - ) - channel: Channel = Channel.STABLE - - @staticmethod - def computation_backend_parser_options(): - return [ - optparse.Option( - "--pytorch-computation-backend", - help=( - "Computation backend for compiled PyTorch distributions, " - "e.g. 'cu102', 'cu115', or 'cpu'. " - "Multiple computation backends can be passed as a comma-separated " - "list, e.g 'cu102,cu113,cu116'. " - "If not specified, the computation backend is detected from the " - "available hardware, preferring CUDA over CPU." - ), - ), - optparse.Option( - "--cpuonly", - action="store_true", - help=( - "Shortcut for '--pytorch-computation-backend=cpu'. " - "If '--computation-backend' is used simultaneously, " - "it takes precedence over '--cpuonly'." - ), - ), - ] - - @staticmethod - def channel_parser_option() -> optparse.Option: - return optparse.Option( - "--pytorch-channel", - help=( - "Channel to download PyTorch distributions from, e.g. 'stable' , " - "'test', 'nightly' and 'lts'. " - "If not specified, defaults to 'stable' unless '--pre' is given in " - "which case it defaults to 'test'." - ), - ) - - @staticmethod - def _parse(argv): - parser = PassThroughOptionParser() - - for option in LttOptions.computation_backend_parser_options(): - parser.add_option(option) - parser.add_option(LttOptions.channel_parser_option()) - parser.add_option("--pre", dest="pre", action="store_true") - - opts, _ = parser.parse_args(argv) - return opts - - @classmethod - def from_pip_argv(cls, argv: List[str]): - if not argv or argv[0] != "install": - return cls() - - opts = cls._parse(argv) - - if opts.pytorch_computation_backend is not None: - cbs = { - cb.ComputationBackend.from_str(string.strip()) - for string in opts.pytorch_computation_backend.split(",") - } - elif opts.cpuonly: - cbs = {cb.CPUBackend()} - elif "LTT_PYTORCH_COMPUTATION_BACKEND" in os.environ: - cbs = { - cb.ComputationBackend.from_str(string.strip()) - for string in os.environ["LTT_PYTORCH_COMPUTATION_BACKEND"].split(",") - } - else: - cbs = cb.detect_compatible_computation_backends() - - if opts.pytorch_channel is not None: - channel = Channel.from_str(opts.pytorch_channel) - elif opts.pre: - channel = Channel.TEST - else: - channel = Channel.STABLE - - return cls(cbs, channel) - - @contextlib.contextmanager def apply_patches(argv): options = LttOptions.from_pip_argv(argv) diff --git a/light_the_torch/_utils.py b/light_the_torch/_patch/utils.py similarity index 100% rename from light_the_torch/_utils.py rename to light_the_torch/_patch/utils.py diff --git a/tests/test_cli.py b/tests/test_cli.py index 766d05b..ef53b15 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -77,7 +77,7 @@ def check_fn(text): @pytest.fixture def set_argv(mocker): def patch(*options): - return mocker.patch.object(sys, "argv", ["ltt", *options]) + return mocker.patch_pip_main.object(sys, "argv", ["ltt", *options]) return patch diff --git a/tests/test_computation_backend.py b/tests/test_computation_backend.py index b26bca3..95b050b 100644 --- a/tests/test_computation_backend.py +++ b/tests/test_computation_backend.py @@ -152,7 +152,7 @@ def test_cuda_vs_rocm(self): @pytest.fixture def patch_nvidia_driver_version(mocker): def factory(version): - return mocker.patch( + return mocker.patch_pip_main( "light_the_torch._cb.subprocess.run", return_value=SimpleNamespace(stdout=f"driver_version\n{version}"), ) @@ -208,7 +208,7 @@ def cuda_backends_params(): class TestDetectCompatibleComputationBackends: def test_no_nvidia_driver(self, mocker): - mocker.patch( + mocker.patch_pip_main( "light_the_torch._cb.subprocess.run", side_effect=subprocess.CalledProcessError(1, ""), ) @@ -224,7 +224,9 @@ def test_cuda_backends( nvidia_driver_version, compatible_cuda_backends, ): - mocker.patch("light_the_torch._cb.platform.system", return_value=system) + mocker.patch_pip_main( + "light_the_torch._cb.platform.system", return_value=system + ) patch_nvidia_driver_version(nvidia_driver_version) backends = cb.detect_compatible_computation_backends() diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 3a69e0a..d9444a3 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -103,12 +103,12 @@ def patched_import(name, globals, locals, fromlist, level): return __import__(name, globals, locals, fromlist, level) - mocker.patch.object(builtins, "__import__", new=patched_import) + mocker.patch_pip_main.object(builtins, "__import__", new=patched_import) values = { name: module for name, module in sys.modules.items() if retain_condition(name) } - mocker.patch.dict(sys.modules, clear=True, values=values) + mocker.patch_pip_main.dict(sys.modules, clear=True, values=values) def test_version_not_installed(mocker): From 7df4d050f2bb129614eee75044a85887754859da Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 23 Jan 2023 14:05:15 +0100 Subject: [PATCH 6/7] partially port #117 --- light_the_torch/_patch/utils.py | 58 ++++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/light_the_torch/_patch/utils.py b/light_the_torch/_patch/utils.py index 19fa064..8cc3b3b 100644 --- a/light_the_torch/_patch/utils.py +++ b/light_the_torch/_patch/utils.py @@ -6,17 +6,31 @@ from unittest import mock +from pip._vendor.packaging.requirements import Requirement -class InternalError(RuntimeError): - def __init__(self) -> None: - # TODO: check against pip version - # TODO: fix wording - msg = ( - "Unexpected internal pytorch-pip-shim error. If you ever encounter this " - "message during normal operation, please submit a bug report at " - "https://github.com/pmeier/pytorch-pip-shim/issues" +from light_the_torch._compat import importlib_metadata + + +class UnexpectedInternalError(Exception): + def __init__(self, msg) -> None: + actual_pip_version = Requirement(f"pip=={importlib_metadata.version('pip')}") + required_pip_version = next( + requirement + for requirement in ( + Requirement(requirement_string) + for requirement_string in importlib_metadata.requires("light_the_torch") + ) + if requirement.name == "pip" + ) + super().__init__( + f"{msg}\n\n" + f"This can happen when the actual pip version (`{actual_pip_version}`) " + f"and the one required by light-the-torch (`{required_pip_version}`) " + f"are out of sync.\n" + f"If that is the case, please reinstall light-the-torch. " + f"Otherwise, please submit a bug report at " + f"https://github.com/pmeier/light-the-torch/issues" ) - super().__init__(msg) class Input(dict): @@ -77,7 +91,7 @@ def apply_fn_patch( postprocessing=lambda input, output: output, ): target = ".".join(parts) - fn = import_fn(target) + fn = import_obj(target) @functools.wraps(fn) def new(*args, **kwargs): @@ -93,7 +107,7 @@ def new(*args, **kwargs): yield -def import_fn(target: str): +def import_obj(target: str): attrs = [] name = target while name: @@ -101,13 +115,25 @@ def import_fn(target: str): module = importlib.import_module(name) break except ImportError: - name, attr = name.rsplit(".", 1) - attrs.append(attr) + try: + name, attr = name.rsplit(".", 1) + except ValueError: + attr = name + name = "" + attrs.insert(0, attr) else: - raise InternalError + raise UnexpectedInternalError( + f"Tried to import `{target}`, " + f"but the top-level namespace `{attrs[0]}` doesn't seem to be a module." + ) obj = module - for attr in attrs[::-1]: - obj = getattr(obj, attr) + for attr in attrs: + try: + obj = getattr(obj, attr) + except AttributeError: + raise UnexpectedInternalError( + f"Failed to access `{attr}` from `{obj.__name__}`" + ) from None return obj From c01c79a796a2c4bc370a90c72274bb4512678f3a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 23 Jan 2023 14:06:12 +0100 Subject: [PATCH 7/7] fix --- light_the_torch/_compat.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 light_the_torch/_compat.py diff --git a/light_the_torch/_compat.py b/light_the_torch/_compat.py new file mode 100644 index 0000000..7aeb237 --- /dev/null +++ b/light_the_torch/_compat.py @@ -0,0 +1,8 @@ +import sys + +if sys.version_info >= (3, 8): + import importlib.metadata as importlib_metadata +else: + import importlib_metadata + +__all__ = ["importlib_metadata"]