Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit ea4c8fa

Browse files
cpuhrschfacebook-github-bot
authored andcommitted
20210525 nestedtensor import
Reviewed By: datumbox Differential Revision: D28679810 fbshipit-source-id: 60790f532183d3cfd05ab23158bcfa0c044c62b7
1 parent 744434a commit ea4c8fa

File tree

12 files changed

+371
-85
lines changed

12 files changed

+371
-85
lines changed

benchmarks/mha.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def from_tensor_list(cls, tensor_list):
5353
MODEL = torch.nn.MultiheadAttention(NDIM, NHEAD).to(DEVICE).eval()
5454

5555

56-
def run_benchmark(bsz, mean_i, mean_j, var, autograd, writer):
56+
def run_benchmark(bsz, mean_i, mean_j, var, writer):
5757
RAND_INTS = [(int(random.gauss(mean_j, var)), int(
5858
random.gauss(mean_i, var))) for _ in range(bsz)]
5959
src_ = nestedtensor.nested_tensor(
@@ -70,39 +70,31 @@ def gen_t_loop_mha(src):
7070
src, mask = detr_nt_src.decompose()
7171
src = src.flatten(2).permute(2, 0, 1).contiguous()
7272
mask = mask.flatten(1).contiguous()
73-
if autograd:
74-
src.requires_grad_()
7573

7674
def te():
77-
if autograd:
78-
MODEL(src, src, src, key_padding_mask=mask,
79-
need_weights=False)[0].sum() # .backward()
8075
MODEL(src, src, src, key_padding_mask=mask,
8176
need_weights=False)
8277

8378
return te
8479

8580
def gen_nt_mha(src):
8681
src = nestedtensor.nested_tensor([t.flatten(1).permute(
87-
1, 0) for t in src], device=DEVICE, dtype=torch.float, requires_grad=False)
82+
1, 0) for t in src], device=DEVICE, dtype=torch.float)
8883

8984
def nt():
90-
if autograd:
91-
MODEL(src, src, src, need_weights=False)[
92-
0].sum() # .backward()
9385
MODEL(src, src, src, need_weights=False)
9486

9587
return nt
9688

9789
result_t = {**utils.benchmark_fn(gen_t_loop_mha(src), 5.0, cuda=True), "bsz": bsz,
98-
"sparsity": sparsity, "autograd": autograd, "var": var, "mean_i": mean_i, "mean_j": mean_j}
90+
"sparsity": sparsity, "var": var, "mean_i": mean_i, "mean_j": mean_j}
9991
result_t["numel"] = sum([x.numel() for x in src_])
10092
result_t["numel_div_avg_us"] = result_t["numel"] / result_t["avg_us"]
10193
result_t["avg_ns_div_numel"] = result_t["avg_us"] / \
10294
result_t["numel"] * 1000
10395
writer.writerow(result_t)
10496
result_nt = {**utils.benchmark_fn(gen_nt_mha(src), 5.0, cuda=True),
105-
"bsz": bsz, "sparsity": 0.0, "autograd": autograd, "var": var, "mean_i": mean_i, "mean_j": mean_j}
97+
"bsz": bsz, "sparsity": 0.0, "var": var, "mean_i": mean_i, "mean_j": mean_j}
10698
result_nt["numel"] = sum([x.numel() for x in src_])
10799
result_nt["numel_div_avg_us"] = result_nt["numel"] / result_nt["avg_us"]
108100
result_nt["avg_ns_div_numel"] = result_nt["avg_us"] / \
@@ -115,10 +107,9 @@ def nt():
115107
torch.manual_seed(1011)
116108
writer = csv.DictWriter(sys.stdout, fieldnames=[
117109
"name", "avg_us", "std_us", "runs", "bsz", "sparsity",
118-
"autograd", "var", "mean_i", "mean_j", "numel", "numel_div_avg_us",
110+
"var", "mean_i", "mean_j", "numel", "numel_div_avg_us",
119111
"avg_ns_div_numel"])
120112
writer.writeheader()
121113
for var in [float(i) / 10 for i in range(0, 100, 50)]:
122-
for autograd in [False]:
123-
for batch_size in [2, 8, 16]:
124-
run_benchmark(batch_size, 30, 30, var, autograd, writer)
114+
for batch_size in [2, 8, 16]:
115+
run_benchmark(batch_size, 30, 30, var, writer)

