Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit fcf06d2

Browse files
cpuhrschfacebook-github-bot
authored andcommitted
20210520 nestedtensor import
Reviewed By: NicolasHug Differential Revision: D28575245 fbshipit-source-id: d4853e8914772a42037490e4f87e575b959ba6d4
1 parent 7a07b06 commit fcf06d2

File tree

6 files changed

+87
-44
lines changed

6 files changed

+87
-44
lines changed

benchmarks/linear.py

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,61 @@
11
import torch
2+
import time
23
import nestedtensor
3-
import utils
4-
5-
import random
6-
random.seed(1010)
7-
8-
BDIM=10
9-
10-
# Performance tanks hard for lots of small Tensors as expected
11-
RAND_INTS = [random.randint(100, 300) for _ in range(BDIM)]
12-
13-
OUTDIM=256
14-
GOALDIM=512
15-
16-
TENSORS0 = [torch.rand(i, OUTDIM).cuda() for i in RAND_INTS]
17-
18-
def gen_t_linear():
19-
nt0 = nestedtensor.nested_tensor(TENSORS0, device=torch.device('cuda'), dtype=torch.float)
20-
data, _ = nt0.to_tensor_mask()
21-
lin = torch.nn.Linear(OUTDIM, GOALDIM).cuda()
22-
23-
def t():
24-
lin(data)
25-
return t
264

275

286
@torch.inference_mode()
29-
def gen_nt_linear():
30-
nt0 = nestedtensor.nested_tensor(TENSORS0, device=torch.device('cuda'), dtype=torch.float)
31-
lin = torch.nn.Linear(OUTDIM, GOALDIM).cuda()
32-
33-
def nt():
34-
lin(nt0)
35-
# print("nt0.size()")
36-
# print(nt0.size())
37-
# import sys; sys.exit(1)
38-
return nt
39-
40-
41-
if __name__ == "__main__":
42-
print(utils.benchmark_fn(gen_t_linear()))
43-
print(utils.benchmark_fn(gen_nt_linear()))
7+
def benchmark_torch_function(iters, f, *args):
8+
f(*args)
9+
if torch.cuda.is_available():
10+
torch.cuda.synchronize()
11+
start_event = torch.cuda.Event(enable_timing=True)
12+
end_event = torch.cuda.Event(enable_timing=True)
13+
start_event.record()
14+
else:
15+
t0 = time.time()
16+
for _ in range(iters):
17+
f(*args)
18+
if torch.cuda.is_available():
19+
end_event.record()
20+
torch.cuda.synchronize()
21+
return start_event.elapsed_time(end_event)
22+
else:
23+
return (time.time() - t0) * 1e3
24+
25+
26+
def run(bdim, embedding_dim, out_dim, min_t, max_t, iters, device):
27+
import random
28+
random.seed(1010)
29+
30+
# The following is meant to emulate the lenghts of randomly sampled tokenized sentences
31+
lengths = [random.randint(min_t, max_t) for _ in range(bdim)]
32+
lengths_mean = torch.tensor(lengths, dtype=torch.float).mean().item()
33+
lengths_std = torch.tensor(lengths, dtype=torch.float).std().item()
34+
35+
# List of sentence embeddings
36+
tensors = [torch.rand(i, embedding_dim) for i in lengths]
37+
# Create packed NestedTensor
38+
nt = nestedtensor.nested_tensor(tensors, device=device, dtype=torch.float)
39+
# Created regular padded Tensor
40+
data = nt.to_padded_tensor(padding=0)
41+
# Amount of storage used for padding only
42+
percentage_padded = 100 * (data.numel() - nt.numel()) / data.numel()
43+
44+
# Projects embeddings into another space
45+
lin = torch.nn.Linear(embedding_dim, out_dim).to(device)
46+
nt_time = benchmark_torch_function(iters, lin, nt)
47+
t_time = benchmark_torch_function(iters, lin, data)
48+
49+
print(f"batch size: {bdim:4.0f}, embedding dim: {embedding_dim}, out_dim: {out_dim}, T mean:{lengths_mean:5.0f}, T std: {lengths_std:4.0f}", end='')
50+
print(f", padding: {percentage_padded:3.0f}%, NT: {nt_time/iters:4.0f}ms, T: {t_time/iters:4.0f}ms, Speedup: {t_time/nt_time:3.2f}x")
51+
52+
53+
if torch.cuda.is_available():
54+
print("CUDA device: ", torch.cuda.get_device_name(0))
55+
iters = 10
56+
for out_dim in [4096, 2048, 1024, 512, 256]:
57+
print("")
58+
for embed_dim in [4096, 2048, 1024, 512, 256]:
59+
print("")
60+
for min_t, max_t in [(16, 128), (32, 128), (64, 128), (128, 128)]:
61+
run(256, embed_dim, out_dim, min_t, max_t, iters, torch.device('cuda'))

