11#include < nestedtensor/csrc/BinaryOps.h>
2+ #ifdef WITH_CUDA
3+ #include < c10/cuda/CUDAStream.h>
4+ #include < nestedtensor/csrc/cuda/add.h>
5+ #include < c10/util/Half.h>
6+ #endif
27
38namespace at {
49
@@ -31,11 +36,56 @@ Tensor NestedTensor_add_Tensor(
3136 }
3237 }
3338 if (is_nested_tensor_impl (self) && !is_nested_tensor_impl (other)) {
34- if (!get_is_contiguous (self)) {
35- self = NestedTensor_contiguous (self);
36- }
39+ self = NestedTensor_contiguous (self);
3740 int64_t self_dim = get_dim (self);
3841 auto self_opt_sizes = get_opt_sizes (self);
42+ #ifdef WITH_CUDA
43+ if (self_dim == 4 && other.dim () == 4 &&
44+ self_opt_sizes[0 ] &&
45+ self_opt_sizes[1 ] &&
46+ (*self_opt_sizes[1 ]) == other.size (1 ) &&
47+ other.size (0 ) == 1 &&
48+ other.size (2 ) == 1 &&
49+ other.size (3 ) == 1 &&
50+ self.dtype () == c10::ScalarType::Half &&
51+ other.dtype () == c10::ScalarType::Half) {
52+ other = other.contiguous ();
53+ at::Tensor self_buffer = get_buffer (self);
54+ Tensor nt_sizes_ =
55+ get_efficient_nested_size (self).sizes ().to (torch::kInt32 );
56+ Tensor nt_sizes_1 = at::native::narrow (nt_sizes_, 1 , 1 , 1 );
57+ Tensor nt_sizes_2 = at::native::narrow (nt_sizes_, 1 , 2 , 1 );
58+ Tensor nt_sizes_all = nt_sizes_1 * nt_sizes_2;
59+ std::vector<int > numbers;
60+ for (int64_t i = 0 ; i < nt_sizes_all.size (0 ); i++) {
61+ for (int64_t j = 0 ; j < *self_opt_sizes[1 ]; j++) {
62+ numbers.push_back (nt_sizes_all[i].item <int >());
63+ }
64+ }
65+ at::Tensor numbers_t = torch::tensor (numbers).to (torch::kInt32 );
66+ Tensor nt_sizes_cumsum =
67+ at::native::cumsum (numbers_t , 0 ).to (torch::kInt32 ).reshape ({-1 });
68+ TORCH_CHECK (nt_sizes_.dim () == 2 , " NestedTensor metadata of unexpected dimension." )
69+ Tensor nt_sizes = at::cat ({torch::tensor ({0 }, torch::kInt32 ), nt_sizes_cumsum});
70+ nt_sizes = nt_sizes.to (torch::kCUDA );
71+ at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream ();
72+ at::Tensor result_buffer = self_buffer.clone ();
73+
74+ c10::Half* self_ptr = self_buffer.data_ptr <c10::Half>();
75+ c10::Half* other_ptr = other.data_ptr <c10::Half>();
76+ c10::Half* result_ptr = result_buffer.data_ptr <c10::Half>();
77+ nested_tensor::cuda::add_scalar_kernelLauncher (
78+ self_ptr,
79+ other_ptr,
80+ result_ptr,
81+ (int )(*self_opt_sizes[0 ] * *self_opt_sizes[1 ]),
82+ (int )(*self_opt_sizes[0 ]),
83+ nt_sizes.data_ptr <int >(),
84+ defaultStream);
85+ return wrap_buffer (std::move (result_buffer), get_efficient_nested_size (self),
86+ get_efficient_nested_stride (self));
87+ }
88+ #endif
3989 if (self_opt_sizes[self_dim - 1 ] && other.dim () == 1 &&
4090 (*(self_opt_sizes[self_dim - 1 ])) == other.size (0 )) {
4191 Tensor self_buffer = get_buffer (self);
@@ -50,7 +100,8 @@ Tensor NestedTensor_add_Tensor(
50100 }
51101 std::tie (self, other) = _expand_other_as (self_, other_);
52102 return map_nested_tensor (
53- [&alpha](Tensor s, Tensor o) { return at::add (s, o, alpha); },
103+ [&alpha](Tensor s, Tensor o) {
104+ return at::add (s, o, alpha); },
54105 self,
55106 other);
56107}
@@ -180,11 +231,64 @@ Tensor& NestedTensor_floor_divide_out(
180231}
181232
182233Tensor NestedTensor_mul_Tensor (const Tensor& self_, const Tensor& other_) {
183- Tensor self;
184- Tensor other;
234+ Tensor self = self_;
235+ Tensor other = other_;
236+ if (is_nested_tensor_impl (self) && !is_nested_tensor_impl (other)) {
237+ self = NestedTensor_contiguous (self);
238+ int64_t self_dim = get_dim (self);
239+ auto self_opt_sizes = get_opt_sizes (self);
240+ #ifdef WITH_CUDA
241+ if (self_dim == 4 && other.dim () == 4 &&
242+ self_opt_sizes[0 ] &&
243+ self_opt_sizes[1 ] &&
244+ (*self_opt_sizes[1 ]) == other.size (1 ) &&
245+ other.size (0 ) == 1 &&
246+ other.size (2 ) == 1 &&
247+ other.size (3 ) == 1 &&
248+ self.dtype () == c10::ScalarType::Half &&
249+ other.dtype () == c10::ScalarType::Half) {
250+ other = other.contiguous ();
251+ at::Tensor self_buffer = get_buffer (self);
252+ Tensor nt_sizes_ =
253+ get_efficient_nested_size (self).sizes ().to (torch::kInt32 );
254+ Tensor nt_sizes_1 = at::native::narrow (nt_sizes_, 1 , 1 , 1 );
255+ Tensor nt_sizes_2 = at::native::narrow (nt_sizes_, 1 , 2 , 1 );
256+ Tensor nt_sizes_all = nt_sizes_1 * nt_sizes_2;
257+ std::vector<int > numbers;
258+ for (int64_t i = 0 ; i < nt_sizes_all.size (0 ); i++) {
259+ for (int64_t j = 0 ; j < *self_opt_sizes[1 ]; j++) {
260+ numbers.push_back (nt_sizes_all[i].item <int >());
261+ }
262+ }
263+ at::Tensor numbers_t = torch::tensor (numbers).to (torch::kInt32 );
264+ Tensor nt_sizes_cumsum =
265+ at::native::cumsum (numbers_t , 0 ).to (torch::kInt32 ).reshape ({-1 });
266+ TORCH_CHECK (nt_sizes_.dim () == 2 , " NestedTensor metadata of unexpected dimension." )
267+ Tensor nt_sizes = at::cat ({torch::tensor ({0 }, torch::kInt32 ), nt_sizes_cumsum});
268+ nt_sizes = nt_sizes.to (torch::kCUDA );
269+ at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream ();
270+ at::Tensor result_buffer = self_buffer.clone ();
271+
272+ c10::Half* self_ptr = self_buffer.data_ptr <c10::Half>();
273+ c10::Half* other_ptr = other.data_ptr <c10::Half>();
274+ c10::Half* result_ptr = result_buffer.data_ptr <c10::Half>();
275+ nested_tensor::cuda::mul_scalar_kernelLauncher (
276+ self_ptr,
277+ other_ptr,
278+ result_ptr,
279+ (int )(*self_opt_sizes[0 ] * *self_opt_sizes[1 ]),
280+ (int )(*self_opt_sizes[0 ]),
281+ nt_sizes.data_ptr <int >(),
282+ defaultStream);
283+ return wrap_buffer (std::move (result_buffer), get_efficient_nested_size (self),
284+ get_efficient_nested_stride (self));
285+ }
286+ #endif
287+ }
185288 std::tie (self, other) = _expand_other_as (self_, other_);
186289 return map_nested_tensor (
187- [](Tensor s, Tensor o) { return at::mul (s, o); }, self, other);
290+ [](Tensor s, Tensor o) {
291+ return at::mul (s, o); }, self, other);
188292}
189293
190294Tensor& NestedTensor_mul__Tensor (Tensor& self_, const Tensor& other_) {
@@ -246,11 +350,64 @@ Tensor NestedTensor_sub_Tensor(
246350 const Tensor& self_,
247351 const Tensor& other_,
248352 const Scalar& alpha) {
249- Tensor self;
250- Tensor other;
353+ Tensor self = self_;
354+ Tensor other = other_;
355+ if (is_nested_tensor_impl (self) && !is_nested_tensor_impl (other)) {
356+ self = NestedTensor_contiguous (self);
357+ int64_t self_dim = get_dim (self);
358+ auto self_opt_sizes = get_opt_sizes (self);
359+ #ifdef WITH_CUDA
360+ if (self_dim == 4 && other.dim () == 4 &&
361+ self_opt_sizes[0 ] &&
362+ self_opt_sizes[1 ] &&
363+ (*self_opt_sizes[1 ]) == other.size (1 ) &&
364+ other.size (0 ) == 1 &&
365+ other.size (2 ) == 1 &&
366+ other.size (3 ) == 1 &&
367+ self.dtype () == c10::ScalarType::Half &&
368+ other.dtype () == c10::ScalarType::Half) {
369+ other = other.contiguous ();
370+ at::Tensor self_buffer = get_buffer (self);
371+ Tensor nt_sizes_ =
372+ get_efficient_nested_size (self).sizes ().to (torch::kInt32 );
373+ Tensor nt_sizes_1 = at::native::narrow (nt_sizes_, 1 , 1 , 1 );
374+ Tensor nt_sizes_2 = at::native::narrow (nt_sizes_, 1 , 2 , 1 );
375+ Tensor nt_sizes_all = nt_sizes_1 * nt_sizes_2;
376+ std::vector<int > numbers;
377+ for (int64_t i = 0 ; i < nt_sizes_all.size (0 ); i++) {
378+ for (int64_t j = 0 ; j < *self_opt_sizes[1 ]; j++) {
379+ numbers.push_back (nt_sizes_all[i].item <int >());
380+ }
381+ }
382+ at::Tensor numbers_t = torch::tensor (numbers).to (torch::kInt32 );
383+ Tensor nt_sizes_cumsum =
384+ at::native::cumsum (numbers_t , 0 ).to (torch::kInt32 ).reshape ({-1 });
385+ TORCH_CHECK (nt_sizes_.dim () == 2 , " NestedTensor metadata of unexpected dimension." )
386+ Tensor nt_sizes = at::cat ({torch::tensor ({0 }, torch::kInt32 ), nt_sizes_cumsum});
387+ nt_sizes = nt_sizes.to (torch::kCUDA );
388+ at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream ();
389+ at::Tensor result_buffer = self_buffer.clone ();
390+
391+ c10::Half* self_ptr = self_buffer.data_ptr <c10::Half>();
392+ c10::Half* other_ptr = other.data_ptr <c10::Half>();
393+ c10::Half* result_ptr = result_buffer.data_ptr <c10::Half>();
394+ nested_tensor::cuda::sub_scalar_kernelLauncher (
395+ self_ptr,
396+ other_ptr,
397+ result_ptr,
398+ (int )(*self_opt_sizes[0 ] * *self_opt_sizes[1 ]),
399+ (int )(*self_opt_sizes[0 ]),
400+ nt_sizes.data_ptr <int >(),
401+ defaultStream);
402+ return wrap_buffer (std::move (result_buffer), get_efficient_nested_size (self),
403+ get_efficient_nested_stride (self));
404+ }
405+ #endif
406+ }
251407 std::tie (self, other) = _expand_other_as (self_, other_);
252408 return map_nested_tensor (
253- [&alpha](Tensor s, Tensor o) { return at::sub (s, o, alpha); },
409+ [&alpha](Tensor s, Tensor o) {
410+ return at::sub (s, o, alpha); },
254411 self,
255412 other);
256413}
0 commit comments