diff --git a/install.py b/install.py index 7eb6a8703..e8cebea64 100644 --- a/install.py +++ b/install.py @@ -147,6 +147,7 @@ def setup_hip(args: argparse.Namespace): parser.add_argument("--xformers", action="store_true", help="Install xformers") parser.add_argument("--tile", action="store_true", help="install tile lang") parser.add_argument("--aiter", action="store_true", help="install AMD's aiter") + parser.add_argument("--cutlass", action="store_true", help="install cutlass") parser.add_argument( "--tritonparse", action="store_true", help="Install tritonparse" ) @@ -183,6 +184,11 @@ def setup_hip(args: argparse.Namespace): if args.fa2 or args.all: logger.info("[tritonbench] installing fa2 from source...") install_fa2(compile=True) + if args.cutlass or args.all: + logger.info("[tritonbench] installing cutlass...") + from tools.cutlass.install import install_cutlass + + install_cutlass() if args.jax or args.all: logger.info("[tritonbench] installing jax...") install_jax() diff --git a/submodules/cutlass b/submodules/cutlass index ad7b2f5e8..f3fde5837 160000 --- a/submodules/cutlass +++ b/submodules/cutlass @@ -1 +1 @@ -Subproject commit ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e +Subproject commit f3fde58372d33e9a5650ba7b80fc48b3b49d40c8 diff --git a/tools/cutlass/install.py b/tools/cutlass/install.py new file mode 100644 index 000000000..d7d337942 --- /dev/null +++ b/tools/cutlass/install.py @@ -0,0 +1,21 @@ +import os +import subprocess +import sys + +from pathlib import Path + +REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent +CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass") + + +def test_cutlass(): + cmd = [ + sys.executable, + "-c", + "import cutlass_cppgen", + ] + subprocess.check_call(cmd) + +def install_cutlass(): + command = ["pip", "install", "-e", "."] + subprocess.check_call(command, cwd=CUTLASS_PATH)