Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build_whl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}

- name: Pull Docker image
run: docker pull pytorch/manylinux-cuda113:latest
run: docker pull pytorch/manylinux2_28-builder:cuda12.4

- name: Run Docker image and execute script
run: |
version=${{ matrix.python-version }}
docker run -e BUILD_DOCKER_ENV=1 -e CUDACXX=/usr/local/cuda-11.3/bin/nvcc -e PATH="/opt/rh/devtoolset-9/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-9/root/usr/lib64:/opt/rh/devtoolset-9/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i pytorch/manylinux-cuda113:latest /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/pip install build; /opt/python/cp${version}*/bin/python -m build .;for file in dist/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done"
docker run -e BUILD_DOCKER_ENV=1 -e CUDACXX=/usr/local/cuda-12.4/bin/nvcc -e PATH="/opt/rh/devtoolset-9/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-9/root/usr/lib64:/opt/rh/devtoolset-9/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i pytorch/manylinux2_28-builder:cuda12.4 /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/pip install build; /opt/python/cp${version}*/bin/python -m build .;for file in dist/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done"

- name: Upload wheels as artifacts
uses: actions/upload-artifact@v4
Expand Down
29 changes: 20 additions & 9 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
FROM nvidia/cuda:10.2-devel
FROM nvidia/cuda:12.8.0-devel-ubuntu22.04
WORKDIR /build

RUN apt update && apt install -y --no-install-recommends \
build-essential \
python3-dev \
python3-pip \
python3-setuptools \
python3-wheel
RUN pip3 install torch==1.10.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3 install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN apt install iputils-ping opensm libopensm-dev libibverbs1 libibverbs-dev -y --no-install-recommends
ENV TORCH_CUDA_ARCH_LIST=6.1;7.0;7.5
python3-wheel \
cmake \
ninja-build \
git \
iputils-ping opensm libopensm-dev libibverbs1 libibverbs-dev

RUN pip3 install --upgrade pip setuptools wheel -i https://pypi.tuna.tsinghua.edu.cn/simple

RUN pip3 install --break-system-packages torch==2.8.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3 install --break-system-packages numpy -i https://pypi.tuna.tsinghua.edu.cn/simple

ENV TORCH_CUDA_ARCH_LIST="6.1;7.0;7.5;8.0;8.6;8.9;9.0"
ENV BMT_AVX512=1

ADD other_requirements.txt other_requirements.txt
RUN pip3 install --upgrade pip && pip3 install -r other_requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3 install --break-system-packages -r other_requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

ADD . .
RUN python3 setup.py install
RUN pip3 install --break-system-packages .

WORKDIR /root
ADD example example
ADD example example
ADD tests tests
2 changes: 1 addition & 1 deletion bmtrain/benchmark/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def all_gather():
end_evt = torch.cuda.Event(enable_timing=True)

current_stream.record_event(start_evt)
nccl.allGather(partition_tensor.storage(), global_tensor.storage(), config['comm'])
nccl.allGather(partition_tensor.view(-1), global_tensor.view(-1), config['comm'])
current_stream.record_event(end_evt)
current_stream.synchronize()
time_usage = start_evt.elapsed_time(end_evt)
Expand Down
2 changes: 1 addition & 1 deletion bmtrain/benchmark/reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def reduce_scatter():
end_evt = torch.cuda.Event(enable_timing=True)

current_stream.record_event(start_evt)
nccl.reduceScatter(global_tensor.storage(), partition_tensor.storage(), 'avg', config['comm'])
nccl.reduceScatter(global_tensor.view(-1), partition_tensor.view(-1), 'avg', config['comm'])
current_stream.record_event(end_evt)
current_stream.synchronize()
time_usage = start_evt.elapsed_time(end_evt)
Expand Down
4 changes: 2 additions & 2 deletions bmtrain/benchmark/send_recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def send_recv():
current_stream.record_event(start_evt)
nccl.groupStart()
if config['rank'] in [0,2,4,6]:
nccl.send(send_buffer.storage(), config['rank']+1, config['comm'])
nccl.send(send_buffer.view(-1), config['rank']+1, config['comm'])
else:
nccl.recv(recv_buffer.storage(), config['rank']-1, config['comm'])
nccl.recv(recv_buffer.view(-1), config['rank']-1, config['comm'])
nccl.groupEnd()
current_stream.record_event(end_evt)
current_stream.synchronize()
Expand Down
76 changes: 10 additions & 66 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,6 @@
from torch.utils.checkpoint import checkpoint


