diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h index d59b0175e4..3ab3225266 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h @@ -10,12 +10,26 @@ #include #include #include +#include #include #include +#include namespace l2_cache { +// Embedding data structure for ObjectCache +struct EmbeddingValue { + std::vector data; + + EmbeddingValue() = default; + explicit EmbeddingValue(size_t size) : data(size) {} + explicit EmbeddingValue(const void* ptr, size_t size) + : data( + static_cast(ptr), + static_cast(ptr) + size) {} +}; + /// @ingroup embedding-ssd /// /// @brief A Cachelib wrapper class for Cachlib interaction @@ -27,16 +41,26 @@ namespace l2_cache { /// Cachelib related optimization will be captured inside this class /// e.g. fetch and delayed markUseful to boost up get performance /// +/// Supports both raw allocator mode (default) and object cache mode. +/// Object cache mode enables optimizations for object-oriented access +/// patterns and can be enabled via CacheConfig.use_object_cache flag. +/// /// @note that this class only handles single Cachelib read/update. /// parallelism is done on the caller side class CacheLibCache { public: using Cache = facebook::cachelib::LruAllocator; + using ObjectCache = facebook::cachelib::objcache2::ObjectCache; + struct CacheConfig { size_t cache_size_bytes; size_t item_size_bytes; size_t num_shards; int64_t max_D_; + // Enable object cache mode for optimized object-oriented access patterns. + // When enabled, activates background item reaper and other object cache + // specific optimizations. + bool use_object_cache = false; }; explicit CacheLibCache( @@ -48,11 +72,23 @@ class CacheLibCache { std::unique_ptr initializeCacheLib(const CacheConfig& config); std::unique_ptr createCacheAdmin( - Cache& cache); + Cache& cache, + bool is_object_cache); + + // Template overload to accept ObjectCache + template + std::unique_ptr createCacheAdmin( + CacheType& cache, + bool is_object_cache); /// Find the stored embeddings from a given embedding indices, aka key /// /// @param key_tensor embedding index(tensor with only one element) to look up + /// @param object_cache_value_out Optional output parameter for ObjectCache + /// mode. + /// If provided and ObjectCache is enabled, stores the shared_ptr to + /// keep the value alive. This prevents use-after-free in + /// multi-threaded contexts. /// /// @return an optional value, return none on cache misses, if cache hit /// return a pointer to the cachelib underlying storage of associated @@ -61,7 +97,9 @@ class CacheLibCache { /// @note that this is not thread safe, caller needs to make sure the data is /// fully processed before doing cache insertion, otherwise the returned space /// might be overwritten if cache is full - folly::Optional get(const at::Tensor& key_tensor); + folly::Optional get( + const at::Tensor& key_tensor, + std::shared_ptr* object_cache_value_out = nullptr); /// Cachelib wrapper specific hash function /// @@ -152,12 +190,23 @@ class CacheLibCache { std::unique_ptr cache_; std::vector pool_ids_; std::unique_ptr admin_; + std::unique_ptr object_cache_; folly::Optional evicted_indices_opt_{folly::none}; folly::Optional evicted_weights_opt_{folly::none}; folly::Optional index_dtype_{folly::none}; folly::Optional weights_dtype_{folly::none}; std::atomic eviction_row_id{0}; + + // Keep the last retrieved ObjectCache value alive to prevent use-after-free + std::shared_ptr last_retrieved_value_; + + // Helper methods for ObjectCache mode + void initializeObjectCacheInternal(int64_t rough_num_items); + folly::Optional getFromObjectCache( + const at::Tensor& key_tensor, + std::shared_ptr* object_cache_value_out); + bool putToObjectCache(const at::Tensor& key_tensor, const at::Tensor& data); }; } // namespace l2_cache diff --git a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp index fd38ca1c89..2b2b35d9af 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp @@ -7,6 +7,7 @@ */ #include "fbgemm_gpu/split_embeddings_cache/cachelib_cache.h" +#include #include #include "fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h" #include "fbgemm_gpu/utils/dispatch_macros.h" @@ -18,19 +19,31 @@ using Cache = facebook::cachelib::LruAllocator; CacheLibCache::CacheLibCache( const CacheConfig& cache_config, int64_t unique_tbe_id) - : cache_config_(cache_config), - unique_tbe_id_(unique_tbe_id), - cache_(initializeCacheLib(cache_config_)), - admin_(createCacheAdmin(*cache_)) { - for (size_t i = 0; i < cache_config_.num_shards; i++) { - pool_ids_.push_back(cache_->addPool( - fmt::format("shard_{}", i), - cache_->getCacheMemoryStats().ramCacheSize / cache_config_.num_shards, - std::set{}, - Cache::MMConfig{ - 0, /* promote on every access*/ - true, /*enable promotion on write*/ - true /*enable promotion on read*/})); + : cache_config_(cache_config), unique_tbe_id_(unique_tbe_id) { + // Initialize cache - this creates either regular Cache or ObjectCache + // For ObjectCache mode, this sets object_cache_ and returns nullptr + // For regular mode, this returns the Cache and object_cache_ remains nullptr + cache_ = initializeCacheLib(cache_config_); + + // Create admin for the underlying cache (works for both modes) + if (cache_config_.use_object_cache) { + // ObjectCache mode: create admin for the ObjectCache directly + admin_ = createCacheAdmin(*object_cache_, true); + } else { + // Regular mode: create admin for the cache + admin_ = createCacheAdmin(*cache_, false); + + // Initialize pools only for regular allocator mode + for (size_t i = 0; i < cache_config_.num_shards; i++) { + pool_ids_.push_back(cache_->addPool( + fmt::format("shard_{}", i), + cache_->getCacheMemoryStats().ramCacheSize / cache_config_.num_shards, + std::set{}, + Cache::MMConfig{ + 0, /* promote on every access*/ + true, /*enable promotion on write*/ + true /*enable promotion on read*/})); + } } } @@ -39,11 +52,17 @@ size_t CacheLibCache::get_cache_item_size() const { } Cache::AccessIterator CacheLibCache::begin() { + if (cache_config_.use_object_cache) { + // ObjectCache has its own iterator - use the underlying L1 cache + return object_cache_->begin(); + } return cache_->begin(); } std::unique_ptr CacheLibCache::initializeCacheLib( const CacheConfig& config) { + // Setup eviction callback (used for both regular Cache and ObjectCache's + // underlying cache) auto eviction_cb = [this]( const facebook::cachelib::LruAllocator::RemoveCbData& data) { @@ -79,40 +98,171 @@ std::unique_ptr CacheLibCache::initializeCacheLib( }); } }; - Cache::Config cacheLibConfig; + int64_t rough_num_items = cache_config_.cache_size_bytes / cache_config_.item_size_bytes; unsigned int bucket_power = std::log(rough_num_items) / std::log(2) + 1; - // 15 here is a magic number between 10 and 20 unsigned int lock_power = std::log(cache_config_.num_shards * 15) / std::log(2) + 1; + XLOG(INFO) << fmt::format( "[TBE_ID{}] Setting up Cachelib for L2 cache, capacity: {}GB, " - "item_size: {}B, max_num_items: {}, bucket_power: {}, lock_power: {}", + "item_size: {}B, max_num_items: {}, bucket_power: {}, lock_power: {}, " + "use_object_cache: {}", unique_tbe_id_, config.cache_size_bytes / 1024 / 1024 / 1024, cache_config_.item_size_bytes, rough_num_items, bucket_power, - lock_power); + lock_power, + cache_config_.use_object_cache); + + // For ObjectCache mode, create ObjectCache which manages its own cache + if (cache_config_.use_object_cache) { + initializeObjectCacheInternal(rough_num_items); + // Return nullptr for cache_ since ObjectCache manages its own + // We'll access the underlying cache via object_cache_->getL1Cache() when + // needed + return nullptr; + } + + // Regular allocator mode - create Cache with eviction callback + Cache::Config cacheLibConfig; cacheLibConfig.setCacheSize(static_cast(config.cache_size_bytes)) .setRemoveCallback(eviction_cb) .setCacheName("TBEL2Cache") .setAccessConfig({bucket_power, lock_power}) - .setFullCoredump(false) - .validate(); + .setFullCoredump(false); + + cacheLibConfig.validate(); return std::make_unique(cacheLibConfig); } +void CacheLibCache::initializeObjectCacheInternal(int64_t rough_num_items) { + XLOG(INFO) << fmt::format( + "[TBE_ID{}] ObjectCache mode enabled, initializing object cache", + unique_tbe_id_); + + // Configure ObjectCache + typename ObjectCache::Config objCacheConfig; + objCacheConfig.setCacheName( + fmt::format("TBEL2ObjectCache_{}", unique_tbe_id_)); + objCacheConfig.setCacheCapacity(rough_num_items); + + // Enable object size tracking to support getTotalObjectSize() + objCacheConfig.objectSizeTrackingEnabled = true; + + // Set up item destructor for ObjectCache (handles evictions and removals) + objCacheConfig.setItemDestructor([this]( + facebook::cachelib::objcache2:: + ObjectCacheDestructorData data) { + // Handle evictions with the same callback logic + if (data.context == + facebook::cachelib::objcache2::ObjectCacheDestructorContext::kEvicted) { + // Track evictions for L2 cache eviction handling + if (evicted_weights_opt_.has_value()) { + FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( + evicted_weights_opt_->scalar_type(), "l2_eviction_handling", [&] { + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + evicted_indices_opt_->scalar_type(), + "l2_eviction_handling", + [&] { + using index_t = scalar_t; + auto indices_data_ptr = + evicted_indices_opt_->data_ptr(); + auto weights_data_ptr = + evicted_weights_opt_->data_ptr(); + auto row_id = eviction_row_id++; + auto weight_dim = evicted_weights_opt_->size(1); + + // Parse key from the data.key (StringPiece) + const auto key_ptr = + reinterpret_cast(data.key.data()); + indices_data_ptr[row_id] = *key_ptr; + + // Get the embedding value and copy data + auto* embedding_val = + reinterpret_cast(data.objectPtr); + std::copy( + reinterpret_cast( + embedding_val->data.data()), + reinterpret_cast( + embedding_val->data.data()) + + weight_dim, + &weights_data_ptr[row_id * weight_dim]); + }); + }); + } + } + + // Clean up the EmbeddingValue object + if (data.context == + facebook::cachelib::objcache2::ObjectCacheDestructorContext:: + kEvicted || + data.context == + facebook::cachelib::objcache2::ObjectCacheDestructorContext:: + kRemoved) { + data.deleteObject(); + } + }); + + // Enable background item reaper for ObjectCache mode + objCacheConfig.setItemReaperInterval(std::chrono::seconds{1}); + + // Create the ObjectCache + object_cache_ = ObjectCache::create(objCacheConfig); +} + std::unique_ptr CacheLibCache::createCacheAdmin( - Cache& cache) { + Cache& cache, + bool is_object_cache) { facebook::cachelib::CacheAdmin::Config adminConfig; adminConfig.oncall = "mvai"; + // Disable background stats exporters for ObjectCache mode to avoid + // race conditions and crashes during initialization + if (is_object_cache) { + adminConfig.globalOdsInterval = std::chrono::seconds{0}; + adminConfig.serviceDataStatsInterval = std::chrono::seconds{0}; + adminConfig.poolRebalancerStatsInterval = std::chrono::seconds{0}; + } return std::make_unique( cache, std::move(adminConfig)); } -folly::Optional CacheLibCache::get(const at::Tensor& key_tensor) { +// Template implementation for ObjectCache +template +std::unique_ptr CacheLibCache::createCacheAdmin( + CacheType& cache, + bool is_object_cache) { + facebook::cachelib::CacheAdmin::Config adminConfig; + adminConfig.oncall = "mvai"; + // Disable background stats exporters for ObjectCache mode to avoid + // race conditions and crashes during initialization + if (is_object_cache) { + adminConfig.globalOdsInterval = std::chrono::seconds{0}; + adminConfig.serviceDataStatsInterval = std::chrono::seconds{0}; + adminConfig.poolRebalancerStatsInterval = std::chrono::seconds{0}; + } + return std::make_unique( + cache, std::move(adminConfig)); +} + +// Explicit template instantiation for ObjectCache +template std::unique_ptr +CacheLibCache::createCacheAdmin( + CacheLibCache::ObjectCache& cache, + bool is_object_cache); + +folly::Optional CacheLibCache::get( + const at::Tensor& key_tensor, + std::shared_ptr* object_cache_value_out) { + // Use ObjectCache if enabled + if (cache_config_.use_object_cache) { + return getFromObjectCache(key_tensor, object_cache_value_out); + } + + // Fallback to regular allocator mode folly::Optional res; FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "get", [&] { using index_t = scalar_t; @@ -142,9 +292,14 @@ void CacheLibCache::batchMarkUseful( if (read_handles.empty()) { return; } + + auto* cache_ptr = cache_config_.use_object_cache + ? &(object_cache_->getL1Cache()) + : cache_.get(); + for (auto& handle : read_handles) { if (handle) { - cache_->markUseful(handle, facebook::cachelib::AccessMode::kRead); + cache_ptr->markUseful(handle, facebook::cachelib::AccessMode::kRead); } } } @@ -156,6 +311,13 @@ bool CacheLibCache::put(const at::Tensor& key_tensor, const at::Tensor& data) { if (!weights_dtype_.has_value()) { weights_dtype_ = data.scalar_type(); } + + // Use ObjectCache if enabled + if (cache_config_.use_object_cache) { + return putToObjectCache(key_tensor, data); + } + + // Fallback to regular allocator mode bool res; FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "put", [&] { using index_t = scalar_t; @@ -189,6 +351,11 @@ CacheLibCache::get_n_items(int n, Cache::AccessIterator& itr) { if (!index_dtype_.has_value() || !weights_dtype_.has_value()) { return folly::none; } + + auto* cache_ptr = cache_config_.use_object_cache + ? &(object_cache_->getL1Cache()) + : cache_.get(); + auto weight_dim = cache_config_.max_D_; auto indices = at::empty( n, at::TensorOptions().dtype(index_dtype_.value()).device(at::kCPU)); @@ -204,7 +371,7 @@ CacheLibCache::get_n_items(int n, Cache::AccessIterator& itr) { using index_t = scalar_t; auto indices_data_ptr = indices.data_ptr(); auto weights_data_ptr = weights.data_ptr(); - for (; itr != cache_->end() && cnt < n; ++itr, ++cnt) { + for (; itr != cache_ptr->end() && cnt < n; ++itr, ++cnt) { const auto key_ptr = reinterpret_cast(itr->getKey().data()); indices_data_ptr[cnt] = *key_ptr; @@ -266,11 +433,82 @@ CacheLibCache::get_tensors_and_reset() { std::vector CacheLibCache::get_cache_usage() { std::vector cache_mem_stats(2, 0); // freeBytes, capacity cache_mem_stats[1] = cache_config_.cache_size_bytes; - for (auto& pool_id : pool_ids_) { - auto pool_stats = cache_->getPoolStats(pool_id); - cache_mem_stats[0] += pool_stats.freeMemoryBytes(); + + if (cache_config_.use_object_cache) { + // For ObjectCache mode, use cache-level stats instead of pool stats + // since ObjectCache manages its own internal pool structure + int64_t used_mem = object_cache_->getTotalObjectSize(); + cache_mem_stats[0] = cache_config_.cache_size_bytes - used_mem; + } else { + // For regular allocator mode, use the pool_ids we created + for (auto& pool_id : pool_ids_) { + auto pool_stats = cache_->getPoolStats(pool_id); + cache_mem_stats[0] += pool_stats.freeMemoryBytes(); + } } + return cache_mem_stats; } +folly::Optional CacheLibCache::getFromObjectCache( + const at::Tensor& key_tensor, + std::shared_ptr* object_cache_value_out) { + folly::Optional res; + FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "get", [&] { + using index_t = scalar_t; + auto key = *(key_tensor.data_ptr()); + + // Convert integer key to string key for ObjectCache + auto key_str = folly::StringPiece( + reinterpret_cast(&key), sizeof(index_t)); + + // Try to find the value in object cache + auto found = object_cache_->find(key_str); + if (!found) { + res = folly::none; + return; + } + + // Store the value to keep it alive and prevent use-after-free + // Cast away const since we need mutable access to the data + auto value_ptr = std::const_pointer_cast(found); + + // Store in output parameter if provided (for multi-threaded contexts) + // Otherwise fall back to member variable (for backward compatibility) + if (object_cache_value_out) { + *object_cache_value_out = value_ptr; + } else { + last_retrieved_value_ = value_ptr; + } + + res = static_cast(value_ptr->data.data()); + }); + return res; +} + +bool CacheLibCache::putToObjectCache( + const at::Tensor& key_tensor, + const at::Tensor& data) { + bool res = false; + FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "put", [&] { + using index_t = scalar_t; + auto key = *(key_tensor.data_ptr()); + + // Convert integer key to string key for ObjectCache + auto key_str = folly::StringPiece( + reinterpret_cast(&key), sizeof(index_t)); + + // Create EmbeddingValue from tensor data + auto embedding_value = + std::make_unique(data.data_ptr(), data.nbytes()); + + // Insert or replace in object cache + auto [insert_status, new_obj, old_obj] = + object_cache_->insertOrReplace( + key_str, std::move(embedding_value), data.nbytes()); + res = (insert_status == ObjectCache::AllocStatus::kSuccess); + }); + return res; +} + } // namespace l2_cache diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h index 0bdd431cf1..2a361c627c 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h @@ -46,7 +46,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { std::optional hash_size_cumsum = std::nullopt, int64_t flushing_block_size = 2000000000 /*2GB*/, bool disable_random_init = false, - bool enable_blob_db = false) + bool enable_blob_db = false, + bool use_object_cache = false) : impl_( std::make_shared( /*path=*/path, @@ -80,7 +81,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { /*hash_size_cumsum=*/hash_size_cumsum, /*flushing_block_size=*/flushing_block_size, /*disable_random_init=*/disable_random_init, - /*enable_blob_db=*/enable_blob_db)) {} + /*enable_blob_db=*/enable_blob_db, + /*use_object_cache=*/use_object_cache)) {} void set_cuda( at::Tensor indices, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index e1d2d0718e..4a12d655a1 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -79,8 +79,10 @@ EmbeddingKVDB::EmbeddingKVDB( std::vector table_names, std::vector table_offsets, const std::vector& table_sizes, - int64_t flushing_block_size) + int64_t flushing_block_size, + bool use_object_cache) : flushing_block_size_(flushing_block_size), + use_object_cache_(use_object_cache), unique_id_(unique_id), num_shards_(num_shards), max_D_(max_D), @@ -102,6 +104,7 @@ EmbeddingKVDB::EmbeddingKVDB( cache_config.num_shards = num_shards_; cache_config.item_size_bytes = max_D_ * ele_size_bytes; cache_config.max_D_ = max_D_; + cache_config.use_object_cache = use_object_cache; l2_cache_ = std::make_unique(cache_config, unique_id); } else { @@ -110,7 +113,8 @@ EmbeddingKVDB::EmbeddingKVDB( XLOG(INFO) << "[TBE_ID" << unique_id_ << "] L2 created with " << num_shards_ << " shards, dimension:" << max_D_ << ", enable_async_update_:" << enable_async_update_ - << ", cache_size_gb:" << cache_size_gb; + << ", cache_size_gb:" << cache_size_gb + << ", use_object_cache:" << use_object_cache; if (enable_async_update_) { cache_filling_thread_ = std::make_unique([=, this] { @@ -522,40 +526,50 @@ std::shared_ptr EmbeddingKVDB::get_cache( for (int i = 0; i < num_shards; i++) { row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); } - for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { - row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] - .emplace_back(row_id); + if (use_object_cache_) { + // ObjectCache mode: distribute rows round-robin across threads + // since ObjectCache doesn't use pool-based sharding + for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { + row_ids_per_shard[row_id % num_shards].emplace_back(row_id); + } + } else { + // Regular cache mode: use shard_id based on key hash + for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { + row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] + .emplace_back(row_id); + } } for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { auto f = folly::via(executor_tp_.get()) - .thenValue([shard_id, - indices, - row_ids_per_shard, - cache_context, - this](folly::Unit) { - FBGEMM_DISPATCH_INTEGRAL_TYPES( - indices.scalar_type(), "get_cache_inner", [&] { - using inner_index_t = scalar_t; - auto inner_indices_addr = - indices.data_ptr(); - for (const auto& row_id : row_ids_per_shard[shard_id]) { - auto emb_idx = inner_indices_addr[row_id]; - if (emb_idx < 0) { - continue; - } - auto cached_addr_opt = l2_cache_->get(indices[row_id]); - if (cached_addr_opt.has_value()) { // cache hit - cache_context->cached_addr_list[row_id] = - cached_addr_opt.value(); - inner_indices_addr[row_id] = - -1; // mark to sentinel value - } else { // cache miss - cache_context->num_misses += 1; - } - } - }); - }); + .thenValue( + [shard_id, indices, row_ids_per_shard, cache_context, this]( + folly::Unit) { + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "get_cache_inner", [&] { + using inner_index_t = scalar_t; + auto inner_indices_addr = + indices.data_ptr(); + for (const auto& row_id : + row_ids_per_shard[shard_id]) { + auto emb_idx = inner_indices_addr[row_id]; + if (emb_idx < 0) { + continue; + } + auto cached_addr_opt = l2_cache_->get( + indices[row_id], + &cache_context->object_cache_values[row_id]); + if (cached_addr_opt.has_value()) { // cache hit + cache_context->cached_addr_list[row_id] = + cached_addr_opt.value(); + inner_indices_addr[row_id] = + -1; // mark to sentinel value + } else { // cache miss + cache_context->num_misses += 1; + } + } + }); + }); futures.push_back(std::move(f)); } folly::collect(futures).wait(); @@ -629,9 +643,18 @@ EmbeddingKVDB::set_cache( for (int i = 0; i < num_shards; i++) { row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); } - for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { - row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] - .emplace_back(row_id); + if (use_object_cache_) { + // ObjectCache mode: distribute rows round-robin across threads + // since ObjectCache doesn't use pool-based sharding + for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { + row_ids_per_shard[row_id % num_shards].emplace_back(row_id); + } + } else { + // Regular cache mode: use shard_id based on key hash + for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { + row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] + .emplace_back(row_id); + } } for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index a8082af235..637b068a2a 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -61,11 +61,16 @@ class CacheContext { public: explicit CacheContext(size_t num_keys) { cached_addr_list = std::vector(num_keys, nullptr); + object_cache_values = + std::vector>( + num_keys, nullptr); } // invalid spot will stay as sentinel value, this is trading space for better // parallelism std::atomic num_misses{0}; std::vector cached_addr_list; + // For ObjectCache mode: keep shared_ptrs alive to prevent use-after-free + std::vector> object_cache_values; }; /// @ingroup embedding-ssd @@ -148,7 +153,8 @@ class EmbeddingKVDB : public std::enable_shared_from_this { std::vector table_names = {}, std::vector table_offsets = {}, const std::vector& table_sizes = {}, - int64_t flushing_block_size = 2000000000 /*2GB*/); + int64_t flushing_block_size = 2000000000 /*2GB*/, + bool use_object_cache = false); virtual ~EmbeddingKVDB(); @@ -500,6 +506,7 @@ class EmbeddingKVDB : public std::enable_shared_from_this { std::unique_ptr l2_cache_; // when flushing l2, the block size in bytes that we flush l2 progressively int64_t flushing_block_size_; + bool use_object_cache_; const int64_t unique_id_; const int64_t num_shards_; const int64_t max_D_; diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 677278fa54..6116dc354b 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -823,6 +823,7 @@ static auto embedding_rocks_db_wrapper = std::optional, int64_t, bool, + bool, bool>(), "", { @@ -857,6 +858,7 @@ static auto embedding_rocks_db_wrapper = torch::arg("flushing_block_size") = 2000000000 /* 2GB */, torch::arg("disable_random_init") = false, torch::arg("enable_blob_db") = false, + torch::arg("use_object_cache") = false, }) .def( "set_cuda", diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index f6def371e2..5fd27e123c 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -122,7 +122,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { std::optional hash_size_cumsum = std::nullopt, int64_t flushing_block_size = 2000000000 /*2GB*/, bool disable_random_init = false, - bool enable_blob_db = false) + bool enable_blob_db = false, + bool use_object_cache = false) : kv_db::EmbeddingKVDB( num_shards, max_D, @@ -136,7 +137,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { std::move(table_names), std::move(table_offsets), table_sizes, - flushing_block_size), + flushing_block_size, + use_object_cache), auto_compaction_enabled_(true), max_D_(max_D), elem_size_(row_storage_bitwidth / 8) { diff --git a/fbgemm_gpu/test/split_embeddings_cache/cachelib_cache_test.cpp b/fbgemm_gpu/test/split_embeddings_cache/cachelib_cache_test.cpp new file mode 100644 index 0000000000..146a76cb44 --- /dev/null +++ b/fbgemm_gpu/test/split_embeddings_cache/cachelib_cache_test.cpp @@ -0,0 +1,270 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include "fbgemm_gpu/split_embeddings_cache/cachelib_cache.h" // @manual=//deeplearning/fbgemm/fbgemm_gpu/src/split_embeddings_cache:cachelib_cache + +namespace l2_cache { + +/** + * @brief Tests basic put and get operations with regular allocator mode. + */ +TEST(CacheLibCacheTest, TestPutAndGetRegularMode) { + const int64_t EMBEDDING_DIM = 8; + const int64_t NUM_SHARDS = 4; + const int64_t CACHE_SIZE = 100 * 1024 * 1024; // 100MB + + CacheLibCache::CacheConfig config; + config.cache_size_bytes = CACHE_SIZE; + config.item_size_bytes = EMBEDDING_DIM * sizeof(float); + config.num_shards = NUM_SHARDS; + config.max_D_ = EMBEDDING_DIM; + config.use_object_cache = false; // Regular allocator mode + + auto cache = std::make_unique(config, 0 /* unique_tbe_id */); + + // Create test data + auto key1 = at::tensor({100}, at::TensorOptions().dtype(at::kLong)); + auto key2 = at::tensor({200}, at::TensorOptions().dtype(at::kLong)); + auto key3 = at::tensor({300}, at::TensorOptions().dtype(at::kLong)); + + auto data1 = at::ones({EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + auto data2 = + at::ones({EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)) * 2.0; + auto data3 = + at::ones({EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)) * 3.0; + + // Test put operations + EXPECT_TRUE(cache->put(key1, data1)); + EXPECT_TRUE(cache->put(key2, data2)); + EXPECT_TRUE(cache->put(key3, data3)); + + // Test get operations + auto result1 = cache->get(key1); + ASSERT_TRUE(result1.has_value()); + auto result1_tensor = at::from_blob( + result1.value(), {EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + EXPECT_TRUE(at::allclose(result1_tensor, data1)); + + auto result2 = cache->get(key2); + ASSERT_TRUE(result2.has_value()); + auto result2_tensor = at::from_blob( + result2.value(), {EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + EXPECT_TRUE(at::allclose(result2_tensor, data2)); + + auto result3 = cache->get(key3); + ASSERT_TRUE(result3.has_value()); + auto result3_tensor = at::from_blob( + result3.value(), {EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + EXPECT_TRUE(at::allclose(result3_tensor, data3)); + + // Test cache miss + auto key_miss = at::tensor({999}, at::TensorOptions().dtype(at::kLong)); + auto result_miss = cache->get(key_miss); + EXPECT_FALSE(result_miss.has_value()); +} + +/** + * @brief Tests basic put and get operations with ObjectCache mode. + */ +TEST(CacheLibCacheTest, TestPutAndGetObjectCacheMode) { + const int64_t EMBEDDING_DIM = 8; + const int64_t NUM_SHARDS = 4; + const int64_t CACHE_SIZE = 100 * 1024 * 1024; // 100MB + + CacheLibCache::CacheConfig config; + config.cache_size_bytes = CACHE_SIZE; + config.item_size_bytes = EMBEDDING_DIM * sizeof(float); + config.num_shards = NUM_SHARDS; + config.max_D_ = EMBEDDING_DIM; + config.use_object_cache = true; // ObjectCache mode + + auto cache = std::make_unique(config, 1 /* unique_tbe_id */); + + // Create test data + auto key1 = at::tensor({100}, at::TensorOptions().dtype(at::kLong)); + auto key2 = at::tensor({200}, at::TensorOptions().dtype(at::kLong)); + auto key3 = at::tensor({300}, at::TensorOptions().dtype(at::kLong)); + + auto data1 = at::ones({EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + auto data2 = + at::ones({EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)) * 2.0; + auto data3 = + at::ones({EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)) * 3.0; + + // Test put operations + EXPECT_TRUE(cache->put(key1, data1)); + EXPECT_TRUE(cache->put(key2, data2)); + EXPECT_TRUE(cache->put(key3, data3)); + + // Test get operations + auto result1 = cache->get(key1); + ASSERT_TRUE(result1.has_value()); + auto result1_tensor = at::from_blob( + result1.value(), {EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + EXPECT_TRUE(at::allclose(result1_tensor, data1)); + + auto result2 = cache->get(key2); + ASSERT_TRUE(result2.has_value()); + auto result2_tensor = at::from_blob( + result2.value(), {EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + EXPECT_TRUE(at::allclose(result2_tensor, data2)); + + auto result3 = cache->get(key3); + ASSERT_TRUE(result3.has_value()); + auto result3_tensor = at::from_blob( + result3.value(), {EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + EXPECT_TRUE(at::allclose(result3_tensor, data3)); + + // Test cache miss + auto key_miss = at::tensor({999}, at::TensorOptions().dtype(at::kLong)); + auto result_miss = cache->get(key_miss); + EXPECT_FALSE(result_miss.has_value()); +} + +/** + * @brief Tests cache update operations (put with existing key). + */ +TEST(CacheLibCacheTest, TestCacheUpdate) { + const int64_t EMBEDDING_DIM = 8; + const int64_t NUM_SHARDS = 4; + const int64_t CACHE_SIZE = 100 * 1024 * 1024; // 100MB + + CacheLibCache::CacheConfig config; + config.cache_size_bytes = CACHE_SIZE; + config.item_size_bytes = EMBEDDING_DIM * sizeof(float); + config.num_shards = NUM_SHARDS; + config.max_D_ = EMBEDDING_DIM; + config.use_object_cache = true; // ObjectCache mode + + auto cache = std::make_unique(config, 2 /* unique_tbe_id */); + + auto key = at::tensor({100}, at::TensorOptions().dtype(at::kLong)); + auto data1 = at::ones({EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + auto data2 = + at::ones({EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)) * 5.0; + + // Insert initial value + EXPECT_TRUE(cache->put(key, data1)); + + // Verify initial value + auto result1 = cache->get(key); + ASSERT_TRUE(result1.has_value()); + auto result1_tensor = at::from_blob( + result1.value(), {EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + EXPECT_TRUE(at::allclose(result1_tensor, data1)); + + // Update with new value + EXPECT_TRUE(cache->put(key, data2)); + + // Verify updated value + auto result2 = cache->get(key); + ASSERT_TRUE(result2.has_value()); + auto result2_tensor = at::from_blob( + result2.value(), {EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + EXPECT_TRUE(at::allclose(result2_tensor, data2)); +} + +/** + * @brief Tests cache usage statistics. + */ +TEST(CacheLibCacheTest, TestCacheUsageStats) { + const int64_t EMBEDDING_DIM = 8; + const int64_t NUM_SHARDS = 4; + const int64_t CACHE_SIZE = 100 * 1024 * 1024; // 100MB + + CacheLibCache::CacheConfig config; + config.cache_size_bytes = CACHE_SIZE; + config.item_size_bytes = EMBEDDING_DIM * sizeof(float); + config.num_shards = NUM_SHARDS; + config.max_D_ = EMBEDDING_DIM; + config.use_object_cache = false; // Regular mode + + auto cache = std::make_unique(config, 3 /* unique_tbe_id */); + + auto stats = cache->get_cache_usage(); + EXPECT_EQ(stats.size(), 2); // [freeBytes, capacity] + EXPECT_EQ(stats[1], CACHE_SIZE); // capacity should match config + EXPECT_GT(stats[0], 0); // should have some free bytes +} + +/** + * @brief Tests cache usage statistics with ObjectCache mode. + */ +TEST(CacheLibCacheTest, TestCacheUsageStatsObjectCache) { + const int64_t EMBEDDING_DIM = 8; + const int64_t NUM_SHARDS = 4; + const int64_t CACHE_SIZE = 100 * 1024 * 1024; // 100MB + + CacheLibCache::CacheConfig config; + config.cache_size_bytes = CACHE_SIZE; + config.item_size_bytes = EMBEDDING_DIM * sizeof(float); + config.num_shards = NUM_SHARDS; + config.max_D_ = EMBEDDING_DIM; + config.use_object_cache = true; // ObjectCache mode + + auto cache = std::make_unique(config, 5 /* unique_tbe_id */); + + // Get stats before inserting any data + auto stats_before = cache->get_cache_usage(); + EXPECT_EQ(stats_before.size(), 2); // [freeBytes, capacity] + EXPECT_EQ(stats_before[1], CACHE_SIZE); // capacity should match config + // With object size tracking enabled, should report full cache as free + EXPECT_EQ(stats_before[0], CACHE_SIZE); // all bytes should be free initially + + // Insert some data + auto key = at::tensor({100}, at::TensorOptions().dtype(at::kLong)); + auto data = at::ones({EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat)); + EXPECT_TRUE(cache->put(key, data)); + + // Get stats after inserting data + auto stats_after = cache->get_cache_usage(); + EXPECT_EQ(stats_after.size(), 2); + EXPECT_EQ(stats_after[1], CACHE_SIZE); // capacity unchanged + // Free bytes should be less than before since we inserted data + EXPECT_LT(stats_after[0], stats_before[0]); + // Used memory should equal the data size we inserted + int64_t used_memory = CACHE_SIZE - stats_after[0]; + int64_t expected_used = EMBEDDING_DIM * sizeof(float); + EXPECT_EQ(used_memory, expected_used); +} + +/** + * @brief Tests cache with different data types. + */ +TEST(CacheLibCacheTest, TestDifferentDataTypes) { + const int64_t EMBEDDING_DIM = 16; + const int64_t NUM_SHARDS = 4; + const int64_t CACHE_SIZE = 100 * 1024 * 1024; // 100MB + + CacheLibCache::CacheConfig config; + config.cache_size_bytes = CACHE_SIZE; + config.item_size_bytes = EMBEDDING_DIM * sizeof(uint8_t); + config.num_shards = NUM_SHARDS; + config.max_D_ = EMBEDDING_DIM; + config.use_object_cache = true; // ObjectCache mode + + auto cache = std::make_unique(config, 6 /* unique_tbe_id */); + + // Test with int64 keys and uint8 data + auto key = at::tensor({100}, at::TensorOptions().dtype(at::kLong)); + auto data = at::randint( + 0, 255, {EMBEDDING_DIM}, at::TensorOptions().dtype(at::kByte)); + + EXPECT_TRUE(cache->put(key, data)); + + auto result = cache->get(key); + ASSERT_TRUE(result.has_value()); + auto result_tensor = at::from_blob( + result.value(), {EMBEDDING_DIM}, at::TensorOptions().dtype(at::kByte)); + EXPECT_TRUE(at::equal(result_tensor, data)); +} + +} // namespace l2_cache diff --git a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py index 83d0351266..9e7aa2c8c8 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py @@ -72,6 +72,7 @@ def generate_fbgemm_kv_backend( eviction_policy: Optional[EvictionPolicy] = None, disable_random_init: bool = False, enable_blob_db: bool = False, + use_object_cache: bool = False, ) -> object: if backend_type == BackendType.SSD: assert ssd_directory @@ -100,6 +101,7 @@ def generate_fbgemm_kv_backend( flushing_block_size=flushing_block_size, disable_random_init=disable_random_init, enable_blob_db=enable_blob_db, + use_object_cache=use_object_cache, ) elif backend_type == BackendType.DRAM: eviction_config = None @@ -1451,3 +1453,52 @@ def test_ssd_blob_db_config( atol=1e-8, rtol=1e-8, ) + + @given(**default_st) + @settings(**default_settings) + def test_ssd_object_cache_config( + self, + T: int, + D: int, + log_E: int, + mixed: bool, + weights_precision: SparseType, + ) -> None: + """ + Test ssd db backend could be set and get as expected when use_object_cache config is turned on. + """ + E = int(10**log_E) + max_D = D * 4 + + with tempfile.TemporaryDirectory() as ssd_directory: + ssd_object_cache = self.generate_fbgemm_kv_backend( + max_D=max_D, + weight_precision=weights_precision, + backend_type=BackendType.SSD, + ssd_directory=ssd_directory, + disable_random_init=True, + use_object_cache=True, + ) + + set_indices = torch.as_tensor( + np.random.choice(E, replace=False, size=(5,)), dtype=torch.int64 + ) + get_indices = set_indices.clone() + values = torch.randn(5, max_D, dtype=weights_precision.as_dtype()) + count = torch.tensor([5], dtype=torch.int64) + # pyre-ignore + ssd_object_cache.set(set_indices, values, count) + + # pyre-ignore + ssd_object_cache.wait_util_filling_work_done() + + output = torch.empty_like(values) + # pyre-ignore + ssd_object_cache.get(get_indices, output, count) + + torch.testing.assert_close( + output, + values, + atol=1e-8, + rtol=1e-8, + )