@@ -541,6 +541,110 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
541541 return std::make_tuple (at::stack (results), 0 );
542542}
543543
544+ std::tuple<Tensor,optional<int64_t >> index_fill_int_scalar_batch_rule (
545+ const Tensor& self, optional<int64_t > self_bdim,
546+ int64_t dim,
547+ const Tensor& index, optional<int64_t > index_bdim,
548+ const Scalar& value) {
549+
550+ // std::cout << "index_fill_int_scalar_batch_rule:" << std::endl;
551+ if (!index_bdim) {
552+ // Handle scalar tensors... self, other can be scalar tensors
553+ const auto self_logical_rank = rankWithoutBatchDim (self, self_bdim);
554+ auto self_ = moveBatchDimToFront (self, self_bdim);
555+ if (self_logical_rank == 0 ) {
556+ self_ = self_.unsqueeze (-1 );
557+ }
558+ dim = maybe_wrap_dim (dim, self_logical_rank);
559+
560+ optional<int64_t > out_bdim = nullopt ;
561+ if (self_bdim) {
562+ const auto batch_size = self.size (*self_bdim);
563+ self_ = ensure_has_bdim (self_, self_bdim.has_value (), batch_size);
564+ dim = dim + 1 ;
565+ out_bdim = 0 ;
566+ }
567+
568+ // std::cout << "1 index_fill, self_: " << self_.sizes() << " index: " << index.sizes() << std::endl;
569+ auto result = self_.index_fill (dim, index, value);
570+ if (self_logical_rank == 0 ) {
571+ result = result.squeeze (-1 );
572+ }
573+ return std::make_tuple (result, out_bdim);
574+ }
575+
576+ // SAME AS FOR index_add
577+ // Index is batched. For-loop and stack is the best thing I can come up with
578+ // right now. We really want generalized index_fill kernel in PyTorch
579+ auto batch_size = get_bdim_size2 (self, self_bdim, index, index_bdim);
580+ std::vector<Tensor> results;
581+ results.reserve (batch_size);
582+ // std::cout << "2 index_fill loop: " << std::endl;
583+ for (const auto i : c10::irange (0 , batch_size)) {
584+ const auto & self_slice = self_bdim.has_value () ?
585+ self.select (*self_bdim, i) : self;
586+ const auto & index_slice = index_bdim.has_value () ?
587+ index.select (*index_bdim, i) : index;
588+ // std::cout << i << " self_: " << self_slice.sizes() << " index: " << index_slice.sizes() << std::endl;
589+ results.push_back (at::index_fill (self_slice, dim, index_slice, value));
590+ }
591+ return std::make_tuple (at::stack (results), 0 );
592+ }
593+
594+ std::tuple<Tensor,optional<int64_t >> index_fill_int_tensor_batch_rule (
595+ const Tensor& self, optional<int64_t > self_bdim,
596+ int64_t dim,
597+ const Tensor& index, optional<int64_t > index_bdim,
598+ const Tensor& value, optional<int64_t > value_bdim) {
599+
600+ // std::cout << "index_fill_int_tensor_batch_rule: "
601+ // << ((index_bdim) ? "true" : "false") << " "
602+ // << ((value_bdim) ? "true" : "false") << " "
603+ // << std::endl;
604+ if (!index_bdim && !value_bdim) {
605+ // Handle scalar tensors... self, other can be scalar tensors
606+ const auto self_logical_rank = rankWithoutBatchDim (self, self_bdim);
607+ auto self_ = moveBatchDimToFront (self, self_bdim);
608+ if (self_logical_rank == 0 ) {
609+ self_ = self_.unsqueeze (-1 );
610+ }
611+ dim = maybe_wrap_dim (dim, self_logical_rank);
612+
613+ optional<int64_t > out_bdim = nullopt ;
614+ if (self_bdim) {
615+ const auto batch_size = self.size (*self_bdim);
616+ self_ = ensure_has_bdim (self_, self_bdim.has_value (), batch_size);
617+ dim = dim + 1 ;
618+ out_bdim = 0 ;
619+ }
620+ // std::cout << "1 index_fill, self_: " << self_.sizes() << " index: " << index.sizes() << std::endl;
621+ auto result = self_.index_fill (dim, index, value);
622+ if (self_logical_rank == 0 ) {
623+ result = result.squeeze (-1 );
624+ }
625+ return std::make_tuple (result, out_bdim);
626+ }
627+
628+ // SAME AS FOR index_add
629+ // Index is batched. For-loop and stack is the best thing I can come up with
630+ // right now. We really want generalized index_fill kernel in PyTorch
631+ auto batch_size = get_bdim_size3 (self, self_bdim, index, index_bdim, value, value_bdim);
632+ std::vector<Tensor> results;
633+ results.reserve (batch_size);
634+ // std::cout << "2 index_fill loop: " << std::endl;
635+ for (const auto i : c10::irange (0 , batch_size)) {
636+ const auto & self_slice = self_bdim.has_value () ?
637+ self.select (*self_bdim, i) : self;
638+ const auto & index_slice = index_bdim.has_value () ?
639+ index.select (*index_bdim, i) : index;
640+ const auto & value_slice = value_bdim.has_value () ?
641+ value.select (*value_bdim, i) : value;
642+ // std::cout << i << " self_: " << self_slice.sizes() << " index: " << index_slice.sizes() << " value: " << value_slice.sizes() << std::endl;
643+ results.push_back (at::index_fill (self_slice, dim, index_slice, value_slice));
644+ }
645+ return std::make_tuple (at::stack (results), 0 );
646+ }
647+
544648TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
545649 m.impl (" index.Tensor" , index_plumbing);
546650 m.impl (" index_put_" , index_put__plumbing);
@@ -550,6 +654,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
550654 m.impl (" index_copy" , index_copy_decomp);
551655 m.impl (" index_select" , index_select_decomp);
552656 VMAP_SUPPORT (" index_add" , index_add_batch_rule);
657+ VMAP_SUPPORT (" index_fill.int_Scalar" , index_fill_int_scalar_batch_rule);
658+ VMAP_SUPPORT (" index_fill.int_Tensor" , index_fill_int_tensor_batch_rule);
553659 VMAP_SUPPORT (" diagonal_scatter" , diagonal_scatter_batch_rule);
554660 VMAP_SUPPORT (" gather" , gather_batch_rule);
555661 VMAP_SUPPORT (" gather_backward" , gather_backward_batch_rule);
0 commit comments