|
5 | 5 | import time |
6 | 6 |
|
7 | 7 | import torch |
8 | | -import torch_tensorrt |
| 8 | +import torch.distributed as dist |
9 | 9 | from llama3_model import ModelArgs, ParallelTransformer |
| 10 | +from tensor_parallel_initialize_dist import ( |
| 11 | + cleanup_distributed_env, |
| 12 | + initialize_distributed_env, |
| 13 | +) |
10 | 14 | from torch.distributed._composable.fsdp import MixedPrecisionPolicy |
11 | 15 | from torch.distributed._composable.fsdp.fully_shard import fully_shard |
12 | 16 | from torch.distributed._tensor import Replicate, Shard |
13 | 17 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
14 | 18 | checkpoint_wrapper, |
15 | 19 | ) |
| 20 | + |
| 21 | +if not dist.is_initialized(): |
| 22 | + initialize_distributed_env() |
| 23 | + |
| 24 | +import torch_tensorrt |
16 | 25 | from torch_tensorrt.dynamo.distributed.utils import ( |
17 | | - cleanup_distributed_env, |
18 | 26 | get_tensor_parallel_device_mesh, |
19 | | - initialize_distributed_env, |
20 | 27 | initialize_logger, |
21 | 28 | ) |
22 | 29 |
|
23 | | -if not dist.is_initialized(): |
24 | | - initialize_distributed_env() |
25 | | - |
26 | 30 | device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() |
27 | | -logger = initialize_logger(_rank, "tensor_parallel_simple_example") |
| 31 | +logger = initialize_logger(_rank, "tensor_parallel_llama3") |
28 | 32 |
|
29 | 33 | logger.info(f"Starting PyTorch TP example on rank {_rank}.") |
30 | 34 | assert ( |
|
0 commit comments