Skip to content
10 changes: 3 additions & 7 deletions Android.bp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ cc_library_shared {
],

include_dirs: [
"packages/modules/NeuralNetworks/common/include",
"packages/modules/NeuralNetworks/common/types/include",
"packages/modules/NeuralNetworks/runtime/include",
"frameworks/ml/nn/runtime/include/",
"frameworks/native/libs/nativewindow/include",
"external/mesa3d/include/android_stub",
"external/grpc-grpc",
Expand Down Expand Up @@ -168,9 +166,8 @@ cc_binary {
srcs: ["service.cpp"],

include_dirs: [
"packages/modules/NeuralNetworks/common/include",
"packages/modules/NeuralNetworks/common/types/include",
"packages/modules/NeuralNetworks/runtime/include",
"frameworks/ml/nn/common/include",
"frameworks/ml/nn/runtime/include/",
"frameworks/native/libs/nativewindow/include",
"external/mesa3d/include/android_stub",
],
Expand All @@ -186,7 +183,6 @@ cc_binary {

shared_libs: [
"libhidlbase",
"libhidltransport",
"libhidlmemory",
"libutils",
"liblog",
Expand Down
120 changes: 61 additions & 59 deletions BasePreparedModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <android/log.h>
#include <cutils/properties.h>
#include <log/log.h>
#include <thread>
#include "ExecutionBurstServer.h"
#include "ValidateHal.h"

Expand All @@ -33,18 +32,25 @@ namespace android::hardware::neuralnetworks::nnhal {
using namespace android::nn;

static const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
bool mRemoteCheck = false;
std::shared_ptr<DetectionClient> mDetectionClient;
uint32_t BasePreparedModel::mFileId = 0;

void BasePreparedModel::deinitialize() {
ALOGV("Entering %s", __func__);
bool is_success = false;
mModelInfo->unmapRuntimeMemPools();
auto ret_xml = std::remove(mXmlFile.c_str());
auto ret_bin = std::remove(mBinFile.c_str());
if ((ret_xml != 0) || (ret_bin != 0)) {
ALOGW("%s Deletion status of xml:%d, bin:%d", __func__, ret_xml, ret_bin);
}
if (mRemoteLoadThread.joinable()) {
mRemoteLoadThread.join();
}
if (mRemoteCheck && mDetectionClient) {
auto reply = mDetectionClient->release(is_success);
ALOGI("GRPC release response is %d : %s", is_success, reply.c_str());
}
setRemoteEnabled(false);

ALOGV("Exiting %s", __func__);
}
Expand All @@ -62,11 +68,11 @@ bool BasePreparedModel::initialize() {
ALOGE("Failed to initialize Model runtime parameters!!");
return false;
}
checkRemoteConnection();

mNgraphNetCreator = std::make_shared<NgraphNetworkCreator>(mModelInfo, mTargetDevice);

if (!mNgraphNetCreator->validateOperations()) return false;
ALOGI("Generating IR Graph");
ALOGI("Generating IR Graph for Model %u", mFileId);
auto ov_model = mNgraphNetCreator->generateGraph();
if (ov_model == nullptr) {
ALOGE("%s Openvino model generation failed", __func__);
Expand All @@ -75,16 +81,15 @@ bool BasePreparedModel::initialize() {
try {
mPlugin = std::make_unique<IENetwork>(mTargetDevice, ov_model);
mPlugin->loadNetwork(mXmlFile, mBinFile);
if(mRemoteCheck) {
auto resp = loadRemoteModel(mXmlFile, mBinFile);
ALOGD("%s Load Remote Model returns %d", __func__, resp);
} else {
ALOGD("%s Remote connection unavailable", __func__);
}
} catch (const std::exception& ex) {
ALOGE("%s Exception !!! %s", __func__, ex.what());
return false;
}
{
mRemoteLoadThread = std::thread([this] {
loadRemoteModel(mXmlFile, mBinFile);
});
}

ALOGV("Exiting %s", __func__);
return true;
Expand All @@ -95,38 +100,51 @@ bool BasePreparedModel::checkRemoteConnection() {
bool is_success = false;
if(getGrpcIpPort(grpc_prop)) {
ALOGV("Attempting GRPC via TCP : %s", grpc_prop);
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(INT_MAX);
args.SetMaxSendMessageSize(INT_MAX);
mDetectionClient = std::make_shared<DetectionClient>(
grpc::CreateChannel(grpc_prop, grpc::InsecureChannelCredentials()));
grpc::CreateCustomChannel(grpc_prop, grpc::InsecureChannelCredentials(), args), mFileId);
if(mDetectionClient) {
auto reply = mDetectionClient->prepare(is_success);
ALOGI("GRPC(TCP) prepare response is %d : %s", is_success, reply.c_str());
}
}
if (!is_success && getGrpcSocketPath(grpc_prop)) {
ALOGV("Attempting GRPC via unix : %s", grpc_prop);
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(INT_MAX);
args.SetMaxSendMessageSize(INT_MAX);
mDetectionClient = std::make_shared<DetectionClient>(
grpc::CreateChannel(std::string("unix:") + grpc_prop, grpc::InsecureChannelCredentials()));
if(mDetectionClient) {
grpc::CreateCustomChannel(std::string("unix:") + grpc_prop, grpc::InsecureChannelCredentials(), args), mFileId);
if (mDetectionClient) {
auto reply = mDetectionClient->prepare(is_success);
ALOGI("GRPC(unix) prepare response is %d : %s", is_success, reply.c_str());
} else {
ALOGE("%s mDetectionClient is null", __func__);
}
}
mRemoteCheck = is_success;
return is_success;
}

bool BasePreparedModel::loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin) {
ALOGI("Entering %s", __func__);
void BasePreparedModel::loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin) {
ALOGI("Entering %s for Model %u", __func__, mFileId);
bool is_success = false;
if(mDetectionClient) {
if(checkRemoteConnection() && mDetectionClient) {
auto reply = mDetectionClient->sendIRs(is_success, ir_xml, ir_bin);
ALOGI("sendIRs response GRPC %d %s", is_success, reply.c_str());
if (reply == "status False") {
ALOGE("%s Model Load Failed",__func__);
}
setRemoteEnabled(is_success);
}
else {
ALOGE("%s mDetectionClient is null",__func__);
}

void BasePreparedModel::setRemoteEnabled(bool flag) {
if(mRemoteCheck != flag) {
ALOGD("GRPC %s Remote Connection", flag ? "ACQUIRED" : "RELEASED");
mRemoteCheck = flag;
}
mRemoteCheck = is_success;
return is_success;
}

static Return<void> notify(const sp<V1_0::IExecutionCallback>& callback, const ErrorStatus& status,
Expand Down Expand Up @@ -268,20 +286,13 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
ALOGD("%s Run", __func__);

if (measure == MeasureTiming::YES) deviceStart = now();
if(mRemoteCheck) {
ALOGI("%s GRPC Remote Infer", __func__);
auto reply = mDetectionClient->remote_infer();
ALOGI("***********GRPC server response************* %s", reply.c_str());
}
if (!mRemoteCheck || !mDetectionClient->get_status()){
try {
plugin->infer();
} catch (const std::exception& ex) {
ALOGE("%s Exception !!! %s", __func__, ex.what());
notify(callback, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
return;
}
}
if (measure == MeasureTiming::YES) deviceEnd = now();

tensorIndex = 0;
Expand Down Expand Up @@ -332,10 +343,7 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
return;
}

if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) {
mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
ngraphNw->getOutputShape(outIndex));
} else {
{
switch (operandType) {
case OperandType::TENSOR_INT32:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int32_t>(),
Expand Down Expand Up @@ -399,7 +407,7 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
}

static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynchronouslyBase(
const Request& request, MeasureTiming measure, BasePreparedModel* preparedModel,
const V1_3::Request& request, MeasureTiming measure, BasePreparedModel* preparedModel,
time_point driverStart) {
ALOGV("Entering %s", __func__);
auto modelInfo = preparedModel->getModelInfo();
Expand All @@ -408,7 +416,7 @@ static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynch
time_point driverEnd, deviceStart, deviceEnd;
std::vector<RunTimePoolInfo> requestPoolInfos;
auto errorStatus = modelInfo->setRunTimePoolInfosFromHidlMemories(request.pools);
if (errorStatus != ErrorStatus::NONE) {
if (errorStatus != V1_3::ErrorStatus::NONE) {
ALOGE("Failed to set runtime pool info from HIDL memories");
return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming};
}
Expand All @@ -427,8 +435,9 @@ static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynch
ALOGV("Input index: %d layername : %s", inIndex, inputNodeName.c_str());
//check if remote infer is available
//TODO: Need to add FLOAT16 support for remote inferencing
if(mRemoteCheck && mDetectionClient) {
mDetectionClient->add_input_data(std::to_string(i), (uint8_t*)srcPtr, ngraphNw->getOutputShape(inIndex), len);
if(preparedModel->mRemoteCheck && preparedModel->mDetectionClient) {
auto inOperandType = modelInfo->getOperandType(inIndex);
preparedModel->mDetectionClient->add_input_data(std::to_string(i), (uint8_t*)srcPtr, ngraphNw->getOutputShape(inIndex), len, inOperandType);
} else {
ov::Tensor destTensor;
try {
Expand Down Expand Up @@ -493,12 +502,15 @@ static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynch
ALOGV("%s Run", __func__);

if (measure == MeasureTiming::YES) deviceStart = now();
if(mRemoteCheck) {
if(preparedModel->mRemoteCheck) {
ALOGI("%s GRPC Remote Infer", __func__);
auto reply = mDetectionClient->remote_infer();
auto reply = preparedModel->mDetectionClient->remote_infer();
ALOGI("***********GRPC server response************* %s", reply.c_str());
}
if (!mRemoteCheck || !mDetectionClient->get_status()){
if (!preparedModel->mRemoteCheck || !preparedModel->mDetectionClient->get_status()){
if(preparedModel->mRemoteCheck) {
preparedModel->setRemoteEnabled(false);
}
try {
ALOGV("%s Client Infer", __func__);
plugin->infer();
Expand Down Expand Up @@ -555,9 +567,9 @@ static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynch
}
//copy output from remote infer
//TODO: Add support for other OperandType
if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) {
mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
ngraphNw->getOutputShape(outIndex));
if (preparedModel->mRemoteCheck && preparedModel->mDetectionClient && preparedModel->mDetectionClient->get_status()) {
preparedModel->mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
ngraphNw->getOutputShape(outIndex), expectedLength);
} else {
switch (operandType) {
case OperandType::TENSOR_INT32:
Expand Down Expand Up @@ -606,8 +618,8 @@ static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynch
ALOGE("Failed to update the request pool infos");
return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming};
}
if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) {
mDetectionClient->clear_data();
if (preparedModel->mRemoteCheck && preparedModel->mDetectionClient && preparedModel->mDetectionClient->get_status()) {
preparedModel->mDetectionClient->clear_data();
}

if (measure == MeasureTiming::YES) {
Expand All @@ -631,7 +643,7 @@ Return<void> BasePreparedModel::executeSynchronously(const Request& request, Mea
return Void();
}
auto [status, outputShapes, timing] =
executeSynchronouslyBase(request, measure, this, driverStart);
executeSynchronouslyBase(convertToV1_3(request), measure, this, driverStart);
cb(status, std::move(outputShapes), timing);
ALOGV("Exiting %s", __func__);
return Void();
Expand All @@ -646,12 +658,12 @@ Return<void> BasePreparedModel::executeSynchronously_1_3(const V1_3::Request& re
time_point driverStart;
if (measure == MeasureTiming::YES) driverStart = now();

if (!validateRequest(convertToV1_0(request), convertToV1_2(mModelInfo->getModel()))) {
if (!validateRequest(request, mModelInfo->getModel())) {
cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
return Void();
}
auto [status, outputShapes, timing] =
executeSynchronouslyBase(convertToV1_0(request), measure, this, driverStart);
executeSynchronouslyBase(request, measure, this, driverStart);
cb(convertToV1_3(status), std::move(outputShapes), timing);
ALOGV("Exiting %s", __func__);
return Void();
Expand Down Expand Up @@ -820,20 +832,13 @@ Return<void> BasePreparedModel::executeFenced(const V1_3::Request& request1_3,

time_point deviceStart, deviceEnd;
if (measure == MeasureTiming::YES) deviceStart = now();
if(mRemoteCheck) {
ALOGI("%s GRPC Remote Infer", __func__);
auto reply = mDetectionClient->remote_infer();
ALOGI("***********GRPC server response************* %s", reply.c_str());
}
if (!mRemoteCheck || !mDetectionClient->get_status()){
try {
mPlugin->infer();
} catch (const std::exception& ex) {
ALOGE("%s Exception !!! %s", __func__, ex.what());
cb(V1_3::ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
return Void();
}
}
if (measure == MeasureTiming::YES) deviceEnd = now();

tensorIndex = 0;
Expand Down Expand Up @@ -870,10 +875,7 @@ Return<void> BasePreparedModel::executeFenced(const V1_3::Request& request1_3,
mModelInfo->updateOutputshapes(i, outDims);
}

if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) {
mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
mNgraphNetCreator->getOutputShape(outIndex));
} else {
{
switch (operandType) {
case OperandType::TENSOR_INT32:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int32_t>(),
Expand Down
13 changes: 8 additions & 5 deletions BasePreparedModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <sys/mman.h>
#include <fstream>
#include <string>
#include <thread>

#include <NgraphNetworkCreator.hpp>
#include <openvino/pass/serialize.hpp>
Expand All @@ -49,14 +50,13 @@ namespace android::hardware::neuralnetworks::nnhal {
template <class T>
using vec = std::vector<T>;
typedef uint8_t* memory;
extern bool mRemoteCheck;
extern std::shared_ptr<DetectionClient> mDetectionClient;
class BasePreparedModel : public V1_3::IPreparedModel {
public:
bool mRemoteCheck = false;
BasePreparedModel(const IntelDeviceType device, const Model& model) : mTargetDevice(device) {
mModelInfo = std::make_shared<NnapiModelInfo>(model);
mXmlFile = std::string("/data/vendor/neuralnetworks/") + std::to_string(mFileId) + std::string(".xml");
mBinFile = std::string("/data/vendor/neuralnetworks/") + std::to_string(mFileId) + std::string(".bin");
mXmlFile = MODEL_DIR + std::to_string(mFileId) + std::string(".xml");
mBinFile = MODEL_DIR + std::to_string(mFileId) + std::string(".bin");
mFileId++;
}

Expand Down Expand Up @@ -89,7 +89,8 @@ class BasePreparedModel : public V1_3::IPreparedModel {

virtual bool initialize();
virtual bool checkRemoteConnection();
virtual bool loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin);
virtual void loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin);
virtual void setRemoteEnabled(bool flag);

std::shared_ptr<NnapiModelInfo> getModelInfo() { return mModelInfo; }

Expand All @@ -98,6 +99,7 @@ class BasePreparedModel : public V1_3::IPreparedModel {
std::shared_ptr<IIENetwork> getPlugin() { return mPlugin; }

std::shared_ptr<ov::Model> modelPtr;
std::shared_ptr<DetectionClient> mDetectionClient;

protected:
virtual void deinitialize();
Expand All @@ -110,6 +112,7 @@ class BasePreparedModel : public V1_3::IPreparedModel {
static uint32_t mFileId;
std::string mXmlFile;
std::string mBinFile;
std::thread mRemoteLoadThread;
};

class BaseFencedExecutionCallback : public V1_3::IFencedExecutionCallback {
Expand Down
Loading