Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions docsrc/contributors/resource_management.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
.. _resource_management:

Resource Management
===================

Overview
--------

Efficient control of CPU and GPU memory is essential for successful model compilation,
especially when working with large models such as LLMs or diffusion models.
Uncontrolled memory growth can cause compilation failures or process termination.
This guide describes the symptoms of excessive memory usage and provides methods
to reduce both CPU and GPU memory consumption.

Memory Usage Control
--------------------

CPU Memory
^^^^^^^^^^

By default, Torch-TensorRT may consume up to **5x** the model size in CPU memory.
This can exceed system limits when compiling large models.

**Common symptoms of high CPU memory usage:**

- Program freeze
- Process terminated by the operating system

**Ways to lower CPU memory usage:**

1. **Enable memory trimming**

Set the following environment variable:

.. code-block:: bash

export TORCHTRT_ENABLE_BUILDER_MALLOC_TRIM=1

This reduces approximately **2x** of redundant model copies, limiting
total CPU memory usage to up to **3x** the model size.

2. **Disable CPU offloading**

In compilation settings, set:

.. code-block:: python

offload_module_to_cpu = False

This removes another **1x** model copy, reducing peak CPU memory
usage to about **2x** the model size.

GPU Memory
^^^^^^^^^^

By default, Torch-TensorRT may consume up to **2x** the model size in GPU memory.

**Common symptoms of high GPU memory usage:**

- CUDA out-of-memory errors
- TensorRT compilation errors

**Ways to lower GPU memory usage:**

1. **Enable offloading to CPU**

In compilation settings, set:

.. code-block:: python

offload_module_to_cpu = True

This shifts one model copy from GPU to CPU memory.
As a result, peak GPU memory usage decreases to about **1x**
the model size, while one more copy of the model will occupy the CPU memory so CPU memory usage increases by roughly **1x**.


1 change: 1 addition & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ Contributor Documentation
contributors/writing_dynamo_aten_lowering_passes
contributors/ts_converters
contributors/useful_links
contributors/resource_management

Indices
----------------
Expand Down
15 changes: 12 additions & 3 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from torch_tensorrt.dynamo.utils import (
deallocate_module,
get_cpu_memory_usage,
get_flat_args_with_check,
get_output_metadata,
parse_graph_io,
Expand Down Expand Up @@ -681,7 +682,7 @@ def compile(
"offload_module_to_cpu": offload_module_to_cpu,
"use_distributed_mode_trace": use_distributed_mode_trace,
}

logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
exported_program = pre_export_lowering(exported_program, settings)
Expand All @@ -695,14 +696,17 @@ def compile(

# Apply lowering on the graph module
gm = post_lowering(gm, settings)
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
logger.debug("Lowered Input graph: " + str(gm.graph))

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
deallocate_module(gm, delete_module=False)
deallocate_module(exported_program.module(), delete_module=False)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
)
logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB")
else:
remaining_memory, total_memory = torch.cuda.mem_get_info()
if remaining_memory < total_memory // 2:
Expand Down Expand Up @@ -868,6 +872,11 @@ def preserve_module_specs(
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those

# Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function.
# This is done to release CPU memory.
for attr in dir(gm):
if attr.startswith("_frozen_param"):
delattr(gm, attr)
for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
Expand Down Expand Up @@ -1243,7 +1252,7 @@ def convert_exported_program_to_serialized_trt_engine(

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs)
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}

Expand Down Expand Up @@ -1330,7 +1339,7 @@ def convert_exported_program_to_serialized_trt_engine(
)

flattened_input_list = get_flat_args_with_check(
exported_program, list(trt_arg_inputs), trt_kwarg_inputs
exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore
)[0]

try:
Expand Down
76 changes: 20 additions & 56 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import gc
import io
import logging
import os
import warnings
Expand Down Expand Up @@ -50,7 +49,12 @@
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
from torch_tensorrt.dynamo.observer import Observer
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
from torch_tensorrt.dynamo.utils import (
DYNAMIC_DIM,
deallocate_module,
get_cpu_memory_usage,
to_torch_device,
)
from torch_tensorrt.logging import TRT_LOGGER

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand All @@ -65,7 +69,7 @@ class UnsupportedOperatorException(RuntimeError):


class TRTInterpreterResult(NamedTuple):
serialized_engine: bytes
engine: trt.ICudaEngine
input_names: Sequence[str]
output_names: Sequence[str]
weight_name_map: Optional[dict[Any, Any]]
Expand Down Expand Up @@ -512,8 +516,7 @@ def _save_weight_mapping(self) -> None:
_LOGGER.info("Building weight name mapping...")
# Stage 1: Name mapping
torch_device = to_torch_device(self.compilation_settings.device)
self.module.to(torch_device)
sd = self.module.state_dict()
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
weight_name_map: dict[str, Any] = {}
weight_refit_map = self.ctx.weight_refit_map
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
Expand Down Expand Up @@ -591,34 +594,6 @@ def _save_weight_mapping(self) -> None:
gc.collect()
torch.cuda.empty_cache()

@needs_refit # type: ignore[misc]
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
# if not self.compilation_settings.strip_engine_weights:
# # set EXCLUDE_WEIGHTS flag to strip weights
# runtime = trt.Runtime(TRT_LOGGER)
# engine = runtime.deserialize_cuda_engine(serialized_engine)

# serialization_config = engine.create_serialization_config()
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
# serialized_engine = engine.serialize_with_config(
# serialization_config
# )

# Cache weighted engine for now
self.engine_cache.insert( # type: ignore[union-attr]
hash_val,
(
serialized_engine,
self._input_names,
self._output_names,
self.input_specs,
self.compilation_settings,
self.weight_name_map,
self.ctx.requires_output_allocator,
),
)

@needs_refit # type: ignore[misc]
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
# query the cached TRT engine
Expand Down Expand Up @@ -671,7 +646,6 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
settings=self.compilation_settings,
weight_name_map=self.weight_name_map,
)
serialized_engine = engine.serialize()

# TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine
# # EXCLUDE_WEIGHTS flag must be cleared
Expand All @@ -684,12 +658,8 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
# )
# # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()

return TRTInterpreterResult(
engine_str,
engine,
self._input_names,
self._output_names,
self.weight_name_map,
Expand Down Expand Up @@ -733,6 +703,9 @@ def run(
return interpreter_result # type: ignore[no-any-return]

self._construct_trt_network_def()
_LOGGER.debug(
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
)

if not self.compilation_settings.immutable_weights:
self._save_weight_mapping()
Expand All @@ -750,36 +723,27 @@ def run(
self._create_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)
serialized_engine = self.builder.build_serialized_network(

cuda_engine = self.builder.build_engine_with_config(
self.ctx.net, builder_config
)
assert serialized_engine
assert cuda_engine

_LOGGER.debug(
f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB"
)

_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")

self.ctx.clear_cpu_weights_reference_holder()

self._save_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)

# Engine caching only for refittable engines
if (
not self.compilation_settings.immutable_weights
and self.compilation_settings.cache_built_engines
and self.engine_cache is not None
):
self._insert_engine_to_cache(hash_val, serialized_engine)

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()

return TRTInterpreterResult(
engine_str,
cuda_engine,
self._input_names,
self._output_names,
self.weight_name_map,
Expand Down
Loading
Loading