nestedtensor/csrc/BinaryOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ Tensor NestedTensor_add_Tensor(
88
const Tensor& self_,
99
const Tensor& other_,
1010
const Scalar& alpha) {
11-
Tensor self;
12-
Tensor other;
13-
std::tie(self, other) = _expand_other_as(self_, other_);
11+
Tensor self = self_;
12+
Tensor other = other_;
1413
if (is_nested_tensor_impl(self) && is_nested_tensor_impl(other)) {
1514
EfficientSizeNode self_efficient_nested_size =
1615
get_efficient_nested_size(self);
@@ -49,6 +48,7 @@ Tensor NestedTensor_add_Tensor(
4948
get_efficient_nested_stride(self));
5049
}
5150
}
51+
std::tie(self, other) = _expand_other_as(self_, other_);
5252
return map_nested_tensor(
5353
[&alpha](Tensor s, Tensor o) { return at::add(s, o, alpha); },
5454
self,

nestedtensor/csrc/storage/EfficientSizeNode.h

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ inline std::vector<c10::optional<int64_t>> construct_efficient_size(
4040
}
4141

4242
inline void _efficient_serialize(
43-
SizeNode nested_node,
43+
const SizeNode& nested_node,
4444
std::vector<int64_t>& out) {
4545
if (!nested_node.is_leaf()) {
4646
out.push_back(nested_node.degree());
@@ -50,7 +50,7 @@ inline void _efficient_serialize(
5050
}
5151
}
5252

53-
inline std::vector<int64_t> efficient_serialize(SizeNode nested_node) {
53+
inline std::vector<int64_t> efficient_serialize(const SizeNode& nested_node) {
5454
std::vector<int64_t> out;
5555
_efficient_serialize(nested_node, out);
5656
return out;
@@ -85,7 +85,7 @@ inline SizeNode efficient_deserialize(
8585
} // namespace impl
8686

8787
struct EfficientSizeNode {
88-
explicit EfficientSizeNode(SizeNode size_node)
88+
explicit EfficientSizeNode(const SizeNode& size_node)
8989
: _height(size_node.height()),
9090
_structure(impl::efficient_serialize(size_node)),
9191
_sizes(impl::stack_sizes(size_node)) {}
@@ -130,6 +130,22 @@ struct EfficientSizeNode {
130130
EfficientSizeNode clone() const {
131131
return EfficientSizeNode(_height, _structure, _sizes.clone());
132132
}
133+
int64_t numel() const {
134+
if (_sizes.dim() == 0 && _structure.size() > 0) {
135+
return _structure[0];
136+
}
137+
if (_sizes.dim() > 0) {
138+
Tensor nt_sizes = at::native::narrow(
139+
_sizes, 1 /* dim */, 0 /* start */, 1 /* length */);
140+
for (int64_t i = 1; i < _sizes.size(1); i++) {
141+
Tensor tmp = at::native::narrow(
142+
_sizes, 1 /* dim */, i /* start */, 1 /* length */);
143+
nt_sizes = nt_sizes * tmp;
144+
}
145+
return nt_sizes.sum().item<int64_t>();
146+
}
147+
return 0;
148+
}
133149

134150
private:
135151
int64_t _height;

nestedtensor/csrc/storage/List.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ struct ListStorage : public NestedTensorStorage {
6161
return get_first_leaf(_structure) ? get_first_leaf(_structure)->is_cuda()
6262
: false;
6363
}
64+
int64_t numel() const override {
65+
return _nested_size.numel();
66+
}
6467

6568
private:
6669
TensorNode _structure;

nestedtensor/csrc/storage/Packed.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ struct PackedStorage : public NestedTensorStorage {
173173
bool is_cuda() const override {
174174
return _buffer.is_cuda();
175175
}
176+
int64_t numel() const override {
177+
return _nested_size.numel();
178+
}
176179

177180
private:
178181
at::Tensor _buffer;

nestedtensor/csrc/storage/StorageBase.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ struct NestedTensorStorage {
4141
virtual bool is_cuda() const {
4242
TORCH_CHECK(false, "Not Implemented.");
4343
}
44+
virtual int64_t numel() const {
45+
TORCH_CHECK(false, "Not Implemented.");
46+
}
4447
};
4548
} // namespace nested_tensor
4649
} // namespace torch

0 commit comments

Comments
 (0)