11#include < nestedtensor/csrc/masking.h>
22#include < chrono>
3+ #ifdef WITH_CUDA
4+ #include < c10/cuda/CUDAStream.h>
5+ #include < nestedtensor/csrc/cuda/padding.h>
6+ #endif
37
48using namespace torch ::nested_tensor;
59using namespace at ;
@@ -40,7 +44,7 @@ std::tuple<Tensor, Tensor> merge_tensor_mask(
4044Tensor pad_tensor_to_shape (Tensor t, std::vector<int64_t > goal_shape) {
4145 std::vector<int64_t > padd;
4246 auto tup = t.sizes ();
43- if (get_dim (t) != goal_shape.size ()) {
47+ if (get_dim (t) != ( int64_t )( goal_shape.size () )) {
4448 throw std::runtime_error (" dimension doesn't match length of goal shape." );
4549 }
4650 for (int64_t i = tup.size () - 1 ; i >= 0 ; i--) {
@@ -182,7 +186,7 @@ c10::optional<Tensor> nt_from_tensor_mask(
182186 }
183187 }
184188 std::vector<TensorNode> inner_tensor_nodes;
185- for (int64_t i = 0 ; i < inner_tensors.size (); i++) {
189+ for (size_t i = 0 ; i < inner_tensors.size (); i++) {
186190 if (inner_tensors[i]) {
187191 TensorNode node = get_nested_tensor_structure (*inner_tensors[i]);
188192 inner_tensor_nodes.push_back (node);
@@ -194,15 +198,68 @@ c10::optional<Tensor> nt_from_tensor_mask(
194198std::tuple<Tensor, Tensor> to_tensor_mask (
195199 Tensor nt,
196200 c10::optional<int64_t > mask_dim) {
197- // TODO: Cover if not isinstance(nt, list) and nt.size() == (1,):
198- // TODO: Move to_tensor_mask entirely into C++
199-
200- std::vector<int64_t > max_size = get_max_size (nt);
201- Tensor tensor;
202- Tensor mask;
203- std::tie (tensor, mask) = pad_nt (nt, max_size);
204- std::tie (tensor, mask) = merge_tensor_mask (tensor, mask, mask_dim);
205- return std::make_tuple (tensor, mask);
201+ TORCH_CHECK (
202+ !mask_dim || *mask_dim <= get_dim (nt),
203+ " Requested mask dimension " ,
204+ *mask_dim,
205+ " is bigger than dimension " ,
206+ get_dim (nt),
207+ " of given NestedTensor." );
208+
209+ auto opt_sizes = get_opt_sizes (nt);
210+ if (opt_sizes.size () == 1 && *opt_sizes[0 ] == 1 ) {
211+ nt = NestedTensor_contiguous (nt);
212+ Tensor nt_buffer = get_buffer (nt);
213+ nt_buffer = nt_buffer.reshape ({-1 });
214+ Tensor result_mask = !mask_dim || *mask_dim == 0 ? torch::tensor (true )
215+ : torch::tensor ({true });
216+ return std::make_tuple (nt_buffer, result_mask);
217+ }
218+
219+ auto max_size = get_max_size (nt);
220+ at::Tensor res_tensor;
221+ at::Tensor res_mask;
222+ std::tie (res_tensor, res_mask) = pad_nt (nt, max_size);
223+ return merge_tensor_mask (res_tensor, res_mask, mask_dim);
224+ }
225+
226+ Tensor to_padded_tensor (Tensor nt, double padding) {
227+ #ifdef WITH_CUDA
228+ if (get_dim (nt) == 3 ) {
229+ auto nt_opt_size = get_opt_sizes (nt);
230+ if (nt_opt_size[2 ]) {
231+ Tensor nt_buffer = get_buffer (nt);
232+ Tensor nt_sizes_ =
233+ get_efficient_nested_size (nt).sizes ().to (torch::kInt32 );
234+ TORCH_CHECK (nt_sizes_.dim () == 2 , " NestedTensor must be of nested_dim 2." )
235+ Tensor nt_sizes = at::native::narrow (nt_sizes_, 1 , 0 , 1 );
236+ int max_size_1 = nt_sizes.max ().item <int >();
237+ nt_sizes =
238+ at::native::cumsum (nt_sizes, 0 ).to (torch::kInt32 ).reshape ({-1 });
239+ nt_sizes = at::cat ({torch::tensor ({0 }, torch::kInt32 ), nt_sizes});
240+ Tensor output = torch::empty (
241+ {*nt_opt_size[0 ], max_size_1, *nt_opt_size[2 ]}, nt_buffer.options ());
242+ output.fill_ (padding);
243+ nt_sizes = nt_sizes.to (torch::kCUDA );
244+ at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream ();
245+ nested_tensor::cuda::add_padding_kernelLauncher (
246+ nt_buffer.data_ptr <float >(),
247+ output.data_ptr <float >(),
248+ nt_sizes.data_ptr <int >(),
249+ *nt_opt_size[0 ],
250+ output.stride (0 ),
251+ *nt_opt_size[2 ],
252+ defaultStream);
253+ return output;
254+ }
255+ }
256+ #endif
257+ at::Tensor tensor;
258+ at::Tensor mask;
259+ std::tie (tensor, mask) = to_tensor_mask (nt, get_dim (nt));
260+ mask = mask.to (torch::kBool );
261+ tensor.masked_fill_ (at::logical_not (mask), padding);
262+ return tensor;
206263}
207264
208265TORCH_LIBRARY_FRAGMENT (nestedtensor, m) {
@@ -219,4 +276,10 @@ TORCH_LIBRARY_FRAGMENT(nestedtensor, m) {
219276
220277 m.def (" get_max_size(Tensor nt) -> int[]" );
221278 m.impl (" get_max_size" , NestedTensorKey, TORCH_FN (get_max_size));
279+
280+ m.def (" to_tensor_mask(Tensor nt, int? mask_dim) -> (Tensor, Tensor)" );
281+ m.impl (" to_tensor_mask" , NestedTensorKey, to_tensor_mask);
282+
283+ m.def (" to_padded_tensor(Tensor nt, float padding) -> Tensor" );
284+ m.impl (" to_padded_tensor" , NestedTensorKey, to_padded_tensor);
222285}
0 commit comments