def storage_type_cuda(storage_type):
"""Convert storage_type to cuda storage_type."""
STORAGE_MAP = {
torch.FloatStorage: torch.cuda.FloatStorage,
torch.DoubleStorage: torch.cuda.DoubleStorage,
torch.HalfStorage: torch.cuda.HalfStorage,
torch.BFloat16Storage: torch.cuda.BFloat16Storage,
torch.CharStorage: torch.cuda.CharStorage,
torch.ByteStorage: torch.cuda.ByteStorage,
torch.ShortStorage: torch.cuda.ShortStorage,
torch.IntStorage: torch.cuda.IntStorage,
torch.cuda.FloatStorage: torch.cuda.FloatStorage,
torch.cuda.DoubleStorage: torch.cuda.DoubleStorage,
torch.cuda.HalfStorage: torch.cuda.HalfStorage,
torch.cuda.BFloat16Storage: torch.cuda.BFloat16Storage,
torch.cuda.CharStorage: torch.cuda.CharStorage,
torch.cuda.ByteStorage: torch.cuda.ByteStorage,
torch.cuda.ShortStorage: torch.cuda.ShortStorage,
torch.cuda.IntStorage: torch.cuda.IntStorage,
}
if storage_type not in STORAGE_MAP:
raise ValueError("Unknown storage type: {}".format(storage_type))
return STORAGE_MAP[storage_type]


def _get_param_kw(param: DistributedParameter):
"""Get DistributedParameter kw name."""
type_name = str(param.dtype).split(".")[-1]
Expand Down Expand Up @@ -121,7 +96,6 @@ def init_param_storage(self):
"All parameters in checkpoint block must be DistributedParameter."
)

storage_type = storage_type_cuda(param.storage_type())
kw_name = _get_param_kw(param)

if kw_name not in self._storage_info:
Expand All @@ -136,7 +110,7 @@ def init_param_storage(self):

self._storage_info[kw_name] = {
"total": 0,
"storage_type": storage_type,
"dtype": param.dtype,
"requires_grad": param.requires_grad,
"group": param.group,
"zero_comm": zero_comm,
Expand Down Expand Up @@ -165,16 +139,10 @@ def init_param_storage(self):
val["end"] = (rank + 1) * partition_size
offsets[kw] = 0

storage_type = val["storage_type"]

storage_param_buffer = storage_type(partition_size)

dtype = storage_param_buffer.dtype
device = storage_param_buffer.device

dtype = val["dtype"]
# bind storage to buffer tensor
storage_param = torch.nn.Parameter(
torch.tensor([], dtype=dtype, device=device).set_(storage_param_buffer)
torch.empty(partition_size, dtype=dtype, device="cuda")
)
if val["requires_grad"]:
storage_param.requires_grad_(True)
Expand Down Expand Up @@ -223,23 +191,12 @@ def init_param_storage(self):
to_offset_end = offset_end + param_st - storage_st

# copy to buffer
# PyTorch 1.11 changed the API of storage.__getitem__
d_dtype = self._storage_params[kw_name].dtype
d_device = self._storage_params[kw_name].device
param.data = torch.tensor(
[], dtype=param.dtype, device=param.device
).set_(
self._storage_params[kw_name].storage(),
to_offset_st,
(to_offset_end - to_offset_st,),
)
param.data = self._storage_params[kw_name].view(-1)[to_offset_st:to_offset_end]
self._param_info[-1]["begin"] = to_offset_st
self._param_info[-1]["end"] = (to_offset_end - to_offset_st,)
setattr(param, "_start_partition", offset_st)
setattr(param, "_end_partition", offset_end)
param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_(
contiguous_param.storage(), offset_st, (offset_end - offset_st,)
)[:]
param.data[:] = contiguous_param.view(-1)[offset_st:offset_end]
del contiguous_param
else:
param.data = torch.tensor([], dtype=param.dtype, device=param.device)
Expand Down Expand Up @@ -424,18 +381,10 @@ def _load_from_state_dict(
to_offset_end = offset_end + param_st - storage_st

# copy to buffer
# PyTorch 1.11 changed the API of storage.__getitem__
d_dtype = self._storage_params[kw_name].dtype
d_device = self._storage_params[kw_name].device
torch.tensor([], dtype=d_dtype, device=d_device).set_(
self._storage_params[kw_name].storage(),
to_offset_st,
(to_offset_end - to_offset_st,),
)[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_(
contiguous_param.storage(), offset_st, (offset_end - offset_st,)
)[
:
]
with torch.no_grad():
self._storage_params[kw_name].data.view(-1)[
to_offset_st:to_offset_end
].copy_(contiguous_param.view(-1)[offset_st:offset_end])
del contiguous_param
elif strict:
missing_keys.append(key)
Expand Down Expand Up @@ -549,12 +498,7 @@ def init_parameters(self):
assert offset_st < offset_end

# copy to buffer
# PyTorch 1.11 changed the API of storage.__getitem__
d_dtype = self._storage_params[kw_name].dtype
d_device = self._storage_params[kw_name].device
param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_(
tmp_tensor.storage(), offset_st, (offset_end - offset_st,)
)[:]
param.data[:] = tmp_tensor.view(-1)[offset_st:offset_end]
del tmp_tensor

def _named_members(self, get_members_fn, prefix="", recurse=True, **kwargs):
Expand Down
31 changes: 17 additions & 14 deletions bmtrain/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
]
def send_activations(hidden_state, next_rank, comm):
send_meta(hidden_state, next_rank, comm)
ncclSend(hidden_state.storage(), next_rank, comm)
ncclSend(hidden_state.contiguous().view(-1), next_rank, comm)

def recv_activations(prev_rank, comm):
dtype, shape = recv_meta(prev_rank, comm)
hidden_state = torch.empty(shape, dtype=dtype, device="cuda")
ncclRecv(hidden_state.storage(), prev_rank, comm)
ncclRecv(hidden_state.view(-1), prev_rank, comm)
return hidden_state

def send_meta(x, next_rank, comm):
Expand All @@ -34,11 +34,11 @@ def send_meta(x, next_rank, comm):
meta_data[1] = DTYPE_LIST.index(x.dtype)
meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int)
meta_data = meta_data.contiguous()
ncclSend(meta_data.storage(), next_rank, comm)
ncclSend(meta_data, next_rank, comm)

def recv_meta(prev_rank, comm):
meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int)
ncclRecv(meta_data.storage(), prev_rank, comm)
ncclRecv(meta_data, prev_rank, comm)
n_dims = meta_data[0].item()
dtype = DTYPE_LIST[meta_data[1].item()]
shape = meta_data[2:n_dims+2].tolist()
Expand All @@ -52,7 +52,7 @@ def forward(ctx, src, root, comm = None):
comm = config["comm"]
ctx.comm = comm
outputs = torch.empty_like(src, dtype = src.dtype, device = src.device)
ncclBroadcast(src.storage(), outputs.storage(), root, comm)
ncclBroadcast(src.contiguous().view(-1), outputs.view(-1), root, comm)
return outputs

