From b558b12bd8edf57df0d40f543d48da5052c433d6 Mon Sep 17 00:00:00 2001 From: Federico Peccia Date: Thu, 26 Jun 2025 14:48:40 +0200 Subject: [PATCH 1/9] Added RISC-V V extension intrinsics for LLVM --- include/tvm/meta_schedule/postproc.h | 2 + include/tvm/meta_schedule/schedule_rule.h | 2 + python/tvm/meta_schedule/tune_context.py | 8 + python/tvm/target/target.py | 12 + python/tvm/tir/tensor_intrin/__init__.py | 2 +- python/tvm/tir/tensor_intrin/riscv_cpu.py | 740 ++++++++++++++++++ src/meta_schedule/postproc/postproc.cc | 8 + .../schedule_rule/schedule_rule.cc | 116 +++ .../space_generator/space_generator.cc | 35 + src/target/parsers/aprofile.cc | 15 + src/target/parsers/cpu.cc | 23 + src/target/parsers/cpu.h | 3 + src/target/source/codegen_c.cc | 2 + 13 files changed, 967 insertions(+), 1 deletion(-) create mode 100644 python/tvm/tir/tensor_intrin/riscv_cpu.py diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index c511271d20a9..6ed7272fe9b4 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -166,6 +166,8 @@ class Postproc : public runtime::ObjectRef { TVM_DLL static Array DefaultLLVM(); /*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */ TVM_DLL static Array DefaultCPUTensorization(); + /*! \brief Create default postprocessors for RISCV */ + TVM_DLL static Array DefaultRISCV(); /*! \brief Create default postprocessors for CUDA */ TVM_DLL static Array DefaultCUDA(); /*! \brief Create default postprocessors for CUDA with TensorCore */ diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 9011ebe0c12f..407914e3d074 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -301,6 +301,8 @@ class ScheduleRule : public runtime::ObjectRef { TVM_DLL static Array DefaultHexagon(); /*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */ TVM_DLL static Array DefaultARM(const String& type); + /*! \brief Create default schedule rules for RISCV CPU (RVV) */ + TVM_DLL static Array DefaultRISCV(int vlen); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); }; diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 5512b7a2682b..488cf2712d29 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -117,6 +117,14 @@ def __init__( if target is not None: if not isinstance(target, Target): target = Target(target) + if "riscv_cpu" in target.keys: + base_features = str(target.attrs["march"]).split("_")[0].replace("rv", "") + if "v" in base_features: + # Because the RVV intrinsics depend on the target, we register them here + # pylint: disable=import-outside-toplevel + from tvm.tir.tensor_intrin.riscv_cpu import register_riscv_tensor_intrinsics + + register_riscv_tensor_intrinsics(target) if space_generator is not None: if not isinstance(space_generator, SpaceGenerator): space_generator = SpaceGenerator.create(space_generator) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 6c83ef6e5bb2..9c72b2fdab27 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -637,6 +637,18 @@ def riscv_cpu(model="sifive-u54", options=None): "-mabi=lp64d", # cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=sifive-u74 ], + "bpi-f3": [ + # "-model=sifive-u74", + "-mtriple=riscv64-unknown-linux-gnu", + "-mcpu=generic", + # "-march=rv64gcv_zvl256b", + # "-mcpu=generic-rv64", + "-mfloat-abi=hard", + "-num-cores=8", + "-mabi=lp64d", + "-mattr=+v,+zvl256b", + # cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=generic -mattr=+v + ], } pre_defined_opt = trans_table.get(model, ["-model=%s" % model]) diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py index 564655455245..0a6cf5310c9c 100644 --- a/python/tvm/tir/tensor_intrin/__init__.py +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -20,4 +20,4 @@ from . import cuda if enabled("llvm"): - from . import arm_cpu, x86, rocm, hexagon + from . import arm_cpu, x86, rocm, hexagon, riscv_cpu diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py b/python/tvm/tir/tensor_intrin/riscv_cpu.py new file mode 100644 index 000000000000..138fa66776af --- /dev/null +++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py @@ -0,0 +1,740 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring,unused-import +"""Intrinsics for RVV tensorization, both for C and LLVM targets. +===================== +**Author**: `Federico Peccia `_ +""" +import re +from tvm.script import tir as T +from tvm.target.datatype import lower_call_pure_extern, register, register_op +from .. import TensorIntrin + +##################################################### +# LLVM RISC-V Intrinsic usage: +# https://llvm.org/docs//RISCV/RISCVVectorExtension.html +# +# Vector types are represented using scalable vector +# types, of the form . n and ty +# control LMUL and SEW respectively (see table in docs). +# TVM represents this with dtype = "tyxvscalexn". +# +# n is calculated as (64/SEW)*LMUL. +# VL is passed to each intrinsic. +# +# Some examples (see table in docs): +# int8 vector type with LMUL = 1 => int8xvscalex8 +# int16 vector type with LMUL = 4 => int16xvscalex16 +# int32 vector type with LMUL = 2 => int32xvscalex4 +# +##################################################### + +##################################################### +# Helper functions +##################################################### + +RISCV_MIN_VL = 4 + + +def get_vlmax(vlen: int, lmul: int, max_sew: int) -> int: + """Return VLMAX + + Args: + vlen (int): Actual VLEN + lmul (int): LMUL + max_sew (int): SEW + + Returns: + int: VLMAX + """ + return (lmul * vlen) // max_sew + + +def get_vlen_from_mattrs(mattrs: list) -> int: + """Extract VLEN from LLVM mattrs list + + Args: + mattrs (list): LLVM list of CPU mattrs + + Returns: + int: VLEN + """ + vlen_regex = r"zvl(\d+)b" + vlen = 0 + for mattr in mattrs: + match = re.search(vlen_regex, mattr) + + if match: + vlen = int(match.group(1)) + break + return vlen + + +def _dtype_to_bits(dtype: str) -> int: + """Get bits from data type + + Args: + dtype (str): Data type + + Returns: + int: bits + """ + bits_per_item = int( + re.match(r"((float)|(int)|(uint))(?P[0-9]+)", dtype).group("width_bits") + ) + assert bits_per_item is not None, f"don't know how to compute size of type {dtype}" + return bits_per_item + + +def _get_dtype_string(dtype: str) -> str: + """Get only type of data type, without bits + + Args: + dtype (str): Data type + + Returns: + str: only string type + """ + return str(re.match(r"[a-z]+", dtype).group(0)) + + +##################################################### +# Parameterized intrinsics +##################################################### + + +def rvv_vmacc(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + input_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = output_str_type[0] + + input_lmul = lmul if output_dtype_prefix == "f" else lmul // 2 + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = "llvm.riscv.vle" + macc_llvm_intrinsic = "llvm.riscv.vmacc" if output_dtype_prefix != "f" else "llvm.riscv.vfmacc" + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + n_input_dtype = (64 // input_bits) * input_lmul + n_output_dtype = (64 // output_bits) * lmul + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_macc_dtype = f"{output_str_type}{output_bits}xvscalex{n_output_dtype}" + + broadcast_input = T.int16(0) if input_dtype == "int16" else T.float32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_vmacc_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((int(vlmax),), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0 : int(vlmax)], A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0 : int(vlmax)]) + for j in range(0, int(vlmax)): + with T.block("update"): + vj = T.axis.remap("S", [j]) + C[vj] = C[vj] + T.cast(A[vj], output_dtype) * T.cast(B[vj], output_dtype) + + @T.prim_func + def rvv_vmacc_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((int(vlmax),), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0 : int(vlmax)]) + + vec_A = ( + T.call_llvm_intrin( + llvm_macc_dtype, + expand_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_macc_dtype, + expand_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + init = T.call_llvm_intrin( + llvm_macc_dtype, + init_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + C.access_ptr(access_mask=C.READ, ptr_type="handle"), + T.uint64(vlmax), + ) + + product = ( + T.call_llvm_intrin( + llvm_macc_dtype, + macc_llvm_intrinsic, + T.uint32(6), + init, + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + T.uint64(3), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_macc_dtype, + macc_llvm_intrinsic, + T.uint32(5), + init, + vec_A, + vec_B, + T.uint64(vlmax), + T.uint64(3), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + T.uint32(3), + product, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(vlmax), + ) + + return rvv_vmacc_desc, rvv_vmacc_llvm_impl + + +def rvv_multivmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + assert J > 1 + + input_bits = _dtype_to_bits(input_dtype) + kernel_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = ( + "i" if output_str_type == "int" else ("u" if output_str_type == "uint" else "f") + ) + + intermmediate_bits = output_bits if output_dtype_prefix == "f" else input_bits + kernel_bits + intermmediate_bits = input_bits + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = ( + "llvm.riscv.vmv.v.x" if output_dtype_prefix != "f" else "llvm.riscv.vfmv.v.f" + ) + mult_llvm_intrinsic = "llvm.riscv.vmul" if output_dtype_prefix != "f" else "llvm.riscv.vfmul" + redsum_llvm_intrinsic = ( + "llvm.riscv.vwredsum" if output_dtype_prefix != "f" else "llvm.riscv.vfredusum" + ) + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + # vscale = vlen // 64 + n_input_dtype = (64 // input_bits) * lmul + n_kernel_dtype = (64 // kernel_bits) * lmul + n_intermmediate_dtype = (64 // intermmediate_bits) * lmul + + n_redsum_dtype = (64 // output_bits) * 1 + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_kernel_dtype = f"{input_dtype}xvscalex{n_kernel_dtype}" + llvm_redsum_dtype = f"{output_dtype}xvscalex{n_redsum_dtype}" + llvm_mult_dtype = f"{output_str_type}{intermmediate_bits}xvscalex{n_intermmediate_dtype}" + + broadcast_input = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_kernel = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_intermmediate = T.int16(0) if intermmediate_bits == 16 else T.int32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_multivmul_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((J, int(vlmax)), kernel_dtype, align=4, offset_factor=1), + C: T.Buffer((J,), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0:J], A[0 : int(vlmax)], B[0:J, 0 : int(vlmax)]) + T.writes(C[0:J]) + for j in range(0, J): + for k in range(0, int(vlmax)): + with T.block("update"): + vj, vk = T.axis.remap("SR", [j, k]) + C[vj] = C[vj] + T.cast(A[vk], output_dtype) * T.cast( + B[vj, vk], output_dtype + ) + + @T.prim_func + def rvv_multivmul_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer( + (J, int(vlmax)), kernel_dtype, align=4, offset_factor=1, strides=[T.int32(), T.int32()] + ), + C: T.Buffer((J,), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0:J, 0 : int(vlmax)]) + T.writes(C[0:J]) + + vec_A = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_kernel_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_kernel, n_kernel_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + redsum = T.call_llvm_intrin( + llvm_redsum_dtype, + init_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + C[0], + T.uint64(1), + ) + + product = ( + T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.uint32(5), + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.uint32(4), + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(vlmax), + ) + ) + + redsum_result = ( + T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.uint32(5), + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.uint32(4), + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(vlmax), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + T.uint32(3), + redsum_result, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(1), + ) + + return rvv_multivmul_desc, rvv_multivmul_llvm_impl + + +def rvv_vmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + input_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = ( + "i" if output_str_type == "int" else ("u" if output_str_type == "uint" else "f") + ) + + intermmediate_bits = output_bits if output_dtype_prefix == "f" else input_bits * 2 + intermmediate_bits = input_bits + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = ( + "llvm.riscv.vmv.v.x" if output_dtype_prefix != "f" else "llvm.riscv.vfmv.v.f" + ) + mult_llvm_intrinsic = "llvm.riscv.vmul" if output_dtype_prefix != "f" else "llvm.riscv.vfmul" + redsum_llvm_intrinsic = ( + "llvm.riscv.vwredsum" if output_dtype_prefix != "f" else "llvm.riscv.vfredusum" + ) + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + # vscale = vlen // 64 + n_input_dtype = (64 // input_bits) * lmul + n_kernel_dtype = (64 // input_bits) * lmul + n_intermmediate_dtype = (64 // intermmediate_bits) * lmul + + n_redsum_dtype = (64 // output_bits) * 1 + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_kernel_dtype = f"{input_dtype}xvscalex{n_kernel_dtype}" + llvm_redsum_dtype = f"{output_dtype}xvscalex{n_redsum_dtype}" + llvm_mult_dtype = f"{output_str_type}{intermmediate_bits}xvscalex{n_intermmediate_dtype}" + + broadcast_input = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_kernel = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_intermmediate = T.int16(0) if intermmediate_bits == 16 else T.int32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_vmul_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), kernel_dtype, align=4, offset_factor=1), + C: T.Buffer((1,), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0], A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0]) + for k in range(0, int(vlmax)): + with T.block("update"): + vk = T.axis.remap("R", [k]) + C[0] = C[0] + T.cast(A[vk], output_dtype) * T.cast(B[vk], output_dtype) + + @T.prim_func + def rvv_vmul_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), kernel_dtype, align=4, offset_factor=1), + C: T.Buffer((1,), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0]) + + vec_A = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_kernel_dtype, + load_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_kernel, n_kernel_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + redsum = T.call_llvm_intrin( + llvm_redsum_dtype, + init_llvm_intrinsic, + T.uint32(3), + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + C[0], + T.uint64(1), + ) + + product = ( + T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.uint32(5), + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.uint32(4), + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(vlmax), + ) + ) + + redsum_result = ( + T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.uint32(5), + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.uint32(4), + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(vlmax), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + T.uint32(3), + redsum_result, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(1), + ) + + return rvv_vmul_desc, rvv_vmul_llvm_impl + + +##################################################### +# Registering intrinsics +##################################################### + + +def register_intrinsic_combinations( + outer_loops, initial_vlmax, lmul, input_dtype, output_dtype, prefix, generator +): + for J in outer_loops: + current_vlmax = initial_vlmax + while current_vlmax >= RISCV_MIN_VL: + + name = f"{prefix}_{J}_{current_vlmax}_m{lmul}" + + desc, impl = generator(J, current_vlmax, input_dtype, output_dtype, lmul) + + print(f"Registering intrin {name}...") + + TensorIntrin.register(name, desc, impl, override=True) + + current_vlmax = current_vlmax // 2 + + +def register_riscv_tensor_intrinsics(target): + target_kind = target.kind.name + assert target_kind in ["llvm"] + + ##################################################### + # Register custom RVV types for C code generation + ##################################################### + dtype_counter = 0 + for bits in [8, 16, 32, 64]: + for dtype in ["int", "uint", "float"]: + for m in [1, 2, 4, 8]: + custom_rvv_type = f"v{dtype}{bits}m{m}_t" + register(custom_rvv_type, 150 + dtype_counter) + register_op( + lower_call_pure_extern, + "Call", + "c", + custom_rvv_type, + intrinsic_name="tir.call_pure_extern", + ) + dtype_counter += 1 + + vlen = get_vlen_from_mattrs(target.mattr) + + for vmul_type, func, outer_loops in zip( + ["vmacc", "multivmul", "vmul"], + [rvv_vmacc, rvv_multivmul, rvv_vmul], + [[1], [get_vlmax(vlen, lmul=1, max_sew=32)], [1]], + ): + + for idtype, odtype in zip(["int16", "float32"], ["int32", "float32"]): + + if idtype == "float32" and vmul_type == "multivmul": + continue + + vlmax = get_vlmax(vlen, lmul=8, max_sew=32) + register_intrinsic_combinations( + outer_loops, vlmax, 8, idtype, odtype, f"rvv_{idtype}_{vmul_type}", func + ) + + print("Finished registering all intrinsics.") diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index ccf280860d80..6d119296480a 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -69,6 +69,14 @@ Array Postproc::DefaultCPUTensorization() { }; } +Array Postproc::DefaultRISCV() { + return Array{ + Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(), + Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/false), + Postproc::RewriteLayout(), + }; +} + Array Postproc::DefaultCUDA() { return Array{ Postproc::DisallowDynamicLoop(), diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 9570c0d0f904..3792632ee044 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -304,6 +304,122 @@ Array ScheduleRule::DefaultHexagon() { }; } +int GetVLMAX(int vlen, int lmul, int max_sew) { return (lmul * vlen) / max_sew; } + +Array ScheduleRule::DefaultRISCV(int vlen) { + Array rules; + + rules.push_back(ScheduleRule::ApplyCustomRule()); + + rules.push_back(ScheduleRule::InlineConstantScalars()); + + rules.push_back(ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/Array{"tir.exp"})); + + rules.push_back(ScheduleRule::AddRFactor( + /*max_jobs_per_core=*/16, + /*max_innermost_factor=*/Integer(64))); + + int vlmax = 0; + int RISCV_MIN_VL = 4; + std::vector vmul_types = {"multivmul", "vmul", "vmacc"}; + String intrin_name = ""; + int j = 1; + + for (const std::string& vmul_type : vmul_types) { + if (vmul_type == "multivmul") + j = GetVLMAX(vlen, 1, 32); + else + j = 1; + + // Registering for int16 + vlmax = GetVLMAX(vlen, 8, 32); + while (vlmax >= RISCV_MIN_VL) { + intrin_name = + "rvv_int16_" + vmul_type + "_" + std::to_string(j) + "_" + std::to_string(vlmax) + "_m8"; + rules.push_back(ScheduleRule::MultiLevelTilingWithIntrin( + /*intrin_name=*/intrin_name, + /*structure=*/"SSRSRS", + /*tile_binds=*/std::nullopt, + /*max_innermost_factor=*/Integer(vlmax), + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}})); + vlmax /= 2; + } + + // Registering for float16 + vlmax = GetVLMAX(vlen, 8, 16); + if (vmul_type == "multivmul") + j = GetVLMAX(vlen, 1, 32); + else + j = 1; + + while (vlmax >= RISCV_MIN_VL) { + intrin_name = "rvv_float16_" + vmul_type + "_" + std::to_string(j) + "_" + + std::to_string(vlmax) + "_m8"; + rules.push_back(ScheduleRule::MultiLevelTilingWithIntrin( + /*intrin_name=*/intrin_name, + /*structure=*/"SSRSRS", + /*tile_binds=*/std::nullopt, + /*max_innermost_factor=*/Integer(vlmax), + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}})); + vlmax /= 2; + } + + vlmax = GetVLMAX(vlen, 8, 32); + while (vlmax >= RISCV_MIN_VL) { + intrin_name = "rvv_float32_" + vmul_type + "_" + std::to_string(j) + "_" + + std::to_string(vlmax) + "_m8"; + rules.push_back(ScheduleRule::MultiLevelTilingWithIntrin( + /*intrin_name=*/intrin_name, + /*structure=*/"SSRSRS", + /*tile_binds=*/std::nullopt, + /*max_innermost_factor=*/Integer(vlmax), + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}})); + vlmax /= 2; + } + } + rules.push_back(ScheduleRule::MultiLevelTiling( + /*structure=*/"SSRSRS", + /*tile_binds=*/std::nullopt, + /*max_innermost_factor=*/Integer(64), + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, + /*reuse_write=*/ + Map{ + {"req", String("may")}, {"levels", Array{1, 2}}, {"scope", String("global")}})); + + rules.push_back(ScheduleRule::ParallelizeVectorizeUnroll( + /*max_jobs_per_core=*/16, + /*max_vectorize_extent=*/64, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_explicit=*/true)); + + rules.push_back(ScheduleRule::RandomComputeLocation()); + + return rules; +} + Array GetARMNeonSpecificRules() { return { ScheduleRule::MultiLevelTilingWithIntrin( diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 709b36417c9e..456def8e6041 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -18,7 +18,9 @@ */ #include +#include "../../runtime/regex.h" #include "../../target/parsers/aprofile.h" +#include "../../target/parsers/cpu.h" #include "../utils.h" namespace tvm { @@ -43,6 +45,9 @@ String GetRuleKindFromTarget(const Target& target) { TargetJSON target_json = target::parsers::aprofile::ParseTarget(target->Export()); TargetFeatures afeatures = Downcast(target_json.at("features")); + if (Downcast(afeatures.at("has_rvv"))) { + return "rvv"; + } if (Downcast(afeatures.at("has_dotprod"))) { return "dotprod"; } @@ -83,6 +88,31 @@ String GetRuleKindFromTarget(const Target& target) { throw; } +std::string GetRISCVMarchFromTarget(const Target& target) { + if (target->kind->name == "c") { + if (Optional opt_march = target->GetAttr("march")) { + return opt_march.value(); + } + } + return ""; +} + +int GetRISCVVLENFromCTarget(const Target& target) { + auto march = GetRISCVMarchFromTarget(target); + int vlen = 0; + if (march.find("zvl") != std::string::npos) { + vlen = tvm::target::parsers::cpu::extractVLENFromString(march); + } + return vlen; +} + +int GetRISCVVLENFromLLVMTarget(const Target& target) { + TargetJSON target_json = target::parsers::aprofile::ParseTarget(target->Export()); + TargetFeatures afeatures = Downcast(target_json.at("features")); + int vlen = Downcast(afeatures.at("rvv_vlen"))->value; + return vlen; +} + void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { if (context->target.defined() && // !(sch_rules.defined() && // @@ -117,6 +147,11 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { default_sch_rules = ScheduleRule::DefaultX86("avx512"); default_postprocs = Postproc::DefaultCPUTensorization(); default_mutator_probs = Mutator::DefaultLLVM(); + } else if (kind == "rvv") { + int vlen = GetRISCVVLENFromLLVMTarget(context->target.value()); + default_sch_rules = ScheduleRule::DefaultRISCV(vlen); + default_postprocs = Postproc::DefaultRISCV(); + default_mutator_probs = Mutator::DefaultLLVM(); } else if (kind == "asimd") { default_sch_rules = ScheduleRule::DefaultARM("neon"); default_postprocs = Postproc::DefaultCPUTensorization(); diff --git a/src/target/parsers/aprofile.cc b/src/target/parsers/aprofile.cc index 65bd6a66aedb..868ba0dd8413 100644 --- a/src/target/parsers/aprofile.cc +++ b/src/target/parsers/aprofile.cc @@ -27,8 +27,10 @@ #include #include +#include "../../runtime/regex.h" #include "../../support/utils.h" #include "../llvm/llvm_instance.h" +#include "cpu.h" namespace tvm { namespace target { @@ -80,6 +82,17 @@ bool CheckContains(Array array, String predicate) { return std::any_of(array.begin(), array.end(), [&](String var) { return var == predicate; }); } +int FindRISCVVLEN(Map features) { + int vlen = 128; + for (auto const& feature : features) { + std::string feature_str = Downcast(feature.first); + if (feature_str.find("zvl") != std::string::npos) { + vlen = tvm::target::parsers::cpu::extractVLENFromString(feature_str); + } + } + return vlen; +} + static TargetFeatures GetFeatures(TargetJSON target) { #ifdef TVM_LLVM_VERSION String kind = Downcast(target.Get("kind").value()); @@ -109,6 +122,8 @@ static TargetFeatures GetFeatures(TargetJSON target) { return {{"is_aarch64", Bool(IsAArch64(mtriple))}, {"has_asimd", Bool(has_feature("neon"))}, {"has_sve", Bool(has_feature("sve"))}, + {"has_rvv", Bool(has_feature("v"))}, + {"rvv_vlen", Integer(FindRISCVVLEN(features))}, {"has_dotprod", Bool(has_feature("dotprod"))}, {"has_matmul_i8", Bool(has_feature("i8mm"))}, {"has_fp16_simd", Bool(has_feature("fullfp16"))}, diff --git a/src/target/parsers/cpu.cc b/src/target/parsers/cpu.cc index ee9bf814d323..18a699717d72 100644 --- a/src/target/parsers/cpu.cc +++ b/src/target/parsers/cpu.cc @@ -60,6 +60,29 @@ TargetJSON ParseTarget(TargetJSON target) { return target; } +int extractVLENFromString(const std::string& input) { + for (size_t i = 0; i + 4 <= input.size(); ++i) { + // Look for the starting sequence "zvl" + if (input[i] == 'z' && input[i + 1] == 'v' && input[i + 2] == 'l') { + size_t j = i + 3; + std::string number; + + // Collect digits + while (j < input.size() && std::isdigit(input[j])) { + number += input[j]; + ++j; + } + + // Check if followed by 'b' after digits + if (!number.empty() && j < input.size() && input[j] == 'b') { + return std::stoi(number); // Convert the number to int + } + } + } + + throw std::runtime_error("No valid pattern found"); +} + } // namespace cpu } // namespace parsers } // namespace target diff --git a/src/target/parsers/cpu.h b/src/target/parsers/cpu.h index 588f98eea043..5008ddb424b9 100644 --- a/src/target/parsers/cpu.h +++ b/src/target/parsers/cpu.h @@ -27,12 +27,15 @@ #include +#include + namespace tvm { namespace target { namespace parsers { namespace cpu { TargetJSON ParseTarget(TargetJSON target); +int extractVLENFromString(const std::string& input); } // namespace cpu } // namespace parsers diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index acc05cf96c08..15a051bc7c7c 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -268,6 +268,8 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp << " + " << index_str << " / " << div_factor << ")"; } else if (t == buffer_element_dtype) { os << buffer_str << "[" << index_str << "]"; + } else if (t == buffer_element_dtype) { + os << buffer_str << "[" << index_str << "]"; } else { os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; } From 918d22e6b5b031758edea180e76f1d0a23573f04 Mon Sep 17 00:00:00 2001 From: Federico Peccia Date: Mon, 18 Aug 2025 09:49:23 +0200 Subject: [PATCH 2/9] Added diff proposal to reuse target.llvm_get_vector_width --- .../space_generator/space_generator.cc | 33 ++++--------------- src/target/parsers/aprofile.cc | 14 -------- src/target/parsers/cpu.cc | 22 ------------- src/target/parsers/cpu.h | 2 -- 4 files changed, 7 insertions(+), 64 deletions(-) diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 456def8e6041..c0163cb41943 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -18,9 +18,7 @@ */ #include -#include "../../runtime/regex.h" #include "../../target/parsers/aprofile.h" -#include "../../target/parsers/cpu.h" #include "../utils.h" namespace tvm { @@ -41,13 +39,14 @@ String GetRuleKindFromTarget(const Target& target) { return "avx512"; } } + bool have_rvv = target_has_feature_fn_ptr("v", target).cast(); + if (have_rvv) { + return "rvv"; + } TargetJSON target_json = target::parsers::aprofile::ParseTarget(target->Export()); TargetFeatures afeatures = Downcast(target_json.at("features")); - if (Downcast(afeatures.at("has_rvv"))) { - return "rvv"; - } if (Downcast(afeatures.at("has_dotprod"))) { return "dotprod"; } @@ -88,28 +87,10 @@ String GetRuleKindFromTarget(const Target& target) { throw; } -std::string GetRISCVMarchFromTarget(const Target& target) { - if (target->kind->name == "c") { - if (Optional opt_march = target->GetAttr("march")) { - return opt_march.value(); - } - } - return ""; -} - -int GetRISCVVLENFromCTarget(const Target& target) { - auto march = GetRISCVMarchFromTarget(target); - int vlen = 0; - if (march.find("zvl") != std::string::npos) { - vlen = tvm::target::parsers::cpu::extractVLENFromString(march); - } - return vlen; -} - int GetRISCVVLENFromLLVMTarget(const Target& target) { - TargetJSON target_json = target::parsers::aprofile::ParseTarget(target->Export()); - TargetFeatures afeatures = Downcast(target_json.at("features")); - int vlen = Downcast(afeatures.at("rvv_vlen"))->value; + static auto llvm_get_vector_width_fn = + tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width"); + const int vlen = llvm_get_vector_width_fn(target).cast(); return vlen; } diff --git a/src/target/parsers/aprofile.cc b/src/target/parsers/aprofile.cc index 868ba0dd8413..c7d7f3448cea 100644 --- a/src/target/parsers/aprofile.cc +++ b/src/target/parsers/aprofile.cc @@ -27,10 +27,8 @@ #include #include -#include "../../runtime/regex.h" #include "../../support/utils.h" #include "../llvm/llvm_instance.h" -#include "cpu.h" namespace tvm { namespace target { @@ -82,16 +80,6 @@ bool CheckContains(Array array, String predicate) { return std::any_of(array.begin(), array.end(), [&](String var) { return var == predicate; }); } -int FindRISCVVLEN(Map features) { - int vlen = 128; - for (auto const& feature : features) { - std::string feature_str = Downcast(feature.first); - if (feature_str.find("zvl") != std::string::npos) { - vlen = tvm::target::parsers::cpu::extractVLENFromString(feature_str); - } - } - return vlen; -} static TargetFeatures GetFeatures(TargetJSON target) { #ifdef TVM_LLVM_VERSION @@ -122,8 +110,6 @@ static TargetFeatures GetFeatures(TargetJSON target) { return {{"is_aarch64", Bool(IsAArch64(mtriple))}, {"has_asimd", Bool(has_feature("neon"))}, {"has_sve", Bool(has_feature("sve"))}, - {"has_rvv", Bool(has_feature("v"))}, - {"rvv_vlen", Integer(FindRISCVVLEN(features))}, {"has_dotprod", Bool(has_feature("dotprod"))}, {"has_matmul_i8", Bool(has_feature("i8mm"))}, {"has_fp16_simd", Bool(has_feature("fullfp16"))}, diff --git a/src/target/parsers/cpu.cc b/src/target/parsers/cpu.cc index 18a699717d72..cc92af3af9fa 100644 --- a/src/target/parsers/cpu.cc +++ b/src/target/parsers/cpu.cc @@ -60,28 +60,6 @@ TargetJSON ParseTarget(TargetJSON target) { return target; } -int extractVLENFromString(const std::string& input) { - for (size_t i = 0; i + 4 <= input.size(); ++i) { - // Look for the starting sequence "zvl" - if (input[i] == 'z' && input[i + 1] == 'v' && input[i + 2] == 'l') { - size_t j = i + 3; - std::string number; - - // Collect digits - while (j < input.size() && std::isdigit(input[j])) { - number += input[j]; - ++j; - } - - // Check if followed by 'b' after digits - if (!number.empty() && j < input.size() && input[j] == 'b') { - return std::stoi(number); // Convert the number to int - } - } - } - - throw std::runtime_error("No valid pattern found"); -} } // namespace cpu } // namespace parsers diff --git a/src/target/parsers/cpu.h b/src/target/parsers/cpu.h index 5008ddb424b9..85f082c2f648 100644 --- a/src/target/parsers/cpu.h +++ b/src/target/parsers/cpu.h @@ -27,7 +27,6 @@ #include -#include namespace tvm { namespace target { @@ -35,7 +34,6 @@ namespace parsers { namespace cpu { TargetJSON ParseTarget(TargetJSON target); -int extractVLENFromString(const std::string& input); } // namespace cpu } // namespace parsers From 49d9a0c5bde56369245d98e879d65dea0e306e85 Mon Sep 17 00:00:00 2001 From: Federico Peccia Date: Mon, 18 Aug 2025 10:11:24 +0200 Subject: [PATCH 3/9] Fixes/changes based on comments on PR --- python/tvm/meta_schedule/tune_context.py | 4 +-- python/tvm/target/target.py | 4 --- python/tvm/tir/tensor_intrin/riscv_cpu.py | 37 +++++++---------------- src/target/source/codegen_c.cc | 2 -- 4 files changed, 13 insertions(+), 34 deletions(-) diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 488cf2712d29..f2220aab43dd 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -28,6 +28,7 @@ from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule +from tvm.target.codegen import target_has_features from . import _ffi_api from .logging import Logger, get_logger, get_logging_func @@ -118,8 +119,7 @@ def __init__( if not isinstance(target, Target): target = Target(target) if "riscv_cpu" in target.keys: - base_features = str(target.attrs["march"]).split("_")[0].replace("rv", "") - if "v" in base_features: + if target_has_features("v", target): # Because the RVV intrinsics depend on the target, we register them here # pylint: disable=import-outside-toplevel from tvm.tir.tensor_intrin.riscv_cpu import register_riscv_tensor_intrinsics diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 9c72b2fdab27..8a8347169ca2 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -638,16 +638,12 @@ def riscv_cpu(model="sifive-u54", options=None): # cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=sifive-u74 ], "bpi-f3": [ - # "-model=sifive-u74", "-mtriple=riscv64-unknown-linux-gnu", "-mcpu=generic", - # "-march=rv64gcv_zvl256b", - # "-mcpu=generic-rv64", "-mfloat-abi=hard", "-num-cores=8", "-mabi=lp64d", "-mattr=+v,+zvl256b", - # cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=generic -mattr=+v ], } pre_defined_opt = trans_table.get(model, ["-model=%s" % model]) diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py b/python/tvm/tir/tensor_intrin/riscv_cpu.py index 138fa66776af..b3565f4ee4f7 100644 --- a/python/tvm/tir/tensor_intrin/riscv_cpu.py +++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py @@ -20,10 +20,13 @@ **Author**: `Federico Peccia `_ """ import re +import logging from tvm.script import tir as T -from tvm.target.datatype import lower_call_pure_extern, register, register_op +from tvm.target.codegen import llvm_get_vector_width from .. import TensorIntrin +logger = logging.getLogger(__name__) + ##################################################### # LLVM RISC-V Intrinsic usage: # https://llvm.org/docs//RISCV/RISCVVectorExtension.html @@ -327,7 +330,7 @@ def rvv_multivmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: @T.prim_func def rvv_multivmul_desc( A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), - B: T.Buffer((J, int(vlmax)), kernel_dtype, align=4, offset_factor=1), + B: T.Buffer((J, int(vlmax)), input_dtype, align=4, offset_factor=1), C: T.Buffer((J,), output_dtype, align=4, offset_factor=1), ) -> None: with T.block("root"): @@ -345,7 +348,7 @@ def rvv_multivmul_desc( def rvv_multivmul_llvm_impl( A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), B: T.Buffer( - (J, int(vlmax)), kernel_dtype, align=4, offset_factor=1, strides=[T.int32(), T.int32()] + (J, int(vlmax)), input_dtype, align=4, offset_factor=1, strides=[T.int32(), T.int32()] ), C: T.Buffer((J,), output_dtype, align=4, offset_factor=1), ) -> None: @@ -530,7 +533,7 @@ def rvv_vmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int) @T.prim_func def rvv_vmul_desc( A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), - B: T.Buffer((int(vlmax),), kernel_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), C: T.Buffer((1,), output_dtype, align=4, offset_factor=1), ) -> None: with T.block("root"): @@ -544,7 +547,7 @@ def rvv_vmul_desc( @T.prim_func def rvv_vmul_llvm_impl( A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), - B: T.Buffer((int(vlmax),), kernel_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), C: T.Buffer((1,), output_dtype, align=4, offset_factor=1), ) -> None: @@ -690,7 +693,7 @@ def register_intrinsic_combinations( desc, impl = generator(J, current_vlmax, input_dtype, output_dtype, lmul) - print(f"Registering intrin {name}...") + logger.debug(f"Registering intrin {name}...") TensorIntrin.register(name, desc, impl, override=True) @@ -701,25 +704,7 @@ def register_riscv_tensor_intrinsics(target): target_kind = target.kind.name assert target_kind in ["llvm"] - ##################################################### - # Register custom RVV types for C code generation - ##################################################### - dtype_counter = 0 - for bits in [8, 16, 32, 64]: - for dtype in ["int", "uint", "float"]: - for m in [1, 2, 4, 8]: - custom_rvv_type = f"v{dtype}{bits}m{m}_t" - register(custom_rvv_type, 150 + dtype_counter) - register_op( - lower_call_pure_extern, - "Call", - "c", - custom_rvv_type, - intrinsic_name="tir.call_pure_extern", - ) - dtype_counter += 1 - - vlen = get_vlen_from_mattrs(target.mattr) + vlen = llvm_get_vector_width(target) for vmul_type, func, outer_loops in zip( ["vmacc", "multivmul", "vmul"], @@ -727,7 +712,7 @@ def register_riscv_tensor_intrinsics(target): [[1], [get_vlmax(vlen, lmul=1, max_sew=32)], [1]], ): - for idtype, odtype in zip(["int16", "float32"], ["int32", "float32"]): + for idtype, odtype in zip(["int16", "float16", "float32"], ["int32", "float16", "float32"]): if idtype == "float32" and vmul_type == "multivmul": continue diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 15a051bc7c7c..acc05cf96c08 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -268,8 +268,6 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp << " + " << index_str << " / " << div_factor << ")"; } else if (t == buffer_element_dtype) { os << buffer_str << "[" << index_str << "]"; - } else if (t == buffer_element_dtype) { - os << buffer_str << "[" << index_str << "]"; } else { os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; } From f9b2667a4ee2a52d4aefbe3bf7c39479563aeec1 Mon Sep 17 00:00:00 2001 From: Federico Peccia Date: Mon, 18 Aug 2025 10:16:36 +0200 Subject: [PATCH 4/9] Lint fixes + call_llvm_intrin nargs change --- python/tvm/tir/tensor_intrin/riscv_cpu.py | 34 ----------------------- src/target/parsers/aprofile.cc | 1 - src/target/parsers/cpu.cc | 1 - src/target/parsers/cpu.h | 1 - 4 files changed, 37 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py b/python/tvm/tir/tensor_intrin/riscv_cpu.py index b3565f4ee4f7..62e535b2b2bd 100644 --- a/python/tvm/tir/tensor_intrin/riscv_cpu.py +++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py @@ -177,12 +177,10 @@ def rvv_vmacc_llvm_impl( T.call_llvm_intrin( llvm_macc_dtype, expand_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_output, n_output_dtype * T.vscale()), T.call_llvm_intrin( llvm_input_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_input, n_input_dtype * T.vscale()), A.access_ptr(access_mask=A.READ, ptr_type="handle"), T.int64(vlmax), @@ -193,7 +191,6 @@ def rvv_vmacc_llvm_impl( else T.call_llvm_intrin( llvm_input_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_input, n_input_dtype * T.vscale()), A.access_ptr(access_mask=A.READ, ptr_type="handle"), T.int64(vlmax), @@ -204,12 +201,10 @@ def rvv_vmacc_llvm_impl( T.call_llvm_intrin( llvm_macc_dtype, expand_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_output, n_output_dtype * T.vscale()), T.call_llvm_intrin( llvm_input_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_input, n_input_dtype * T.vscale()), B.access_ptr(access_mask=B.READ, ptr_type="handle"), T.int64(vlmax), @@ -220,7 +215,6 @@ def rvv_vmacc_llvm_impl( else T.call_llvm_intrin( llvm_input_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_input, n_input_dtype * T.vscale()), B.access_ptr(access_mask=B.READ, ptr_type="handle"), T.int64(vlmax), @@ -230,7 +224,6 @@ def rvv_vmacc_llvm_impl( init = T.call_llvm_intrin( llvm_macc_dtype, init_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_output, n_output_dtype * T.vscale()), C.access_ptr(access_mask=C.READ, ptr_type="handle"), T.uint64(vlmax), @@ -240,7 +233,6 @@ def rvv_vmacc_llvm_impl( T.call_llvm_intrin( llvm_macc_dtype, macc_llvm_intrinsic, - T.uint32(6), init, vec_A, vec_B, @@ -252,7 +244,6 @@ def rvv_vmacc_llvm_impl( else T.call_llvm_intrin( llvm_macc_dtype, macc_llvm_intrinsic, - T.uint32(5), init, vec_A, vec_B, @@ -264,7 +255,6 @@ def rvv_vmacc_llvm_impl( T.call_llvm_intrin( "", store_llvm_intrinsic, - T.uint32(3), product, C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), T.uint64(vlmax), @@ -362,12 +352,10 @@ def rvv_multivmul_llvm_impl( T.call_llvm_intrin( llvm_mult_dtype, expand_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), T.call_llvm_intrin( llvm_input_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_input, n_input_dtype * T.vscale()), A.access_ptr(access_mask=A.READ, ptr_type="handle"), T.int64(vlmax), @@ -378,7 +366,6 @@ def rvv_multivmul_llvm_impl( else T.call_llvm_intrin( llvm_input_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_input, n_input_dtype * T.vscale()), A.access_ptr(access_mask=A.READ, ptr_type="handle"), T.int64(vlmax), @@ -389,12 +376,10 @@ def rvv_multivmul_llvm_impl( T.call_llvm_intrin( llvm_mult_dtype, expand_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), T.call_llvm_intrin( llvm_input_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_input, n_input_dtype * T.vscale()), B.access_ptr(access_mask=B.READ, ptr_type="handle"), T.int64(vlmax), @@ -405,7 +390,6 @@ def rvv_multivmul_llvm_impl( else T.call_llvm_intrin( llvm_kernel_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_kernel, n_kernel_dtype * T.vscale()), B.access_ptr(access_mask=B.READ, ptr_type="handle"), T.int64(vlmax), @@ -415,7 +399,6 @@ def rvv_multivmul_llvm_impl( redsum = T.call_llvm_intrin( llvm_redsum_dtype, init_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), C[0], T.uint64(1), @@ -425,7 +408,6 @@ def rvv_multivmul_llvm_impl( T.call_llvm_intrin( llvm_mult_dtype, mult_llvm_intrinsic, - T.uint32(5), T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), vec_A, vec_B, @@ -436,7 +418,6 @@ def rvv_multivmul_llvm_impl( else T.call_llvm_intrin( llvm_mult_dtype, mult_llvm_intrinsic, - T.uint32(4), T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), vec_A, vec_B, @@ -448,7 +429,6 @@ def rvv_multivmul_llvm_impl( T.call_llvm_intrin( llvm_redsum_dtype, redsum_llvm_intrinsic, - T.uint32(5), T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), product, redsum, @@ -459,7 +439,6 @@ def rvv_multivmul_llvm_impl( else T.call_llvm_intrin( llvm_redsum_dtype, redsum_llvm_intrinsic, - T.uint32(4), T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), product, redsum, @@ -470,7 +449,6 @@ def rvv_multivmul_llvm_impl( T.call_llvm_intrin( "", store_llvm_intrinsic, - T.uint32(3), redsum_result, C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), T.uint64(1), @@ -560,12 +538,10 @@ def rvv_vmul_llvm_impl( T.call_llvm_intrin( llvm_mult_dtype, expand_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), T.call_llvm_intrin( llvm_input_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_input, n_input_dtype * T.vscale()), A.access_ptr(access_mask=A.READ, ptr_type="handle"), T.int64(vlmax), @@ -576,7 +552,6 @@ def rvv_vmul_llvm_impl( else T.call_llvm_intrin( llvm_input_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_input, n_input_dtype * T.vscale()), A.access_ptr(access_mask=A.READ, ptr_type="handle"), T.int64(vlmax), @@ -587,12 +562,10 @@ def rvv_vmul_llvm_impl( T.call_llvm_intrin( llvm_mult_dtype, expand_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), T.call_llvm_intrin( llvm_input_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_input, n_input_dtype * T.vscale()), B.access_ptr(access_mask=B.READ, ptr_type="handle"), T.int64(vlmax), @@ -603,7 +576,6 @@ def rvv_vmul_llvm_impl( else T.call_llvm_intrin( llvm_kernel_dtype, load_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_kernel, n_kernel_dtype * T.vscale()), B.access_ptr(access_mask=B.READ, ptr_type="handle"), T.int64(vlmax), @@ -613,7 +585,6 @@ def rvv_vmul_llvm_impl( redsum = T.call_llvm_intrin( llvm_redsum_dtype, init_llvm_intrinsic, - T.uint32(3), T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), C[0], T.uint64(1), @@ -623,7 +594,6 @@ def rvv_vmul_llvm_impl( T.call_llvm_intrin( llvm_mult_dtype, mult_llvm_intrinsic, - T.uint32(5), T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), vec_A, vec_B, @@ -634,7 +604,6 @@ def rvv_vmul_llvm_impl( else T.call_llvm_intrin( llvm_mult_dtype, mult_llvm_intrinsic, - T.uint32(4), T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), vec_A, vec_B, @@ -646,7 +615,6 @@ def rvv_vmul_llvm_impl( T.call_llvm_intrin( llvm_redsum_dtype, redsum_llvm_intrinsic, - T.uint32(5), T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), product, redsum, @@ -657,7 +625,6 @@ def rvv_vmul_llvm_impl( else T.call_llvm_intrin( llvm_redsum_dtype, redsum_llvm_intrinsic, - T.uint32(4), T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), product, redsum, @@ -668,7 +635,6 @@ def rvv_vmul_llvm_impl( T.call_llvm_intrin( "", store_llvm_intrinsic, - T.uint32(3), redsum_result, C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), T.uint64(1), diff --git a/src/target/parsers/aprofile.cc b/src/target/parsers/aprofile.cc index c7d7f3448cea..65bd6a66aedb 100644 --- a/src/target/parsers/aprofile.cc +++ b/src/target/parsers/aprofile.cc @@ -80,7 +80,6 @@ bool CheckContains(Array array, String predicate) { return std::any_of(array.begin(), array.end(), [&](String var) { return var == predicate; }); } - static TargetFeatures GetFeatures(TargetJSON target) { #ifdef TVM_LLVM_VERSION String kind = Downcast(target.Get("kind").value()); diff --git a/src/target/parsers/cpu.cc b/src/target/parsers/cpu.cc index cc92af3af9fa..ee9bf814d323 100644 --- a/src/target/parsers/cpu.cc +++ b/src/target/parsers/cpu.cc @@ -60,7 +60,6 @@ TargetJSON ParseTarget(TargetJSON target) { return target; } - } // namespace cpu } // namespace parsers } // namespace target diff --git a/src/target/parsers/cpu.h b/src/target/parsers/cpu.h index 85f082c2f648..588f98eea043 100644 --- a/src/target/parsers/cpu.h +++ b/src/target/parsers/cpu.h @@ -27,7 +27,6 @@ #include - namespace tvm { namespace target { namespace parsers { From a86c2145fd133b512abff06cd728e6ada9addf1e Mon Sep 17 00:00:00 2001 From: Federico Peccia Date: Thu, 21 Aug 2025 14:26:44 +0200 Subject: [PATCH 5/9] Intrinsic registration mistakes fixed --- python/tvm/tir/tensor_intrin/riscv_cpu.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py b/python/tvm/tir/tensor_intrin/riscv_cpu.py index 62e535b2b2bd..b897cc3fb77e 100644 --- a/python/tvm/tir/tensor_intrin/riscv_cpu.py +++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py @@ -18,6 +18,9 @@ """Intrinsics for RVV tensorization, both for C and LLVM targets. ===================== **Author**: `Federico Peccia `_ + [*] Tensor Program Optimization for the RISC-V Vector + Extension Using Probabilistic Programs + https://arxiv.org/abs/2507.01457 """ import re import logging @@ -680,12 +683,9 @@ def register_riscv_tensor_intrinsics(target): for idtype, odtype in zip(["int16", "float16", "float32"], ["int32", "float16", "float32"]): - if idtype == "float32" and vmul_type == "multivmul": - continue - vlmax = get_vlmax(vlen, lmul=8, max_sew=32) register_intrinsic_combinations( outer_loops, vlmax, 8, idtype, odtype, f"rvv_{idtype}_{vmul_type}", func ) - print("Finished registering all intrinsics.") + logger.debug("Finished registering all intrinsics.") From 9f363b5519d6b59e59505fa17b045781b30a744c Mon Sep 17 00:00:00 2001 From: Federico Peccia Date: Thu, 21 Aug 2025 15:55:34 +0200 Subject: [PATCH 6/9] Enabled RVV intrinsics as experimental --- python/tvm/meta_schedule/tune_context.py | 7 ------- python/tvm/tir/tensor_intrin/riscv_cpu.py | 11 +++++++++-- src/meta_schedule/space_generator/space_generator.cc | 12 +++++++++--- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index f2220aab43dd..f131cf16706e 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -118,13 +118,6 @@ def __init__( if target is not None: if not isinstance(target, Target): target = Target(target) - if "riscv_cpu" in target.keys: - if target_has_features("v", target): - # Because the RVV intrinsics depend on the target, we register them here - # pylint: disable=import-outside-toplevel - from tvm.tir.tensor_intrin.riscv_cpu import register_riscv_tensor_intrinsics - - register_riscv_tensor_intrinsics(target) if space_generator is not None: if not isinstance(space_generator, SpaceGenerator): space_generator = SpaceGenerator.create(space_generator) diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py b/python/tvm/tir/tensor_intrin/riscv_cpu.py index b897cc3fb77e..9962d9c2e260 100644 --- a/python/tvm/tir/tensor_intrin/riscv_cpu.py +++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py @@ -18,7 +18,7 @@ """Intrinsics for RVV tensorization, both for C and LLVM targets. ===================== **Author**: `Federico Peccia `_ - [*] Tensor Program Optimization for the RISC-V Vector + [*] Tensor Program Optimization for the RISC-V Vector Extension Using Probabilistic Programs https://arxiv.org/abs/2507.01457 """ @@ -26,6 +26,8 @@ import logging from tvm.script import tir as T from tvm.target.codegen import llvm_get_vector_width +from tvm.target import Target +from tvm.target.codegen import target_has_features from .. import TensorIntrin logger = logging.getLogger(__name__) @@ -683,9 +685,14 @@ def register_riscv_tensor_intrinsics(target): for idtype, odtype in zip(["int16", "float16", "float32"], ["int32", "float16", "float32"]): - vlmax = get_vlmax(vlen, lmul=8, max_sew=32) + vlmax = get_vlmax(vlen, lmul=8, max_sew=16 if odtype == "float16" else 32) register_intrinsic_combinations( outer_loops, vlmax, 8, idtype, odtype, f"rvv_{idtype}_{vmul_type}", func ) logger.debug("Finished registering all intrinsics.") + + +target = Target.current() +if "riscv_cpu" in target.keys and "rvv" in target.model and target_has_features("v", target): + register_riscv_tensor_intrinsics(target) diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index c0163cb41943..8359ec598860 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -129,9 +129,15 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { default_postprocs = Postproc::DefaultCPUTensorization(); default_mutator_probs = Mutator::DefaultLLVM(); } else if (kind == "rvv") { - int vlen = GetRISCVVLENFromLLVMTarget(context->target.value()); - default_sch_rules = ScheduleRule::DefaultRISCV(vlen); - default_postprocs = Postproc::DefaultRISCV(); + if (context->target.value()->GetAttr("model") == "rvv") { + // experimental rvv tensorization + int vlen = GetRISCVVLENFromLLVMTarget(context->target.value()); + default_sch_rules = ScheduleRule::DefaultRISCV(vlen); + default_postprocs = Postproc::DefaultRISCV(); + } else { + default_sch_rules = ScheduleRule::DefaultLLVM(); + default_postprocs = Postproc::DefaultLLVM(); + } default_mutator_probs = Mutator::DefaultLLVM(); } else if (kind == "asimd") { default_sch_rules = ScheduleRule::DefaultARM("neon"); From 91592e905bd1311471286e0318577c6e4827f6c0 Mon Sep 17 00:00:00 2001 From: Federico Peccia Date: Thu, 21 Aug 2025 16:16:31 +0200 Subject: [PATCH 7/9] Lint fixes --- python/tvm/meta_schedule/tune_context.py | 1 - python/tvm/tir/tensor_intrin/riscv_cpu.py | 7 ++++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index f131cf16706e..5512b7a2682b 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -28,7 +28,6 @@ from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule -from tvm.target.codegen import target_has_features from . import _ffi_api from .logging import Logger, get_logger, get_logging_func diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py b/python/tvm/tir/tensor_intrin/riscv_cpu.py index 9962d9c2e260..fd5ca978cbf6 100644 --- a/python/tvm/tir/tensor_intrin/riscv_cpu.py +++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py @@ -693,6 +693,7 @@ def register_riscv_tensor_intrinsics(target): logger.debug("Finished registering all intrinsics.") -target = Target.current() -if "riscv_cpu" in target.keys and "rvv" in target.model and target_has_features("v", target): - register_riscv_tensor_intrinsics(target) +current_target = Target.current() +if "riscv_cpu" in current_target.keys and "rvv" in current_target.model and \ + target_has_features("v", current_target): + register_riscv_tensor_intrinsics(current_target) From 10cf78114a68bc2c6a1a296bc346e63869fe1206 Mon Sep 17 00:00:00 2001 From: Federico Peccia Date: Fri, 22 Aug 2025 08:43:57 +0200 Subject: [PATCH 8/9] CI Lint fix --- python/tvm/tir/tensor_intrin/riscv_cpu.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py b/python/tvm/tir/tensor_intrin/riscv_cpu.py index fd5ca978cbf6..867ceca5602a 100644 --- a/python/tvm/tir/tensor_intrin/riscv_cpu.py +++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py @@ -694,6 +694,9 @@ def register_riscv_tensor_intrinsics(target): current_target = Target.current() -if "riscv_cpu" in current_target.keys and "rvv" in current_target.model and \ - target_has_features("v", current_target): +if ( + "riscv_cpu" in current_target.keys + and "rvv" in current_target.model + and target_has_features("v", current_target) +): register_riscv_tensor_intrinsics(current_target) From c56e82dc91cbaed91dfef6d9e983fc3063a6d28f Mon Sep 17 00:00:00 2001 From: Federico Peccia Date: Fri, 22 Aug 2025 09:32:25 +0200 Subject: [PATCH 9/9] CI fix --- python/tvm/tir/tensor_intrin/riscv_cpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py b/python/tvm/tir/tensor_intrin/riscv_cpu.py index 867ceca5602a..68e9b3ca6212 100644 --- a/python/tvm/tir/tensor_intrin/riscv_cpu.py +++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py @@ -695,7 +695,8 @@ def register_riscv_tensor_intrinsics(target): current_target = Target.current() if ( - "riscv_cpu" in current_target.keys + current_target is not None + and "riscv_cpu" in current_target.keys and "rvv" in current_target.model and target_has_features("v", current_target) ):