Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit 825729c

Browse files
cpuhrschfacebook-github-bot
authored andcommitted
20210526 nestedtensor import
Reviewed By: bhosmer Differential Revision: D28712349 fbshipit-source-id: 559595fae8f3cff2cfc63fa3c4def5d402a1877d
1 parent ea4c8fa commit 825729c

File tree

7 files changed

+81
-160
lines changed

7 files changed

+81
-160
lines changed

benchmarks/mha_cuda.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55

66
@torch.inference_mode()
7-
def benchmark_torch_function(iters, f, *args):
8-
f(*args)
7+
def benchmark_torch_function(iters, f, *args, **kwargs):
8+
f(*args, **kwargs)
99
if torch.cuda.is_available():
1010
torch.cuda.synchronize()
1111
start_event = torch.cuda.Event(enable_timing=True)
@@ -14,7 +14,7 @@ def benchmark_torch_function(iters, f, *args):
1414
else:
1515
t0 = time.time()
1616
for _ in range(iters):
17-
f(*args)
17+
f(*args, **kwargs)
1818
if torch.cuda.is_available():
1919
end_event.record()
2020
torch.cuda.synchronize()
@@ -38,15 +38,33 @@ def run(bdim, embedding_dim, nhead, min_t, max_t, iters, device):
3838
nt = nestedtensor.nested_tensor(tensors, device=device, dtype=torch.float)
3939

4040
# Create MHA with self-attention in mind
41-
lin = torch.nn.MultiheadAttention(embedding_dim, nhead).to(device).eval()
42-
nt_time = benchmark_torch_function(iters, lin, nt, nt, nt)
43-
# import sys; sys.exit(1)
41+
mha = torch.nn.MultiheadAttention(embedding_dim, nhead).to(device).eval()
42+
43+
# Create regular padded Tensor with corresponding mask
44+
data, mask = nt.to_tensor_mask(mask_dim=2)
45+
# Prepare input for torch.nn.MHA, which is batch second for Tensor input
46+
data = data.transpose(0, 1)
47+
not_mask = torch.logical_not(mask)
48+
49+
# Comparison test to show correctness and API differences
50+
with torch.inference_mode():
51+
nt_output, _ = mha(nt, nt, nt, need_weights=False)
52+
t_output, _ = mha(data, data, data, key_padding_mask=not_mask, need_weights=False)
53+
nt_output_padded = nt_output.to_padded_tensor(padding=0)
54+
t_output = t_output.transpose(0, 1)
55+
# Fill in zero for masked-out values to enable comparison
56+
t_output = t_output * mask.unsqueeze(-1)
57+
# Tolerances taken from torch/testing/_core.py
58+
assert torch.isclose(nt_output_padded, t_output, rtol=1e-4, atol=1e-5).all().item()
59+
60+
# Time NT version
61+
nt_time = benchmark_torch_function(iters, mha, nt, nt, nt, need_weights=False)
4462

45-
# Created regular padded Tensor
46-
data = nt.to_padded_tensor(padding=0)
4763
# Amount of storage used for padding only
4864
percentage_padded = 100 * (data.numel() - nt.numel()) / data.numel()
49-
t_time = benchmark_torch_function(iters, lin, data, data, data)
65+
66+
# Time Tensor version
67+
t_time = benchmark_torch_function(iters, mha, data, data, data, key_padding_mask=not_mask, need_weights=False)
5068

5169
print(f"batch size: {bdim:4.0f}, embedding dim: {embedding_dim}, nhead: {nhead}, T mean:{lengths_mean:5.0f}, T std: {lengths_std:4.0f}", end='')
5270
print(f", padding: {percentage_padded:3.0f}%, NT: {nt_time/iters:4.0f}us, T: {t_time/iters:4.0f}us, Speedup: {t_time/nt_time:3.2f}x")
@@ -56,10 +74,10 @@ def run(bdim, embedding_dim, nhead, min_t, max_t, iters, device):
5674
if torch.cuda.is_available():
5775
print("CUDA device: ", torch.cuda.get_device_name(0))
5876
device = torch.device('cuda')
59-
iters = 1000
77+
iters = 10
6078
for nhead in [2, 4, 8]:
6179
print("")
62-
for embed_dim in [128, 256, 512, 1024]:
80+
for embed_dim in [1024, 512, 256, 128]:
6381
print("")
6482
for min_t, max_t in [(16, 128), (32, 128), (64, 128), (128, 128)]:
6583
run(256, embed_dim, nhead, min_t, max_t, iters, device)

nestedtensor/csrc/cuda/mha.cpp

