From 851d03eebff47306a98e640bd101d8967c4f4b64 Mon Sep 17 00:00:00 2001 From: Matthias Braun Date: Tue, 6 Jan 2026 20:09:39 -0800 Subject: [PATCH] Specialize more cases to improve EmbeddingSpMDMNBitBenchmark (#5245) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2236 `EmbeddingSpMDMNBitBenchmark` uses the `scale_bias_last == true && has_weight == true` variant which wasn't previously specialized in the code because we did not see it in the bigger test suites. However it does not cost too much code size to add it and makes this benchmark look better... Also added block_size==576, block_size==36 and block_size==72 specialization seen in some other models. This adds ~50K in code size (compared to currently ~150K for the existing specializations) which seems acceptable. Reviewed By: excelle08 Differential Revision: D87289832 --- bench/EmbeddingSpMDMNBitBenchmark.cc | 1 + src/EmbeddingSpMDMAutovec.cc | 100 +++++++++++++++++++++++---- 2 files changed, 86 insertions(+), 15 deletions(-) diff --git a/bench/EmbeddingSpMDMNBitBenchmark.cc b/bench/EmbeddingSpMDMNBitBenchmark.cc index ebfcc334c7..dca1caf891 100644 --- a/bench/EmbeddingSpMDMNBitBenchmark.cc +++ b/bench/EmbeddingSpMDMNBitBenchmark.cc @@ -486,6 +486,7 @@ static int run_benchmark( cout << ", asmjit speedup, " << t_ref / t; #endif cout << '\n'; + cout.flush(); } // flush_cache } // has_weight return 0; diff --git a/src/EmbeddingSpMDMAutovec.cc b/src/EmbeddingSpMDMAutovec.cc index 41446c7c57..afa52e242f 100644 --- a/src/EmbeddingSpMDMAutovec.cc +++ b/src/EmbeddingSpMDMAutovec.cc @@ -1302,6 +1302,7 @@ typename EmbeddingSpMDMKernelSignature:: PREFETCH, \ IS_WEIGHT_POSITIONAL, \ USE_OFFSETS, \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1314,7 +1315,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(4, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1327,7 +1328,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(24, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1340,7 +1341,20 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(32, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{36}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(36, false)), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1353,7 +1367,20 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(64, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{72}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(72, false)), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1366,7 +1393,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(96, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1379,7 +1406,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(124, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1392,7 +1419,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(128, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1405,7 +1432,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(252, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1418,7 +1445,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(256, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1431,7 +1458,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(320, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1444,7 +1471,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(384, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1457,7 +1484,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(508, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1470,7 +1497,20 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(512, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{576}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(576, false)), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1483,7 +1523,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(768, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) \ @@ -1496,7 +1536,7 @@ typename EmbeddingSpMDMKernelSignature:: USE_OFFSETS, \ /*OUTPUT_STRIDE*/ var, \ /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(1024, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ + SCALE_BIAS_LAST, \ NO_BAG, \ IS_BF16_OUT, \ IS_BF16_IN) @@ -1508,6 +1548,27 @@ typename EmbeddingSpMDMKernelSignature:: /*PREFETCH*/ var, /*IS_WEIGHT_POSITIONAL*/ fixed(false), /*USE_OFFSETS*/ fixed(true), + /*SCALE_BIAS_LAST*/ fixed(false), + /*NO_BAG*/ fixed(false), + /*IS_BF16_OUT*/ var, + /*IS_BF16_IN*/ var) + SPECIALIZE_BLOCK_SIZE( + /*HAS_WEIGHT*/ fixed(false), + /*NORMALIZE_BY_LENGTHS*/ fixed(false), + /*PREFETCH*/ var, + /*IS_WEIGHT_POSITIONAL*/ fixed(false), + /*USE_OFFSETS*/ fixed(true), + /*SCALE_BIAS_LAST*/ fixed(false), + /*NO_BAG*/ fixed(false), + /*IS_BF16_OUT*/ var, + /*IS_BF16_IN*/ var) + SPECIALIZE_BLOCK_SIZE( + /*HAS_WEIGHT*/ fixed(true), + /*NORMALIZE_BY_LENGTHS*/ fixed(false), + /*PREFETCH*/ var, + /*IS_WEIGHT_POSITIONAL*/ fixed(false), + /*USE_OFFSETS*/ fixed(true), + /*SCALE_BIAS_LAST*/ fixed(true), /*NO_BAG*/ fixed(false), /*IS_BF16_OUT*/ var, /*IS_BF16_IN*/ var) @@ -1517,6 +1578,7 @@ typename EmbeddingSpMDMKernelSignature:: /*PREFETCH*/ var, /*IS_WEIGHT_POSITIONAL*/ fixed(false), /*USE_OFFSETS*/ fixed(true), + /*SCALE_BIAS_LAST*/ fixed(true), /*NO_BAG*/ fixed(false), /*IS_BF16_OUT*/ var, /*IS_BF16_IN*/ var) @@ -1874,6 +1936,14 @@ GenerateEmbeddingSpMDMNBitWithStrides_autovec( /*SCALE_BIAS_LAST*/ fixed(false), /*IS_BF16_OUT*/ var, /*NO_BAG*/ fixed(false)) + SPECIALIZE_INPUT_RATE( + /*HAS_WEIGHT*/ fixed(true), + /*NORMALIZE_BY_LENGTHS*/ fixed(false), + /*IS_WEIGHT_POSITIONAL*/ fixed(false), + /*USE_OFFSETS*/ fixed(true), + /*SCALE_BIAS_LAST*/ fixed(true), + /*IS_BF16_OUT*/ var, + /*NO_BAG*/ fixed(false)) SPECIALIZE_INPUT_RATE( /*HAS_WEIGHT*/ fixed(false), /*NORMALIZE_BY_LENGTHS*/ fixed(false),