diff --git a/tensorflow/lite/delegates/webnn/webnn_delegate.cc b/tensorflow/lite/delegates/webnn/webnn_delegate.cc index c01dbb025184ea..7cb89c57e9353e 100644 --- a/tensorflow/lite/delegates/webnn/webnn_delegate.cc +++ b/tensorflow/lite/delegates/webnn/webnn_delegate.cc @@ -315,6 +315,15 @@ class Subgraph { } } + + emscripten::val graph_inputs = emscripten::val::object(); + for (auto& i : compute_inputs) { + std::string name = std::to_string(i); + auto input_size = context->tensors[i].bytes / 4; + graph_inputs.set(name, emscripten::val::global("Float32Array").new_(input_size)); + } + + emscripten::val graph_outputs = emscripten::val::object(); emscripten::val named_operands = emscripten::val::object(); for (auto o : outputs) { std::string name = std::to_string(o); @@ -322,15 +331,19 @@ class Subgraph { TF_LITE_KERNEL_LOG(context, "Invalid operand"); return nullptr; } + auto output_size = context->tensors[o].bytes / 4; + graph_outputs.set(name, emscripten::val::global("Float32Array").new_(output_size)); named_operands.set(name, webnn_operands.at(o)); } - emscripten::val wnn_graph = wnn_builder.call("buildSync", named_operands); + emscripten::val wnn_graph = wnn_builder.call("build", named_operands).await(); if (!wnn_graph.as()) { TF_LITE_KERNEL_LOG(context, "failed to build WebNN graph"); return nullptr; } - return new Subgraph(delegate->wnn_context_, wnn_graph, std::move(compute_inputs), std::move(outputs)); + return new Subgraph(delegate->wnn_context_, wnn_graph, + std::move(compute_inputs), std::move(outputs), + graph_inputs, graph_outputs); } TfLiteStatus Prepare(TfLiteContext* context) { return kTfLiteOk; } @@ -356,27 +369,43 @@ class Subgraph { } } + std::unordered_map output_views; if (any_pointers_changed) { - graph_inputs_ = emscripten::val::object(); for (int t : inputs_) { std::string name = std::to_string(t); auto input_size = context->tensors[t].bytes / 4; auto input_data = context->tensors[t].data.f; emscripten::val view{ emscripten::typed_memory_view(input_size, input_data) }; - graph_inputs_.set(name, view); + graph_inputs_[name].call("set", view); } - graph_outputs_ = emscripten::val::object(); for (int t : outputs_) { std::string name = std::to_string(t); auto output_size = context->tensors[t].bytes / 4; auto output_data = context->tensors[t].data.f; emscripten::val view{emscripten::typed_memory_view(output_size, output_data)}; - graph_outputs_.set(name, view); + output_views.insert({name, view}); + } + } + + const emscripten::val results = + wnn_context_ + .call("compute", wnn_graph_, graph_inputs_, graph_outputs_) + .await(); + + if (any_pointers_changed) { + // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer. + for (int t : outputs_) { + std::string name = std::to_string(t); + emscripten::val view = output_views.at(name); + view.call("set", results["outputs"][name]); } } - wnn_context_.call("computeSync", wnn_graph_, graph_inputs_, graph_outputs_); + // WebNN compute() method would return the input and output buffers via the + // promise resolution. Reuse the buffers to avoid additional allocation. + graph_inputs_ = results["inputs"]; + graph_outputs_ = results["outputs"]; return kTfLiteOk; } @@ -3692,16 +3721,21 @@ class Subgraph { } private: - Subgraph(emscripten::val context, emscripten::val graph, std::unordered_set&& inputs, std::unordered_set&& outputs) - : wnn_context_(context), wnn_graph_(graph), inputs_(inputs), outputs_(outputs) { + Subgraph(emscripten::val context, emscripten::val graph, + std::unordered_set&& inputs, std::unordered_set&& outputs, + emscripten::val graph_inputs, emscripten::val graph_outputs) + : wnn_context_(context), + wnn_graph_(graph), + inputs_(inputs), + outputs_(outputs), + graph_inputs_(graph_inputs), + graph_outputs_(graph_outputs) { for (auto& i : inputs_) { externals_[i] = nullptr; } for (auto& o : outputs_) { externals_[o] = nullptr; } - graph_inputs_ = emscripten::val::object(); - graph_outputs_ = emscripten::val::object(); } emscripten::val wnn_context_ = emscripten::val::object(); @@ -4147,7 +4181,7 @@ TfLiteDelegate* TfLiteWebNNDelegateCreate( emscripten::val(power_preference_name)); context_options.set("numThreads", options->numThreads); emscripten::val wnn_context = - ml.call("createContextSync", context_options); + ml.call("createContext", context_options).await(); if (!wnn_context.as()) { TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_ERROR,