Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions tensorflow/lite/delegates/webnn/webnn_delegate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,22 +315,35 @@ class Subgraph {
}
}


emscripten::val graph_inputs = emscripten::val::object();
for (auto& i : compute_inputs) {
std::string name = std::to_string(i);
auto input_size = context->tensors[i].bytes / 4;
graph_inputs.set(name, emscripten::val::global("Float32Array").new_(input_size));
}

emscripten::val graph_outputs = emscripten::val::object();
emscripten::val named_operands = emscripten::val::object();
for (auto o : outputs) {
std::string name = std::to_string(o);
if (!webnn_operands.at(o).as<bool>()) {
TF_LITE_KERNEL_LOG(context, "Invalid operand");
return nullptr;
}
auto output_size = context->tensors[o].bytes / 4;
graph_outputs.set(name, emscripten::val::global("Float32Array").new_(output_size));
named_operands.set(name, webnn_operands.at(o));
}

emscripten::val wnn_graph = wnn_builder.call<emscripten::val>("buildSync", named_operands);
emscripten::val wnn_graph = wnn_builder.call<emscripten::val>("build", named_operands).await();
if (!wnn_graph.as<bool>()) {
TF_LITE_KERNEL_LOG(context, "failed to build WebNN graph");
return nullptr;
}
return new Subgraph(delegate->wnn_context_, wnn_graph, std::move(compute_inputs), std::move(outputs));
return new Subgraph(delegate->wnn_context_, wnn_graph,
std::move(compute_inputs), std::move(outputs),
graph_inputs, graph_outputs);
}

TfLiteStatus Prepare(TfLiteContext* context) { return kTfLiteOk; }
Expand All @@ -356,27 +369,43 @@ class Subgraph {
}
}

std::unordered_map<std::string, emscripten::val> output_views;
if (any_pointers_changed) {
graph_inputs_ = emscripten::val::object();
for (int t : inputs_) {
std::string name = std::to_string(t);
auto input_size = context->tensors[t].bytes / 4;
auto input_data = context->tensors[t].data.f;
emscripten::val view{ emscripten::typed_memory_view(input_size, input_data) };
graph_inputs_.set(name, view);
graph_inputs_[name].call<void>("set", view);
}

graph_outputs_ = emscripten::val::object();
for (int t : outputs_) {
std::string name = std::to_string(t);
auto output_size = context->tensors[t].bytes / 4;
auto output_data = context->tensors[t].data.f;
emscripten::val view{emscripten::typed_memory_view(output_size, output_data)};
graph_outputs_.set(name, view);
output_views.insert({name, view});
}
}

const emscripten::val results =
wnn_context_
.call<emscripten::val>("compute", wnn_graph_, graph_inputs_, graph_outputs_)
.await();

if (any_pointers_changed) {
// Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer.
for (int t : outputs_) {
std::string name = std::to_string(t);
emscripten::val view = output_views.at(name);
view.call<void>("set", results["outputs"][name]);
}
}

wnn_context_.call<void>("computeSync", wnn_graph_, graph_inputs_, graph_outputs_);
// WebNN compute() method would return the input and output buffers via the
// promise resolution. Reuse the buffers to avoid additional allocation.
graph_inputs_ = results["inputs"];
graph_outputs_ = results["outputs"];

return kTfLiteOk;
}
Expand Down Expand Up @@ -3692,16 +3721,21 @@ class Subgraph {
}

private:
Subgraph(emscripten::val context, emscripten::val graph, std::unordered_set<int>&& inputs, std::unordered_set<int>&& outputs)
: wnn_context_(context), wnn_graph_(graph), inputs_(inputs), outputs_(outputs) {
Subgraph(emscripten::val context, emscripten::val graph,
std::unordered_set<int>&& inputs, std::unordered_set<int>&& outputs,
emscripten::val graph_inputs, emscripten::val graph_outputs)
: wnn_context_(context),
wnn_graph_(graph),
inputs_(inputs),
outputs_(outputs),
graph_inputs_(graph_inputs),
graph_outputs_(graph_outputs) {
for (auto& i : inputs_) {
externals_[i] = nullptr;
}
for (auto& o : outputs_) {
externals_[o] = nullptr;
}
graph_inputs_ = emscripten::val::object();
graph_outputs_ = emscripten::val::object();
}

emscripten::val wnn_context_ = emscripten::val::object();
Expand Down Expand Up @@ -4147,7 +4181,7 @@ TfLiteDelegate* TfLiteWebNNDelegateCreate(
emscripten::val(power_preference_name));
context_options.set("numThreads", options->numThreads);
emscripten::val wnn_context =
ml.call<emscripten::val>("createContextSync", context_options);
ml.call<emscripten::val>("createContext", context_options).await();

if (!wnn_context.as<bool>()) {
TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_ERROR,
Expand Down