Skip to content

Commit bd02455

Browse files
committed
the barrier for TRT-LLM installation
1 parent 63121fb commit bd02455

File tree

1 file changed

+45
-31
lines changed

1 file changed

+45
-31
lines changed

py/torch_tensorrt/_utils.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,55 @@ def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path:
143143
)
144144

145145

146+
def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None:
147+
# this will not be encountered in case of platforms not supporting torch distributed/nccl/TRT-LLM
148+
from torch.distributed import barrier, get_rank, is_initialized
149+
150+
if not is_initialized():
151+
# Single process case, just unzip
152+
is_master = True
153+
else:
154+
is_master = get_rank() == 0 # only rank 0 does the unzip
155+
156+
if is_master:
157+
try:
158+
import zipfile
159+
except ImportError as e:
160+
raise ImportError(
161+
"zipfile module is required but not found. Please install zipfile"
162+
)
163+
try:
164+
with zipfile.ZipFile(wheel_path) as zip_ref:
165+
zip_ref.extractall(extract_dir)
166+
logger.debug(f"Extracted wheel to {extract_dir}")
167+
168+
except FileNotFoundError as e:
169+
# This should capture the errors in the download failure above
170+
logger.error(f"Wheel file not found at {wheel_path}: {e}")
171+
raise RuntimeError(
172+
f"Failed to find downloaded wheel file at {wheel_path}"
173+
) from e
174+
except zipfile.BadZipFile as e:
175+
logger.error(f"Invalid or corrupted wheel file: {e}")
176+
raise RuntimeError(
177+
"Downloaded wheel file is corrupted or not a valid zip archive"
178+
) from e
179+
except Exception as e:
180+
logger.error(f"Unexpected error while extracting wheel: {e}")
181+
raise RuntimeError(
182+
"Unexpected error during extraction of TensorRT-LLM wheel"
183+
) from e
184+
185+
# Make sure others wait until unzip is done
186+
if is_initialized():
187+
barrier()
188+
189+
146190
def download_and_get_plugin_lib_path() -> Optional[str]:
147191
"""
148192
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
149-
150193
Args:
151194
platform (str): Platform identifier (e.g., 'linux_x86_64')
152-
153195
Returns:
154196
Optional[str]: Path to shared library or None if operation fails.
155197
"""
@@ -194,32 +236,7 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
194236
except OSError as e:
195237
logger.error(f"Local file write error: {e}")
196238

197-
try:
198-
import zipfile
199-
except ImportError as e:
200-
raise ImportError(
201-
"zipfile module is required but not found. Please install zipfile"
202-
)
203-
try:
204-
with zipfile.ZipFile(wheel_path) as zip_ref:
205-
zip_ref.extractall(extract_dir)
206-
logger.debug(f"Extracted wheel to {extract_dir}")
207-
except FileNotFoundError as e:
208-
# This should capture the errors in the download failure above
209-
logger.error(f"Wheel file not found at {wheel_path}: {e}")
210-
raise RuntimeError(
211-
f"Failed to find downloaded wheel file at {wheel_path}"
212-
) from e
213-
except zipfile.BadZipFile as e:
214-
logger.error(f"Invalid or corrupted wheel file: {e}")
215-
raise RuntimeError(
216-
"Downloaded wheel file is corrupted or not a valid zip archive"
217-
) from e
218-
except Exception as e:
219-
logger.error(f"Unexpected error while extracting wheel: {e}")
220-
raise RuntimeError(
221-
"Unexpected error during extraction of TensorRT-LLM wheel"
222-
) from e
239+
extract_wheel_file(wheel_path, extract_dir)
223240

224241
try:
225242
wheel_path.unlink(missing_ok=True)
@@ -238,10 +255,8 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
238255
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
239256
"""
240257
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
241-
242258
Args:
243259
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
244-
245260
Returns:
246261
bool: True if successful, False otherwise.
247262
"""
@@ -293,7 +308,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
293308
Attempts to load the TensorRT-LLM plugin and initialize it.
294309
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
295310
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
296-
297311
Returns:
298312
bool: True if the plugin was successfully loaded and initialized, False otherwise.
299313
"""

0 commit comments

Comments
 (0)