diff --git a/docker/pytorch_base/Dockerfile.prebuilt b/docker/pytorch_base/Dockerfile.prebuilt index 563f5183c..35a5d1a87 100644 --- a/docker/pytorch_base/Dockerfile.prebuilt +++ b/docker/pytorch_base/Dockerfile.prebuilt @@ -35,9 +35,14 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ && /opt/conda/bin/pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} \ tensorrt torch-tensorrt==${TORCH_TENSORRT_VERSION} --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | tr -d .) \ # && /opt/conda/bin/pip install "nvidia-modelopt" --extra-index-url https://pypi.nvidia.com \ + && find /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia -name '*.so*' -path '*/lib/*' -printf '%h\n' 2>/dev/null | sort -u > /etc/ld.so.conf.d/nvidia-pip.conf \ + && echo "/opt/conda/lib/python${PYTHON_VERSION}/site-packages/tensorrt_libs" >> /etc/ld.so.conf.d/nvidia-pip.conf \ + && ldconfig \ && apt-mark auto '.*' > /dev/null \ && apt-mark manual $preexistingAptMark ca-certificates > /dev/null \ - && find /usr/local/lib /opt/conda -type f -executable -exec ldd '{}' ';' \ + && find /usr/local/lib /opt/conda \ + \( -path '*/site-packages/nvidia' -o -path '*/site-packages/torch' -o -path '*/site-packages/tensorrt*' -o -path '*/site-packages/triton' \) -prune \ + -o -type f -executable -exec ldd '{}' ';' \ | awk '/=>/ { print $(NF-1) }' \ | sort -u \ | grep -v not \ @@ -55,4 +60,4 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ /usr/share/doc \ /usr/share/doc-base -CMD ["/bin/bash"] \ No newline at end of file +CMD ["/bin/bash"] diff --git a/tests/unit/mazepa/test_semaphores.py b/tests/unit/mazepa/test_semaphores.py index 439ea81e8..732ccedc1 100644 --- a/tests/unit/mazepa/test_semaphores.py +++ b/tests/unit/mazepa/test_semaphores.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True) def cleanup_semaphores(): yield - sema_types: List[SemaphoreType] = ["read", "write", "cuda", "cpu"] + sema_types: List[SemaphoreType] = list(get_args(SemaphoreType)) for name in sema_types: try: # two unlinks in case grandparent semaphore exists @@ -32,6 +32,10 @@ def cleanup_semaphores(): semaphore(name).unlink() except: pass + try: + TimingTracker(name, pid=os.getpid()).unlink() + except: + pass def test_default_semaphore_init(): diff --git a/zetta_utils/convnet/utils.py b/zetta_utils/convnet/utils.py index 56f49884d..31fcb97c5 100644 --- a/zetta_utils/convnet/utils.py +++ b/zetta_utils/convnet/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import io +import os from typing import Literal, Optional, Sequence, Union, overload import cachetools @@ -8,28 +9,50 @@ import onnx import onnx2torch import torch +import xxhash from numpy import typing as npt from typeguard import typechecked from zetta_utils import builder, log, tensor_ops +from zetta_utils.mazepa import semaphore logger = log.get_logger("zetta_utils") +TENSORRT_AVAILABLE = False +try: + import torch_tensorrt # pylint: disable=import-error + + TENSORRT_AVAILABLE = True # pragma: no cover +except (ImportError, OSError) as e: + logger.info(f"torch_tensorrt is not available: {e}") + + @builder.register("load_model") @typechecked def load_model( - path: str, device: Union[str, torch.device] = "cpu", use_cache: bool = False + path: str, + device: Union[str, torch.device] = "cpu", + use_cache: bool = False, + input_shape: Sequence[int] | None = None, + tensorrt_enabled: bool = False, + tensorrt_cache_dir: str = ".", # defaults to the current working directory ) -> torch.nn.Module: # pragma: no cover if use_cache: - result = _load_model_cached(path, device) + result = _load_model_cached( + path, device, input_shape, tensorrt_enabled, tensorrt_cache_dir + ) else: - result = _load_model(path, device) + result = _load_model(path, device, input_shape, tensorrt_enabled, tensorrt_cache_dir) return result def _load_model( - path: str, device: Union[str, torch.device] = "cpu" + path: str, + device: Union[str, torch.device] = "cpu", + input_shape: Sequence[int] | None = None, + tensorrt_enabled: bool = False, + tensorrt_cache_dir: str = ".", ) -> torch.nn.Module: # pragma: no cover logger.debug(f"Loading model from '{path}'") if path.endswith(".json"): @@ -40,8 +63,55 @@ def _load_model( elif path.endswith(".onnx"): with fsspec.open(path, "rb") as f: result = onnx2torch.convert(onnx.load(f)).to(device) + elif path.endswith(".ts"): + # load a cached TensorRT model + result = torch.export.load(path).module() else: raise ValueError(f"Unsupported file format: {path}") + + if tensorrt_enabled: + if not TENSORRT_AVAILABLE: + raise RuntimeError("torch_tensorrt is not installed!") + + with semaphore("tensorrt"): + # TensorRT should not be compiled concurrently by many threads + # or will run out of memory + # Ideally, only by one thread and the others then load the cached model + assert input_shape is not None # mypy + + gpu_capability = torch.cuda.get_device_capability(device) + trt_fname = ( + str( + xxhash.xxh128( + str((path, tuple(input_shape), gpu_capability)).encode("utf-8") + ).hexdigest() + ) + + ".trt.ep" + ) + cache_path = os.path.join(tensorrt_cache_dir, trt_fname) + + # Try to load the optimized model from cache + try: + result = torch_tensorrt.load(cache_path).module() + logger.info(f"Loaded cached TensorRT model: {cache_path}") + return result + except Exception as e: # pylint: disable=broad-exception-caught + logger.info(f"Cache not found or invalid, compiling TensorRT model: {e}") + + example_in = torch.rand(input_shape).to(device=device) + + try: + compiled = torch_tensorrt.compile( + result, + inputs=[example_in], + enabled_precisions={torch.float, torch.half}, + ) + torch_tensorrt.save(compiled, cache_path) + logger.info(f"Compiled TensorRT model saved to cache: {cache_path}") + result = compiled + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning(f"TensorRT compilation failed, falling back to eager mode: {e}") + return result @@ -112,16 +182,34 @@ def load_and_run_model( @typechecked -def load_and_run_model(path, data_in, device=None, use_cache=True): # pragma: no cover +def load_and_run_model( + path, + data_in, + device=None, + use_cache=True, + tensorrt_enabled: bool = False, + tensorrt_cache_dir: str = ".", +): # pragma: no cover if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - model = load_model(path=path, device=device, use_cache=use_cache) + model = load_model( + path=path, + device=device, + use_cache=use_cache, + input_shape=data_in.shape, + tensorrt_enabled=tensorrt_enabled, + tensorrt_cache_dir=tensorrt_cache_dir, + ) autocast_device = device.type if isinstance(device, torch.device) else str(device) with torch.inference_mode(): # uses less memory when used with JITs with torch.autocast(device_type=autocast_device): - output = model(tensor_ops.convert.to_torch(data_in, device=device)) + gpu_input = tensor_ops.convert.to_torch(data_in, device=device) + output = model(gpu_input) + del gpu_input output = tensor_ops.convert.astype(output, reference=data_in, cast=True) + torch.cuda.empty_cache() + return output diff --git a/zetta_utils/mazepa/semaphores.py b/zetta_utils/mazepa/semaphores.py index 4894f6f41..bf0bcb7f4 100644 --- a/zetta_utils/mazepa/semaphores.py +++ b/zetta_utils/mazepa/semaphores.py @@ -18,7 +18,7 @@ from zetta_utils.common.pprint import lrpad logger = log.get_logger("mazepa") -SemaphoreType = Literal["read", "write", "cuda", "cpu"] +SemaphoreType = Literal["read", "write", "cuda", "cpu", "tensorrt"] DEFAULT_SEMA_COUNT = 1 TIMING_FORMAT = "dddd" # wait_time, lease_time, lease_count, start_time @@ -153,15 +153,15 @@ def configure_semaphores( Context manager for creating and destroying semaphores. """ - sema_types_to_check: List[SemaphoreType] = ["read", "write", "cuda", "cpu"] + sema_types_to_check: List[SemaphoreType] = ["read", "write", "cuda", "cpu", "tensorrt"] if semaphores_spec is not None: for name in semaphores_spec: if name not in get_args(SemaphoreType): raise ValueError(f"`{name}` is not a valid semaphore type.") try: - for sema_type in sema_types_to_check: + _required: tuple[SemaphoreType, ...] = ("read", "write", "cuda", "cpu") + for sema_type in _required: assert semaphores_spec[sema_type] >= 0 - semaphores_spec_ = semaphores_spec except KeyError as e: raise ValueError( "`semaphores_spec` given to `execute_with_pool` must contain " @@ -169,7 +169,7 @@ def configure_semaphores( ) from e except AssertionError as e: raise ValueError("Number of semaphores must be nonnegative.") from e - semaphores_spec_ = semaphores_spec + semaphores_spec_ = {"tensorrt": DEFAULT_SEMA_COUNT, **semaphores_spec} else: semaphores_spec_ = {name: DEFAULT_SEMA_COUNT for name in sema_types_to_check}