Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit d91f32c

Browse files
cpuhrschfacebook-github-bot
authored andcommitted
20210519 nestedtensor import
Reviewed By: mthrok Differential Revision: D28543070 fbshipit-source-id: 992aea23686a445a67c126ead9580db2a14a5544
1 parent a2c1a07 commit d91f32c

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

nestedtensor/csrc/BinaryOps.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,30 @@ Tensor NestedTensor_add_Tensor(
1111
Tensor self;
1212
Tensor other;
1313
std::tie(self, other) = _expand_other_as(self_, other_);
14-
if (is_nested_tensor_impl(self) && !is_nested_tensor_impl(other) &&
15-
get_is_contiguous(self)) {
14+
if (is_nested_tensor_impl(self) && is_nested_tensor_impl(other)) {
15+
EfficientSizeNode self_efficient_nested_size =
16+
get_efficient_nested_size(self);
17+
EfficientSizeNode other_efficient_nested_size =
18+
get_efficient_nested_size(other);
19+
if (efficient_size_matches(
20+
self_efficient_nested_size, other_efficient_nested_size)) {
21+
if (!get_is_contiguous(self)) {
22+
self = NestedTensor_contiguous(self);
23+
}
24+
if (!get_is_contiguous(other)) {
25+
other = NestedTensor_contiguous(other);
26+
}
27+
return wrap_buffer(
28+
at::add(
29+
get_buffer(self).reshape({-1}), get_buffer(other).reshape({-1})),
30+
self_efficient_nested_size,
31+
get_efficient_nested_stride(self));
32+
}
33+
}
34+
if (is_nested_tensor_impl(self) && !is_nested_tensor_impl(other)) {
35+
if (!get_is_contiguous(self)) {
36+
self = NestedTensor_contiguous(self);
37+
}
1638
int64_t self_dim = get_dim(self);
1739
auto self_opt_sizes = get_opt_sizes(self);
1840
if (self_opt_sizes[self_dim - 1] && other.dim() == 1 &&

nestedtensor/csrc/storage/EfficientSizeNode.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ inline bool efficient_size_structure_matches(
153153
return true;
154154
}
155155

156+
inline bool efficient_size_matches(
157+
EfficientSizeNode& size_node0,
158+
EfficientSizeNode& size_node1) {
159+
if (!efficient_size_structure_matches(size_node0, size_node1)) {
160+
return false;
161+
}
162+
at::Tensor sizes0 = size_node0.sizes();
163+
at::Tensor sizes1 = size_node1.sizes();
164+
return at::equal(sizes0, sizes1);
165+
}
166+
156167
template <class F>
157168
inline EfficientSizeNode map_efficient_size(
158169
F&& fn,

0 commit comments

Comments
 (0)