diff --git a/.github/workflows/build_whl.yml b/.github/workflows/build_whl.yml index 9116b598..05193be8 100644 --- a/.github/workflows/build_whl.yml +++ b/.github/workflows/build_whl.yml @@ -37,12 +37,12 @@ jobs: password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Pull Docker image - run: docker pull pytorch/manylinux-cuda113:latest + run: docker pull pytorch/manylinux2_28-builder:cuda12.4 - name: Run Docker image and execute script run: | version=${{ matrix.python-version }} - docker run -e BUILD_DOCKER_ENV=1 -e CUDACXX=/usr/local/cuda-11.3/bin/nvcc -e PATH="/opt/rh/devtoolset-9/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-9/root/usr/lib64:/opt/rh/devtoolset-9/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i pytorch/manylinux-cuda113:latest /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/pip install build; /opt/python/cp${version}*/bin/python -m build .;for file in dist/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done" + docker run -e BUILD_DOCKER_ENV=1 -e CUDACXX=/usr/local/cuda-12.4/bin/nvcc -e PATH="/opt/rh/devtoolset-9/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-9/root/usr/lib64:/opt/rh/devtoolset-9/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i pytorch/manylinux2_28-builder:cuda12.4 /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/pip install build; /opt/python/cp${version}*/bin/python -m build .;for file in dist/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done" - name: Upload wheels as artifacts uses: actions/upload-artifact@v4 diff --git a/Dockerfile b/Dockerfile index 8e6cbddf..8859976f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,20 +1,31 @@ -FROM nvidia/cuda:10.2-devel +FROM nvidia/cuda:12.8.0-devel-ubuntu22.04 WORKDIR /build + RUN apt update && apt install -y --no-install-recommends \ build-essential \ python3-dev \ python3-pip \ python3-setuptools \ - python3-wheel -RUN pip3 install torch==1.10.0 -i https://pypi.tuna.tsinghua.edu.cn/simple -RUN pip3 install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple -RUN apt install iputils-ping opensm libopensm-dev libibverbs1 libibverbs-dev -y --no-install-recommends -ENV TORCH_CUDA_ARCH_LIST=6.1;7.0;7.5 + python3-wheel \ + cmake \ + ninja-build \ + git \ + iputils-ping opensm libopensm-dev libibverbs1 libibverbs-dev + +RUN pip3 install --upgrade pip setuptools wheel -i https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install --break-system-packages torch==2.8.0 -i https://pypi.tuna.tsinghua.edu.cn/simple +RUN pip3 install --break-system-packages numpy -i https://pypi.tuna.tsinghua.edu.cn/simple + +ENV TORCH_CUDA_ARCH_LIST="6.1;7.0;7.5;8.0;8.6;8.9;9.0" ENV BMT_AVX512=1 + ADD other_requirements.txt other_requirements.txt -RUN pip3 install --upgrade pip && pip3 install -r other_requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +RUN pip3 install --break-system-packages -r other_requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + ADD . . -RUN python3 setup.py install +RUN pip3 install --break-system-packages . WORKDIR /root -ADD example example \ No newline at end of file +ADD example example +ADD tests tests \ No newline at end of file diff --git a/bmtrain/benchmark/all_gather.py b/bmtrain/benchmark/all_gather.py index b2f2ee7c..89e8bf2b 100644 --- a/bmtrain/benchmark/all_gather.py +++ b/bmtrain/benchmark/all_gather.py @@ -18,7 +18,7 @@ def all_gather(): end_evt = torch.cuda.Event(enable_timing=True) current_stream.record_event(start_evt) - nccl.allGather(partition_tensor.storage(), global_tensor.storage(), config['comm']) + nccl.allGather(partition_tensor.view(-1), global_tensor.view(-1), config['comm']) current_stream.record_event(end_evt) current_stream.synchronize() time_usage = start_evt.elapsed_time(end_evt) diff --git a/bmtrain/benchmark/reduce_scatter.py b/bmtrain/benchmark/reduce_scatter.py index 75733556..8bfc0cac 100644 --- a/bmtrain/benchmark/reduce_scatter.py +++ b/bmtrain/benchmark/reduce_scatter.py @@ -18,7 +18,7 @@ def reduce_scatter(): end_evt = torch.cuda.Event(enable_timing=True) current_stream.record_event(start_evt) - nccl.reduceScatter(global_tensor.storage(), partition_tensor.storage(), 'avg', config['comm']) + nccl.reduceScatter(global_tensor.view(-1), partition_tensor.view(-1), 'avg', config['comm']) current_stream.record_event(end_evt) current_stream.synchronize() time_usage = start_evt.elapsed_time(end_evt) diff --git a/bmtrain/benchmark/send_recv.py b/bmtrain/benchmark/send_recv.py index e3c971e4..2d71a1cb 100644 --- a/bmtrain/benchmark/send_recv.py +++ b/bmtrain/benchmark/send_recv.py @@ -18,9 +18,9 @@ def send_recv(): current_stream.record_event(start_evt) nccl.groupStart() if config['rank'] in [0,2,4,6]: - nccl.send(send_buffer.storage(), config['rank']+1, config['comm']) + nccl.send(send_buffer.view(-1), config['rank']+1, config['comm']) else: - nccl.recv(recv_buffer.storage(), config['rank']-1, config['comm']) + nccl.recv(recv_buffer.view(-1), config['rank']-1, config['comm']) nccl.groupEnd() current_stream.record_event(end_evt) current_stream.synchronize() diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 216d77b2..fa7c4872 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -11,31 +11,6 @@ from torch.utils.checkpoint import checkpoint -def storage_type_cuda(storage_type): - """Convert storage_type to cuda storage_type.""" - STORAGE_MAP = { - torch.FloatStorage: torch.cuda.FloatStorage, - torch.DoubleStorage: torch.cuda.DoubleStorage, - torch.HalfStorage: torch.cuda.HalfStorage, - torch.BFloat16Storage: torch.cuda.BFloat16Storage, - torch.CharStorage: torch.cuda.CharStorage, - torch.ByteStorage: torch.cuda.ByteStorage, - torch.ShortStorage: torch.cuda.ShortStorage, - torch.IntStorage: torch.cuda.IntStorage, - torch.cuda.FloatStorage: torch.cuda.FloatStorage, - torch.cuda.DoubleStorage: torch.cuda.DoubleStorage, - torch.cuda.HalfStorage: torch.cuda.HalfStorage, - torch.cuda.BFloat16Storage: torch.cuda.BFloat16Storage, - torch.cuda.CharStorage: torch.cuda.CharStorage, - torch.cuda.ByteStorage: torch.cuda.ByteStorage, - torch.cuda.ShortStorage: torch.cuda.ShortStorage, - torch.cuda.IntStorage: torch.cuda.IntStorage, - } - if storage_type not in STORAGE_MAP: - raise ValueError("Unknown storage type: {}".format(storage_type)) - return STORAGE_MAP[storage_type] - - def _get_param_kw(param: DistributedParameter): """Get DistributedParameter kw name.""" type_name = str(param.dtype).split(".")[-1] @@ -121,7 +96,6 @@ def init_param_storage(self): "All parameters in checkpoint block must be DistributedParameter." ) - storage_type = storage_type_cuda(param.storage_type()) kw_name = _get_param_kw(param) if kw_name not in self._storage_info: @@ -136,7 +110,7 @@ def init_param_storage(self): self._storage_info[kw_name] = { "total": 0, - "storage_type": storage_type, + "dtype": param.dtype, "requires_grad": param.requires_grad, "group": param.group, "zero_comm": zero_comm, @@ -165,16 +139,10 @@ def init_param_storage(self): val["end"] = (rank + 1) * partition_size offsets[kw] = 0 - storage_type = val["storage_type"] - - storage_param_buffer = storage_type(partition_size) - - dtype = storage_param_buffer.dtype - device = storage_param_buffer.device - + dtype = val["dtype"] # bind storage to buffer tensor storage_param = torch.nn.Parameter( - torch.tensor([], dtype=dtype, device=device).set_(storage_param_buffer) + torch.empty(partition_size, dtype=dtype, device="cuda") ) if val["requires_grad"]: storage_param.requires_grad_(True) @@ -223,23 +191,12 @@ def init_param_storage(self): to_offset_end = offset_end + param_st - storage_st # copy to buffer - # PyTorch 1.11 changed the API of storage.__getitem__ - d_dtype = self._storage_params[kw_name].dtype - d_device = self._storage_params[kw_name].device - param.data = torch.tensor( - [], dtype=param.dtype, device=param.device - ).set_( - self._storage_params[kw_name].storage(), - to_offset_st, - (to_offset_end - to_offset_st,), - ) + param.data = self._storage_params[kw_name].view(-1)[to_offset_st:to_offset_end] self._param_info[-1]["begin"] = to_offset_st self._param_info[-1]["end"] = (to_offset_end - to_offset_st,) setattr(param, "_start_partition", offset_st) setattr(param, "_end_partition", offset_end) - param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_( - contiguous_param.storage(), offset_st, (offset_end - offset_st,) - )[:] + param.data[:] = contiguous_param.view(-1)[offset_st:offset_end] del contiguous_param else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) @@ -424,18 +381,10 @@ def _load_from_state_dict( to_offset_end = offset_end + param_st - storage_st # copy to buffer - # PyTorch 1.11 changed the API of storage.__getitem__ - d_dtype = self._storage_params[kw_name].dtype - d_device = self._storage_params[kw_name].device - torch.tensor([], dtype=d_dtype, device=d_device).set_( - self._storage_params[kw_name].storage(), - to_offset_st, - (to_offset_end - to_offset_st,), - )[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_( - contiguous_param.storage(), offset_st, (offset_end - offset_st,) - )[ - : - ] + with torch.no_grad(): + self._storage_params[kw_name].data.view(-1)[ + to_offset_st:to_offset_end + ].copy_(contiguous_param.view(-1)[offset_st:offset_end]) del contiguous_param elif strict: missing_keys.append(key) @@ -549,12 +498,7 @@ def init_parameters(self): assert offset_st < offset_end # copy to buffer - # PyTorch 1.11 changed the API of storage.__getitem__ - d_dtype = self._storage_params[kw_name].dtype - d_device = self._storage_params[kw_name].device - param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_( - tmp_tensor.storage(), offset_st, (offset_end - offset_st,) - )[:] + param.data[:] = tmp_tensor.view(-1)[offset_st:offset_end] del tmp_tensor def _named_members(self, get_members_fn, prefix="", recurse=True, **kwargs): diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index d1b489e2..e5f91586 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -20,12 +20,12 @@ ] def send_activations(hidden_state, next_rank, comm): send_meta(hidden_state, next_rank, comm) - ncclSend(hidden_state.storage(), next_rank, comm) + ncclSend(hidden_state.contiguous().view(-1), next_rank, comm) def recv_activations(prev_rank, comm): dtype, shape = recv_meta(prev_rank, comm) hidden_state = torch.empty(shape, dtype=dtype, device="cuda") - ncclRecv(hidden_state.storage(), prev_rank, comm) + ncclRecv(hidden_state.view(-1), prev_rank, comm) return hidden_state def send_meta(x, next_rank, comm): @@ -34,11 +34,11 @@ def send_meta(x, next_rank, comm): meta_data[1] = DTYPE_LIST.index(x.dtype) meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int) meta_data = meta_data.contiguous() - ncclSend(meta_data.storage(), next_rank, comm) + ncclSend(meta_data, next_rank, comm) def recv_meta(prev_rank, comm): meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) - ncclRecv(meta_data.storage(), prev_rank, comm) + ncclRecv(meta_data, prev_rank, comm) n_dims = meta_data[0].item() dtype = DTYPE_LIST[meta_data[1].item()] shape = meta_data[2:n_dims+2].tolist() @@ -52,7 +52,7 @@ def forward(ctx, src, root, comm = None): comm = config["comm"] ctx.comm = comm outputs = torch.empty_like(src, dtype = src.dtype, device = src.device) - ncclBroadcast(src.storage(), outputs.storage(), root, comm) + ncclBroadcast(src.contiguous().view(-1), outputs.view(-1), root, comm) return outputs @staticmethod @@ -74,13 +74,14 @@ def forward(ctx, input : torch.Tensor, comm = None): world_size = commCount(comm) if not input.is_contiguous(): input = input.contiguous() - if input.storage_offset() != 0 or input.storage().size() != input.numel(): + # Clone if storage_offset != 0 so data_ptr points to the start of the tensor data. + if input.storage_offset() != 0: input = input.clone() output = torch.empty( (world_size,) + input.size(), dtype=input.dtype, device=input.device) ctx.comm = comm ncclAllGather( - input.storage(), - output.storage(), + input.view(-1), + output.view(-1), comm ) return output @@ -115,13 +116,14 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None) assert input.shape[0] % commCount(comm) == 0, "The dimension 0 must be divisible by the number of communication processes" if not input.is_contiguous(): input = input.contiguous() - if input.storage_offset() != 0 or input.storage().size() != input.numel(): + # Ensure data_ptr starts at offset 0 for NCCL. + if input.storage_offset() != 0: input = input.clone() output_shape = (input.shape[0] // commCount(comm), *input.shape[1:]) output = torch.empty( output_shape, dtype=input.dtype, device=input.device ) ncclReduceScatter( - input.storage(), - output.storage(), + input.view(-1), + output.view(-1), op, comm ) @@ -171,13 +173,14 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None) ctx.comm = comm if not input.is_contiguous(): input = input.contiguous() - if input.storage_offset() != 0 or input.storage().size() != input.numel(): + # Ensure data_ptr starts at offset 0 for NCCL. + if input.storage_offset() != 0: input = input.clone() output = torch.empty( input.size(), dtype=input.dtype, device=input.device) ncclAllReduce( - input.storage(), - output.storage(), + input.view(-1), + output.view(-1), op, comm ) diff --git a/bmtrain/inspect/model.py b/bmtrain/inspect/model.py index fc54f0d6..e11f8fdb 100644 --- a/bmtrain/inspect/model.py +++ b/bmtrain/inspect/model.py @@ -10,25 +10,24 @@ def _gather_value(value : torch.Tensor, partition_size, origin_size): global_size = partition_size * config['world_size'] - storage = value.storage_type()(global_size) + global_buffer = torch.empty(global_size, dtype=value.dtype, device=value.device) - if value.storage().size() != partition_size: + if value.numel() != partition_size: tmp_buf = torch.zeros(partition_size, dtype=value.dtype, device=value.device) tmp_buf[:value.numel()] = value[:] nccl.allGather( - tmp_buf.storage(), - storage, + tmp_buf, + global_buffer, config['comm'] ) else: nccl.allGather( - value.storage(), - storage, + value, + global_buffer, config['comm'] ) - output_tensor = torch.tensor([], dtype=value.dtype, device="cuda") - output_tensor.set_(storage, 0, origin_size) + output_tensor = global_buffer[:origin_size.numel()].view(origin_size) return output_tensor @@ -52,22 +51,27 @@ def inspect_pipeline_transformer_block_list(pipe_model: PipelineTransformerBlock _param_buffer = {} _grad_buffer = {} for kw, val in model._storage_info.items(): - storage_type = model._storage_params[kw].storage_type() - - _param_buffer[kw] = storage_type(val["partition_size"] * val['world_size']) - if model._storage_params[kw].grad is not None: - _grad_buffer[kw] = storage_type(val["partition_size"] * val['world_size']) + local_param = model._storage_params[kw] + _param_buffer[kw] = torch.empty( + val["partition_size"] * val['world_size'], + dtype=local_param.dtype, device=local_param.device + ) + if local_param.grad is not None: + _grad_buffer[kw] = torch.empty( + val["partition_size"] * val['world_size'], + dtype=local_param.dtype, device=local_param.device + ) nccl.groupStart() for kw, val in model._storage_info.items(): nccl.allGather( - model._storage_params[kw].storage(), + model._storage_params[kw], _param_buffer[kw], val["zero_comm"] ) if model._storage_params[kw].grad is not None: nccl.allGather( - model._storage_params[kw].grad.storage(), + model._storage_params[kw].grad, _grad_buffer[kw], val["zero_comm"] ) @@ -77,13 +81,12 @@ def inspect_pipeline_transformer_block_list(pipe_model: PipelineTransformerBlock abs_name = prefix + param["name"] if fnmatch.fnmatch(abs_name, param_name): kw_name = param["kw_name"] - dtype = _param_buffer[kw_name].dtype - device = _param_buffer[kw_name].device offset = param["offset"] shape = param["shape"] - p = torch.tensor([], dtype=dtype, device=device).set_(_param_buffer[kw_name], offset, shape) + numel = shape.numel() + p = _param_buffer[kw_name][offset:offset+numel].view(shape) if kw_name in _grad_buffer: - g = torch.tensor([], dtype=dtype, device=device).set_(_grad_buffer[kw_name], offset, shape) + g = _grad_buffer[kw_name][offset:offset+numel].view(shape) info = { "name": abs_name, "shape": tuple(shape), @@ -131,22 +134,27 @@ def inspect_block(model : Block, param_name : str, prefix : str = ''): _param_buffer = {} _grad_buffer = {} for kw, val in model._storage_info.items(): - storage_type = model._storage_params[kw].storage_type() - - _param_buffer[kw] = storage_type(val["partition_size"] * config['world_size']) - if model._storage_params[kw].grad is not None: - _grad_buffer[kw] = storage_type(val["partition_size"] * config['world_size']) + local_param = model._storage_params[kw] + _param_buffer[kw] = torch.empty( + val["partition_size"] * config['world_size'], + dtype=local_param.dtype, device=local_param.device + ) + if local_param.grad is not None: + _grad_buffer[kw] = torch.empty( + val["partition_size"] * config['world_size'], + dtype=local_param.dtype, device=local_param.device + ) nccl.groupStart() for kw, val in model._storage_info.items(): nccl.allGather( - model._storage_params[kw].storage(), + model._storage_params[kw], _param_buffer[kw], config["comm"] ) if model._storage_params[kw].grad is not None: nccl.allGather( - model._storage_params[kw].grad.storage(), + model._storage_params[kw].grad, _grad_buffer[kw], config["comm"] ) @@ -157,13 +165,12 @@ def inspect_block(model : Block, param_name : str, prefix : str = ''): abs_name = prefix + param["name"] if fnmatch.fnmatch(abs_name, param_name): kw_name = param["kw_name"] - dtype = _param_buffer[kw_name].dtype - device = _param_buffer[kw_name].device offset = param["offset"] shape = param["shape"] - p = torch.tensor([], dtype=dtype, device=device).set_(_param_buffer[kw_name], offset, shape) + numel = shape.numel() + p = _param_buffer[kw_name][offset:offset+numel].view(shape) if kw_name in _grad_buffer: - g = torch.tensor([], dtype=dtype, device=device).set_(_grad_buffer[kw_name], offset, shape) + g = _grad_buffer[kw_name][offset:offset+numel].view(shape) ret.append({ "name": abs_name, "shape": tuple(shape), @@ -217,7 +224,7 @@ def inspect_model(model : torch.nn.Module, param_name : str, prefix : str = ''): for name, param in model._parameters.items(): if fnmatch.fnmatch(prefix + name, param_name): if isinstance(param, DistributedParameter): - p = _gather_value(param.data, param.storage().size(), param._original_shape) + p = _gather_value(param.data, param._partition_size, param._original_shape) else: p = param if p is None: @@ -232,7 +239,7 @@ def inspect_model(model : torch.nn.Module, param_name : str, prefix : str = ''): } if param.grad is not None: if isinstance(param, DistributedParameter): - g = _gather_value(param.grad.data, param.storage().size(), param._original_shape) + g = _gather_value(param.grad.data, param._partition_size, param._original_shape) else: g = param.grad stats["grad_std"] = g.std().cpu().item() diff --git a/bmtrain/inspect/tensor.py b/bmtrain/inspect/tensor.py index 9d003f82..19281fb5 100644 --- a/bmtrain/inspect/tensor.py +++ b/bmtrain/inspect/tensor.py @@ -229,7 +229,7 @@ def get_summary(self): info = torch.empty(2, dtype=x.dtype, device=x.device) info[0] = x.mean() info[1] = x.var() - nccl.allReduce(info.storage(), info.storage(), "sum", comm) + nccl.allReduce(info, info, "sum", comm) info = info / nccl.commCount(comm) x_mean = info[0].cpu().item() x_std = math.sqrt(info[1].cpu().item()) @@ -242,7 +242,7 @@ def get_summary(self): info[1] = x.var() info[2] = x.grad.mean() info[3] = x.grad.var() - nccl.allReduce(info.storage(), info.storage(), "sum", comm) + nccl.allReduce(info, info, "sum", comm) info = info / nccl.commCount(comm) x_mean = info[0].cpu().item() x_std = math.sqrt(info[1].cpu().item()) @@ -251,7 +251,7 @@ def get_summary(self): info[0] = x.max() info[1] = -x.min() - nccl.allReduce(info.storage(), info.storage(), "max", comm) + nccl.allReduce(info, info, "max", comm) x_max = info[0].cpu().item() x_min = -info[1].cpu().item() diff --git a/bmtrain/nccl/__init__.py b/bmtrain/nccl/__init__.py index 0f4129d5..c339e6b6 100644 --- a/bmtrain/nccl/__init__.py +++ b/bmtrain/nccl/__init__.py @@ -101,34 +101,34 @@ def commRank(comm : NCCLCommunicator): """ return C.ncclCommUserRank(comm.ptr) def allReduce( - src : torch.storage._StorageBase, - dst : torch.storage._StorageBase, + src : torch.Tensor, + dst : torch.Tensor, op : Literal["sum", "prod", "max", "min", "avg"], comm : NCCLCommunicator ): """NCCL API: `ncclAllReduce `_ Args: - src (torch.storage._StorageBase): Source buffer. - dst (torch.storage._StorageBase): Destination buffer. + src (torch.Tensor): Source tensor. + dst (torch.Tensor): Destination tensor. op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation. comm (NCCLCommunicator): NCCL communicator. - The src and dst buffers must be the same size, type and on the same device. + The src and dst tensors must be the same size, type and on the same device. - If src == dst, the operation is performed in-place. + If src is dst, the operation is performed in-place. """ - assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.dtype == dst.dtype, "send and recv buffers must be the same type" assert src.is_cuda and dst.is_cuda sendbuff = src.data_ptr() recvbuff = dst.data_ptr() - count = src.size() + count = src.numel() datatype = dtype2nccl(src.dtype) operator = op2nccl(op) - assert src.size() == dst.size(), "Buffer size not aligned" + assert src.numel() == dst.numel(), "Buffer size not aligned" C.ncclAllReduce( sendbuff, recvbuff, @@ -138,20 +138,20 @@ def allReduce( comm.ptr, torch.cuda.current_stream().cuda_stream ) -def send(src : torch.storage._StorageBase, +def send(src : torch.Tensor, peer : int, comm : NCCLCommunicator ): """NCCL API: `ncclsend `_ Args: - src (torch.storage._StorageBase): Source buffer. + src (torch.Tensor): Source tensor. peer (int): rank peer needs to call ncclRecv comm (NCCLCommunicator): NCCL communicator. """ sendbuff = src.data_ptr() - count = src.size() + count = src.numel() datatype = dtype2nccl(src.dtype) C.ncclSend( sendbuff, @@ -161,12 +161,12 @@ def send(src : torch.storage._StorageBase, comm.ptr, torch.cuda.current_stream().cuda_stream ) -def recv(dst : torch.storage._StorageBase, +def recv(dst : torch.Tensor, peer : int, comm : NCCLCommunicator ): recvbuff = dst.data_ptr() - count = dst.size() + count = dst.numel() datatype = dtype2nccl(dst.dtype) C.ncclRecv( recvbuff, @@ -178,34 +178,34 @@ def recv(dst : torch.storage._StorageBase, ) def broadcast( - src : torch.storage._StorageBase, - dst : torch.storage._StorageBase, + src : torch.Tensor, + dst : torch.Tensor, root : int, comm : NCCLCommunicator ): """NCCL API: `ncclBroadcast `_ Args: - src (torch.storage._StorageBase): Source buffer. - dst (torch.storage._StorageBase): Destination buffer. + src (torch.Tensor): Source tensor. + dst (torch.Tensor): Destination tensor. root (int): Rank of the root. comm (NCCLCommunicator): NCCL communicator. - The src and dst buffers must be the same size, type and on the same device. + The src and dst tensors must be the same size, type and on the same device. - If src == dst, the operation is performed in-place. + If src is dst, the operation is performed in-place. """ - assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.dtype == dst.dtype, "send and recv buffers must be the same type" assert src.is_cuda and dst.is_cuda sendbuff = src.data_ptr() recvbuff = dst.data_ptr() - count = src.size() + count = src.numel() datatype = dtype2nccl(src.dtype) - assert dst.size() == src.size(), "Buffer size not aligned" + assert dst.numel() == src.numel(), "Buffer size not aligned" C.ncclBroadcast( sendbuff, recvbuff, @@ -217,8 +217,8 @@ def broadcast( ) def reduce( - src : torch.storage._StorageBase, - dst : torch.storage._StorageBase, + src : torch.Tensor, + dst : torch.Tensor, op : Literal["sum", "prod", "max", "min", "avg"], root : int, comm : NCCLCommunicator @@ -226,54 +226,52 @@ def reduce( """NCCL API: `ncclReduce `_ Args: - src (torch.storage._StorageBase): Source buffer. - dst (torch.storage._StorageBase): Destination buffer. + src (torch.Tensor): Source tensor. + dst (torch.Tensor): Destination tensor. op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation. root (int): Rank of the root. comm (NCCLCommunicator): NCCL communicator. - The src and dst buffers must be the same size, type and on the same device. + The src and dst tensors must be the same size, type and on the same device. - If src == dst, the operation is performed in-place. + If src is dst, the operation is performed in-place. """ - assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.dtype == dst.dtype, "send and recv buffers must be the same type" assert src.is_cuda and dst.is_cuda sendbuff = src.data_ptr() recvbuff = dst.data_ptr() - count = src.size() + count = src.numel() datatype = dtype2nccl(src.dtype) operator = op2nccl(op) - assert dst.size() == src.size(), "Buffer size not aligned" + assert dst.numel() == src.numel(), "Buffer size not aligned" C.ncclReduce(sendbuff, recvbuff, count, datatype, operator, root, comm.ptr, torch.cuda.current_stream().cuda_stream) def allGather( - src : torch.storage._StorageBase, - dst : torch.storage._StorageBase, + src : torch.Tensor, + dst : torch.Tensor, comm : NCCLCommunicator ): """NCCL API: `ncclAllGather `_ Args: - src (torch.storage._StorageBase): Source buffer. - dst (torch.storage._StorageBase): Destination buffer. + src (torch.Tensor): Source tensor. + dst (torch.Tensor): Destination tensor. comm (NCCLCommunicator): NCCL communicator. - The size of the dst buffer must be equal to the size of src buffer * world_size. - - The dst buffer is only used on rank root. + The size of the dst tensor must be equal to the size of src tensor * world_size. """ - assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.dtype == dst.dtype, "send and recv buffers must be the same type" assert src.is_cuda and dst.is_cuda sendbuff = src.data_ptr() recvbuff = dst.data_ptr() - sendcount = src.size() + sendcount = src.numel() datatype = dtype2nccl(src.dtype) - assert dst.size() % sendcount == 0, "Buffer size not aligned" + assert dst.numel() % sendcount == 0, "Buffer size not aligned" C.ncclAllGather( sendbuff, recvbuff, @@ -285,34 +283,34 @@ def allGather( def reduceScatter( - src : torch.storage._StorageBase, - dst : torch.storage._StorageBase, + src : torch.Tensor, + dst : torch.Tensor, op : Literal["sum", "prod", "max", "min", "avg"], comm : NCCLCommunicator ): """NCCL API: `ncclReduceScatter `_ Args: - src (torch.storage._StorageBase): Source buffer. - dst (torch.storage._StorageBase): Destination buffer. + src (torch.Tensor): Source tensor. + dst (torch.Tensor): Destination tensor. op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation. comm (NCCLCommunicator): NCCL communicator. - The size of the dst buffer must be equal to the size of src buffer / world_size. + The size of the dst tensor must be equal to the size of src tensor / world_size. - The dst buffer on rank `i` will contail the i-th block of the reduced result. + The dst tensor on rank `i` will contain the i-th block of the reduced result. """ - assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.dtype == dst.dtype, "send and recv buffers must be the same type" assert src.is_cuda and dst.is_cuda sendbuff = src.data_ptr() recvbuff = dst.data_ptr() - recvcount = dst.size() + recvcount = dst.numel() datatype = dtype2nccl(src.dtype) operator = op2nccl(op) - assert src.size() % recvcount == 0, "Buffer size not aligned" + assert src.numel() % recvcount == 0, "Buffer size not aligned" C.ncclReduceScatter( sendbuff, recvbuff, diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 60fed663..6790b2a6 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -2,4 +2,5 @@ from .column_parallel_linear import ColumnParallelLinear from .row_parallel_linear import RowParallelLinear from .parallel_embedding import VPEmbedding +from .parallel_projection import Projection, VPProjection from .parallel_linear_func import OpParallelLinear diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index 3bdc4e56..3f721bda 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -51,9 +51,10 @@ def forward(self, x: torch.Tensor, projection=False): out = F.embedding(x, weight) return out else: - x = bmt.distributed.all_gather(x, comm=bmt.config["tp_comm"]).view( - x.shape[0], -1, x.shape[-1] - ) + # Same as VPProjection: gather TP shards, merge to full hidden, then linear to this vocab partition. + shape = x.shape + g = bmt.distributed.all_gather(x, comm=bmt.config["tp_comm"]) + x = g.permute(1, 0, *range(2, g.ndim)).reshape(*shape[:-1], -1) return bmt.nn.OpParallelLinear.apply( x, self.weight, None, False, False, False, None, 1 ) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index e389cde6..c7fa04dd 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -91,7 +91,7 @@ def async_reduce_scatter_linear_func(input, weight, bias, async_chunks=2): shape[0] = shape[0] // config["tp_size"] outputs[i] = torch.empty(shape, dtype=out.dtype, device=out.device) nccl.reduceScatter( - out.storage(), outputs[i].storage(), "sum", config["tp_comm"] + out.contiguous().view(-1), outputs[i].view(-1), "sum", config["tp_comm"] ) current_stream.wait_stream(comm_stream) @@ -254,7 +254,7 @@ def forward( return out if reduce_output_type == ReduceType.ALL_REDUCE: - nccl.allReduce(out.storage(), out.storage(), "sum", config["tp_comm"]) + nccl.allReduce(out.contiguous().view(-1), out.view(-1), "sum", config["tp_comm"]) return out else: assert False, "no support reduce type{}".format(reduce_output_type) @@ -309,8 +309,8 @@ def backward(ctx, grad_output): grad_input.record_stream(config["tp_comm_stream"]) grad_all_input.record_stream(config["tp_comm_stream"]) nccl.reduceScatter( - grad_all_input.storage(), - grad_input.storage(), + grad_all_input.contiguous().view(-1), + grad_input.view(-1), "sum", config["tp_comm"], ) @@ -319,8 +319,8 @@ def backward(ctx, grad_output): config["tp_comm_stream"].wait_stream(current_stream) grad_input.record_stream(config["tp_comm_stream"]) nccl.allReduce( - grad_all_input.storage(), - grad_all_input.storage(), + grad_all_input.contiguous().view(-1), + grad_all_input.view(-1), "sum", config["tp_comm"], ) diff --git a/bmtrain/nn/parallel_projection.py b/bmtrain/nn/parallel_projection.py new file mode 100644 index 00000000..4883b310 --- /dev/null +++ b/bmtrain/nn/parallel_projection.py @@ -0,0 +1,76 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +from .linear import OpLinear +from .parallel_linear_func import OpParallelLinear + + +class Projection(bmt.DistributedModule): + """Output projection: linear map from hidden size to full vocabulary (reference / non-TP). + + Args: + vocab_size: number of classes / vocabulary size. + embedding_size: hidden dimension (input features). + dtype: parameter dtype. + init_mean, init_std: arguments for :func:`torch.nn.init.normal_` on the full weight matrix. + """ + + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1.0, + ): + super().__init__() + self.dim_model = embedding_size + self.vocab_size = vocab_size + self.weight = bmt.DistributedParameter( + torch.empty(vocab_size, embedding_size, dtype=dtype), + init_method=bmt.ParameterInitializer( + torch.nn.init.normal_, mean=init_mean, std=init_std + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return OpLinear.apply(x, self.weight, None) + + +class VPProjection(bmt.DistributedModule): + """Vocabulary-parallel output projection (weight sharded on dim 0). + + Each rank accepts a slice of the hidden dimension; tensors are gathered before linear. + Matches :class:`VPEmbedding` forward with ``projection=True``. + """ + + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1.0, + ): + super().__init__() + assert vocab_size % config["tp_size"] == 0 + self.dim_model = embedding_size + self.vocab_size_per_partition = vocab_size // config["tp_size"] + self.weight = bmt.DistributedParameter( + torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype), + init_method=bmt.ParameterInitializer( + torch.nn.init.normal_, mean=init_mean, std=init_std + ), + tp_split_dim=0, + tp_mode=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Each rank holds a slice of the last (hidden) dim. all_gather returns (tp_size, *x.shape); + # permute so batch leads, then reshape to full embedding_size on the last dim for F.linear. + shape = x.shape + g = bmt.distributed.all_gather(x, comm=config["tp_comm"]) + x = g.permute(1, 0, *range(2, g.ndim)).reshape(*shape[:-1], -1) + return OpParallelLinear.apply( + x, self.weight, None, False, False, False, None, 1 + ) diff --git a/bmtrain/optim/adam.py b/bmtrain/optim/adam.py index f99c483c..2eddc949 100644 --- a/bmtrain/optim/adam.py +++ b/bmtrain/optim/adam.py @@ -1,16 +1,48 @@ import torch from ..global_var import config from . import _function as F -import torch.optim._functional +# torch.optim._functional was removed in some PyTorch versions; fall back to manual Adam if unavailable. +try: + import torch.optim._functional as _optim_functional + _has_torch_adam = True +except ImportError: + _has_torch_adam = False from .. import C from .. import nccl import inspect -from ..utils import check_torch_version from copy import deepcopy from itertools import chain from collections import defaultdict +def _functional_adam_state_step(step) -> torch.Tensor: + """Scalar step tensor for torch.optim._functional.adam. + + PyTorch 2.x _multi_tensor_adam groups state tensors with params; step is only + exempted when it is CPU float32/float64. torch.tensor(int) yields int64 and + triggers RuntimeError in _group_tensors_by_device_and_dtype. + """ + if isinstance(step, torch.Tensor): + return step.detach().to(dtype=torch.float64, device="cpu").reshape(()) + return torch.tensor(float(step), dtype=torch.float64, device="cpu") + + +def _torch_functional_adam_other_kwargs(): + """Build extra kwargs needed by torch.optim._functional.adam across versions.""" + other_kwargs = {} + if not _has_torch_adam: + return other_kwargs + sig = inspect.signature(_optim_functional.adam).parameters + if "maximize" in sig: + other_kwargs["maximize"] = False + # PyTorch 2.4+ defaults foreach=True on CUDA; int64 step used to break grouping + if "foreach" in sig: + other_kwargs["foreach"] = False + if "fused" in sig: + other_kwargs["fused"] = False + return other_kwargs + + class AdamOptimizer(torch.optim.Optimizer): """ Adam optimizer support fp16 and bf16. @@ -112,33 +144,39 @@ def step(self, closure=None, scale=1): grad = p.grad if p.dtype == torch.float32: - other_kwargs = {} - if ( - "maximize" - in inspect.signature( - torch.optim._functional.adam - ).parameters - ): - other_kwargs["maximize"] = False - torch.optim._functional.adam( - [p], - [grad / scale], - [state["exp_avg"]], - [state["exp_avg_sq"]], - [], - ( - [state["step"]] - if check_torch_version("1.12.0") < 0 - else [torch.tensor(state["step"])] - ), - amsgrad=False, - beta1=group["betas"][0], - beta2=group["betas"][1], - lr=0.0 if state["step"] < self._hold_steps else group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - **other_kwargs - ) + if _has_torch_adam: + other_kwargs = _torch_functional_adam_other_kwargs() + _optim_functional.adam( + [p], + [grad / scale], + [state["exp_avg"]], + [state["exp_avg_sq"]], + [], + [_functional_adam_state_step(state["step"])], + amsgrad=False, + beta1=group["betas"][0], + beta2=group["betas"][1], + lr=0.0 if state["step"] < self._hold_steps else group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + **other_kwargs + ) + else: + # Fallback: manual Adam when torch.optim._functional is unavailable. + lr = 0.0 if state["step"] < self._hold_steps else group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + weight_decay = group["weight_decay"] + g = grad / scale + if weight_decay != 0: + g = g.add(p, alpha=weight_decay) + state["exp_avg"].mul_(beta1).add_(g, alpha=1 - beta1) + state["exp_avg_sq"].mul_(beta2).addcmul_(g, g, value=1 - beta2) + bias_correction1 = 1 - beta1 ** (state["step"] + 1) + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + step_size = lr / bias_correction1 + denom = (state["exp_avg_sq"].sqrt() / (bias_correction2 ** 0.5)).add_(eps) + p.addcdiv_(state["exp_avg"], denom, value=-step_size) state["step"] += 1 else: f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16 diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index f6ea97ba..4182a406 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -2,12 +2,17 @@ from ..global_var import config from . import _function as F from .. import nccl -import inspect -from ..utils import check_torch_version +# torch.optim._functional was removed in some PyTorch versions; fall back to manual Adam if unavailable. +try: + import torch.optim._functional as _optim_functional + _has_torch_adam = True +except ImportError: + _has_torch_adam = False from copy import deepcopy from itertools import chain from collections import defaultdict from ._distributed import state_dict_gather +from .adam import _functional_adam_state_step, _torch_functional_adam_other_kwargs class AdamOffloadOptimizer(torch.optim.Optimizer): @@ -160,31 +165,36 @@ def step(self, closure=None, scale=1): grad = -state["_grad_fp32"] else: grad = state["_grad_fp32"] - other_kwargs = {} - if ( - "maximize" - in inspect.signature(torch.optim._functional.adam).parameters - ): - other_kwargs["maximize"] = False - torch.optim._functional.adam( - [state["_param_fp32"]], - [grad], - [state["exp_avg"]], - [state["exp_avg_sq"]], - [], - ( - [state["step"]] - if check_torch_version("1.12.0") < 0 - else [torch.tensor(state["step"])] - ), - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=0.0 if state["step"] < self._hold_steps else lr, - weight_decay=weight_decay, - eps=eps, - **other_kwargs - ) + if _has_torch_adam: + other_kwargs = _torch_functional_adam_other_kwargs() + _optim_functional.adam( + [state["_param_fp32"]], + [grad], + [state["exp_avg"]], + [state["exp_avg_sq"]], + [], + [_functional_adam_state_step(state["step"])], + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=0.0 if state["step"] < self._hold_steps else lr, + weight_decay=weight_decay, + eps=eps, + **other_kwargs + ) + else: + # Fallback: manual Adam when torch.optim._functional is unavailable. + actual_lr = 0.0 if state["step"] < self._hold_steps else lr + g = grad.clone() + if weight_decay != 0: + g = g.add(state["_param_fp32"], alpha=weight_decay) + state["exp_avg"].mul_(beta1).add_(g, alpha=1 - beta1) + state["exp_avg_sq"].mul_(beta2).addcmul_(g, g, value=1 - beta2) + bias_correction1 = 1 - beta1 ** (state["step"] + 1) + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + step_size = actual_lr / bias_correction1 + denom = (state["exp_avg_sq"].sqrt() / (bias_correction2 ** 0.5)).add_(eps) + state["_param_fp32"].addcdiv_(state["exp_avg"], denom, value=-step_size) # transfer parameters back to device asynchronously param.copy_(state["_param_fp32"], non_blocking=True) state["step"] += 1 diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 1a98ed92..dd6a5c9c 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -7,17 +7,18 @@ from ..global_var import config def check_overflow(param_groups): - # check overflow - has_inf_or_nan = torch.zeros(1, dtype=torch.uint8, device="cuda")[0] + # Use a 1-element tensor (not a scalar) so nccl.allReduce receives a valid data_ptr. + has_inf_or_nan = torch.zeros(1, dtype=torch.uint8, device="cuda") for group in param_groups: for p in group['params']: if p.grad is not None: if p.dtype != torch.float: - has_inf_nan(p.grad, has_inf_or_nan) + # Pass a 1-element slice to keep the tensor contiguous with nonzero numel. + has_inf_nan(p.grad, has_inf_or_nan[0:1]) if "comm" in config: - nccl.allReduce(has_inf_or_nan.storage(), has_inf_or_nan.storage(), "max", config["comm"]) + nccl.allReduce(has_inf_or_nan, has_inf_or_nan, "max", config["comm"]) - if has_inf_or_nan > 0: + if has_inf_or_nan[0] > 0: raise OverflowError("Gradient overflow") def grad_rescale(param_groups, scale): @@ -179,16 +180,17 @@ def clip_grad_norm(self, param_groups, max_norm, norm_type=2, eps=1e-6): grads.append(torch.zeros_like(p.data)) if norm_type == 'inf': - total_norm_cuda = max(g.data.abs().max() for g in grads).detach() - nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "max", config["comm"]) - total_norm = total_norm_cuda + # unsqueeze to 1-element tensor so nccl.allReduce gets a valid buffer. + total_norm_cuda = max(g.data.abs().max() for g in grads).detach().unsqueeze(0) + nccl.allReduce(total_norm_cuda, total_norm_cuda, "max", config["comm"]) + total_norm = total_norm_cuda[0] else: norm_type = float(norm_type) - total_norm_cuda = torch.cuda.FloatTensor([0]) + total_norm_cuda = torch.tensor([0.0], dtype=torch.float32, device="cuda") for index, g in enumerate(grads): param_norm = g.data.float().norm(norm_type) total_norm_cuda += param_norm ** norm_type - nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "sum", config["comm"]) + nccl.allReduce(total_norm_cuda, total_norm_cuda, "sum", config["comm"]) total_norm = total_norm_cuda[0] ** (1. / norm_type) # total_norm = total_norm / scale # clip_coef = float(max_norm) / (total_norm + eps) @@ -208,8 +210,17 @@ def _justify_scale(self, scale): self.steps_since_last_scale = 0 def state_dict(self, gather_opt=False) -> dict: + def _optimizer_state_dict(opt): + # BMTrain optimizers (e.g. AdamOffloadOptimizer) accept gather=; std PyTorch Optimizer does not. + if gather_opt: + try: + return opt.state_dict(gather_opt) + except TypeError: + pass + return opt.state_dict() + return { - "optimizers": [opt.state_dict(gather_opt) for opt in self.optimizers], + "optimizers": [_optimizer_state_dict(opt) for opt in self.optimizers], "lr_schedulers": [lrs.state_dict() if lrs else None for lrs in self.lr_schedulers], "loss_scale": self.loss_scale, "loss_scale_enabled": self.loss_scale_enabled, diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index 21f95f25..bbf6f86f 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -18,11 +18,9 @@ def init_distributed_parameter(params: Iterable[torch.nn.Parameter]): if param._init_method is None: continue with torch.no_grad(): - partition_size = param.storage().size() + partition_size = param._partition_size global_size = partition_size * config["tp_zero_size"] * config["tp_size"] - tmp_storage = param.storage_type()(global_size) - tmp_tensor = torch.tensor([], dtype=param.dtype, device="cuda") - tmp_tensor.set_(tmp_storage, 0, param._tp_original_shape) + tmp_tensor = torch.empty(param._tp_original_shape, dtype=param.dtype, device="cuda") param._init_method(tmp_tensor) if param._tp_mode and param._tp_split_dim >= 0: @@ -39,17 +37,10 @@ def init_distributed_parameter(params: Iterable[torch.nn.Parameter]): begin = config["tp_zero_rank"] else: begin = config["zero_rank"] - end = begin + 1 - - # Pytorch 1.11 changed the API of storage.__getitem__ - torch.tensor([], dtype=param.dtype, device=param.device).set_( - param.storage() - )[:] = torch.tensor([], dtype=param.dtype, device=param.device).set_( - tmp_tensor.storage() - )[ - partition_size * begin : partition_size * end - ] - # param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)]) + + tmp_flat = tmp_tensor.view(-1) + src_slice = tmp_flat[partition_size * begin : partition_size * (begin + 1)] + param.data[:src_slice.numel()] = src_slice[:param.numel()] def iterate_parameters(model: torch.nn.Module): diff --git a/bmtrain/parameter.py b/bmtrain/parameter.py index 2dad4a3d..dfbabfd5 100644 --- a/bmtrain/parameter.py +++ b/bmtrain/parameter.py @@ -43,7 +43,6 @@ def __new__( num_of_elements = data.numel() - cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda") if tp_mode: comm = config["tp_zero_comm"] else: @@ -58,16 +57,23 @@ def __new__( tp_original_shape = list(original_shape) tp_original_shape[tp_split_dim] *= config["tp_size"] - cuda_storage = cuda_tensor.storage_type()(cuda_storage_size) + storage_tensor = torch.empty(cuda_storage_size, dtype=data.dtype, device="cuda") start_of_partition = cuda_storage_size * rank end_of_partition = min(num_of_elements, cuda_storage_size * (rank + 1)) - # FX: cuda_tensor_size < 0 if num_of_elements is too small cuda_tensor_size = max(end_of_partition - start_of_partition, 0) - cuda_tensor.set_(cuda_storage, 0, (cuda_tensor_size,)) - cuda_tensor.copy_(data.view(-1)[start_of_partition:end_of_partition]) + cuda_tensor = storage_tensor[:cuda_tensor_size] + # Detach, flatten, clone on CUDA so we read plain tensor storage. Copying from + # `nn.Parameter(...).view(-1)[...]` alone can yield zeros for the common pattern + # `self.weight = nn.Parameter(torch.empty(...)); init_(self.weight)` (see tests). + with torch.no_grad(): + src = data.detach() + if src.device.type != "cuda": + src = src.cuda() + src_flat = src.reshape(-1).contiguous().clone() + cuda_tensor.copy_(src_flat[start_of_partition:end_of_partition]) ret = torch.Tensor._make_subclass(cls, cuda_tensor, requires_grad) setattr(ret, "_original_shape", original_shape) @@ -81,6 +87,7 @@ def __new__( setattr(ret, "_zero_comm", comm) setattr(ret, "_tp_split_dim", tp_split_dim) setattr(ret, "_tp_original_shape", tp_original_shape) + setattr(ret, "_partition_size", cuda_storage_size) return ret @property @@ -103,7 +110,7 @@ def gather(self) -> torch.Tensor: current_stream.wait_stream(config["load_stream"]) return output_tensor - def gather_all(self) -> torch.tensor: + def gather_all(self) -> torch.Tensor: """Gather the data from ZeRO and Tensor Parallel distributed nodes. Return: @@ -124,7 +131,7 @@ def gather_all(self) -> torch.tensor: else: return zero_param - def tp_gather(self) -> torch.tensor: + def tp_gather(self) -> torch.Tensor: """Gather the data from Tensor Parallel distributed nodes. Return: @@ -145,28 +152,40 @@ def tp_gather(self) -> torch.tensor: return self def _copy_data(self, data: torch.Tensor): - """Copy data to self.data.""" - self.data.copy_(data.view(-1)[self._start_partition : self._end_partition]) + """Copy data to self.data. + + Detach → move to CUDA → flatten → clone to ensure we read from a + materialized contiguous buffer, avoiding stale-storage issues when the + source comes from nn.Parameter wrappers. + """ + with torch.no_grad(): + src = data.detach() + if src.device.type != "cuda": + src = src.cuda() + flat = src.reshape(-1).contiguous().clone() + self.data.copy_(flat[self._start_partition : self._end_partition]) class OpAllGather(torch.autograd.Function): @staticmethod def forward(ctx, value: DistributedParameter): assert isinstance(value, DistributedParameter) - comm = value._zero_comm # config['zero_comm'] + comm = value._zero_comm world_size = nccl.commCount(comm) ctx.comm = comm ctx.world_size = world_size - partition_size = value.storage().size() + partition_size = value._partition_size global_size = partition_size * world_size - storage = value.storage_type()(global_size) + global_buffer = torch.empty(global_size, dtype=value.dtype, device="cuda") + # Pad local data to partition_size; value.numel() may be smaller on the last rank. + local_buffer = torch.empty(partition_size, dtype=value.dtype, device="cuda") + local_buffer[:value.numel()].copy_(value) - nccl.allGather(value.storage(), storage, comm) + nccl.allGather(local_buffer, global_buffer, comm) - output_tensor = torch.tensor([], dtype=value.dtype, device="cuda") - output_tensor.set_(storage, 0, value._original_shape) + output_tensor = global_buffer[:value._original_shape.numel()].view(value._original_shape) ctx.partition_size = partition_size ctx.tensor_size = value.size(0) @@ -177,16 +196,20 @@ def backward(ctx, grad_output: torch.Tensor): if not grad_output.is_contiguous(): grad_output = grad_output.contiguous() - grad_storage = grad_output.storage_type()(ctx.partition_size) - grad_output_storage = grad_output.storage() - if grad_output_storage.size() == ctx.partition_size * ctx.world_size: - pass - else: - grad_output_storage.resize_(ctx.partition_size * ctx.world_size) - nccl.reduceScatter(grad_output_storage, grad_storage, "sum", ctx.comm) - grad_tensor = torch.tensor([], dtype=grad_output.dtype, device="cuda") - grad_tensor.set_(grad_storage, 0, (ctx.tensor_size,)) - return grad_tensor + expected_size = ctx.partition_size * ctx.world_size + # Pad or truncate grad to match the allGather buffer size (partition_size * world_size), + # because the original tensor may have been smaller due to non-divisible partitioning. + grad_flat = grad_output.reshape(-1) + if grad_flat.numel() < expected_size: + padded = torch.zeros(expected_size, dtype=grad_output.dtype, device=grad_output.device) + padded[:grad_flat.numel()].copy_(grad_flat) + grad_flat = padded + elif grad_flat.numel() > expected_size: + grad_flat = grad_flat[:expected_size].contiguous() + + grad_partition = torch.empty(ctx.partition_size, dtype=grad_output.dtype, device=grad_output.device) + nccl.reduceScatter(grad_flat, grad_partition, "sum", ctx.comm) + return grad_partition[:ctx.tensor_size] class ParameterInitializer: diff --git a/bmtrain/store.py b/bmtrain/store.py index 2a3ee02c..a91462f3 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -33,7 +33,8 @@ def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''): for name, module in model._modules.items(): if module is not None: _save_to_local_rank0(module, destination, prefix + name + '.') - for hook in model._state_dict_hooks.values(): + # _state_dict_hooks may not exist in newer PyTorch versions; use getattr for safety. + for hook in getattr(model, '_state_dict_hooks', {}).values(): hook_result = hook(model, destination, prefix, local_metadata) if hook_result is not None: destination = hook_result @@ -50,7 +51,7 @@ def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''): for name, module in model._modules.items(): if module is not None: _save_to_rank0(module, destination, prefix + name + '.') - for hook in model._state_dict_hooks.values(): + for hook in getattr(model, '_state_dict_hooks', {}).values(): hook_result = hook(model, destination, prefix, local_metadata) if hook_result is not None: destination = hook_result @@ -74,7 +75,7 @@ def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None, infer_model.load_layer_state_dict(local_state_dict) else: _save_to_infer_model(module, infer_model, destination, prefix + name + '.') - for hook in model._state_dict_hooks.values(): + for hook in getattr(model, '_state_dict_hooks', {}).values(): hook_result = hook(model, destination, prefix, local_metadata) if hook_result is not None: destination = hook_result @@ -147,8 +148,8 @@ def allgather_objects(obj): max_data_length = gathered_length.max().item() gpu_data_bytes = torch.zeros(max_data_length, dtype=torch.uint8, device="cuda") - byte_storage = torch.ByteStorage.from_buffer(data_bytes) - gpu_data_bytes[:data_length] = torch.ByteTensor(byte_storage) + byte_tensor = torch.frombuffer(bytearray(data_bytes), dtype=torch.uint8) + gpu_data_bytes[:data_length] = byte_tensor gathered_data = bmt.distributed.all_gather(gpu_data_bytes).cpu() @@ -162,38 +163,34 @@ def broadcast_object(obj, comm, src = 0): if nccl.commRank(comm) == src: f = io.BytesIO() _pickler(f).dump(obj) - byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) - # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. - # Otherwise, it will casue 100X slowdown. - # See: https://github.com/pytorch/pytorch/issues/65696 - byte_tensor = torch.ByteTensor(byte_storage).cuda() - local_size = torch.LongTensor([byte_tensor.numel()]).cuda() + byte_tensor = torch.frombuffer(bytearray(f.getvalue()), dtype=torch.uint8).cuda() + local_size = torch.tensor([byte_tensor.numel()], dtype=torch.long, device="cuda") nccl.broadcast( - local_size.storage(), - local_size.storage(), + local_size, + local_size, src, comm ) nccl.broadcast( - byte_tensor.storage(), - byte_tensor.storage(), + byte_tensor, + byte_tensor, src, comm ) else: - local_size = torch.LongTensor([0]).cuda() + local_size = torch.tensor([0], dtype=torch.long, device="cuda") nccl.broadcast( - local_size.storage(), - local_size.storage(), + local_size, + local_size, src, comm ) byte_tensor_size = local_size[0].item() byte_tensor = torch.empty(int(byte_tensor_size), dtype=torch.uint8, device="cuda") nccl.broadcast( - byte_tensor.storage(), - byte_tensor.storage(), + byte_tensor, + byte_tensor, src, comm ) @@ -219,15 +216,15 @@ def broadcast(self): input_param = input_param.cuda().contiguous() nccl.broadcast( - input_param.storage(), - output_param.storage(), + input_param.view(-1), + output_param.view(-1), 0, config['comm'] ) else: nccl.broadcast( - output_param.storage(), - output_param.storage(), + output_param.view(-1), + output_param.view(-1), 0, config['comm'] ) @@ -266,8 +263,8 @@ def __getitem__(self, key : str): tmp_shape[2:2 + shape_list.size(0)] = shape_list nccl.broadcast( - tmp_shape.storage(), - tmp_shape.storage(), + tmp_shape, + tmp_shape, 0, config['comm'] ) @@ -313,7 +310,8 @@ def load(model : torch.nn.Module, file_name : str, strict : bool = True): >>> bmtrain.load(model, "model.pt", strict=True) """ if config['rank'] == 0: - state_dict = DistributedStateDictWrapper(torch.load(file_name)) + # weights_only=False: BMTrain checkpoints may contain non-tensor objects (e.g. metadata). + state_dict = DistributedStateDictWrapper(torch.load(file_name, weights_only=False)) else: state_dict = DistributedStateDictWrapper({}) diff --git a/bmtrain/synchronize.py b/bmtrain/synchronize.py index 87619159..8713ded2 100644 --- a/bmtrain/synchronize.py +++ b/bmtrain/synchronize.py @@ -13,8 +13,8 @@ def synchronize(): raise RuntimeError("BMTrain is not initialized") with torch.cuda.stream(config["barrier_stream"]): - barrier = torch.cuda.FloatTensor([1]) - nccl.allReduce(barrier.storage(), barrier.storage(), "sum", config["comm"]) + barrier = torch.tensor([1.0], dtype=torch.float32, device="cuda") + nccl.allReduce(barrier, barrier, "sum", config["comm"]) config["barrier_stream"].synchronize() @@ -53,8 +53,8 @@ def gather_result(result: torch.Tensor): "bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning, ) - if result.storage_offset() != 0 or result.storage().size() != result.numel(): - # Create a clone of the original tensor if it's a slice + # Clone sliced or non-contiguous tensors so data_ptr is at offset 0 for NCCL. + if result.storage_offset() != 0 or not result.is_contiguous(): result = result.clone() output_cuda = True @@ -66,7 +66,7 @@ def gather_result(result: torch.Tensor): device=result.device, dtype=result.dtype, ) - nccl.allGather(result.storage(), ret.storage(), config["comm"]) + nccl.allGather(result, ret, config["comm"]) if output_cuda: return ret else: diff --git a/bmtrain/utils.py b/bmtrain/utils.py index daa4c595..b82fb1c6 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -35,8 +35,12 @@ def load_nccl_pypi(): """ try: import nvidia.nccl - except: - raise ImportError("Run pip install nvidia-nccl-cu11 >=2.14.3 first") + except ImportError: + raise ImportError( + "NCCL not found. Install the appropriate package:\n" + " pip install nvidia-nccl-cu12>=2.14.3 (for CUDA 12.x)\n" + " pip install nvidia-nccl-cu11>=2.14.3 (for CUDA 11.x)" + ) path = os.path.join(os.path.dirname(nvidia.nccl.__file__), "lib") for file_so in os.listdir(path): diff --git a/bmtrain/zero_context.py b/bmtrain/zero_context.py index 8a74b3f8..dfa98f1d 100644 --- a/bmtrain/zero_context.py +++ b/bmtrain/zero_context.py @@ -20,8 +20,6 @@ def __init__(self, block: "Block", ctx_dict: dict = None, pipe=False) -> None: self.ctx_dict = ctx_dict self._param_buffer = {} self._grad_buffer = {} - self._param_tensor = {} - self._grad_tensor = {} self._need_release = False def enter(self, flag=0, requires_grad=False): @@ -41,35 +39,24 @@ def enter(self, flag=0, requires_grad=False): assert kw not in self._param_buffer local_param = self.block._storage_params[kw] - storage_type = local_param.storage_type() if flag != 2: - self._param_buffer[kw] = storage_type( - val["partition_size"] * val["world_size"] + self._param_buffer[kw] = torch.empty( + val["partition_size"] * val["world_size"], + dtype=local_param.dtype, + device=local_param.device, ) - self._param_tensor[kw] = torch.tensor( - [], - dtype=self._param_buffer[kw].dtype, - device=self._param_buffer[kw].device, - ).set_(self._param_buffer[kw]) if requires_grad and local_param.requires_grad: - self._grad_buffer[kw] = storage_type( - val["partition_size"] * val["world_size"] - ) - self._grad_tensor[kw] = ( - torch.tensor( - [], - dtype=self._grad_buffer[kw].dtype, - device=self._grad_buffer[kw].device, - ) - .set_(self._grad_buffer[kw]) - .zero_() + self._grad_buffer[kw] = torch.zeros( + val["partition_size"] * val["world_size"], + dtype=local_param.dtype, + device=local_param.device, ) if flag != 2: nccl.groupStart() for kw, val in self.block._storage_info.items(): nccl.allGather( - self.block._storage_params[kw].storage(), + self.block._storage_params[kw], self._param_buffer[kw], val["zero_comm"], ) @@ -78,40 +65,29 @@ def enter(self, flag=0, requires_grad=False): current_stream = torch.cuda.current_stream() current_stream.wait_stream(config["load_stream"]) - # set wait stream for each storage for kw in self.block._storage_info.keys(): if flag != 2: - self._param_tensor[kw].record_stream(current_stream) - if requires_grad and kw in self._grad_tensor: - self._grad_tensor[kw].record_stream(current_stream) + self._param_buffer[kw].record_stream(current_stream) + if requires_grad and kw in self._grad_buffer: + self._grad_buffer[kw].record_stream(current_stream) - # update parameters in block for param in self.block._param_info: kw_name = param["kw_name"] offset = param["offset"] shape = param["shape"] + numel = shape.numel() if flag != 2: - dtype = self._param_buffer[kw_name].dtype - device = self._param_buffer[kw_name].device - param["parameter"].data = torch.tensor( - [], dtype=dtype, device=device - ).set_(self._param_buffer[kw_name], offset, shape) + param["parameter"].data = self._param_buffer[kw_name][offset:offset+numel].view(shape) else: - dtype = param["parameter"].data.dtype - device = param["parameter"].data.device - param["parameter"].data = torch.tensor( - [], dtype=dtype, device=device - ).set_(self.ctx_dict[kw_name], offset, shape) + param["parameter"].data = self.ctx_dict[kw_name][offset:offset+numel].view(shape) if ( requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad ): - param["parameter"].grad = torch.tensor( - [], dtype=dtype, device=device - ).set_(self._grad_buffer[kw_name], offset, shape) + param["parameter"].grad = self._grad_buffer[kw_name][offset:offset+numel].view(shape) def __enter__(self): self.enter() @@ -128,48 +104,38 @@ def exit(self, flag=0, backward=False): for kw, val in self.block._storage_info.items(): local_param = self.block._storage_params[kw] - # accumulate previous gradient if local_param.requires_grad: if local_param.grad is None: - grad_storage = val["storage_type"]( - val["partition_size"] - ) # initialize gradient if not exist - local_param.grad = ( - torch.tensor( - [], dtype=grad_storage.dtype, device=grad_storage.device - ) - .set_(grad_storage) - .zero_() + local_param.grad = torch.zeros( + val["partition_size"], + dtype=local_param.dtype, + device=local_param.device, ) else: - self._grad_tensor[kw][ + self._grad_buffer[kw][ val["begin"] : val["end"] ] += local_param.grad current_stream = torch.cuda.current_stream() - config["load_stream"].wait_stream(current_stream) # wait for backward + config["load_stream"].wait_stream(current_stream) with torch.cuda.stream(config["load_stream"]): nccl.groupStart() for kw, val in self.block._storage_info.items(): local_param = self.block._storage_params[kw] - # scatter gradient if local_param.requires_grad: nccl.reduceScatter( self._grad_buffer[kw], - local_param.grad.storage(), + local_param.grad, "sum", val["zero_comm"], ) nccl.groupEnd() - # set wait stream for each storage - for kw in self._grad_tensor.keys(): - # grads can not be freed until reduce ops finish - self._grad_tensor[kw].record_stream(config["load_stream"]) + for kw in self._grad_buffer.keys(): + self._grad_buffer[kw].record_stream(config["load_stream"]) - # Release all parameters from buffer to block_storge for param in self.block._param_info: kw_name = param["kw_name"] dtype = self.block._storage_params[kw_name].dtype @@ -180,24 +146,18 @@ def exit(self, flag=0, backward=False): continue begin = param["begin"] end = param["end"] - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_( - self.block._storage_params[kw_name].storage(), begin, end - ) + size = end[0] if isinstance(end, tuple) else end + param["parameter"].data = self.block._storage_params[kw_name].view(-1)[begin:begin+size] if ( param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None ): - param["parameter"].grad = torch.tensor( - [], dtype=dtype, device=device - ).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + param["parameter"].grad = self.block._storage_params[kw_name].grad.view(-1)[begin:begin+size] if flag == 1: for i in self._param_buffer: self.ctx_dict[i] = self._param_buffer[i] - self._grad_tensor = {} - self._param_tensor = {} self._grad_buffer = {} self._param_buffer = {} def __exit__(self, exc_type, exc_val, exc_tb): - # reduce scatter gradients self.exit() diff --git a/docs/source-en/notes/installation.md b/docs/source-en/notes/installation.md index 330bfb21..aef36f0d 100644 --- a/docs/source-en/notes/installation.md +++ b/docs/source-en/notes/installation.md @@ -13,7 +13,7 @@ $ pip install bmtrain ```shell $ git clone https://github.com/OpenBMB/BMTrain.git $ cd BMTrain -$ python3 setup.py install +$ pip install . ``` ## Compilation Options @@ -27,7 +27,7 @@ By setting environment variables, you can configure the compilation options of B ### CUDA Compute Capability -`TORCH_CUDA_ARCH_LIST=6.0 6.1 7.0 7.5 8.0+PTX` +`TORCH_CUDA_ARCH_LIST="6.1 7.0 7.5 8.0 8.6 8.9 9.0"` ## Recommended Configuration diff --git a/docs/source/notes/installation.md b/docs/source/notes/installation.md index 3fff4eb6..b8995fb4 100644 --- a/docs/source/notes/installation.md +++ b/docs/source/notes/installation.md @@ -13,7 +13,7 @@ $ pip install bmtrain ```shell $ git clone https://github.com/OpenBMB/BMTrain.git $ cd BMTrain -$ python3 setup.py install +$ pip install . ``` ## 编译选项 @@ -27,7 +27,7 @@ $ python3 setup.py install ### CUDA计算兼容性 -`TORCH_CUDA_ARCH_LIST=6.0 6.1 7.0 7.5 8.0+PTX` +`TORCH_CUDA_ARCH_LIST="6.1 7.0 7.5 8.0 8.6 8.9 9.0"` ## 推荐配置 diff --git a/example/layers/attention.py b/example/layers/attention.py index 0f5155d4..a9d6bf96 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -88,13 +88,13 @@ def forward(self, score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), score, - torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype) + torch.tensor(float('-inf'), device=score.device, dtype=score.dtype) ) score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), self.softmax(score), - torch.scalar_tensor(0, device=score.device, dtype=score.dtype) + torch.tensor(0, device=score.device, dtype=score.dtype) ) score = score.view(-1, seq_q, seq_kv) diff --git a/example/layers/embedding.py b/example/layers/embedding.py index f62151c4..3a3be7be 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -52,11 +52,11 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, Examples:: - >>> # FloatTensor containing pretrained weights - >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> # Tensor containing pretrained weights + >>> weight = torch.tensor([[1, 2.3, 3], [4, 5.1, 6.3]], dtype=torch.float32) >>> embedding = nn.Embedding.from_pretrained(weight) >>> # Get embeddings for index 1 - >>> input = torch.LongTensor([1]) + >>> input = torch.tensor([1], dtype=torch.long) >>> embedding(input) tensor([[ 4.0000, 5.1000, 6.3000]]) """ diff --git a/pyproject.toml b/pyproject.toml index b563eb32..a88cce68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ requires = [ "setuptools", "pybind11", - "nvidia-nccl-cu11 >= 2.14.3", + "nvidia-nccl-cu12 >= 2.14.3", "cmake > 3.27.0" ] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 70752ff6..978a6a63 100644 --- a/setup.py +++ b/setup.py @@ -100,11 +100,11 @@ def build_extension(self, ext): packages=find_packages(), install_requires=[ "numpy", - "nvidia-nccl-cu11>=2.14.3" + "nvidia-nccl-cu12>=2.14.3" ], setup_requires=[ "pybind11", - "nvidia-nccl-cu11>=2.14.3" + "nvidia-nccl-cu12>=2.14.3" ], ext_modules=ext_modules, cmdclass={ diff --git a/tests/test_all.py b/tests/test_all.py index db5d2dd4..e7601cbe 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -35,7 +35,7 @@ ]) for t, num_gpu in tq: - PREFIX = f"python3 -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node={num_gpu} --master_addr=localhost --master_port=32123" + PREFIX = f"torchrun --nnodes=1 --nproc_per_node={num_gpu} --rdzv_backend=c10d --rdzv_endpoint=localhost:32123" SUFFIX = f"> test_log.txt 2>&1" command = f"{PREFIX} test_{t}.py {SUFFIX}" completedProc = subprocess.run(command, shell=True) diff --git a/tests/test_column_parallel_linear.py b/tests/test_column_parallel_linear.py index 5f2fdad1..f31abba5 100644 --- a/tests/test_column_parallel_linear.py +++ b/tests/test_column_parallel_linear.py @@ -15,7 +15,7 @@ def run_bmt(x, gather_input, gather_output, ckp_path, tp_size=2): def run_torch(x, ckp_path): linear = torch.nn.Linear(8, 8) - linear_dict = torch.load(ckp_path) + linear_dict = torch.load(ckp_path, weights_only=False) linear.load_state_dict(linear_dict) linear = linear.cuda() linear.weight.requires_grad_() diff --git a/tests/test_init_parameters.py b/tests/test_init_parameters.py index b67431f2..70918361 100644 --- a/tests/test_init_parameters.py +++ b/tests/test_init_parameters.py @@ -164,6 +164,11 @@ def test_main(): m[1] = Linear_NormalInitAfter(*shape) ret[1] = (m[1].weight.data, m[1].bias.data) + # Wrap m[1] right after ret[1] so this case stays next to Linear_NormalInitAfter; + # DistributedParameter now materializes a flat clone when copying (see parameter.py). + m[5] = bmt.BMTrainModelWrapper(m[1]) + ret[5] = (m[5].weight.data, m[5].bias.data) + # bmtrain manual_seed(33) m[2] = Linear_BMTInitializer(*shape) @@ -184,10 +189,6 @@ def test_main(): m[4] = bmt.BMTrainModelWrapper(m[0]) ret[4] = (m[4].weight.data, m[4].bias.data) - manual_seed(33) - m[5] = bmt.BMTrainModelWrapper(m[1]) - ret[5] = (m[5].weight.data, m[5].bias.data) - manual_seed(33) m[6] = Linear_Pipeline(*shape) bmt.init_parameters(m[6]) diff --git a/tests/test_init_parameters_multi_gpu.py b/tests/test_init_parameters_multi_gpu.py index 1e61568c..1339b07e 100644 --- a/tests/test_init_parameters_multi_gpu.py +++ b/tests/test_init_parameters_multi_gpu.py @@ -97,8 +97,8 @@ def forward(self, input): def check(ckpt_path, ckpt_path_ref): if bmt.rank() == 0: - ckpt1 = torch.load(ckpt_path) - ckpt2 = torch.load(ckpt_path_ref) + ckpt1 = torch.load(ckpt_path, weights_only=False) + ckpt2 = torch.load(ckpt_path_ref, weights_only=False) for (k1, v1), (k2, v2) in zip(ckpt1.items(), ckpt2.items()): assert_eq(k1, k2) print(v1, v2) diff --git a/tests/test_model_wrapper.py b/tests/test_model_wrapper.py index 6f913d3c..fb1d2371 100644 --- a/tests/test_model_wrapper.py +++ b/tests/test_model_wrapper.py @@ -64,13 +64,13 @@ def forward(self, score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), score, - torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype) + torch.tensor(float('-inf'), device=score.device, dtype=score.dtype) ) score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), self.softmax(score), - torch.scalar_tensor(0, device=score.device, dtype=score.dtype) + torch.tensor(0, device=score.device, dtype=score.dtype) ) score = score.view(batch_size * self.num_heads, seq_q, seq_kv) diff --git a/tests/test_optim.py b/tests/test_optim.py index 0aca8c31..64d4e52d 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -46,7 +46,12 @@ def main(dtype): opt1 = bmt.optim.AdamOptimizer(model1.parameters(), lr=1) opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), lr=1) - opt3 = torch.optim.Adam(model3.parameters(), lr=1) + # Match bmtrain.optim.AdamOptimizer: it calls torch.optim._functional.adam with + # foreach=False, fused=False. PyTorch 2.4+ defaults foreach=True on CUDA; the + # multi-tensor path differs slightly in float rounding from the single-tensor path. + opt3 = torch.optim.Adam( + model3.parameters(), lr=1, foreach=False, fused=False + ) opt4 = bmt.optim.AdamOptimizer(model4.parameters(), lr=1) opt5 = bmt.optim.AdamOffloadOptimizer(model5.parameters(), lr=1) diff --git a/tests/test_optim_state.py b/tests/test_optim_state.py index 57d5d0e3..61c790de 100644 --- a/tests/test_optim_state.py +++ b/tests/test_optim_state.py @@ -108,7 +108,7 @@ def main(): optim_manager.add_optimizer(opt1, lrs1) optim_manager.add_optimizer(opt2, lrs2) optim_manager.add_optimizer(opt3, lrs3) - optim_manager.load_state_dict(torch.load(f"test_optim_manager_{bmt.rank()}.opt")) + optim_manager.load_state_dict(torch.load(f"test_optim_manager_{bmt.rank()}.opt", weights_only=False)) manual_seed() train(model1, model2, model3, optim_manager) diff --git a/tests/test_parallel_projection.py b/tests/test_parallel_projection.py index dc1e874d..93bc3740 100644 --- a/tests/test_parallel_projection.py +++ b/tests/test_parallel_projection.py @@ -1,11 +1,12 @@ import torch import bmtrain as bmt from bmtrain.global_var import config +from bmtrain.nn.parallel_projection import Projection, VPProjection import numpy as np import os def run_normal(x, t, ckp_path, dtype): - proj = bmt.nn.Projection(100, 64, dtype=dtype) + proj = Projection(100, 64, dtype=dtype) bmt.init_parameters(proj) bmt.save(proj, ckp_path) loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=False) @@ -16,7 +17,7 @@ def run_normal(x, t, ckp_path, dtype): return y, loss, y.grad def run_vp(x, t, ckp_path, dtype): - proj = bmt.nn.VPProjection(100, 64, dtype=dtype) + proj = VPProjection(100, 64, dtype=dtype) bmt.load(proj, ckp_path) loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) y = proj(x) @@ -30,10 +31,11 @@ def run(dtype): torch.cuda.manual_seed(100) tp_size = config["tp_size"] tp_rank = config['tp_rank'] - x = torch.randn(110, 64, device='cuda', dtype=dtype) + x_full = torch.randn(110, 64, device='cuda', dtype=dtype) + x_shard = x_full.chunk(tp_size, dim=-1)[tp_rank].contiguous() t = torch.cat([torch.arange(100).view(10, 10), torch.ones((10, 1))*-100], dim=-1).view(110).int().cuda() - y1, loss1, grad1 = run_normal(x, t, ckp_path, dtype) - y2, loss2, grad2 = run_vp(x, t, ckp_path, dtype) + y1, loss1, grad1 = run_normal(x_full, t, ckp_path, dtype) + y2, loss2, grad2 = run_vp(x_shard, t, ckp_path, dtype) y1 = y1.chunk(tp_size, dim=-1)[tp_rank] grad1 = grad1.chunk(tp_size, dim=-1)[tp_rank] for r in range(tp_size): diff --git a/tests/test_row_parallel_linear.py b/tests/test_row_parallel_linear.py index 23dce8b2..acec13de 100644 --- a/tests/test_row_parallel_linear.py +++ b/tests/test_row_parallel_linear.py @@ -16,7 +16,7 @@ def run_bmt(x, ckp_path, split_input=True, use_checkpoint_block=True): def run_torch(x, ckp_path): linear = torch.nn.Linear(8, 8) - linear_dict = torch.load(ckp_path) + linear_dict = torch.load(ckp_path, weights_only=False) linear.load_state_dict(linear_dict) linear = linear.cuda() linear.weight.requires_grad_() diff --git a/tests/test_training.py b/tests/test_training.py index 46389802..b8a98ac4 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -33,18 +33,18 @@ def clip_grad_norm(loss_scale, param_groups, max_norm, norm_type=2, eps=1e-6, is grads.append(torch.zeros_like(p.data)) if norm_type == 'inf': - total_norm_cuda = max(g.data.abs().max() for g in grads).detach() + total_norm_cuda = max(g.data.abs().max() for g in grads).detach().unsqueeze(0) if not is_torch: - bmt.nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "max", bmt.config["comm"]) - total_norm = total_norm_cuda + bmt.nccl.allReduce(total_norm_cuda, total_norm_cuda, "max", bmt.config["comm"]) + total_norm = total_norm_cuda[0] else: norm_type = float(norm_type) - total_norm_cuda = torch.cuda.FloatTensor([0]) + total_norm_cuda = torch.tensor([0.0], dtype=torch.float32, device="cuda") for index, g in enumerate(grads): param_norm = g.data.float().norm(norm_type) total_norm_cuda += param_norm ** norm_type if not is_torch: - bmt.nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "sum", bmt.config["comm"]) + bmt.nccl.allReduce(total_norm_cuda, total_norm_cuda, "sum", bmt.config["comm"]) total_norm = total_norm_cuda[0] ** (1. / norm_type) clip_coef = float(max_norm * scale) / (total_norm + eps) if clip_coef < 1: @@ -110,13 +110,13 @@ def forward(self, score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), score, - torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype) + torch.tensor(float('-inf'), device=score.device, dtype=score.dtype) ) score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), self.softmax(score), - torch.scalar_tensor(0, device=score.device, dtype=score.dtype) + torch.tensor(0, device=score.device, dtype=score.dtype) ) score = score.view(batch_size * self.num_heads, seq_q, seq_kv) @@ -408,7 +408,7 @@ def make_ref_ckpt(): ret = {} def torch_model(): model = GPT(**kwargs) - model.load_state_dict(torch.load(ckpt_path)) + model.load_state_dict(torch.load(ckpt_path, weights_only=False)) model = model.cuda() return model