Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional

import torch
Expand Down Expand Up @@ -239,22 +245,22 @@
long col = 0;
for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) {
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col);
auto tmp1 = tmp0 * vec_sum_scale;
auto tmp2 = tmp1.round();
auto tmp3 = tmp2 + vec_beta1;
auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1);
auto tmp3 = tmp1.round();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also apply the optimization to the below masked vectorization part?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val);
store(tmp_out + col, tmp4);
auto tmp6 = at::vec::convert<int32_t>(tmp4);
auto tmp7 = at::vec::convert<scalar_t>(tmp6);
tmp7.store(tmp_out + col, vec_size);
vec_tmp_sum += tmp6;
}
if (col < kvBlockSize) {
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col, kvBlockSize - col);
auto tmp1 = tmp0 * vec_sum_scale;
auto tmp2 = tmp1.round();
auto tmp3 = tmp2 + vec_beta1;
auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1);
auto tmp3 = tmp1.round();
auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val);
store(tmp_out + col, tmp4, kvBlockSize - col);
auto tmp6 = at::vec::convert<int32_t>(tmp4);
auto tmp7 = at::vec::convert<scalar_t>(tmp6);
tmp7.store(tmp_out + col, kvBlockSize - col);
vec_tmp_sum = at::vec::Vectorized<int32_t>::set(vec_tmp_sum, vec_tmp_sum + tmp6, kvBlockSize - col);
}
sum_a_ptr[row] += vec_tmp_sum.reduce_add() * beta2;
Expand Down Expand Up @@ -341,17 +347,15 @@
long col = 0;
for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) {
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col);
auto tmp1 = tmp0 * vec_sum_scale;
auto tmp2 = tmp1.round();
auto tmp3 = tmp2 + vec_beta1;
auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1);
auto tmp3 = tmp1.round();
auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val);
store(tmp_out + col, tmp4);
}
if (col < kvBlockSize) {
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col, kvBlockSize - col);
auto tmp1 = tmp0 * vec_sum_scale;
auto tmp2 = tmp1.round();
auto tmp3 = tmp2 + vec_beta1;
auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1);
auto tmp3 = tmp1.round();
auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val);
store(tmp_out + col, tmp4, kvBlockSize - col);
}
Expand Down Expand Up @@ -406,9 +410,8 @@
auto tmp2 = tmp1 - vec_sum_a;
auto tmp3 = tmp2 + vec_beta1;
auto tmp4 = at::vec::convert<float>(tmp3);
auto tmp5 = tmp4 * vec_alpha;
auto tmp6 = tmp5.round();
auto tmp7 = tmp6 + vec_beta2;
auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2);
auto tmp7 = tmp5.round();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val);
store(tmp_out + col, tmp8);
}
Expand All @@ -419,9 +422,8 @@
auto tmp2 = tmp1 - vec_sum_a;
auto tmp3 = tmp2 + vec_beta1;
auto tmp4 = at::vec::convert<float>(tmp3);
auto tmp5 = tmp4 * vec_alpha;
auto tmp6 = tmp5.round();
auto tmp7 = tmp6 + vec_beta2;
auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2);
auto tmp7 = tmp5.round();
auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val);
store(tmp_out + col, tmp8, N - col);
}
Expand Down Expand Up @@ -463,19 +465,17 @@
auto tmp3 = tmp1 - vec_sum_a;
// auto tmp3 = tmp2 + vec_beta1;
auto tmp4 = at::vec::convert<float>(tmp3);
auto tmp5 = tmp4 * vec_alpha;
auto tmp6 = tmp5.round();
auto tmp7 = tmp6 + vec_beta2;
auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2);
auto tmp7 = tmp5.round();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val);
store(tmp_out + col, tmp8);
}
if (col < N) {
auto tmp1 = at::vec::Vectorized<int32_t>::loadu(tmp_in + col, N - col);
auto tmp3 = tmp1 - vec_sum_a;
auto tmp4 = at::vec::convert<float>(tmp3);
auto tmp5 = tmp4 * vec_alpha;
auto tmp6 = tmp5.round();
auto tmp7 = tmp6 + vec_beta2;
auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2);
auto tmp7 = tmp5.round();
auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val);
store(tmp_out + col, tmp8, N - col);
}
Expand Down Expand Up @@ -1384,7 +1384,7 @@
q_sum_ptr, static_cast<int32_t>(0), qSplitSize);
{%- endif %}
const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1;

for (int64_t l = 0; l < rkvSlice; l++) {
int64_t n = l * kvSplitSize;
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
Expand Down
Loading