@@ -105,7 +105,6 @@ Tensor NestedTensor_batch_norm(
105105 check_dims_match_num_input_features (" bias" , n_input, get_numel (*bias));
106106 }
107107
108- auto scalar_shape = make_scalar_shape (get_dim (input), n_input);
109108 at::Tensor mean = *running_mean;
110109 at::Tensor var = *running_var;
111110#ifdef WITH_CUDA
@@ -120,46 +119,64 @@ Tensor NestedTensor_batch_norm(
120119 (mean.dtype () == torch::kHalf ) &&
121120 (var.dtype () == torch::kHalf ) &&
122121 (bias->dtype () == torch::kHalf ) &&
123- (weight->dtype () == torch::kHalf )
122+ (weight->dtype () == torch::kHalf ) &&
123+ get_is_cuda (input)
124124 )
125125 {
126-
127126 // Custom CUDA Half implementation.
128127 mean = mean.contiguous ();
129128 Tensor bias_cont = (*bias).contiguous ();
130129 Tensor weight_cont = (*weight).contiguous ();
131130 Tensor running_var_cont = (*running_var).contiguous ();
131+
132+ c10::Half* mean_ptr = mean.data_ptr <c10::Half>();
133+ c10::Half* bias_ptr = bias_cont.data_ptr <c10::Half>();
134+ c10::Half* weight_ptr = weight_cont.data_ptr <c10::Half>();
135+ c10::Half* running_var_ptr = running_var_cont.data_ptr <c10::Half>();
136+
137+ if (get_is_contiguous (input, c10::MemoryFormat::ChannelsLast)) {
138+ Tensor input_buffer = get_buffer (input);
139+ int64_t num_channel = weight_cont.size (0 );
140+ at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream ();
141+ nested_tensor::cuda::batchnorm_inference_channels_last_kernelLauncher (
142+ input_buffer.data_ptr <c10::Half>(),
143+ mean_ptr,
144+ running_var_ptr,
145+ c10::Half ((float )(eps)),
146+ weight_ptr,
147+ bias_ptr,
148+ input_buffer.data_ptr <c10::Half>(),
149+ num_channel,
150+ input_buffer.numel (),
151+ defaultStream);
152+ input_buffer = input_buffer.view (-1 );
153+ return wrap_buffer (std::move (input_buffer), get_efficient_nested_size (input), get_efficient_nested_stride (input));
154+ }
132155
133156 Tensor output = input;
134157 output = NestedTensor_contiguous (output);
135158 Tensor input_buffer = get_buffer (output);
136- Tensor output_buffer = input_buffer.clone ();
159+ // Tensor output_buffer = input_buffer.clone();
137160
138161 auto self_opt_sizes = get_opt_sizes (input);
139162
140163 Tensor nt_sizes_ =
141- get_efficient_nested_size (input).sizes ().to (torch::kInt32 );
164+ get_efficient_nested_size (input).sizes (); // .to(torch::kInt32);
142165 Tensor nt_sizes_1 = at::native::narrow (nt_sizes_, 1 , 1 , 1 );
143166 Tensor nt_sizes_2 = at::native::narrow (nt_sizes_, 1 , 2 , 1 );
144167 Tensor nt_sizes_all = nt_sizes_1 * nt_sizes_2;
145- int * nt_sizes_all_ptr = nt_sizes_all.data_ptr <int >();
146- std::vector< int > numbers ;
147- numbers. reserve ( 1 + (nt_sizes_all. size ( 0 ) * *self_opt_sizes[ 1 ]) );
148- numbers. push_back ( 0 ) ;
168+ int64_t * nt_sizes_all_ptr = nt_sizes_all.data_ptr <int64_t >();
169+ at::Tensor numbers_t = at::empty ({ 1 + (nt_sizes_all. size ( 0 ) * *self_opt_sizes[ 1 ])}, torch:: kInt64 ) ;
170+ int64_t * numbers_t_ptr = numbers_t . data_ptr < int64_t >( );
171+ numbers_t_ptr[ 0 ] = 0 ;
149172 int64_t index = 1 ;
150173 for (int64_t i = 0 ; i < nt_sizes_all.size (0 ); i++) {
151174 for (int64_t j = 0 ; j < *self_opt_sizes[1 ]; j++) {
152- numbers. push_back (numbers [index - 1 ] + nt_sizes_all_ptr[i]);
175+ numbers_t_ptr[index] = (numbers_t_ptr [index - 1 ] + nt_sizes_all_ptr[i]);
153176 index++;
154177 }
155178 }
156- at::Tensor numbers_t = torch::tensor (numbers).to (torch::kInt32 );
157- Tensor nt_sizes = numbers_t .to (torch::kCUDA );
158-
159- c10::Half* mean_ptr = mean.data_ptr <c10::Half>();
160- c10::Half* running_var_ptr = running_var_cont.data_ptr <c10::Half>();
161- c10::Half* bias_ptr = bias_cont.data_ptr <c10::Half>();
162- c10::Half* weight_ptr = weight_cont.data_ptr <c10::Half>();
179+ Tensor nt_sizes = numbers_t .to (at::Device (kCUDA ), torch::kInt32 , true , true );
163180
164181 at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream ();
165182 nested_tensor::cuda::batchnorm_inference_kernelLauncher (
@@ -169,15 +186,21 @@ Tensor NestedTensor_batch_norm(
169186 c10::Half ((float )(eps)),
170187 weight_ptr,
171188 bias_ptr,
172- output_buffer .data_ptr <c10::Half>(),
173- ( int )(*self_opt_sizes[ 0 ] * *self_opt_sizes[ 1 ] ),
189+ input_buffer .data_ptr <c10::Half>(),
190+ // output_buffer.data_ptr<c10::Half>( ),
174191 (int )(*self_opt_sizes[0 ]),
192+ (int )(weight_cont.size (0 )),
193+ (int )(*self_opt_sizes[0 ] *
194+ *self_opt_sizes[1 ] *
195+ *self_opt_sizes[2 ] *
196+ *self_opt_sizes[3 ]),
175197 nt_sizes.data_ptr <int >(),
176198 defaultStream
177199 );
178- return wrap_buffer (std::move (output_buffer ), get_efficient_nested_size (output), get_efficient_nested_stride (output));
200+ return wrap_buffer (std::move (input_buffer ), get_efficient_nested_size (output), get_efficient_nested_stride (output));
179201 }
180202#endif
203+ auto scalar_shape = make_scalar_shape (get_dim (input), n_input);
181204
182205 at::Tensor invstd = 1 / at::sqrt (*running_var + eps);
183206
0 commit comments