benchmarks/mha_cuda.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import time
3+
import nestedtensor
4+
5+
6+
@torch.inference_mode()
7+
def benchmark_torch_function(iters, f, *args):
8+
f(*args)
9+
if torch.cuda.is_available():
10+
torch.cuda.synchronize()
11+
start_event = torch.cuda.Event(enable_timing=True)
12+
end_event = torch.cuda.Event(enable_timing=True)
13+
start_event.record()
14+
else:
15+
t0 = time.time()
16+
for _ in range(iters):
17+
f(*args)
18+
if torch.cuda.is_available():
19+
end_event.record()
20+
torch.cuda.synchronize()
21+
return start_event.elapsed_time(end_event) * 1e3
22+
else:
23+
return (time.time() - t0) * 1e6
24+
25+
26+
def run(bdim, embedding_dim, nhead, min_t, max_t, iters, device):
27+
import random
28+
random.seed(1010)
29+
30+
# The following is meant to emulate the lenghts of randomly sampled tokenized sentences
31+
lengths = [random.randint(min_t, max_t) for _ in range(bdim)]
32+
lengths_mean = torch.tensor(lengths, dtype=torch.float).mean().item()
33+
lengths_std = torch.tensor(lengths, dtype=torch.float).std().item()
34+
35+
# List of sentence embeddings
36+
tensors = [torch.rand(i, embedding_dim) for i in lengths]
37+
# Create packed NestedTensor
38+
nt = nestedtensor.nested_tensor(tensors, device=device, dtype=torch.float)
39+
40+
# Create MHA with self-attention in mind
41+
lin = torch.nn.MultiheadAttention(embedding_dim, nhead).to(device).eval()
42+
nt_time = benchmark_torch_function(iters, lin, nt, nt, nt)
43+
# import sys; sys.exit(1)
44+
45+
# Created regular padded Tensor
46+
data = nt.to_padded_tensor(padding=0)
47+
# Amount of storage used for padding only
48+
percentage_padded = 100 * (data.numel() - nt.numel()) / data.numel()
49+
t_time = benchmark_torch_function(iters, lin, data, data, data)
50+
51+
print(f"batch size: {bdim:4.0f}, embedding dim: {embedding_dim}, nhead: {nhead}, T mean:{lengths_mean:5.0f}, T std: {lengths_std:4.0f}", end='')
52+
print(f", padding: {percentage_padded:3.0f}%, NT: {nt_time/iters:4.0f}us, T: {t_time/iters:4.0f}us, Speedup: {t_time/nt_time:3.2f}x")
53+
54+
55+
device = torch.device('cpu')
56+
if torch.cuda.is_available():
57+
print("CUDA device: ", torch.cuda.get_device_name(0))
58+
device = torch.device('cuda')
59+
iters = 1000
60+
for nhead in [2, 4, 8]:
61+
print("")
62+
for embed_dim in [128, 256, 512, 1024]:
63+
print("")
64+
for min_t, max_t in [(16, 128), (32, 128), (64, 128), (128, 128)]:
65+
run(256, embed_dim, nhead, min_t, max_t, iters, device)

benchmarks/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
from nestedtensor import torch
1+
import torch
22
import time
3-
import random
4-
import pprint
53

64
import cProfile
75
import pstats
@@ -18,7 +16,8 @@ def gen_tensor():
1816
# return torch.tensor([globals()['SEED']])
1917
return torch.rand(EMBED_DIM)
2018

21-
def benchmark_fn(fn, run_time = 5.0, use_cprofile=False, warmup=1.0, cuda=False):
19+
20+
def benchmark_fn(fn, run_time=5.0, use_cprofile=False, warmup=1.0, cuda=False):
2221
times = []
2322
t = 0.0
2423
pr = cProfile.Profile()