@staticmethod
Expand All @@ -74,13 +74,14 @@ def forward(ctx, input : torch.Tensor, comm = None):
world_size = commCount(comm)
if not input.is_contiguous():
input = input.contiguous()
if input.storage_offset() != 0 or input.storage().size() != input.numel():
# Clone if storage_offset != 0 so data_ptr points to the start of the tensor data.
if input.storage_offset() != 0:
input = input.clone()
output = torch.empty( (world_size,) + input.size(), dtype=input.dtype, device=input.device)
ctx.comm = comm
ncclAllGather(
input.storage(),
output.storage(),
input.view(-1),
output.view(-1),
comm
)
return output
Expand Down Expand Up @@ -115,13 +116,14 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None)
assert input.shape[0] % commCount(comm) == 0, "The dimension 0 must be divisible by the number of communication processes"
if not input.is_contiguous():
input = input.contiguous()
if input.storage_offset() != 0 or input.storage().size() != input.numel():
# Ensure data_ptr starts at offset 0 for NCCL.
if input.storage_offset() != 0:
input = input.clone()
output_shape = (input.shape[0] // commCount(comm), *input.shape[1:])
output = torch.empty( output_shape, dtype=input.dtype, device=input.device )
ncclReduceScatter(
input.storage(),
output.storage(),
input.view(-1),
output.view(-1),
op,
comm
)
Expand Down Expand Up @@ -171,13 +173,14 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None)
ctx.comm = comm
if not input.is_contiguous():
input = input.contiguous()
if input.storage_offset() != 0 or input.storage().size() != input.numel():
# Ensure data_ptr starts at offset 0 for NCCL.
if input.storage_offset() != 0:
input = input.clone()
output = torch.empty( input.size(), dtype=input.dtype, device=input.device)

ncclAllReduce(
input.storage(),
output.storage(),
input.view(-1),
output.view(-1),
op,
comm
)
Expand Down
Loading
Loading