diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 4b3d89e6cd34..bdd0e147107b 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -316,9 +316,9 @@ def get_op_builder(self, class_name): # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed # if successful this also means we're doing a local install and not JIT compile path from op_builder import __deepspeed__ # noqa: F401 # type: ignore - from op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder + from op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, ZenFlowAdamBuilder, NotImplementedBuilder except ImportError: - from deepspeed.ops.op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder + from deepspeed.ops.op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, ZenFlowAdamBuilder, NotImplementedBuilder if class_name == "CCLCommBuilder": return CCLCommBuilder @@ -328,6 +328,8 @@ def get_op_builder(self, class_name): return FusedAdamBuilder elif class_name == "CPUAdamBuilder": return CPUAdamBuilder + elif class_name == "ZenFlowAdamBuilder": + return ZenFlowAdamBuilder elif class_name == "AsyncIOBuilder": return AsyncIOBuilder else: diff --git a/deepspeed/ops/adam/zenflow_cpu_adam.py b/deepspeed/ops/adam/zenflow_cpu_adam.py index 0809d7a0f7e0..d0e58a29f0e9 100644 --- a/deepspeed/ops/adam/zenflow_cpu_adam.py +++ b/deepspeed/ops/adam/zenflow_cpu_adam.py @@ -5,18 +5,28 @@ from deepspeed.ops.adam import DeepSpeedCPUAdam import torch +from deepspeed.ops.op_builder import ZenFlowAdamBuilder class ZenFlowCPUAdam(DeepSpeedCPUAdam): def __init__(self, *args, overlap_step=False, **kwargs): super(ZenFlowCPUAdam, self).__init__(*args, **kwargs) + + # Destroy the one created by DeepSpeedCPUAdam in cpu_adam_op + self.ds_opt_adam.destroy_adam(self.opt_id) + + self.ds_opt_adam = ZenFlowAdamBuilder().load() + self.ds_opt_adam.create_adam(self.opt_id, self.param_groups[0]['lr'], self.param_groups[0]['betas'][0], + self.param_groups[0]['betas'][1], self.param_groups[0]['eps'], + self.param_groups[0]['weight_decay'], self.adam_w_mode, False) + self.overlap_step = overlap_step if not self.overlap_step: - print("ZenFlowCPUAdam initialized with normal step.") + # Use sequential update logic self.step = self._sequential_step else: - print("ZenFlowCPUAdam initialized with overlap step.") + # Use parallel/overlapped update logic self.step = self._parallel_step @torch.no_grad() diff --git a/deepspeed/ops/adam/zenflow_torch_adam.py b/deepspeed/ops/adam/zenflow_torch_adam.py index 1d55210d6edc..5a52d10967a9 100644 --- a/deepspeed/ops/adam/zenflow_torch_adam.py +++ b/deepspeed/ops/adam/zenflow_torch_adam.py @@ -8,6 +8,7 @@ from torch import Tensor from deepspeed.utils.torch import required_torch_version +from deepspeed.runtime.compiler import is_compiling # Check if we have PyTorch >= 2.0 for ZenFlow features _ZENFLOW_AVAILABLE = required_torch_version(min_version=2.1) @@ -568,7 +569,7 @@ def _single_tensor_adamw( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type and param.device.type in capturable_supported_devices @@ -674,7 +675,7 @@ def _multi_tensor_adamw( raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True") # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if not is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices(supports_xla=False) assert all( p.device.type == step.device.type and p.device.type in capturable_supported_devices @@ -722,7 +723,7 @@ def _multi_tensor_adamw( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: + if not is_compiling() and device_state_steps[0].is_cpu: torch._foreach_add_(device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0) else: torch._foreach_add_(device_state_steps, 1) @@ -936,7 +937,7 @@ def adamw( "Please upgrade to PyTorch 2.0+ to use ZenFlow, or omit 'zenflow' " "from your DeepSpeed configuration to use the default ZeRO-Offload optimizer.") - if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps): + if not is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") # Respect when the user inputs False/True for foreach or fused. We only want to change diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 272fe832cb2f..cd45d86b65e3 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2331,6 +2331,13 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): def _backward_prologue(self): self._start_timers(self.engine_timers.backward_timers) + # ZenFlow requires the use of engine.backward(loss) to manage its + # specialized backward pass and synchronization logic. + if self.zenflow and not self._running_engine_backward: + raise RuntimeError("Direct calls to loss.backward() are not currently supported when ZenFlow is enabled. " + "Please use engine.backward(loss) instead to allow ZenFlow to manage " + "selective updates and synchronization.") + # When necessary internal APIs are not available, we disable direct calls to tensor.backward() # and limit to engine.backward(loss) only. if not self._support_torch_style_backward and not self._running_engine_backward: @@ -2514,6 +2521,14 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True): loss = self.torch_autocast_z0_gradscaler.scale(loss) with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs): + # ZenFlow requires exclusive control over the backward pass to manage its + # selective parameter updates and synchronization boundaries. + if self.zenflow: + self.optimizer.backward(loss, **backward_kwargs) + self._backward_epilogue() + self._running_engine_backward = False + return gas_scaled_loss + if self.zero_optimization() or not self.amp_enabled(): loss.backward(**backward_kwargs) elif self.amp_enabled(): diff --git a/op_builder/cpu/__init__.py b/op_builder/cpu/__init__.py index 7084db8469f1..69e1adfa8c1b 100644 --- a/op_builder/cpu/__init__.py +++ b/op_builder/cpu/__init__.py @@ -7,5 +7,6 @@ from .comm import CCLCommBuilder, ShareMemCommBuilder from .fused_adam import FusedAdamBuilder from .cpu_adam import CPUAdamBuilder +from .zenflow_adam import ZenFlowAdamBuilder from .no_impl import NotImplementedBuilder from .async_io import AsyncIOBuilder diff --git a/op_builder/cpu/zenflow_adam.py b/op_builder/cpu/zenflow_adam.py new file mode 100644 index 000000000000..c5526c5df679 --- /dev/null +++ b/op_builder/cpu/zenflow_adam.py @@ -0,0 +1,27 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CPUOpBuilder + + +class ZenFlowAdamBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_ZENFLOW_ADAM" + NAME = "zenflow_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + return ['csrc/includes'] diff --git a/op_builder/zenflow_adam.py b/op_builder/zenflow_adam.py new file mode 100644 index 000000000000..0f510bdcf908 --- /dev/null +++ b/op_builder/zenflow_adam.py @@ -0,0 +1,27 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import TorchCPUOpBuilder + + +class ZenFlowAdamBuilder(TorchCPUOpBuilder): + BUILD_VAR = "DS_BUILD_ZENFLOW_ADAM" + NAME = "zenflow_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + return ['csrc/includes']