Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 220 additions & 2 deletions src/EmbeddingSpMDMAutovec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,134 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
return current == index_size;
}

template <typename IndexType, typename OffsetType>
static bool ALWAYS_INLINE EmbeddingSpMDMNBitRowWiseSparse_autovec(
const int bit_rate,
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t uncompressed_data_size,
const uint8_t* input,
const IndexType* indices,
const int32_t* compressed_indices_table,
const OffsetType* offsets_or_lengths,
const float* weights,
const bool normalize_by_lengths,
float* out,
const bool is_weight_positional,
const bool use_offsets) {
if (uncompressed_data_size < 0) {
return false;
}

// block_size is the number of elements and fused_block_size is the size in
// bytes of an entire row, including scale and bias.
const int num_elem_per_byte = 8 / bit_rate;
const int64_t scale_bias_size = 2 * sizeof(float16);
const uint64_t scale_bias_offset = div_up(block_size, num_elem_per_byte);
const int64_t fused_block_size = scale_bias_offset + scale_bias_size;

int64_t current = 0;
float* buf = out;
for (int64_t m = 0; m < output_size; ++m) {
const OffsetType len = use_offsets
? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
: offsets_or_lengths[m];
const int64_t end = current + len;
if (end > index_size) {
return false;
}

memset(buf, 0, sizeof(float) * block_size);

const float* weights_addr = weights != nullptr
? (is_weight_positional ? weights : weights + current)
: nullptr;
for (; current < end; ++current) {
int64_t uncompressed_idx = indices[current];
if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) {
return false;
}
int64_t idx = compressed_indices_table[uncompressed_idx];
if (idx == -1) {
weights_addr++;
continue;
}

const uint8_t* input_row_base = input + fused_block_size * idx;
const uint8_t* scale_bias_addr = input_row_base + scale_bias_offset;

float scale =
cpu_half2float(*reinterpret_cast<const float16*>(scale_bias_addr));
float bias = cpu_half2float(
*reinterpret_cast<const float16*>(scale_bias_addr + sizeof(float16)));

if (weights != nullptr) {
float weight = *weights_addr++;
scale *= weight;
bias *= weight;
}

const uint8_t* input_row = input_row_base;
if (bit_rate == 4) {
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 2));
j += 2) {
uint8_t tmp = *input_row++;
float quantized1 = float(tmp & 0xf);
float quantized2 = float(tmp >> 4);
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
}
#endif
for (; j < block_size; j += 2) {
uint8_t tmp = *input_row++;
float quantized1 = float(tmp & 0xf);
float quantized2 = float(tmp >> 4);
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
}
} else if (bit_rate == 2) {
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 4));
j += 4) {
uint8_t tmp = *input_row++;
float quantized1 = float(tmp & 0x3);
float quantized2 = float((tmp & 0xC) >> 2);
float quantized3 = float((tmp & 0x30) >> 4);
float quantized4 = float(tmp >> 6);
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
buf[j + 2] = std::fma(scale, quantized3, buf[j + 2] + bias);
buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias);
}
#endif
for (; j < block_size; j += 4) {
uint8_t tmp = *input_row++;
float quantized1 = float(tmp & 0x3);
float quantized2 = float((tmp & 0xC) >> 2);
float quantized3 = float((tmp & 0x30) >> 4);
float quantized4 = float(tmp >> 6);
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
buf[j + 2] = std::fma(scale, quantized3, buf[j + 2] + bias);
buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias);
}
}
}
if (normalize_by_lengths && len) {
float scale = 1.f / len;
for (int j = 0; j < block_size; ++j) {
buf[j] *= scale;
}
}
buf += block_size;
}
return current == index_size;
}

