Skip to content

Commit d4d996d

Browse files
authored
[Compat] Autocast func (#76191)
1 parent b6521c6 commit d4d996d

File tree

3 files changed

+65
-4
lines changed

3 files changed

+65
-4
lines changed

python/paddle/amp/auto_cast.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1064,7 +1064,10 @@ def amp_decorate(
10641064

10651065

10661066
def autocast(
1067-
enabled=True, dtype=paddle.float16, cache_enabled=True
1067+
device_type: str | None,
1068+
dtype: _DTypeLiteral = 'float16',
1069+
enabled: bool = True,
1070+
cache_enabled: bool = True,
10681071
) -> AbstractContextManager:
10691072
"""
10701073
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
@@ -1075,6 +1078,7 @@ def autocast(
10751078
imperative mode.
10761079
10771080
Args:
1081+
device_type(str, optional): Device type.But because the paddle does not distinguish between devices, this parameter does not work
10781082
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
10791083
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
10801084
cache_enabled(bool, optional): whether to enable cache or not. Default is True. But this parameter is not used

python/paddle/device/__init__.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454

5555
if TYPE_CHECKING:
56+
from contextlib import AbstractContextManager
5657
from types import TracebackType
5758

5859
from paddle import IPUPlace as _IPUPlace, XPUPlace as _XPUPlace
@@ -1788,13 +1789,59 @@ def manual_seed_all(seed: int) -> None:
17881789

17891790

17901791
class _AutocastMode:
1791-
autocast = staticmethod(_autocast)
1792+
@staticmethod
1793+
def autocast(
1794+
enabled=True, dtype=paddle.float16, cache_enabled=True
1795+
) -> AbstractContextManager:
1796+
"""
1797+
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
1798+
If enabled, the input data type (float32, float16 or bfloat16) of each operator is decided
1799+
by autocast algorithm for better performance.
1800+
1801+
Commonly, it is used together with `GradScaler` and `decorator` to achieve Auto-Mixed-Precision in
1802+
imperative mode.
1803+
1804+
Args:
1805+
device_type(str, optional): Device type. But because the paddle does not distinguish between devices, this parameter does not work.
1806+
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
1807+
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
1808+
cache_enabled(bool, optional): whether to enable cache or not. Default is True. But this parameter is not used
1809+
1810+
Note:
1811+
paddle.cuda.amp.
1812+
1813+
Examples:
1814+
1815+
.. code-block:: python
1816+
1817+
>>> # doctest: +REQUIRES(env:GPU)
1818+
>>> import paddle
1819+
1820+
>>> conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
1821+
>>> data = paddle.rand([10, 3, 32, 32])
1822+
1823+
>>> with paddle.device.amp.auto_cast():
1824+
... conv = conv2d(data)
1825+
... print(conv.dtype)
1826+
>>> # doctest: +SKIP("This has diff in xdoctest env")
1827+
paddle.float16
1828+
>>> # doctest: -SKIP
1829+
1830+
>>> with paddle.device.amp.auto_cast(enable=False):
1831+
... conv = conv2d(data)
1832+
... print(conv.dtype)
1833+
>>> # doctest: +SKIP("This has diff in xdoctest env")
1834+
paddle.float32
1835+
>>> # doctest: -SKIP
1836+
1837+
"""
1838+
return _autocast(device_type='cuda', enabled=enabled, dtype=dtype)
17921839

17931840

17941841
class amp:
17951842
"""Namespace for amp marker operations."""
17961843

1797-
autocast = staticmethod(_autocast)
1844+
autocast = staticmethod(_AutocastMode.autocast)
17981845
autocast_mode = _AutocastMode()
17991846

18001847

test/amp/test_amp_api.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,17 @@ def _run_autocast_test(self, ctx):
116116
self.assertEqual(out3.dtype, paddle.float32)
117117

118118
def test_amp_autocast(self):
119-
self._run_autocast_test(paddle.amp.autocast())
119+
self._run_autocast_test(paddle.amp.autocast(device_type='cuda'))
120+
121+
def test_amp_autocast2(self):
122+
self._run_autocast_test(
123+
paddle.amp.autocast(
124+
device_type='cuda',
125+
enabled=True,
126+
dtype=paddle.float16,
127+
cache_enabled=True,
128+
)
129+
)
120130

121131
def test_cuda_amp_autocast(self):
122132
self._run_autocast_test(paddle.cuda.amp.autocast())

0 commit comments

Comments
 (0)