|
18 | 18 |
|
19 | 19 | import paddle |
20 | 20 |
|
21 | | -# This test file covers casting operations between different data types, |
22 | | -# including lightweight float formats (float8, float16, bfloat16) and complex types. |
23 | | - |
24 | 21 |
|
25 | 22 | class TestComplexCastOp(unittest.TestCase): |
26 | 23 | def test_complex_to_real(self): |
@@ -82,116 +79,6 @@ def test_complex64_complex128(self): |
82 | 79 | c_128.cast('complex128').numpy(), c_64.numpy(), rtol=1e-05 |
83 | 80 | ) |
84 | 81 |
|
85 | | - @unittest.skipIf( |
86 | | - not paddle.is_compiled_with_cuda(), |
87 | | - "float16/bfloat16/float8 test runs only on CUDA", |
88 | | - ) |
89 | | - def test_float16_bfloat16_to_complex(self): |
90 | | - # Test float16 to complex64/complex128 |
91 | | - r_fp16 = np.random.random(size=[10, 10]).astype('float16') |
92 | | - r_fp16_t = paddle.to_tensor(r_fp16, dtype='float16') |
93 | | - |
94 | | - self.assertEqual(r_fp16_t.cast('complex64').dtype, paddle.complex64) |
95 | | - self.assertEqual(r_fp16_t.cast('complex128').dtype, paddle.complex128) |
96 | | - |
97 | | - np.testing.assert_allclose( |
98 | | - r_fp16_t.cast('complex64').real().numpy(), |
99 | | - r_fp16.astype('float32'), |
100 | | - rtol=1e-03, |
101 | | - ) |
102 | | - np.testing.assert_allclose( |
103 | | - r_fp16_t.cast('complex128').real().numpy(), |
104 | | - r_fp16.astype('float64'), |
105 | | - rtol=1e-03, |
106 | | - ) |
107 | | - |
108 | | - # Test bfloat16 to complex64/complex128 |
109 | | - r_bf16 = np.random.random(size=[10, 10]).astype('float32') |
110 | | - r_bf16_t = paddle.to_tensor(r_bf16, dtype='bfloat16') |
111 | | - |
112 | | - self.assertEqual(r_bf16_t.cast('complex64').dtype, paddle.complex64) |
113 | | - self.assertEqual(r_bf16_t.cast('complex128').dtype, paddle.complex128) |
114 | | - |
115 | | - np.testing.assert_allclose( |
116 | | - r_bf16_t.cast('complex64').real().numpy(), |
117 | | - r_bf16_t.cast('float32').numpy(), |
118 | | - rtol=1e-02, |
119 | | - ) |
120 | | - np.testing.assert_allclose( |
121 | | - r_bf16_t.cast('complex128').real().numpy(), |
122 | | - r_bf16_t.cast('float64').numpy(), |
123 | | - rtol=1e-02, |
124 | | - ) |
125 | | - |
126 | | - @unittest.skipIf( |
127 | | - not paddle.is_compiled_with_cuda(), |
128 | | - "float8 test runs only on CUDA", |
129 | | - ) |
130 | | - def test_float8_to_complex(self): |
131 | | - # Test float8_e4m3fn to complex64/complex128 |
132 | | - r_fp32 = np.random.uniform(1.0, 10.0, size=[10, 10]).astype('float32') |
133 | | - r_fp32_t = paddle.to_tensor(r_fp32) |
134 | | - r_fp8_e4m3fn_t = r_fp32_t.astype('float8_e4m3fn') |
135 | | - |
136 | | - self.assertEqual( |
137 | | - r_fp8_e4m3fn_t.cast('complex64').dtype, paddle.complex64 |
138 | | - ) |
139 | | - self.assertEqual( |
140 | | - r_fp8_e4m3fn_t.cast('complex128').dtype, paddle.complex128 |
141 | | - ) |
142 | | - |
143 | | - # Verify the real part matches the float32 version |
144 | | - np.testing.assert_allclose( |
145 | | - r_fp8_e4m3fn_t.cast('complex64').real().numpy(), |
146 | | - r_fp8_e4m3fn_t.cast('float32').numpy(), |
147 | | - rtol=1e-02, |
148 | | - ) |
149 | | - np.testing.assert_allclose( |
150 | | - r_fp8_e4m3fn_t.cast('complex128').real().numpy(), |
151 | | - r_fp8_e4m3fn_t.cast('float64').numpy(), |
152 | | - rtol=1e-02, |
153 | | - ) |
154 | | - |
155 | | - # Verify the imaginary part is zero |
156 | | - np.testing.assert_array_equal( |
157 | | - r_fp8_e4m3fn_t.cast('complex64').imag().numpy(), |
158 | | - np.zeros([10, 10], dtype='float32'), |
159 | | - ) |
160 | | - np.testing.assert_array_equal( |
161 | | - r_fp8_e4m3fn_t.cast('complex128').imag().numpy(), |
162 | | - np.zeros([10, 10], dtype='float64'), |
163 | | - ) |
164 | | - |
165 | | - # Test float8_e5m2 to complex64/complex128 |
166 | | - r_fp8_e5m2_t = r_fp32_t.astype('float8_e5m2') |
167 | | - |
168 | | - self.assertEqual(r_fp8_e5m2_t.cast('complex64').dtype, paddle.complex64) |
169 | | - self.assertEqual( |
170 | | - r_fp8_e5m2_t.cast('complex128').dtype, paddle.complex128 |
171 | | - ) |
172 | | - |
173 | | - # Verify the real part matches the float32 version |
174 | | - np.testing.assert_allclose( |
175 | | - r_fp8_e5m2_t.cast('complex64').real().numpy(), |
176 | | - r_fp8_e5m2_t.cast('float32').numpy(), |
177 | | - rtol=1e-02, |
178 | | - ) |
179 | | - np.testing.assert_allclose( |
180 | | - r_fp8_e5m2_t.cast('complex128').real().numpy(), |
181 | | - r_fp8_e5m2_t.cast('float64').numpy(), |
182 | | - rtol=1e-02, |
183 | | - ) |
184 | | - |
185 | | - # Verify the imaginary part is zero |
186 | | - np.testing.assert_array_equal( |
187 | | - r_fp8_e5m2_t.cast('complex64').imag().numpy(), |
188 | | - np.zeros([10, 10], dtype='float32'), |
189 | | - ) |
190 | | - np.testing.assert_array_equal( |
191 | | - r_fp8_e5m2_t.cast('complex128').imag().numpy(), |
192 | | - np.zeros([10, 10], dtype='float64'), |
193 | | - ) |
194 | | - |
195 | 82 |
|
196 | 83 | if __name__ == '__main__': |
197 | 84 | unittest.main() |
0 commit comments