Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def setup_hip(args: argparse.Namespace):
parser.add_argument("--xformers", action="store_true", help="Install xformers")
parser.add_argument("--tile", action="store_true", help="install tile lang")
parser.add_argument("--aiter", action="store_true", help="install AMD's aiter")
parser.add_argument("--cutlass", action="store_true", help="install cutlass")
parser.add_argument(
"--tritonparse", action="store_true", help="Install tritonparse"
)
Expand Down Expand Up @@ -183,6 +184,11 @@ def setup_hip(args: argparse.Namespace):
if args.fa2 or args.all:
logger.info("[tritonbench] installing fa2 from source...")
install_fa2(compile=True)
if args.cutlass or args.all:
logger.info("[tritonbench] installing cutlass...")
from tools.cutlass.install import install_cutlass

install_cutlass()
if args.jax or args.all:
logger.info("[tritonbench] installing jax...")
install_jax()
Expand Down
2 changes: 1 addition & 1 deletion submodules/cutlass
Submodule cutlass updated 958 files
21 changes: 21 additions & 0 deletions tools/cutlass/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
import subprocess
import sys

from pathlib import Path

REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass")


def test_cutlass():
cmd = [
sys.executable,
"-c",
"import cutlass_cppgen",
]
subprocess.check_call(cmd)

def install_cutlass():
command = ["pip", "install", "-e", "."]
subprocess.check_call(command, cwd=CUTLASS_PATH)
Loading