Lines changed: 21 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,6 @@ using namespace at;
1919
namespace torch {
2020
namespace nested_tensor {
2121

22-
at::Tensor _sequence_mask(at::Tensor lengths) {
23-
int64_t batch_size = lengths.numel();
24-
int64_t max_len = lengths.max().item<int64_t>();
25-
at::Tensor mask = torch::arange(0, max_len, torch::kFloat);
26-
mask = mask.repeat({batch_size, 1});
27-
mask = mask.lt(lengths.unsqueeze(1));
28-
mask = mask.to(torch::kCUDA);
29-
mask = mask.view({-1, 1, 1, max_len});
30-
at::Tensor m2 = mask.transpose(2, 3);
31-
return mask * m2;
32-
}
33-
3422
at::Tensor bt_min_mha(
3523
int64_t num_heads,
3624
int64_t head_dim,
@@ -39,12 +27,8 @@ at::Tensor bt_min_mha(
3927
at::Tensor query,
4028
at::Tensor key,
4129
at::Tensor value,
42-
at::Tensor attr_kernel_Q,
43-
at::Tensor attr_kernel_K,
44-
at::Tensor attr_kernel_V,
45-
at::Tensor attr_bias_Q,
46-
at::Tensor attr_bias_K,
47-
at::Tensor attr_bias_V,
30+
at::Tensor attr_kernel,
31+
at::Tensor attr_bias,
4832
double scaling,
4933
at::Tensor out_proj_weight,
5034
at::Tensor out_proj_bias) {
@@ -90,9 +74,8 @@ at::Tensor bt_min_mha(
9074
TORCH_CHECK(query_esize.height() == 1, "Query nested dim isn't 1.");
9175
auto query_esize_sizes = query_esize.sizes();
9276

93-
at::Tensor attr_mask = _sequence_mask(
94-
at::native::select(query_esize_sizes, 1, 0).contiguous());
95-
attr_mask = attr_mask.to(float_options);
77+
at::Tensor attr_mask = input_mask.view({-1, 1, 1, seq_len}).to(float_options);
78+
attr_mask = attr_mask * attr_mask.transpose(2, 3);
9679

9780
nteffectivetransformer::exclusiveScan_kernelLauncher(
9881
prefix_sum_ptr,
@@ -111,31 +94,29 @@ at::Tensor bt_min_mha(
11194
(int32_t)(embedding_dim),
11295
defaultStream);
11396

114-
// std::cout << "input_mask: " << input_mask << std::endl;
115-
// std::cout << "prefix_sum: " << prefix_sum << std::endl;
116-
// std::cout << "batch_idx: " << batch_idx << std::endl;
117-
// std::cout << "word_idx: " << word_idx << std::endl;
118-
119-
at::Tensor q, k, v;
120-
q = at::addmm(attr_bias_Q, query, attr_kernel_Q.t());
121-
k = at::addmm(attr_bias_K, key, attr_kernel_K.t());
122-
v = at::addmm(attr_bias_V, value, attr_kernel_V.t());
123-
at::Tensor q_buf = get_buffer(q);
124-
at::Tensor k_buf = get_buffer(k);
125-
at::Tensor v_buf = get_buffer(v);
126-
127-
int valid_word_num = prefix_sum.reshape({-1})[word_num - 1].item<int>();
128-
int last_mask = input_mask.reshape({-1})[word_num - 1].item<int>();
129-
if (last_mask == 1) {
130-
valid_word_num++;
131-
}
97+
at::Tensor packed = at::matmul(query, attr_kernel.t());
98+
at::Tensor packed_buf = get_buffer(packed).contiguous().reshape({-1, 3 * embedding_dim});
99+
std::vector<at::Tensor> packed_chunks = packed_buf.chunk(3, -1);
100+
at::Tensor q_buf = packed_chunks[0].contiguous().reshape({-1});
101+
at::Tensor k_buf = packed_chunks[1].contiguous().reshape({-1});
102+
at::Tensor v_buf = packed_chunks[2].contiguous().reshape({-1});
103+
104+
int valid_word_num = get_numel(query) / embedding_dim;
132105

133106
at::Tensor query_buf = torch::zeros(
134107
{batch_size, head_num, seq_len, size_per_head}, float_options);
135108
at::Tensor key_buf = torch::zeros(
136109
{batch_size, head_num, seq_len, size_per_head}, float_options);
137110
at::Tensor val_buf = torch::zeros(
138111
{batch_size, head_num, seq_len, size_per_head}, float_options);
112+
at::Tensor attr_out =
113+
torch::zeros({valid_word_num, embedding_dim}, float_options);
114+
115+
std::vector<at::Tensor> bias_chunks = attr_bias.chunk(3);
116+
at::Tensor attr_bias_Q = bias_chunks[0];
117+
at::Tensor attr_bias_K = bias_chunks[1];
118+
at::Tensor attr_bias_V = bias_chunks[2];
119+
139120
nteffectivetransformer::cuda::add_QKV_bias_padding_kernelLauncher<float>(
140121
q_buf.data_ptr<float>(),
141122
attr_bias_Q.data_ptr<float>(),
@@ -169,8 +150,6 @@ at::Tensor bt_min_mha(
169150

170151
auto attn_output = at::matmul(attn_output_weights, val_buf);
171152

172-
at::Tensor attr_out =
173-
torch::zeros({valid_word_num, embedding_dim}, float_options);
174153
nteffectivetransformer::cuda::transpose_rm_padding_kernelLauncher<float>(
175154
attn_output.data_ptr<float>(),
176155
attr_out.data_ptr<float>(),
@@ -184,7 +163,6 @@ at::Tensor bt_min_mha(
184163
defaultStream);
185164

186165
// TODO: Bias is variably sized, need to add support for that.
187-
// result = at::addmm(out_proj_bias, attr_out, out_proj_weight.t());
188166
at::Tensor result = at::matmul(attr_out, out_proj_weight.t());
189167
result = result.reshape({-1});
190168
return wrap_buffer(
@@ -195,7 +173,7 @@ at::Tensor bt_min_mha(
195173

196174
TORCH_LIBRARY_FRAGMENT(nestedtensor, m) {
197175
m.def(
198-
"bt_min_mha(int num_heads, int head_dim, float dropout_p, bool training, Tensor query, Tensor key, Tensor value, Tensor attr_kernel_Q, Tensor attr_kernel_K, Tensor attr_kernel_V, Tensor attr_bias_Q, Tensor attr_bias_K, Tensor attr_bias_V, float scaling, Tensor out_proj_weight, Tensor out_proj_bias) -> Tensor");
176+
"bt_min_mha(int num_heads, int head_dim, float dropout_p, bool training, Tensor query, Tensor key, Tensor value, Tensor attr_kernel, Tensor attr_bias, float scaling, Tensor out_proj_weight, Tensor out_proj_bias) -> Tensor");
199177
m.impl("bt_min_mha", NestedTensorKey, &bt_min_mha);
200178
}
201179

nestedtensor/csrc/matmul.cpp

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -49,56 +49,7 @@ Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other) {
4949
other);
5050
}
5151

52-
Tensor NestedTensor_addmm(
53-
const Tensor& bias,
54-
const Tensor& input,
55-
const Tensor& weight,
56-
const c10::Scalar& alpha,
57-
const c10::Scalar& beta) {
58-
if (!is_nested_tensor_impl(bias) && is_nested_tensor_impl(input) &&
59-
!is_nested_tensor_impl(weight)) {
60-
if (get_is_contiguous(input)) {
61-
if (get_dim(bias) == 1 && get_dim(input) == 3 && get_dim(weight) == 2) {
62-
auto input_opt_sizes = get_opt_sizes(input);
63-
if (input_opt_sizes[2]) {
64-
if (*input_opt_sizes[2] == weight.size(1)) {
65-
Tensor input_buffer = get_buffer(input);
66-
Tensor result_buffer =
67-
at::addmm(
68-
bias,
69-
input_buffer.reshape({-1, weight.size(1)}),
70-
weight,
71-
alpha,
72-
beta)
73-
.reshape({-1});
74-
int64_t weight_size_1 = weight.size(1);
75-
EfficientSizeNode result_nested_size = map_efficient_size(
76-
[weight_size_1](int64_t* data_ptr, int64_t size) {
77-
data_ptr[1] = weight_size_1;
78-
},
79-
get_efficient_nested_size(input));
80-
EfficientSizeNode input_nested_stride =
81-
get_efficient_nested_stride(input);
82-
return wrap_buffer(
83-
std::move(result_buffer),
84-
result_nested_size,
85-
input_nested_stride);
86-
}
87-
}
88-
}
89-
}
90-
}
91-
return map_nested_tensor(
92-
[&alpha, &beta](at::Tensor bias, at::Tensor input, at::Tensor weight) {
93-
return at::addmm(bias, input, weight, alpha, beta);
94-
},
95-
bias,
96-
input,
97-
weight);
98-
}
99-
10052
TORCH_LIBRARY_IMPL(aten, NestedTensor, m) {
101-
nt_impl(m, "addmm", NestedTensor_addmm);
10253
nt_impl(m, "matmul", NestedTensor_matmul);
10354
}
10455
} // namespace at

nestedtensor/csrc/mha.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,22 @@ at::Tensor min_mha(
3939
int64_t edim = *(opt_sizes[2]);
4040

4141
at::Tensor q, k, v;
42-
q = at::addmm(
43-
at::slice(*in_proj_bias, 0, 0, edim).contiguous(),
42+
q = at::matmul(
4443
query,
45-
at::slice(in_proj_weight, 0, 0, edim).t().contiguous(),
46-
scaling,
47-
scaling);
48-
k = at::addmm(
49-
at::slice(*in_proj_bias, 0, edim, 2 * edim).contiguous(),
44+
at::slice(in_proj_weight, 0, 0, edim).t().contiguous());
45+
k = at::matmul(
5046
key,
5147
at::slice(in_proj_weight, 0, edim, 2 * edim).t().contiguous());
52-
v = at::addmm(
53-
at::slice(*in_proj_bias, 0, 2 * edim).contiguous(),
48+
v = at::matmul(
5449
value,
5550
at::slice(in_proj_weight, 0, 2 * edim).t().contiguous());
5651

52+
q = q + at::slice(*in_proj_bias, 0, 0, edim).contiguous();
53+
k = k + at::slice(*in_proj_bias, 0, edim, 2 * edim).contiguous();
54+
v = v + at::slice(*in_proj_bias, 0, 2 * edim).contiguous();
55+
56+
q = q * torch::tensor(scaling);
57+
5758
q = q.reshape({-1, -1, num_heads, head_dim}).transpose(1, 2);
5859
k = k.reshape({-1, -1, num_heads, head_dim}).transpose(1, 2);
5960
v = v.reshape({-1, -1, num_heads, head_dim}).transpose(1, 2);
@@ -62,7 +63,8 @@ at::Tensor min_mha(
6263
attn_output_weights = at::dropout(attn_output_weights, dropout_p, training);
6364
auto attn_output = at::matmul(attn_output_weights, v);
6465
attn_output = attn_output.transpose(1, 2).reshape({-1, -1, edim});
65-
attn_output = at::addmm(out_proj_bias, attn_output, out_proj_weight.t());
66+
attn_output = at::matmul(attn_output, out_proj_weight.t());
67+
attn_output = attn_output + out_proj_bias;
6668
return attn_output;
6769
}
6870

nestedtensor/nn/mha.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,6 @@
1111
# NT case query, key, value have nested_dim 1 and are of shape (bsz, tgt_len, embed_dim)
1212

1313

14-
def sequence_mask(lengths, max_len=None, is_2d=True):
15-
batch_size = lengths.numel()
16-
max_len = max_len or lengths.max()
17-
mask = (torch.arange(0, max_len, device=lengths.device)
18-
.type_as(lengths)
19-
.repeat(batch_size, 1)
20-
.lt(lengths.unsqueeze(1)))
21-
if is_2d:
22-
return mask
23-
else:
24-
mask = mask.view(-1, 1, 1, max_len)
25-
m2 = mask.transpose(2, 3)
26-
return mask * m2
27-
28-
2914
def multi_head_attention_forward(query,
3015
key,
3116
value,
@@ -77,21 +62,15 @@ def multi_head_attention_forward(query,
7762
scaling = float(head_dim) ** -0.5
7863

7964
if query is key and key is value and in_proj_weight.is_cuda:
80-
w_q, w_k, w_v = in_proj_weight.chunk(3)
81-
b_q, b_k, b_v = in_proj_bias.chunk(3)
8265
return torch.ops.nestedtensor.bt_min_mha(num_heads,
8366
head_dim,
8467
0.5,
8568
False,
8669
query,
8770
query,
8871
query,
89-
w_q.contiguous(),
90-
w_k.contiguous(),
91-
w_v.contiguous(),
92-
b_q.contiguous(),
93-
b_k.contiguous(),
94-
b_v.contiguous(),
72+
in_proj_weight,
73+
in_proj_bias,
9574
scaling,
9675
out_proj_weight,
9776
in_proj_bias), None

nestedtensor/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
__version__ = '0.1.4+bf12d17'
2-
git_version = 'bf12d17c3b7891c713cb16b7e36926b873813ceb'
1+
__version__ = '0.1.4+e1d384f'
2+
git_version = 'e1d384fea9d70a664b38a53768f82c81057a7d13'
33
from nestedtensor import _C
44
if hasattr(_C, 'CUDA_VERSION'):
55
cuda = _C.CUDA_VERSION

0 commit comments

Comments
 (0)