diff --git a/src/EmbeddingSpMDMAutovec.cc b/src/EmbeddingSpMDMAutovec.cc index 41446c7c57..0b59726c9d 100644 --- a/src/EmbeddingSpMDMAutovec.cc +++ b/src/EmbeddingSpMDMAutovec.cc @@ -538,6 +538,134 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec( return current == index_size; } +template +static bool ALWAYS_INLINE EmbeddingSpMDMNBitRowWiseSparse_autovec( + const int bit_rate, + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t uncompressed_data_size, + const uint8_t* input, + const IndexType* indices, + const int32_t* compressed_indices_table, + const OffsetType* offsets_or_lengths, + const float* weights, + const bool normalize_by_lengths, + float* out, + const bool is_weight_positional, + const bool use_offsets) { + if (uncompressed_data_size < 0) { + return false; + } + + // block_size is the number of elements and fused_block_size is the size in + // bytes of an entire row, including scale and bias. + const int num_elem_per_byte = 8 / bit_rate; + const int64_t scale_bias_size = 2 * sizeof(float16); + const uint64_t scale_bias_offset = div_up(block_size, num_elem_per_byte); + const int64_t fused_block_size = scale_bias_offset + scale_bias_size; + + int64_t current = 0; + float* buf = out; + for (int64_t m = 0; m < output_size; ++m) { + const OffsetType len = use_offsets + ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] + : offsets_or_lengths[m]; + const int64_t end = current + len; + if (end > index_size) { + return false; + } + + memset(buf, 0, sizeof(float) * block_size); + + const float* weights_addr = weights != nullptr + ? (is_weight_positional ? weights : weights + current) + : nullptr; + for (; current < end; ++current) { + int64_t uncompressed_idx = indices[current]; + if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) { + return false; + } + int64_t idx = compressed_indices_table[uncompressed_idx]; + if (idx == -1) { + weights_addr++; + continue; + } + + const uint8_t* input_row_base = input + fused_block_size * idx; + const uint8_t* scale_bias_addr = input_row_base + scale_bias_offset; + + float scale = + cpu_half2float(*reinterpret_cast(scale_bias_addr)); + float bias = cpu_half2float( + *reinterpret_cast(scale_bias_addr + sizeof(float16))); + + if (weights != nullptr) { + float weight = *weights_addr++; + scale *= weight; + bias *= weight; + } + + const uint8_t* input_row = input_row_base; + if (bit_rate == 4) { + int64_t j = 0; +#ifdef FBGEMM_VECTOR_WIDTH + for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 2)); + j += 2) { + uint8_t tmp = *input_row++; + float quantized1 = float(tmp & 0xf); + float quantized2 = float(tmp >> 4); + buf[j] = std::fma(scale, quantized1, buf[j] + bias); + buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias); + } +#endif + for (; j < block_size; j += 2) { + uint8_t tmp = *input_row++; + float quantized1 = float(tmp & 0xf); + float quantized2 = float(tmp >> 4); + buf[j] = std::fma(scale, quantized1, buf[j] + bias); + buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias); + } + } else if (bit_rate == 2) { + int64_t j = 0; +#ifdef FBGEMM_VECTOR_WIDTH + for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 4)); + j += 4) { + uint8_t tmp = *input_row++; + float quantized1 = float(tmp & 0x3); + float quantized2 = float((tmp & 0xC) >> 2); + float quantized3 = float((tmp & 0x30) >> 4); + float quantized4 = float(tmp >> 6); + buf[j] = std::fma(scale, quantized1, buf[j] + bias); + buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias); + buf[j + 2] = std::fma(scale, quantized3, buf[j + 2] + bias); + buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias); + } +#endif + for (; j < block_size; j += 4) { + uint8_t tmp = *input_row++; + float quantized1 = float(tmp & 0x3); + float quantized2 = float((tmp & 0xC) >> 2); + float quantized3 = float((tmp & 0x30) >> 4); + float quantized4 = float(tmp >> 6); + buf[j] = std::fma(scale, quantized1, buf[j] + bias); + buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias); + buf[j + 2] = std::fma(scale, quantized3, buf[j + 2] + bias); + buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias); + } + } + } + if (normalize_by_lengths && len) { + float scale = 1.f / len; + for (int j = 0; j < block_size; ++j) { + buf[j] *= scale; + } + } + buf += block_size; + } + return current == index_size; +} + /// @ingroup tbe-cpu-autovec /// /// Autovectorized version of method `EmbeddingSpMDM_ref` for FP32 weight type. @@ -778,8 +906,8 @@ static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec( constexpr bool is8bit = std::is_same_v; if constexpr (is8bit) { - // block_size is the number of elements and fused_block_size is the size - // of an entire row, including scale and bias. + // block_size is the number of elements and fused_block_size is the size in + // bytes of an entire row, including scale and bias. const auto scale_bias_offset = 2 * sizeof(float); const int64_t fused_block_size = block_size + scale_bias_offset; int64_t current = 0; @@ -2153,6 +2281,96 @@ INSTANTIATE_SPMDM_INDEX_T(std::uint8_t) #undef INSTANTIATE_SPMDM_OUT_T #undef INSTANTIATE_SPMDM_BASE +template +typename EmbeddingSpMDMRowWiseSparseKernelSignature< + uint8_t, + IndexType, + OffsetType>::Type +GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec( + int bit_rate, + int64_t block_size, + bool has_weight, + bool normalize_by_lengths, + [[maybe_unused]] int prefetch, + bool is_weight_positional, + bool use_offsets) { + assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4"); + using specialization_helper::fixed; + using specialization_helper::match; + using specialization_helper::specialize; + using specialization_helper::var; + +#define SPECIALIZE( \ + BIT_RATE, \ + BLOCK_SIZE, \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS) \ + if (match(BIT_RATE, bit_rate) && match(BLOCK_SIZE, block_size) && \ + match(HAS_WEIGHT, has_weight) && \ + match(NORMALIZE_BY_LENGTHS, normalize_by_lengths) && \ + match(IS_WEIGHT_POSITIONAL, is_weight_positional) && \ + match(USE_OFFSETS, use_offsets)) { \ + return [=](int64_t output_size, \ + int64_t index_size, \ + int64_t uncompressed_data_size, \ + const uint8_t* input, \ + const IndexType* indices, \ + const OffsetType* offsets_or_lengths, \ + const float* weights, \ + float* out, \ + const int32_t* compressed_indices_table) { \ + if (specialize(HAS_WEIGHT, has_weight)) { \ + __builtin_assume(weights != nullptr); \ + } else { \ + weights = nullptr; \ + } \ + return EmbeddingSpMDMNBitRowWiseSparse_autovec( \ + /*bit_rate=*/specialize(BIT_RATE, bit_rate), \ + /*block_size=*/specialize(BLOCK_SIZE, block_size), \ + /*output_size=*/output_size, \ + /*index_size=*/index_size, \ + /*uncompressed_data_size=*/uncompressed_data_size, \ + /*input=*/input, \ + /*indices=*/indices, \ + /*compressed_indices_table=*/compressed_indices_table, \ + /*offsets_or_lengths=*/offsets_or_lengths, \ + /*weights=*/weights, /*normalize_by_lengths=*/ \ + specialize(NORMALIZE_BY_LENGTHS, normalize_by_lengths), \ + /*out=*/out, /*is_weight_positional=*/ \ + specialize(IS_WEIGHT_POSITIONAL, is_weight_positional), \ + /*use_offsets=*/specialize(USE_OFFSETS, use_offsets)); \ + }; \ + } + + SPECIALIZE(/*BIT_RATE*/ fixed(4), var, var, var, var, var); + SPECIALIZE(/*BIT_RATE*/ fixed(2), var, var, var, var, var); + abort(); // should not get here +#undef SPECIALIZE +} + +#define INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(INDEX_TYPE, OFFSET_TYPE) \ + template typename EmbeddingSpMDMRowWiseSparseKernelSignature< \ + uint8_t, \ + INDEX_TYPE, \ + OFFSET_TYPE>::Type \ + GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec( \ + int bit_rate, \ + int64_t block_size, \ + bool has_weight, \ + bool normalize_by_lengths, \ + int prefetch, \ + bool is_weight_positional, \ + bool use_offsets); + +INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int32_t, int32_t) +INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int32_t, int64_t) +INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int64_t, int32_t) +INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int64_t, int64_t) + +#undef INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE + } // namespace fbgemm #endif // #ifdef __linux__ diff --git a/src/EmbeddingSpMDMAutovec.h b/src/EmbeddingSpMDMAutovec.h index 0f590552fc..81d7d90d35 100644 --- a/src/EmbeddingSpMDMAutovec.h +++ b/src/EmbeddingSpMDMAutovec.h @@ -92,6 +92,20 @@ GenerateEmbeddingSpMDMRowWiseSparse_autovec( bool is_weight_positional, bool use_offsets); +template +typename EmbeddingSpMDMRowWiseSparseKernelSignature< + uint8_t, + IndexType, + OffsetType>::Type +GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec( + int bit_rate, + int64_t block_size, + bool has_weight, + bool normalize_by_lengths, + int prefetch, + bool is_weight_positional, + bool use_offsets); + } // namespace fbgemm #endif // #ifdef __linux__ diff --git a/src/EmbeddingSpMDMNBit.cc b/src/EmbeddingSpMDMNBit.cc index a732caa5f3..3db324ea76 100644 --- a/src/EmbeddingSpMDMNBit.cc +++ b/src/EmbeddingSpMDMNBit.cc @@ -1296,6 +1296,25 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse( } #endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +#ifdef FBGEMM_AUTOVEC_AVAILABLE + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + if ((fbgemmHasArmSve2Support() && !is_autovec_disabled()) || + is_autovec_forced()) { + return GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec< + /*IndexType=*/indxType, + /*OffsetType=*/offsetType>( + /*bit_rate=*/bit_rate, + /*block_size=*/block_size, + /*has_weight=*/has_weight, + /*normalize_by_lengths=*/normalize_by_lengths, + /*prefetch=*/prefetch, + /*is_weight_positional=*/is_weight_positional, + /*use_offsets=*/use_offsets); + } +#endif + #ifdef VLOG VLOG(0) << "AVX2 or AVX512 not found, taking the slow path"; #endif