|
53 | 53 | ) |
54 | 54 |
|
55 | 55 | if TYPE_CHECKING: |
| 56 | + from contextlib import AbstractContextManager |
56 | 57 | from types import TracebackType |
57 | 58 |
|
58 | 59 | from paddle import IPUPlace as _IPUPlace, XPUPlace as _XPUPlace |
@@ -1788,13 +1789,59 @@ def manual_seed_all(seed: int) -> None: |
1788 | 1789 |
|
1789 | 1790 |
|
1790 | 1791 | class _AutocastMode: |
1791 | | - autocast = staticmethod(_autocast) |
| 1792 | + @staticmethod |
| 1793 | + def autocast( |
| 1794 | + enabled=True, dtype=paddle.float16, cache_enabled=True |
| 1795 | + ) -> AbstractContextManager: |
| 1796 | + """ |
| 1797 | + Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. |
| 1798 | + If enabled, the input data type (float32, float16 or bfloat16) of each operator is decided |
| 1799 | + by autocast algorithm for better performance. |
| 1800 | +
|
| 1801 | + Commonly, it is used together with `GradScaler` and `decorator` to achieve Auto-Mixed-Precision in |
| 1802 | + imperative mode. |
| 1803 | +
|
| 1804 | + Args: |
| 1805 | + device_type(str, optional): Device type. But because the paddle does not distinguish between devices, this parameter does not work. |
| 1806 | + enable(bool, optional): Enable auto-mixed-precision or not. Default is True. |
| 1807 | + dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. |
| 1808 | + cache_enabled(bool, optional): whether to enable cache or not. Default is True. But this parameter is not used |
| 1809 | +
|
| 1810 | + Note: |
| 1811 | + paddle.cuda.amp. |
| 1812 | +
|
| 1813 | + Examples: |
| 1814 | +
|
| 1815 | + .. code-block:: python |
| 1816 | +
|
| 1817 | + >>> # doctest: +REQUIRES(env:GPU) |
| 1818 | + >>> import paddle |
| 1819 | +
|
| 1820 | + >>> conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) |
| 1821 | + >>> data = paddle.rand([10, 3, 32, 32]) |
| 1822 | +
|
| 1823 | + >>> with paddle.device.amp.auto_cast(): |
| 1824 | + ... conv = conv2d(data) |
| 1825 | + ... print(conv.dtype) |
| 1826 | + >>> # doctest: +SKIP("This has diff in xdoctest env") |
| 1827 | + paddle.float16 |
| 1828 | + >>> # doctest: -SKIP |
| 1829 | +
|
| 1830 | + >>> with paddle.device.amp.auto_cast(enable=False): |
| 1831 | + ... conv = conv2d(data) |
| 1832 | + ... print(conv.dtype) |
| 1833 | + >>> # doctest: +SKIP("This has diff in xdoctest env") |
| 1834 | + paddle.float32 |
| 1835 | + >>> # doctest: -SKIP |
| 1836 | +
|
| 1837 | + """ |
| 1838 | + return _autocast(device_type='cuda', enabled=enabled, dtype=dtype) |
1792 | 1839 |
|
1793 | 1840 |
|
1794 | 1841 | class amp: |
1795 | 1842 | """Namespace for amp marker operations.""" |
1796 | 1843 |
|
1797 | | - autocast = staticmethod(_autocast) |
| 1844 | + autocast = staticmethod(_AutocastMode.autocast) |
1798 | 1845 | autocast_mode = _AutocastMode() |
1799 | 1846 |
|
1800 | 1847 |
|
|
0 commit comments