Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docker/pytorch_base/Dockerfile.prebuilt
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,14 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
&& /opt/conda/bin/pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} \
tensorrt torch-tensorrt==${TORCH_TENSORRT_VERSION} --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | tr -d .) \
# && /opt/conda/bin/pip install "nvidia-modelopt" --extra-index-url https://pypi.nvidia.com \
&& find /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia -name '*.so*' -path '*/lib/*' -printf '%h\n' 2>/dev/null | sort -u > /etc/ld.so.conf.d/nvidia-pip.conf \
&& echo "/opt/conda/lib/python${PYTHON_VERSION}/site-packages/tensorrt_libs" >> /etc/ld.so.conf.d/nvidia-pip.conf \
&& ldconfig \
&& apt-mark auto '.*' > /dev/null \
&& apt-mark manual $preexistingAptMark ca-certificates > /dev/null \
&& find /usr/local/lib /opt/conda -type f -executable -exec ldd '{}' ';' \
&& find /usr/local/lib /opt/conda \
\( -path '*/site-packages/nvidia' -o -path '*/site-packages/torch' -o -path '*/site-packages/tensorrt*' -o -path '*/site-packages/triton' \) -prune \
-o -type f -executable -exec ldd '{}' ';' \
| awk '/=>/ { print $(NF-1) }' \
| sort -u \
| grep -v not \
Expand All @@ -55,4 +60,4 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
/usr/share/doc \
/usr/share/doc-base

CMD ["/bin/bash"]
CMD ["/bin/bash"]
6 changes: 5 additions & 1 deletion tests/unit/mazepa/test_semaphores.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@
@pytest.fixture(autouse=True)
def cleanup_semaphores():
yield
sema_types: List[SemaphoreType] = ["read", "write", "cuda", "cpu"]
sema_types: List[SemaphoreType] = list(get_args(SemaphoreType))
for name in sema_types:
try:
# two unlinks in case grandparent semaphore exists
semaphore(name).unlink()
semaphore(name).unlink()
except:
pass
try:
TimingTracker(name).unlink()
except:
pass


def test_default_semaphore_init():
Expand Down
97 changes: 90 additions & 7 deletions zetta_utils/convnet/utils.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,58 @@
from __future__ import annotations

import io
import os
from typing import Literal, Optional, Sequence, Union, overload

import cachetools
import fsspec
import onnx
import onnx2torch
import torch
import xxhash
from numpy import typing as npt
from typeguard import typechecked

from zetta_utils import builder, log, tensor_ops
from zetta_utils.mazepa import semaphore

logger = log.get_logger("zetta_utils")


TENSORRT_AVAILABLE = False
try:
import torch_tensorrt # pylint: disable=import-error

TENSORRT_AVAILABLE = True
except (ImportError, OSError) as e:
logger.info(f"torch_tensorrt is not available: {e}")


@builder.register("load_model")
@typechecked
def load_model(
path: str, device: Union[str, torch.device] = "cpu", use_cache: bool = False
path: str,
device: Union[str, torch.device] = "cpu",
use_cache: bool = False,
input_shape: Sequence[int] | None = None,
tensorrt_enabled: bool = False,
tensorrt_cache_dir: str = ".", # defaults to the current working directory
) -> torch.nn.Module: # pragma: no cover
if use_cache:
result = _load_model_cached(path, device)
result = _load_model_cached(
path, device, input_shape, tensorrt_enabled, tensorrt_cache_dir
)
else:
result = _load_model(path, device)
result = _load_model(path, device, input_shape, tensorrt_enabled, tensorrt_cache_dir)
return result


