4646
4747#include " integer_advanced_indexing.hpp"
4848
49- #define INDEXING_MODES 2
50- #define WRAP_MODE 0
51- #define CLIP_MODE 1
52-
5349namespace dpctl
5450{
5551namespace tensor
@@ -62,11 +58,15 @@ namespace td_ns = dpctl::tensor::type_dispatch;
6258using dpctl::tensor::kernels::indexing::put_fn_ptr_t ;
6359using dpctl::tensor::kernels::indexing::take_fn_ptr_t ;
6460
65- static take_fn_ptr_t take_dispatch_table[INDEXING_MODES][td_ns::num_types]
66- [td_ns::num_types];
61+ static take_fn_ptr_t take_wrap_dispatch_table[td_ns::num_types]
62+ [td_ns::num_types];
63+
64+ static take_fn_ptr_t take_clip_dispatch_table[td_ns::num_types]
65+ [td_ns::num_types];
66+
67+ static put_fn_ptr_t put_wrap_dispatch_table[td_ns::num_types][td_ns::num_types];
6768
68- static put_fn_ptr_t put_dispatch_table[INDEXING_MODES][td_ns::num_types]
69- [td_ns::num_types];
69+ static put_fn_ptr_t put_clip_dispatch_table[td_ns::num_types][td_ns::num_types];
7070
7171namespace py = pybind11;
7272
@@ -486,7 +486,8 @@ py_take(const dpctl::tensor::usm_ndarray &src,
486486 std::end (pack_deps));
487487 all_deps.insert (std::end (all_deps), std::begin (depends), std::end (depends));
488488
489- auto fn = take_dispatch_table[mode][src_type_id][ind_type_id];
489+ auto fn = mode ? take_wrap_dispatch_table[src_type_id][ind_type_id]
490+ : take_clip_dispatch_table[src_type_id][ind_type_id];
490491
491492 if (fn == nullptr ) {
492493 sycl::event::wait (host_task_events);
@@ -755,7 +756,8 @@ py_put(const dpctl::tensor::usm_ndarray &dst,
755756 std::end (pack_deps));
756757 all_deps.insert (std::end (all_deps), std::begin (depends), std::end (depends));
757758
758- auto fn = put_dispatch_table[mode][dst_type_id][ind_type_id];
759+ auto fn = mode ? put_wrap_dispatch_table[dst_type_id][ind_type_id]
760+ : put_clip_dispatch_table[dst_type_id][ind_type_id];
759761
760762 if (fn == nullptr ) {
761763 sycl::event::wait (host_task_events);
@@ -790,20 +792,20 @@ void init_advanced_indexing_dispatch_tables(void)
790792 using dpctl::tensor::kernels::indexing::TakeClipFactory;
791793 DispatchTableBuilder<take_fn_ptr_t , TakeClipFactory, num_types>
792794 dtb_takeclip;
793- dtb_takeclip.populate_dispatch_table (take_dispatch_table[CLIP_MODE] );
795+ dtb_takeclip.populate_dispatch_table (take_clip_dispatch_table );
794796
795797 using dpctl::tensor::kernels::indexing::TakeWrapFactory;
796798 DispatchTableBuilder<take_fn_ptr_t , TakeWrapFactory, num_types>
797799 dtb_takewrap;
798- dtb_takewrap.populate_dispatch_table (take_dispatch_table[WRAP_MODE] );
800+ dtb_takewrap.populate_dispatch_table (take_wrap_dispatch_table );
799801
800802 using dpctl::tensor::kernels::indexing::PutClipFactory;
801803 DispatchTableBuilder<put_fn_ptr_t , PutClipFactory, num_types> dtb_putclip;
802- dtb_putclip.populate_dispatch_table (put_dispatch_table[CLIP_MODE] );
804+ dtb_putclip.populate_dispatch_table (put_clip_dispatch_table );
803805
804806 using dpctl::tensor::kernels::indexing::PutWrapFactory;
805807 DispatchTableBuilder<put_fn_ptr_t , PutWrapFactory, num_types> dtb_putwrap;
806- dtb_putwrap.populate_dispatch_table (put_dispatch_table[WRAP_MODE] );
808+ dtb_putwrap.populate_dispatch_table (put_wrap_dispatch_table );
807809}
808810
809811} // namespace py_internal
0 commit comments