Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,26 @@
#include <ATen/ATen.h>
#include <cachelib/allocator/CacheAllocator.h>
#include <cachelib/facebook/admin/CacheAdmin.h>
#include <cachelib/object_cache/ObjectCache.h>

#include <cstdint>
#include <iostream>
#include <vector>

namespace l2_cache {

// Embedding data structure for ObjectCache
struct EmbeddingValue {
std::vector<uint8_t> data;

EmbeddingValue() = default;
explicit EmbeddingValue(size_t size) : data(size) {}
explicit EmbeddingValue(const void* ptr, size_t size)
: data(
static_cast<const uint8_t*>(ptr),
static_cast<const uint8_t*>(ptr) + size) {}
};

/// @ingroup embedding-ssd
///
/// @brief A Cachelib wrapper class for Cachlib interaction
Expand All @@ -27,16 +41,26 @@ namespace l2_cache {
/// Cachelib related optimization will be captured inside this class
/// e.g. fetch and delayed markUseful to boost up get performance
///
/// Supports both raw allocator mode (default) and object cache mode.
/// Object cache mode enables optimizations for object-oriented access
/// patterns and can be enabled via CacheConfig.use_object_cache flag.
///
/// @note that this class only handles single Cachelib read/update.
/// parallelism is done on the caller side
class CacheLibCache {
public:
using Cache = facebook::cachelib::LruAllocator;
using ObjectCache = facebook::cachelib::objcache2::ObjectCache<Cache>;

struct CacheConfig {
size_t cache_size_bytes;
size_t item_size_bytes;
size_t num_shards;
int64_t max_D_;
// Enable object cache mode for optimized object-oriented access patterns.
// When enabled, activates background item reaper and other object cache
// specific optimizations.
bool use_object_cache = false;
};

explicit CacheLibCache(
Expand All @@ -48,11 +72,23 @@ class CacheLibCache {
std::unique_ptr<Cache> initializeCacheLib(const CacheConfig& config);

std::unique_ptr<facebook::cachelib::CacheAdmin> createCacheAdmin(
Cache& cache);
Cache& cache,
bool is_object_cache);

// Template overload to accept ObjectCache
template <typename CacheType>
std::unique_ptr<facebook::cachelib::CacheAdmin> createCacheAdmin(
CacheType& cache,
bool is_object_cache);

/// Find the stored embeddings from a given embedding indices, aka key
///
/// @param key_tensor embedding index(tensor with only one element) to look up
/// @param object_cache_value_out Optional output parameter for ObjectCache
/// mode.
/// If provided and ObjectCache is enabled, stores the shared_ptr to
/// keep the value alive. This prevents use-after-free in
/// multi-threaded contexts.
///
/// @return an optional value, return none on cache misses, if cache hit
/// return a pointer to the cachelib underlying storage of associated
Expand All @@ -61,7 +97,9 @@ class CacheLibCache {
/// @note that this is not thread safe, caller needs to make sure the data is
/// fully processed before doing cache insertion, otherwise the returned space
/// might be overwritten if cache is full
folly::Optional<void*> get(const at::Tensor& key_tensor);
folly::Optional<void*> get(
const at::Tensor& key_tensor,
std::shared_ptr<EmbeddingValue>* object_cache_value_out = nullptr);

/// Cachelib wrapper specific hash function
///
Expand Down Expand Up @@ -152,12 +190,23 @@ class CacheLibCache {
std::unique_ptr<Cache> cache_;
std::vector<facebook::cachelib::PoolId> pool_ids_;
std::unique_ptr<facebook::cachelib::CacheAdmin> admin_;
std::unique_ptr<ObjectCache> object_cache_;

folly::Optional<at::Tensor> evicted_indices_opt_{folly::none};
folly::Optional<at::Tensor> evicted_weights_opt_{folly::none};
folly::Optional<at::ScalarType> index_dtype_{folly::none};
folly::Optional<at::ScalarType> weights_dtype_{folly::none};
std::atomic<int64_t> eviction_row_id{0};

// Keep the last retrieved ObjectCache value alive to prevent use-after-free
std::shared_ptr<EmbeddingValue> last_retrieved_value_;

// Helper methods for ObjectCache mode
void initializeObjectCacheInternal(int64_t rough_num_items);
folly::Optional<void*> getFromObjectCache(
const at::Tensor& key_tensor,
std::shared_ptr<EmbeddingValue>* object_cache_value_out);
bool putToObjectCache(const at::Tensor& key_tensor, const at::Tensor& data);
};

} // namespace l2_cache
Loading
Loading