nestedtensor/csrc/cuda/mha.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,23 @@ using namespace at;
1919
namespace torch {
2020
namespace nested_tensor {
2121

22+
at::Tensor _sequence_mask(at::Tensor lengths) {
23+
int64_t batch_size = lengths.numel();
24+
int64_t max_len = lengths.max().item<int64_t>();
25+
at::Tensor mask = torch::arange(0, max_len, torch::kFloat);
26+
mask = mask.repeat({batch_size, 1});
27+
mask = mask.lt(lengths.unsqueeze(1));
28+
mask = mask.to(torch::kCUDA);
29+
mask = mask.view({-1, 1, 1, max_len});
30+
at::Tensor m2 = mask.transpose(2, 3);
31+
return mask * m2;
32+
}
33+
2234
at::Tensor bt_min_mha(
2335
int64_t num_heads,
2436
int64_t head_dim,
2537
double dropout_p,
2638
bool training,
27-
at::Tensor input_mask,
2839
at::Tensor query,
2940
at::Tensor key,
3041
at::Tensor value,
@@ -36,8 +47,7 @@ at::Tensor bt_min_mha(
3647
at::Tensor attr_bias_V,
3748
double scaling,
3849
at::Tensor out_proj_weight,
39-
at::Tensor out_proj_bias,
40-
at::Tensor attr_mask) {
50+
at::Tensor out_proj_bias) {
4151
// TODO: Assert that max seq_len is 1024!
4252
TORCH_CHECK(get_dim(query) == 3, "query needs to be 3 dim.");
4353
TORCH_CHECK(get_dim(key) == 3, "key needs to be 3 dim.");
@@ -49,15 +59,17 @@ at::Tensor bt_min_mha(
4959
// }
5060
// TODO: Add explicit check that verifies query, key and value are the same
5161
// auto start = std::chrono::system_clock::now();
62+
auto options =
63+
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
64+
at::Tensor input_mask = to_mask(query, 2);
65+
input_mask = input_mask.to(options);
5266
int64_t batch_size = input_mask.size(0);
5367
int64_t seq_len = input_mask.size(1);
5468
int64_t embedding_dim = head_dim * num_heads; //*(opt_sizes[2]);
5569
int64_t head_num = num_heads;
5670
int64_t size_per_head = embedding_dim / head_num;
5771
auto float_options =
5872
torch::TensorOptions().dtype(torch::kFloat).device(torch::kCUDA);
59-
auto options =
60-
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
6173
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
6274
at::cuda::setCurrentCUDAStream(defaultStream);
6375

@@ -74,6 +86,14 @@ at::Tensor bt_min_mha(
7486

7587
at::Tensor tmp = get_buffer(query);
7688

89+
auto query_esize = get_efficient_nested_size(query);
90+
TORCH_CHECK(query_esize.height() == 1, "Query nested dim isn't 1.");
91+
auto query_esize_sizes = query_esize.sizes();
92+
93+
at::Tensor attr_mask = _sequence_mask(
94+
at::native::select(query_esize_sizes, 1, 0).contiguous());
95+
attr_mask = attr_mask.to(float_options);
96+
7797
nteffectivetransformer::exclusiveScan_kernelLauncher(
7898
prefix_sum_ptr,
7999
input_mask.data_ptr<int>(),
@@ -175,7 +195,7 @@ at::Tensor bt_min_mha(
175195

176196
TORCH_LIBRARY_FRAGMENT(nestedtensor, m) {
177197
m.def(
178-
"bt_min_mha(int num_heads, int head_dim, float dropout_p, bool training, Tensor input_mask, Tensor query, Tensor key, Tensor value, Tensor attr_kernel_Q, Tensor attr_kernel_K, Tensor attr_kernel_V, Tensor attr_bias_Q, Tensor attr_bias_K, Tensor attr_bias_V, float scaling, Tensor out_proj_weight, Tensor out_proj_bias, Tensor attr_mask) -> Tensor");
198+
"bt_min_mha(int num_heads, int head_dim, float dropout_p, bool training, Tensor query, Tensor key, Tensor value, Tensor attr_kernel_Q, Tensor attr_kernel_K, Tensor attr_kernel_V, Tensor attr_bias_Q, Tensor attr_bias_K, Tensor attr_bias_V, float scaling, Tensor out_proj_weight, Tensor out_proj_bias) -> Tensor");
179199
m.impl("bt_min_mha", NestedTensorKey, &bt_min_mha);
180200
}
181201

nestedtensor/csrc/masking.cpp

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,20 @@ std::vector<int64_t> _get_max_size(const SizeNode& size_node) {
7979
return result;
8080
}
8181

82-
std::vector<int64_t> get_max_size(Tensor nt) {
82+
std::vector<int64_t> get_max_size(const Tensor& nt) {
83+
if (get_nested_dim(nt) == 1) {
84+
auto nt_opt_sizes = get_opt_sizes(nt);
85+
if (nt_opt_sizes.size() > 0 && *nt_opt_sizes[0] > 0) {
86+
auto esize = get_efficient_nested_size(nt);
87+
auto sizes = esize.sizes();
88+
auto max_sizes = std::get<0>(sizes.max(0));
89+
std::vector<int64_t> result;
90+
for (int64_t i = 0; i < max_sizes.size(0); i++) {
91+
result.push_back(max_sizes[i].item<int64_t>());
92+
}
93+
return result;
94+
}
95+
}
8396
return _get_max_size(get_nested_size(nt));
8497
}
8598

@@ -203,7 +216,6 @@ std::tuple<Tensor, Tensor> to_tensor_mask(
203216
auto nt_opt_size = get_opt_sizes(nt);
204217
Tensor nt_buffer = get_buffer(nt);
205218
if (nt_opt_size[2] && nt_buffer.is_cuda()) {
206-
std::cout << "Calling efficient to_tensor_mask" << std::endl;
207219
Tensor nt_sizes_ =
208220
get_efficient_nested_size(nt).sizes().to(torch::kInt32);
209221
TORCH_CHECK(nt_sizes_.dim() == 2, "NestedTensor must be of nested_dim 2.")
@@ -258,6 +270,127 @@ std::tuple<Tensor, Tensor> to_tensor_mask(
258270
return merge_tensor_mask(res_tensor, res_mask, mask_dim);
259271
}
260272

273+
Tensor merge_mask(
274+
Tensor mask,
275+
c10::optional<int64_t> mask_dim) {
276+
if (mask_dim && get_dim(mask) == (*mask_dim)) {
277+
return mask;
278+
}
279+
280+
if (get_dim(mask) == 0) {
281+
return mask;
282+
}
283+
284+
int64_t last_size = mask.size(-1);
285+
Tensor collapsed_mask = mask.sum(-1);
286+
Tensor is_last_size = (collapsed_mask == last_size);
287+
Tensor is_zero = (collapsed_mask == 0);
288+
int64_t is_last_size_sum = is_last_size.sum().item<int64_t>();
289+
int64_t is_zero_sum = is_zero.sum().item<int64_t>();
290+
if ((is_last_size_sum + is_zero_sum) == get_numel(collapsed_mask)) {
291+
collapsed_mask = collapsed_mask.to(torch::kBool);
292+
return merge_mask(collapsed_mask, mask_dim);
293+
}
294+
295+
if (mask_dim && mask_dim != get_dim(mask)) {
296+
throw std::runtime_error(
297+
"Mask dimension is too small to represent data tensor.");
298+
}
299+
// This is expected to be a no-op, except in rare cases.
300+
mask = mask.contiguous();
301+
return mask;
302+
}
303+
304+
Tensor _create_nt_mask(std::vector<int64_t> sizes, std::vector<int64_t> shape) {
305+
int64_t numel = 1;
306+
for (size_t i = 0; i < sizes.size(); i++) {
307+
numel = numel * sizes[i];
308+
}
309+
TORCH_CHECK(numel > 0, "Empty tensors are not yet supported.");
310+
// Dont pad in case of a scalar
311+
if (sizes.size() == 0) {
312+
return torch::tensor(true);
313+
}
314+
auto options = torch::TensorOptions().dtype(torch::kByte);
315+
Tensor mask = pad_tensor_to_shape(
316+
torch::full(
317+
IntArrayRef(sizes),
318+
true,
319+
options),
320+
shape);
321+
return mask;
322+
}
323+
324+
Tensor _create_nt_mask(SizeNode nt_size, std::vector<int64_t> shape) {
325+
if (nt_size.degree() == 0) {
326+
return _create_nt_mask(nt_size.payload(), shape);
327+
}
328+
329+
std::vector<Tensor> res_mask;
330+
if (nt_size.degree() == 0) {
331+
return torch::tensor({false}, torch::kByte);
332+
} else {
333+
for (auto child : nt_size.unbind()) {
334+
Tensor mask = _create_nt_mask(child, shape);
335+
res_mask.push_back(mask);
336+
}
337+
}
338+
339+
return at::stack(res_mask);
340+
}
341+
342+
Tensor _create_nt_mask(EfficientSizeNode nt_size, std::vector<int64_t> shape) {
343+
if (nt_size.height() == 1) {
344+
std::vector<at::Tensor> tmp_masks;
345+
auto esizes = nt_size.sizes();
346+
int64_t* esizes_ptr = esizes.data_ptr<int64_t>();
347+
for(int64_t i = 0; i < esizes.size(0); i++) {
348+
std::vector<int64_t> tmp_sizes;
349+
for(size_t j = 0; j < shape.size(); j++) {
350+
tmp_sizes.push_back(esizes_ptr[i * esizes.stride(0) + j]);
351+
}
352+
tmp_masks.push_back(_create_nt_mask(tmp_sizes, shape));
353+
}
354+
return at::stack(tmp_masks);
355+
}
356+
return _create_nt_mask(nt_size.to_size_node(), shape);
357+
}
358+
359+
Tensor to_mask(
360+
Tensor nt,
361+
c10::optional<int64_t> mask_dim) {
362+
TORCH_CHECK(
363+
!mask_dim || *mask_dim <= get_dim(nt),
364+
"Requested mask dimension ",
365+
*mask_dim,
366+
" is bigger than dimension ",
367+
get_dim(nt),
368+
" of given NestedTensor.");
369+
370+
371+
auto opt_sizes = get_opt_sizes(nt);
372+
if (opt_sizes.size() == 1 && *opt_sizes[0] == 1) {
373+
Tensor result_mask = !mask_dim || *mask_dim == 0 ? torch::tensor(true)
374+
: torch::tensor({true});
375+
return result_mask;
376+
}
377+
378+
std::vector<int64_t> max_size;
379+
if (get_nested_dim(nt) == 1 &&
380+
get_dim(nt) > 1 &&
381+
mask_dim &&
382+
*mask_dim > 1) {
383+
auto tmp_max_size = get_max_size(nt);
384+
for (int64_t i = 1; i < *mask_dim; i++) {
385+
max_size.push_back(tmp_max_size[i - 1]);
386+
}
387+
return _create_nt_mask(get_efficient_nested_size(nt), max_size);
388+
}
389+
max_size = get_max_size(nt);
390+
at::Tensor res_mask = _create_nt_mask(get_efficient_nested_size(nt), max_size);
391+
return merge_mask(res_mask, mask_dim);
392+
}
393+
261394
Tensor to_padded_tensor(Tensor nt, double padding) {
262395
#ifdef WITH_CUDA
263396
if (get_dim(nt) == 3 && get_is_contiguous(nt)) {
@@ -315,6 +448,9 @@ TORCH_LIBRARY_FRAGMENT(nestedtensor, m) {
315448
m.def("to_tensor_mask(Tensor nt, int? mask_dim) -> (Tensor, Tensor)");
316449
m.impl("to_tensor_mask", NestedTensorKey, to_tensor_mask);
317450

451+
m.def("to_mask(Tensor nt, int? mask_dim) -> Tensor");
452+
m.impl("to_mask", NestedTensorKey, to_mask);
453+
318454
m.def("to_padded_tensor(Tensor nt, float padding) -> Tensor");
319455
m.impl("to_padded_tensor", NestedTensorKey, to_padded_tensor);
320456
}

0 commit comments

Comments
 (0)