33#include < torch/torch.h>
44
55#include < algorithm>
6+ #include < execution>
67#include < utility>
78#include < vector>
89
@@ -58,15 +59,16 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
5859 const scalar_t *weights_ptr = weights_contiguous.const_data_ptr <scalar_t >();
5960 scalar_t *output_ptr = output.mutable_data_ptr <scalar_t >();
6061
61- std::transform (weights_ptr, weights_ptr + total_size, input_ptr, buffer ,
62- [](const scalar_t &a, const scalar_t &b) {
62+ std::transform (std::execution::par, weights_ptr, weights_ptr + total_size,
63+ input_ptr, buffer, [](const scalar_t &a, const scalar_t &b) {
6364 return std::make_pair (a, b);
6465 });
6566
6667 at::parallel_for (0 , n_batch, 1 , [&](int64_t start, int64_t end) {
6768 for (auto b = start; b < end; b++) {
6869 std::inclusive_scan (
69- buffer + b * T, buffer + (b + 1 ) * T, buffer + b * T,
70+ std::execution::par, buffer + b * T, buffer + (b + 1 ) * T,
71+ buffer + b * T,
7072 [](const std::pair<scalar_t , scalar_t > &a,
7173 const std::pair<scalar_t , scalar_t > &b) {
7274 return std::make_pair (a.first * b.first ,
@@ -77,7 +79,7 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
7779 });
7880
7981 std::transform (
80- buffer, buffer + total_size, output_ptr,
82+ std::execution::par, buffer, buffer + total_size, output_ptr,
8183 [](const std::pair<scalar_t , scalar_t > &a) { return a.second ; });
8284}
8385
0 commit comments