Skip to content
This repository was archived by the owner on May 13, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Installation
downloads
install

.vscode
# T2S generated tmp
*.out
tmp.txt
Expand All @@ -14,7 +14,7 @@ profile.mon
*-interface.*
*_genx.cpp
temp*
*.png
.png
*.o
*.isa
signed*
Expand Down
8 changes: 8 additions & 0 deletions Halide/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ WITH_PTX ?= $(findstring nvptx, $(LLVM_COMPONENTS))
WITH_AMDGPU ?= $(findstring amdgpu, $(LLVM_COMPONENTS))
WITH_WEBASSEMBLY ?= $(findstring webassembly, $(LLVM_COMPONENTS))
WITH_OPENCL ?= not-empty
WITH_ONEAPI ?= not-empty
WITH_METAL ?= not-empty
WITH_OPENGL ?= not-empty
WITH_D3D12 ?= not-empty
Expand Down Expand Up @@ -193,6 +194,9 @@ AMDGPU_LLVM_CONFIG_LIB=$(if $(WITH_AMDGPU), amdgpu, )
OPENCL_CXX_FLAGS=$(if $(WITH_OPENCL), -DWITH_OPENCL, )
OPENCL_LLVM_CONFIG_LIB=$(if $(WITH_OPENCL), , )

ONEAPI_CXX_FLAGS=$(if $(WITH_ONEAPI), -DWITH_OPENCL, )
ONEAPI_LLVM_CONFIG_LIB=$(if $(WITH_ONEAPI), , )

METAL_CXX_FLAGS=$(if $(WITH_METAL), -DWITH_METAL, )
METAL_LLVM_CONFIG_LIB=$(if $(WITH_METAL), , )

Expand Down Expand Up @@ -249,6 +253,7 @@ CXX_FLAGS += $(HEXAGON_CXX_FLAGS)
CXX_FLAGS += $(AARCH64_CXX_FLAGS)
CXX_FLAGS += $(X86_CXX_FLAGS)
CXX_FLAGS += $(OPENCL_CXX_FLAGS)
CXX_FLAGS += $(ONEAPI_CXX_FLAGS)
CXX_FLAGS += $(METAL_CXX_FLAGS)
CXX_FLAGS += $(OPENGL_CXX_FLAGS)
CXX_FLAGS += $(D3D12_CXX_FLAGS)
Expand Down Expand Up @@ -281,6 +286,7 @@ LLVM_STATIC_LIBFILES = \
$(X86_LLVM_CONFIG_LIB) \
$(ARM_LLVM_CONFIG_LIB) \
$(OPENCL_LLVM_CONFIG_LIB) \
$(ONEAPI_LLVM_CONFIG_LIB) \
$(METAL_LLVM_CONFIG_LIB) \
$(PTX_LLVM_CONFIG_LIB) \
$(AARCH64_LLVM_CONFIG_LIB) \
Expand Down Expand Up @@ -474,6 +480,7 @@ SOURCE_FILES = \
CanonicalizeGPUVars.cpp \
Closure.cpp \
CodeGen_ARM.cpp \
CodeGen_DPC_Dev.cpp \
CodeGen_C.cpp \
CodeGen_D3D12Compute_Dev.cpp \
CodeGen_GPU_Dev.cpp \
Expand Down Expand Up @@ -647,6 +654,7 @@ HEADER_FILES = \
CanonicalizeGPUVars.h \
Closure.h \
CodeGen_ARM.h \
CodeGen_DPC_Dev.h \
CodeGen_C.h \
CodeGen_D3D12Compute_Dev.h \
CodeGen_GPU_Dev.h \
Expand Down
37 changes: 34 additions & 3 deletions Halide/src/CodeGen_CM_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,10 @@ void CodeGen_CM_Dev::CodeGen_CM_C::print_media_block_rw(Type t, vector<Expr> arg
for (int j = 0; j < cols; j += max_cols_at_once) {
int cols_at_once = j + max_cols_at_once <= cols ? max_cols_at_once : cols-j;

// Replace the buffer name with the one specified in stensors
string name = is_write ? args[7].as<StringImm>()->value : print_expr(args[0]);
stream << get_indent() << (is_write ? "write(" : "read(");
stream << print_name(print_expr(args[0])) << ", ";
stream << print_name(name) << ", ";
stream << print_expr(args[1] * bytes) << ", ";
stream << print_expr(args[2] + i) << ", ";
auto ramp = args[4].as<Ramp>();
Expand Down Expand Up @@ -683,6 +685,28 @@ void CodeGen_CM_Dev::add_kernel(Stmt s,
src_stream.str(str);
}

class FindRefName : public IRVisitor
{
const string &buf_name;
public:
using IRVisitor::visit;
string ref_name;

void visit(const Call *op) override {
if (op->is_intrinsic(Call::cm_store_2d)) {
internal_assert(op->args[0].as<Variable>());
auto &name = op->args[0].as<Variable>()->name;
if (name == buf_name && op->args.size() == 8) {
internal_assert(op->args[7].as<StringImm>());
ref_name = op->args[7].as<StringImm>()->value;
}
}
}

FindRefName(const string &_b)
: buf_name(_b) {}
};

void CodeGen_CM_Dev::CodeGen_CM_C::add_kernel(Stmt s,
const string &name,
const vector<DeviceArgument> &args) {
Expand All @@ -692,7 +716,14 @@ void CodeGen_CM_Dev::CodeGen_CM_C::add_kernel(Stmt s,
stream << "extern \"C\" _GENX_MAIN_ void " << name << "(\n";
for (size_t i = 0; i < args.size(); i++) {
if (args[i].is_buffer) {
stream << "SurfaceIndex " << print_name(args[i].name)
string name = args[i].name;
// Trick: replace the buffer name with the one specified in stensor
FindRefName frn(name);
s.accept(&frn);
if (!frn.ref_name.empty()) {
name = frn.ref_name;
}
stream << "SurfaceIndex " << print_name(name)
<< " [[type(\"image2d_t " << print_type(args[i].type) << "\")]]";
Allocation alloc;
alloc.type = args[i].type;
Expand Down Expand Up @@ -760,4 +791,4 @@ vector<char> CodeGen_CM_Dev::compile_to_src() {
}

} // namespace Internal
} // namespace Halide
} // namespace Halide
Loading