Skip to content

Commit f4f338a

Browse files
committed
changing the implementation of avoiding race condition in unzip of TRT-LLM wheel by using lock file
1 parent 38224c5 commit f4f338a

File tree

7 files changed

+76
-72
lines changed

7 files changed

+76
-72
lines changed

examples/distributed_inference/tensor_parallel_initialize_dist.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import tensorrt as trt
1515
import torch
1616
import torch.distributed as dist
17-
from torch.distributed._tensor.device_mesh import init_device_mesh
17+
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
18+
19+
logger = logging.getLogger(__name__)
1820

1921

2022
# this is kept at the application level, when mpirun is used to run the application
@@ -54,3 +56,37 @@ def cleanup_distributed_env():
5456
"""Clean up distributed process group to prevent resource leaks."""
5557
if dist.is_initialized():
5658
dist.destroy_process_group()
59+
60+
61+
def check_tensor_parallel_device_number(world_size: int) -> None:
62+
if world_size % 2 != 0:
63+
raise ValueError(
64+
f"TP examples require even number of GPUs, but got {world_size} gpus"
65+
)
66+
67+
68+
def get_tensor_parallel_device_mesh(
69+
rank: int = 0, world_size: int = 1
70+
) -> tuple[DeviceMesh, int, int]:
71+
local_rank = int(
72+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
73+
)
74+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
75+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
76+
rank = device_mesh.get_rank()
77+
assert rank == local_rank
78+
device_id = (
79+
rank % torch.cuda.device_count()
80+
) # Ensure each rank gets a unique device
81+
torch.cuda.set_device(device_id)
82+
83+
return device_mesh, world_size, rank
84+
85+
86+
def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger:
87+
logger = logging.getLogger()
88+
logger.setLevel(logging.INFO)
89+
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
90+
fh.setLevel(logging.INFO)
91+
logger.addHandler(fh)
92+
return logger

examples/distributed_inference/tensor_parallel_rotary_embedding.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,21 @@
99
1010
"""
1111

12-
import logging
13-
import os
1412
import time
1513

1614
import torch
1715
import torch.distributed as dist
1816
from tensor_parallel_initialize_dist import (
1917
cleanup_distributed_env,
18+
get_tensor_parallel_device_mesh,
2019
initialize_distributed_env,
20+
initialize_distributed_logger,
2121
)
2222

2323
if not dist.is_initialized():
2424
initialize_distributed_env()
2525

2626
import torch_tensorrt
27-
from torch_tensorrt.dynamo.distributed.utils import (
28-
get_tensor_parallel_device_mesh,
29-
initialize_distributed_logger,
30-
)
3127

3228
device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
3329
logger = initialize_distributed_logger(_rank, "tensor_parallel_rotary_embedding")
@@ -36,8 +32,8 @@
3632

3733
"""
3834
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
39-
Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
40-
Command to run with 2 GPUs: mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py
35+
Command to run with single GPU: USE_TRTLLM_PLUGINS=1 mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
36+
Command to run with 2 GPUs: USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py
4137
"""
4238

4339
BATCH = 2

examples/distributed_inference/tensor_parallel_simple_example.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
-----
1717
.. code-block:: bash
1818
19-
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
19+
USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
2020
"""
2121

2222
import time
@@ -27,7 +27,9 @@
2727
import torch.nn as nn
2828
from tensor_parallel_initialize_dist import (
2929
cleanup_distributed_env,
30+
get_tensor_parallel_device_mesh,
3031
initialize_distributed_env,
32+
initialize_distributed_logger,
3133
)
3234

3335
if not dist.is_initialized():

py/torch_tensorrt/_utils.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import platform
66
import sys
77
import tempfile
8+
import time
89
import urllib.request
910
from pathlib import Path
1011
from typing import Any, Optional
@@ -144,47 +145,59 @@ def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path:
144145

145146

