-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[feat](moe_kernel): add amd blis support #1600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @KMSorSMS, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the project's hardware support by integrating AMD BLIS for optimized Mixture of Experts (MoE) kernel operations. It introduces new INT4 and INT8 quantized MoE kernel implementations and refines the build system to provide better configurability and a more robust developer experience, particularly for CUDA-enabled builds. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for AMD BLIS within the MoE kernel, adding a new CMake preset for AMD platforms, a Python wrapper for the generic MoE kernel, and several enhancements to the build system. My review has identified a few critical issues and areas for improvement. There's a misconfiguration in the new 'amd' CMake preset that could cause runtime errors. A significant bug was found in experts.py where incorrect classes are being used for the new MoE methods. The new moe_kernel.py file contains several copy-paste errors from the AMX implementation, is missing an implementation for a required abstract method, and has some code quality issues. On the other hand, the build system improvements in setup.py, such as automatic nvcc detection and build directory cleaning, are valuable additions.
kt-kernel/python/experts.py
Outdated
| from .utils.amx import AMXMoEWrapper | ||
| from .utils.llamafile import LlamafileMoEWrapper | ||
|
|
||
| from .utils.moe_kernel import Int8_KERNEL_MOE, Int4_KERNEL_MOE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import is incorrect. It should import GeneralMoEWrapper, which is the Python wrapper for the new MoE kernels. The Int8_KERNEL_MOE and Int4_KERNEL_MOE are C++ extension classes and should be used within the wrapper, not directly in the factory.
| from .utils.moe_kernel import Int8_KERNEL_MOE, Int4_KERNEL_MOE | |
| from .utils.moe_kernel import GeneralMoEWrapper |
kt-kernel/python/experts.py
Outdated
| elif method == "MOE_INT8": | ||
| backend_cls = Int8_KERNEL_MOE | ||
| elif method == "MOE_INT4": | ||
| backend_cls = Int4_KERNEL_MOE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backend_cls for MOE_INT8 and MOE_INT4 methods should be GeneralMoEWrapper. The current implementation incorrectly assigns the C++ extension classes Int8_KERNEL_MOE and Int4_KERNEL_MOE directly. These C++ classes have a different constructor signature (they expect a MOEConfig object) and do not inherit from BaseMoEWrapper, which will cause a runtime error when backend_cls is instantiated.
| elif method == "MOE_INT8": | |
| backend_cls = Int8_KERNEL_MOE | |
| elif method == "MOE_INT4": | |
| backend_cls = Int4_KERNEL_MOE | |
| elif method in ["MOE_INT8", "MOE_INT4"]: | |
| backend_cls = GeneralMoEWrapper |
|
|
||
| from typing import Optional | ||
|
|
||
| class GeneralMoEWrapper(BaseMoEWrapper): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class GeneralMoEWrapper inherits from BaseMoEWrapper, which has an abstract method load_weights_from_tensors. This method is not implemented in GeneralMoEWrapper, which will cause a TypeError when an instance of this class is created. You should implement this method. You can likely adapt the implementation from AMXMoEWrapper.
| "KTRANSFORMERS_CPU_USE_AMX": "OFF", | ||
| "LLAMA_AVX512": "OFF", | ||
| "LLAMA_AVX2": "ON", | ||
| "KTRANSFORMERS_CPU_USE_AMX_AVX512": "ON", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The KTRANSFORMERS_CPU_USE_AMX_AVX512 flag is set to "ON" in the amd preset. This flag enables AVX-512 compilation flags, which may not be supported on all AMD CPUs targeted by this preset (which enables AVX2). This could lead to "illegal instruction" errors at runtime. This flag should be set to "OFF" to align with the avx preset and avoid requiring AVX-512.
| "KTRANSFORMERS_CPU_USE_AMX_AVX512": "ON", | |
| "KTRANSFORMERS_CPU_USE_AMX_AVX512": "OFF", |
| if not _HAS_INT4_SUPPORT and method == "MOE_INT4": | ||
| raise RuntimeError( | ||
| "AMX backend not available. kt_kernel_ext was not compiled with AMX support.\n" | ||
| "Please recompile with AMX enabled." | ||
| ) | ||
| if not _HAS_INT8_SUPPORT and method == "MOE_INT8": | ||
| raise RuntimeError( | ||
| "AMX backend not available. kt_kernel_ext was not compiled with AMX support.\n" | ||
| "Please recompile with AMX enabled." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error messages are misleading as they refer to the "AMX backend". These should be updated to refer to the generic MoE kernel backend, as this wrapper is not specific to AMX.
| if not _HAS_INT4_SUPPORT and method == "MOE_INT4": | |
| raise RuntimeError( | |
| "AMX backend not available. kt_kernel_ext was not compiled with AMX support.\n" | |
| "Please recompile with AMX enabled." | |
| ) | |
| if not _HAS_INT8_SUPPORT and method == "MOE_INT8": | |
| raise RuntimeError( | |
| "AMX backend not available. kt_kernel_ext was not compiled with AMX support.\n" | |
| "Please recompile with AMX enabled." | |
| ) | |
| if not _HAS_INT4_SUPPORT and method == "MOE_INT4": | |
| raise RuntimeError( | |
| "INT4 MoE kernel backend not available. kt_kernel_ext was not compiled with MoE kernel support.\n" | |
| "Please recompile with KTRANSFORMERS_CPU_MOE_KERNEL=ON." | |
| ) | |
| if not _HAS_INT8_SUPPORT and method == "MOE_INT8": | |
| raise RuntimeError( | |
| "INT8 MoE kernel backend not available. kt_kernel_ext was not compiled with MoE kernel support.\n" | |
| "Please recompile with KTRANSFORMERS_CPU_MOE_KERNEL=ON." | |
| ) |
| """ | ||
| AMX-based MoE wrapper implementation. | ||
| Supports AMXINT4 and AMXINT8 quantization methods. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class docstring appears to be copied from the AMX wrapper. It should be updated to reflect that this is a general MoE kernel wrapper, not specific to AMX.
| """ | |
| AMX-based MoE wrapper implementation. | |
| Supports AMXINT4 and AMXINT8 quantization methods. | |
| """ | |
| """ | |
| General MoE kernel wrapper implementation. | |
| Supports INT4 and INT8 quantization methods. | |
| """ |
| """ | ||
| Initialize AMX MoE Wrapper. | ||
| Args: | ||
| layer_idx: Layer index | ||
| num_experts: Total number of experts | ||
| num_experts_per_tok: Number of experts per token (top-k) | ||
| hidden_size: Hidden dimension size | ||
| moe_intermediate_size: MoE intermediate size | ||
| num_gpu_experts: Number of experts to run on GPU | ||
| cpuinfer_threads: Number of CPU inference threads | ||
| threadpool_count: Number of NUMA subpools | ||
| weight_path: Path to AMX weights (SafeTensor format) | ||
| chunked_prefill_size: Maximum prefill chunk size | ||
| cpu_save: Whether to save weights to CPU memory | ||
| max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0. | ||
| method: general quantization method ("MOE_INT4" or "MOE_INT8") | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for __init__ seems to be copied from the AMX wrapper. It should be updated to remove AMX-specific references and describe the general MoE kernel wrapper.
| """ | |
| Initialize AMX MoE Wrapper. | |
| Args: | |
| layer_idx: Layer index | |
| num_experts: Total number of experts | |
| num_experts_per_tok: Number of experts per token (top-k) | |
| hidden_size: Hidden dimension size | |
| moe_intermediate_size: MoE intermediate size | |
| num_gpu_experts: Number of experts to run on GPU | |
| cpuinfer_threads: Number of CPU inference threads | |
| threadpool_count: Number of NUMA subpools | |
| weight_path: Path to AMX weights (SafeTensor format) | |
| chunked_prefill_size: Maximum prefill chunk size | |
| cpu_save: Whether to save weights to CPU memory | |
| max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0. | |
| method: general quantization method ("MOE_INT4" or "MOE_INT8") | |
| """ | |
| """ | |
| Initialize General MoE Wrapper. | |
| Args: | |
| layer_idx: Layer index | |
| num_experts: Total number of experts | |
| num_experts_per_tok: Number of experts per token (top-k) | |
| hidden_size: Hidden dimension size | |
| moe_intermediate_size: MoE intermediate size | |
| num_gpu_experts: Number of experts to run on GPU | |
| cpuinfer_threads: Number of CPU inference threads | |
| threadpool_count: Number of NUMA subpools | |
| weight_path: Path to weights (SafeTensor format) | |
| chunked_prefill_size: Maximum prefill chunk size | |
| cpu_save: Whether to save weights to CPU memory | |
| max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0. | |
| method: general quantization method ("MOE_INT4" or "MOE_INT8") | |
| """ |
|
|
||
| # AMX-specific: Check if we should load merged safetensor weights | ||
| self.load_merged_weight = False | ||
| import glob |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| gate_ptrs = [ | ||
| [ | ||
| ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) | ||
| for et in numa_array | ||
| ] | ||
| for numa_array in self.gate_weights | ||
| ] | ||
|
|
||
| up_ptrs = [ | ||
| [ | ||
| ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) | ||
| for et in numa_array | ||
| ] | ||
| for numa_array in self.up_weights | ||
| ] | ||
|
|
||
| down_ptrs = [ | ||
| [ | ||
| ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) | ||
| for et in numa_array | ||
| ] | ||
| for numa_array in self.down_weights | ||
| ] | ||
|
|
||
| gate_scale_ptrs = [ | ||
| [ | ||
| ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) | ||
| for et in numa_array | ||
| ] | ||
| for numa_array in self.gate_scales | ||
| ] | ||
|
|
||
| up_scale_ptrs = [ | ||
| [ | ||
| ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) | ||
| for et in numa_array | ||
| ] | ||
| for numa_array in self.up_scales | ||
| ] | ||
|
|
||
| down_scale_ptrs = [ | ||
| [ | ||
| ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) | ||
| for et in numa_array | ||
| ] | ||
| for numa_array in self.down_scales | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#1582 #1601