Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,10 @@ def amp_decorate(


def autocast(
enabled=True, dtype=paddle.float16, cache_enabled=True
device_type: str | None,
dtype: _DTypeLiteral = 'float16',
enabled: bool = True,
cache_enabled: bool = True,
) -> AbstractContextManager:
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
Expand All @@ -1075,6 +1078,7 @@ def autocast(
imperative mode.

Args:
device_type(str, optional): Device type.But because the paddle does not distinguish between devices, this parameter does not work
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
cache_enabled(bool, optional): whether to enable cache or not. Default is True. But this parameter is not used
Expand Down
51 changes: 49 additions & 2 deletions python/paddle/device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)

if TYPE_CHECKING:
from contextlib import AbstractContextManager
from types import TracebackType

from paddle import IPUPlace as _IPUPlace, XPUPlace as _XPUPlace
Expand Down Expand Up @@ -1788,13 +1789,59 @@ def manual_seed_all(seed: int) -> None:


class _AutocastMode:
autocast = staticmethod(_autocast)
@staticmethod
def autocast(
enabled=True, dtype=paddle.float16, cache_enabled=True
) -> AbstractContextManager:
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
If enabled, the input data type (float32, float16 or bfloat16) of each operator is decided
by autocast algorithm for better performance.

Commonly, it is used together with `GradScaler` and `decorator` to achieve Auto-Mixed-Precision in
imperative mode.

Args:
device_type(str, optional): Device type. But because the paddle does not distinguish between devices, this parameter does not work.
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
cache_enabled(bool, optional): whether to enable cache or not. Default is True. But this parameter is not used

Note:
paddle.cuda.amp.

Examples:

.. code-block:: python

>>> # doctest: +REQUIRES(env:GPU)
>>> import paddle

>>> conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
>>> data = paddle.rand([10, 3, 32, 32])

>>> with paddle.device.amp.auto_cast():
... conv = conv2d(data)
... print(conv.dtype)
>>> # doctest: +SKIP("This has diff in xdoctest env")
paddle.float16
>>> # doctest: -SKIP

>>> with paddle.device.amp.auto_cast(enable=False):
... conv = conv2d(data)
... print(conv.dtype)
>>> # doctest: +SKIP("This has diff in xdoctest env")
paddle.float32
>>> # doctest: -SKIP

"""
return _autocast(device_type='cuda', enabled=enabled, dtype=dtype)


class amp:
"""Namespace for amp marker operations."""

autocast = staticmethod(_autocast)
autocast = staticmethod(_AutocastMode.autocast)
autocast_mode = _AutocastMode()


Expand Down
12 changes: 11 additions & 1 deletion test/amp/test_amp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,17 @@ def _run_autocast_test(self, ctx):
self.assertEqual(out3.dtype, paddle.float32)

def test_amp_autocast(self):
self._run_autocast_test(paddle.amp.autocast())
self._run_autocast_test(paddle.amp.autocast(device_type='cuda'))

def test_amp_autocast2(self):
self._run_autocast_test(
paddle.amp.autocast(
device_type='cuda',
enabled=True,
dtype=paddle.float16,
cache_enabled=True,
)
)

def test_cuda_amp_autocast(self):
self._run_autocast_test(paddle.cuda.amp.autocast())
Expand Down
Loading