|
5 | 5 | import torch.distributed as dist |
6 | 6 | import torch.nn as nn |
7 | 7 | from conversion.harness import DispatchTestCase |
8 | | -from distributed_utils import set_environment_variables_pytest |
| 8 | + |
| 9 | +# The distributed env initialization has to be before torchTRT import since it uses barrier |
| 10 | +from distributed_utils import ( |
| 11 | + set_environment_variables_pytest, |
| 12 | + set_environment_variables_pytest_multi_process, |
| 13 | + set_environment_variables_pytest_single_process, |
| 14 | +) |
9 | 15 | from parameterized import parameterized |
10 | 16 | from torch.testing._internal.common_utils import run_tests |
11 | | -from torch_tensorrt.dynamo.utils import is_platform_supported_for_trtllm |
| 17 | + |
| 18 | +if "OMPI_COMM_WORLD_SIZE" in os.environ: |
| 19 | + set_environment_variables_pytest_multi_process() |
| 20 | +else: |
| 21 | + set_environment_variables_pytest_single_process() |
| 22 | + |
| 23 | +if not dist.is_initialized(): |
| 24 | + dist.init_process_group( |
| 25 | + backend="nccl", |
| 26 | + init_method="env://", |
| 27 | + ) |
| 28 | + |
| 29 | +from torch_tensorrt.dynamo.distributed.utils import is_platform_supported_for_trtllm |
12 | 30 |
|
13 | 31 |
|
14 | 32 | class DistributedGatherModel(nn.Module): |
@@ -48,11 +66,9 @@ class TestNcclOpsConverter(DispatchTestCase): |
48 | 66 | ) |
49 | 67 | @classmethod |
50 | 68 | def setUpClass(cls): |
51 | | - set_environment_variables_pytest() |
52 | | - cls.world_size = 1 |
53 | | - if not dist.is_initialized(): |
54 | | - dist.init_process_group(backend="nccl") |
55 | | - cls.group = dist.new_group(ranks=[0]) |
| 69 | + cls.world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) |
| 70 | + cls.rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) |
| 71 | + cls.group = dist.new_group(ranks=list(range(cls.world_size))) |
56 | 72 | cls.group_name = cls.group.group_name |
57 | 73 |
|
58 | 74 | @classmethod |
|
0 commit comments