diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py new file mode 100644 index 000000000000..93bf297d0ba0 --- /dev/null +++ b/deepspeed/utils/parallel_state.py @@ -0,0 +1,1161 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) DeepSpeed Team + +# DeepSpeed Team + +# The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Refactored Model and data parallel groups with class-based design.""" + +import logging +from datetime import timedelta +from typing import Callable, List, Optional + +import numpy as np +import torch + +from deepspeed.accelerator import get_accelerator +import deepspeed.comm as dist + +logger = logging.getLogger(__name__) + +try: + import einops + HAVE_EINOPS = True +except ImportError: + HAVE_EINOPS = False + + +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name, mem_alloc_context=None): + """Returns a sub-tensor from the buffer for the given shape.""" + from functools import reduce + import operator + + required_len = reduce(operator.mul, tensor_shape, 1) + if (self.buffer.get((name, dtype), None) is None or self.buffer[(name, dtype)].numel() < required_len): + from contextlib import nullcontext + mem_alloc_context = mem_alloc_context if mem_alloc_context else nullcontext + with mem_alloc_context(): + self.buffer[(name, dtype)] = torch.empty( + required_len, + dtype=dtype, + device=get_accelerator().current_device(), + requires_grad=False, + ) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + + +def generate_masked_orthogonal_rank_groups(world_size: int, parallel_size: List[int], + mask: List[bool]) -> List[List[int]]: + r"""Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. + + Algorithm: + For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and + local_rank satisfy the following equation: + global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """Solve: index = sum(idx[i] * stride[i])""" + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + assert (sum([x * y for x, y in zip(idx, stride[:-1])]) == index), f"idx {index} with shape {shape} mismatch" + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride)) + ranks.append(rank) + return ranks + + +class RankGenerator: + """A class for generating rank groups for different modes of parallelism.""" + + def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, sp: int, order: str, rank_offset: int = 0) -> None: + assert (ep == 1 or cp == 1), "Both EP and CP > 1 is not allowed in one rank generator." + + # Check SP compatibility: SP cannot be used with TP, PP, or EP + if sp > 1: + if tp > 1: + raise RuntimeError(f"Sequence Parallel (SP) cannot be used together with Tensor Parallel (TP). " + f"SP size: {sp}, TP size: {tp}. " + "Please set tp=1 when using SP.") + if pp > 1: + raise RuntimeError(f"Sequence Parallel (SP) cannot be used together with Pipeline Parallel (PP). " + f"SP size: {sp}, PP size: {pp}. " + "Please set pp=1 when using SP.") + if ep > 1: + raise RuntimeError(f"Sequence Parallel (SP) cannot be used together with Expert Parallel (EP). " + f"SP size: {sp}, EP size: {ep}. " + "Please set ep=1 when using SP.") + + self.tp = tp + self.ep = ep + self.dp = dp + self.pp = pp + self.cp = cp + self.sp = sp + self.rank_offset = rank_offset + self.world_size = tp * dp * pp * cp * ep * sp + + self.name_to_size = { + "tp": self.tp, + "pp": self.pp, + "dp": self.dp, + "ep": self.ep, + "cp": self.cp, + "sp": self.sp, + } + self.order = order + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError(f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't" + f"specified the order ({self.order}).") + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + """Create a mask for the specified tokens based on the given order.""" + ordered_token = order.split("-") + token_list = token.split("-") + mask = [False] * len(ordered_token) + for t in token_list: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Args: + token (str): Specify the ranks type (e.g., 'tp-dp') + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks + + +class ParallelState: + """Encapsulates all parallel state and operations. + + This class replaces the global variables and functions from the original + parallel_state.py, providing a cleaner, more maintainable interface. + """ + + def __init__(self): + # Process groups + self.tensor_model_parallel_group = None + self.pipeline_model_parallel_group = None + self.model_parallel_group = None + self.embedding_group = None + self.position_embedding_group = None + self.data_parallel_group = None + self.data_parallel_group_gloo = None + self.tensor_and_data_parallel_group = None + self.context_parallel_group = None + self.tensor_and_context_parallel_group = None + self.tensor_and_data_parallel_group_with_cp = None + self.data_parallel_group_with_cp = None + self.data_parallel_group_with_cp_gloo = None + + # Sequence parallel groups + self.sequence_parallel_group = None + self.sequence_and_data_parallel_group = None + + # Expert-related groups + self.expert_model_parallel_group = None + self.expert_tensor_parallel_group = None + self.expert_tensor_and_model_parallel_group = None + self.expert_tensor_model_pipeline_parallel_group = None + self.expert_data_parallel_group = None + self.expert_data_parallel_group_gloo = None + self.intra_partial_expert_data_parallel_group = None + self.intra_partial_expert_data_parallel_group_gloo = None + self.inter_partial_expert_data_parallel_group = None + + # All-to-All groups for ZeRO++ quantized gradients + self.all_to_all_groups = {} + self.all_to_all_initialized = False + + # Global ranks lists + self.embedding_global_ranks = None + self.position_embedding_global_ranks = None + self.pipeline_global_ranks = None + self.data_parallel_global_ranks = None + self.tensor_model_parallel_global_ranks = None + self.model_parallel_global_ranks = None + self.context_parallel_global_ranks = None + self.data_parallel_global_ranks_with_cp = None + self.hierarchical_context_parallel_groups = None + + # Parallel state values + self.virtual_pipeline_model_parallel_rank = None + self.virtual_pipeline_model_parallel_world_size = None + self.mpu_tensor_model_parallel_world_size = None + self.mpu_pipeline_model_parallel_world_size = None + self.mpu_data_parallel_world_size = None + self.mpu_data_parallel_rank = None + self.mpu_tensor_model_parallel_rank = None + self.mpu_pipeline_model_parallel_rank = None + + # Expert parallel state values + self.mpu_expert_model_parallel_world_size = None + self.mpu_expert_model_parallel_rank = None + self.mpu_expert_tensor_parallel_world_size = None + self.mpu_expert_tensor_parallel_rank = None + + # Other + self.global_memory_buffer = None + self.global_process_group_list = None + self.intra_partial_data_parallel_group_with_cp = None + self.intra_partial_data_parallel_group_with_cp_gloo = None + self.intra_distributed_optimizer_instance_group = None + + # Rank generators + self.decoder_rank_generator = None + self.expert_decoder_rank_generator = None + + def _get_pg_options(self, pg_name: str, pg_comm_cfgs: dict): + """Get the options for a specific process group.""" + # TODO: construct process group options from json config + # + # As of PyTorch 2.9, the only backend that supports pg options is nccl, + # and a nccl-specific class, namely ProcessGroupNCCL.Options, is + # required to construct the options. + # + # To enable configuring such options in DeepSpeed, we need to define the + # interface for users to specify them and also figure out whether we + # want to export ProcessGroupNCCL.Options in deepspeed.comm or allow + # using torch distributed for this specific case in check-torchdist.py. + # Those are left as future work. + return None + + def _create_group( + self, + ranks, + timeout=None, + backend=None, + pg_options=None, + use_local_synchronization=False, + group_desc=None, + ): + """Creates a ProcessGroup.""" + if backend is not None and backend != "nccl": + logger.warning(f"{backend} backend is not supported for new_group. Using deepspeed.comm directly.") + return None + + # TODO: Currently using deepspeed.comm.new_group() which only supports 'ranks' parameter. + # The following parameters are commented out and will be enabled once DeepSpeed's + # comm interface supports them: + # - timeout: Timeout for process group operations + # - backend: Communication backend (e.g., 'nccl', 'gloo') + # - pg_options: Process group options + # - use_local_synchronization: Enable local synchronization + # - group_desc: Group description for debugging (requires PyTorch >= 2.4) + kwargs = { + "ranks": ranks, + # "timeout": timeout, + # "backend": backend, + # "pg_options": pg_options, + # "use_local_synchronization": use_local_synchronization, + # "group_desc": group_desc, + } + + group = dist.new_group(**kwargs) + if self.global_process_group_list is None: + self.global_process_group_list = [None] + if dist.get_rank() in ranks: + self.global_process_group_list.append(group) + return group + + def _create_hierarchical_groups( + self, + rank, + ranks, + hierarchical_group_sizes, + create_gloo_process_groups=False, + pg_options=None, + timeout=None, + group_desc=None, + ): + """Create hierarchical groups for a set of ranks.""" + if not HAVE_EINOPS: + raise ImportError("einops is not installed. Please install it with `pip install einops`.") + + hierarchical_groups = [] + hierarchical_groups_gloo = [] + if not isinstance(pg_options, list): + pg_options = [pg_options] * len(hierarchical_group_sizes) + + for level in range(len(hierarchical_group_sizes)): + rearranged_ranks = einops.rearrange( + np.array(ranks), + "(l s u) -> (l u) s", + u=int(np.prod(hierarchical_group_sizes[:level])), + s=hierarchical_group_sizes[level], + l=int(np.prod(hierarchical_group_sizes[level + 1:])), + ).tolist() + for sub_ranks in rearranged_ranks: + sub_group = self._create_group( + sub_ranks, + timeout=timeout, + pg_options=pg_options[level], + group_desc=f"HIERARCHICAL_{group_desc}_L{level}", + ) + if create_gloo_process_groups: + sub_group_gloo = self._create_group( + sub_ranks, + timeout=timeout, + backend="gloo", + pg_options=pg_options[level], + group_desc=f"HIERARCHICAL_{group_desc}_GLOO_L{level}", + ) + else: + sub_group_gloo = None + if rank in sub_ranks: + hierarchical_groups.append(sub_group) + hierarchical_groups_gloo.append(sub_group_gloo) + + assert rank not in ranks or len(hierarchical_groups) == len(hierarchical_group_sizes) + assert rank not in ranks or len(hierarchical_groups_gloo) == len(hierarchical_group_sizes) + return hierarchical_groups, hierarchical_groups_gloo + + def initialize_model_parallel( + self, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + virtual_pipeline_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_comm_backend: Optional[str] = None, + context_parallel_size: int = 1, + sequence_parallel_size: int = 1, + hierarchical_context_parallel_sizes: Optional[List[int]] = None, + expert_model_parallel_size: int = 1, + num_distributed_optimizer_instances: int = 1, + expert_tensor_parallel_size: Optional[int] = None, + distributed_timeout_minutes: int = 30, + order: str = "tp-ep-dp-pp", + get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, + get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, + create_gloo_process_groups: bool = False, + ) -> None: + """Initialize model data parallel groups. + + This is the main initialization method that sets up all parallel groups. + """ + + def default_embedding_ranks(pp_ranks): + """Return the default ranks that constitute the stages on which the word embeddings live.""" + if len(pp_ranks) == 1: + return [pp_ranks[0]] + else: + return [pp_ranks[0], pp_ranks[-1]] + + def default_position_embedding_ranks(pp_ranks): + """Return the default ranks that constitute the stages on which the position embeddings live.""" + return [pp_ranks[0]] + + if get_embedding_ranks is None: + get_embedding_ranks = default_embedding_ranks + if get_position_embedding_ranks is None: + get_position_embedding_ranks = default_position_embedding_ranks + + # Get world size and rank + assert dist.is_initialized() + world_size: int = dist.get_world_size() + rank = dist.get_rank() + + model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size * sequence_parallel_size + if world_size % model_size != 0: + raise RuntimeError(f"world_size ({world_size}) is not divisible by {model_size}") + + data_parallel_size: int = world_size // model_size + + if virtual_pipeline_model_parallel_size is not None: + if not pipeline_model_parallel_size > 1: + raise RuntimeError("pipeline-model-parallel size should be greater than 1 with interleaved schedule") + self.virtual_pipeline_model_parallel_rank = 0 + self.virtual_pipeline_model_parallel_world_size = virtual_pipeline_model_parallel_size + + # TODO: Collect process group options from configs + # + # Check _get_pg_options for details. + pg_comm_cfgs = {} + + # Create rank generators + self.decoder_rank_generator = RankGenerator( + tp=tensor_model_parallel_size, + ep=1, + dp=data_parallel_size, + pp=pipeline_model_parallel_size, + cp=context_parallel_size, + order=order, + rank_offset=0, + sp=sequence_parallel_size, + ) + + # Build expert rank generator + if expert_tensor_parallel_size is None: + expert_tensor_parallel_size = tensor_model_parallel_size + expert_tensor_model_pipeline_parallel_size = (expert_tensor_parallel_size * expert_model_parallel_size * + pipeline_model_parallel_size) + expert_data_parallel_size = world_size // expert_tensor_model_pipeline_parallel_size + if world_size % expert_tensor_model_pipeline_parallel_size != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})" + ) + + self.expert_decoder_rank_generator = RankGenerator( + tp=expert_tensor_parallel_size, + ep=expert_model_parallel_size, + dp=expert_data_parallel_size, + pp=pipeline_model_parallel_size, + cp=1, + order=order, + rank_offset=0, + sp=1, + ) + + timeout = timedelta(minutes=distributed_timeout_minutes) + + # Build data-parallel groups with context parallel + assert self.data_parallel_group is None, "data parallel group is already initialized" + assert (data_parallel_size * context_parallel_size) % num_distributed_optimizer_instances == 0, ( + "Data parallel size should be divisible by partial DistOpt shard factor") + intra_partial_data_parallel_size = (data_parallel_size * + context_parallel_size) // num_distributed_optimizer_instances + + for ranks_with_cp in self.decoder_rank_generator.get_ranks('dp-cp'): + group_with_cp = self._create_group( + ranks_with_cp, + timeout=timeout, + pg_options=self._get_pg_options("dp_cp", pg_comm_cfgs), + group_desc="DATA_PARALLEL_GROUP_WITH_CP", + ) + if create_gloo_process_groups: + group_with_cp_gloo = self._create_group( + ranks_with_cp, + timeout=timeout, + backend="gloo", + group_desc="DATA_PARALLEL_GROUP_WITH_CP_GLOO", + ) + else: + group_with_cp_gloo = None + if rank in ranks_with_cp: + self.data_parallel_group_with_cp = group_with_cp + self.data_parallel_group_with_cp_gloo = group_with_cp_gloo + self.data_parallel_global_ranks_with_cp = ranks_with_cp + + if num_distributed_optimizer_instances > 1: + for i in range(num_distributed_optimizer_instances): + intra_partial_dp_ranks_with_cp = ranks_with_cp[( + i * intra_partial_data_parallel_size):((i + 1) * intra_partial_data_parallel_size)] + intra_partial_dp_group_with_cp = self._create_group( + intra_partial_dp_ranks_with_cp, + timeout=timeout, + pg_options=self._get_pg_options("intra_dp_cp", pg_comm_cfgs), + group_desc="INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP", + ) + if create_gloo_process_groups: + intra_partial_dp_group_with_cp_gloo = self._create_group( + intra_partial_dp_ranks_with_cp, + timeout=timeout, + backend="gloo", + group_desc="INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO", + ) + else: + intra_partial_dp_group_with_cp_gloo = None + if rank in intra_partial_dp_ranks_with_cp: + self.intra_partial_data_parallel_group_with_cp = intra_partial_dp_group_with_cp + self.intra_partial_data_parallel_group_with_cp_gloo = (intra_partial_dp_group_with_cp_gloo) + else: + self.intra_partial_data_parallel_group_with_cp = self.data_parallel_group_with_cp + self.intra_partial_data_parallel_group_with_cp_gloo = self.data_parallel_group_with_cp_gloo + + # Build data-parallel groups + for ranks in self.decoder_rank_generator.get_ranks('dp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("dp", pg_comm_cfgs), + group_desc="DATA_PARALLEL_GROUP", + ) + if create_gloo_process_groups: + group_gloo = self._create_group(ranks, + timeout=timeout, + backend="gloo", + group_desc="DATA_PARALLEL_GROUP_GLOO") + else: + group_gloo = None + if rank in ranks: + self.data_parallel_group = group + self.data_parallel_group_gloo = group_gloo + self.data_parallel_global_ranks = ranks + + # Build context-parallel groups + assert self.context_parallel_group is None, 'context parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('cp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("cp", pg_comm_cfgs), + group_desc="CONTEXT_PARALLEL_GROUP", + ) + if rank in ranks: + self.context_parallel_group = group + self.context_parallel_global_ranks = ranks + if hierarchical_context_parallel_sizes: + assert np.prod(hierarchical_context_parallel_sizes) == context_parallel_size + hierarchical_groups, _ = self._create_hierarchical_groups( + rank, + ranks, + hierarchical_context_parallel_sizes, + create_gloo_process_groups=False, + pg_options=self._get_pg_options("hcp", pg_comm_cfgs), + timeout=timeout, + group_desc="CONTEXT_PARALLEL_GROUP", + ) + if rank in ranks: + self.hierarchical_context_parallel_groups = hierarchical_groups + + # Build model-parallel groups + assert self.model_parallel_group is None, 'model parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp-pp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("mp", pg_comm_cfgs), + group_desc="MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.model_parallel_group = group + self.model_parallel_global_ranks = ranks + + # Build tensor model-parallel groups + assert self.tensor_model_parallel_group is None, 'tensor model parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("tp", pg_comm_cfgs), + group_desc="TENSOR_MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.tensor_model_parallel_group = group + self.tensor_model_parallel_global_ranks = ranks + + # Build pipeline model-parallel groups and embedding groups + assert self.pipeline_model_parallel_group is None, "pipeline model parallel group is already initialized" + assert self.embedding_group is None, "embedding group is already initialized" + assert self.position_embedding_group is None, "position embedding group is already initialized" + + for ranks in self.decoder_rank_generator.get_ranks('pp'): + group = self._create_group( + ranks, + timeout=timeout, + backend=pipeline_model_parallel_comm_backend, + pg_options=(None if pipeline_model_parallel_comm_backend == "ucc" else self._get_pg_options( + "pp", pg_comm_cfgs)), + group_desc="PIPELINE_MODEL_PARALLEL_GROUP", + ) + assert ( + pipeline_model_parallel_comm_backend == None or pipeline_model_parallel_comm_backend == "nccl" + or pipeline_model_parallel_comm_backend == "ucc" + ), f'"{pipeline_model_parallel_comm_backend}" backend for PP communication is currently not supported' + + if rank in ranks: + if self.pipeline_model_parallel_group is None: + self.pipeline_model_parallel_group = group + self.pipeline_global_ranks = ranks + elif isinstance(self.pipeline_global_ranks[0], list): + if not isinstance(self.pipeline_model_parallel_group, list): + self.pipeline_model_parallel_group = [self.pipeline_model_parallel_group] + self.pipeline_model_parallel_group.append(group) + self.pipeline_global_ranks.append(ranks) + else: + self.pipeline_model_parallel_group = [self.pipeline_model_parallel_group, group] + self.pipeline_global_ranks = [self.pipeline_global_ranks, ranks] + + embedding_ranks = get_embedding_ranks(ranks) + group = self._create_group( + embedding_ranks, + timeout=timeout, + pg_options=self._get_pg_options("embd", pg_comm_cfgs), + group_desc="EMBEDDING_GROUP", + ) + if rank in embedding_ranks: + self.embedding_group = group + self.embedding_global_ranks = embedding_ranks + + position_embedding_ranks = get_position_embedding_ranks(ranks) + group = self._create_group( + position_embedding_ranks, + timeout=timeout, + pg_options=self._get_pg_options("pos_embd", pg_comm_cfgs), + group_desc="POSITION_EMBEDDING_GROUP", + ) + if rank in position_embedding_ranks: + self.position_embedding_group = group + self.position_embedding_global_ranks = position_embedding_ranks + + # Build tensor + data parallel groups + assert self.tensor_and_data_parallel_group is None, 'Tensor + data parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp-dp-cp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("tp_dp_cp", pg_comm_cfgs), + group_desc="TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP", + ) + if rank in ranks: + self.tensor_and_data_parallel_group_with_cp = group + for ranks in self.decoder_rank_generator.get_ranks('tp-dp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("tp_dp", pg_comm_cfgs), + group_desc="TENSOR_AND_DATA_PARALLEL_GROUP", + ) + if rank in ranks: + self.tensor_and_data_parallel_group = group + + assert self.tensor_and_context_parallel_group is None, 'Tensor + context parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp-cp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("tp_cp", pg_comm_cfgs), + group_desc="TENSOR_AND_CONTEXT_PARALLEL_GROUP", + ) + if rank in ranks: + self.tensor_and_context_parallel_group = group + + # Build expert-related parallel groups + assert self.expert_model_parallel_group is None, 'Expert parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('ep'): + group = self._create_group( + ranks, + pg_options=self._get_pg_options("ep", pg_comm_cfgs), + group_desc="EXPERT_MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_model_parallel_group = group + + assert self.expert_tensor_parallel_group is None, 'Expert tensor model parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('tp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("ep_tp", pg_comm_cfgs), + group_desc="EXPERT_TENSOR_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_tensor_parallel_group = group + + assert self.expert_tensor_and_model_parallel_group is None, 'Expert tensor + model parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('tp-ep'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("tp_ep_mp", pg_comm_cfgs), + group_desc="EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_tensor_and_model_parallel_group = group + + assert self.expert_tensor_model_pipeline_parallel_group is None, 'The expert_tensor_model_pipeline parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('tp-ep-pp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("tp_ep_pp", pg_comm_cfgs), + group_desc="EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_tensor_model_pipeline_parallel_group = group + + assert self.expert_data_parallel_group is None, "Expert data group is already initialized" + assert self.expert_data_parallel_group_gloo is None, "Expert data group-gloo is already initialized" + assert self.intra_partial_expert_data_parallel_group is None, "Intra partial expert data group is already initialized" + assert self.intra_partial_expert_data_parallel_group_gloo is None, "Intra partial expert data group-gloo is already initialized" + assert self.inter_partial_expert_data_parallel_group is None, "Inter partial expert data group is already initialized" + + assert (expert_data_parallel_size % num_distributed_optimizer_instances == 0 + ), "Expert data parallel size should be divisible by partial DistOpt shard factor" + intra_partial_expert_data_parallel_size = (expert_data_parallel_size // num_distributed_optimizer_instances) + + for ranks in self.expert_decoder_rank_generator.get_ranks('dp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("ep_dp", pg_comm_cfgs), + group_desc="EXPERT_DATA_PARALLEL_GROUP", + ) + if create_gloo_process_groups: + group_gloo = self._create_group(ranks, backend="gloo", group_desc="EXPERT_DATA_PARALLEL_GROUP_GLOO") + else: + group_gloo = None + if rank in ranks: + self.expert_data_parallel_group = group + self.expert_data_parallel_group_gloo = group_gloo + + if num_distributed_optimizer_instances > 1: + hierarchical_groups, hierarchical_groups_gloo = self._create_hierarchical_groups( + rank, + ranks, + [intra_partial_expert_data_parallel_size, num_distributed_optimizer_instances], + create_gloo_process_groups=create_gloo_process_groups, + pg_options=[ + self._get_pg_options("intra_ep_dp", pg_comm_cfgs), + self._get_pg_options("inter_ep_dp", pg_comm_cfgs), + ], + timeout=timeout, + group_desc="EXPERT_DATA_PARALLEL_GROUP", + ) + if rank in ranks: + self.intra_partial_expert_data_parallel_group = hierarchical_groups[0] + self.intra_partial_expert_data_parallel_group_gloo = hierarchical_groups_gloo[0] + self.inter_partial_expert_data_parallel_group = hierarchical_groups[1] + else: + self.intra_partial_expert_data_parallel_group = self.expert_data_parallel_group + self.intra_partial_expert_data_parallel_group_gloo = self.expert_data_parallel_group_gloo + + # Build intra distributed optimizer instance group + assert self.intra_distributed_optimizer_instance_group is None, "Intra distributed optimizer instance group is already initialized" + model_parallel_group_id = 0 + intra_dist_opt_ranks = [] + for ranks in self.expert_decoder_rank_generator.get_ranks('tp-ep-pp'): + model_parallel_group_id += 1 + intra_dist_opt_ranks.extend(ranks) + if model_parallel_group_id % intra_partial_expert_data_parallel_size == 0: + intra_dist_opt_instance_group = self._create_group( + intra_dist_opt_ranks, + timeout=timeout, + pg_options=self._get_pg_options("intra_dist_opt_instance", pg_comm_cfgs), + group_desc="INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP", + ) + if rank in intra_dist_opt_ranks: + self.intra_distributed_optimizer_instance_group = intra_dist_opt_instance_group + intra_dist_opt_ranks = [] + + # Build sequence parallel groups using RankGenerator + if sequence_parallel_size > 1: + assert self.sequence_parallel_group is None, "sequence parallel group is already initialized" + assert self.sequence_and_data_parallel_group is None, "sequence and data parallel group is already initialized" + + # Build SP groups using RankGenerator + for ranks in self.decoder_rank_generator.get_ranks('sp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("sp", pg_comm_cfgs), + group_desc="SEQUENCE_PARALLEL_GROUP", + ) + if rank in ranks: + self.sequence_parallel_group = group + + # Build SP+DP combined groups using RankGenerator + for ranks in self.decoder_rank_generator.get_ranks('sp-dp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("sp_dp", pg_comm_cfgs), + group_desc="SEQUENCE_AND_DATA_PARALLEL_GROUP", + ) + if rank in ranks: + self.sequence_and_data_parallel_group = group + + # Initialize global memory buffer + self._set_global_memory_buffer() + + def _set_global_memory_buffer(self): + """Initialize global buffer.""" + assert self.global_memory_buffer is None, "global memory buffer is already initialized" + self.global_memory_buffer = GlobalMemoryBuffer() + + # Getter methods for process groups + def get_model_parallel_group(self, check_initialized=True): + """Get the model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.model_parallel_group is not None, "model parallel group is not initialized" + return self.model_parallel_group + + def get_tensor_model_parallel_group(self, check_initialized=True): + """Get the tensor-model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.tensor_model_parallel_group is not None, "tensor model parallel group is not initialized" + return self.tensor_model_parallel_group + + def get_pipeline_model_parallel_group(self, check_initialized=True): + """Get the pipeline-model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.pipeline_model_parallel_group is not None, "pipeline_model parallel group is not initialized" + return self.pipeline_model_parallel_group + + def get_data_parallel_group(self, with_context_parallel=False, partial_data_parallel=False): + """Get the data-parallel group the caller rank belongs to.""" + if with_context_parallel: + if partial_data_parallel: + assert self.intra_partial_data_parallel_group_with_cp is not None, "Intra partial data parallel group is not initialized" + return self.intra_partial_data_parallel_group_with_cp + assert self.data_parallel_group_with_cp is not None, "data parallel group with context parallel combined is not initialized" + return self.data_parallel_group_with_cp + else: + assert self.data_parallel_group is not None, "data parallel group is not initialized" + assert partial_data_parallel == False, "Partial DP for Optimizer needs to include CP" + return self.data_parallel_group + + def get_context_parallel_group(self, check_initialized=True): + """Get the context-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.context_parallel_group is not None, "context parallel group is not initialized" + return self.context_parallel_group + + def get_sequence_parallel_group(self, check_initialized=True): + """Get the sequence-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.sequence_parallel_group is not None, "sequence parallel group is not initialized" + return self.sequence_parallel_group + + def get_sequence_and_data_parallel_group(self, check_initialized=True): + """Get the sequence and data parallel group the caller rank belongs to.""" + if check_initialized: + assert self.sequence_and_data_parallel_group is not None, "sequence and data parallel group is not initialized" + return self.sequence_and_data_parallel_group + + def get_embedding_group(self, check_initialized=True): + """Get the embedding group the caller rank belongs to.""" + if check_initialized: + assert self.embedding_group is not None, "embedding group is not initialized" + return self.embedding_group + + def get_tensor_and_data_parallel_group(self, check_initialized=True, with_context_parallel=False): + """Get the tensor- and data-parallel group the caller rank belongs to.""" + if with_context_parallel: + if check_initialized: + assert self.tensor_and_data_parallel_group_with_cp is not None, 'tensor and data parallel group is not initialized' + return self.tensor_and_data_parallel_group_with_cp + else: + if check_initialized: + assert self.tensor_and_data_parallel_group is not None, 'tensor and data parallel group is not initialized' + return self.tensor_and_data_parallel_group + + def get_tensor_and_context_parallel_group(self, check_initialized=True): + """Get the tensor- and context-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.tensor_and_context_parallel_group is not None, "tensor and context parallel group is not initialized" + return self.tensor_and_context_parallel_group + + # Getter methods for world sizes and ranks + def get_tensor_model_parallel_world_size(self): + """Return world size for the tensor-model-parallel group.""" + if self.mpu_tensor_model_parallel_world_size is not None: + return self.mpu_tensor_model_parallel_world_size + return self.get_tensor_model_parallel_group().size() + + def get_pipeline_model_parallel_world_size(self): + """Return world size for the pipeline-model-parallel group.""" + if self.mpu_pipeline_model_parallel_world_size is not None: + return self.mpu_pipeline_model_parallel_world_size + return self.get_pipeline_model_parallel_group().size() + + def get_tensor_model_parallel_rank(self): + """Return caller's rank for the tensor-model-parallel group.""" + if self.mpu_tensor_model_parallel_rank is not None: + return self.mpu_tensor_model_parallel_rank + return self.get_tensor_model_parallel_group().rank() + + def get_pipeline_model_parallel_rank(self): + """Return caller's rank for the pipeline-model-parallel group.""" + if self.mpu_pipeline_model_parallel_rank is not None: + return self.mpu_pipeline_model_parallel_rank + return dist.get_rank(group=self.get_pipeline_model_parallel_group()) + + def get_data_parallel_world_size(self, with_context_parallel=False, partial_data_parallel=False): + """Return world size for the data parallel group.""" + if self.mpu_data_parallel_world_size is not None: + return self.mpu_data_parallel_world_size + if dist.is_available() and dist.is_initialized(): + return self.get_data_parallel_group(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel).size() + else: + return 0 + + def get_data_parallel_rank(self, with_context_parallel=False, partial_data_parallel=False): + """Return caller's rank in the data-parallel group.""" + if self.mpu_data_parallel_rank is not None: + return self.mpu_data_parallel_rank + if dist.is_available() and dist.is_initialized(): + return self.get_data_parallel_group(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel).rank() + else: + return 0 + + def get_context_parallel_world_size(self): + """Return world size for the context parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_context_parallel_group().size() + else: + return 0 + + def get_context_parallel_rank(self): + """Return caller's rank in the context-parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_context_parallel_group().rank() + else: + return 0 + + def get_sequence_parallel_world_size(self): + """Return world size for the sequence parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_parallel_group is not None: + return self.get_sequence_parallel_group().size() + return 1 + + def get_sequence_parallel_rank(self): + """Return caller's rank in the sequence-parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_parallel_group is not None: + return self.get_sequence_parallel_group().rank() + return 0 + + def get_sequence_and_data_parallel_world_size(self): + """Return world size for the sequence and data parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_and_data_parallel_group is not None: + return self.get_sequence_and_data_parallel_group().size() + return 0 + + def get_sequence_and_data_parallel_rank(self): + """Return caller's rank in the sequence and data parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_and_data_parallel_group is not None: + return self.get_sequence_and_data_parallel_group().rank() + return 0 + + def is_initialized(self): + """Check if parallel state has been initialized""" + return self.data_parallel_group is not None + + def initialize_all_to_all_groups(self): + """Initialize All-to-All groups for quantized gradient communication. + + Creates local and global All-to-All groups based on node topology: + - Local groups: intra-node communication (NVLink/NVSwitch) + - Global groups: inter-node communication (cross-node) + + Used by ZeRO++ when zero_quantized_gradients is enabled. + + Returns: + Dictionary of All-to-All groups + """ + if self.all_to_all_initialized: + return self.all_to_all_groups + + assert dist.is_initialized(), 'dist is not initialized' + + device_per_node = get_accelerator().device_count() + world_size = dist.get_world_size() + num_nodes = world_size // device_per_node + + if num_nodes == 0 and world_size > 0: + # Single incomplete node + assert world_size >= 1, 'num_gpus must >=1, cannot initialize All-To-All' + ranks = list(range(world_size)) + self.all_to_all_groups['local_0'] = self._create_group(ranks) + + elif num_nodes == 1: + # Exactly one node + assert world_size == device_per_node, 'num_gpus not equal to device per node, cannot initialize All-To-All' + ranks = list(range(device_per_node)) + self.all_to_all_groups['local_0'] = self._create_group(ranks) + + else: + # Multiple nodes: create both local and global groups + assert world_size > device_per_node, 'num_nodes<2 cannot initialize All-To-All' + + # Local groups (intra-node) + for node_id in range(num_nodes): + local_ranks = [j + device_per_node * node_id for j in range(device_per_node)] + self.all_to_all_groups[f"local_{node_id}"] = self._create_group(local_ranks) + + # Global groups (inter-node) + for device_id in range(device_per_node): + global_ranks = [device_id + j * device_per_node for j in range(num_nodes)] + self.all_to_all_groups[f"global_{device_id}"] = self._create_group(global_ranks) + + self.all_to_all_initialized = True + return self.all_to_all_groups + + def get_all_to_all_groups(self): + """Get All-to-All groups dictionary. + + Initializes the groups if not already initialized. + + Returns: + Dictionary of All-to-All groups + """ + if not self.all_to_all_initialized: + self.initialize_all_to_all_groups() + return self.all_to_all_groups + + def get_global_memory_buffer(self): + """Return the global GlobalMemoryBuffer object""" + assert self.global_memory_buffer is not None, "global memory buffer is not initialized" + return self.global_memory_buffer + + # Expert-related getter methods + def get_expert_model_parallel_group(self, check_initialized=True): + """Get the expert-model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.expert_model_parallel_group is not None, "expert model parallel group is not initialized" + return self.expert_model_parallel_group + + def get_expert_model_parallel_world_size(self): + """Return world size for the expert-model-parallel group.""" + if self.mpu_expert_model_parallel_world_size is not None: + return self.mpu_expert_model_parallel_world_size + if dist.is_available() and dist.is_initialized(): + return self.get_expert_model_parallel_group().size() + else: + return 0 + + def get_expert_model_parallel_rank(self): + """Return caller's rank in the expert-model-parallel group.""" + if self.mpu_expert_model_parallel_rank is not None: + return self.mpu_expert_model_parallel_rank + if dist.is_available() and dist.is_initialized(): + return self.get_expert_model_parallel_group().rank() + else: + return 0 + + def get_expert_tensor_parallel_group(self, check_initialized=True): + """Get the expert-tensor-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.expert_tensor_parallel_group is not None, "Expert tensor parallel group is not initialized" + return self.expert_tensor_parallel_group + + def get_expert_tensor_parallel_world_size(self): + """Return world size for the expert tensor parallel group.""" + if self.mpu_expert_tensor_parallel_world_size is not None: + return self.mpu_expert_tensor_parallel_world_size + if not self.expert_tensor_parallel_group: + return self.mpu_tensor_model_parallel_world_size + else: + return self.get_expert_tensor_parallel_group().size() + + def get_expert_tensor_parallel_rank(self): + """Return my rank for the expert tensor parallel group.""" + if self.mpu_expert_tensor_parallel_rank is not None: + return self.mpu_expert_tensor_parallel_rank + if not self.expert_tensor_parallel_group: + return self.mpu_tensor_model_parallel_rank + else: + return self.get_expert_tensor_parallel_group().rank() + + def get_expert_data_parallel_group(self, check_initialized=True, partial_expert_data_parallel=False): + """Get expert data parallel group.""" + if partial_expert_data_parallel: + if check_initialized: + assert self.intra_partial_expert_data_parallel_group is not None, "Intra partial expert data parallel group is not initialized" + return self.intra_partial_expert_data_parallel_group + else: + if check_initialized: + assert self.expert_data_parallel_group is not None, "Expert data parallel group is not initialized" + return self.expert_data_parallel_group + + def get_expert_data_parallel_rank(self, partial_expert_data_parallel=False): + """Return caller's rank in the expert data parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_expert_data_parallel_group( + partial_expert_data_parallel=partial_expert_data_parallel).rank() + else: + return 0 + + def get_expert_data_parallel_world_size(self, partial_expert_data_parallel=False): + """Return world size for the expert data parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_expert_data_parallel_group( + partial_expert_data_parallel=partial_expert_data_parallel).size() + else: + return 0 + + +# Convenience function to create a singleton instance +_parallel_state_instance = None + + +def get_parallel_state() -> ParallelState: + """Get or create the global ParallelState instance.""" + global _parallel_state_instance + if _parallel_state_instance is None: + _parallel_state_instance = ParallelState() + return _parallel_state_instance diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py new file mode 100644 index 000000000000..eb768fd83815 --- /dev/null +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -0,0 +1,908 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) DeepSpeed Team + +# DeepSpeed Team + +# The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +DeepSpeed Compatibility Layer for parallel_state. + +This module provides module-level functions compatible with DeepSpeed's +groups.py API, allowing code written for DeepSpeed to work with the +refactored parallel_state module. + +Key Features: +- Supports multiple parallel state instances (for RL scenarios with different models) +- Backward compatible with single global instance +- Context manager for switching between different parallel configurations +- Configuration-based initialization from config.json + +Usage: + # Basic usage (single global instance): + from parallel_state_deepspeed import get_data_parallel_group + dp_group = get_data_parallel_group() + + # Multi-instance usage (for RL scenarios): + from parallel_state_deepspeed import ( + get_parallel_state_instance, + set_current_parallel_state, + get_data_parallel_group, + ) + + # Create different instances for different models + actor_state = get_parallel_state_instance("actor") + critic_state = get_parallel_state_instance("critic") + + # Initialize with different DP sizes + actor_state.initialize_model_parallel(tensor_model_parallel_size=2, data_parallel_size=4) + critic_state.initialize_model_parallel(tensor_model_parallel_size=1, data_parallel_size=8) + + # Use context manager to switch + with set_current_parallel_state("actor"): + actor_dp_group = get_data_parallel_group() # Uses actor's DP group + + with set_current_parallel_state("critic"): + critic_dp_group = get_data_parallel_group() # Uses critic's DP group + + # Initialize from config.json: + from deepspeed import DeepSpeedConfig + ds_config = DeepSpeedConfig("config.json") + initialize_parallel_state_from_config(ds_config) +""" + +from contextlib import contextmanager +from typing import Optional, Union, Dict, Any, List +from .parallel_state import ParallelState, get_parallel_state as _get_default_parallel_state + +# Registry for multiple parallel state instances +_parallel_state_registry = {} +_default_instance_name = "__default__" + +# Current active instance name (thread-local would be better, but using global for simplicity) +_current_instance_name = _default_instance_name + + +def get_parallel_state_instance(name: Optional[str] = None) -> ParallelState: + """Get or create a named ParallelState instance. + + Args: + name: Name of the instance. If None, returns the default global instance. + Use different names for different models in RL scenarios. + + Returns: + ParallelState instance + + Example: + # For RL with actor and critic models + actor_state = get_parallel_state_instance("actor") + critic_state = get_parallel_state_instance("critic") + """ + if name is None: + return _get_default_parallel_state() + + if name not in _parallel_state_registry: + _parallel_state_registry[name] = ParallelState() + + return _parallel_state_registry[name] + + +def set_current_parallel_state(name: Optional[str] = None): + """Set the current active parallel state instance. + + Args: + name: Name of the instance to activate. If None, uses the default instance. + + Returns: + Context manager for temporarily switching the active instance + + Example: + with set_current_parallel_state("actor"): + dp_group = get_data_parallel_group() # Uses actor's DP group + """ + + @contextmanager + def _context(): + global _current_instance_name + old_name = _current_instance_name + _current_instance_name = name if name is not None else _default_instance_name + try: + yield + finally: + _current_instance_name = old_name + + return _context() + + +def get_current_parallel_state() -> ParallelState: + """Get the currently active parallel state instance. + + Returns: + The currently active ParallelState instance + """ + return get_parallel_state_instance(_current_instance_name) + + +def get_parallel_state(name: Optional[str] = None) -> ParallelState: + """Get parallel state instance (backward compatible). + + If name is provided, returns the named instance. + Otherwise, returns the currently active instance. + + Args: + name: Optional name of the instance. If None, returns current active instance. + + Returns: + ParallelState instance + """ + if name is not None: + return get_parallel_state_instance(name) + return get_current_parallel_state() + + +# ============================================================================ +# Core Tensor/Model/Data Parallel Functions +# ============================================================================ + + +def get_tensor_model_parallel_group(name: Optional[str] = None): + """Get the tensor model parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + Use this in RL scenarios to specify which model's parallel groups to use. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_group() + + +def get_model_parallel_group(name: Optional[str] = None): + """Get the model parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_model_parallel_group() + + +def get_data_parallel_group(name: Optional[str] = None, + with_context_parallel: bool = False, + partial_data_parallel: bool = False): + """Get the data parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + Use this in RL scenarios to specify which model's DP group to use. + For example, "actor" vs "critic" may have different DP sizes. + with_context_parallel: Whether to include context parallel in the group. + partial_data_parallel: Whether to use partial data parallel group. + + DeepSpeed-compatible interface. + + Example: + # In RL scenario with different DP sizes: + actor_dp = get_data_parallel_group("actor") # Actor's DP group + critic_dp = get_data_parallel_group("critic") # Critic's DP group + + # Or use context manager: + with set_current_parallel_state("actor"): + dp_group = get_data_parallel_group() # Uses actor's DP group + """ + return get_parallel_state(name).get_data_parallel_group(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel) + + +def get_tensor_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the tensor model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_world_size() + + +def get_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_world_size() + + +def get_tensor_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank for the tensor-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_rank() + + +def get_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank for the model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_rank() + + +def get_data_parallel_world_size(name: Optional[str] = None, + with_context_parallel: bool = False, + partial_data_parallel: bool = False): + """Return world size for the data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + with_context_parallel: Whether to include context parallel. + partial_data_parallel: Whether to use partial data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_data_parallel_world_size(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel) + + +def get_data_parallel_rank(name: Optional[str] = None, + with_context_parallel: bool = False, + partial_data_parallel: bool = False): + """Return caller's rank in the data-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + with_context_parallel: Whether to include context parallel. + partial_data_parallel: Whether to use partial data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_data_parallel_rank(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel) + + +def get_tensor_model_parallel_src_rank(name: Optional[str] = None): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + import deepspeed.comm as dist + global_rank = dist.get_rank() + local_world_size = get_tensor_model_parallel_world_size(name) + return (global_rank // local_world_size) * local_world_size + + +def set_tensor_model_parallel_world_size(world_size, name: Optional[str] = None): + """Set the tensor model parallel size. + + Args: + world_size: World size to set. + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + ps = get_parallel_state(name) + ps.mpu_tensor_model_parallel_world_size = world_size + + +def set_tensor_model_parallel_rank(rank, name: Optional[str] = None): + """Set tensor model parallel rank. + + Args: + rank: Rank to set. + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + ps = get_parallel_state(name) + ps.mpu_tensor_model_parallel_rank = rank + + +# ============================================================================ +# Pipeline Parallel Functions +# ============================================================================ + + +def get_pipeline_model_parallel_group(name: Optional[str] = None): + """Get the pipeline-model-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_pipeline_model_parallel_group() + + +def get_pipeline_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the pipeline-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_pipeline_model_parallel_world_size() + + +def get_pipeline_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank for the pipeline-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_pipeline_model_parallel_rank() + + +# ============================================================================ +# Context Parallel Functions +# ============================================================================ + + +def get_context_parallel_group(name: Optional[str] = None): + """Get the context-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_context_parallel_group() + + +def get_context_parallel_world_size(name: Optional[str] = None): + """Return world size for the context parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_context_parallel_world_size() + + +def get_context_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the context-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_context_parallel_rank() + + +# ============================================================================ +# Sequence Parallel Functions +# ============================================================================ + + +def get_sequence_parallel_group(name: Optional[str] = None): + """Get the sequence-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_parallel_group() + + +def get_sequence_parallel_world_size(name: Optional[str] = None): + """Return world size for the sequence parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_parallel_world_size() + + +def get_sequence_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the sequence-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_parallel_rank() + + +def get_sequence_and_data_parallel_group(name: Optional[str] = None): + """Get the sequence and data parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_and_data_parallel_group() + + +def get_sequence_and_data_parallel_world_size(name: Optional[str] = None): + """Return world size for the sequence and data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_and_data_parallel_world_size() + + +def get_sequence_and_data_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the sequence and data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_and_data_parallel_rank() + + +# ============================================================================ +# Expert Parallel Functions +# ============================================================================ + + +def get_expert_model_parallel_group(name: Optional[str] = None): + """Get the expert-model-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_model_parallel_group() + + +def get_expert_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the expert-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_model_parallel_world_size() + + +def get_expert_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the expert-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_model_parallel_rank() + + +def get_expert_tensor_parallel_group(name: Optional[str] = None): + """Get the expert-tensor-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_tensor_parallel_group() + + +def get_expert_tensor_parallel_world_size(name: Optional[str] = None): + """Return world size for the expert tensor parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_tensor_parallel_world_size() + + +def get_expert_tensor_parallel_rank(name: Optional[str] = None): + """Return my rank for the expert tensor parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_tensor_parallel_rank() + + +def get_expert_data_parallel_group(name: Optional[str] = None, partial_expert_data_parallel: bool = False): + """Get expert data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + partial_expert_data_parallel: Whether to use partial expert data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_data_parallel_group( + partial_expert_data_parallel=partial_expert_data_parallel) + + +def get_expert_data_parallel_world_size(name: Optional[str] = None, partial_expert_data_parallel: bool = False): + """Return world size for the expert data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + partial_expert_data_parallel: Whether to use partial expert data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_data_parallel_world_size( + partial_expert_data_parallel=partial_expert_data_parallel) + + +def get_expert_data_parallel_rank(name: Optional[str] = None, partial_expert_data_parallel: bool = False): + """Return caller's rank in the expert data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + partial_expert_data_parallel: Whether to use partial expert data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_data_parallel_rank( + partial_expert_data_parallel=partial_expert_data_parallel) + + +# ============================================================================ +# Additional Helper Functions +# ============================================================================ + + +def get_embedding_group(name: Optional[str] = None): + """Get the embedding group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_embedding_group() + + +def get_tensor_and_data_parallel_group(name: Optional[str] = None, with_context_parallel: bool = False): + """Get the tensor- and data-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + with_context_parallel: Whether to include context parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_and_data_parallel_group(with_context_parallel=with_context_parallel) + + +def get_tensor_and_context_parallel_group(name: Optional[str] = None): + """Get the tensor- and context-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_and_context_parallel_group() + + +def is_initialized(name: Optional[str] = None): + """Check if parallel state has been initialized. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).is_initialized() + + +# ============================================================================ +# All-to-All Groups for ZeRO++ Quantized Gradients +# ============================================================================ + + +def initialize_all_to_all_groups(name: Optional[str] = None): + """Initialize All-to-All groups for quantized gradient communication. + + Creates local and global All-to-All groups based on node topology. + Used by ZeRO++ when zero_quantized_gradients is enabled. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + Returns: + Dictionary of All-to-All groups + + Example: + # Initialize for default instance + all_to_all_groups = initialize_all_to_all_groups() + + # Initialize for named instance (RL scenario) + actor_groups = initialize_all_to_all_groups("actor") + critic_groups = initialize_all_to_all_groups("critic") + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).initialize_all_to_all_groups() + + +def get_all_to_all_groups(name: Optional[str] = None): + """Get All-to-All groups dictionary. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + Returns: + Dictionary of All-to-All groups + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_all_to_all_groups() + + +def _get_local_all_to_all_group(name: Optional[str] = None): + """Get All-to-All groups for current rank (backward compatible with groups.py). + + This function provides backward compatibility with the groups.py interface. + It returns all All-to-All groups (both local and global). + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + Returns: + Dictionary of All-to-All groups + + Note: + This is a compatibility wrapper. New code should use get_all_to_all_groups() instead. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_all_to_all_groups() + + +# ============================================================================ +# Configuration-based Initialization +# ============================================================================ + + +def initialize_parallel_state_from_config( + config: Union[Dict[str, Any], Any], + name: Optional[str] = None, + # Optional parameters to override config values + tensor_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_size: Optional[int] = None, + virtual_pipeline_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_comm_backend: Optional[str] = None, + context_parallel_size: Optional[int] = None, + hierarchical_context_parallel_sizes: Optional[List[int]] = None, + expert_model_parallel_size: Optional[int] = None, + num_distributed_optimizer_instances: Optional[int] = None, + expert_tensor_parallel_size: Optional[int] = None, + sequence_parallel_size: Optional[int] = None, + nccl_communicator_config_path: Optional[str] = None, + distributed_timeout_minutes: Optional[int] = None, + order: Optional[str] = None, + create_gloo_process_groups: Optional[bool] = None, + high_priority_stream_groups: Optional[List[str]] = None, +) -> ParallelState: + """Initialize parallel state from DeepSpeed config.json with optional parameter overrides. + + Reads parallelism configuration from the DeepSpeed config (including nested dicts) + and initializes the ParallelState instance. Returns the instance so it can be used + directly as the ``mpu`` argument to ``deepspeed.initialize``. + + Configuration priority: function parameters > config file values > default values + + Config keys support dot-separated paths for nested dicts. For example, + ``tensor_model_parallel_size`` is resolved from either the top-level key + ``"tensor_model_parallel_size"`` or the nested key ``"tensor_parallel.autotp_size"``. + + Args: + config: Either a DeepSpeedConfig object or a config dictionary. + name: Optional name of the parallel state instance to initialize. + If None, initializes the default global instance. + tensor_model_parallel_size: Size of tensor model parallel group. Default: 1. + Also read from ``tensor_parallel.autotp_size`` in config. + pipeline_model_parallel_size: Size of pipeline model parallel group. Default: 1 + virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size. Default: None + pipeline_model_parallel_comm_backend: Communication backend for pipeline. Default: None + context_parallel_size: Size of context parallel group. Default: 1 (MUST be 1, CP not supported) + hierarchical_context_parallel_sizes: Hierarchical context parallel sizes. Default: None (NOT supported) + expert_model_parallel_size: Size of expert model parallel group. Default: 1 + num_distributed_optimizer_instances: Number of distributed optimizer instances. Default: 1 + expert_tensor_parallel_size: Size of expert tensor parallel group. Default: None + sequence_parallel_size: Size of sequence parallel group. Default: 1 + nccl_communicator_config_path: Path to NCCL communicator config. Default: None + distributed_timeout_minutes: Timeout for distributed operations. Default: 30 + order: Order of parallelism dimensions. Default: "tp-ep-dp-pp" + create_gloo_process_groups: Whether to create Gloo process groups. Default: False + high_priority_stream_groups: High priority stream groups. Default: None + + Returns: + The initialized (or already-initialized) ParallelState instance. + + Example usage:: + + # Use return value as mpu: + ps = initialize_parallel_state_from_config(config_dict) + model, optimizer, _, _ = deepspeed.initialize( + model=model, model_parameters=model.parameters(), + config=config_dict, mpu=ps) + + # AutoTP config (nested dict): + config_dict = { + "tensor_parallel": {"autotp_size": 4}, + ... + } + ps = initialize_parallel_state_from_config(config_dict) + + # Override specific parameters: + ps = initialize_parallel_state_from_config( + ds_config, + tensor_model_parallel_size=4, + expert_model_parallel_size=2 + ) + + # Named instances (RL scenarios): + actor_ps = initialize_parallel_state_from_config(ds_config, name="actor") + critic_ps = initialize_parallel_state_from_config( + ds_config, name="critic", tensor_model_parallel_size=2) + """ + # Extract config dictionary + if hasattr(config, '_param_dict'): + # DeepSpeedConfig object + config_dict = config._param_dict + elif isinstance(config, dict): + # Already a dictionary + config_dict = config + else: + raise ValueError(f"config must be a DeepSpeedConfig object or a dict, got {type(config)}") + + # Get the parallel state instance + ps = get_parallel_state_instance(name) + + if ps.is_initialized(): + return ps + + # Import logging + import logging + logger = logging.getLogger(__name__) + + def _resolve_nested_key(d, dotted_key): + """Resolve a dot-separated key path in a nested dict. Returns (found, value).""" + keys = dotted_key.split(".") + cur = d + for k in keys: + if isinstance(cur, dict) and k in cur: + cur = cur[k] + else: + return False, None + return True, cur + + def get_value(param_value, config_key, default_value): + """ + Get value with priority: function parameter > config value > default. + + config_key can be a single dot-separated string (e.g. "tensor_parallel.autotp_size") + or a list of candidate keys tried in order. + """ + if param_value is not None: + return param_value + + candidates = config_key if isinstance(config_key, (list, tuple)) else [config_key] + for key in candidates: + found, value = _resolve_nested_key(config_dict, key) + if found: + return value + + return default_value + + init_kwargs = { + "tensor_model_parallel_size": + get_value(tensor_model_parallel_size, + ["tensor_model_parallel_size", "tensor_parallel.autotp_size"], 1), + "pipeline_model_parallel_size": + get_value(pipeline_model_parallel_size, "pipeline_model_parallel_size", 1), + "virtual_pipeline_model_parallel_size": + get_value(virtual_pipeline_model_parallel_size, "virtual_pipeline_model_parallel_size", None), + "pipeline_model_parallel_comm_backend": + get_value(pipeline_model_parallel_comm_backend, "pipeline_model_parallel_comm_backend", None), + "context_parallel_size": + get_value(context_parallel_size, "context_parallel_size", 1), + "sequence_parallel_size": + get_value(sequence_parallel_size, "sequence_parallel_size", 1), + "hierarchical_context_parallel_sizes": + get_value(hierarchical_context_parallel_sizes, "hierarchical_context_parallel_sizes", None), + "expert_model_parallel_size": + get_value(expert_model_parallel_size, "expert_model_parallel_size", 1), + "num_distributed_optimizer_instances": + get_value(num_distributed_optimizer_instances, "num_distributed_optimizer_instances", 1), + "expert_tensor_parallel_size": + get_value(expert_tensor_parallel_size, "expert_tensor_parallel_size", None), + "nccl_communicator_config_path": + get_value(nccl_communicator_config_path, "nccl_communicator_config_path", None), + "distributed_timeout_minutes": + get_value(distributed_timeout_minutes, "distributed_timeout_minutes", 30), + "order": + get_value(order, "order", "tp-ep-dp-pp"), + "create_gloo_process_groups": + get_value(create_gloo_process_groups, "create_gloo_process_groups", False), + "high_priority_stream_groups": + get_value(high_priority_stream_groups, "high_priority_stream_groups", None), + } + + # Validate context_parallel_size + cp_size = init_kwargs["context_parallel_size"] + if cp_size != 1: + raise NotImplementedError( + f"DeepSpeed currently does not support context_parallel_size > 1. " + f"Got context_parallel_size={cp_size}. Please set context_parallel_size=1 in your config.") + + # Validate hierarchical_context_parallel_sizes + hcp_sizes = init_kwargs["hierarchical_context_parallel_sizes"] + if hcp_sizes is not None: + raise NotImplementedError( + f"DeepSpeed currently does not support hierarchical_context_parallel_sizes. " + f"Got hierarchical_context_parallel_sizes={hcp_sizes}. Please remove this configuration.") + + # Remove None values for optional parameters (except those that can be None) + # Keep None for: virtual_pipeline_model_parallel_size, pipeline_model_parallel_comm_backend, + # hierarchical_context_parallel_sizes, expert_tensor_parallel_size + # Note: nccl_communicator_config_path and high_priority_stream_groups are not supported by initialize_model_parallel + filtered_kwargs = {} + supported_params = { + "tensor_model_parallel_size", "pipeline_model_parallel_size", "virtual_pipeline_model_parallel_size", + "pipeline_model_parallel_comm_backend", "context_parallel_size", "sequence_parallel_size", + "hierarchical_context_parallel_sizes", "expert_model_parallel_size", "num_distributed_optimizer_instances", + "expert_tensor_parallel_size", "distributed_timeout_minutes", "order", "create_gloo_process_groups" + } + + for key, value in init_kwargs.items(): + # Skip unsupported parameters + if key not in supported_params: + continue + # Keep None for parameters that can be None + if value is not None or key in [ + "virtual_pipeline_model_parallel_size", "pipeline_model_parallel_comm_backend", + "hierarchical_context_parallel_sizes", "expert_tensor_parallel_size" + ]: + filtered_kwargs[key] = value + + ps.initialize_model_parallel(**filtered_kwargs) + return ps diff --git a/docs/_pages/training.md b/docs/_pages/training.md index e31651cc487a..bdae8b563807 100644 --- a/docs/_pages/training.md +++ b/docs/_pages/training.md @@ -244,6 +244,100 @@ mpu.get_data_parallel_group() mpu.get_data_parallel_world_size() ``` +### Built-in Parallel State Management + +DeepSpeed provides a built-in `ParallelState` class that implements the `mpu` interface +with Megatron-style process group management. It supports tensor parallelism (TP), +pipeline parallelism (PP), data parallelism (DP), sequence parallelism (SP), +context parallelism (CP), and expert parallelism (EP). + +#### Basic Usage + +You can initialize the parallel state either explicitly or from a DeepSpeed config: + +```python +from deepspeed.utils import parallel_state_deepspeed as ps + +# Option 1: Initialize from config dict (also works with DeepSpeedConfig objects) +config_dict = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": {"autotp_size": 4}, + "zero_optimization": {"stage": 1} +} +parallel_state = ps.initialize_parallel_state_from_config(config_dict) + +# The returned ParallelState can be passed directly as mpu +model_engine, optimizer, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=config_dict, + mpu=parallel_state +) +``` + +```python +# Option 2: Initialize explicitly with parallelism dimensions +parallel_state = ps.get_parallel_state_instance() +parallel_state.initialize_model_parallel( + tensor_model_parallel_size=4, + pipeline_model_parallel_size=2, + sequence_parallel_size=1, +) +``` + +#### Configuration-based Initialization + +`initialize_parallel_state_from_config` resolves parallelism parameters with +the following priority: **function parameters > config values > defaults**. + +Config keys support dot-separated paths for nested dictionaries. For example, +`tensor_model_parallel_size` can be read from `"tensor_model_parallel_size"` at +the top level or `"tensor_parallel.autotp_size"` in a nested config. + +```python +from deepspeed.utils import parallel_state_deepspeed as ps + +# Override specific parameters while reading others from config +parallel_state = ps.initialize_parallel_state_from_config( + config_dict, + tensor_model_parallel_size=4, # Override config value + expert_model_parallel_size=2, # Override config value +) +``` + +#### Multiple Instances (RL Scenarios) + +In reinforcement learning scenarios where multiple models (e.g., actor and critic) +require different parallelism configurations, you can create named instances: + +```python +from deepspeed.utils import parallel_state_deepspeed as ps + +# Create separate parallel state instances +actor_ps = ps.initialize_parallel_state_from_config( + actor_config, name="actor", + tensor_model_parallel_size=4, +) +critic_ps = ps.initialize_parallel_state_from_config( + critic_config, name="critic", + tensor_model_parallel_size=2, +) + +# Use context manager to switch between instances +with ps.set_current_parallel_state("actor"): + dp_group = ps.get_data_parallel_group() # Uses actor's groups + +with ps.set_current_parallel_state("critic"): + dp_group = ps.get_data_parallel_group() # Uses critic's groups +``` + +#### Compatibility with Existing Code + +The module-level functions in `parallel_state_deepspeed` (such as +`get_data_parallel_group()`, `get_tensor_model_parallel_world_size()`, etc.) +operate on the current active `ParallelState` instance, preserving backward +compatibility with code written against the previous `groups.py` API. + ### Integration with Megatron-LM DeepSpeed is fully compatible with [Megatron](https://github.com/NVIDIA/Megatron-LM). Please see the [Megatron-LM tutorial](/tutorials/megatron/) for details. diff --git a/docs/code-docs/source/initialize.rst b/docs/code-docs/source/initialize.rst index dd69a5dec4d2..172376043229 100644 --- a/docs/code-docs/source/initialize.rst +++ b/docs/code-docs/source/initialize.rst @@ -42,3 +42,41 @@ Distributed Initialization Optional distributed backend initialization separate from ``deepspeed.initialize()``. Useful in scenarios where the user wants to use torch distributed calls before calling ``deepspeed.initialize()``, such as when using model parallelism, pipeline parallelism, or certain data loader scenarios. .. autofunction:: deepspeed.init_distributed + + +.. _parallel-state-init: + +Parallel State Initialization +----------------------------- +DeepSpeed provides a built-in ``ParallelState`` class for Megatron-style process group management +covering tensor, pipeline, data, sequence, context, and expert parallelism. + +Use ``initialize_parallel_state_from_config`` to create and initialize a ``ParallelState`` from +a DeepSpeed config dictionary (or ``DeepSpeedConfig`` object). The returned instance implements +the ``mpu`` interface and can be passed directly to ``deepspeed.initialize(mpu=...)``. + +Example usage: + +.. code-block:: python + + from deepspeed.utils import parallel_state_deepspeed as ps + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": {"autotp_size": 4}, + } + + # Initialize and use as mpu + parallel_state = ps.initialize_parallel_state_from_config(config_dict) + model_engine, optimizer, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=config_dict, + mpu=parallel_state, + ) + +.. autofunction:: deepspeed.utils.parallel_state_deepspeed.initialize_parallel_state_from_config + +.. autoclass:: deepspeed.utils.parallel_state.ParallelState + :members: initialize_model_parallel, is_initialized, get_tensor_model_parallel_group, get_data_parallel_group, get_pipeline_model_parallel_group, get_sequence_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, get_data_parallel_world_size, get_data_parallel_rank, get_pipeline_model_parallel_world_size, get_pipeline_model_parallel_rank + :noindex: diff --git a/tests/unit/utils/test_parallel_state_deepspeed.py b/tests/unit/utils/test_parallel_state_deepspeed.py new file mode 100644 index 000000000000..d77793716382 --- /dev/null +++ b/tests/unit/utils/test_parallel_state_deepspeed.py @@ -0,0 +1,459 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Integration tests for using ParallelState as mpu in deepspeed.initialize() + +Tests the full workflow: +1. Initialize parallel_state_deepspeed with parallel configurations +2. Pass it as mpu parameter to deepspeed.initialize() +3. Verify DeepSpeed Engine correctly uses the parallel state +""" + +import pytest +import torch +import deepspeed +import deepspeed.comm as dist +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader + + +class TestParallelStateAsMPU(DistributedTest): + """Test parallel_state_deepspeed as mpu parameter in deepspeed.initialize()""" + + world_size = 8 + + def _get_base_config(self): + """Get base DeepSpeed config""" + return {"train_batch_size": 8, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}} + + def _verify_mpu_integration(self, engine, mpu, expected_tp=1, expected_pp=1, expected_sp=1): + """Verify mpu is correctly integrated in engine""" + # 1. Engine holds mpu reference + assert engine.mpu == mpu + + # 2. Parallel configuration is correct + assert mpu.get_tensor_model_parallel_world_size() == expected_tp + assert mpu.get_pipeline_model_parallel_world_size() == expected_pp + + # 3. Data parallel world size is correctly calculated + world_size = dist.get_world_size() + expected_dp = world_size // (expected_tp * expected_pp * expected_sp) + assert mpu.get_data_parallel_world_size() == expected_dp + + # 4. Config uses mpu for world_size calculation + assert engine.config.world_size == expected_dp + + return expected_dp + + def test_basic_mpu_usage(self): + """Test basic mpu parameter usage with TP and PP""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_basic") + state.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=2) + + config = self._get_base_config() + model = SimpleModel(hidden_dim=16) + + # Pass parallel_state module as mpu (the module provides compatibility layer) + with ps.set_current_parallel_state("test_basic"): + engine, optimizer, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Verify integration + with ps.set_current_parallel_state("test_basic"): + self._verify_mpu_integration(engine, ps, expected_tp=2, expected_pp=2) + + # Verify optimizer is created + assert optimizer is not None + + # Test training for 5 batches + data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) + for i, batch in enumerate(data_loader): + if i >= 5: + break + loss = engine(batch[0], batch[1]) + assert loss is not None + engine.backward(loss) + engine.step() + + def test_config_driven_mpu(self): + """Test mpu initialized from config with sequence_parallel_size""" + from deepspeed.utils import parallel_state_deepspeed as ps + + config = { + "train_batch_size": 8, + "sequence_parallel_size": 2, + "order": "tp-sp-dp-pp", # Need to specify order when using sp + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001 + } + } + } + + # Initialize from config + ps.initialize_parallel_state_from_config(config, name="config_driven_test") + + model = SimpleModel(hidden_dim=16) + + # Set current instance + ps.set_current_parallel_state("config_driven_test") + + engine, _, _, _ = deepspeed.initialize(model=model, config=config, mpu=ps, model_parameters=model.parameters()) + + # Verify SP group is created + with ps.set_current_parallel_state("config_driven_test"): + sp_world_size = ps.get_sequence_parallel_world_size() + assert sp_world_size == 2 + + # Verify integration + with ps.set_current_parallel_state("config_driven_test"): + self._verify_mpu_integration(engine, ps, expected_sp=2) + + # Test training for 5 batches + data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) + for i, batch in enumerate(data_loader): + if i >= 5: + break + loss = engine(batch[0], batch[1]) + engine.backward(loss) + engine.step() + + def test_multi_instance_mpu(self): + """Test multiple named instances as mpu (Actor-Critic scenario)""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Initialize Actor with TP=2 + actor_state = ps.get_parallel_state_instance("actor") + actor_state.initialize_model_parallel(tensor_model_parallel_size=2) + + # Initialize Critic with TP=1 (no parallelism) + critic_state = ps.get_parallel_state_instance("critic") + critic_state.initialize_model_parallel(tensor_model_parallel_size=1) + + config = self._get_base_config() + + # Create Actor engine + actor_model = SimpleModel(hidden_dim=16) + with ps.set_current_parallel_state("actor"): + actor_engine, _, _, _ = deepspeed.initialize(model=actor_model, + config=config, + mpu=ps, + model_parameters=actor_model.parameters()) + + # Create Critic engine + critic_model = SimpleModel(hidden_dim=16) + with ps.set_current_parallel_state("critic"): + critic_engine, _, _, _ = deepspeed.initialize(model=critic_model, + config=config, + mpu=ps, + model_parameters=critic_model.parameters()) + + # Verify Actor uses TP=2 + with ps.set_current_parallel_state("actor"): + assert ps.get_tensor_model_parallel_world_size() == 2 + assert actor_engine.mpu == ps + + # Verify Critic uses TP=1 + with ps.set_current_parallel_state("critic"): + assert ps.get_tensor_model_parallel_world_size() == 1 + assert critic_engine.mpu == ps + + # Test training for 5 batches on both engines + actor_loader = random_dataloader(model=actor_engine.module, total_samples=20, hidden_dim=16, device=actor_engine.device) + critic_loader = random_dataloader(model=critic_engine.module, total_samples=20, hidden_dim=16, device=critic_engine.device) + for i, (actor_batch, critic_batch) in enumerate(zip(actor_loader, critic_loader)): + if i >= 5: + break + actor_loss = actor_engine(actor_batch[0], actor_batch[1]) + assert actor_loss is not None + actor_engine.backward(actor_loss) + actor_engine.step() + + critic_loss = critic_engine(critic_batch[0], critic_batch[1]) + assert critic_loss is not None + critic_engine.backward(critic_loss) + critic_engine.step() + + def test_mpu_with_zero_stage1(self): + """Test mpu integration with ZeRO Stage 1""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_zero") + state.initialize_model_parallel(tensor_model_parallel_size=2) + + config = { + "train_batch_size": 8, + "zero_optimization": { + "stage": 1 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001 + } + } + } + + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("test_zero"): + engine, optimizer, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Verify ZeRO optimizer is created + assert optimizer is not None + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + assert isinstance(optimizer, DeepSpeedZeroOptimizer) + + # Verify mpu integration + with ps.set_current_parallel_state("test_zero"): + self._verify_mpu_integration(engine, ps, expected_tp=2) + + # Verify optimizer uses correct DP group + assert optimizer.mpu == ps + + # Test training for 5 batches + data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) + for i, batch in enumerate(data_loader): + if i >= 5: + break + loss = engine(batch[0], batch[1]) + engine.backward(loss) + engine.step() + + def test_deepspeed_config_uses_mpu(self): + """Test DeepSpeedConfig correctly uses mpu for world_size calculation""" + from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.runtime.config import DeepSpeedConfig + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_config") + state.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=2) + + config_dict = self._get_base_config() + + # Create DeepSpeedConfig with mpu + with ps.set_current_parallel_state("test_config"): + ds_config = DeepSpeedConfig(config_dict, mpu=ps) + + # Verify world_size calculation uses mpu + expected_dp = dist.get_world_size() // (2 * 2) + assert ds_config.world_size == expected_dp + + # Verify it matches mpu's calculation + with ps.set_current_parallel_state("test_config"): + assert ds_config.world_size == ps.get_data_parallel_world_size() + + def test_mpu_without_parallelism(self): + """Test mpu with all parallelism dimensions = 1 (no parallelism)""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_no_parallel") + state.initialize_model_parallel() + + config = self._get_base_config() + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("test_no_parallel"): + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Verify all dimensions are 1 + with ps.set_current_parallel_state("test_no_parallel"): + assert ps.get_tensor_model_parallel_world_size() == 1 + assert ps.get_pipeline_model_parallel_world_size() == 1 + + # DP should equal world_size + assert ps.get_data_parallel_world_size() == dist.get_world_size() + + # Test training for 5 batches + data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) + for i, batch in enumerate(data_loader): + if i >= 5: + break + loss = engine(batch[0], batch[1]) + engine.backward(loss) + engine.step() + + def test_mpu_with_different_orders(self): + """Test mpu with different parallel dimension orders""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_order") + state.initialize_model_parallel(tensor_model_parallel_size=2, + expert_model_parallel_size=2, + order="tp-ep-dp-pp") + + config = self._get_base_config() + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("test_order"): + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Verify parallel configuration + with ps.set_current_parallel_state("test_order"): + assert ps.get_tensor_model_parallel_world_size() == 2 + assert ps.get_expert_model_parallel_world_size() == 2 + + # Verify DP world_size: world_size / (tp * ep) + expected_dp = dist.get_world_size() // (2 * 2) + assert ps.get_data_parallel_world_size() == expected_dp + + +class TestParallelStateConfigPriority(DistributedTest): + """Test configuration priority: params > config > defaults""" + + world_size = 4 + + def test_param_overrides_config(self): + """Function parameter should override config value""" + from deepspeed.utils import parallel_state_deepspeed as ps + + config = { + "train_batch_size": 4, + "sequence_parallel_size": 2, # Config says 2 + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001 + } + } + } + + # Override with param: sp=1 + ps.initialize_parallel_state_from_config( + config, + name="param_override_test", + sequence_parallel_size=1 # Parameter overrides config + ) + + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("param_override_test"): + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # With sp=1, SP group should not have special effect + assert engine is not None + assert engine.mpu == ps + + def test_config_overrides_default(self): + """Config value should override default value""" + from deepspeed.utils import parallel_state_deepspeed as ps + + config = { + "train_batch_size": 4, + "sequence_parallel_size": 2, # Override default (1) + "order": "tp-sp-dp-pp", # Need to specify order when using sp + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001 + } + } + } + + # Don't pass sequence_parallel_size parameter + ps.initialize_parallel_state_from_config(config, name="config_override_test") + + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("config_override_test"): + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Verify SP is configured from config + # Since sp_size = 2, SP group should be initialized + with ps.set_current_parallel_state("config_override_test"): + sp_world_size = ps.get_sequence_parallel_world_size() + assert sp_world_size == 2 + + +class TestParallelStateValidation(DistributedTest): + """Test validation and error handling""" + + world_size = 4 + + def test_context_parallel_not_supported(self): + """Test that CP > 1 raises NotImplementedError""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # CP > 1 should raise error via initialize_parallel_state_from_config + with pytest.raises(NotImplementedError, match="does not support context_parallel_size"): + ps.initialize_parallel_state_from_config({"context_parallel_size": 2}, name="cp_test") + + def test_hierarchical_cp_not_supported(self): + """Test that hierarchical CP raises NotImplementedError""" + from deepspeed.utils import parallel_state_deepspeed as ps + + with pytest.raises(NotImplementedError, match="does not support hierarchical_context_parallel_sizes"): + ps.initialize_parallel_state_from_config({"hierarchical_context_parallel_sizes": [2, 2]}, name="hcp_test") + + +class TestAllToAllGroupsWithMPU(DistributedTest): + """Test All-to-All groups initialization with mpu""" + + world_size = 8 + + def test_all_to_all_groups_with_mpu(self): + """Test All-to-All groups work with mpu in initialize""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_all_to_all") + state.initialize_model_parallel() + + config = {"train_batch_size": 8, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}} + + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("test_all_to_all"): + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Initialize All-to-All groups + with ps.set_current_parallel_state("test_all_to_all"): + all_to_all_groups = ps.initialize_all_to_all_groups() + + # Verify groups are created + assert isinstance(all_to_all_groups, dict) + assert len(all_to_all_groups) > 0 + + # Test backward compatibility interface + with ps.set_current_parallel_state("test_all_to_all"): + compat_groups = ps._get_local_all_to_all_group() + assert compat_groups == all_to_all_groups + + # Test training for 5 batches + data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) + for i, batch in enumerate(data_loader): + if i >= 5: + break + loss = engine(batch[0], batch[1]) + assert loss is not None + engine.backward(loss) + engine.step()