146147
def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None:
147-
# this will not be encountered in case of platforms not supporting torch distributed/nccl/TRT-LLM
148-
from torch.distributed import barrier, get_rank, is_initialized
149-
150-
if not is_initialized():
151-
# Single process case, just unzip
152-
is_master = True
153-
else:
154-
is_master = get_rank() == 0 # only rank 0 does the unzip
155-
156-
if is_master:
148+
"""
149+
Safely extract a wheel file to a directory with a lock to prevent concurrent extraction.
150+
"""
151+
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) # MPI rank from OpenMPI
152+
torch.cuda.set_device(rank)
153+
lock_file = extract_dir / ".extracting"
154+
155+
# Rank 0 performs extraction
156+
if rank == 0:
157+
logger.debug(
158+
f"[Rank {rank}] Starting extraction of {wheel_path} to {extract_dir}"
159+
)
157160
try:
158161
import zipfile
159162
except ImportError as e:
160163
raise ImportError(
161164
"zipfile module is required but not found. Please install zipfile"
162165
)
166+
# Create lock file to signal extraction in progress
167+
extract_dir.mkdir(parents=True, exist_ok=False)
168+
lock_file.touch(exist_ok=False)
163169
try:
164170
with zipfile.ZipFile(wheel_path) as zip_ref:
165171
zip_ref.extractall(extract_dir)
166-
logger.debug(f"Extracted wheel to {extract_dir}")
167-
172+
logger.debug(f"[Rank {rank}] Extraction complete: {extract_dir}")
173+
print(f"[Rank {rank}] Extraction complete: {extract_dir}")
168174
except FileNotFoundError as e:
169-
# This should capture the errors in the download failure above
170-
logger.error(f"Wheel file not found at {wheel_path}: {e}")
175+
logger.error(f"[Rank {rank}] Wheel file not found at {wheel_path}: {e}")
171176
raise RuntimeError(
172177
f"Failed to find downloaded wheel file at {wheel_path}"
173178
) from e
174179
except zipfile.BadZipFile as e:
175-
logger.error(f"Invalid or corrupted wheel file: {e}")
180+
logger.error(f"[Rank {rank}] Invalid or corrupted wheel file: {e}")
176181
raise RuntimeError(
177182
"Downloaded wheel file is corrupted or not a valid zip archive"
178183
) from e
179184
except Exception as e:
180-
logger.error(f"Unexpected error while extracting wheel: {e}")
185+
logger.error(f"[Rank {rank}] Unexpected error while extracting wheel: {e}")
181186
raise RuntimeError(
182187
"Unexpected error during extraction of TensorRT-LLM wheel"
183188
) from e
189+
finally:
190+
# Remove lock file to signal completion
191+
lock_file.unlink(missing_ok=True)
184192

185-
# Make sure others wait until unzip is done
186-
if is_initialized():
187-
barrier()
193+
else:
194+
# Other ranks wait for extraction to complete
195+
while lock_file.exists():
196+
logger.debug(
197+
f"[Rank {rank}] Waiting for extraction to finish at {extract_dir}..."
198+
)
199+
print(f"[Rank {rank}] Waiting... device:", torch.cuda.current_device())
200+
time.sleep(0.5)
188201

189202

190203
def download_and_get_plugin_lib_path() -> Optional[str]:

py/torch_tensorrt/dynamo/distributed/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

py/torch_tensorrt/dynamo/distributed/utils.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,6 @@ def run(self):
450450
"torch_tensorrt.dynamo.conversion.impl.unary",
451451
"torch_tensorrt.dynamo.conversion.plugins",
452452
"torch_tensorrt.dynamo.debug",
453-
"torch_tensorrt.dynamo.distributed",
454453
"torch_tensorrt.dynamo.lowering",
455454
"torch_tensorrt.dynamo.lowering.passes",
456455
"torch_tensorrt.dynamo.partitioning",

0 commit comments

Comments
 (0)