@@ -11,6 +11,22 @@ Tensor NestedTensor_add_Tensor(
1111 Tensor self;
1212 Tensor other;
1313 std::tie (self, other) = _expand_other_as (self_, other_);
14+ if (is_nested_tensor_impl (self) && !is_nested_tensor_impl (other) &&
15+ get_is_contiguous (self)) {
16+ int64_t self_dim = get_dim (self);
17+ auto self_opt_sizes = get_opt_sizes (self);
18+ if (self_opt_sizes[self_dim - 1 ] && other.dim () == 1 &&
19+ (*(self_opt_sizes[self_dim - 1 ])) == other.size (0 )) {
20+ Tensor self_buffer = get_buffer (self);
21+ Tensor result_buffer =
22+ at::add (self_buffer.reshape ({-1 , other.size (0 )}), other)
23+ .reshape ({-1 });
24+ return wrap_buffer (
25+ std::move (result_buffer),
26+ get_efficient_nested_size (self),
27+ get_efficient_nested_stride (self));
28+ }
29+ }
1430 return map_nested_tensor (
1531 [&alpha](Tensor s, Tensor o) { return at::add (s, o, alpha); },
1632 self,
@@ -97,6 +113,50 @@ Tensor& NestedTensor_div_out(
97113 return out;
98114}
99115
116+ Tensor NestedTensor_floor_divide_Tensor (
117+ const Tensor& self_,
118+ const Tensor& other_) {
119+ Tensor self;
120+ Tensor other;
121+ std::tie (self, other) = _expand_other_as (self_, other_);
122+ return map_nested_tensor (
123+ [](Tensor s, Tensor o) { return at::floor_divide (s, o); }, self, other);
124+ }
125+
126+ Tensor& NestedTensor_floor_divide__Tensor (Tensor& self_, const Tensor& other_) {
127+ at::Tensor self;
128+ at::Tensor other;
129+ std::tie (self, other) = _expand_other_as (self_, other_);
130+ apply_nested_tensor (
131+ [](Tensor& tensor, const Tensor other) {
132+ tensor.floor_divide_ (other);
133+ return tensor;
134+ },
135+ self,
136+ other);
137+ return self_;
138+ }
139+
140+ Tensor& NestedTensor_floor_divide_out (
141+ const Tensor& self,
142+ const Tensor& other,
143+ Tensor& out) {
144+ TORCH_CHECK (
145+ is_nested_tensor_impl (out),
146+ " NT binary out variant requires NT as out argument." );
147+ TORCH_CHECK (
148+ is_nested_tensor_impl (out, self, other),
149+ " binary_out doesn't support non-NT arguments." )
150+ apply_nested_tensor (
151+ [](Tensor& self, Tensor& other, Tensor& out) {
152+ return at::floor_divide_out (self, other, out);
153+ },
154+ self,
155+ other,
156+ out);
157+ return out;
158+ }
159+
100160Tensor NestedTensor_mul_Tensor (const Tensor& self_, const Tensor& other_) {
101161 Tensor self;
102162 Tensor other;
@@ -270,13 +330,32 @@ Tensor& NestedTensor_pow__Tensor(Tensor& self_, const Tensor& other_) {
270330 return self_;
271331}
272332
333+ Tensor NestedTensor_pow_Scalar (const Scalar& base, const Tensor& exponent_) {
334+ Tensor exponent = exponent_;
335+ return map_nested_tensor (
336+ [&base](Tensor exponent) { return at::pow (base, exponent); }, exponent);
337+ }
338+
339+ Tensor NestedTensor_pow_Tensor_Tensor (
340+ const Tensor& self_,
341+ const Tensor& other_) {
342+ Tensor self;
343+ Tensor other;
344+ std::tie (self, other) = _expand_other_as (self_, other_);
345+ return map_nested_tensor (
346+ [](Tensor s, Tensor o) { return at::pow (s, o); }, self, other);
347+ }
348+
273349TORCH_LIBRARY_IMPL (aten, NestedTensor, m) {
274350 nt_impl (m, " add.Tensor" , NestedTensor_add_Tensor);
275351 nt_impl (m, " add_.Tensor" , NestedTensor_add__Tensor);
276352 nt_impl (m, " add.out" , NestedTensor_add_out);
277353 nt_impl (m, " div.Tensor" , NestedTensor_div_Tensor);
278354 nt_impl (m, " div_.Tensor" , NestedTensor_div__Tensor);
279355 nt_impl (m, " div.out" , NestedTensor_div_out);
356+ nt_impl (m, " floor_divide" , NestedTensor_floor_divide_Tensor);
357+ nt_impl (m, " floor_divide_.Tensor" , NestedTensor_floor_divide__Tensor);
358+ nt_impl (m, " floor_divide.out" , NestedTensor_floor_divide_out);
280359 nt_impl (m, " mul.Tensor" , NestedTensor_mul_Tensor);
281360 nt_impl (m, " mul_.Tensor" , NestedTensor_mul__Tensor);
282361 nt_impl (m, " mul.out" , NestedTensor_mul_out);
@@ -289,6 +368,8 @@ TORCH_LIBRARY_IMPL(aten, NestedTensor, m) {
289368 nt_impl (m, " atan2" , NestedTensor_atan2);
290369 nt_impl (m, " remainder.Tensor" , NestedTensor_remainder_Tensor);
291370 nt_impl (m, " pow_.Tensor" , NestedTensor_pow__Tensor);
371+ nt_impl (m, " pow.Scalar" , NestedTensor_pow_Scalar);
372+ nt_impl (m, " pow.Tensor_Tensor" , NestedTensor_pow_Tensor_Tensor);
292373}
293374
294375} // namespace at
0 commit comments