Skip to content
72 changes: 72 additions & 0 deletions paddle/phi/core/framework/data_type_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,78 @@ struct CastDataTypeFunctor {
}
};

template <>
struct CastDataTypeFunctor<::phi::dtype::float8_e5m2,
::phi::dtype::complex<float>> {
HOSTDEVICE inline ::phi::dtype::complex<float> operator()(
::phi::dtype::float8_e5m2 in) const {
return ::phi::dtype::complex<float>(static_cast<float>(in));
}
};

template <>
struct CastDataTypeFunctor<::phi::dtype::float8_e5m2,
::phi::dtype::complex<double>> {
HOSTDEVICE inline ::phi::dtype::complex<double> operator()(
::phi::dtype::float8_e5m2 in) const {
return ::phi::dtype::complex<double>(static_cast<double>(in));
}
};

template <>
struct CastDataTypeFunctor<::phi::dtype::float8_e4m3fn,
::phi::dtype::complex<float>> {
HOSTDEVICE inline ::phi::dtype::complex<float> operator()(
::phi::dtype::float8_e4m3fn in) const {
return ::phi::dtype::complex<float>(static_cast<float>(in));
}
};

template <>
struct CastDataTypeFunctor<::phi::dtype::float8_e4m3fn,
::phi::dtype::complex<double>> {
HOSTDEVICE inline ::phi::dtype::complex<double> operator()(
::phi::dtype::float8_e4m3fn in) const {
return ::phi::dtype::complex<double>(static_cast<double>(in));
}
};

template <>
struct CastDataTypeFunctor<::phi::dtype::bfloat16,
::phi::dtype::complex<float>> {
HOSTDEVICE inline ::phi::dtype::complex<float> operator()(
::phi::dtype::bfloat16 in) const {
return ::phi::dtype::complex<float>(static_cast<float>(in));
}
};

template <>
struct CastDataTypeFunctor<::phi::dtype::bfloat16,
::phi::dtype::complex<double>> {
HOSTDEVICE inline ::phi::dtype::complex<double> operator()(
::phi::dtype::bfloat16 in) const {
return ::phi::dtype::complex<double>(static_cast<double>(in));
}
};

template <>
struct CastDataTypeFunctor<::phi::dtype::float16,
::phi::dtype::complex<float>> {
HOSTDEVICE inline ::phi::dtype::complex<float> operator()(
::phi::dtype::float16 in) const {
return ::phi::dtype::complex<float>(static_cast<float>(in));
}
};

template <>
struct CastDataTypeFunctor<::phi::dtype::float16,
::phi::dtype::complex<double>> {
HOSTDEVICE inline ::phi::dtype::complex<double> operator()(
::phi::dtype::float16 in) const {
return ::phi::dtype::complex<double>(static_cast<double>(in));
}
};

#if defined(PADDLE_WITH_XPU)

template <typename InType, typename OutType>
Expand Down
160 changes: 160 additions & 0 deletions test/cpp/fluid/framework/data_type_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,4 +395,164 @@ TEST(DataTypeTransform, CPUTransform) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_bool[i]));
}
}

// data type transform from lightweight float formats to complex types
{
auto kernel_float8_e5m2 = phi::KernelKey(
place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT8_E5M2);
auto kernel_float8_e4m3fn = phi::KernelKey(
place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT8_E4M3FN);
auto kernel_complex64 = phi::KernelKey(
place, phi::DataLayout::ALL_LAYOUT, phi::DataType::COMPLEX64);
auto kernel_complex128 = phi::KernelKey(
place, phi::DataLayout::ALL_LAYOUT, phi::DataType::COMPLEX128);

phi::DenseTensor in;
phi::DenseTensor out;
int data_number = 2 * 3;

// Test float16 to complex64
{
phi::dtype::float16* ptr = in.mutable_data<phi::dtype::float16>(
common::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
ptr[i] = static_cast<phi::dtype::float16>(i);
}

paddle::framework::TransDataType(kernel_fp16, kernel_complex64, in, &out);
phi::dtype::complex<float>* out_data =
out.data<phi::dtype::complex<float>>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data[i].real, static_cast<float>(ptr[i]));
EXPECT_EQ(out_data[i].imag, 0.0f);
}
}

