diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 6ee8e104e..0b510e91a 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -65,6 +65,52 @@ const char* diopiGetImplVersion() { return version; } +diopiError_t diopiEmpty(diopiContextHandle_t ctx, const diopiSize_t shape, diopiTensorHandle_t out) { + impl::aten::setCurStream(ctx); + auto atOut = impl::aten::buildATen(out); + auto atSize = impl::aten::buildAtIntArray(shape); + CALL_ATEN_FUNC(empty_out, atOut, atSize); + return diopiSuccess; +} + +diopiError_t diopiFull(diopiContextHandle_t ctx, const diopiSize_t shape, const diopiScalar_t* fill_value, diopiTensorHandle_t out) { + impl::aten::setCurStream(ctx); + + auto status = diopiEmpty(ctx, shape, out); + if (status != diopiSuccess) { + return status; + } + auto atOut = impl::aten::buildATen(out); + auto atSize = impl::aten::buildAtIntArray(shape); + auto atFillValue = impl::aten::buildAtScalar(fill_value); + CALL_ATEN_FUNC(full_out, atOut, atSize, atFillValue); + return diopiSuccess; +} + +diopiError_t diopiRandNormal(diopiContextHandle_t ctx, const diopiScalar_t* mean, const diopiScalar_t* std, const int64_t seed, const diopiSize_t shape, + diopiTensorHandle_t out) { + auto status = diopiEmpty(ctx, shape, out); + if (status != diopiSuccess) return status; + + impl::aten::setCurStream(ctx); + auto atOut = impl::aten::buildATen(out); + auto atMean = impl::aten::buildAtScalar(mean); + auto atStd = impl::aten::buildAtScalar(std); + auto atShape = impl::aten::buildAtIntArray(shape); + + if (seed != 0) { + at::manual_seed(seed); + } + + auto standardNormal = at::randn(atShape); + + auto scaled = standardNormal.mul(atStd).add(atMean); + + at::native::copy_(atOut, scaled); + + return diopiSuccess; +} + diopiError_t diopiRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { impl::aten::setCurStream(ctx); auto atOut = impl::aten::buildATen(out); diff --git a/proto/include/diopi/functions.h b/proto/include/diopi/functions.h index 4f7dfcecb..c09a20d3f 100644 --- a/proto/include/diopi/functions.h +++ b/proto/include/diopi/functions.h @@ -20,6 +20,46 @@ DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetVendorName(); DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetImplVersion(); DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetLastErrorString(); +/** + * @brief Creates an empty tensor with the specified data type, device, and shape. + * @param[in] ctx Context environment. + * @param[in] dtype The desired data type for the tensor. + * @param[in] device The device where the tensor will be allocated. + * @param[in] shape the shape of the tensor to be created. + * @param[out] output Pointer to a tensor handle that will hold the reference to the created tensor. + * @return diopiError_t Status of the operation; diopiSuccess on successful tensor creation. + */ +DIOPI_API diopiError_t diopiEmpty(diopiContextHandle_t ctx, const diopiDtype_t dtype, const diopiDevice_t device, const diopiSize_t* shape, + diopiTensorHandle_t out); + +/** + * @brief Creates a tensor filled with a specified value, with the specified data type, device, and shape. + * @param[in] ctx Context environment. + * @param[in] dtype The desired data type for the tensor. + * @param[in] device The device where the tensor will be allocated. + * @param[in] shape The shape of the tensor to be created. + * @param[in] fill_value A pointer to the value that will fill the tensor. + * @param[out] out Pointer to a tensor handle that will hold the reference to the created tensor. + * @return diopiError_t Status of the operation; diopiSuccess on successful tensor creation. + */ +DIOPI_API diopiError_t diopiFull(diopiContextHandle_t ctx, const diopiDtype_t dtype, const diopiDevice_t device, const diopiSize_t* shape, + const diopiScalar_t* fill_value, diopiTensorHandle_t out); + +/** + * @brief Creates a tensor filled with random numbers drawn from a normal distribution with specified mean and standard deviation. + * @param[in] ctx Context environment. + * @param[in] dtype The desired data type for the tensor. + * @param[in] device The device where the tensor will be allocated. + * @param[in] mean A pointer to the mean value of the normal distribution. + * @param[in] std A pointer to the standard deviation of the normal distribution. + * @param[in] seed The random seed for reproducibility. + * @param[in] shape The shape of the tensor to be created. + * @param[out] out Pointer to a tensor handle that will hold the reference to the created tensor. + * @return diopiError_t Status of the operation; diopiSuccess on successful tensor creation. + */ +DIOPI_API diopiError_t diopiRandNormal(diopiContextHandle_t ctx, const diopiDtype_t dtype, diopiDevice_t device, const diopiScalar_t* mean, + const diopiScalar_t* std, const int64_t seed, const diopiSize_t* shape, diopiTensorHandle_t out); + /** * @brief Applies a 2D convolution over an input image composed of several input planes. * @param[in] ctx Context environment.