Skip to content

Commit d289587

Browse files
committed
the barrier for TRT-LLM installation
1 parent 63121fb commit d289587

File tree

1 file changed

+44
-31
lines changed

1 file changed

+44
-31
lines changed

py/torch_tensorrt/_utils.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,54 @@ def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path:
143143
)
144144

145145

146+
def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None:
147+
from torch.distributed import barrier, get_rank, is_initialized
148+
149+
if not is_initialized():
150+
# Single process case, just unzip
151+
is_master = True
152+
else:
153+
is_master = get_rank() == 0 # only rank 0 does the unzip
154+
155+
if is_master:
156+
try:
157+
import zipfile
158+
except ImportError as e:
159+
raise ImportError(
160+
"zipfile module is required but not found. Please install zipfile"
161+
)
162+
try:
163+
with zipfile.ZipFile(wheel_path) as zip_ref:
164+
zip_ref.extractall(extract_dir)
165+
logger.debug(f"Extracted wheel to {extract_dir}")
166+
167+
except FileNotFoundError as e:
168+
# This should capture the errors in the download failure above
169+
logger.error(f"Wheel file not found at {wheel_path}: {e}")
170+
raise RuntimeError(
171+
f"Failed to find downloaded wheel file at {wheel_path}"
172+
) from e
173+
except zipfile.BadZipFile as e:
174+
logger.error(f"Invalid or corrupted wheel file: {e}")
175+
raise RuntimeError(
176+
"Downloaded wheel file is corrupted or not a valid zip archive"
177+
) from e
178+
except Exception as e:
179+
logger.error(f"Unexpected error while extracting wheel: {e}")
180+
raise RuntimeError(
181+
"Unexpected error during extraction of TensorRT-LLM wheel"
182+
) from e
183+
184+
# Make sure others wait until unzip is done
185+
if is_initialized():
186+
barrier()
187+
188+
146189
def download_and_get_plugin_lib_path() -> Optional[str]:
147190
"""
148191
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
149-
150192
Args:
151193
platform (str): Platform identifier (e.g., 'linux_x86_64')
152-
153194
Returns:
154195
Optional[str]: Path to shared library or None if operation fails.
155196
"""
@@ -194,32 +235,7 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
194235
except OSError as e:
195236
logger.error(f"Local file write error: {e}")
196237

197-
try:
198-
import zipfile
199-
except ImportError as e:
200-
raise ImportError(
201-
"zipfile module is required but not found. Please install zipfile"
202-
)
203-
try:
204-
with zipfile.ZipFile(wheel_path) as zip_ref:
205-
zip_ref.extractall(extract_dir)
206-
logger.debug(f"Extracted wheel to {extract_dir}")
207-
except FileNotFoundError as e:
208-
# This should capture the errors in the download failure above
209-
logger.error(f"Wheel file not found at {wheel_path}: {e}")
210-
raise RuntimeError(
211-
f"Failed to find downloaded wheel file at {wheel_path}"
212-
) from e
213-
except zipfile.BadZipFile as e:
214-
logger.error(f"Invalid or corrupted wheel file: {e}")
215-
raise RuntimeError(
216-
"Downloaded wheel file is corrupted or not a valid zip archive"
217-
) from e
218-
except Exception as e:
219-
logger.error(f"Unexpected error while extracting wheel: {e}")
220-
raise RuntimeError(
221-
"Unexpected error during extraction of TensorRT-LLM wheel"
222-
) from e
238+
extract_wheel_file(wheel_path, extract_dir)
223239

224240
try:
225241
wheel_path.unlink(missing_ok=True)
@@ -238,10 +254,8 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
238254
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
239255
"""
240256
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
241-
242257
Args:
243258
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
244-
245259
Returns:
246260
bool: True if successful, False otherwise.
247261
"""
@@ -293,7 +307,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
293307
Attempts to load the TensorRT-LLM plugin and initialize it.
294308
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
295309
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
296-
297310
Returns:
298311
bool: True if the plugin was successfully loaded and initialized, False otherwise.
299312
"""

0 commit comments

Comments
 (0)