Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit cf952fe

Browse files
cpuhrschfacebook-github-bot
authored andcommitted
20210516 nestedtensor import
Summary: Import from GH Reviewed By: datumbox Differential Revision: D28468749 fbshipit-source-id: 63f20f81a2585910d6a45312a050768b4d373632
1 parent 746a32b commit cf952fe

File tree

10 files changed

+294
-207
lines changed

10 files changed

+294
-207
lines changed

README.md

Lines changed: 29 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,42 @@ If you are here because you ran into a runtime error due to a missing feature or
66

77
If you are new to this project, we recommend you take a look at our [whirlwind introduction](https://colab.research.google.com/github/pytorch/nestedtensor/blob/master/tutorials/notebooks/basic.ipynb) to get started.
88

9-
## Operator support
9+
## Autograd support
1010

11-
Please see [the list of currently supported operators](https://github.com/pytorch/nestedtensor/blob/master/nestedtensor/csrc/README.md) and [open an issue](https://github.com/pytorch/nestedtensor/issues/new/choose) if you find you need one for your project that's not listed.
11+
Due to missing extensibility features of PyTorch nestedtensor currently lacks autograd support. We're actively working on this and recognize that it severely limits the applicability of the project. Please run nestedtensor operations within the [inference mode](https://github.com/ailzhang/rfcs/blob/rfc0011/RFC-0011-InferenceMode.md) context to prevent any adverse interactions with the autograd system.
12+
13+
For example
14+
```
15+
sentences = [torch.randn(10, 5), torch.randn(5, 5), torch.randn(9, 5)]
16+
with torch.inference_mode():
17+
nt = nestedtensor.nested_tensor(sentences)
18+
nt.sum(1)
19+
```
1220

1321
## Binaries
1422

15-
The nestedtensor project is built on top of a torch fork for improved interoperability and also ships with torchvision binaries that were built against this fork. To use NestedTensors you need to install this version of torch, which is frequently rebased upon PyTorch's [viable/strict](https://github.com/pytorch/pytorch/tree/viable/strict) branch (most recent master where all tests pass).
23+
Due to the development velocity of PyTorch the nestedtensor project is built on top of and dependent on a fixed, recent PyTorch nightly.
1624

1725
| Version | Python | CUDA | Wheels |
1826
| --- | ---- | ------ | ---- |
19-
| 0.1.1 | 3.6 | CPU-only | [torch](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.6/torch-1.8.0_nestedtensor_0.1.1_cpu-cp36-cp36m-linux_x86_64.whl), [nestedtensor](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.6/nestedtensor-0.1.1_cpu-cp36-cp36m-linux_x86_64.whl), [torchvision](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.6/torchvision-0.1.1_cpu-cp36-cp36m-linux_x86_64.whl) |
20-
| 0.1.1 | 3.7 | CPU-only | [torch](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.7/torch-1.8.0_nestedtensor_0.1.1_cpu-cp37-cp37m-linux_x86_64.whl), [nestedtensor](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.7/nestedtensor-0.1.1_cpu-cp37-cp37m-linux_x86_64.whl), [torchvision](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.7/torchvision-0.1.1_cpu-cp37-cp37m-linux_x86_64.whl) |
21-
| 0.1.1 | 3.8 | CPU-only | [torch](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.8/torch-1.8.0_nestedtensor_0.1.1_cpu-cp38-cp38m-linux_x86_64.whl), [nestedtensor](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.8/nestedtensor-0.1.1_cpu-cp38-cp38m-linux_x86_64.whl), [torchvision](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.8/torchvision-0.1.1_cpu-cp38-cp38m-linux_x86_64.whl) |
27+
| 0.1.1 | 3.6 | CPU-only | [nestedtensor](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.6/nestedtensor-0.1.1_cpu-cp36-cp36m-linux_x86_64.whl) |
28+
| 0.1.1 | 3.7 | CPU-only | [nestedtensor](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.7/nestedtensor-0.1.1_cpu-cp37-cp37m-linux_x86_64.whl) |
29+
| 0.1.1 | 3.8 | CPU-only | [nestedtensor](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.8/nestedtensor-0.1.1_cpu-cp38-cp38m-linux_x86_64.whl) |
30+
| 0.1.1 | 3.6 | CUDA 10.2 | [nestedtensor](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.6/nestedtensor-0.1.1_cu102-cp36-cp36m-linux_x86_64.whl) |
31+
| 0.1.1 | 3.7 | CUDA 10.2 | [nestedtensor](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.7/nestedtensor-0.1.1_cu102-cp37-cp37m-linux_x86_64.whl) |
32+
| 0.1.1 | 3.8 | CUDA 10.2 | [nestedtensor](https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.8/nestedtensor-0.1.1_cu102-cp38-cp38m-linux_x86_64.whl) |
33+
34+
When installing a binary please specify the corresponding torch nightly link archive to automatically pull in the correct PyTorch nightly.
35+
36+
CPU
37+
```
38+
pip install https://download.pytorch.org/nestedtensor/whl/nightly/cpu/py3.7/nestedtensor-0.1.1_cpu-cp37-cp37m-linux_x86_64.whl -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
39+
```
40+
41+
CUDA 10.2
42+
```
43+
pip install https://download.pytorch.org/nestedtensor/whl/nightly/cu102/py3.7/nestedtensor-0.1.1_cu102-cp37-cp37m-linux_x86_64.whl -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html
44+
```
2245

2346
## Why consider using this? / Dealing with dynamic shapes
2447

@@ -63,52 +86,12 @@ a NestedTensor is still a Tensor. That means it needs to have a single dimension
6386

6487
The nestedtensor package is a prototype intended for early stage feedback and testing. It is on the road to a beta classification, but there is no definitive timeline yet. See [PyTorch feature classification](https://pytorch.org/docs/stable/index.html) for what prototype, beta and stale means.
6588

66-
## Supported platforms
67-
68-
It is developed [against a fork](https://github.com/cpuhrsch/pytorchnestedtensor) of PyTorch to enable cutting-edge features such as improved performance or better `torch.vmap` integration.
69-
70-
Developers will thus need to build from source, but users can use the binary we will start shipping soon ([see the related issue](https://github.com/pytorch/nestedtensor/issues/262)).
71-
72-
If you want to use the binaries you need to run on Linux, use Python 3.8+ and have a CUDA-11 toolkit installed.
73-
74-
If you want to build from source you can probably get it to work on many platforms, but supporting other platforms won't take priority over Linux. We're happy to review community contributions that achieve this however.
75-
7689
## Dependencies
7790

7891
- pytorch (installed from nestedtensor/third_party/pytorch submodule)
7992
- torchvision (needed for examples and tests)
8093
- ipython (needed for examples)
8194
- notebook (needed for examples)
8295

83-
## Build for development
84-
85-
Get the source
86-
87-
```
88-
git clone --recursive https://github.com/pytorch/nestedtensor
89-
cd nestedtensor
90-
# if you are updating an existing checkout
91-
git submodule sync
92-
git submodule update --init --recursive
93-
```
94-
95-
Install the build tools
96-
97-
```
98-
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests
99-
conda install -c pytorch magma-cuda110
100-
```
101-
102-
Build from scratch
103-
```
104-
./clean_build_with_submodule.sh
105-
```
106-
107-
Incremental builds
108-
```
109-
./build_with_submodule.sh
110-
```
111-
112-
11396
## Contribution
11497
The project is under active development. If you have a suggestions or found a bug, please file an issue!

benchmarks/matmul.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,47 +5,35 @@
55
import random
66
random.seed(1010)
77

8+
BDIM=10
9+
810
# Performance tanks hard for lots of small Tensors as expected
911
RAND_INTS = [random.randint(10, 30) for _ in range(2000)]
10-
RAND_INTS = [random.randint(1000, 3000) for _ in range(20)]
1112

12-
TENSORS0 = [torch.rand(9, 245, 2560, requires_grad=True).cuda() for i in RAND_INTS]
13-
TENSORS1 = [torch.rand(9, 2560, 245, requires_grad=True).cuda() for i in RAND_INTS]
13+
OUTDIM=256
14+
15+
TENSORS0 = [torch.rand(i, OUTDIM).cuda() for i in RAND_INTS]
1416

1517
def gen_t_matmul():
16-
tensor0 = torch.stack(TENSORS0)
17-
tensor1 = torch.stack(TENSORS1)
18+
nt0 = nestedtensor.nested_tensor(TENSORS0, device=torch.device('cuda'), dtype=torch.float)
19+
data, _ = nt0.to_tensor_mask()
20+
t1 = torch.randn(OUTDIM, 512).cuda()
1821

1922
def t():
20-
tensor0.requires_grad_()
21-
tensor1.requires_grad_()
22-
torch.matmul(tensor0, tensor1).sum().backward()
23-
tensor0.detach_()
24-
tensor1.detach_()
23+
torch.matmul(data, t1)
2524
return t
2625

2726

28-
def gen_t_loop_matmul():
29-
tensors = [torch.rand(i, 2560).cuda() for i in RAND_INTS]
30-
31-
def t_loop():
32-
for (t0, t1) in zip(TENSORS0, TENSORS1):
33-
torch.matmul(t0, t1).sum().backward()
34-
t0.grad = None
35-
t1.grad = None
36-
return t_loop
37-
38-
27+
@torch.inference_mode()
3928
def gen_nt_matmul():
40-
nt0 = nestedtensor.nested_tensor(TENSORS0, device=torch.device('cuda'), dtype=torch.float, requires_grad=True)
41-
nt1 = nestedtensor.nested_tensor(TENSORS1, device=torch.device('cuda'), dtype=torch.float, requires_grad=True)
29+
nt0 = nestedtensor.nested_tensor(TENSORS0, device=torch.device('cuda'), dtype=torch.float)
30+
t1 = torch.randn(OUTDIM, 512).cuda()
4231

4332
def nt():
44-
torch.matmul(nt0, nt1).sum().backward()
33+
torch.matmul(nt0, t1)
4534
return nt
4635

4736

4837
if __name__ == "__main__":
49-
# print(utils.benchmark_fn(gen_t_matmul()))
50-
# print(utils.benchmark_fn(gen_t_loop_matmul()))
38+
print(utils.benchmark_fn(gen_t_matmul()))
5139
print(utils.benchmark_fn(gen_nt_matmul()))

nestedtensor/csrc/matmul.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,40 @@ namespace F = torch::nn::functional;
99
namespace at {
1010

1111
Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other) {
12+
if (is_nested_tensor_impl(self) && !is_nested_tensor_impl(other)) {
13+
if (get_is_contiguous(self) && get_is_contiguous(other)) {
14+
if (get_dim(self) == 3 && get_dim(other) == 2) {
15+
auto self_opt_sizes = get_opt_sizes(self);
16+
if (self_opt_sizes[2]) {
17+
if (*self_opt_sizes[2] == other.size(0)) {
18+
Tensor self_buffer = get_buffer(self);
19+
Tensor result_buffer =
20+
at::matmul(self_buffer.reshape({-1, other.size(0)}), other);
21+
result_buffer = result_buffer.reshape({-1});
22+
int64_t other_size_1 = other.size(1);
23+
EfficientSizeNode new_nested_size =
24+
get_efficient_nested_size(self).clone();
25+
EfficientSizeNode new_nested_stride =
26+
get_efficient_nested_stride(self).clone();
27+
apply_efficient_size(
28+
[other_size_1](
29+
int64_t* size_ptr,
30+
int64_t size_size,
31+
int64_t* stride_ptr,
32+
int64_t stride_size) {
33+
size_ptr[1] = other_size_1;
34+
stride_ptr[1] = 1;
35+
stride_ptr[0] = other_size_1;
36+
},
37+
new_nested_size,
38+
new_nested_stride);
39+
return wrap_buffer(
40+
std::move(result_buffer), new_nested_size, new_nested_stride);
41+
}
42+
}
43+
}
44+
}
45+
}
1246
return map_nested_tensor(
1347
[](at::Tensor self, at::Tensor other) { return at::matmul(self, other); },
1448
self,

nestedtensor/csrc/storage/EfficientSizeNode.h

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,7 @@ struct EfficientSizeNode {
9292
_opt_sizes(impl::construct_efficient_size(
9393
impl::efficient_deserialize(_structure, _height),
9494
_sizes)) {
95-
// for (size_t i = 0; i < _structure.size(); i++) {
96-
// std::cout << "_structure[" << i << "]: " << _structure[i] << std::endl;
97-
// }
98-
// std::cout << "---" << std::endl;
99-
}
95+
}
10096

10197
explicit EfficientSizeNode(
10298
int64_t height,
@@ -138,6 +134,9 @@ struct EfficientSizeNode {
138134
const std::vector<int64_t>& structure() const {
139135
return _structure;
140136
}
137+
EfficientSizeNode clone() const {
138+
return EfficientSizeNode(_height, _structure, _sizes.clone(), _opt_sizes);
139+
}
141140

142141
private:
143142
int64_t _height;
@@ -159,5 +158,32 @@ static inline EfficientSizeNode map_efficient_size(
159158
size_node.height(), size_node.structure(), sizes, size_node.opt_sizes());
160159
}
161160

161+
template <class F>
162+
static inline void apply_efficient_size(
163+
F&& fn,
164+
EfficientSizeNode& size_node0,
165+
EfficientSizeNode& size_node1) {
166+
at::Tensor sizes0 = size_node0.sizes();
167+
at::Tensor sizes1 = size_node1.sizes();
168+
int64_t* sizes0_ptr = sizes0.data_ptr<int64_t>();
169+
int64_t* sizes1_ptr = sizes1.data_ptr<int64_t>();
170+
const std::vector<int64_t>& structure0 = size_node0.structure();
171+
const std::vector<int64_t>& structure1 = size_node1.structure();
172+
TORCH_CHECK(
173+
structure0.size() == structure1.size(),
174+
"Tree structure doesn't match. Size.");
175+
for (size_t i = 0; i < structure0.size(); i++) {
176+
TORCH_CHECK(
177+
structure0[i] == structure1[i],
178+
"Tree structure doesn't match. Values.");
179+
}
180+
for (int64_t i = 0; i < sizes0.size(0); i++) {
181+
fn(sizes0_ptr + i * sizes0.size(1),
182+
sizes0.size(0),
183+
sizes1_ptr + i * sizes1.size(1),
184+
sizes1.size(0));
185+
}
186+
}
187+
162188
} // namespace nested_tensor
163189
} // namespace torch

nestedtensor/nested/nested.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,4 +512,5 @@ def to_padded_tensor(self, mask_dim=None, padding=-1):
512512
tensor, mask = masking.to_tensor_mask(self, mask_dim)
513513
while mask.dim() < tensor.dim():
514514
mask = mask.unsqueeze(-1)
515+
mask = mask.to(torch.bool)
515516
return tensor.masked_fill(~mask, padding)

nestedtensor/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
__version__ = '0.1.4+291a8a1'
2-
git_version = '291a8a10d7de34c02ce2616db4eb8cf95ec27df9'
1+
__version__ = '0.1.4+fbdd335'
2+
git_version = 'fbdd335e410c7b3cf7970fbd65db181e9302e07d'
33
from nestedtensor import _C
44
if hasattr(_C, 'CUDA_VERSION'):
55
cuda = _C.CUDA_VERSION

setup.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,12 @@ def write_version_file():
6363

6464
pytorch_dep = "torch"
6565

66-
requirements = [
67-
pytorch_dep,
68-
]
69-
7066
if os.getenv("PYTORCH_VERSION"):
7167
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
7268

69+
requirements = [
70+
pytorch_dep,
71+
]
7372

7473
def get_extensions():
7574

test/test_nested_tensor_functional.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ def test_addmm(self):
2929
[torch.rand(1, 4), torch.rand(1, 4), torch.rand(4, 4)]
3030
)
3131

32+
@torch.inference_mode()
33+
def test_conv2d(self):
34+
nt = ntnt_nograd(
35+
[torch.rand(3, 35, 56), torch.rand(3, 43, 23), torch.rand(3, 24, 52)]
36+
)
37+
weight = torch.randn(5, 5).repeat(3, 3, 1, 1)
38+
torch.conv2d(nt, weight)
39+
3240
def test_contiguousity(self):
3341
initial_t = torch.rand(2, 5, 10, 15)
3442
self.assertEqual(True, initial_t.is_contiguous())

test/test_nested_tensor_masking.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def test_scalar_and_empty_nt_cuda(self):
181181

182182
# TODO: Fix this case together with C++ rewrite.
183183
self.assertRaisesRegex(
184-
RuntimeError, "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda", lambda: a.to_tensor_mask())
184+
RuntimeError, "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda", lambda: a.to_tensor_mask())
185185
# tensor, mask = a.to_tensor_mask()
186186
# TestCase.assertEqual(self, tensor, torch.tensor([[0], [11]], dtype=torch.long, device='cuda'))
187187
# TestCase.assertEqual(self, mask, torch.tensor([False, True], device='cuda'))
@@ -1105,6 +1105,31 @@ def test_ntftm_mask_dim_cuda(self):
11051105
TestCase.assertEqual(self, a, res_nt)
11061106
TestCase.assertEqual(self, res_nt.nested_dim(), a.nested_dim())
11071107

1108+
def test_to_padded_tensor(self):
1109+
data1 = torch.tensor(
1110+
[[[0.8413, 0.7325, 0.0000, 0.0000],
1111+
[0.0000, 0.0000, 0.0000, 0.0000],
1112+
[0.0000, 0.0000, 0.0000, 0.0000]],
1113+
1114+
[[0.6334, 0.5473, 0.3273, 0.0564],
1115+
[0.3023, 0.6826, 0.3519, 0.1804],
1116+
[0.8431, 0.1645, 0.1821, 0.9185]]])
1117+
mask1 = torch.tensor(
1118+
[[[True, True, False, False],
1119+
[False, False, False, False],
1120+
[False, False, False, False]],
1121+
1122+
[[True, True, True, True],
1123+
[True, True, True, True],
1124+
[True, True, True, True]]])
1125+
nt2 = nt.nested_tensor_from_tensor_mask(data1, mask1)
1126+
data2, mask2 = nt2.to_tensor_mask()
1127+
self.assertEqual(data1, data2)
1128+
self.assertEqual(mask1, mask2)
1129+
data3 = nt2.to_padded_tensor(padding=-10)
1130+
data1 = data1 + ~mask1 * -10
1131+
self.assertEqual(data1, data3)
1132+
11081133

11091134
if __name__ == "__main__":
11101135
unittest.main()

0 commit comments

Comments
 (0)