def _load_model(
path: str, device: Union[str, torch.device] = "cpu"
path: str,
device: Union[str, torch.device] = "cpu",
input_shape: Sequence[int] | None = None,
tensorrt_enabled: bool = False,
tensorrt_cache_dir: str = ".",
) -> torch.nn.Module: # pragma: no cover
logger.debug(f"Loading model from '{path}'")
if path.endswith(".json"):
Expand All @@ -40,8 +63,50 @@ def _load_model(
elif path.endswith(".onnx"):
with fsspec.open(path, "rb") as f:
result = onnx2torch.convert(onnx.load(f)).to(device)
elif path.endswith(".ts"):
# load a cached TensorRT model
result = torch.export.load(path).module()
else:
raise ValueError(f"Unsupported file format: {path}")

if tensorrt_enabled:
if not TENSORRT_AVAILABLE:
raise RuntimeError("torch_tensorrt is not installed!")

with semaphore("tensorrt"):
# TensorRT should not be compiled concurrently by many threads
# or will run out of memory
# Ideally, only by one thread and the others then load the cached model
assert input_shape is not None # mypy

trt_fname = (
str(xxhash.xxh128(str((path, tuple(input_shape))).encode("utf-8")).hexdigest())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to also add the gpu model or gpu capability to the hash. I'm not sure if a T4 can run a model compiled by L4 or reversed

+ ".trt.ep"
)
cache_path = os.path.join(tensorrt_cache_dir, trt_fname)

# Try to load the optimized model from cache
try:
result = torch_tensorrt.load(cache_path).module()
logger.info(f"Loaded cached TensorRT model: {cache_path}")
return result
except Exception as e: # pylint: disable=broad-exception-caught
logger.info(f"Cache not found or invalid, compiling TensorRT model: {e}")

example_in = torch.rand(input_shape).to(device=device)

try:
compiled = torch_tensorrt.compile(
result,
inputs=[example_in],
enabled_precisions={torch.float, torch.half},
)
torch_tensorrt.save(compiled, cache_path)
logger.info(f"Compiled TensorRT model saved to cache: {cache_path}")
result = compiled
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning(f"TensorRT compilation failed, falling back to eager mode: {e}")

return result


Expand Down Expand Up @@ -112,16 +177,34 @@ def load_and_run_model(


@typechecked
def load_and_run_model(path, data_in, device=None, use_cache=True): # pragma: no cover
def load_and_run_model(
path,
data_in,
device=None,
use_cache=True,
tensorrt_enabled: bool = False,
tensorrt_cache_dir: str = ".",
): # pragma: no cover

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = load_model(path=path, device=device, use_cache=use_cache)
model = load_model(
path=path,
device=device,
use_cache=use_cache,
input_shape=data_in.shape,
tensorrt_enabled=tensorrt_enabled,
tensorrt_cache_dir=tensorrt_cache_dir,
)

autocast_device = device.type if isinstance(device, torch.device) else str(device)
with torch.inference_mode(): # uses less memory when used with JITs
with torch.autocast(device_type=autocast_device):
output = model(tensor_ops.convert.to_torch(data_in, device=device))
gpu_input = tensor_ops.convert.to_torch(data_in, device=device)
output = model(gpu_input)
del gpu_input
output = tensor_ops.convert.astype(output, reference=data_in, cast=True)
torch.cuda.empty_cache()

return output
10 changes: 5 additions & 5 deletions zetta_utils/mazepa/semaphores.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from zetta_utils.common.pprint import lrpad

logger = log.get_logger("mazepa")
SemaphoreType = Literal["read", "write", "cuda", "cpu"]
SemaphoreType = Literal["read", "write", "cuda", "cpu", "tensorrt"]

DEFAULT_SEMA_COUNT = 1
TIMING_FORMAT = "dddd" # wait_time, lease_time, lease_count, start_time
Expand Down Expand Up @@ -152,23 +152,23 @@ def configure_semaphores(
Context manager for creating and destroying semaphores.
"""

sema_types_to_check: List[SemaphoreType] = ["read", "write", "cuda", "cpu"]
sema_types_to_check: List[SemaphoreType] = ["read", "write", "cuda", "cpu", "tensorrt"]
if semaphores_spec is not None:
for name in semaphores_spec:
if name not in get_args(SemaphoreType):
raise ValueError(f"`{name}` is not a valid semaphore type.")
try:
for sema_type in sema_types_to_check:
sema_type: SemaphoreType
for sema_type in ("read", "write", "cuda", "cpu"):
assert semaphores_spec[sema_type] >= 0
semaphores_spec_ = semaphores_spec
except KeyError as e:
raise ValueError(
"`semaphores_spec` given to `execute_with_pool` must contain "
"`read`, `write`, `cuda`, and `cpu`."
) from e
except AssertionError as e:
raise ValueError("Number of semaphores must be nonnegative.") from e
semaphores_spec_ = semaphores_spec
semaphores_spec_ = {"tensorrt": DEFAULT_SEMA_COUNT, **semaphores_spec}
else:
semaphores_spec_ = {name: DEFAULT_SEMA_COUNT for name in sema_types_to_check}

Expand Down
Loading