Skip to content

Commit 19e4d69

Browse files
committed
feat: enable parallel execution policy in scan_cpu for improved performance
1 parent adfb329 commit 19e4d69

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

torchlpc/csrc/scan_cpu.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <torch/torch.h>
44

55
#include <algorithm>
6+
#include <execution>
67
#include <utility>
78
#include <vector>
89

@@ -58,15 +59,16 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
5859
const scalar_t *weights_ptr = weights_contiguous.const_data_ptr<scalar_t>();
5960
scalar_t *output_ptr = output.mutable_data_ptr<scalar_t>();
6061

61-
std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer,
62-
[](const scalar_t &a, const scalar_t &b) {
62+
std::transform(std::execution::par, weights_ptr, weights_ptr + total_size,
63+
input_ptr, buffer, [](const scalar_t &a, const scalar_t &b) {
6364
return std::make_pair(a, b);
6465
});
6566

6667
at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end) {
6768
for (auto b = start; b < end; b++) {
6869
std::inclusive_scan(
69-
buffer + b * T, buffer + (b + 1) * T, buffer + b * T,
70+
std::execution::par, buffer + b * T, buffer + (b + 1) * T,
71+
buffer + b * T,
7072
[](const std::pair<scalar_t, scalar_t> &a,
7173
const std::pair<scalar_t, scalar_t> &b) {
7274
return std::make_pair(a.first * b.first,
@@ -77,7 +79,7 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
7779
});
7880

7981
std::transform(
80-
buffer, buffer + total_size, output_ptr,
82+
std::execution::par, buffer, buffer + total_size, output_ptr,
8183
[](const std::pair<scalar_t, scalar_t> &a) { return a.second; });
8284
}
8385

0 commit comments

Comments
 (0)