Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,9 @@ def get_op_builder(self, class_name):
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
from op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
from op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, ZenFlowAdamBuilder, NotImplementedBuilder
except ImportError:
from deepspeed.ops.op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
from deepspeed.ops.op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, ZenFlowAdamBuilder, NotImplementedBuilder

if class_name == "CCLCommBuilder":
return CCLCommBuilder
Expand All @@ -328,6 +328,8 @@ def get_op_builder(self, class_name):
return FusedAdamBuilder
elif class_name == "CPUAdamBuilder":
return CPUAdamBuilder
elif class_name == "ZenFlowAdamBuilder":
return ZenFlowAdamBuilder
elif class_name == "AsyncIOBuilder":
return AsyncIOBuilder
else:
Expand Down
14 changes: 12 additions & 2 deletions deepspeed/ops/adam/zenflow_cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,28 @@

from deepspeed.ops.adam import DeepSpeedCPUAdam
import torch
from deepspeed.ops.op_builder import ZenFlowAdamBuilder


class ZenFlowCPUAdam(DeepSpeedCPUAdam):

def __init__(self, *args, overlap_step=False, **kwargs):
super(ZenFlowCPUAdam, self).__init__(*args, **kwargs)

# Destroy the one created by DeepSpeedCPUAdam in cpu_adam_op
self.ds_opt_adam.destroy_adam(self.opt_id)

self.ds_opt_adam = ZenFlowAdamBuilder().load()
self.ds_opt_adam.create_adam(self.opt_id, self.param_groups[0]['lr'], self.param_groups[0]['betas'][0],
self.param_groups[0]['betas'][1], self.param_groups[0]['eps'],
self.param_groups[0]['weight_decay'], self.adam_w_mode, False)

self.overlap_step = overlap_step
if not self.overlap_step:
print("ZenFlowCPUAdam initialized with normal step.")
# Use sequential update logic
self.step = self._sequential_step
else:
print("ZenFlowCPUAdam initialized with overlap step.")
# Use parallel/overlapped update logic
self.step = self._parallel_step

@torch.no_grad()
Expand Down
9 changes: 5 additions & 4 deletions deepspeed/ops/adam/zenflow_torch_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import Tensor

from deepspeed.utils.torch import required_torch_version
from deepspeed.runtime.compiler import is_compiling

# Check if we have PyTorch >= 2.0 for ZenFlow features
_ZENFLOW_AVAILABLE = required_torch_version(min_version=2.1)
Expand Down Expand Up @@ -568,7 +569,7 @@ def _single_tensor_adamw(
step_t = state_steps[i]

# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
if not is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type and param.device.type in capturable_supported_devices
Expand Down Expand Up @@ -674,7 +675,7 @@ def _multi_tensor_adamw(
raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True")

# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
if not is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(supports_xla=False)
assert all(
p.device.type == step.device.type and p.device.type in capturable_supported_devices
Expand Down Expand Up @@ -722,7 +723,7 @@ def _multi_tensor_adamw(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
if not is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0)
else:
torch._foreach_add_(device_state_steps, 1)
Expand Down Expand Up @@ -936,7 +937,7 @@ def adamw(
"Please upgrade to PyTorch 2.0+ to use ZenFlow, or omit 'zenflow' "
"from your DeepSpeed configuration to use the default ZeRO-Offload optimizer.")

if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
if not is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")

# Respect when the user inputs False/True for foreach or fused. We only want to change
Expand Down
15 changes: 15 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,6 +2331,13 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
def _backward_prologue(self):
self._start_timers(self.engine_timers.backward_timers)

# ZenFlow requires the use of engine.backward(loss) to manage its
# specialized backward pass and synchronization logic.
if self.zenflow and not self._running_engine_backward:
raise RuntimeError("Direct calls to loss.backward() are not currently supported when ZenFlow is enabled. "
"Please use engine.backward(loss) instead to allow ZenFlow to manage "
"selective updates and synchronization.")

# When necessary internal APIs are not available, we disable direct calls to tensor.backward()
# and limit to engine.backward(loss) only.
if not self._support_torch_style_backward and not self._running_engine_backward:
Expand Down Expand Up @@ -2514,6 +2521,14 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
loss = self.torch_autocast_z0_gradscaler.scale(loss)

with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs):
# ZenFlow requires exclusive control over the backward pass to manage its
# selective parameter updates and synchronization boundaries.
if self.zenflow:
self.optimizer.backward(loss, **backward_kwargs)
self._backward_epilogue()
self._running_engine_backward = False
return gas_scaled_loss

if self.zero_optimization() or not self.amp_enabled():
loss.backward(**backward_kwargs)
elif self.amp_enabled():
Expand Down
1 change: 1 addition & 0 deletions op_builder/cpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
from .comm import CCLCommBuilder, ShareMemCommBuilder
from .fused_adam import FusedAdamBuilder
from .cpu_adam import CPUAdamBuilder
from .zenflow_adam import ZenFlowAdamBuilder
from .no_impl import NotImplementedBuilder
from .async_io import AsyncIOBuilder
27 changes: 27 additions & 0 deletions op_builder/cpu/zenflow_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) DeepSpeed Team.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .builder import CPUOpBuilder


class ZenFlowAdamBuilder(CPUOpBuilder):
BUILD_VAR = "DS_BUILD_ZENFLOW_ADAM"
NAME = "zenflow_adam"

def __init__(self):
super().__init__(name=self.NAME)

def absolute_name(self):
return f'deepspeed.ops.adam.{self.NAME}_op'

def sources(self):
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp']

def libraries_args(self):
args = super().libraries_args()
return args

def include_paths(self):
return ['csrc/includes']
27 changes: 27 additions & 0 deletions op_builder/zenflow_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) DeepSpeed Team.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .builder import TorchCPUOpBuilder


class ZenFlowAdamBuilder(TorchCPUOpBuilder):
BUILD_VAR = "DS_BUILD_ZENFLOW_ADAM"
NAME = "zenflow_adam"

def __init__(self):
super().__init__(name=self.NAME)

def absolute_name(self):
return f'deepspeed.ops.adam.{self.NAME}_op'

def sources(self):
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp']

def libraries_args(self):
args = super().libraries_args()
return args

def include_paths(self):
return ['csrc/includes']