// Test float16 to complex128
{
phi::dtype::float16* ptr = in.mutable_data<phi::dtype::float16>(
common::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
ptr[i] = static_cast<phi::dtype::float16>(i);
}

paddle::framework::TransDataType(
kernel_fp16, kernel_complex128, in, &out);
phi::dtype::complex<double>* out_data =
out.data<phi::dtype::complex<double>>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data[i].real, static_cast<double>(ptr[i]));
EXPECT_EQ(out_data[i].imag, 0.0);
}
}

// Test bfloat16 to complex64
{
phi::dtype::bfloat16* ptr = in.mutable_data<phi::dtype::bfloat16>(
common::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
ptr[i] = static_cast<phi::dtype::bfloat16>(i);
}

paddle::framework::TransDataType(kernel_bf16, kernel_complex64, in, &out);
phi::dtype::complex<float>* out_data =
out.data<phi::dtype::complex<float>>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data[i].real, static_cast<float>(ptr[i]));
EXPECT_EQ(out_data[i].imag, 0.0f);
}
}

// Test bfloat16 to complex128
{
phi::dtype::bfloat16* ptr = in.mutable_data<phi::dtype::bfloat16>(
common::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
ptr[i] = static_cast<phi::dtype::bfloat16>(i);
}

paddle::framework::TransDataType(
kernel_bf16, kernel_complex128, in, &out);
phi::dtype::complex<double>* out_data =
out.data<phi::dtype::complex<double>>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data[i].real, static_cast<double>(ptr[i]));
EXPECT_EQ(out_data[i].imag, 0.0);
}
}

// Test float8_e4m3fn to complex64
{
phi::dtype::float8_e4m3fn* ptr =
in.mutable_data<phi::dtype::float8_e4m3fn>(common::make_ddim({2, 3}),
place);
for (int i = 0; i < data_number; ++i) {
ptr[i] = static_cast<phi::dtype::float8_e4m3fn>(i);
}

paddle::framework::TransDataType(
kernel_float8_e4m3fn, kernel_complex64, in, &out);
phi::dtype::complex<float>* out_data =
out.data<phi::dtype::complex<float>>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data[i].real, static_cast<float>(ptr[i]));
EXPECT_EQ(out_data[i].imag, 0.0f);
}
}

// Test float8_e4m3fn to complex128
{
phi::dtype::float8_e4m3fn* ptr =
in.mutable_data<phi::dtype::float8_e4m3fn>(common::make_ddim({2, 3}),
place);
for (int i = 0; i < data_number; ++i) {
ptr[i] = static_cast<phi::dtype::float8_e4m3fn>(i);
}

paddle::framework::TransDataType(
kernel_float8_e4m3fn, kernel_complex128, in, &out);
phi::dtype::complex<double>* out_data =
out.data<phi::dtype::complex<double>>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data[i].real, static_cast<double>(ptr[i]));
EXPECT_EQ(out_data[i].imag, 0.0);
}
}

// Test float8_e5m2 to complex64
{
phi::dtype::float8_e5m2* ptr = in.mutable_data<phi::dtype::float8_e5m2>(
common::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
ptr[i] = static_cast<phi::dtype::float8_e5m2>(i);
}

paddle::framework::TransDataType(
kernel_float8_e5m2, kernel_complex64, in, &out);
phi::dtype::complex<float>* out_data =
out.data<phi::dtype::complex<float>>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data[i].real, static_cast<float>(ptr[i]));
EXPECT_EQ(out_data[i].imag, 0.0f);
}
}

// Test float8_e5m2 to complex128
{
phi::dtype::float8_e5m2* ptr = in.mutable_data<phi::dtype::float8_e5m2>(
common::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
ptr[i] = static_cast<phi::dtype::float8_e5m2>(i);
}

paddle::framework::TransDataType(
kernel_float8_e5m2, kernel_complex128, in, &out);
phi::dtype::complex<double>* out_data =
out.data<phi::dtype::complex<double>>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data[i].real, static_cast<double>(ptr[i]));
EXPECT_EQ(out_data[i].imag, 0.0);
}
}
}
}
Loading
Loading