@@ -298,81 +298,77 @@ template <class T, class String>
298298size_t StringDictionary::getBulk (const std::vector<String>& string_vec,
299299 T* encoded_vec,
300300 const int64_t generation) const {
301- CHECK (!base_dict_) << " Not implemented" ;
302301 constexpr int64_t target_strings_per_thread{1000 };
303302 const int64_t num_lookup_strings = string_vec.size ();
304303 if (num_lookup_strings == 0 ) {
305304 return 0 ;
306305 }
307306
308- const ThreadInfo thread_info (
309- std::thread::hardware_concurrency (), num_lookup_strings, target_strings_per_thread);
310- CHECK_GE (thread_info.num_threads , 1L );
311- CHECK_GE (thread_info.num_elems_per_thread , 1L );
312-
313- std::vector<size_t > num_strings_not_found_per_thread (thread_info.num_threads , 0UL );
307+ size_t base_num_strings_not_found = string_vec.size ();
308+ if (base_dict_) {
309+ auto base_generation_for_bulk =
310+ generation >= 0 ? std::min (generation, base_generation_) : base_generation_;
311+ base_num_strings_not_found =
312+ base_dict_->getBulk (string_vec, encoded_vec, base_generation_for_bulk);
313+ }
314314
315315 mapd_shared_lock<mapd_shared_mutex> read_lock (rw_mutex_);
316- const int64_t num_dict_strings = generation >= 0 ? generation : storageEntryCount ();
317- const bool dictionary_is_empty = (num_dict_strings == 0 );
318- if (dictionary_is_empty) {
319- tbb::parallel_for (tbb::blocked_range<int64_t >(0 , num_lookup_strings),
320- [&](const tbb::blocked_range<int64_t >& r) {
321- const int64_t start_idx = r.begin ();
322- const int64_t end_idx = r.end ();
323- for (int64_t string_idx = start_idx; string_idx < end_idx;
324- ++string_idx) {
325- encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID;
326- }
327- });
328- return num_lookup_strings;
316+ const int64_t num_dict_strings = generation >= 0 ? generation : entryCount ();
317+ const bool skip_owned_string =
318+ (num_dict_strings <= base_generation_) || !base_num_strings_not_found;
319+ if (skip_owned_string) {
320+ // Need to fill the resulting vector if it wasn't done by the base dictionary.
321+ if (!base_dict_) {
322+ tbb::parallel_for (tbb::blocked_range<int64_t >(
323+ 0 , num_lookup_strings, (size_t )64 << 10 /* 256KB chunks*/ ),
324+ [&](const tbb::blocked_range<int64_t >& r) {
325+ const int64_t start_idx = r.begin ();
326+ const int64_t end_idx = r.end ();
327+ for (int64_t string_idx = start_idx; string_idx < end_idx;
328+ ++string_idx) {
329+ encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID;
330+ }
331+ });
332+ }
333+ return base_num_strings_not_found;
329334 }
330335 // If we're here the generation-capped dictionary has strings in it
331336 // that we need to look up against
332-
333- tbb::task_arena limited_arena (thread_info.num_threads );
334- limited_arena.execute ([&] {
335- CHECK_LE (tbb::this_task_arena::max_concurrency (), thread_info.num_threads );
336- tbb::parallel_for (
337- tbb::blocked_range<int64_t >(
338- 0 , num_lookup_strings, thread_info.num_elems_per_thread /* tbb grain_size */ ),
339- [&](const tbb::blocked_range<int64_t >& r) {
340- const int64_t start_idx = r.begin ();
341- const int64_t end_idx = r.end ();
342- size_t num_strings_not_found = 0 ;
343- for (int64_t string_idx = start_idx; string_idx != end_idx; ++string_idx) {
344- const auto & input_string = string_vec[string_idx];
345- if (input_string.empty ()) {
346- encoded_vec[string_idx] = inline_int_null_value<T>();
347- continue ;
348- }
349- if (input_string.size () > StringDictionary::MAX_STRLEN) {
350- throw_string_too_long_error (input_string, dict_ref_);
351- }
352- const uint32_t input_string_hash = hash_string (input_string);
353- uint32_t hash_bucket =
354- computeBucket (input_string_hash, input_string, string_id_uint32_table_);
355- // Will either be legit id or INVALID_STR_ID
356- const auto string_id = string_id_uint32_table_[hash_bucket];
357- if (string_id == StringDictionary::INVALID_STR_ID ||
358- string_id >= num_dict_strings) {
359- encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID;
360- num_strings_not_found++;
361- continue ;
362- }
363- encoded_vec[string_idx] = string_id;
337+ size_t found_owned = tbb::parallel_reduce (
338+ tbb::blocked_range<int64_t >(
339+ 0 , num_lookup_strings, target_strings_per_thread /* tbb grain_size */ ),
340+ (size_t )0 ,
341+ [&](const tbb::blocked_range<int64_t >& r, size_t found) {
342+ const int64_t start_idx = r.begin ();
343+ const int64_t end_idx = r.end ();
344+ for (int64_t string_idx = start_idx; string_idx != end_idx; ++string_idx) {
345+ if (base_dict_ && encoded_vec[string_idx] != StringDictionary::INVALID_STR_ID) {
346+ continue ;
364347 }
365- const size_t tbb_thread_idx = tbb::this_task_arena::current_thread_index ();
366- num_strings_not_found_per_thread[tbb_thread_idx] = num_strings_not_found;
367- },
368- tbb::simple_partitioner ());
369- });
348+ const auto & input_string = string_vec[string_idx];
349+ if (input_string.empty ()) {
350+ encoded_vec[string_idx] = inline_int_null_value<T>();
351+ ++found;
352+ continue ;
353+ }
354+ if (input_string.size () > StringDictionary::MAX_STRLEN) {
355+ throw_string_too_long_error (input_string, dict_ref_);
356+ }
357+ // Will either be legit id or INVALID_STR_ID
358+ const auto string_id = getOwnedUnlocked (input_string);
359+ if (string_id == StringDictionary::INVALID_STR_ID ||
360+ string_id >= num_dict_strings) {
361+ encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID;
362+ continue ;
363+ }
364+ encoded_vec[string_idx] = string_id;
365+ ++found;
366+ }
367+ return found;
368+ },
369+ std::plus<size_t >());
370370
371- size_t num_strings_not_found = 0 ;
372- for (int64_t thread_idx = 0 ; thread_idx < thread_info.num_threads ; ++thread_idx) {
373- num_strings_not_found += num_strings_not_found_per_thread[thread_idx];
374- }
375- return num_strings_not_found;
371+ return base_num_strings_not_found - found_owned;
376372}
377373
378374template size_t StringDictionary::getBulk (const std::vector<std::string>& string_vec,
@@ -565,6 +561,15 @@ int32_t StringDictionary::getUnlocked(const std::string_view sv,
565561 return base_res;
566562 }
567563 }
564+ return getOwnedUnlocked (sv, hash);
565+ }
566+
567+ int32_t StringDictionary::getOwnedUnlocked (const std::string_view sv) const noexcept {
568+ return getOwnedUnlocked (sv, hash_string (sv));
569+ }
570+
571+ int32_t StringDictionary::getOwnedUnlocked (const std::string_view sv,
572+ const uint32_t hash) const noexcept {
568573 auto str_id = string_id_uint32_table_[computeBucket (hash, sv, string_id_uint32_table_)];
569574 return str_id;
570575}
0 commit comments