Skip to content

Commit 814843e

Browse files
authored
Enable bitsandbytes quantization on AMD GPUs that use warp size 32 (#27307)
Signed-off-by: sstamenk <strahinja.stamenkovic@amd.com>
1 parent 20852c8 commit 814843e

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

tests/models/quantization/test_bitsandbytes.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
from ...utils import compare_two_settings, multi_gpu_test
1515
from ..utils import check_embeddings_close, check_logprobs_close
1616

17-
pytestmark = pytest.mark.skipif(
18-
current_platform.is_rocm(),
19-
reason="bitsandbytes quantization not supported on ROCm (CUDA-only kernels)",
20-
)
17+
if current_platform.is_rocm():
18+
from vllm.platforms.rocm import on_gfx9
19+
20+
pytestmark = pytest.mark.skipif(
21+
on_gfx9(),
22+
reason="bitsandbytes not supported on gfx9 (warp size 64 limitation)",
23+
)
2124

2225
models_4bit_to_test = [
2326
("facebook/opt-125m", "quantize opt model inflight"),

vllm/platforms/rocm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ class RocmPlatform(Platform):
185185
"petit_nvfp4",
186186
"torchao",
187187
]
188+
# bitsandbytes not supported on gfx9 (warp size 64 limitation)
189+
if not on_gfx9():
190+
supported_quantization += ["bitsandbytes"]
188191

189192
@classmethod
190193
def get_vit_attn_backend(

0 commit comments

Comments
 (0)