From 2e856639633251c181cdafde4f1e9adc1ed97714 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 26 Feb 2026 23:18:09 -0800 Subject: [PATCH 1/3] Replace get_accelerator().amp().autocast() with torch.amp.autocast(device_type=...) Signed-off-by: Ma, Guokai --- tests/unit/runtime/test_autocast.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/unit/runtime/test_autocast.py b/tests/unit/runtime/test_autocast.py index 9176770afda7..682a98ae38bb 100644 --- a/tests/unit/runtime/test_autocast.py +++ b/tests/unit/runtime/test_autocast.py @@ -26,8 +26,6 @@ def test_missing_amp_autocast(self, half_op): assert output.dtype == ds_linear.weight.dtype def test_disable_autocast_linear(self, half_op): - amp = get_accelerator().amp() - hidden_dim = 4 if half_op: input = torch.randn(hidden_dim).to(get_accelerator().device_name()).half() @@ -36,18 +34,15 @@ def test_disable_autocast_linear(self, half_op): input = torch.randn(hidden_dim).to(get_accelerator().device_name()) ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name()) - with amp.autocast(False): + with torch.amp.autocast(device_type=get_accelerator().device_name(), enabled=False): output = ds_linear(input) assert output.dtype == ds_linear.weight.dtype -@pytest.mark.skipif(get_accelerator().amp() is None, reason='amp is not installed') @pytest.mark.parametrize('half_input, half_weight', [(False, False), (False, True), (True, False), (True, True)]) class TestAutoCastEnable(DistributedTest): def test_autocast_linear(self, tmpdir, half_input, half_weight): - amp = get_accelerator().amp() - hidden_dim = 4 input = torch.randn(hidden_dim).to(get_accelerator().device_name()) ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name()) @@ -58,6 +53,6 @@ def test_autocast_linear(self, tmpdir, half_input, half_weight): if half_weight: ds_linear = ds_linear.half() - with amp.autocast(): + with torch.amp.autocast(device_type=get_accelerator().device_name()): output = ds_linear(input) assert output.dtype == torch.half or output.dtype == torch.bfloat16 From a5b04f024f0204129b3b64a9e5a7e3f5eefccc5a Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 26 Feb 2026 23:39:22 -0800 Subject: [PATCH 2/3] Remove get_accelerator().amp() API and use torch.amp directly The amp() method on each accelerator returned a device-specific torch..amp module, but since PyTorch 2.4 the unified torch.amp API (torch.amp.custom_fwd, torch.amp.custom_bwd, torch.amp.autocast) accepts a device_type argument and works across all backends. The previous commit already migrated the two call sites; this commit removes the now-unused amp() abstract method and all 8 accelerator implementations, plus simplifies the custom_fwd/custom_bwd setup in zero/linear.py by dropping the pre-2.4 fallback path. Signed-off-by: Ma, Guokai --- accelerator/abstract_accelerator.py | 4 ---- accelerator/cpu_accelerator.py | 3 --- accelerator/cuda_accelerator.py | 5 ----- accelerator/hpu_accelerator.py | 3 --- accelerator/mlu_accelerator.py | 5 ----- accelerator/mps_accelerator.py | 3 --- accelerator/npu_accelerator.py | 5 ----- accelerator/sdaa_accelerator.py | 5 ----- accelerator/xpu_accelerator.py | 3 --- deepspeed/runtime/zero/linear.py | 15 ++------------- 10 files changed, 2 insertions(+), 49 deletions(-) diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index 2a0770ac681b..4db8ae5ebdb7 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -178,10 +178,6 @@ def supported_dtypes(self): ... # Misc - @abc.abstractmethod - def amp(self): - ... - @abc.abstractmethod def is_available(self): ... diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index d933041bed55..6b414a6e04ef 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -199,9 +199,6 @@ def available_memory(self, device_index=None): return psutil.virtual_memory().available # Misc - def amp(self): - return torch.cpu.amp - def is_available(self): return True diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index c45903421324..42cb93f9581d 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -222,11 +222,6 @@ def supported_dtypes(self): return supported_dtypes # Misc - def amp(self): - if hasattr(torch.cuda, 'amp'): - return torch.cuda.amp - return None - def is_available(self): return torch.cuda.is_available() diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 9d82eb590902..c6bc94f06149 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -173,9 +173,6 @@ def supported_dtypes(self): return supported_dtypes # Misc - def amp(self): - return None - def is_available(self): return self.hpu.is_available() diff --git a/accelerator/mlu_accelerator.py b/accelerator/mlu_accelerator.py index bef716f0ee4e..4689034692d1 100644 --- a/accelerator/mlu_accelerator.py +++ b/accelerator/mlu_accelerator.py @@ -162,11 +162,6 @@ def supported_dtypes(self): return supported_dtypes # Misc - def amp(self): - if hasattr(torch.mlu, 'amp'): - return torch.mlu.amp - return None - def is_available(self): return torch.mlu.is_available() diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index aa8e86ef1ce0..0d67c9cc1f7e 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -156,9 +156,6 @@ def supported_dtypes(self): return [torch.float] # Misc - def amp(self): - return - def is_available(self): return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 7cf30a349c57..421050d19f22 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -163,11 +163,6 @@ def supported_dtypes(self): return [torch.float, torch.half, torch.bfloat16] # Misc - def amp(self): - if hasattr(torch.npu, 'amp'): - return torch.npu.amp - return None - def is_available(self): return torch.npu.is_available() diff --git a/accelerator/sdaa_accelerator.py b/accelerator/sdaa_accelerator.py index 26113d38dd15..f185731d9385 100755 --- a/accelerator/sdaa_accelerator.py +++ b/accelerator/sdaa_accelerator.py @@ -192,11 +192,6 @@ def supported_dtypes(self): return supported_dtypes # Misc - def amp(self): - if hasattr(torch.sdaa, 'amp'): - return torch.sdaa.amp - return None - def is_available(self): return torch.sdaa.is_available() diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index 09f58abdd95b..0095c5d951d5 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -166,9 +166,6 @@ def available_memory(self, device_index=None): return self.total_memory(device_index) - self.memory_allocated(device_index) # Misc - def amp(self): - return torch.xpu.amp - def is_available(self): return torch.xpu.is_available() diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index 0a86e3c389b1..0fd02cdc67ef 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -23,7 +23,6 @@ from torch.nn.parameter import Parameter from torch.nn import init from torch.nn.modules.module import Module -from deepspeed.runtime.utils import noop_decorator from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator @@ -33,18 +32,8 @@ def print_rank_0(message, debug=False, force=False): print(message) -try: - # Fix `torch.[device].amp.custom_fwd/bwd` FutureWarning in torch 2.4 - if hasattr(torch, 'amp') and hasattr(torch.amp, 'custom_fwd') and hasattr(torch.amp, 'custom_bwd'): - autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name()) - autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name()) - else: - # original implementation - autocast_custom_fwd = get_accelerator().amp().custom_fwd - autocast_custom_bwd = get_accelerator().amp().custom_bwd -except (ImportError, AttributeError) as exp: - autocast_custom_fwd = noop_decorator - autocast_custom_bwd = noop_decorator +autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name()) +autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name()) class LinearFunctionForZeroStage3(torch.autograd.Function): From a31a5db5fbaf37e8b7d4a95c85c007d6042218e5 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Sun, 1 Mar 2026 22:23:08 -0800 Subject: [PATCH 3/3] update pytorch version in workflow nv-pre-compile-ops.yml Signed-off-by: Ma, Guokai --- .github/workflows/nv-pre-compile-ops.yml | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml index 53e2aad85a6b..9a90cc6602ba 100644 --- a/.github/workflows/nv-pre-compile-ops.yml +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -23,11 +23,20 @@ jobs: unit-tests: runs-on: ubuntu-24.04 container: - image: deepspeed/gh-builder:ubuntu1804-py38-torch1131-cu116 + image: nvidia/cuda:12.6.3-devel-ubuntu22.04 steps: + - name: Install system dependencies + run: | + apt-get update && apt-get install -y git python3 python3-pip libaio-dev ninja-build + ln -sf /usr/bin/python3 /usr/bin/python + - uses: actions/checkout@v4 + - name: Install PyTorch + run: | + pip install torch==2.10.0 --index-url https://download.pytorch.org/whl/cu126 + - name: environment run: | which python @@ -36,7 +45,7 @@ jobs: #python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - name: Compile DeepSpeed Ops run: | - DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 DS_BUILD_DEEP_COMPILE=0 pip3 install . + DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;8.9;9.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 DS_BUILD_DEEP_COMPILE=0 pip3 install . - name: DS Report run: | - ds_report + DS_ACCELERATOR=cuda ds_report