/// @ingroup tbe-cpu-autovec
///
/// Autovectorized version of method `EmbeddingSpMDM_ref` for FP32 weight type.
Expand Down Expand Up @@ -778,8 +906,8 @@ static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec(
constexpr bool is8bit = std::is_same_v<InType, uint8_t>;

if constexpr (is8bit) {
// block_size is the number of elements and fused_block_size is the size
// of an entire row, including scale and bias.
// block_size is the number of elements and fused_block_size is the size in
// bytes of an entire row, including scale and bias.
const auto scale_bias_offset = 2 * sizeof(float);
const int64_t fused_block_size = block_size + scale_bias_offset;
int64_t current = 0;
Expand Down Expand Up @@ -2153,6 +2281,96 @@ INSTANTIATE_SPMDM_INDEX_T(std::uint8_t)
#undef INSTANTIATE_SPMDM_OUT_T
#undef INSTANTIATE_SPMDM_BASE

template <typename IndexType, typename OffsetType>
typename EmbeddingSpMDMRowWiseSparseKernelSignature<
uint8_t,
IndexType,
OffsetType>::Type
GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec(
int bit_rate,
int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
[[maybe_unused]] int prefetch,
bool is_weight_positional,
bool use_offsets) {
assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
using specialization_helper::fixed;
using specialization_helper::match;
using specialization_helper::specialize;
using specialization_helper::var;

#define SPECIALIZE( \
BIT_RATE, \
BLOCK_SIZE, \
HAS_WEIGHT, \
NORMALIZE_BY_LENGTHS, \
IS_WEIGHT_POSITIONAL, \
USE_OFFSETS) \
if (match(BIT_RATE, bit_rate) && match(BLOCK_SIZE, block_size) && \
match(HAS_WEIGHT, has_weight) && \
match(NORMALIZE_BY_LENGTHS, normalize_by_lengths) && \
match(IS_WEIGHT_POSITIONAL, is_weight_positional) && \
match(USE_OFFSETS, use_offsets)) { \
return [=](int64_t output_size, \
int64_t index_size, \
int64_t uncompressed_data_size, \
const uint8_t* input, \
const IndexType* indices, \
const OffsetType* offsets_or_lengths, \
const float* weights, \
float* out, \
const int32_t* compressed_indices_table) { \
if (specialize(HAS_WEIGHT, has_weight)) { \
__builtin_assume(weights != nullptr); \
} else { \
weights = nullptr; \
} \
return EmbeddingSpMDMNBitRowWiseSparse_autovec( \
/*bit_rate=*/specialize(BIT_RATE, bit_rate), \
/*block_size=*/specialize(BLOCK_SIZE, block_size), \
/*output_size=*/output_size, \
/*index_size=*/index_size, \
/*uncompressed_data_size=*/uncompressed_data_size, \
/*input=*/input, \
/*indices=*/indices, \
/*compressed_indices_table=*/compressed_indices_table, \
/*offsets_or_lengths=*/offsets_or_lengths, \
/*weights=*/weights, /*normalize_by_lengths=*/ \
specialize(NORMALIZE_BY_LENGTHS, normalize_by_lengths), \
/*out=*/out, /*is_weight_positional=*/ \
specialize(IS_WEIGHT_POSITIONAL, is_weight_positional), \
/*use_offsets=*/specialize(USE_OFFSETS, use_offsets)); \
}; \
}

SPECIALIZE(/*BIT_RATE*/ fixed(4), var, var, var, var, var);
SPECIALIZE(/*BIT_RATE*/ fixed(2), var, var, var, var, var);
abort(); // should not get here
#undef SPECIALIZE
}

#define INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(INDEX_TYPE, OFFSET_TYPE) \
template typename EmbeddingSpMDMRowWiseSparseKernelSignature< \
uint8_t, \
INDEX_TYPE, \
OFFSET_TYPE>::Type \
GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec<INDEX_TYPE, OFFSET_TYPE>( \
int bit_rate, \
int64_t block_size, \
bool has_weight, \
bool normalize_by_lengths, \
int prefetch, \
bool is_weight_positional, \
bool use_offsets);

INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int32_t, int32_t)
INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int32_t, int64_t)
INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int64_t, int32_t)
INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int64_t, int64_t)

#undef INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE

} // namespace fbgemm

#endif // #ifdef __linux__
14 changes: 14 additions & 0 deletions src/EmbeddingSpMDMAutovec.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ GenerateEmbeddingSpMDMRowWiseSparse_autovec(
bool is_weight_positional,
bool use_offsets);

template <typename IndexType, typename OffsetType>
typename EmbeddingSpMDMRowWiseSparseKernelSignature<
uint8_t,
IndexType,
OffsetType>::Type
GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec(
int bit_rate,
int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch,
bool is_weight_positional,
bool use_offsets);

} // namespace fbgemm

#endif // #ifdef __linux__
19 changes: 19 additions & 0 deletions src/EmbeddingSpMDMNBit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,25 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
}
#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64

#ifdef FBGEMM_AUTOVEC_AVAILABLE
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if ((fbgemmHasArmSve2Support() && !is_autovec_disabled()) ||
is_autovec_forced()) {
return GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec<
/*IndexType=*/indxType,
/*OffsetType=*/offsetType>(
/*bit_rate=*/bit_rate,
/*block_size=*/block_size,
/*has_weight=*/has_weight,
/*normalize_by_lengths=*/normalize_by_lengths,
/*prefetch=*/prefetch,
/*is_weight_positional=*/is_weight_positional,
/*use_offsets=*/use_offsets);
}
#endif

#ifdef VLOG
VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
#endif
Expand Down
Loading