diff --git a/paddle/phi/core/framework/data_type_transform.cc b/paddle/phi/core/framework/data_type_transform.cc index 6ed397d85d378e..d05d2a2a72c65f 100644 --- a/paddle/phi/core/framework/data_type_transform.cc +++ b/paddle/phi/core/framework/data_type_transform.cc @@ -34,6 +34,78 @@ struct CastDataTypeFunctor { } }; +template <> +struct CastDataTypeFunctor<::phi::dtype::float8_e5m2, + ::phi::dtype::complex> { + HOSTDEVICE inline ::phi::dtype::complex operator()( + ::phi::dtype::float8_e5m2 in) const { + return ::phi::dtype::complex(static_cast(in)); + } +}; + +template <> +struct CastDataTypeFunctor<::phi::dtype::float8_e5m2, + ::phi::dtype::complex> { + HOSTDEVICE inline ::phi::dtype::complex operator()( + ::phi::dtype::float8_e5m2 in) const { + return ::phi::dtype::complex(static_cast(in)); + } +}; + +template <> +struct CastDataTypeFunctor<::phi::dtype::float8_e4m3fn, + ::phi::dtype::complex> { + HOSTDEVICE inline ::phi::dtype::complex operator()( + ::phi::dtype::float8_e4m3fn in) const { + return ::phi::dtype::complex(static_cast(in)); + } +}; + +template <> +struct CastDataTypeFunctor<::phi::dtype::float8_e4m3fn, + ::phi::dtype::complex> { + HOSTDEVICE inline ::phi::dtype::complex operator()( + ::phi::dtype::float8_e4m3fn in) const { + return ::phi::dtype::complex(static_cast(in)); + } +}; + +template <> +struct CastDataTypeFunctor<::phi::dtype::bfloat16, + ::phi::dtype::complex> { + HOSTDEVICE inline ::phi::dtype::complex operator()( + ::phi::dtype::bfloat16 in) const { + return ::phi::dtype::complex(static_cast(in)); + } +}; + +template <> +struct CastDataTypeFunctor<::phi::dtype::bfloat16, + ::phi::dtype::complex> { + HOSTDEVICE inline ::phi::dtype::complex operator()( + ::phi::dtype::bfloat16 in) const { + return ::phi::dtype::complex(static_cast(in)); + } +}; + +template <> +struct CastDataTypeFunctor<::phi::dtype::float16, + ::phi::dtype::complex> { + HOSTDEVICE inline ::phi::dtype::complex operator()( + ::phi::dtype::float16 in) const { + return ::phi::dtype::complex(static_cast(in)); + } +}; + +template <> +struct CastDataTypeFunctor<::phi::dtype::float16, + ::phi::dtype::complex> { + HOSTDEVICE inline ::phi::dtype::complex operator()( + ::phi::dtype::float16 in) const { + return ::phi::dtype::complex(static_cast(in)); + } +}; + #if defined(PADDLE_WITH_XPU) template diff --git a/test/cpp/fluid/framework/data_type_transform_test.cc b/test/cpp/fluid/framework/data_type_transform_test.cc index 6a510d21acdca4..c94b9a8d5d7da1 100644 --- a/test/cpp/fluid/framework/data_type_transform_test.cc +++ b/test/cpp/fluid/framework/data_type_transform_test.cc @@ -395,4 +395,164 @@ TEST(DataTypeTransform, CPUTransform) { EXPECT_EQ(ptr[i], static_cast(in_data_bool[i])); } } + + // data type transform from lightweight float formats to complex types + { + auto kernel_float8_e5m2 = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT8_E5M2); + auto kernel_float8_e4m3fn = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT8_E4M3FN); + auto kernel_complex64 = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::COMPLEX64); + auto kernel_complex128 = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::COMPLEX128); + + phi::DenseTensor in; + phi::DenseTensor out; + int data_number = 2 * 3; + + // Test float16 to complex64 + { + phi::dtype::float16* ptr = in.mutable_data( + common::make_ddim({2, 3}), place); + for (int i = 0; i < data_number; ++i) { + ptr[i] = static_cast(i); + } + + paddle::framework::TransDataType(kernel_fp16, kernel_complex64, in, &out); + phi::dtype::complex* out_data = + out.data>(); + for (int i = 0; i < data_number; ++i) { + EXPECT_EQ(out_data[i].real, static_cast(ptr[i])); + EXPECT_EQ(out_data[i].imag, 0.0f); + } + } + + // Test float16 to complex128 + { + phi::dtype::float16* ptr = in.mutable_data( + common::make_ddim({2, 3}), place); + for (int i = 0; i < data_number; ++i) { + ptr[i] = static_cast(i); + } + + paddle::framework::TransDataType( + kernel_fp16, kernel_complex128, in, &out); + phi::dtype::complex* out_data = + out.data>(); + for (int i = 0; i < data_number; ++i) { + EXPECT_EQ(out_data[i].real, static_cast(ptr[i])); + EXPECT_EQ(out_data[i].imag, 0.0); + } + } + + // Test bfloat16 to complex64 + { + phi::dtype::bfloat16* ptr = in.mutable_data( + common::make_ddim({2, 3}), place); + for (int i = 0; i < data_number; ++i) { + ptr[i] = static_cast(i); + } + + paddle::framework::TransDataType(kernel_bf16, kernel_complex64, in, &out); + phi::dtype::complex* out_data = + out.data>(); + for (int i = 0; i < data_number; ++i) { + EXPECT_EQ(out_data[i].real, static_cast(ptr[i])); + EXPECT_EQ(out_data[i].imag, 0.0f); + } + } + + // Test bfloat16 to complex128 + { + phi::dtype::bfloat16* ptr = in.mutable_data( + common::make_ddim({2, 3}), place); + for (int i = 0; i < data_number; ++i) { + ptr[i] = static_cast(i); + } + + paddle::framework::TransDataType( + kernel_bf16, kernel_complex128, in, &out); + phi::dtype::complex* out_data = + out.data>(); + for (int i = 0; i < data_number; ++i) { + EXPECT_EQ(out_data[i].real, static_cast(ptr[i])); + EXPECT_EQ(out_data[i].imag, 0.0); + } + } + + // Test float8_e4m3fn to complex64 + { + phi::dtype::float8_e4m3fn* ptr = + in.mutable_data(common::make_ddim({2, 3}), + place); + for (int i = 0; i < data_number; ++i) { + ptr[i] = static_cast(i); + } + + paddle::framework::TransDataType( + kernel_float8_e4m3fn, kernel_complex64, in, &out); + phi::dtype::complex* out_data = + out.data>(); + for (int i = 0; i < data_number; ++i) { + EXPECT_EQ(out_data[i].real, static_cast(ptr[i])); + EXPECT_EQ(out_data[i].imag, 0.0f); + } + } + + // Test float8_e4m3fn to complex128 + { + phi::dtype::float8_e4m3fn* ptr = + in.mutable_data(common::make_ddim({2, 3}), + place); + for (int i = 0; i < data_number; ++i) { + ptr[i] = static_cast(i); + } + + paddle::framework::TransDataType( + kernel_float8_e4m3fn, kernel_complex128, in, &out); + phi::dtype::complex* out_data = + out.data>(); + for (int i = 0; i < data_number; ++i) { + EXPECT_EQ(out_data[i].real, static_cast(ptr[i])); + EXPECT_EQ(out_data[i].imag, 0.0); + } + } + + // Test float8_e5m2 to complex64 + { + phi::dtype::float8_e5m2* ptr = in.mutable_data( + common::make_ddim({2, 3}), place); + for (int i = 0; i < data_number; ++i) { + ptr[i] = static_cast(i); + } + + paddle::framework::TransDataType( + kernel_float8_e5m2, kernel_complex64, in, &out); + phi::dtype::complex* out_data = + out.data>(); + for (int i = 0; i < data_number; ++i) { + EXPECT_EQ(out_data[i].real, static_cast(ptr[i])); + EXPECT_EQ(out_data[i].imag, 0.0f); + } + } + + // Test float8_e5m2 to complex128 + { + phi::dtype::float8_e5m2* ptr = in.mutable_data( + common::make_ddim({2, 3}), place); + for (int i = 0; i < data_number; ++i) { + ptr[i] = static_cast(i); + } + + paddle::framework::TransDataType( + kernel_float8_e5m2, kernel_complex128, in, &out); + phi::dtype::complex* out_data = + out.data>(); + for (int i = 0; i < data_number; ++i) { + EXPECT_EQ(out_data[i].real, static_cast(ptr[i])); + EXPECT_EQ(out_data[i].imag, 0.0); + } + } + } } diff --git a/test/legacy_test/test_lightweight_float_to_complex.py b/test/legacy_test/test_lightweight_float_to_complex.py new file mode 100644 index 00000000000000..cf3d32148bae94 --- /dev/null +++ b/test/legacy_test/test_lightweight_float_to_complex.py @@ -0,0 +1,167 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +class TestLightweightFloatToComplex(unittest.TestCase): + """Test casting from lightweight float formats (float8, float16, bfloat16) to complex types.""" + + def test_float16_to_complex(self): + """Test float16 to complex64/complex128 conversion.""" + paddle.set_device('cpu') + + r_fp16 = np.random.random(size=[10, 10]).astype('float16') + r_fp16_t = paddle.to_tensor(r_fp16, dtype='float16') + + # Test dtype conversion + self.assertEqual(r_fp16_t.cast('complex64').dtype, paddle.complex64) + self.assertEqual(r_fp16_t.cast('complex128').dtype, paddle.complex128) + + # Verify the real part is correct + np.testing.assert_allclose( + r_fp16_t.cast('complex64').real().numpy(), + r_fp16.astype('float32'), + rtol=1e-03, + ) + np.testing.assert_allclose( + r_fp16_t.cast('complex128').real().numpy(), + r_fp16.astype('float64'), + rtol=1e-03, + ) + + # Verify the imaginary part is zero + np.testing.assert_array_equal( + r_fp16_t.cast('complex64').imag().numpy(), + np.zeros([10, 10], dtype='float32'), + ) + np.testing.assert_array_equal( + r_fp16_t.cast('complex128').imag().numpy(), + np.zeros([10, 10], dtype='float64'), + ) + + def test_bfloat16_to_complex(self): + """Test bfloat16 to complex64/complex128 conversion.""" + paddle.set_device('cpu') + + r_bf16 = np.random.random(size=[10, 10]).astype('float32') + r_bf16_t = paddle.to_tensor(r_bf16, dtype='bfloat16') + + # Test dtype conversion + self.assertEqual(r_bf16_t.cast('complex64').dtype, paddle.complex64) + self.assertEqual(r_bf16_t.cast('complex128').dtype, paddle.complex128) + + # Verify the real part is correct + np.testing.assert_allclose( + r_bf16_t.cast('complex64').real().numpy(), + r_bf16_t.cast('float32').numpy(), + rtol=1e-02, + ) + np.testing.assert_allclose( + r_bf16_t.cast('complex128').real().numpy(), + r_bf16_t.cast('float64').numpy(), + rtol=1e-02, + ) + + # Verify the imaginary part is zero + np.testing.assert_array_equal( + r_bf16_t.cast('complex64').imag().numpy(), + np.zeros([10, 10], dtype='float32'), + ) + np.testing.assert_array_equal( + r_bf16_t.cast('complex128').imag().numpy(), + np.zeros([10, 10], dtype='float64'), + ) + + def test_float8_e4m3fn_to_complex(self): + """Test float8_e4m3fn to complex64/complex128 conversion.""" + paddle.set_device('cpu') + + r_fp32 = np.random.uniform(1.0, 10.0, size=[10, 10]).astype('float32') + r_fp32_t = paddle.to_tensor(r_fp32) + r_fp8_e4m3fn_t = r_fp32_t.astype('float8_e4m3fn') + + # Test dtype conversion + self.assertEqual( + r_fp8_e4m3fn_t.cast('complex64').dtype, paddle.complex64 + ) + self.assertEqual( + r_fp8_e4m3fn_t.cast('complex128').dtype, paddle.complex128 + ) + + # Verify the real part matches the float32 version + np.testing.assert_allclose( + r_fp8_e4m3fn_t.cast('complex64').real().numpy(), + r_fp8_e4m3fn_t.cast('float32').numpy(), + rtol=1e-02, + ) + np.testing.assert_allclose( + r_fp8_e4m3fn_t.cast('complex128').real().numpy(), + r_fp8_e4m3fn_t.cast('float64').numpy(), + rtol=1e-02, + ) + + # Verify the imaginary part is zero + np.testing.assert_array_equal( + r_fp8_e4m3fn_t.cast('complex64').imag().numpy(), + np.zeros([10, 10], dtype='float32'), + ) + np.testing.assert_array_equal( + r_fp8_e4m3fn_t.cast('complex128').imag().numpy(), + np.zeros([10, 10], dtype='float64'), + ) + + def test_float8_e5m2_to_complex(self): + """Test float8_e5m2 to complex64/complex128 conversion.""" + paddle.set_device('cpu') + + r_fp32 = np.random.uniform(1.0, 10.0, size=[10, 10]).astype('float32') + r_fp32_t = paddle.to_tensor(r_fp32) + r_fp8_e5m2_t = r_fp32_t.astype('float8_e5m2') + + # Test dtype conversion + self.assertEqual(r_fp8_e5m2_t.cast('complex64').dtype, paddle.complex64) + self.assertEqual( + r_fp8_e5m2_t.cast('complex128').dtype, paddle.complex128 + ) + + # Verify the real part matches the float32 version + np.testing.assert_allclose( + r_fp8_e5m2_t.cast('complex64').real().numpy(), + r_fp8_e5m2_t.cast('float32').numpy(), + rtol=1e-02, + ) + np.testing.assert_allclose( + r_fp8_e5m2_t.cast('complex128').real().numpy(), + r_fp8_e5m2_t.cast('float64').numpy(), + rtol=1e-02, + ) + + # Verify the imaginary part is zero + np.testing.assert_array_equal( + r_fp8_e5m2_t.cast('complex64').imag().numpy(), + np.zeros([10, 10], dtype='float32'), + ) + np.testing.assert_array_equal( + r_fp8_e5m2_t.cast('complex128').imag().numpy(), + np.zeros([10, 10], dtype='float64'), + ) + + +if __name__ == '__main__': + unittest.main()