Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,52 @@ const char* diopiGetImplVersion() {
return version;
}

diopiError_t diopiEmpty(diopiContextHandle_t ctx, const diopiSize_t shape, diopiTensorHandle_t out) {
impl::aten::setCurStream(ctx);
auto atOut = impl::aten::buildATen(out);
auto atSize = impl::aten::buildAtIntArray(shape);
CALL_ATEN_FUNC(empty_out, atOut, atSize);
return diopiSuccess;
}

diopiError_t diopiFull(diopiContextHandle_t ctx, const diopiSize_t shape, const diopiScalar_t* fill_value, diopiTensorHandle_t out) {
impl::aten::setCurStream(ctx);

auto status = diopiEmpty(ctx, shape, out);
if (status != diopiSuccess) {
return status;
}
auto atOut = impl::aten::buildATen(out);
auto atSize = impl::aten::buildAtIntArray(shape);
auto atFillValue = impl::aten::buildAtScalar(fill_value);
CALL_ATEN_FUNC(full_out, atOut, atSize, atFillValue);
return diopiSuccess;
}

diopiError_t diopiRandNormal(diopiContextHandle_t ctx, const diopiScalar_t* mean, const diopiScalar_t* std, const int64_t seed, const diopiSize_t shape,
diopiTensorHandle_t out) {
auto status = diopiEmpty(ctx, shape, out);
if (status != diopiSuccess) return status;

impl::aten::setCurStream(ctx);
auto atOut = impl::aten::buildATen(out);
auto atMean = impl::aten::buildAtScalar(mean);
auto atStd = impl::aten::buildAtScalar(std);
auto atShape = impl::aten::buildAtIntArray(shape);

if (seed != 0) {
at::manual_seed(seed);
}

auto standardNormal = at::randn(atShape);

auto scaled = standardNormal.mul(atStd).add(atMean);

at::native::copy_(atOut, scaled);

return diopiSuccess;
}

diopiError_t diopiRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atOut = impl::aten::buildATen(out);
Expand Down
40 changes: 40 additions & 0 deletions proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,46 @@ DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetVendorName();
DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetImplVersion();
DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetLastErrorString();

/**
* @brief Creates an empty tensor with the specified data type, device, and shape.
* @param[in] ctx Context environment.
* @param[in] dtype The desired data type for the tensor.
* @param[in] device The device where the tensor will be allocated.
* @param[in] shape the shape of the tensor to be created.
* @param[out] output Pointer to a tensor handle that will hold the reference to the created tensor.
* @return diopiError_t Status of the operation; diopiSuccess on successful tensor creation.
*/
DIOPI_API diopiError_t diopiEmpty(diopiContextHandle_t ctx, const diopiDtype_t dtype, const diopiDevice_t device, const diopiSize_t* shape,
diopiTensorHandle_t out);

/**
* @brief Creates a tensor filled with a specified value, with the specified data type, device, and shape.
* @param[in] ctx Context environment.
* @param[in] dtype The desired data type for the tensor.
* @param[in] device The device where the tensor will be allocated.
* @param[in] shape The shape of the tensor to be created.
* @param[in] fill_value A pointer to the value that will fill the tensor.
* @param[out] out Pointer to a tensor handle that will hold the reference to the created tensor.
* @return diopiError_t Status of the operation; diopiSuccess on successful tensor creation.
*/
DIOPI_API diopiError_t diopiFull(diopiContextHandle_t ctx, const diopiDtype_t dtype, const diopiDevice_t device, const diopiSize_t* shape,
const diopiScalar_t* fill_value, diopiTensorHandle_t out);

/**
* @brief Creates a tensor filled with random numbers drawn from a normal distribution with specified mean and standard deviation.
* @param[in] ctx Context environment.
* @param[in] dtype The desired data type for the tensor.
* @param[in] device The device where the tensor will be allocated.
* @param[in] mean A pointer to the mean value of the normal distribution.
* @param[in] std A pointer to the standard deviation of the normal distribution.
* @param[in] seed The random seed for reproducibility.
* @param[in] shape The shape of the tensor to be created.
* @param[out] out Pointer to a tensor handle that will hold the reference to the created tensor.
* @return diopiError_t Status of the operation; diopiSuccess on successful tensor creation.
*/
DIOPI_API diopiError_t diopiRandNormal(diopiContextHandle_t ctx, const diopiDtype_t dtype, diopiDevice_t device, const diopiScalar_t* mean,
const diopiScalar_t* std, const int64_t seed, const diopiSize_t* shape, diopiTensorHandle_t out);

/**
* @brief Applies a 2D convolution over an input image composed of several input planes.
* @param[in] ctx Context environment.
Expand Down