Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
curl -sSL https://install.python-poetry.org | python3
poetry install --all-extras
poetry install --all-extras -vvv
- name: Type-checking package with mypy
run: |
# Run this mypy instance against our main package.
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
curl -sSL https://install.python-poetry.org | python3
poetry install --all-extras
poetry install --all-extras -vvv
- name: Test with pytest
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
Expand Down
189 changes: 189 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/
docs/source/getting_started/examples/*.rst
!**/*.template.rst

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# VSCode
.vscode/

# DS Store
.DS_Store

# Results
*.csv

# Python pickle files
*.pkl

# Sphinx documentation
_build/

# vim swap files
*.swo
*.swp

# hip files generated by PyTorch
*.hip
*_hip*
hip_compat.h

# Benchmark dataset
*.json
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "3rdparty/flashinfer"]
path = 3rdparty/flashinfer
url = https://github.com/flashinfer-ai/flashinfer.git
1 change: 1 addition & 0 deletions 3rdparty/flashinfer
Submodule flashinfer added at 58d359
52 changes: 52 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
cmake_minimum_required(VERSION 3.23.1)
project(deft
VERSION 2024
DESCRIPTION "An IO-aware fast attention kernel for efficient tree-structured interactions with LLMs"
LANGUAGES CUDA CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(FLASHINFER_DIR ${PROJECT_SOURCE_DIR}/3rdparty/flashinfer)
set(FLASHINFER_INCLUDE_DIR ${FLASHINFER_DIR}/include)

add_subdirectory(${FLASHINFER_DIR})

set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g -lineinfo -Xptxas -v -lineinfo")
find_package(Python3 REQUIRED)
find_package(CUDAToolkit REQUIRED)
if(NOT Python3_FOUND)
message(FATAL_ERROR "Python3 not found.")
endif()
if(NOT CUDAToolkit_FOUND)
message(FATAL_ERROR "CUDA not found.")
endif()
message(STATUS "Python3 found at ${Python3_EXECUTABLE}")
message(STATUS "CUDA version is ${CUDAToolkit_VERSION}")
set(DEFT_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/csrc)
include_directories(${DEFT_SOURCE_DIR})
include_directories(${CUDAToolkit_INCLUDE_DIRS})
include_directories(${FLASHINFER_INCLUDE_DIR})

link_directories(${CUDAToolkit_LIBRARY_DIR})

set(DEFT_ENABLE_TESTS CACHE BOOL "Enable tests for DEFT" ON)

if(DEFT_ENABLE_TESTS)
enable_testing()
set(DEFT_TEST_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tests/csrc)
file(GLOB_RECURSE TEST_DUMMY_SRC ${DEFT_TEST_SOURCE_DIR}/test_dummy.cu)
add_executable(test_dummy ${TEST_DUMMY_SRC})
target_link_libraries(test_dummy gtest_main gtest)
target_link_libraries(test_dummy ${CUDAToolkit_LIBRARIES})
add_test(NAME test_dummy COMMAND test_dummy)

file(GLOB_RECURSE TEST_DEFT_SRC ${DEFT_TEST_SOURCE_DIR}/test_deft_attention.cu)
add_executable(test_deft_attention ${TEST_DEFT_SRC})
target_link_libraries(test_deft_attention gtest_main gtest)
target_link_libraries(test_deft_attention ${CUDAToolkit_LIBRARIES})
target_include_directories(test_deft_attention PRIVATE ${FLASHINFER_DIR}/src)
add_dependencies(test_deft_attention dispatch_inc)
target_link_libraries(test_deft_attention decode_kernels)
add_test(NAME test_deft_attention COMMAND test_deft_attention)
endif()
19 changes: 16 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ We propose DeFT, an IO-aware attention algorithm for efficient tree-structured i

- [2024/05] We update the second version of DeFT paper with a better algorithm for general tree-structured LLM inference: [DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference](https://arxiv.org/abs/2404.00242)!
- [2024/03] [DeFT: Flash Tree-Attention With IO-Awareness for Efficient Tree-Search-Based LLM Inference](https://openreview.net/pdf?id=HqfLHoX8bR) has been accepted as Oral presentation in [ICLR'24 AGI Workshop](https://iclr.cc/virtual/2024/23126)!

****

## Abstract
Given the increasing demand for tree-structured interactions with LLMs, we introduce DeFT (Decoding with Flash Tree-Attention), an IO-aware tree attention algorithm tailored for tree-structured inference. Unlike traditional sequence-based decoding, tree-structured decoding better accommodates modern task requirements, including self-consistency, few-shot prompting, multi-step reasoning, and multi-model/head coordination. However, existing sequence-based inference systems are ill-suited for tree-structured decoding, resulting in redundancy in computation, memory footprints, and memory access, thereby undermining inference efficiency. To address this challenge, DeFT maintains memory-efficient attention calculation with low memory footprints through two key stages: (1) QKV Preparation: We propose a KV-Guided Grouping Strategy with Tree Split to intelligently group QKV, optimizing GPU resource utilization while minimizing memory reads/writes for KV cache between GPU global memory and on-chip shared memory; (2)Attention Calculation: We compute partial attention of each QKV group in a fused kernel and employ a Tree-topology-aware Global Reduction strategy to obtain final attention. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation (e.g., Softmax), DeFT achieves up to 2.52/3.82x speedup in the end-to-end/attention latency across three practical tree-based workloads: namely, few-shot prompting, multi-step reasoning, and speculative decoding, over state-of-the-art attention algorithms.
Expand All @@ -42,10 +42,23 @@ poetry install
CUDA_VISIBLE_DEVICES=0 python examples/
```


### Run Tests

<!-- We profile DeFT kernel performance with [nvbench](https://github.com/NVIDIA/nvbench) and you can compile and run the benchmarks with the following commands: -->

```bash
git submodule update --init --recursive # clone flashinfer
cmake -B build
cmake --build build
cd build
ctest
```

## FAQ

1. **What is the difference between two versions of DeFT papers in arXiv?**

DeFT-v1


Expand All @@ -72,4 +85,4 @@ If you find DeFT useful or relevant to your project and research, please kindly
journal={arXiv preprint arXiv:2404.00242},
year={2024}
}
```
```
29 changes: 29 additions & 0 deletions build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging
from typing import Any, Dict

from torch.utils.cpp_extension import CUDAExtension

logger = logging.getLogger(__name__)


ext_modules = []
ext_modules.append(
CUDAExtension(
name='deft._kernels',
sources=['csrc/deft_api.cpp', 'csrc/deft/ops.cu'],
include_dirs=[
'csrc',
],
)
)


def build(setup_kwargs: Dict[str, Any]) -> None:
setup_kwargs.update(
{
# 'ext_modules': ext_modules,
# 'cmdclass': {
# 'build_ext': BuildExtension.with_options(use_ninja=False),
# },
}
)
41 changes: 41 additions & 0 deletions config.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Whether to compile fp8 kernels or not.
set(FLASHINFER_ENABLE_FP8 OFF)
# Whether to compile bf16 kernels or not.
set(FLASHINFER_ENABLE_BF16 OFF)
# Whether to compile tvm bindings or not.
set(FLASHINFER_TVM_BINDING OFF)
# Whether to compile prefill kernel tests/benchmarks or not.
set(FLASHINFER_PREFILL ON)
# Whether to compile decode kernel tests/benchmarks or not.
set(FLASHINFER_DECODE ON)
# Whether to compile page kernel tests/benchmarks or not.
set(FLASHINFER_PAGE ON)
# Whether to compile cascade kernel tests/benchmarks or not.
set(FLASHINFER_CASCADE ON)
# Whether to compile sampling kernel tests/benchmarks or not.
set(FLASHINFER_SAMPLING OFF)
# Whether to compile normalization kernel tests/benchmarks or not.
set(FLASHINFER_NORMALIZATION OFF)
# Whether to compile fastdiv tests
set(FLASHINFER_FASTDIV_TEST ON)
# Whether to compile distributed tests
set(FLASHINFER_DISTRIBUTED OFF)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_HEAD_DIMS 128)
set(FLASHINFER_GEN_KV_LAYOUTS 0)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0)
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")
set(FLASHINFER_GEN_MASK_MODES 0)

# Set target cuda architectures for tests/benchmarks, defaults to native.
# "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU.
# it's new in CMake 3.24, if you are using an older of CMake or you want to use a different value, you can
# set its value here. Supported CUDA architctures include 80;86;89;90
# NOTE(Zihao): using "native" might be slow because whenever compile a cuda file with `-arch=native`, nvcc will spawn
# a `__nvcc_device_query` process to get the architecture of the host's GPU, which could stall the compilation process.
# So it's recommended to set it to a specific value if you know the architecture of the target GPU.
# Example:
# set(FLASHINFER_CUDA_ARCHITECTURES 80)
set(FLASHINFER_CUDA_ARCHITECTURES 80)
Loading