@@ -26,7 +26,7 @@ std::tuple<Tensor, Tensor> merge_tensor_mask(
2626 Tensor is_zero = (collapsed_mask == 0 );
2727 int64_t is_last_size_sum = is_last_size.sum ().item <int64_t >();
2828 int64_t is_zero_sum = is_zero.sum ().item <int64_t >();
29- if ((is_last_size_sum + is_zero_sum) == collapsed_mask. numel ( )) {
29+ if ((is_last_size_sum + is_zero_sum) == get_numel (collapsed_mask )) {
3030 collapsed_mask = collapsed_mask.to (torch::kBool );
3131 return merge_tensor_mask (tensor, collapsed_mask, mask_dim);
3232 }
@@ -85,7 +85,7 @@ std::vector<int64_t> get_max_size(Tensor nt) {
8585
8686std::tuple<Tensor, Tensor> pad_nt (Tensor nt, std::vector<int64_t > shape) {
8787 if (!is_nested_tensor_impl (nt)) {
88- if (nt. numel ( ) == 0 ) {
88+ if (get_numel (nt ) == 0 ) {
8989 TORCH_CHECK (false , " Empty tensors are not yet supported." );
9090 }
9191 // Dont pad in case of a scalar
@@ -131,7 +131,7 @@ c10::optional<Tensor> nt_from_tensor_mask(
131131 Tensor mask,
132132 int64_t nested_dim) {
133133 if (nested_dim == 0 ) {
134- if ((mask. numel ( ) == 0 ) || (mask. numel ( ) == 1 && mask.item <bool >())) {
134+ if ((get_numel (mask ) == 0 ) || (get_numel (mask ) == 1 && mask.item <bool >())) {
135135 return tensor;
136136 }
137137
@@ -153,7 +153,7 @@ c10::optional<Tensor> nt_from_tensor_mask(
153153 bool all_zero = true ;
154154 for (int64_t i = 0 ; i < mask.size (0 ); i++) {
155155 Tensor tmp = *nt_from_tensor_mask (tensor[i], mask[i], nested_dim);
156- if (tmp. numel ( ) > 0 ) {
156+ if (get_numel (tmp ) > 0 ) {
157157 all_zero = false ;
158158 tensors.push_back (tmp);
159159 }
@@ -172,12 +172,12 @@ c10::optional<Tensor> nt_from_tensor_mask(
172172 return c10::nullopt ;
173173 }
174174 std::vector<c10::optional<Tensor>> inner_tensors;
175- if ((mask. numel ( ) == 0 ) || (mask. numel ( ) == 1 && mask.item <bool >())) {
175+ if ((get_numel (mask ) == 0 ) || (get_numel (mask ) == 1 && mask.item <bool >())) {
176176 for (int64_t i = 0 ; i < tensor.size (0 ); i++) {
177177 inner_tensors.push_back (
178178 nt_from_tensor_mask (tensor[i], mask, nested_dim - 1 ));
179179 }
180- } else if (mask. numel ( ) == 1 && !mask.item <bool >()) {
180+ } else if (get_numel (mask ) == 1 && !mask.item <bool >()) {
181181 inner_tensors.push_back (c10::nullopt );
182182 } else {
183183 for (int64_t i = 0 ; i < tensor.size (0 ); i++) {
@@ -198,6 +198,41 @@ c10::optional<Tensor> nt_from_tensor_mask(
198198std::tuple<Tensor, Tensor> to_tensor_mask (
199199 Tensor nt,
200200 c10::optional<int64_t > mask_dim) {
201+ #ifdef WITH_CUDA
202+ if (get_dim (nt) == 3 && get_is_contiguous (nt) && mask_dim && *mask_dim == 2 ) {
203+ auto nt_opt_size = get_opt_sizes (nt);
204+ Tensor nt_buffer = get_buffer (nt);
205+ if (nt_opt_size[2 ] && nt_buffer.is_cuda ()) {
206+ std::cout << " Calling efficient to_tensor_mask" << std::endl;
207+ Tensor nt_sizes_ =
208+ get_efficient_nested_size (nt).sizes ().to (torch::kInt32 );
209+ TORCH_CHECK (nt_sizes_.dim () == 2 , " NestedTensor must be of nested_dim 2." )
210+ Tensor nt_sizes = at::native::narrow (nt_sizes_, 1 , 0 , 1 );
211+ int max_size_1 = nt_sizes.max ().item <int >();
212+ nt_sizes =
213+ at::native::cumsum (nt_sizes, 0 ).to (torch::kInt32 ).reshape ({-1 });
214+ nt_sizes = at::cat ({torch::tensor ({0 }, torch::kInt32 ), nt_sizes});
215+ Tensor output = torch::zeros (
216+ {*nt_opt_size[0 ], max_size_1, *nt_opt_size[2 ]}, nt_buffer.options ());
217+ nt_sizes = nt_sizes.to (torch::kCUDA );
218+ Tensor output_mask = torch::zeros (
219+ {*nt_opt_size[0 ], max_size_1}, nt_buffer.options ());
220+ output_mask = output_mask.to (torch::kInt32 );
221+ at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream ();
222+ nested_tensor::cuda::add_padding_mask_kernelLauncher (
223+ nt_buffer.data_ptr <float >(),
224+ output.data_ptr <float >(),
225+ output_mask.data_ptr <int >(),
226+ nt_sizes.data_ptr <int >(),
227+ *nt_opt_size[0 ],
228+ output_mask.stride (0 ),
229+ output.stride (0 ),
230+ *nt_opt_size[2 ],
231+ defaultStream);
232+ return std::make_tuple (output, output_mask.to (torch::kBool ));
233+ }
234+ }
235+ #endif
201236 TORCH_CHECK (
202237 !mask_dim || *mask_dim <= get_dim (nt),
203238 " Requested mask dimension " ,
@@ -225,10 +260,10 @@ std::tuple<Tensor, Tensor> to_tensor_mask(
225260
226261Tensor to_padded_tensor (Tensor nt, double padding) {
227262#ifdef WITH_CUDA
228- if (get_dim (nt) == 3 ) {
263+ if (get_dim (nt) == 3 && get_is_contiguous (nt) ) {
229264 auto nt_opt_size = get_opt_sizes (nt);
230- if (nt_opt_size[ 2 ]) {
231- Tensor nt_buffer = get_buffer (nt);
265+ Tensor nt_buffer = get_buffer (nt);
266+ if (nt_opt_size[ 2 ] && nt_buffer. is_cuda ()) {
232267 Tensor nt_sizes_ =
233268 get_efficient_nested_size (nt).sizes ().to (torch::kInt32 );
234269 TORCH_CHECK (nt_sizes_.dim () == 2 , " NestedTensor must be of nested_dim 2." )
0 commit comments