From 251de79b58495347e81ed99f05dda6f190c95839 Mon Sep 17 00:00:00 2001 From: trivoldus28 Date: Mon, 5 Jan 2026 04:42:10 +0000 Subject: [PATCH 1/7] add tensorrt support --- zetta_utils/convnet/utils.py | 97 +++++++++++++++++++++++++++++--- zetta_utils/mazepa/semaphores.py | 9 +-- 2 files changed, 95 insertions(+), 11 deletions(-) diff --git a/zetta_utils/convnet/utils.py b/zetta_utils/convnet/utils.py index 56f49884d..56590da2a 100644 --- a/zetta_utils/convnet/utils.py +++ b/zetta_utils/convnet/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import io +import os from typing import Literal, Optional, Sequence, Union, overload import cachetools @@ -8,28 +9,50 @@ import onnx import onnx2torch import torch +import xxhash from numpy import typing as npt from typeguard import typechecked from zetta_utils import builder, log, tensor_ops +from zetta_utils.mazepa import semaphore logger = log.get_logger("zetta_utils") +TENSORRT_AVAILABLE = False +try: + import torch_tensorrt + + TENSORRT_AVAILABLE = True +except ImportError as e: + print(f"torch_tensorrt is not available: {e}") + + @builder.register("load_model") -@typechecked +# @typechecked def load_model( - path: str, device: Union[str, torch.device] = "cpu", use_cache: bool = False + path: str, + device: Union[str, torch.device] = "cpu", + use_cache: bool = False, + input_shape: Sequence[int] | None = None, + tensorrt_enabled: bool = False, + tensorrt_cache_dir: str = ".", # defaults to the current working directory ) -> torch.nn.Module: # pragma: no cover if use_cache: - result = _load_model_cached(path, device) + result = _load_model_cached( + path, device, input_shape, tensorrt_enabled, tensorrt_cache_dir + ) else: - result = _load_model(path, device) + result = _load_model(path, device, input_shape, tensorrt_enabled, tensorrt_cache_dir) return result def _load_model( - path: str, device: Union[str, torch.device] = "cpu" + path: str, + device: Union[str, torch.device] = "cpu", + input_shape: Sequence[int] | None = None, + tensorrt_enabled: bool = False, + tensorrt_cache_dir: str = ".", ) -> torch.nn.Module: # pragma: no cover logger.debug(f"Loading model from '{path}'") if path.endswith(".json"): @@ -40,8 +63,53 @@ def _load_model( elif path.endswith(".onnx"): with fsspec.open(path, "rb") as f: result = onnx2torch.convert(onnx.load(f)).to(device) + elif path.endswith(".ts"): + # load a cached TensorRT model + result = torch.export.load(path) else: raise ValueError(f"Unsupported file format: {path}") + + if tensorrt_enabled: + if not TENSORRT_AVAILABLE: + raise RuntimeError("torch_tensorrt is not installed!") + + with semaphore("trt_compilation"): + # TensorRT should not be compiled concurrently by many threads or will run out of memory + # Ideally, only by one thread and the others then load the cached model + + trt_fname = ( + str(xxhash.xxh128(str((path, tuple(input_shape))).encode("utf-8")).hexdigest()) + + ".trt.ts" + ) + cache_path = os.path.join(tensorrt_cache_dir, trt_fname) + + # Try to load the optimized model from cache + try: + with fsspec.open(cache_path, "rb") as f: + return torch_tensorrt.load(f) + except FileNotFoundError: + print(f"Cache not found. Compiling TensorRT model: {cache_path}") + except Exception as e: + print(f"Error loading TensorRT model from cache: {e}") + + example_in = torch.rand(input_shape).to(device=device) + + with torch.inference_mode(): + trace = torch.jit.trace(result, example_in) + + result = torch_tensorrt.ts.compile( + trace, + inputs=[example_in], + truncate_long_and_double=True, + enabled_precisions={torch.float, torch.half}, + debug=False, + ) + + # save optimized model + with fsspec.open(cache_path, "wb") as f: + torch_tensorrt.save(result, f, output_format="torchscript", inputs=[example_in]) + print(f"Compiled TensorRT model saved to cache: {cache_path}") + return result @@ -112,16 +180,31 @@ def load_and_run_model( @typechecked -def load_and_run_model(path, data_in, device=None, use_cache=True): # pragma: no cover +def load_and_run_model( + path, + data_in, + device=None, + use_cache=True, + tensorrt_enabled: bool = False, + tensorrt_cache_dir: str = ".", +): # pragma: no cover if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - model = load_model(path=path, device=device, use_cache=use_cache) + model = load_model( + path=path, + device=device, + use_cache=use_cache, + input_shape=data_in.shape, + tensorrt_enabled=tensorrt_enabled, + tensorrt_cache_dir=tensorrt_cache_dir, + ) autocast_device = device.type if isinstance(device, torch.device) else str(device) with torch.inference_mode(): # uses less memory when used with JITs with torch.autocast(device_type=autocast_device): output = model(tensor_ops.convert.to_torch(data_in, device=device)) output = tensor_ops.convert.astype(output, reference=data_in, cast=True) + return output diff --git a/zetta_utils/mazepa/semaphores.py b/zetta_utils/mazepa/semaphores.py index 4894f6f41..5d386cac8 100644 --- a/zetta_utils/mazepa/semaphores.py +++ b/zetta_utils/mazepa/semaphores.py @@ -18,7 +18,7 @@ from zetta_utils.common.pprint import lrpad logger = log.get_logger("mazepa") -SemaphoreType = Literal["read", "write", "cuda", "cpu"] +SemaphoreType = Literal["read", "write", "cuda", "cpu", "trt_compilation"] DEFAULT_SEMA_COUNT = 1 TIMING_FORMAT = "dddd" # wait_time, lease_time, lease_count, start_time @@ -153,14 +153,15 @@ def configure_semaphores( Context manager for creating and destroying semaphores. """ - sema_types_to_check: List[SemaphoreType] = ["read", "write", "cuda", "cpu"] + sema_types_to_check: List[SemaphoreType] = ["read", "write", "cuda", "cpu", "trt_compilation"] if semaphores_spec is not None: for name in semaphores_spec: if name not in get_args(SemaphoreType): raise ValueError(f"`{name}` is not a valid semaphore type.") try: - for sema_type in sema_types_to_check: - assert semaphores_spec[sema_type] >= 0 + # TODO: need to make trt_compilation optional + # for sema_type in sema_types_to_check: + # assert semaphores_spec[sema_type] >= 0 semaphores_spec_ = semaphores_spec except KeyError as e: raise ValueError( From 666d8bdef50df3a0b51cd2b6dc3d7d5a8475777b Mon Sep 17 00:00:00 2001 From: Dodam Ih Date: Mon, 9 Mar 2026 11:15:27 -0700 Subject: [PATCH 2/7] feat: TensorRT shared library discovery in Docker Add ldconfig for nvidia/tensorrt .so paths, prune large site-packages dirs from ldd scan to speed up build. Co-Authored-By: Claude Opus 4.6 --- docker/pytorch_base/Dockerfile.prebuilt | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/docker/pytorch_base/Dockerfile.prebuilt b/docker/pytorch_base/Dockerfile.prebuilt index 563f5183c..35a5d1a87 100644 --- a/docker/pytorch_base/Dockerfile.prebuilt +++ b/docker/pytorch_base/Dockerfile.prebuilt @@ -35,9 +35,14 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ && /opt/conda/bin/pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} \ tensorrt torch-tensorrt==${TORCH_TENSORRT_VERSION} --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | tr -d .) \ # && /opt/conda/bin/pip install "nvidia-modelopt" --extra-index-url https://pypi.nvidia.com \ + && find /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia -name '*.so*' -path '*/lib/*' -printf '%h\n' 2>/dev/null | sort -u > /etc/ld.so.conf.d/nvidia-pip.conf \ + && echo "/opt/conda/lib/python${PYTHON_VERSION}/site-packages/tensorrt_libs" >> /etc/ld.so.conf.d/nvidia-pip.conf \ + && ldconfig \ && apt-mark auto '.*' > /dev/null \ && apt-mark manual $preexistingAptMark ca-certificates > /dev/null \ - && find /usr/local/lib /opt/conda -type f -executable -exec ldd '{}' ';' \ + && find /usr/local/lib /opt/conda \ + \( -path '*/site-packages/nvidia' -o -path '*/site-packages/torch' -o -path '*/site-packages/tensorrt*' -o -path '*/site-packages/triton' \) -prune \ + -o -type f -executable -exec ldd '{}' ';' \ | awk '/=>/ { print $(NF-1) }' \ | sort -u \ | grep -v not \ @@ -55,4 +60,4 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ /usr/share/doc \ /usr/share/doc-base -CMD ["/bin/bash"] \ No newline at end of file +CMD ["/bin/bash"] From f18787bed2e63718d6a4f2279f32a53cd2b53bbe Mon Sep 17 00:00:00 2001 From: Dodam Ih Date: Mon, 9 Mar 2026 11:15:32 -0700 Subject: [PATCH 3/7] feat: update TensorRT compilation to new API Switch from torchscript-based TRT compilation to torch_tensorrt.compile with ExportedProgram (.ep) format. Add GPU memory cleanup, improve error handling with fallback to eager mode. Co-Authored-By: Claude Opus 4.6 --- zetta_utils/convnet/utils.py | 48 +++++++++++++++++------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/zetta_utils/convnet/utils.py b/zetta_utils/convnet/utils.py index 56590da2a..68d6e310d 100644 --- a/zetta_utils/convnet/utils.py +++ b/zetta_utils/convnet/utils.py @@ -24,8 +24,8 @@ import torch_tensorrt TENSORRT_AVAILABLE = True -except ImportError as e: - print(f"torch_tensorrt is not available: {e}") +except Exception as e: + logger.info(f"torch_tensorrt is not available: {e}") @builder.register("load_model") @@ -73,42 +73,37 @@ def _load_model( if not TENSORRT_AVAILABLE: raise RuntimeError("torch_tensorrt is not installed!") - with semaphore("trt_compilation"): + with semaphore("tensorrt"): # TensorRT should not be compiled concurrently by many threads or will run out of memory # Ideally, only by one thread and the others then load the cached model trt_fname = ( str(xxhash.xxh128(str((path, tuple(input_shape))).encode("utf-8")).hexdigest()) - + ".trt.ts" + + ".trt.ep" ) cache_path = os.path.join(tensorrt_cache_dir, trt_fname) # Try to load the optimized model from cache try: - with fsspec.open(cache_path, "rb") as f: - return torch_tensorrt.load(f) - except FileNotFoundError: - print(f"Cache not found. Compiling TensorRT model: {cache_path}") + result = torch_tensorrt.load(cache_path).module() + logger.info(f"Loaded cached TensorRT model: {cache_path}") + return result except Exception as e: - print(f"Error loading TensorRT model from cache: {e}") + logger.info(f"Cache not found or invalid, compiling TensorRT model: {e}") example_in = torch.rand(input_shape).to(device=device) - with torch.inference_mode(): - trace = torch.jit.trace(result, example_in) - - result = torch_tensorrt.ts.compile( - trace, - inputs=[example_in], - truncate_long_and_double=True, - enabled_precisions={torch.float, torch.half}, - debug=False, - ) - - # save optimized model - with fsspec.open(cache_path, "wb") as f: - torch_tensorrt.save(result, f, output_format="torchscript", inputs=[example_in]) - print(f"Compiled TensorRT model saved to cache: {cache_path}") + try: + compiled = torch_tensorrt.compile( + result, + inputs=[example_in], + enabled_precisions={torch.float, torch.half}, + ) + torch_tensorrt.save(compiled, cache_path) + logger.info(f"Compiled TensorRT model saved to cache: {cache_path}") + result = compiled + except Exception as e: + logger.warning(f"TensorRT compilation failed, falling back to eager mode: {e}") return result @@ -204,7 +199,10 @@ def load_and_run_model( autocast_device = device.type if isinstance(device, torch.device) else str(device) with torch.inference_mode(): # uses less memory when used with JITs with torch.autocast(device_type=autocast_device): - output = model(tensor_ops.convert.to_torch(data_in, device=device)) + gpu_input = tensor_ops.convert.to_torch(data_in, device=device) + output = model(gpu_input) + del gpu_input output = tensor_ops.convert.astype(output, reference=data_in, cast=True) + torch.cuda.empty_cache() return output From 0b7b403971464558792187eb47deb005f951ec52 Mon Sep 17 00:00:00 2001 From: Dodam Ih Date: Mon, 9 Mar 2026 11:15:36 -0700 Subject: [PATCH 4/7] refactor: rename trt_compilation semaphore to tensorrt Make tensorrt semaphore optional with a default count, so existing specs without it don't break. Co-Authored-By: Claude Opus 4.6 --- zetta_utils/mazepa/semaphores.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/zetta_utils/mazepa/semaphores.py b/zetta_utils/mazepa/semaphores.py index 5d386cac8..76e73b311 100644 --- a/zetta_utils/mazepa/semaphores.py +++ b/zetta_utils/mazepa/semaphores.py @@ -18,7 +18,7 @@ from zetta_utils.common.pprint import lrpad logger = log.get_logger("mazepa") -SemaphoreType = Literal["read", "write", "cuda", "cpu", "trt_compilation"] +SemaphoreType = Literal["read", "write", "cuda", "cpu", "tensorrt"] DEFAULT_SEMA_COUNT = 1 TIMING_FORMAT = "dddd" # wait_time, lease_time, lease_count, start_time @@ -153,16 +153,14 @@ def configure_semaphores( Context manager for creating and destroying semaphores. """ - sema_types_to_check: List[SemaphoreType] = ["read", "write", "cuda", "cpu", "trt_compilation"] + sema_types_to_check: List[SemaphoreType] = ["read", "write", "cuda", "cpu", "tensorrt"] if semaphores_spec is not None: for name in semaphores_spec: if name not in get_args(SemaphoreType): raise ValueError(f"`{name}` is not a valid semaphore type.") try: - # TODO: need to make trt_compilation optional - # for sema_type in sema_types_to_check: - # assert semaphores_spec[sema_type] >= 0 - semaphores_spec_ = semaphores_spec + for sema_type in ("read", "write", "cuda", "cpu"): + assert semaphores_spec[sema_type] >= 0 except KeyError as e: raise ValueError( "`semaphores_spec` given to `execute_with_pool` must contain " @@ -170,7 +168,7 @@ def configure_semaphores( ) from e except AssertionError as e: raise ValueError("Number of semaphores must be nonnegative.") from e - semaphores_spec_ = semaphores_spec + semaphores_spec_ = {"tensorrt": DEFAULT_SEMA_COUNT, **semaphores_spec} else: semaphores_spec_ = {name: DEFAULT_SEMA_COUNT for name in sema_types_to_check} From faa7d711c972131c9e590d949956ce0b5ec7fccb Mon Sep 17 00:00:00 2001 From: Dodam Ih Date: Tue, 24 Mar 2026 16:15:50 -0700 Subject: [PATCH 5/7] fix: extract modules for type checks --- zetta_utils/convnet/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zetta_utils/convnet/utils.py b/zetta_utils/convnet/utils.py index 68d6e310d..aebbddbb4 100644 --- a/zetta_utils/convnet/utils.py +++ b/zetta_utils/convnet/utils.py @@ -29,7 +29,7 @@ @builder.register("load_model") -# @typechecked +@typechecked def load_model( path: str, device: Union[str, torch.device] = "cpu", @@ -65,7 +65,7 @@ def _load_model( result = onnx2torch.convert(onnx.load(f)).to(device) elif path.endswith(".ts"): # load a cached TensorRT model - result = torch.export.load(path) + result = torch.export.load(path).module() else: raise ValueError(f"Unsupported file format: {path}") From b9eeebb81a25ad2b3da486333ca869c679e1788a Mon Sep 17 00:00:00 2001 From: Dodam Ih Date: Tue, 24 Mar 2026 22:18:28 -0700 Subject: [PATCH 6/7] chore: mypy / pylint fix --- zetta_utils/convnet/utils.py | 12 +++++++----- zetta_utils/mazepa/semaphores.py | 1 + 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/zetta_utils/convnet/utils.py b/zetta_utils/convnet/utils.py index aebbddbb4..15987c572 100644 --- a/zetta_utils/convnet/utils.py +++ b/zetta_utils/convnet/utils.py @@ -21,10 +21,10 @@ TENSORRT_AVAILABLE = False try: - import torch_tensorrt + import torch_tensorrt # pylint: disable=import-error TENSORRT_AVAILABLE = True -except Exception as e: +except (ImportError, OSError) as e: logger.info(f"torch_tensorrt is not available: {e}") @@ -74,8 +74,10 @@ def _load_model( raise RuntimeError("torch_tensorrt is not installed!") with semaphore("tensorrt"): - # TensorRT should not be compiled concurrently by many threads or will run out of memory + # TensorRT should not be compiled concurrently by many threads + # or will run out of memory # Ideally, only by one thread and the others then load the cached model + assert input_shape is not None # mypy trt_fname = ( str(xxhash.xxh128(str((path, tuple(input_shape))).encode("utf-8")).hexdigest()) @@ -88,7 +90,7 @@ def _load_model( result = torch_tensorrt.load(cache_path).module() logger.info(f"Loaded cached TensorRT model: {cache_path}") return result - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught logger.info(f"Cache not found or invalid, compiling TensorRT model: {e}") example_in = torch.rand(input_shape).to(device=device) @@ -102,7 +104,7 @@ def _load_model( torch_tensorrt.save(compiled, cache_path) logger.info(f"Compiled TensorRT model saved to cache: {cache_path}") result = compiled - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught logger.warning(f"TensorRT compilation failed, falling back to eager mode: {e}") return result diff --git a/zetta_utils/mazepa/semaphores.py b/zetta_utils/mazepa/semaphores.py index 76e73b311..1143e6124 100644 --- a/zetta_utils/mazepa/semaphores.py +++ b/zetta_utils/mazepa/semaphores.py @@ -159,6 +159,7 @@ def configure_semaphores( if name not in get_args(SemaphoreType): raise ValueError(f"`{name}` is not a valid semaphore type.") try: + sema_type: SemaphoreType for sema_type in ("read", "write", "cuda", "cpu"): assert semaphores_spec[sema_type] >= 0 except KeyError as e: From 12620984669193f034a6f525ec47c188003aae6c Mon Sep 17 00:00:00 2001 From: Dodam Ih Date: Tue, 24 Mar 2026 22:20:30 -0700 Subject: [PATCH 7/7] chore: add tensorrt cleanup to test fixture --- tests/unit/mazepa/test_semaphores.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/mazepa/test_semaphores.py b/tests/unit/mazepa/test_semaphores.py index 439ea81e8..6d1e6d8f9 100644 --- a/tests/unit/mazepa/test_semaphores.py +++ b/tests/unit/mazepa/test_semaphores.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True) def cleanup_semaphores(): yield - sema_types: List[SemaphoreType] = ["read", "write", "cuda", "cpu"] + sema_types: List[SemaphoreType] = list(get_args(SemaphoreType)) for name in sema_types: try: # two unlinks in case grandparent semaphore exists @@ -32,6 +32,10 @@ def cleanup_semaphores(): semaphore(name).unlink() except: pass + try: + TimingTracker(name).unlink() + except: + pass def test_default_semaphore_init():