From dc7b1b97c51487180621f9193a080ada7eb6f01a Mon Sep 17 00:00:00 2001 From: Yin Hongyun Date: Mon, 4 Nov 2024 23:25:52 +0800 Subject: [PATCH 1/2] [feat] add diopiEmpty, diopiFull, diopiRandNormal --- impl/torch/functions/functions.cpp | 43 ++++++++++++++++++++++++++++++ proto/include/diopi/functions.h | 40 +++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 6ee8e104e..44da17296 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -65,6 +65,49 @@ const char* diopiGetImplVersion() { return version; } +diopiError_t diopiEmpty(diopiContextHandle_t ctx, const diopiDtype_t dtype, const diopiDevice_t device, const diopiSize_t* shape, diopiTensorHandle_t out) { + impl::aten::setCurStream(ctx); + return diopiRequireTensor(ctx, &out, shape, nullptr, dtype, device); +} + +diopiError_t diopiFull(diopiContextHandle_t ctx, const diopiDtype_t dtype, const diopiDevice_t device, const diopiSize_t* shape, + const diopiScalar_t* fill_value, diopiTensorHandle_t out) { + impl::aten::setCurStream(ctx); + + auto status = diopiRequireTensor(ctx, &out, shape, nullptr, dtype, device); + if (status != diopiSuccess) { + return status; + } + auto atOut = impl::aten::buildATen(out); + auto atFillValue = impl::aten::buildAtScalar(fill_value); + CALL_ATEN_FUNC(fill_, atOut, atFillValue); + return diopiSuccess; +} + +diopiError_t diopiRandNormal(diopiContextHandle_t ctx, const diopiDtype_t dtype, diopiDevice_t device, diopiTensorHandle_t out, const diopiScalar_t* mean, + const diopiScalar_t* std, const int64_t seed, const diopiSize_t* shape) { + auto atMean = impl::aten::buildAtScalar(mean); + auto atStd = impl::aten::buildAtScalar(std); + auto atShape = impl::aten::buildAtIntArray(shape); + + auto status = diopiRequireTensor(ctx, &out, shape, nullptr, dtype, device); + if (status != diopiSuccess) return status; + + auto atOut = impl::aten::buildATen(out); + + if (seed != 0) { + at::manual_seed(seed); + } + + auto standardNormal = at::randn(atShape); + + auto scaled = standardNormal.mul(atStd).add(atMean); + + at::native::copy_(atOut, scaled); + + return diopiSuccess; +} + diopiError_t diopiRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { impl::aten::setCurStream(ctx); auto atOut = impl::aten::buildATen(out); diff --git a/proto/include/diopi/functions.h b/proto/include/diopi/functions.h index 4f7dfcecb..c09a20d3f 100644 --- a/proto/include/diopi/functions.h +++ b/proto/include/diopi/functions.h @@ -20,6 +20,46 @@ DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetVendorName(); DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetImplVersion(); DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetLastErrorString(); +/** + * @brief Creates an empty tensor with the specified data type, device, and shape. + * @param[in] ctx Context environment. + * @param[in] dtype The desired data type for the tensor. + * @param[in] device The device where the tensor will be allocated. + * @param[in] shape the shape of the tensor to be created. + * @param[out] output Pointer to a tensor handle that will hold the reference to the created tensor. + * @return diopiError_t Status of the operation; diopiSuccess on successful tensor creation. + */ +DIOPI_API diopiError_t diopiEmpty(diopiContextHandle_t ctx, const diopiDtype_t dtype, const diopiDevice_t device, const diopiSize_t* shape, + diopiTensorHandle_t out); + +/** + * @brief Creates a tensor filled with a specified value, with the specified data type, device, and shape. + * @param[in] ctx Context environment. + * @param[in] dtype The desired data type for the tensor. + * @param[in] device The device where the tensor will be allocated. + * @param[in] shape The shape of the tensor to be created. + * @param[in] fill_value A pointer to the value that will fill the tensor. + * @param[out] out Pointer to a tensor handle that will hold the reference to the created tensor. + * @return diopiError_t Status of the operation; diopiSuccess on successful tensor creation. + */ +DIOPI_API diopiError_t diopiFull(diopiContextHandle_t ctx, const diopiDtype_t dtype, const diopiDevice_t device, const diopiSize_t* shape, + const diopiScalar_t* fill_value, diopiTensorHandle_t out); + +/** + * @brief Creates a tensor filled with random numbers drawn from a normal distribution with specified mean and standard deviation. + * @param[in] ctx Context environment. + * @param[in] dtype The desired data type for the tensor. + * @param[in] device The device where the tensor will be allocated. + * @param[in] mean A pointer to the mean value of the normal distribution. + * @param[in] std A pointer to the standard deviation of the normal distribution. + * @param[in] seed The random seed for reproducibility. + * @param[in] shape The shape of the tensor to be created. + * @param[out] out Pointer to a tensor handle that will hold the reference to the created tensor. + * @return diopiError_t Status of the operation; diopiSuccess on successful tensor creation. + */ +DIOPI_API diopiError_t diopiRandNormal(diopiContextHandle_t ctx, const diopiDtype_t dtype, diopiDevice_t device, const diopiScalar_t* mean, + const diopiScalar_t* std, const int64_t seed, const diopiSize_t* shape, diopiTensorHandle_t out); + /** * @brief Applies a 2D convolution over an input image composed of several input planes. * @param[in] ctx Context environment. From 2345e3722c2f115e889ad30d211ac8c46573945b Mon Sep 17 00:00:00 2001 From: Yin Hongyun Date: Tue, 5 Nov 2024 16:30:35 +0800 Subject: [PATCH 2/2] [feat] add diopiEmpty, diopiFull, diopiRandNormal --- impl/torch/functions/functions.cpp | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 44da17296..0b510e91a 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -65,35 +65,38 @@ const char* diopiGetImplVersion() { return version; } -diopiError_t diopiEmpty(diopiContextHandle_t ctx, const diopiDtype_t dtype, const diopiDevice_t device, const diopiSize_t* shape, diopiTensorHandle_t out) { +diopiError_t diopiEmpty(diopiContextHandle_t ctx, const diopiSize_t shape, diopiTensorHandle_t out) { impl::aten::setCurStream(ctx); - return diopiRequireTensor(ctx, &out, shape, nullptr, dtype, device); + auto atOut = impl::aten::buildATen(out); + auto atSize = impl::aten::buildAtIntArray(shape); + CALL_ATEN_FUNC(empty_out, atOut, atSize); + return diopiSuccess; } -diopiError_t diopiFull(diopiContextHandle_t ctx, const diopiDtype_t dtype, const diopiDevice_t device, const diopiSize_t* shape, - const diopiScalar_t* fill_value, diopiTensorHandle_t out) { +diopiError_t diopiFull(diopiContextHandle_t ctx, const diopiSize_t shape, const diopiScalar_t* fill_value, diopiTensorHandle_t out) { impl::aten::setCurStream(ctx); - auto status = diopiRequireTensor(ctx, &out, shape, nullptr, dtype, device); + auto status = diopiEmpty(ctx, shape, out); if (status != diopiSuccess) { return status; } auto atOut = impl::aten::buildATen(out); + auto atSize = impl::aten::buildAtIntArray(shape); auto atFillValue = impl::aten::buildAtScalar(fill_value); - CALL_ATEN_FUNC(fill_, atOut, atFillValue); + CALL_ATEN_FUNC(full_out, atOut, atSize, atFillValue); return diopiSuccess; } -diopiError_t diopiRandNormal(diopiContextHandle_t ctx, const diopiDtype_t dtype, diopiDevice_t device, diopiTensorHandle_t out, const diopiScalar_t* mean, - const diopiScalar_t* std, const int64_t seed, const diopiSize_t* shape) { - auto atMean = impl::aten::buildAtScalar(mean); - auto atStd = impl::aten::buildAtScalar(std); - auto atShape = impl::aten::buildAtIntArray(shape); - - auto status = diopiRequireTensor(ctx, &out, shape, nullptr, dtype, device); +diopiError_t diopiRandNormal(diopiContextHandle_t ctx, const diopiScalar_t* mean, const diopiScalar_t* std, const int64_t seed, const diopiSize_t shape, + diopiTensorHandle_t out) { + auto status = diopiEmpty(ctx, shape, out); if (status != diopiSuccess) return status; + impl::aten::setCurStream(ctx); auto atOut = impl::aten::buildATen(out); + auto atMean = impl::aten::buildAtScalar(mean); + auto atStd = impl::aten::buildAtScalar(std); + auto atShape = impl::aten::buildAtIntArray(shape); if (seed != 0) { at::manual_seed(seed);