Skip to content

Commit e016bf1

Browse files
committed
Add new test file for lightweight float to complex casting
1 parent 88d3dd4 commit e016bf1

File tree

1 file changed

+0
-113
lines changed

1 file changed

+0
-113
lines changed

test/legacy_test/test_complex_cast.py

Lines changed: 0 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818

1919
import paddle
2020

21-
# This test file covers casting operations between different data types,
22-
# including lightweight float formats (float8, float16, bfloat16) and complex types.
23-
2421

2522
class TestComplexCastOp(unittest.TestCase):
2623
def test_complex_to_real(self):
@@ -82,116 +79,6 @@ def test_complex64_complex128(self):
8279
c_128.cast('complex128').numpy(), c_64.numpy(), rtol=1e-05
8380
)
8481

85-
@unittest.skipIf(
86-
not paddle.is_compiled_with_cuda(),
87-
"float16/bfloat16/float8 test runs only on CUDA",
88-
)
89-
def test_float16_bfloat16_to_complex(self):
90-
# Test float16 to complex64/complex128
91-
r_fp16 = np.random.random(size=[10, 10]).astype('float16')
92-
r_fp16_t = paddle.to_tensor(r_fp16, dtype='float16')
93-
94-
self.assertEqual(r_fp16_t.cast('complex64').dtype, paddle.complex64)
95-
self.assertEqual(r_fp16_t.cast('complex128').dtype, paddle.complex128)
96-
97-
np.testing.assert_allclose(
98-
r_fp16_t.cast('complex64').real().numpy(),
99-
r_fp16.astype('float32'),
100-
rtol=1e-03,
101-
)
102-
np.testing.assert_allclose(
103-
r_fp16_t.cast('complex128').real().numpy(),
104-
r_fp16.astype('float64'),
105-
rtol=1e-03,
106-
)
107-
108-
# Test bfloat16 to complex64/complex128
109-
r_bf16 = np.random.random(size=[10, 10]).astype('float32')
110-
r_bf16_t = paddle.to_tensor(r_bf16, dtype='bfloat16')
111-
112-
self.assertEqual(r_bf16_t.cast('complex64').dtype, paddle.complex64)
113-
self.assertEqual(r_bf16_t.cast('complex128').dtype, paddle.complex128)
114-
115-
np.testing.assert_allclose(
116-
r_bf16_t.cast('complex64').real().numpy(),
117-
r_bf16_t.cast('float32').numpy(),
118-
rtol=1e-02,
119-
)
120-
np.testing.assert_allclose(
121-
r_bf16_t.cast('complex128').real().numpy(),
122-
r_bf16_t.cast('float64').numpy(),
123-
rtol=1e-02,
124-
)
125-
126-
@unittest.skipIf(
127-
not paddle.is_compiled_with_cuda(),
128-
"float8 test runs only on CUDA",
129-
)
130-
def test_float8_to_complex(self):
131-
# Test float8_e4m3fn to complex64/complex128
132-
r_fp32 = np.random.uniform(1.0, 10.0, size=[10, 10]).astype('float32')
133-
r_fp32_t = paddle.to_tensor(r_fp32)
134-
r_fp8_e4m3fn_t = r_fp32_t.astype('float8_e4m3fn')
135-
136-
self.assertEqual(
137-
r_fp8_e4m3fn_t.cast('complex64').dtype, paddle.complex64
138-
)
139-
self.assertEqual(
140-
r_fp8_e4m3fn_t.cast('complex128').dtype, paddle.complex128
141-
)
142-
143-
# Verify the real part matches the float32 version
144-
np.testing.assert_allclose(
145-
r_fp8_e4m3fn_t.cast('complex64').real().numpy(),
146-
r_fp8_e4m3fn_t.cast('float32').numpy(),
147-
rtol=1e-02,
148-
)
149-
np.testing.assert_allclose(
150-
r_fp8_e4m3fn_t.cast('complex128').real().numpy(),
151-
r_fp8_e4m3fn_t.cast('float64').numpy(),
152-
rtol=1e-02,
153-
)
154-
155-
# Verify the imaginary part is zero
156-
np.testing.assert_array_equal(
157-
r_fp8_e4m3fn_t.cast('complex64').imag().numpy(),
158-
np.zeros([10, 10], dtype='float32'),
159-
)
160-
np.testing.assert_array_equal(
161-
r_fp8_e4m3fn_t.cast('complex128').imag().numpy(),
162-
np.zeros([10, 10], dtype='float64'),
163-
)
164-
165-
# Test float8_e5m2 to complex64/complex128
166-
r_fp8_e5m2_t = r_fp32_t.astype('float8_e5m2')
167-
168-
self.assertEqual(r_fp8_e5m2_t.cast('complex64').dtype, paddle.complex64)
169-
self.assertEqual(
170-
r_fp8_e5m2_t.cast('complex128').dtype, paddle.complex128
171-
)
172-
173-
# Verify the real part matches the float32 version
174-
np.testing.assert_allclose(
175-
r_fp8_e5m2_t.cast('complex64').real().numpy(),
176-
r_fp8_e5m2_t.cast('float32').numpy(),
177-
rtol=1e-02,
178-
)
179-
np.testing.assert_allclose(
180-
r_fp8_e5m2_t.cast('complex128').real().numpy(),
181-
r_fp8_e5m2_t.cast('float64').numpy(),
182-
rtol=1e-02,
183-
)
184-
185-
# Verify the imaginary part is zero
186-
np.testing.assert_array_equal(
187-
r_fp8_e5m2_t.cast('complex64').imag().numpy(),
188-
np.zeros([10, 10], dtype='float32'),
189-
)
190-
np.testing.assert_array_equal(
191-
r_fp8_e5m2_t.cast('complex128').imag().numpy(),
192-
np.zeros([10, 10], dtype='float64'),
193-
)
194-
19582

19683
if __name__ == '__main__':
19784
unittest.main()

0 commit comments

Comments
 (0)